PyTorch(三)数据集与数据加载器

发布于:2024-06-29 ⋅ 阅读:(10) ⋅ 点赞:(0)

#c 目的 需要的目的

专门处理数据的代码可能会变得「混乱且难以维护」,理想情况下是将「数据集代码」与「模型训练代码」「解耦(decoupled)」,以提高可读性和模块性。

torch.utils.data.DataLoadertorch.utils.data.Dataset,可以使用「预加载数据集」以及「自定义数据」。
Dataset存储样本及其相应的标签。
DataLoader则围绕Dataset包装了一个可迭代对象,以便于轻松访问样本。

1 下载数据集

#e MNIST数据集

以下载TorchVision下的Fashion-MNIST为例。Fashion-MNIST是一个由Zalando的文章图片组成的数据集,包含60,000个训练样本和10,000个测试样本。每个样本包括一个28×28的灰度图像以及来自10个类别之一的相关标签。

# 下载训练数据集
train_data = datasets.FashionMNIST(
    root="data",  # 数据存储的路径
    train=True,   # 指定下载的是训练数据集
    download=True,  # 如果数据不存在,则通过网络下载
    transform=ToTensor()  # 将图片转换为Tensor
)

# 下载测试数据集
test_data = datasets.FashionMNIST(
    root="data",  # 数据存储的路径
    train=False,  # 指定下载的是测试数据集
    download=True,  # 如果数据不存在,则通过网络下载
    transform=ToTensor()  # 将图片转换为Tensor
)

2 迭代和可视化数据

#e 迭代和可视化

lables_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))#创建一个matplotlib图形对象,设置图形的大小为8x8英寸。
cols, rows = 3, 3#设置列数和行数
for i in range(1, cols * rows +1):#循环9次
    sample_idx = torch.randint(len(training_data),size=(1,)).item()
    '''
    使用torch.randint随机生成一个介于0和训练数据集长度之间的整数,作为随机选取的图像的索引。
    size=(1,)指定生成一个数,item()将其转换为Python的标准整数。
    '''
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)#添加子图,设置行数、列数和子图的索引,位置由i决定
    plt.title(lables_map[label])#设置标题
    plt.axis("off")#关闭坐标轴
    plt.imshow(img.squeeze(), cmap="gray")#灰度显示
    # plt.imshow(img.squeeze())#彩色显示,无需指定cmap
    '''
    img.squeeze()将图像张量的维度为1的轴删除,因为imshow函数预期的是一个二维图像。
    cmap="gray"指定了灰度图像。
    '''
plt.show()

3 自定义数据集

#c 要素 自定义数据集要素

自定义的Dataset类必须实现以下三个函数:

__init__:初始化函数,用于设置数据集的属性,如加载数据、预处理步骤等。

__len__:返回数据集中样本的数量。这个函数使得Dataset对象可以被len()函数调用,通常返回数据集中样本的总数。

__getitem__:根据索引获取单个样本。这个函数允许通过索引访问数据集中的每个样本。索引从0开始,对应于数据集中的第一个样本。

#e 三要素 自定义数据集要素

import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transfrom=None, target_tansform=None):
        self.img_labels = pd.read_csv(annotations_file)#读取CSV文件
        self.img_dir = img_dir#图像目录
        self.transfrom = transfrom#图像转换
        self.target_tansform = target_tansform#目标转换
    
    def __len__(self):
        return len(self.img_labels)#返回数据集的长度
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        imgage = read_image(img_path)#读取图像,转换成张量
        label = self.img_labels.iloc[idx, 1]#检索对应的标签
        if self.transfrom:#转换图像
            imgage = self.transfrom(imgage)
        if self.target_tansform:
            label = self.target_tansform(label)
        return imgage, label #以元组的形式返回图像和标签

4 使用DataLoader准备训练数据

#c 思路 数据准备思路

在训练模型的过程中,通常希望以“小批量”的形式传递样本,每个周期重新打乱数据以减少模型的过拟合,并使用Python的multiprocessing多进程来加速数据检索。数据集(Dataset)负责逐个样本地获取数据集的特征和标签。DataLoader是一个可迭代对象,它抽象了这些复杂性,提供了一个简单的API。

#e 准备代码

from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
#在这里,DataLoader将训练数据集传递给train_dataloader,每个小批量包含64个特征和标签对,shuffle=True表示在每个周期重新打乱数据。
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

5 通过DataLoader迭代

#c 特点 迭代特点

将该数据集加载到DataLoader中,可以根据需要迭代遍历数据集。每次迭代都会返回一批train_featurestrain_labels(分别包含batch_size=64个特征和标签)。若指定了shuffle=True,在遍历完所有的批次之后,数据会被重新打乱。这意味着每个周期(epoch)开始时,数据的顺序都会随机化,有助于模型学习到「更加泛化「的特征,从而减少「过拟合」的风险。

#e 迭代代码

train_features, train_labels = next(iter(train_dataloader))
#iter(train_dataloader)返回一个迭代器对象,next()函数返回迭代器的下一批数据
print(f"Feature batch shape: {train_features.size()}")#size()返回张量的形状(批量大小、通道数、高度、宽度)
print(f"Labels batch shape: {train_labels.size()}")#size()返回张量的形状(批量大小)
img = train_features[0].squeeze()#删除维度为1的轴,特别是当图像以(1,高度,宽度)或(1,通道数,高度,宽度)的形式存在时。
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
'''
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 6
'''

网站公告

今日签到

点亮在社区的每一天
去签到