python打卡day54

发布于:2025-06-21 ⋅ 阅读:(20) ⋅ 点赞:(0)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                      download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                     download=True, transform=transform)
testloader = DataLoader(testset, batch_size=128, shuffle=False)
# 原始Inception模块
class Inception(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=1),
            nn.ReLU()
        )
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, 96, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(96, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.ReLU()
        )
        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, 32, kernel_size=1),
            nn.ReLU()
        )

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        branch_pool = self.branch_pool(x)
        return torch.cat([branch1x1, branch3x3, branch5x5, branch_pool], dim=1)
# 带残差的Inception模块
class InceptionWithResidual(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=1),
            nn.ReLU()
        )
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, 96, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(96, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.ReLU()
        )
        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, 32, kernel_size=1),
            nn.ReLU()
        )
        
        if in_channels != 256:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, 256, kernel_size=1),
                nn.BatchNorm2d(256)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        branch_pool = self.branch_pool(x)
        outputs = torch.cat([branch1x1, branch3x3, branch5x5, branch_pool], dim=1)
        shortcut = self.shortcut(x)
        return F.relu(outputs + shortcut)
# CBAM注意力模块
class CBAM(nn.Module):
    def __init__(self, channels, reduction_ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(channels // reduction_ratio, channels)
        )
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3)

    def forward(self, x):
        # 通道注意力
        avg_out = self.fc(self.avg_pool(x).squeeze())
        max_out = self.fc(self.max_pool(x).squeeze())
        channel_att = torch.sigmoid(avg_out + max_out).unsqueeze(2).unsqueeze(3)
        x = x * channel_att
        
        # 空间注意力
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        spatial_att = torch.cat([avg_out, max_out], dim=1)
        spatial_att = torch.sigmoid(self.conv(spatial_att))
        return x * spatial_att

# 带CBAM的Inception模块
class InceptionWithCBAM(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=1),
            nn.ReLU()
        )
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, 96, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(96, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.ReLU()
        )
        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, 32, kernel_size=1),
            nn.ReLU()
        )
        self.cbam = CBAM(256)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        branch_pool = self.branch_pool(x)
        outputs = torch.cat([branch1x1, branch3x3, branch5x5, branch_pool], dim=1)
        return self.cbam(outputs)
# 完整网络架构
class InceptionNet(nn.Module):
    def __init__(self, num_classes=10, module_type='original'):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        if module_type == 'original':
            self.inception1 = Inception(64)
            self.inception2 = Inception(256)
        elif module_type == 'residual':
            self.inception1 = InceptionWithResidual(64)
            self.inception2 = InceptionWithResidual(256)
        elif module_type == 'cbam':
            self.inception1 = InceptionWithCBAM(64)
            self.inception2 = InceptionWithCBAM(256)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.inception1(x)
        x = self.inception2(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
# 训练函数
def train(model, epoch):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 100 == 99:
            print(f'Epoch: {epoch + 1}, Batch: {i + 1}, Loss: {running_loss / 100:.3f}')
            running_loss = 0.0

# 测试函数
def test(model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy on test set: {100 * correct / total:.2f}%')
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 原始Inception网络
    print("训练原始Inception网络:")
    model = InceptionNet(module_type='original').to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(5):
        train(model, epoch)
        test(model)
    
    # 带残差的Inception网络
    print("\n训练带残差的Inception网络:")
    model = InceptionNet(module_type='residual').to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(5):
        train(model, epoch)
        test(model)
    
    # 带CBAM的Inception网络
    print("\n训练带CBAM的Inception网络:")
    model = InceptionNet(module_type='cbam').to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(5):
        train(model, epoch)
        test(model)
训练原始Inception网络:
Epoch: 1, Batch: 100, Loss: 1.982
Epoch: 1, Batch: 200, Loss: 1.718
Epoch: 1, Batch: 300, Loss: 1.602
Accuracy on test set: 43.35%
Epoch: 2, Batch: 100, Loss: 1.475
Epoch: 2, Batch: 200, Loss: 1.405
Epoch: 2, Batch: 300, Loss: 1.371
Accuracy on test set: 53.42%
Epoch: 3, Batch: 100, Loss: 1.279
Epoch: 3, Batch: 200, Loss: 1.239
Epoch: 3, Batch: 300, Loss: 1.197
Accuracy on test set: 59.10%
Epoch: 4, Batch: 100, Loss: 1.130
Epoch: 4, Batch: 200, Loss: 1.118
Epoch: 4, Batch: 300, Loss: 1.084
Accuracy on test set: 60.84%
Epoch: 5, Batch: 100, Loss: 1.061
Epoch: 5, Batch: 200, Loss: 1.015
Epoch: 5, Batch: 300, Loss: 1.005
Accuracy on test set: 59.86%

训练带残差的Inception网络:
Epoch: 1, Batch: 100, Loss: 1.829
Epoch: 1, Batch: 200, Loss: 1.600
Epoch: 1, Batch: 300, Loss: 1.473
Accuracy on test set: 50.87%
Epoch: 2, Batch: 100, Loss: 1.324
Epoch: 2, Batch: 200, Loss: 1.267
Epoch: 2, Batch: 300, Loss: 1.231
Accuracy on test set: 58.51%
Epoch: 3, Batch: 100, Loss: 1.132
Epoch: 3, Batch: 200, Loss: 1.100
Epoch: 3, Batch: 300, Loss: 1.074
Accuracy on test set: 60.79%
Epoch: 4, Batch: 100, Loss: 1.027
Epoch: 4, Batch: 200, Loss: 1.000
Epoch: 4, Batch: 300, Loss: 0.987
Accuracy on test set: 60.19%
Epoch: 5, Batch: 100, Loss: 0.965
Epoch: 5, Batch: 200, Loss: 0.934
Epoch: 5, Batch: 300, Loss: 0.918
Accuracy on test set: 66.30%

训练带CBAM的Inception网络:
Epoch: 1, Batch: 100, Loss: 2.038
Epoch: 1, Batch: 200, Loss: 1.754
Epoch: 1, Batch: 300, Loss: 1.653
Accuracy on test set: 40.46%
Epoch: 2, Batch: 100, Loss: 1.523
Epoch: 2, Batch: 200, Loss: 1.450
Epoch: 2, Batch: 300, Loss: 1.414
Accuracy on test set: 51.94%
Epoch: 3, Batch: 100, Loss: 1.324
Epoch: 3, Batch: 200, Loss: 1.287
Epoch: 3, Batch: 300, Loss: 1.225
Accuracy on test set: 56.31%
Epoch: 4, Batch: 100, Loss: 1.177
Epoch: 4, Batch: 200, Loss: 1.135
Epoch: 4, Batch: 300, Loss: 1.105
Accuracy on test set: 62.34%
Epoch: 5, Batch: 100, Loss: 1.072
Epoch: 5, Batch: 200, Loss: 1.029
Epoch: 5, Batch: 300, Loss: 1.008
Accuracy on test set: 64.95%

@浙大疏锦行


网站公告

今日签到

点亮在社区的每一天
去签到