在 PyTorch 中,view()
是一个非常常用的张量(Tensor)操作函数,用于 改变张量的形状(shape),但 不会改变其数据内容。你可以把它看作是 PyTorch 中的类似 NumPy 中 reshape()
的方法。
一、基本语法
tensor.view(shape)
shape
:目标张量的形状,可以是多个整数参数,也可以是一个 tuple。- 其中某个维度可以设为
-1
,PyTorch 会根据总元素数自动推断这个维度的大小。
二、注意事项
view()
要求 原张量是连续的内存(contiguous),否则需先.contiguous()
;view()
改变的是张量的“视图”,不拷贝数据,效率高;view()
返回的是一个新的张量,原张量不变;-1
只能出现一次,用于自动计算维度。
三、代码示例
示例 1:基本使用
import torch
a = torch.arange(12) # 创建一个 1D 张量 [0, 1, ..., 11]
print(a.shape) # torch.Size([12])
b = a.view(3, 4) # 变成 3x4 的二维张量
print(b)
输出:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
示例 2:使用 -1
自动推断维度
a = torch.arange(12)
b = a.view(3, -1) # 自动计算出第二个维度为 4
print(b.shape) # torch.Size([3, 4])
示例 3:使用 view()
时不连续的内存
a = torch.randn(4, 3)
b = a.transpose(0, 1) # 维度交换后,b 不是连续内存
try:
b.view(-1) # 报错
except RuntimeError as e:
print("错误:", e)
# 正确做法:先调用 contiguous()
b = b.contiguous().view(-1)
示例 4:结合 batch 维度 reshape 图像
img = torch.randn(16, 3, 32, 32) # 假设是一个 batch 的图像 (N=16, C=3, H=32, W=32)
flat = img.view(16, -1) # 展平成向量,每个图像是 3072 维
print(flat.shape) # torch.Size([16, 3072])
四、view() vs reshape()
a.view(3, 4)
a.reshape(3, 4)
相同点:功能类似,都能改变形状;
不同点:
view()
要求连续内存;reshape()
更灵活,自动处理非连续情况,内部可能会拷贝数据。
推荐:如果你确信内存连续,view()
更高效;否则用 reshape()
更稳妥。
五、工程实战应用示例
在使用 CNN(卷积神经网络)或 RNN(循环神经网络)时,view()
函数常被用于 调整张量的形状 以满足网络结构输入/输出的要求。以下是典型场景中的用法分析和代码示例:
1、CNN 中 view()
的常见用途
场景:Flatten 卷积层输出,接入全连接层(Linear)
卷积层输出通常是一个 4D 张量 (batch_size, channels, height, width)
,在接入 nn.Linear
全连接层前需要展平为 2D,即 (batch_size, features)
。
示例:
import torch
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 16, 3), # 输入:1x28x28,输出:16x26x26
nn.ReLU(),
nn.MaxPool2d(2) # 输出:16x13x13
)
self.fc = nn.Linear(16 * 13 * 13, 10)
def forward(self, x):
x = self.conv(x) # shape: (batch, 16, 13, 13)
x = x.view(x.size(0), -1) # shape: (batch, 2704)
x = self.fc(x)
return x
x = torch.randn(8, 1, 28, 28) # batch of 8 images
model = CNN()
output = model(x)
print(output.shape) # torch.Size([8, 10])
x.view(x.size(0), -1)
保证了展平时 batch size 不变,特征自动推断。
2、RNN 中 view()
的常见用途
场景 1:将 RNN 的输出展平后送入全连接层
RNN 输出形状通常是 (seq_len, batch_size, hidden_size)
,有时需要 reshape 为 (batch_size * seq_len, hidden_size)
再处理。
示例:
class RNNClassifier(nn.Module):
def __init__(self):
super().__init__()
self.rnn = nn.RNN(input_size=10, hidden_size=20)
self.fc = nn.Linear(20, 5)
def forward(self, x):
out, _ = self.rnn(x) # out: (seq_len, batch_size, 20)
out = out.view(-1, 20) # flatten 所有时间步,shape: (seq_len * batch_size, 20)
out = self.fc(out) # shape: (seq_len * batch_size, 5)
return out
x = torch.randn(15, 4, 10) # seq_len=15, batch_size=4, input_size=10
model = RNNClassifier()
output = model(x)
print(output.shape) # torch.Size([60, 5])
场景 2:将嵌套序列 batch + time 展平为单个输入批
有时会将输入 (batch, seq_len, input_dim)
reshape 为 (batch * seq_len, input_dim)
以方便线性层处理:
x = torch.randn(32, 10, 100) # batch_size=32, seq_len=10, input_dim=100
x = x.view(-1, 100) # -> (320, 100)
3、CNN + RNN 混合模型中的 view()
用法
比如:图像经过 CNN 得到特征,再按时间顺序输入 RNN。这里就需要在 CNN 输出后 reshape 成 RNN 接受的格式 (seq_len, batch, input_size)
。
示例:
class CNN_RNN(nn.Module):
def __init__(self):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=3), # 1x28x28 -> 8x26x26
nn.ReLU(),
nn.MaxPool2d(2) # -> 8x13x13
)
self.rnn = nn.GRU(13 * 13, 64)
self.fc = nn.Linear(64, 10)
def forward(self, x):
batch_size, time_steps, C, H, W = x.shape # e.g., (16, 5, 1, 28, 28)
x = x.view(batch_size * time_steps, C, H, W)
x = self.cnn(x) # -> (batch*time, 8, 13, 13)
x = x.view(batch_size, time_steps, -1) # -> (batch, time, 1352)
x = x.permute(1, 0, 2) # -> (time, batch, features)
out, _ = self.rnn(x)
out = self.fc(out[-1])
return out
x = torch.randn(16, 5, 1, 28, 28) # batch=16, time=5
model = CNN_RNN()
output = model(x)
print(output.shape) # torch.Size([16, 10])
总结:常见 view()
用法对照表
用法场景 | 示例代码 | 解释 |
---|---|---|
Flatten CNN 输出 | x.view(x.size(0), -1) |
展平成 (B, C*H*W) |
展开 RNN 所有时间步输出 | x.view(-1, hidden_size) |
-> (T*B, H) |
输入 RNN 前展开为 (seq, B, dim) |
x.permute(1, 0, 2) |
时间维放前面 |
多帧图像展平成一个 batch | x.view(B*T, C, H, W) |
常用于视频输入 CNN |
总结
特性 | view() |
---|---|
功能 | 改变张量形状 |
是否拷贝数据 | 否(更高效) |
内存要求 | 连续(需 .contiguous()) |
支持 -1 吗 | 支持自动推断 |