【动手学深度学习】Fashion-MNIST图片分类数据集

发布于:2025-03-13 ⋅ 阅读:(19) ⋅ 点赞:(0)


1,Fashion-MNIST数据集

实现softmax回归之前,我们先来学习一下读取多分类数据集,以便能够进行更好地演示。以常用的Fashion-MNIST数据集为例。

Fashion-MNIST数据集是一个广泛使用的图像分类数据集,旨在替代经典的MNIST手写数字识别数据集。该数据集包含了来自10个类别的70,000张灰度图像,每张图像的尺寸为28x28像素。类别涵盖了各种衣物和配件,包括T恤、裤子、套头衫、裙子、外套、凉鞋、衬衫、运动鞋、包和短靴。

Fashion-MNIST数据集分为两个主要部分:

  • 训练集:包含60,000张图像,用于模型训练。
  • 测试集:包含10,000张图像,用于评估模型性能。

每个类别在训练集和测试集中都有相同数量的样本。这意味着训练集中的每个类别有6,000张图像,而测试集中的每个类别有1,000张图像。


2,Fashion-MNIST数据集下载


导入模块

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

# 表示使用svg格式显示图片
d2l.use_svg_display()

通过torchvision库将Fashion-MNIST数据集下载并读取到内存中

trans = transforms.ToTensor()  

# 加载训练数据集
mnist_train = torchvision.datasets.FashionMNIST(
    """
    root表示数据集下载和存储的位置。此处指定的是当前目录上一级目录的data文件夹下
    train=True表示是训练数据
    transform=trans表示应用前面定义的图像转换操作
    download=True表示如果数据集尚未下载到指定的root目录,则自动下载
    """
    root="../data", train=True, transform=trans, download=True)

# 加载测试数据集
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

代码解释:

transforms.ToTensor() 是 PyTorch 中的一个预处理函数,它是一个图像转换操作,用于将 PIL 图像或 NumPy ndarray 转换为 FloatTensor。此类转换是图像数据预处理的一部分,常在将数据输入神经网络之前使用。代码中,transforms.ToTensor() 做了以下几件事情:

  • 将输入图像的像素值从 [0, 255] 缩放到 [0.0, 1.0]。
  • 将图像从 HWC(高度、宽度、通道)格式转换为 CHW(通道、高度、宽度)格式,这是 PyTorch 期望的输入格式。
  • 返回一个 FloatTensor 对象。

简单来说就是:将图片转为pytorch期望的tensor类型

代码运行结果如下:

在这里插入图片描述


3,验证数据集完整性

下载结束后可验证下载的数据集是否正常。

①打开文件资源管理器查看数据集是否下载完成

在这里插入图片描述

查看训练数据集和测试数据集的大小

len(mnist_train), len(mnist_test)

输出结果如下(训练集和测试集分别包含60000和10000张图像):

在这里插入图片描述

③查看图像形状

每个输入图像的高度和宽度均为28像素。 数据集由灰度图像组成,其通道数为1

mnist_train[0][0].shape

输出结果如下:

在这里插入图片描述


4,加载Fashion-MNIST数据集

Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。

定义get_fashion_mnist_labels函数用于在数字标签索引及文本名称之间进行转换。

def get_fashion_mnist_labels(labels):  
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]	

创建用于可视化样本的函数(绘制图象函数):

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  
    """
    绘制图像列表,参数解释:
    * imgs: 要显示的图像列表或数组。这些图像可以是 PIL 图像对象或 PyTorch 张量
    * num_rows: 显示的行数
    * num_cols: 显示的列数
    * titles: 可选,指定每个图像下方的标题列表。若提供,标题数应该与图像数相匹配。
    * scale: 可选参数,用于调整图像显示的大小。默认值为 1.5
    """

	# 计算图像尺寸
    figsize = (num_cols * scale, num_rows * scale)
	# 创建一个子图网格,大小为 figsize
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    # 将子图网格展平成一维数组,方便迭代
    axes = axes.flatten()
    # 迭代遍历每一对子图和图像
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
        	# matplotlib的imshow()函数的期望输入NumPy数组或者是PIL图像对象
            # 因此pytorch张量类型图片需要转为nump数组
            ax.imshow(img.numpy())
        else:
            # PIL类型图片直接展示
            ax.imshow(img)
        
        # 隐藏x轴和y轴
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        # 如果提供了标题,则设置子图标题
        if titles:
            ax.set_title(titles[i])
            
	# 返回子图数组
    return axes

展示训练数据集中前18个样本的图像及其相应的标签

# DataLoader可拿到批量数据
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))

# 打印输出数字标签索引 y
print(y)

# 批量大小18;2行,每行9个图片;title设置成转换后的文本标签
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

在这里插入图片描述

读取小批量数据

# 批量大小256
batch_size = 256

def get_dataloader_workers():  
    """使用4个进程来读取数据"""
    return 4

train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

打印一下读取训练数据所需的时间(一般需要确保读取数据比训练更快)

timer = d2l.Timer()

for X, y in train_iter:
    continue

# 打印输出读取一次数据需要的时间
f'{timer.stop():.2f} sec'

输出结果如下:

在这里插入图片描述


接下来整合所有组件,实现一个完整加载数据集的流程

定义load_data_fashion_mnist函数,加载Fashion-MNIST数据集

"""
下载Fashion-MNIST数据集,然后将其加载到内存中

参数resize表示调整图片大小
"""
def load_data_fashion_mnist(batch_size, resize=None): 
    
    # trans是一个用于转换的 *列表*
    trans = [transforms.ToTensor()]
    
    if resize:    # resize不为空,表示需要调整图片大小
        trans.insert(0, transforms.Resize(resize))
        
    trans = transforms.Compose(trans)
    
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))
  • torchvision.transforms模块中,transforms.Resize(size)是一个图像预处理操作,能将输入的PIL图像或者Tensor按照指定的size大小进行调整。这里的size可以是一个表示目标宽高的二元组(width, height),也可以是一个表示等比例缩放的目标大小的整数;
  • 转换1: transforms.ToTensor():将PIL图像或numpy.ndarray转换成PyTorch的Tensor格式;
  • 转换2: 若resize参数不为空,将在列表的最前面插入transforms.Resize(resize)转换操作,根据提供的尺寸调整图像的大小;
  • 转换3: 最后,使用transforms.Compose(trans)可以将列表中的所有转换操作组合成一个单一的转换。调用后trans变量不再是一个列表,而是一个transforms.Compose对象,它封装了所有的转换操作。

演示通过指定resize参数来测试load_data_fashion_mnist函数的图像大小调整功能。

# 批量大小32; resize=64表示图像大小(高宽)调整为64×64
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    """
    x是图像数据;y是标签
    .dtype 将显示数据的数据类型
    .shape 将显示数据的形状
    """
    print(X.shape, X.dtype, y.shape, y.dtype)
    break

运行结果如下:

在这里插入图片描述