DAY 38 Dataset和Dataloader类
对应5. 27作业
知识点回顾:
- Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
- Dataloader类
- minist手写数据集的了解
作业:了解下cifar数据集,尝试获取其中一张图片
MNIST手写数字数据集。该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28像素,共包含10个类别。因为每个数据的维度比较小,所以既可以视为结构化数据,用机器学习、MLP训练,也可以视为图像数据,用卷积神经网络训练。
导入Dataset类和Dataloader类必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt
# 设置随机种子,确保结果可复现
torch.manual_seed(42)
其中torchvision是一个计算机视觉的库,它的常见方法如下:
torchvision
- datasets # 视觉数据集(如 MNIST、CIFAR)
- transforms # 视觉数据预处理(如裁剪、翻转、归一化)
- models # 预训练模型(如 ResNet、YOLO)
- utils # 视觉工具函数(如目标检测后处理)
- io # 图像/视频 IO 操作
1.其中的transforms模块提供了一系列常用的图像预处理操作:
# 先归一化,再标准化
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量并归一化到[0,1]
transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])
2.MNIST数据集
# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
test_dataset = datasets.MNIST(
root='./data',
train=False,
transform=transform
)
随机取出一张图片(包括图片和标签)
import matplotlib.pyplot as plt
# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签
可视化取出的图像
# 可视化原始图像(需要反归一化)
def imshow(img):
img = img * 0.3081 + 0.1307 # 反标准化
npimg = img.numpy()
plt.imshow(npimg[0], cmap='gray') # 显示灰度图像
plt.show()
print(f"Label: {label}")
imshow(image)
当然也可以用相同的思路取出两张:
import matplotlib.pyplot as plt
import torch
# 随机选择两张图片的索引
sample_idx_1 = torch.randint(0, len(train_dataset), size=(1,)).item()
sample_idx_2 = torch.randint(0, len(train_dataset), size=(1,)).item()
# 获取图片和标签
image_1, label_1 = train_dataset[sample_idx_1]
image_2, label_2 = train_dataset[sample_idx_2]
# 定义一个函数来反归一化并显示图像
def imshow(img):
img = img * 0.3081 + 0.1307 # 反标准化
npimg = img.numpy()
plt.imshow(npimg[0], cmap='gray')
# 创建一个包含两个子图的画布
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
# 显示第一张图片
plt.sca(axes[0])
imshow(image_1)
axes[0].set_title(f'Label: {label_1}')
axes[0].axis('off')
# 显示第二张图片
plt.sca(axes[1])
imshow(image_2)
axes[1].set_title(f'Label: {label_2}')
axes[1].axis('off')
plt.show()
我们是如何通过dataset类取出图像的呢??
PyTorch 的torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:
- __len__():返回数据集的样本总数。
- __getitem__(idx):根据索引idx返回对应样本的数据和标签。
PyTorch 要求所有数据集必须实现__getitem__和__len__,这样才能被DataLoader等工具兼容。这是一种接口约定,类似函数参数的规范。这意味着,如果你要创建一个自定义数据集,你需要实现这两个方法,否则PyTorch将无法识别你的数据集。
__getitem__方法用于让对象支持索引操作,当使用[]语法访问对象元素时,Python 会自动调用该方法。
# 示例代码
class MyList:
def __init__(self):
self.data = [10, 20, 30, 40, 50]
def __getitem__(self, idx):
return self.data[idx]
# 创建类的实例
my_list_obj = MyList()
# 此时可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2]) # 输出:30
__len__方法用于返回对象中元素的数量,当使用内置函数len()作用于对象时,Python 会自动调用该方法。
class MyList:
def __init__(self):
self.data = [10, 20, 30, 40, 50]
def __len__(self):
return len(self.data)
# 创建类的实例
my_list_obj = MyList()
# 使用len()函数获取元素数量,这会自动调用__len__方法
print(len(my_list_obj)) # 输出:5
再介绍一下Dataloader类
# 3. 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
shuffle=True # 随机打乱数据
)
test_loader = DataLoader(
test_dataset,
batch_size=1000 # 每个批次1000张图片
# shuffle=False # 测试时不需要打乱数据
)
作业:
维度 | CIFAR 数据集 | MNIST 手写数据集 |
---|---|---|
创建机构 / 背景 | 由加拿大先进研究所(CIFAR)开发,用于计算机视觉研究 | 由纽约大学柯朗数学科学研究所开发,用于手写数字识别研究 |
数据类型 | 自然物体彩色图像(如动物、交通工具等) | 手写数字灰度图像(0-9) |
图像分辨率 | 32×32 像素(RGB 三通道,彩色图像) | 28×28 像素(单通道,灰度图像) |
数据集规模 | - 总样本数:60,000 张 - 训练集:50,000 张 - 测试集:10,000 张 |
- 总样本数:70,000 张 - 训练集:60,000 张 - 测试集:10,000 张 |
类别数量 | - CIFAR-10:10 个大类 - CIFAR-100:100 个细分类别(20 个超类) |
10 个类别(数字 0-9) |
任务难度 | - 图像分辨率低但包含复杂背景和类内差异 - CIFAR-100 因类别多、区分度小,难度更高 |
图像背景简单,数字形态相对固定,难度较低 |
典型应用场景 | 图像分类、目标识别、深度学习算法基准测试(如 CNN 优化) | 手写数字识别、基础算法验证(如神经网络入门案例) |
数据预处理 | 需进行色彩归一化、数据增强(如裁剪、翻转)等处理 | 通常仅需灰度归一化和简单降噪处理 |
模型性能基准 | - CIFAR-10 顶尖模型准确率:~97% - CIFAR-100 顶尖模型准确率:~87% |
顶尖模型准确率:~99.7%(如 CNN) |
相似点 | - 均为图像分类领域经典基准数据集 - 均包含训练集和测试集,结构标准化 - 广泛用于算法教学、研究和性能对比 |
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt
# 设置随机种子,确保结果可复现
torch.manual_seed(42)
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化处理,将像素值从[0,1]缩放到[-1,1]
])
# 加载训练集
train_dataset = datasets.CIFAR10(
root='./data', # 数据存放路径
train=True, # 是否为训练集
download=True, # 如果数据不存在,是否自动下载
transform=transform # 数据预处理
)
# 加载测试集
test_dataset = datasets.CIFAR10(
root='./data', # 数据存放路径
train=False, # 是否为测试集
transform=transform # 数据预处理
)
import matplotlib.pyplot as plt
# 类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签
# 可视化原始图像(需要反归一化)
def imshow(img, title=None):
img = img / 2 + 0.5 # 反归一化:将[-1,1]范围转回[0,1]
npimg = img.numpy()
plt.figure(figsize=(4, 4))
plt.imshow(np.transpose(npimg, (1, 2, 0))) # 调整通道顺序:从[C,H,W]到[H,W,C]
if title:
plt.title(title)
plt.axis('off')
plt.show()
print(f"Label: {label} ({classes[label]})")
imshow(image, f"Label: {classes[label]}")