DAY38打卡

发布于:2025-06-01 ⋅ 阅读:(26) ⋅ 点赞:(0)

作业:了解下cifar数据集,尝试获取其中一张图片

cifar数据集是一个三通道的彩色图像集,包括10个不同物种。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt

# 设置随机种子,确保结果可复现
torch.manual_seed(42)
# 1. 数据预处理,该写法非常类似于管道pipeline
# transforms 模块提供了一系列常用的图像预处理操作

# CIFAR-10的归一化和标准化转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量并归一化到[0,1]
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),  # CIFAR-10数据集的RGB通道均值
        (0.2470, 0.2435, 0.2616)   # CIFAR-10数据集的RGB通道标准差
    )
])
# 2. 加载cifar-10数据集,如果没有会自动下载
train_dataset = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.CIFAR10(
    root='./data',
    train=False,
    transform=transform
)

# 定义类别
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')
 
# 随机选择一张图片
idx = torch.randint(0, len(train_dataset), size=(1,))
img, label = train_dataset[idx]
 
# 反标准化函数
def denormalize(x):
    mean = torch.tensor([0.4914, 0.4822, 0.4465])
    std = torch.tensor([0.2470, 0.2435, 0.2616])
    # CIFAR-10是彩色图像,需要对所有通道进行反标准化
    return x * std[:, None, None] + mean[:, None, None]
 
# 显示图片
plt.figure()
plt.imshow(denormalize(img).permute(1, 2, 0))  # 调整通道顺序以正确显示彩色图像
plt.title(f'Label: {classes[label]}')
plt.axis('off')
plt.show()
 
 
# 3. 创建数据加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
    shuffle=True # 随机打乱数据
)

@浙大疏锦行


网站公告

今日签到

点亮在社区的每一天
去签到