使用torchvision中的经典数据集

发布于:2024-10-10 ⋅ 阅读:(14) ⋅ 点赞:(0)

**前置知识:


train_set=torchvision.datasets.CIFAR10(
root="./dataset",
train=True,
transform=dataset_trans,
download=True
)

1、torchvision.datasets.CIFAR10:PyTorch 中的一个类,用于加载 CIFAR-10 数据集

(CIFAR-10 是一个包含 60,000 张 32x32 彩色图像的数据库,分为 10 个类别)

2、

root="./dataset" 数据集存储路径
train=True True说明是训练集,False说明是验证集
transform=dataset_trans 数据的预处理(可以用Compose)
download=True 如果指定路径下没有数据集,设置为True则会自动下载数据集

 运行代码,在python控制台可以看到下载链接,

复制下载链接,让迅雷来下载更快速

3、这里的test_set元素类型是(PIL.Image,target),即一个数据对象=图片+标签

 print(test_set[0])  #(<PIL.Image.Image image mode=RGB size=32x32 at 0x2A3F0419A30>, 6)

 target是分类的编号,共有十个类别,则target从0到9

**代码:

import torchvision #与视觉图像有关,transforms和dataset都是它的模块
from torch.utils.tensorboard import SummaryWriter

dataset_trans=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

#引入torchvision里的经典数据集(root:存储路径,train:是训练集还是验证集,download:是否需要下载)
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_trans,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_trans,download=True)


# print(test_set[0]) #(<PIL.Image.Image image mode=RGB size=32x32 at 0x2A3F0419A30>, 6)
# img,target=test_set[0]
# print(img) #一个数据对象=图片+标签
# print(target)
#
# print(test_set.classes) #classes表示有哪些类型,标签与其对应(若标签是6,则表明是6号类型,即frog)
# print(test_set.classes[target])
# img.show()


writer=SummaryWriter("logs")
for i in range(10):
    img,target=test_set[i]
    writer.add_image("images",img,i)
writer.close()