python学习打卡day38

发布于:2025-05-30 ⋅ 阅读:(21) ⋅ 点赞:(0)
DAY 38 Dataset和Dataloader类

对应5. 27作业

知识点回顾:

  1. Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
  2. Dataloader类
  3. minist手写数据集的了解

作业:了解下cifar数据集,尝试获取其中一张图片

MNIST手写数字数据集。该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28像素,共包含10个类别。因为每个数据的维度比较小,所以既可以视为结构化数据,用机器学习、MLP训练,也可以视为图像数据,用卷积神经网络训练。

 导入Dataset类和Dataloader类必要的库

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt

# 设置随机种子,确保结果可复现
torch.manual_seed(42)

其中torchvision是一个计算机视觉的库,它的常见方法如下:

torchvision

  1. datasets       # 视觉数据集(如 MNIST、CIFAR)
  2.  transforms     # 视觉数据预处理(如裁剪、翻转、归一化)
  3. models         # 预训练模型(如 ResNet、YOLO)
  4. utils          # 视觉工具函数(如目标检测后处理)
  5. io             # 图像/视频 IO 操作

 1.其中的transforms模块提供了一系列常用的图像预处理操作:

# 先归一化,再标准化
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量并归一化到[0,1]
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])

2.MNIST数据集

# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)

随机取出一张图片(包括图片和标签) 

import matplotlib.pyplot as plt

# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签

可视化取出的图像 

# 可视化原始图像(需要反归一化)
def imshow(img):
    img = img * 0.3081 + 0.1307  # 反标准化
    npimg = img.numpy()
    plt.imshow(npimg[0], cmap='gray') # 显示灰度图像
    plt.show()

print(f"Label: {label}")
imshow(image)

当然也可以用相同的思路取出两张:

import matplotlib.pyplot as plt
import torch

# 随机选择两张图片的索引
sample_idx_1 = torch.randint(0, len(train_dataset), size=(1,)).item()
sample_idx_2 = torch.randint(0, len(train_dataset), size=(1,)).item()

# 获取图片和标签
image_1, label_1 = train_dataset[sample_idx_1]
image_2, label_2 = train_dataset[sample_idx_2]

# 定义一个函数来反归一化并显示图像
def imshow(img):
    img = img * 0.3081 + 0.1307  # 反标准化
    npimg = img.numpy()
    plt.imshow(npimg[0], cmap='gray')

# 创建一个包含两个子图的画布
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

# 显示第一张图片
plt.sca(axes[0])
imshow(image_1)
axes[0].set_title(f'Label: {label_1}')
axes[0].axis('off')

# 显示第二张图片
plt.sca(axes[1])
imshow(image_2)
axes[1].set_title(f'Label: {label_2}')
axes[1].axis('off')

plt.show()

我们是如何通过dataset类取出图像的呢?? 

PyTorch 的torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:

- __len__():返回数据集的样本总数。

- __getitem__(idx):根据索引idx返回对应样本的数据和标签。

PyTorch 要求所有数据集必须实现__getitem__和__len__,这样才能被DataLoader等工具兼容。这是一种接口约定,类似函数参数的规范。这意味着,如果你要创建一个自定义数据集,你需要实现这两个方法,否则PyTorch将无法识别你的数据集。

__getitem__方法用于让对象支持索引操作,当使用[]语法访问对象元素时,Python 会自动调用该方法。

# 示例代码
class MyList:
    def __init__(self):
        self.data = [10, 20, 30, 40, 50]

    def __getitem__(self, idx):
        return self.data[idx]

# 创建类的实例
my_list_obj = MyList()
# 此时可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2])  # 输出:30

__len__方法用于返回对象中元素的数量,当使用内置函数len()作用于对象时,Python 会自动调用该方法。

class MyList:
    def __init__(self):
        self.data = [10, 20, 30, 40, 50]

    def __len__(self):
        return len(self.data)

# 创建类的实例
my_list_obj = MyList()
# 使用len()函数获取元素数量,这会自动调用__len__方法
print(len(my_list_obj))  # 输出:5

再介绍一下Dataloader类

# 3. 创建数据加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关
    shuffle=True # 随机打乱数据
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1000 # 每个批次1000张图片
    # shuffle=False # 测试时不需要打乱数据
)

作业: 

维度 CIFAR 数据集 MNIST 手写数据集
创建机构 / 背景 由加拿大先进研究所(CIFAR)开发,用于计算机视觉研究 由纽约大学柯朗数学科学研究所开发,用于手写数字识别研究
数据类型 自然物体彩色图像(如动物、交通工具等) 手写数字灰度图像(0-9)
图像分辨率 32×32 像素(RGB 三通道,彩色图像) 28×28 像素(单通道,灰度图像)
数据集规模 - 总样本数:60,000 张
- 训练集:50,000 张
- 测试集:10,000 张
- 总样本数:70,000 张
- 训练集:60,000 张
- 测试集:10,000 张
类别数量 - CIFAR-10:10 个大类
- CIFAR-100:100 个细分类别(20 个超类)
10 个类别(数字 0-9)
任务难度 - 图像分辨率低但包含复杂背景和类内差异
- CIFAR-100 因类别多、区分度小,难度更高
图像背景简单,数字形态相对固定,难度较低
典型应用场景 图像分类、目标识别、深度学习算法基准测试(如 CNN 优化) 手写数字识别、基础算法验证(如神经网络入门案例)
数据预处理 需进行色彩归一化、数据增强(如裁剪、翻转)等处理 通常仅需灰度归一化和简单降噪处理
模型性能基准 - CIFAR-10 顶尖模型准确率:~97%
- CIFAR-100 顶尖模型准确率:~87%
顶尖模型准确率:~99.7%(如 CNN)
相似点 - 均为图像分类领域经典基准数据集
- 均包含训练集和测试集,结构标准化
- 广泛用于算法教学、研究和性能对比
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt

# 设置随机种子,确保结果可复现
torch.manual_seed(42)
# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为Tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化处理,将像素值从[0,1]缩放到[-1,1]
])

 

# 加载训练集
train_dataset = datasets.CIFAR10(
    root='./data',  # 数据存放路径
    train=True,  # 是否为训练集
    download=True,  # 如果数据不存在,是否自动下载
    transform=transform  # 数据预处理
)

# 加载测试集
test_dataset = datasets.CIFAR10(
    root='./data',  # 数据存放路径
    train=False,  # 是否为测试集
    transform=transform  # 数据预处理
)
import matplotlib.pyplot as plt
# 类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签
# 可视化原始图像(需要反归一化)
def imshow(img, title=None):
    img = img / 2 + 0.5  # 反归一化:将[-1,1]范围转回[0,1]
    npimg = img.numpy()
    plt.figure(figsize=(4, 4))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))  # 调整通道顺序:从[C,H,W]到[H,W,C]
    if title:
        plt.title(title)
    plt.axis('off')
    plt.show()

print(f"Label: {label} ({classes[label]})")
imshow(image, f"Label: {classes[label]}")

@浙大疏锦行