pytorch学习笔记6

发布于:2024-06-16 ⋅ 阅读:(73) ⋅ 点赞:(0)

想要找一些官方的小工具数据集,可以进入pytorch官网,DOCS-》pytorch下拉至libraries,点击torchversion,调整版本至0.9.0就可以找到相应的一些数据集,训练集
ctrl+p可以看一个函数中需要设置哪些参数

下载数据集可以参考官方文档中的描述对数据集进行下载
在这里插入图片描述

import torchvision

train_set=torchvision.datasets.CIFAR10(root='./dataset',train=True,download=True)
#root表示数据集存放在那个位置./表示当前目录
#train如果为True,则从训练集创建数据集,否则从测试集创建数据集
#download如果为true,则从互联网下载数据集并将其放在根目录中。如果数据集已经下载,则不会再次下载。
#这些参数都可以从官方文档中获取
test_set=torchvision.datasets.CIFAR10(root='./dataset',train=False,download=True)

下载慢的话可以用迅雷下载
在这里插入图片描述
使用这个地址
在这里插入图片描述
<PIL.Image.Image image mode=RGB size=32x32 at 0x179A0B20A90> 表示这个样本是一个 32x32 像素的 RGB 彩色图像,使用 PIL 库表示。
3 是这个图像对应的标签,即这个图像所代表的物体类别在 CIFAR-10 数据集中的索引(CIFAR-10 数据集共有 10 个类别,索引从 0 到 9)

torchvision.transforms.Compose 是一个方便的工具,用于将多个图像变换操作组合成一个单一的变换。在图像处理和深度学习模型的训练过程中,通常需要对图像进行一系列的预处理操作,例如裁剪、缩放、归一化等。Compose 允许你将这些操作串联起来,使其按顺序应用于每个输入图像。

import torchvision.transforms as transforms

dataset_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

如下进行

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
train_set=torchvision.datasets.CIFAR10(root='./dataset',train=True,transform=dataset_transform,download=True)
test_set=torchvision.datasets.CIFAR10(root='./dataset',train=False,transform=dataset_transform,download=True)
print(test_set[0])

加上transform的流程对数据进行处理
这样批量将导入的数据进行处理

在图像处理中和深度学习模型的训练过程中,Normalize 变换的主要作用是对图像数据进行标准化处理。具体来说,Normalize 会调整图像的像素值,使其符合某个特定的分布。这通常有助于加速模型的训练过程并提高模型的性能。下面是 Normalize 的作用和原理的详细解释。

作用
加速模型收敛:通过将输入数据标准化,可以使模型的梯度更加稳定,避免某些特征对模型训练造成过大的影响,从而加速模型的收敛速度。
提高模型性能:标准化可以使不同特征的数据分布更加一致,有助于模型更好地理解和学习数据的特征,提高模型的性能。
防止梯度消失和梯度爆炸:标准化可以将输入数据的范围限制在一个较小的范围内,防止梯度在传播过程中变得过大或过小,稳定模型的训练过程。

transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

对于红色通道,减去均值 0.485,再除以标准差 0.229。
对于绿色通道,减去均值 0.456,再除以标准差 0.224。
对于蓝色通道,减去均值 0.406,再除以标准差 0.225。

ctrl+/可以快速对多行进行注释