题目
手写数字识别的卷积神经网络(CNN)代码,实现前向传播
解答
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
# super(Net, self).__init__()
super().__init__()
self.model = nn.Sequential(
# The size of the picture is 28x28
nn.Conv2d(in_channels = 1,out_channels = 16,kernel_size = 3,stride = 1,padding = 1),
nn.ReLU(),
nn.MaxPool2d(kernel_size = 2,stride = 2),
# The size of the picture is 14x14
nn.Conv2d(in_channels = 16,out_channels = 32,kernel_size = 3,stride = 1,padding = 1),
nn.ReLU(),
nn.MaxPool2d(kernel_size = 2,stride = 2),
# The size of the picture is 7x7
nn.Conv2d(in_channels = 32,out_channels = 64,kernel_size = 3,stride = 1,padding = 1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(in_features = 7 * 7 * 64,out_features = 128),
nn.ReLU(),
nn.Linear(in_features = 128,out_features = 10),
nn.Softmax(dim=1)
)
def forward(self,input):
output = self.model(input)
return output
net = Net()
# 将模型转换到device中,并将其结构显示出来
# print(net.to(device))
trainImgs = torch.Tensor(32, 1, 28, 28) # [B, C, H, W]
outputs = net(trainImgs)
print(outputs.shape) # torch.Size([32, 10])
注意
在 Python 中,super(Net, self).__init__()或
super().__init__()
的作用是调用父类的构造函数,确保子类 Net
继承自父类(如 torch.nn.Module
)的属性和方法被正确初始化。
1. 代码含义
super()
:返回父类的代理对象,用于调用父类的方法。Net
:当前子类的名称。self
:当前子类的实例对象。__init__()
:父类的构造函数方法。
组合起来:
调用 Net
的父类(例如 torch.nn.Module
)的 __init__()
方法,确保父类的初始化逻辑被执行。
2. 为什么需要这行代码?
继承父类功能:
在 PyTorch 中,自定义神经网络模型必须继承torch.nn.Module
。
父类Module
内部定义了模型的核心机制(如参数管理、GPU 转换等)。
如果不调用父类的__init__()
,这些功能将无法正确初始化。避免潜在错误:
如果省略这行代码,子类Net
将无法使用Module
的功能,导致以下问题:模型参数(如
Conv2d
的权重)不会被识别和优化。无法将模型移动到 GPU(
.to(device)
)。无法正确保存或加载模型(
torch.save
/torch.load
)。
3.在 PyTorch 中的具体作用
在 PyTorch 模型中,父类 torch.nn.Module
的 __init__()
会做以下关键操作:
注册参数(Parameters)和子模块(Submodules):
将self.conv1
、self.linear
等子层添加到模型的参数列表中,优化器(如torch.optim.SGD
)才能找到并更新这些参数。设备管理:
跟踪模型所在的设备(CPU/GPU),确保输入数据和模型参数在同一设备上。模型序列化:
支持模型的保存(torch.save
)和加载(torch.load
)。