DAY 44 预训练模型
知识点回顾:
- 预训练的概念
- 常见的分类预训练模型
- 图像预训练模型的发展史
- 预训练的策略
- 预训练代码实战:resnet18
作业:
- 尝试在cifar10对比如下其他的预训练模型,观察差异,尽可能和他人选择的不同
- 尝试通过ctrl进入resnet的内部,观察残差究竟是什么
这句话可以忽略,这是未来的内容
豆包知识点回顾:
我会用 “学习知识”“盖房子” 等生活场景类比,把预训练相关概念变得通俗易懂,再整理成便于复习的学习笔记,帮你快速掌握要点。
一、知识点通俗解释 + 趣味比喻
1. 预训练的概念
含义:在大规模通用数据上先训练一个模型,学习到通用的基础特征(如物体轮廓、纹理等),然后将这个模型应用到自己的特定任务(如猫狗分类)中,通过微调参数适应新任务,避免从头开始训练。
比喻:就像学习知识,先在学校里学习语文、数学等基础知识(通用数据上预训练) ,掌握了基本的阅读、计算能力。之后学习特定专业(特定任务),比如编程,就不需要从零开始学识字和数数,直接基于已有基础学习编程知识(微调模型),效率更高 。
2. 常见的分类预训练模型
含义:在图像分类任务中被广泛使用的预训练模型,这些模型在大规模图像数据集(如 ImageNet,包含 1000 类物体)上预训练,具有强大的特征提取能力。
比喻:它们是建筑界的 “万能图纸” ,比如 ResNet、VGG、EfficientNet 等模型,就像已经设计好的经典建筑图纸,这些图纸(模型)在建造大量不同房子(处理大量图像)的过程中被优化得非常高效。当你要盖自己的小房子(完成特定分类任务)时,直接使用这些图纸(预训练模型),在其基础上修改细节(微调),能快速且高质量地完成建造(训练) 。
3. 图像预训练模型的发展史
含义:图像预训练模型从早期到现在不断发展,经历了结构复杂度增加、性能提升、计算效率优化等阶段,从简单的网络结构(如 LeNet)到复杂高效的网络(如 Swin Transformer)。
比喻:如同汽车的进化史 ,最开始的图像预训练模型(LeNet)像老式蒸汽汽车,结构简单,速度慢,只能完成简单任务(识别手写数字);后来出现了 VGG,像燃油汽车,层数增多(结构复杂),性能提升;再到 ResNet,引入残差连接,如同混合动力汽车,解决了网络深度增加带来的问题;如今的 Swin Transformer,像新能源智能汽车,采用新架构(Transformer),在性能和效率上都有巨大突破 。
4. 预训练的策略
含义:将预训练模型应用到新任务时的具体方法,包括冻结部分层、微调全部层、调整学习率等,以平衡训练速度和模型效果。
比喻:像是改造二手房 ,拿到一个预训练模型(二手房),有不同的改造策略:
- 冻结部分层:只翻新表面(如粉刷墙壁、换家具),底层结构(预训练好的底层特征提取层)保持不变,适合新任务和预训练任务相似时,快速完成改造(训练);
- 微调全部层:彻底翻新,连地基(所有层参数)都调整,适合新任务和预训练任务差异较大时,但花费时间和精力更多(训练时间长);
- 调整学习率:控制翻新的速度,学习率大(快速翻新),但容易出错;学习率小(缓慢翻新),更精细但耗时 。
5. 预训练代码实战:resnet18
含义:使用 PyTorch 等框架调用预训练的 ResNet18 模型,应用到自己的图像分类任务中,通过加载预训练权重、修改最后一层输出维度、微调模型参数完成训练。
比喻:用现成的乐高套装(ResNet18 预训练模型)搭建新造型 ,先拿到已经拼好大部分的乐高套装(加载预训练权重),然后根据自己想搭建的造型(分类任务),替换掉顶部的几块乐高(修改最后一层输出维度,适应分类类别数),接着按照自己的想法调整拼接方式(微调参数),最终完成独特的乐高作品(训练好的模型) 。
二、学习笔记
1. 核心概念速记表
知识点 | 核心定义 | 生活比喻 | 关键操作 / 作用 |
---|---|---|---|
预训练的概念 | 先在通用数据训练模型,再微调用于特定任务 | 先学基础知识,再学专业 | 减少训练时间,提高模型性能 |
常见分类预训练模型 | 在大规模图像数据预训练的分类模型 | 建筑万能图纸 | 提供高效特征提取能力 |
图像预训练模型发展史 | 模型结构、性能、效率不断进化 | 汽车进化史 | 反映技术从简单到复杂的演进 |
预训练的策略 | 应用预训练模型的具体方法 | 二手房改造 | 平衡训练速度与模型效果 |
预训练代码实战:resnet18 | 调用 ResNet18 预训练模型完成任务 | 乐高套装改造 | 加载权重→修改输出层→微调训练 |
2. 复习口诀
预训练,学基础,微调适配新任务;
经典模型像图纸,拿来就能盖新屋;
模型发展如车进,结构性能都更新;
应用策略有妙招,冻结微调把握好;
ResNet18 实战练,加载微调很关键。
通过这些比喻和学习笔记,希望能让你轻松理解预训练相关知识。如果还有疑问或想深入了解某个部分,随时可以告诉我。
作业(ai):
一、作业 1:CIFAR10 上对比不同预训练模型
1. 任务说明
在 CIFAR10 数据集(10 类图像,共 6 万张)上测试不同预训练模型的性能,观察准确率、训练速度、显存占用等差异。避免与他人重复,可选择小众但高效的模型(如MobileNetV3、ShuffleNetV2、EfficientNet-Lite)。
2. 对比实验设计
模型 | 参数量 | 特点 | 适用场景 | 预期结果 |
---|---|---|---|---|
ResNet18 | 11.7M | 经典残差结构,平衡精度与速度 | 通用任务 | 基准参考 |
MobileNetV3 | 5.4M | 轻量级,深度可分离卷积 | 移动设备部署 | 参数量少,速度快 |
ShuffleNetV2 | 1.4M | 通道混洗,极致轻量 | 低算力场景 | 参数量极低,速度极快 |
EfficientNet-Lite0 | 4.0M | 自动搜索架构,移动端优化 | 移动端 + 高精度需求 | 精度接近 ResNet,速度更快 |
3. 实验步骤(代码框架)
python
运行
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torchvision.models import resnet18, mobilenet_v3_large, shufflenet_v2_x1_0
from efficientnet_pytorch import EfficientNet # 需要额外安装
# 1. 数据准备(CIFAR10)
transform = transforms.Compose([
transforms.Resize(224), # 调整图像大小以适应预训练模型
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
shuffle=True)
# 2. 加载预训练模型并修改最后一层
def get_model(model_name):
if model_name == 'resnet18':
model = resnet18(pretrained=True)
model.fc = nn.Linear(512, 10) # 修改为10类输出
elif model_name == 'mobilenet_v3':
model = mobilenet_v3_large(pretrained=True)
model.classifier[3] = nn.Linear(1280, 10)
elif model_name == 'shufflenet_v2':
model = shufflenet_v2_x1_0(pretrained=True)
model.fc = nn.Linear(1024, 10)
elif model_name == 'efficientnet_lite0':
model = EfficientNet.from_pretrained('efficientnet-lite0')
model._fc = nn.Linear(1280, 10)
return model
# 3. 训练与评估(略,参考之前课程代码)
4. 预期差异分析
- 准确率:ResNet18 ≈ EfficientNet-Lite0 > MobileNetV3 > ShuffleNetV2(预训练模型在 CIFAR10 上微调后,准确率通常在 85%-95% 之间)。
- 训练速度:ShuffleNetV2 > MobileNetV3 > EfficientNet-Lite0 > ResNet18(轻量级模型计算量小,训练更快)。
- 显存占用:ShuffleNetV2 < MobileNetV3 < EfficientNet-Lite0 < ResNet18(参数量越少,显存占用越低)。
二、作业 2:探究 ResNet 的残差结构
1. 理论解释
残差块(Residual Block) 是 ResNet 的核心创新,解决了深层网络训练时的梯度消失 / 爆炸问题。它允许网络学习 “残差映射”(即输入与输出的差异),而非直接学习完整映射。
公式:
plaintext
输出 = 输入 + 残差
其中,残差 = F(x)(经过卷积、激活等操作)
2. 代码探究(以 PyTorch 的 ResNet18 为例)
步骤:
- 在 PyCharm/VSCode 中打开 Python 文件,输入
from torchvision.models import resnet18
。 - 将光标放在
resnet18
上,按Ctrl
并点击(或右键选择 “Go to Definition”),进入源码。 - 找到
BasicBlock
类(ResNet18 的残差块实现):
python
运行
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride) # 3x3卷积
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes) # 3x3卷积
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample # 用于调整输入维度,使与残差匹配
def forward(self, x):
identity = x # 保存原始输入(跳跃连接)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None: # 如果需要调整维度
identity = self.downsample(x)
out += identity # 核心:残差连接(输出 = 残差 + 原始输入)
out = self.relu(out) # 最后再通过ReLU激活
return out
3. 残差结构可视化
豆包ai生成失败
4. 关键发现
- 跳跃连接(Skip Connection):直接将输入
x
加到卷积后的输出上,实现 “残差学习”。 - 维度匹配:当输入输出维度不一致时(如
stride>1
),通过downsample
模块调整输入维度(通常用 1x1 卷积)。 - 激活函数位置:ReLU 激活在残差相加之后,确保非线性特性。
三、总结
- 作业 1:通过对比不同预训练模型,理解模型架构、参数量与性能的权衡,为实际任务选择最优模型。
- 作业 2:深入 ResNet 源码,发现残差结构通过 “跳跃连接” 和 “残差学习” 解决深层网络训练难题,这是现代 CNN 的核心思想之一。
动手实验后,建议记录各模型的训练日志(准确率、耗时、显存),并绘制对比图表,直观展示差异! 📊