【MNIST手写数据集数据图像可视化 (Pytorch)】

发布于:2023-01-23 ⋅ 阅读:(437) ⋅ 点赞:(0)

MNIST手写数据集数据处理 (Pytorch)

内容概括:

基于pytorch框架,torchvision.datasets.MNIST导入数据集,对图片数据进行预处理以及使用matplotlib可视化。

import 文件

import torch
import torchvision
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"#和python环境有关的一种设置,使程序继续运行
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim import lr_scheduler #优化学习率
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from matplotlib import pyplot as plt

参数设置

#parameters:
lr=0.008
train_batch_size=64 
train_epochs=16
test_batch_size=1000
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

导入数据

train_dataset=torchvision.datasets.MNIST('./data/',train=True,download=True,
                                                 transform=torchvision.transforms.Compose([
                                                     torchvision.transforms.RandomHorizontalFlip(p=0.3),
                                                     torchvision.transforms.RandomVerticalFlip(p=0.3),
                                                     torchvision.transforms.ToTensor(),
                                                     torchvision.transforms.Normalize((0.1307,), (0.3081,))]))
test_dataset=torchvision.datasets.MNIST('./data/',train=False,download=True,
                                                 transform=torchvision.transforms.Compose([
                                                     torchvision.transforms.ToTensor(),
                                                     torchvision.transforms.Normalize((0.1307,), (0.3081,))]))

可视化

#查看图片
example_images,example_targets=(test_dataset._load_data())
fig = plt.figure()#生成图框
for i,c in enumerate(np.random.randint(0,1000,6)):#随机取0,1000里的6张图片
    plt.subplot(2,3,i+1) #i+1表示当前图片摆放位置,位置从1开始(故i要加1)
    plt.tight_layout()#自动调整间距
    plt.imshow(example_images[c], cmap='gray', interpolation='none')#加灰度
    plt.title("Ground Truth: {}".format(example_targets[c]))
    plt.xticks([])#x轴坐标设为空
    plt.yticks([])#y轴坐标设为空
plt.show()

注意这里为什么使用 test_dataset._load_data()
根据pytorch官方文档关于 torchvision.datasets.mnist 的解释
pytorch官网torchvision.datasets.mnist
下方_load_data 返回 data 和 target 正是我们所需得到的图像和标记
MNIST_vis_loaddata.png
可视化结果:
加灰度(cmap='gray’参数)的图像: MNIST_vis_gray.png

不加灰度的图像:

MNIST_vis_none.png


网站公告

今日签到

点亮在社区的每一天
去签到

热门文章