在 PyTorch 中,Flatten
操作是将多维张量转换为一维向量的重要操作,常用于卷积神经网络(CNN)的全连接层之前。以下是 PyTorch 中实现 Flatten 的各种方法及其应用场景。
一、基本 Flatten 方法
1. 使用 torch.flatten()
函数
import torch
# 创建一个4D张量 (batch_size, channels, height, width)
x = torch.randn(32, 3, 28, 28) # 32张28x28的RGB图像
# 展平整个张量
flattened = torch.flatten(x) # 输出形状: [75264] (32*3*28*28)
# 从指定维度开始展平
flattened = torch.flatten(x, start_dim=1) # 输出形状: [32, 2352] (保持batch维度)
2. 使用 nn.Flatten
层
import torch.nn as nn
flatten = nn.Flatten() # 默认从第1维开始展平(保持batch维度)
x = torch.randn(32, 3, 28, 28)
output = flatten(x) # 输出形状: [32, 2352]
可以指定开始和结束维度:
flatten = nn.Flatten(start_dim=1, end_dim=2)
x = torch.randn(32, 3, 28, 28)
output = flatten(x) # 输出形状: [32, 84, 28] (合并了第1和2维)
二、不同场景下的 Flatten 应用
1. CNN 中的典型用法
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(1, 16, 3),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.flatten = nn.Flatten()
self.fc = nn.Linear(32 * 5 * 5, 10) # 计算展平后的尺寸
def forward(self, x):
x = self.conv_layers(x)
x = self.flatten(x) # 形状从 [B, 32, 5, 5] 变为 [B, 800]
x = self.fc(x)
return x
2. 手动计算展平后的尺寸
# 计算卷积层输出尺寸的辅助函数
def conv_output_size(input_size, kernel_size, stride=1, padding=0):
return (input_size - kernel_size + 2 * padding) // stride + 1
# 计算经过多层卷积和池化后的尺寸
h, w = 28, 28 # 输入尺寸
h = conv_output_size(h, 3) # conv1: 26
w = conv_output_size(w, 3) # conv1: 26
h = conv_output_size(h, 2, 2) # pool1: 13
w = conv_output_size(w, 2, 2) # pool1: 13
h = conv_output_size(h, 3) # conv2: 11
w = conv_output_size(w, 3) # conv2: 11
h = conv_output_size(h, 2, 2) # pool2: 5
w = conv_output_size(w, 2, 2) # pool2: 5
print(f"展平后的特征数: {32 * h * w}") # 32 * 5 * 5 = 800
三、高级用法
1. 部分展平
# 只展平图像空间维度,保留通道维度
x = torch.randn(32, 3, 28, 28)
flattened = x.flatten(start_dim=2) # 形状: [32, 3, 784]
2. 自定义 Flatten 层
class ChannelLastFlatten(nn.Module):
"""将通道维度移到最后的展平层"""
def forward(self, x):
# 输入形状: [B, C, H, W]
x = x.permute(0, 2, 3, 1) # [B, H, W, C]
return x.reshape(x.size(0), -1) # [B, H*W*C]
3. 展平特定维度
# 展平批量维度和通道维度
x = torch.randn(32, 3, 28, 28)
flattened = x.flatten(end_dim=1) # 形状: [96, 28, 28] (32*3=96)
四、注意事项
维度计算:确保展平后的尺寸与全连接层的输入尺寸匹配
批量维度:通常保留第0维(batch维度)不被展平
内存连续性:
view()
需要连续内存,必要时先调用contiguous()
替代方法:
x.view(x.size(0), -1)
是flatten(start_dim=1)
的常见替代写法
五、性能比较
方法 | 优点 | 缺点 |
---|---|---|
torch.flatten() |
官方推荐,可读性好 | 无 |
nn.Flatten() |
可作为网络层使用 | 需要实例化对象 |
x.view() |
最简洁 | 需要手动计算尺寸 |
x.reshape() |
自动处理内存连续性 | 性能略低于view |
六、示例代码
import torch
import torch.nn as nn
# 定义一个包含Flatten的完整模型
class ImageClassifier(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.flatten = nn.Flatten()
self.classifier = nn.Sequential(
nn.Linear(256 * 4 * 4, 1024), # 假设输入图像是32x32
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(1024, 10)
)
def forward(self, x):
x = self.features(x)
x = self.flatten(x)
x = self.classifier(x)
return x
# 使用示例
model = ImageClassifier()
input_tensor = torch.randn(16, 3, 32, 32) # batch=16, 3通道, 32x32图像
output = model(input_tensor)
print(output.shape) # 输出形状: [16, 10]