python打卡day45

发布于:2025-06-30 ⋅ 阅读:(23) ⋅ 点赞:(0)

@疏锦行

知识点回顾:

  1. tensorboard的发展历史和原理
  2. tensorboard的常见操作
  3. tensorboard在cifar上的实战:MLP和CNN模型

    作业:对resnet18在cifar10上采用微调策略下,用tensorboard监控训练过程。

    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torchvision import datasets, transforms, models
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    import os
    
    # 设置随机种子保证可重复性
    torch.manual_seed(42)
    
    # 定义数据预处理
    transform = {    
        'train': transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]),
        'test': transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    }
    
    # 加载 CIFAR-10 数据集
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform['train'])
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform['test'])
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    # 初始化 TensorBoard 的 SummaryWriter
    log_dir = 'runs/resnet18_cifar10_finetune'
    if os.path.exists(log_dir):
        i = 1
        while os.path.exists(f"{log_dir}_{i}"):
            i += 1
        log_dir = f"{log_dir}_{i}"
    writer = SummaryWriter(log_dir)
    
    # 加载预训练的 ResNet18 模型
    model = models.resnet18(pretrained=True)
    
    # 修改最后一层全连接层以适应 CIFAR-10 的 10 个类别
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 10)
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    
    # 设置训练设备
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # 训练和验证函数
    num_epochs = 10
    for epoch in range(num_epochs):
        # 训练阶段
        model.train()
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
    
        epoch_loss = running_loss / len(train_dataset)
        epoch_acc = running_corrects.double() / len(train_dataset)
        
        # 使用 TensorBoard 记录训练损失和准确率
        writer.add_scalar('Train/Loss', epoch_loss, epoch)
        writer.add_scalar('Train/Accuracy', epoch_acc, epoch)
        
        # 验证阶段
        model.eval()
        val_running_loss = 0.0
        val_running_corrects = 0
        
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                
                val_running_loss += loss.item() * inputs.size(0)
                val_running_corrects += torch.sum(preds == labels.data)
    
        val_epoch_loss = val_running_loss / len(test_dataset)
        val_epoch_acc = val_running_corrects.double() / len(test_dataset)
        
        # 使用 TensorBoard 记录验证损失和准确率
        writer.add_scalar('Validation/Loss', val_epoch_loss, epoch)
        writer.add_scalar('Validation/Accuracy', val_epoch_acc, epoch)
        
        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}, Val Loss: {val_epoch_loss:.4f}, Val Acc: {val_epoch_acc:.4f}')
    
    # 可视化模型结构
    dataiter = iter(train_loader)
    images, labels = next(dataiter)
    images = images.to(device)
    writer.add_graph(model, images)
    writer.close()


网站公告

今日签到

点亮在社区的每一天
去签到