pytorch建立线性回归神经网络
模型建立,简单回归问题
import torch.nn as nn
x_data =torch.tensor([[1.0],[2.0],[3.0]])
y_data=torch.tensor([[2.0],[4.0],[6.0]])
#重点在于构造计算图 pytorch会自动计算梯度
#Z=wx+b 就是一个线性单元
class LinearModel(nn.Module):
#Module的对象会自动实现backword()的过程
#构造函数
def __init__(self) :
super(LinearModel, self).__init__()
#Linear()构建y=wx+b,且继承于Module自动完成backword()的过程
self.layer=nn.Sequential(nn.Linear(1,20)
,nn.Linear(20,20)
,nn.Linear(20,20)
,nn.Linear(20,1))
#前馈计算的函数 必须有
def forward(self,x):
#调用linear的__call__(),在此函数中会调用forward()
y_pred=self.layer(x)
return y_pred
def train(model, optimizer, criterion, num_epochs):
losses = []
for epoch in range(num_epochs):
optimizer.zero_grad()
y_pred=model(x_data)
loss=criterion(y_pred,y_data)
loss.backward()
optimizer.step()
losses.append(loss.item())
if epoch % 300 == 0:
print(f'Epoch {epoch}, Loss: {loss.item()}')
return losses
model = LinearModel()
#调用损失函数
criterion=nn.MSELoss(size_average=False)
#优化器,lr学习率
optimizer=torch.optim.Adam(model.parameters(),lr=0.01)
losses = train(model, optimizer, criterion, num_epochs=1000)#迭代步数可增大,文章中用的是10000
# print(losses)
可视化
import matplotlib.pyplot as plt
import numpy as np
plt.figure(figsize=(3,2))
import matplotlib.pyplot as plt
plt.plot(losses)
plt.show()
plt.figure(2,figsize=(3,2))
plt.scatter(x_data,y_data)
x_new=torch.Tensor(np.arange(0,4,0.01)).reshape(-1,1)
y_new=model(x_new)
# print(y_new)
plt.plot(x_new.detach().numpy(),y_new.detach().numpy(),color='r')
plt.show()
模型保存,保存为pth结构
保存模型,仅保存网络参数
#保存模型,仅保存网络参数
torch.save(model.state_dict(), 'model_params.pth')
模型的调用,注意要有有原来的网络结构
import torch.nn as nn
class LinearModel(torch.nn.Module):
#Module的对象会自动实现backword()的过程
#构造函数
def __init__(self) :
super(LinearModel, self).__init__()
#Linear()构建y=wx+b,且继承于Module自动完成backword()的过程
self.layer=nn.Sequential(nn.Linear(1,20)
,nn.Linear(20,20)
,nn.Linear(20,20)
,nn.Linear(20,1))
#前馈计算的函数 必须有
def forward(self,x):
#调用linear的__call__(),在此函数中会调用forward()
y_pred=self.layer(x)
return y_pred
net=LinearModel()
net.load_state_dict(torch.load('model_params.pth'))
应用调用的模型, 可视化2
import matplotlib.pyplot as plt
import numpy as np
plt.figure(2,figsize=(3,2))
# plt.scatter(x_data,y_data)
x_new=torch.Tensor(np.arange(0,4,0.01)).reshape(-1,1)
y_new=net(x_new)
# print(y_new)
plt.plot(x_new.detach().numpy(),y_new.detach().numpy(),color='r')
plt.show()