torchvision中数据集的使用
import torchvision
import tarfile
from torch.utils.tensorboard import SummaryWriter
# 定义数据预处理流水线:只包含ToTensor转换
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor() # 将PIL图像转换为Tensor格式
])
# CIFAR10数据集说明:
# - 包含6万张32x32彩色图片
# - 10个类别:['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# - 训练集:50,000张,测试集:10,000张
train_set = torchvision.datasets.CIFAR10(root="./P_10_dataset", # 数据集保存路径
train=True, # 加载训练集
transform=dataset_transform, # 应用转换
download=True) # 自动下载
test_set = torchvision.datasets.CIFAR10(root="./P_10_dataset",
train=False, # 加载测试集
transform=dataset_transform,
download=True)
# 如果需要手动解压(通常不需要,PyTorch会自动处理)
# with tarfile.open('P_10_dataset/cifar-10-python.tar.gz', 'r:gz') as tar:
# tar.extractall(path='P_10_dataset')
# 检查数据集样本:转换后是(tensor, 标签)元组
# print(test_set[0]) # 输出示例: (<Tensor>, 3)
# 获取数据集类别名称
# print(test_set.classes) # 输出10个类别的名称
# 获取单个样本
# img, target = test_set[0]
# print(img) # 输出Tensor对象
# print(target) # 输出整数标签(如3)
# print(test_set.classes[target]) # 输出对应的类别名称(如'cat')
# 验证转换是否成功
print(test_set[0]) # 确认输出为Tensor格式
# 创建TensorBoard写入器
writer = SummaryWriter("p10")
# 将测试集前10个样本写入TensorBoard
for i in range(10):
img, target = test_set[i] # 获取图像Tensor和标签
writer.add_image("test_set", img, i) # 添加到TensorBoard
# 扩展:同时添加类别标签作为标题
# writer.add_image(f"test_set/{test_set.classes[target]}", img, i)
writer.close() # 关闭写入器
1. torchvision.datasets 模块
作用:提供常用计算机视觉数据集的便捷访问
常用数据集:
CIFAR10/100
:小尺寸彩色图像分类MNIST/FashionMNIST
:手写数字/服装灰度图像ImageNet
:大规模图像分类(需单独下载)COCO
:目标检测与分割数据集
核心参数:
torchvision.datasets.XXX(root, train, transform, download)
root
:数据集存储路径train
:True=训练集,False=测试集transform
:数据预处理流水线download
:自动下载数据集
CIFAR10数据集,有6w张彩色图片,5W张用作train,1w张用于test 类别包含以下10个类别,每个类别6000张 ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
DataLoader的使用
视频代码
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 准备测试集:使用CIFAR10数据集,应用ToTensor转换
test_data = torchvision.datasets.CIFAR10(
root="./P_10_dataset", # 数据集存储路径
train=False, # 使用测试集
transform=torchvision.transforms.ToTensor() # 将PIL图像转为Tensor
)
# 创建DataLoader
test_loader = DataLoader(
dataset=test_data, # 要加载的数据集
batch_size=64, # 每批加载的样本数
shuffle=True, # 每个epoch是否打乱数据顺序
num_workers=0, # 数据加载使用的子进程数(0表示主进程)
drop_last=True # 是否丢弃最后不足batch_size的批次
)
# 从测试数据集中取出第一个样本
img, target = test_data[0]
print(img.shape) # 输出: torch.Size([3, 32, 32]) - 3通道,32x32大小
print(target) # 输出: 3 - 对应的类别标签
# 创建TensorBoard写入器
writer = SummaryWriter('dataloader_logs')
step = 0 # 步数计数器
# 进行2个epoch的迭代
for epoch in range(2): # epoch数
# 遍历DataLoader中的所有批次
for data in test_loader:
imgs, targets = data # 解包批次数据
# 将当前批次的所有图像写入TensorBoard
writer.add_images("Epoch:{}".format(epoch), imgs, step)
step += 1 # 增加步数
# 关闭写入器
writer.close()
# 批次数据形状示例:
# imgs: torch.Size([64, 3, 32, 32])
# targets: tensor([5, 6, 5, 0, ...]) - 64个标签
1. DataLoader 核心功能
功能 | 说明 |
---|---|
批量加载 | 将数据集分成多个批次(batch) |
数据打乱 | 每个epoch重新随机排序数据 |
并行加载 | 使用多进程加速数据加载 |
自动分批 | 处理最后不足batch大小的批次 |
2. DataLoader 重要参数
参数 | 作用 | 常用值 |
---|---|---|
batch_size |
每批样本数 | 32/64/128 |
shuffle |
是否打乱数据 | True(训练)/False(测试) |
num_workers |
加载数据的子进程数 | 0(主进程)/2/4/8 |
drop_last |
是否丢弃最后不足batch的样本 | True/False |
pin_memory |
是否将数据复制到CUDA固定内存 | True(GPU训练) |
3. 批次数据结构
图像数据:
(batch_size, channels, height, width)
示例:
[64, 3, 32, 32]
表示64张32x32的RGB图像
标签数据:
(batch_size,)
示例:
tensor([3, 5, 9, ...])
64个类别标签
4. TensorBoard可视化
add_image()
: 添加单张图像add_images()
: 添加多张图像(整个批次)命名技巧: 包含epoch信息便于区分不同训练阶段
DataLoader是PyTorch数据管道的核心组件,它不仅仅是简单的数据分批工具,更是连接数据预处理与模型训练的桥梁。合理配置DataLoader参数可以显著提升训练效率和资源利用率,特别是在处理大规模数据集时。