【Pytorch框架】无中生有,从0到1使用Dataset类处理MNIST数据集

发布于:2024-12-07 ⋅ 阅读:(245) ⋅ 点赞:(0)

一、Pytorch下载

Pytorch框架以包的类型存在,但是又不同于其他包。


这里只介绍通过anaconda安装pytorch,因为安装并不是这篇博文的重点,详细的安装介绍可以参考 pytorch安装介绍

🔔目前pytorch框架只支持 CPU版CUDA版,而CUDA目前只有NVIDIA显卡支持,所以没有NVIDIA显卡支持的请安装CPU版。

1、首先进入 pytorch官网,往下翻找到:

在这里插入图片描述
2、选择稳定版,操作系统根据自己的来,Linux系统选择Linux,Windows系统选择Windows,package这里我们使用的是anaconda安装,所以选择conda,语言不多说哈,版本根据上面说的,有NVIDIA显卡的选择CUDA版,没有的选择CPU版。
在这里插入图片描述
3、复制下面的安装语句,比如我这里是下面这个,粘贴到anaconda的命令行,至于anaconda的哪个环境可以自己选择,回车就可以自动下载啦😆

conda install pytorch torchvision torchaudio cpuonly -c pytorch

在这里插入图片描述

二、MNIST数据集下载

💥MNIST数据集是一个经典的机器学习和计算机视觉数据集,用于手写数字识别的训练和测试,内容包含70000张手写数字的灰度图像,其中60000张用于训练,10000张用于测试。每张图像的大小为28x28像素,表示手写数字0-9。

👇下载方式一:通过pytorch下载

pytorch中内置有MNIST数据集,下载非常方便。

import torch
from torchvision import datasets, transforms

# 定义数据转换:将图像转换为张量
transform = transforms.Compose([
    transforms.ToTensor(),
])

# 下载 MNIST 数据集
train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='data', train=False, download=True, transform=transform)

🔔 datasets.MNIST中root代表数据集路径,train为True表示数据集为训练数据集,否则为测试数据集,download为True表示当路径中没有MNIST数据集时,自动下载,并将数据集保存在root路径中,transform表示数据转换器。

👇下载方式二:通过百度网盘下载

😙资源来自 详解MNIST数据集下载、解析及显示的Python实现

  • 原始格式数据下载
链接:https://pan.baidu.com/s/1jAPlVKLYamJn6I63GD6HDg?pwd=azq2 
提取码:azq2 

在这里插入图片描述

  • JPEG格式数据下载
链接:https://pan.baidu.com/s/1TaL3dCHxAj17LgvSSd_eTA?pwd=xl8n 
提取码:xl8n 

在这里插入图片描述

🌟上述两种格式下载后均是一个文件,这里我将 training类型test类型 分成了两个文件夹,图片格式为JEPG格式。

链接: https://pan.baidu.com/s/1i-hXHMBq1-dWKvZXhUYoAQ?pwd=6666 
提取码: 6666 

三、自定义Dataset类处理MNIST数据集

☀️自定义Dataset类需要继承torch.utils.data中的Dataset类,并重写__getitem__方法,使用PIL包下的Image类处理图片,os包读取数据集路径

from torch.utils.data import Dataset
from PIL import Image
import os

class MyDataset(Dataset):

在MyDataset类中有两个方法:__init__方法和__getitem__方法。

  • 在__init__方法中,传入数据集路径参数。对MyDataset类进行初始化,通过os.listdir方法将 数据集路径 对应的文件处理为列表,列表中存储每一张图片的完整名称(例如:test_0_7.jpg,其中test代表属于测试集数据,若为training则代表训练集数据;0代表图片的索引;7为label,表示图片所描述的数字)。
  • 在__getitem__方法中,传入索引参数,返回索引对应图片的JpegImageFile对象及对应的label标签。使用Image.open方法可将图片路径转换为JpegImageFile格式。
# 定义MyDataset数据集处理类,继承于Dataset,重写__getitem__方法
class MyDataset(Dataset):
    def __init__(self, root_dir):
        #数据集路径
        self.root_dir = root_dir
        #通过listdir函数将数据集中的图片转化成列表,列表中存储图片的完整名称,例如 test_0_7.jpg
        self.img_paths = os.listdir(self.root_dir)

    def __getitem__(self, idx):
        #获取下标为idx的图片路径
        img_path = self.img_paths[idx]
        #os.path.join函数可将两个路径拼接起来,Image.open函数将路径对应的图片打开为JpegImageFile格式
        img = Image.open(os.path.join(self.root_dir,img_path))
        #去除名称中的后缀,获得图片的label
        img_name = img_path.split('.')[0]

        label = img_name.split('_')[-1]
        #将图片路径和label返回
        return img, label

    def __len__(self):
        return len(self.img_paths)

🌟 其中的__len__方法可以返回img_paths列表的长度。


🍁测试完整代码

from torch.utils.data import Dataset
from PIL import Image
import os

# 定义MyDataset数据集处理类,继承于Dataset,重写__getitem__方法
class MyDataset(Dataset):
    def __init__(self, root_dir):
        #数据集路径
        self.root_dir = root_dir
        #通过listdir函数将数据集中的图片转化成列表,列表中存储图片的完整名称,例如 test_0_7.jpg
        self.img_paths = os.listdir(self.root_dir)

    def __getitem__(self, idx):
        #获取下标为idx的图片路径
        img_path = self.img_paths[idx]
        #os.path.join函数可将两个路径拼接起来,Image.open函数将路径对应的图片打开为JpegImageFile格式
        img = Image.open(os.path.join(self.root_dir,img_path))
        #去除名称中的后缀,获得图片的label
        img_name = img_path.split('.')[0]

        label = img_name.split('_')[-1]
        #将图片路径和label返回
        return img, label

    def __len__(self):
        return len(self.img_paths)


if __name__ == '__main__':
    #定义数据集路径
    train_dataset_path = "./mnist-20/training"
    test_dataset_path = "./mnist-20/test"

    #创建MyDataset类的对象
    train_dataset = MyDataset(train_dataset_path)
    test_dataset = MyDataset(test_dataset_path)

	#得到__getitem__函数返回的变量
    img,label = train_dataset[0] 
    #展示图片
    img.show()

在这里插入图片描述


网站公告

今日签到

点亮在社区的每一天
去签到