代码
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils import data
mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, download=True,
transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, download=True,
transform=transforms.ToTensor())
batch_size = 256
train_loader = data.DataLoader(mnist_train, batch_size, shuffle=True)
test_loader = data.DataLoader(mnist_test, batch_size, shuffle=True)
num_inputs = 784
num_outputs = 10
def softmax(X):
X_exp = X.exp()
partition = X_exp.sum(dim=1, keepdim=True)
return X_exp / partition
def net(X):
return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)
def cross_entropy(y_hat, y):
return - torch.log(y_hat.gather(1, y.view(-1, 1)))
def sgd(params, lr, batch_size):
for param in params:
param.data -= lr * param.grad / batch_size
def accuracy(y_hat, y):
return (y_hat.argmax(dim=1) == y).float().mean().item()
W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype=torch.float)
W.requires_grad_()
b.requires_grad_()
num_epochs, lr = 10, 0.1
loss = cross_entropy
optimizer = sgd
for epoch in range(1, 1 + num_epochs):
total_loss = 0.0
train_sample = 0.0
train_acc_sum = 0
for x, y in train_loader:
y_hat = net(x)
l = loss(y_hat, y)
l.sum().backward()
sgd([W, b], lr, batch_size)
W.grad.data.zero_()
b.grad.data.zero_()
total_loss += l.sum().item()
train_sample += y.shape[0]
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
print('epoch %d, loss %.4f, train acc %.3f' % (epoch, total_loss / train_sample, train_acc_sum / train_sample,))
with torch.no_grad():
total_loss = 0.0
test_sample = 0.0
test_acc_sum = 0
for x, y in test_loader:
y_hat = net(x)
l = loss(y_hat, y)
total_loss += l.sum().item()
test_sample += y.shape[0]
test_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
print('loss %.4f, test acc %.3f' % (total_loss / test_sample, test_acc_sum / test_sample,))
结果
![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/8b46a6c945534245bf1ca91eff8dd2db.png)