MNIST手写数据集数据处理 (Pytorch)
内容概括:
基于pytorch框架,torchvision.datasets.MNIST导入数据集,对图片数据进行预处理以及使用matplotlib可视化。
import 文件
import torch
import torchvision
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"#和python环境有关的一种设置,使程序继续运行
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim import lr_scheduler #优化学习率
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from matplotlib import pyplot as plt
参数设置
#parameters:
lr=0.008
train_batch_size=64
train_epochs=16
test_batch_size=1000
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
导入数据
train_dataset=torchvision.datasets.MNIST('./data/',train=True,download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(p=0.3),
torchvision.transforms.RandomVerticalFlip(p=0.3),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))]))
test_dataset=torchvision.datasets.MNIST('./data/',train=False,download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))]))
可视化
#查看图片
example_images,example_targets=(test_dataset._load_data())
fig = plt.figure()#生成图框
for i,c in enumerate(np.random.randint(0,1000,6)):#随机取0,1000里的6张图片
plt.subplot(2,3,i+1) #i+1表示当前图片摆放位置,位置从1开始(故i要加1)
plt.tight_layout()#自动调整间距
plt.imshow(example_images[c], cmap='gray', interpolation='none')#加灰度
plt.title("Ground Truth: {}".format(example_targets[c]))
plt.xticks([])#x轴坐标设为空
plt.yticks([])#y轴坐标设为空
plt.show()
注意这里为什么使用 test_dataset._load_data()
根据pytorch官方文档关于 torchvision.datasets.mnist 的解释
pytorch官网torchvision.datasets.mnist
下方_load_data 返回 data 和 target 正是我们所需得到的图像和标记
可视化结果:
加灰度(cmap='gray’参数)的图像:
不加灰度的图像: