目录
1. DenseNet 介绍
DenseNet是一种深度学习架构,卷积神经网络(CNN)的一种变体,旨在解决梯度消失的问题并提高网络连接性。
在传统的CNN中,信息流是顺序的,每一层只连接到下一层。这可能会导致梯度在网络中传播时减小,从而难以训练深度网络。DenseNet旨在通过引入密集连接来缓解这一问题,密集连接允许从网络中的任何层直接连接到任何其他层。
DenseNet由多个密集块组成,每个密集块包含多个层。密集块内的每一层都连接到同一块内的其他每一层。这种密集的连接促进了特征重用和信息流,使梯度更容易在整个网络中传播。此外,DenseNet在每个密集块后都加入了一个过渡层,以降低特征图的维度并控制网络的增长。
DenseNet的主要优势包括:
改进的梯度流:层之间的直接连接有助于克服梯度消失问题,并实现深度网络的高效训练。
强大的特征重用:密集的连接促进了特征重用,从而实现了更紧凑的网络和更好的参数效率。
参数数量减少:与传统的CNN架构相比,DenseNet通常需要更少的参数,从而使模型更容易训练,计算效率更高。
提高精度:DenseNet已被证明在各种计算机视觉任务上达到了最先进的性能,如图像分类和物体检测。
总体而言,DenseNet是一个强大的深度学习架构,可以解决训练深度网络的挑战。其密集的连接性和高效的参数共享使其成为各种计算机视觉任务的有效选择。
其中,denseNet不同版本的架构如下
2. DenseNet 实现的垃圾图像分类
densenet实现的model部分代码如下面所示,这里如果采用官方预训练权重的话,会自动导入官方提供的最新版本的权重
if model == 'densenet121':
net = m.densenet121(pretrained=m.DenseNet121_Weights.DEFAULT if weights else False,progress=True)
elif model == 'densenet161':
net = m.densenet161(pretrained=m.DenseNet161_Weights.DEFAULT if weights else False,progress=True)
elif model == 'densenet169':
net = m.densenet169(pretrained=m.DenseNet169_Weights.DEFAULT if weights else False,progress=True)
elif model == 'densenet201':
net = m.densenet201(pretrained=m.DenseNet201_Weights.DEFAULT if weights else False,progress=True)
else:
print('模型选择错误!!')
return None
2.1 垃圾10分类数据集
数据集的摆放如下:
字典文件:
{
"0": "dianchi",
"1": "lajiao",
"2": "pingguo",
"3": "qiezi",
"4": "taoci",
"5": "tudou",
"6": "xiangjiao",
"7": "yandi",
"8": "yilaguan",
"9": "yinliaoping"
}
其中,训练集总数700,验证集总数300,经过transform预处理的可视化结果如下:
2.2 训练
将数据集按照上述格式摆放好即可开始训练,训练的参数如下:
parser.add_argument("--model", default='densenet121', type=str,help='densenet121,densenet161,densenet169,densenet201')
parser.add_argument("--pretrained", default=True, type=bool) # 采用官方权重
parser.add_argument("--freeze_layers", default=True, type=bool) # 冻结权重
parser.add_argument("--batch-size", default=4, type=int)
parser.add_argument("--epochs", default=30, type=int)
parser.add_argument("--optim", default='SGD', type=str,help='SGD、Adam') # 优化器选择
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--lrf',default=0.0001,type=float) # 最终学习率 = lr * lrf
这里分类的个数不需要指定,代码会根据数据集自动生成!
2.3 训练结果
所有的结果都保存在 runs 目录下
这里只展示部分:
训练日志:
"epoch:29": {
"train info": {
"accuracy": 0.9985714285571633,
"dianchi": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"lajiao": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"pingguo": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"qiezi": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"taoci": {
"Precision": 1.0,
"Recall": 0.9857,
"Specificity": 1.0,
"F1 score": 0.9928
},
"tudou": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"xiangjiao": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"yandi": {
"Precision": 0.9859,
"Recall": 1.0,
"Specificity": 0.9984,
"F1 score": 0.9929
},
"yilaguan": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"yinliaoping": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"mean precision": 0.9985900000000001,
"mean recall": 0.99857,
"mean specificity": 0.9998400000000001,
"mean f1 score": 0.99857
},
"valid info": {
"accuracy": 0.9699999999676666,
"dianchi": {
"Precision": 0.9032,
"Recall": 0.9333,
"Specificity": 0.9889,
"F1 score": 0.918
},
"lajiao": {
"Precision": 1.0,
"Recall": 0.9667,
"Specificity": 1.0,
"F1 score": 0.9831
},
"pingguo": {
"Precision": 0.9375,
"Recall": 1.0,
"Specificity": 0.9926,
"F1 score": 0.9677
},
"qiezi": {
"Precision": 0.9677,
"Recall": 1.0,
"Specificity": 0.9963,
"F1 score": 0.9836
},
"taoci": {
"Precision": 1.0,
"Recall": 1.0,
"Specificity": 1.0,
"F1 score": 1.0
},
"tudou": {
"Precision": 1.0,
"Recall": 0.9,
"Specificity": 1.0,
"F1 score": 0.9474
},
"xiangjiao": {
"Precision": 0.9677,
"Recall": 1.0,
"Specificity": 0.9963,
"F1 score": 0.9836
},
"yandi": {
"Precision": 0.9375,
"Recall": 1.0,
"Specificity": 0.9926,
"F1 score": 0.9677
},
"yilaguan": {
"Precision": 1.0,
"Recall": 0.9333,
"Specificity": 1.0,
"F1 score": 0.9655
},
"yinliaoping": {
"Precision": 1.0,
"Recall": 0.9667,
"Specificity": 1.0,
"F1 score": 0.9831
},
"mean precision": 0.97136,
"mean recall": 0.97,
"mean specificity": 0.99667,
"mean f1 score": 0.96997
}
训练集和测试集的混淆矩阵:
2.4 推理
推理是指没有标签,只有图片数据的情况下对数据的预测,这里直接运行predict脚本即可
需要把待推理的数据放在 inference/img 下
推理结果:
3. 项目下载
关于本项目代码和数据集、训练结果的下载:
计算机视觉项目:DenseNet卷积神经网络网络【121,161,169,201四种版本】实现的自适应迁移学习、图像识别项目:10种生活中常见垃圾图像分类资源-CSDN文库
关于计算机视觉实战可以继续关注本专栏,会持续更新图像分类和医学图像分割项目
关于图像分类和语义分割的改进:改进系列_Ai 医学图像分割的博客-CSDN博客