Day51 复习日-模型改进

发布于:2025-07-05 ⋅ 阅读:(17) ⋅ 点赞:(0)

day43对自己找的数据集用简单cnn训练,现在用预训练,加入注意力等

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np

# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  

# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 1. 数据预处理
# # 计算均值和方差(仅运行一次)
# def calculate_mean_std(dataloader):
#     mean = torch.zeros(3)
#     std = torch.zeros(3)
#     total_images = 0
#     for images, _ in dataloader:
#         batch_size = images.size(0)
#         images = images.view(batch_size, 3, -1)
#         mean += images.mean(2).sum(0)
#         std += images.std(2).sum(0)
#         total_images += batch_size
#     mean /= total_images
#     std /= total_images
#     return mean, std

# # 用无增强的dataloader计算(避免增强影响统计)
# temp_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
# temp_dataset = datasets.ImageFolder(root=your_data_root, transform=temp_transform)
# temp_loader = DataLoader(temp_dataset, batch_size=32, shuffle=False)
# mean, std = calculate_mean_std(temp_loader)
# print(f"数据集均值:{mean},方差:{std}")

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  # 使用ImageNet的均值和方差
    transforms.Normalize((0.4790, 0.4813, 0.4370), (0.2123, 0.2066, 0.2085))
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    transforms.Normalize((0.4790, 0.4813, 0.4370), (0.2123, 0.2066, 0.2085))
])


# 2. 加载自定义数据集
full_dataset = datasets.ImageFolder(
    root=r"BengaliFishImages\fish_images",  
    transform=train_transform
)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

# 3. 创建数据加载器
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 4. 定义注意力机制

# SE注意力机制模块
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

# CBAM注意力机制模块
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class CBAMBlock(nn.Module):
    def __init__(self, channel, ratio=16, kernel_size=7):
        super(CBAMBlock, self).__init__()
        self.channel_attention = ChannelAttention(channel, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.channel_attention(x)
        x = x * self.spatial_attention(x)
        return x

# 5. 定义改进的CNN模型(可选择添加SE或CBAM注意力)
class ImprovedCNN(nn.Module):
    def __init__(self, num_classes=20, attention_type=None):
        super(ImprovedCNN, self).__init__()
        self.attention_type = attention_type
        
        # 第一个卷积块
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)  # 128 -> 64
        if attention_type == 'se':
            self.att1 = SEBlock(32)
        elif attention_type == 'cbam':
            self.att1 = CBAMBlock(32)
        
        # 第二个卷积块
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)  # 64 -> 32
        if attention_type == 'se':
            self.att2 = SEBlock(64)
        elif attention_type == 'cbam':
            self.att2 = CBAMBlock(64)
        
        # 第三个卷积块
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(2)  # 32 -> 16
        if attention_type == 'se':
            self.att3 = SEBlock(128)
        elif attention_type == 'cbam':
            self.att3 = CBAMBlock(128)
        
        # 第四个卷积块
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.relu4 = nn.ReLU()
        self.pool4 = nn.MaxPool2d(2)  # 16 -> 8
        if attention_type == 'se':
            self.att4 = SEBlock(256)
        elif attention_type == 'cbam':
            self.att4 = CBAMBlock(256)
        
        # 全连接层
        self.fc1 = nn.Linear(256 * 8 * 8, 512)
        self.dropout = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        # 卷积块 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        if self.attention_type is not None:
            x = self.att1(x)
        
        # 卷积块 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        if self.attention_type is not None:
            x = self.att2(x)
        
        # 卷积块 3
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.pool3(x)
        if self.attention_type is not None:
            x = self.att3(x)
        
        # 卷积块 4
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu4(x)
        x = self.pool4(x)
        if self.attention_type is not None:
            x = self.att4(x)
        
        # 全连接层
        x = x.view(-1, 256 * 8 * 8)
        x = self.fc1(x)
        x = self.relu4(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# 6. 基于预训练模型的分类器
def create_pretrained_model(model_name, num_classes=20, freeze_feature=True, attention_type=None):
    """
    创建基于预训练模型的分类器
    
    Args:
        model_name: 预训练模型名称,如'resnet50', 'vgg16', 'mobilenet_v2'
        num_classes: 分类类别数
        freeze_feature: 是否冻结特征提取部分
        attention_type: 注意力类型,None, 'se' 或 'cbam'
        
    Returns:
        构建好的模型
    """
    if model_name == 'resnet50':
        model = models.resnet50(pretrained=True)
        # 冻结特征提取部分
        if freeze_feature:
            for param in model.parameters():
                param.requires_grad = False
        
        # 添加注意力机制(可选)
        if attention_type == 'se':
            model.layer4[0].conv1 = nn.Sequential(
                model.layer4[0].conv1,
                SEBlock(512)
            )
        elif attention_type == 'cbam':
            model.layer4[0].conv1 = nn.Sequential(
                model.layer4[0].conv1,
                CBAMBlock(512)
            )
        
        # 替换最后的全连接层
        num_ftrs = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    
    elif model_name == 'vgg16':
        model = models.vgg16(pretrained=True)
        if freeze_feature:
            for param in model.features.parameters():
                param.requires_grad = False
        
        # 添加注意力机制(可选)
        if attention_type is not None:
            att_module = SEBlock(512) if attention_type == 'se' else CBAMBlock(512)
            model.features = nn.Sequential(
                *list(model.features.children()),
                att_module
            )
        
        # 替换分类器
        num_ftrs = model.classifier[6].in_features
        model.classifier[6] = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    
    elif model_name == 'mobilenet_v2':
        model = models.mobilenet_v2(pretrained=True)
        if freeze_feature:
            for param in model.features.parameters():
                param.requires_grad = False
        
        # 添加注意力机制(可选)
        if attention_type is not None:
            att_module = SEBlock(1280) if attention_type == 'se' else CBAMBlock(1280)
            model.features = nn.Sequential(
                *list(model.features.children()),
                att_module
            )
        
        # 替换分类器
        num_ftrs = model.classifier[1].in_features
        model.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    
    else:
        raise ValueError(f"不支持的模型名称: {model_name}")
    
    return model

# 7. 训练与测试函数(保持原有功能,略作调整)
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs):
    model.train()
    all_iter_losses = []
    iter_indices = []
    train_acc_history = []
    test_acc_history = []
    train_loss_history = []
    test_loss_history = []
    
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            iter_loss = loss.item()
            all_iter_losses.append(iter_loss)
            iter_indices.append(epoch * len(train_loader) + batch_idx + 1)
            
            running_loss += iter_loss
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
            if (batch_idx + 1) % 100 == 0:
                print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} '
                      f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')
        
        epoch_train_loss = running_loss / len(train_loader)
        epoch_train_acc = 100. * correct / total
        train_acc_history.append(epoch_train_acc)
        train_loss_history.append(epoch_train_loss)
        
        # 测试阶段
        model.eval()
        test_loss = 0
        correct_test = 0
        total_test = 0
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                test_loss += criterion(output, target).item()
                _, predicted = output.max(1)
                total_test += target.size(0)
                correct_test += predicted.eq(target).sum().item()
        
        epoch_test_loss = test_loss / len(test_loader)
        epoch_test_acc = 100. * correct_test / total_test
        test_acc_history.append(epoch_test_acc)
        test_loss_history.append(epoch_test_loss)
        
        scheduler.step(epoch_test_loss)
        
        print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')
    
    plot_iter_losses(all_iter_losses, iter_indices)
    plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)
    
    return epoch_test_acc

# 8. 绘图函数(保持不变)
def plot_iter_losses(losses, indices):
    plt.figure(figsize=(10, 4))
    plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')
    plt.xlabel('Iteration(Batch序号)')
    plt.ylabel('损失值')
    plt.title('每个 Iteration 的训练损失')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):
    epochs = range(1, len(train_acc) + 1)
    
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_acc, 'b-', label='训练准确率')
    plt.plot(epochs, test_acc, 'r-', label='测试准确率')
    plt.xlabel('Epoch')
    plt.ylabel('准确率 (%)')
    plt.title('训练和测试准确率')
    plt.legend()
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_loss, 'b-', label='训练损失')
    plt.plot(epochs, test_loss, 'r-', label='测试损失')
    plt.xlabel('Epoch')
    plt.ylabel('损失值')
    plt.title('训练和测试损失')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()

# 9. 模型训练配置与执行
def main():
    # 选择模型类型: 'custom' (自定义CNN), 'resnet50', 'vgg16', 'mobilenet_v2'
    model_type = 'resnet50'  # 可更换为其他模型
    
    # 选择注意力机制: None, 'se', 'cbam'
    attention_type = 'cbam'  # 可更换为其他注意力类型或None
    
    # 训练参数
    epochs = 30  # 预训练模型通常需要更少的epochs
    num_classes = 20
    
    # 初始化模型
    if model_type == 'custom':
        print(f"使用自定义CNN模型,注意力机制: {attention_type}")
        model = ImprovedCNN(num_classes=num_classes, attention_type=attention_type).to(device)
    else:
        print(f"使用预训练{model_type}模型,注意力机制: {attention_type}")
        # model = create_pretrained_model(
        #     model_name=model_type,
        #     num_classes=num_classes,
        #     freeze_feature=False,  # 设为True表示只训练顶层,False表示微调整个模型
        #     attention_type=attention_type
        # ).to(device)
        # 使用预训练模型,先冻结特征层
        model = create_pretrained_model(
            model_name=model_type,
            num_classes=num_classes,
            freeze_feature=True,  # 先冻结特征层,只训练顶层
            attention_type=None  # 禁用注意力
        ).to(device)
    
    # 定义损失函数、优化器和学习率调度器
    criterion = nn.CrossEntropyLoss()
    # optimizer = optim.Adam(model.parameters(), lr=0.001)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    #     optimizer, mode='min', patience=3, factor=0.5
    # )

    # 调整优化器和学习率
    optimizer = optim.Adam(model.parameters(), lr=1e-4)  # 更小的学习率
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=5, factor=0.5, min_lr=1e-6
    )
    
    # 开始训练
    print(f"开始训练...")
    final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs)
    print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")
    
    # 保存模型
    model_filename = f"{model_type}_{attention_type if attention_type else 'no_att'}_fish_model.pth"
    torch.save(model.state_dict(), model_filename)
    print(f"模型已保存为: {model_filename}")

if __name__ == "__main__":
    main()

@浙大疏锦行