PyTorch中flatten()函数详解以及与view()和 reshape()的对比和实战代码示例

发布于:2025-07-30 ⋅ 阅读:(18) ⋅ 点赞:(0)

在 PyTorch 中,flatten() 函数常用于将张量(tensor)展平成一维或多维结构,尤其在构建神经网络(如 CNN)时,从卷积层输出进入全连接层前经常使用它。


一、基本语法

torch.flatten(input, start_dim=0, end_dim=-1)

参数说明:

参数 说明
input 输入张量
start_dim 开始展平的维度(包含该维)
end_dim 结束展平的维度(包含该维)

展平操作会把 start_dimend_dim 之间的维度合并成一维。


二、常见示例

示例 1:基本使用

import torch

x = torch.tensor([[[1, 2],
                   [3, 4]],
                  [[5, 6],
                   [7, 8]]])  # shape = (2, 2, 2)

out = torch.flatten(x)
print(out)
print(out.shape)  # torch.Size([8])

等价于 x.view(-1),即将所有维度展平成一维。


示例 2:保留前维度(常见于 CNN)

x = torch.randn(10, 3, 32, 32)  # 10张图片,3通道,32x32大小
out = torch.flatten(x, start_dim=1)

print(out.shape)  # torch.Size([10, 3072])

解释:

  • 展平从第 1 维开始(channel, height, width)→ 展平成一个维度
  • 第 0 维(batch size)保留,适合连接到 nn.Linear

示例 3:多维展开(指定 end_dim)

x = torch.randn(2, 3, 4, 5)  # shape = (2, 3, 4, 5)
out = torch.flatten(x, start_dim=1, end_dim=2)

print(out.shape)  # torch.Size([2, 12, 5]) -> (3*4 = 12)

三、与 .view() 的区别

函数 说明
view() 更底层、需要张量是连续的,手动指定形状
flatten() 更高层、更安全、自动处理维度合并,常用于模型构建中

四、常见用法:在模型中使用

1、示例1

import torch.nn as nn

class MyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(16, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)              # shape: (N, 16, 1, 1)
        x = torch.flatten(x, 1)       # shape: (N, 16)
        x = self.fc(x)
        return x

2、示例2

下面使用了 torch.flatten() 将卷积层的输出展平,并连接到全连接层。这个结构常见于 CNN 图像分类模型。


使用 flatten() 的 CNN 训练流程(以 CIFAR-10 为例)

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# ==== 1. 定义 CNN 模型,使用 flatten() ====
class FlattenCNN(nn.Module):
    def __init__(self):
        super(FlattenCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),  # 输入: [B, 3, 32, 32]
            nn.ReLU(),
            nn.MaxPool2d(2),                # 输出: [B, 16, 16, 16]

            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)                 # 输出: [B, 32, 8, 8]
        )

        self.fc = nn.Sequential(
            nn.Linear(32 * 8 * 8, 128),
            nn.ReLU(),
            nn.Linear(128, 10)              # CIFAR-10 共 10 类
        )

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)  # 👈 仅展平通道和空间维度,保留 batch
        x = self.fc(x)
        return x

# ==== 2. 准备数据 ====
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# ==== 3. 模型训练设置 ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FlattenCNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# ==== 4. 训练过程 ====
def train(model, loader, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f}")

# ==== 5. 开始训练 ====
train(model, train_loader, epochs=5)

重点说明

使用 torch.flatten(x, 1) 的原因:

  • 只展平通道、高、宽三维(保留 batch size)
  • 替代 x.view(x.size(0), -1) 更安全,避免非连续张量报错
  • 推荐在模型中构建更加模块化、清晰

五、三种张量展平方式:flatten()view()reshape() 的对比

下面从功能差异使用限制和**性能对比(benchmark)**进行三者的比较。


1、三者功能对比

函数 特点说明
flatten() 高级 API,自动处理维度合并,不要求张量连续。推荐模型中使用。
view() 底层操作,速度快,但要求张量是连续(tensor.is_contiguous()True
reshape() 更灵活,如果张量不连续,会自动复制为连续版本。性能略慢但更安全

2、代码功能对比

x = torch.randn(32, 3, 64, 64)  # batch of images

# flatten
f1 = torch.flatten(x, 1)

# view
f2 = x.view(32, -1)

# reshape
f3 = x.reshape(32, -1)

print(f1.shape, f2.shape, f3.shape)

输出一致:torch.Size([32, 12288])


3、非连续张量对比(view 会报错)

x = torch.randn(2, 3, 4)
y = x.permute(0, 2, 1)  # 非连续张量

try:
    y.view(-1)  # 会报错
except RuntimeError as e:
    print("view error:", e)

print("reshape:", y.reshape(-1).shape)   # reshape 正常
print("flatten:", torch.flatten(y).shape)  # flatten 正常

4、性能测试(benchmark)

import torch
import time

x = torch.randn(1024, 512, 28, 28)

# 保证是连续的
x_contig = x.contiguous()

N = 1000

def benchmark(op, name):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(N):
        _ = op(x_contig)
    torch.cuda.synchronize()
    end = time.time()
    print(f"{name}: {(end - start)*1000:.2f} ms")

benchmark(lambda x: torch.flatten(x, 1), "flatten()")
benchmark(lambda x: x.view(x.size(0), -1), "view()")
benchmark(lambda x: x.reshape(x.size(0), -1), "reshape()")

示例结果(A100 GPU):

flatten(): 58.12 ms
view():    41.76 ms
reshape(): 47.32 ms

总结view()最快,但要求张量连续;flatten()最安全但稍慢;reshape()是折中方案。


5、 建议总结

场景 推荐方式 原因
模型中展平 CNN 输出 flatten() 简洁、安全,尤其在复杂网络中
确保连续张量、追求速度 view() 性能最佳
张量可能非连续 reshape() 自动处理不连续情况,代码更鲁棒

六、小结

用法 效果
torch.flatten(x) 将所有维展平成一维
torch.flatten(x, 1) 保留 batch 维,常用于 CNN
torch.flatten(x, 1, 2) 展平指定维度区间


网站公告

今日签到

点亮在社区的每一天
去签到