python打卡day50

发布于:2025-07-05 ⋅ 阅读:(19) ⋅ 点赞:(0)

知识点回顾:

  1. resnet结构解析
  2. CBAM放置位置的思考
  3. 针对预训练模型的训练策略
    1. 差异化学习率
    2. 三阶段微调

@疏锦行

ps:今日的代码训练时长较长,3080ti大概需要40min的训练时长

作业:

  1. 好好理解下resnet18的模型结构
  2. 尝试对vgg16+cbam进行微调策略

resnet18是一种经典的卷积神经网络,通过引入残差块解决了深度神经网络训练时的梯度消失问题。

# 加载预训练的 VGG16 模型
import torchvision.models as models
vgg16 = models.vgg16(pretrained=True)

# 修改 VGG16 模型,在每个卷积块后添加 CBAM 模块
class VGG16_CBAM(nn.Module):
    def __init__(self, vgg16):
        super(VGG16_CBAM, self).__init__()
        self.features = nn.ModuleList()
        cbam_inserted = False
        for layer in vgg16.features:
            self.features.append(layer)
            if isinstance(layer, nn.MaxPool2d):
                # 在每个最大池化层后插入 CBAM 模块
                in_channels = self.features[-2].out_channels
                self.features.append(CBAM(in_channels))
        self.features = nn.Sequential(*self.features)
        self.avgpool = vgg16.avgpool
        self.classifier = vgg16.classifier

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# 创建带 CBAM 的 VGG16 模型
vgg16_cbam = VGG16_CBAM(vgg16)

# 微调策略:冻结特征提取层,只训练全连接层和 CBAM 模块
for param in vgg16_cbam.features.parameters():
    param.requires_grad = False

# 解冻 CBAM 模块的参数
for layer in vgg16_cbam.features:
    if isinstance(layer, CBAM):
        for param in layer.parameters():
            param.requires_grad = True

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    [param for param in vgg16_cbam.parameters() if param.requires_grad],
    lr=0.001
)

# 训练模型
vgg16_cbam.to(device)
num_epochs = 10
for epoch in range(num_epochs):
    vgg16_cbam.train()
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = vgg16_cbam(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}')


网站公告

今日签到

点亮在社区的每一天
去签到