Pytorch下载Mnist手写数据识别训练数据集的代码详解

发布于:2025-07-20 ⋅ 阅读:(17) ⋅ 点赞:(0)
datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

1. datasets.MNIST

这是torchvision.datasets模块中的一个类,专门用于加载MNIST数据集。MNIST是一个著名的手写数字识别数据集,包含60,000个训练样本和10,000个测试样本,每个样本是28x28的灰度图像。

2. 参数解释

root='./data'
  • 作用:指定数据集下载和存储的根目录。

  • 解释:这里设置为当前目录下的data文件夹。如果该文件夹不存在,PyTorch会自动创建它。

  • 默认值:通常没有默认值,必须指定。

train=False
  • 作用:指定加载的是训练集还是测试集。

  • 解释

    • train=True:加载训练集(60,000个样本)。

    • train=False:加载测试集(10,000个样本)。

  • 默认值:通常是True(加载训练集),但明确指定是好的实践。

download=True
  • 作用:控制是否下载数据集。

  • 解释

    • download=True:如果数据集在root目录下不存在,则自动下载。

    • download=False:不下载,仅尝试从root目录加载。

  • 默认值:通常是False,但这里显式设置为True以确保下载。

transform=transforms.ToTensor()
  • 作用:指定对加载的数据进行何种预处理或转换。

  • 解释

    • transforms.ToTensor()是PyTorch的一个转换函数,它将PIL图像或NumPy数组转换为PyTorch张量(torch.Tensor),并自动进行以下操作:

      1. 将图像的像素值从[0, 255]缩放到[0.0, 1.0](除以255)。

      2. 将图像的形状从(H, W, C)(高度、宽度、通道)转换为(C, H, W)(通道、高度、宽度)。对于MNIST,因为是灰度图像,所以通道数为1,形状从(28, 28)变为(1, 28, 28)

    • 如果不指定transform,返回的是PIL图像格式。

  • 默认值:如果不指定,返回原始数据(通常是PIL图像)。

3. 返回值

这行代码的返回值是一个torchvision.datasets.MNIST对象,可以像数据集一样使用:

  • 可以通过索引(如dataset[0])访问单个样本。

  • 可以通过len(dataset)获取数据集大小。

  • 通常用于DataLoader中批量加载数据。

4. 完整示例

python

from torchvision import datasets, transforms

# 下载并加载MNIST测试集
test_dataset = datasets.MNIST(
    root='./data',      # 数据存储目录
    train=False,        # 加载测试集
    download=True,      # 如果数据不存在则下载
    transform=transforms.ToTensor()  # 转换为张量并归一化到[0,1]
)

# 打印数据集大小
print(len(test_dataset))  # 输出: 10000

# 获取第一个样本
image, label = test_dataset[0]
print(image.shape)  # 输出: torch.Size([1, 28, 28])
print(label)        # 输出: 7(标签)

5. 其他常见参数(虽然不是这里使用的)

  • target_transform:对标签(target)进行转换的函数(类似transform对图像的作用)。

  • 某些数据集可能有额外参数(如MNIST通常没有,但其他数据集可能有split等)。

总结:这行代码的作用是从PyTorch自动下载MNIST测试集到./data目录,并将图像转换为PyTorch张量格式,方便后续用于深度学习模型的测试。


网站公告

今日签到

点亮在社区的每一天
去签到