**前置知识:
train_set=torchvision.datasets.CIFAR10(
root="./dataset",
train=True,
transform=dataset_trans,
download=True
)
1、torchvision.datasets.CIFAR10:PyTorch 中的一个类,用于加载 CIFAR-10 数据集
(CIFAR-10 是一个包含 60,000 张 32x32 彩色图像的数据库,分为 10 个类别)
2、
root="./dataset" | 数据集存储路径 |
train=True | True说明是训练集,False说明是验证集 |
transform=dataset_trans | 数据的预处理(可以用Compose) |
download=True | 如果指定路径下没有数据集,设置为True则会自动下载数据集 |
运行代码,在python控制台可以看到下载链接,
复制下载链接,让迅雷来下载更快速
3、这里的test_set元素类型是(PIL.Image,target),即一个数据对象=图片+标签
print(test_set[0]) #(<PIL.Image.Image image mode=RGB size=32x32 at 0x2A3F0419A30>, 6)
target是分类的编号,共有十个类别,则target从0到9
**代码:
import torchvision #与视觉图像有关,transforms和dataset都是它的模块
from torch.utils.tensorboard import SummaryWriter
dataset_trans=torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
#引入torchvision里的经典数据集(root:存储路径,train:是训练集还是验证集,download:是否需要下载)
train_set=torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_trans,download=True)
test_set=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_trans,download=True)
# print(test_set[0]) #(<PIL.Image.Image image mode=RGB size=32x32 at 0x2A3F0419A30>, 6)
# img,target=test_set[0]
# print(img) #一个数据对象=图片+标签
# print(target)
#
# print(test_set.classes) #classes表示有哪些类型,标签与其对应(若标签是6,则表明是6号类型,即frog)
# print(test_set.classes[target])
# img.show()
writer=SummaryWriter("logs")
for i in range(10):
img,target=test_set[i]
writer.add_image("images",img,i)
writer.close()