Pytorch维度转换操作:view,reshape,permute,flatten函数详解

发布于:2024-09-17 ⋅ 阅读:(65) ⋅ 点赞:(0)

引言

Pytorch中常见的维度转换函数有view, reshape, permute, flatten。本文将详细介绍这几个函数的作用与使用方式,并给出了具体的代码示例,希望能够帮助大家。

常见的维度有四维:比如(batch, channel, height, width);三维:比如(b,n,c);二维:比如(h,w)。下面介绍如何使用上述函数进行维度之间的转换。

1. view函数

作用

tensor.view() 可以用来调整张量的形状,这对于在网络层之间传递数据或者在处理图像数据时非常有用。需要注意的是,新的形状必须与原始张量的元素数量一致。

参数

size (tuple of ints) – 新的大小应该与原张量元素数量相匹配。可以指定一个尺寸为 -1 的维度来自动计算合适的大小。

代码示例:

将计算机视觉中的常见四维张量(Batch, Channel, Height, Width)转为三维(Batch,N,Channel)形式。

import torch
# view使用示例
x = torch.randn(16,3,64,64) # B, C, H, W
print(x.shape) #torch.Size([16,3,64,64])
B, C, H, W = x.size()

# 转为BNC
x = x.view(B, -1, C)
# 或者 x = x.view(B, H*W, C)
print(x.shape) #torch.Size([16, 4096, 3])

torch.randn() 是 PyTorch 中的一个函数,用于生成一个填充了从标准正态分布(均值为 0,方差为 1)中随机抽取的数字的张量。

2. permute函数

作用

permute() 函数用于改变张量的维度顺序。它接受一个新的维度顺序作为参数,并返回一个新的张量,其维度顺序按照给定的顺序排列。

参数说明

参数:一个元组,表示新的维度顺序。例如,对于一个形状为 (10, 3, 32, 32) 的张量,permute(0, 2, 3, 1) 表示新的维度顺序为 (10, 32, 32, 3)。其中0,1,2,3分别表示4个维度(10, 3, 32, 32)的索引。

代码示例:

依然将计算机视觉中的常见四维张量(Batch, Channel, Height, Width)转为三维(Batch,N,Channel)形式。

import torch
# permute使用示例:permute转换唯独顺序
x = torch.randn(16,3,64,64) # B, C, H, W
print(x.shape) #torch.Size([16,3,64,64])

# 16,3,64,64的维度索引分别为0,1,2,3
dim_change = x.permute(0,2,3,1) # 转为 B,H,W,C
# 然后将中间两个通道索引为[1,2]展平
out = dim_change.flatten(start_dim=1,end_dim=2)
print(out.shape) #torch.Size([16, 4096, 3])

flatten() 方法用于展平张量的一个或多个维度。它可以接受两个可选参数:start_dim:从哪个维度开始展平,默认为 0。 

end_dim:到哪个维度结束展平,默认为 -1,表示直到最后一个维度。 

此处的作用是将第二个和第三个维度进行展平。start_dim=1 表示从第二个维度(即 64)开始展平。end_dim=2 表示到第三个维度(即 64)结束展平。展平后的结果为 (16, 4096, 3),其中 4096= 64 * 64。 

通过这些步骤,你可以将原始张量从 (16,3,64,64) 转换为 (16, 4096, 3)。

3. Reshape函数

torch.reshape() 可以改变张量的形状,而不改变张量中的数据。与view函数的作用类似。

注意事项:新旧形状的元素总数必须相同。

import torch

# 创建一个简单的张量
x = torch.randn(4, 3)
print("Original tensor:")
print(x)

# 使用 torch.reshape() 来改变张量的形状
# 将 (4, 3) 的张量转换成 (2, 6) 的张量
reshaped_x = torch.reshape(x, (2, 6))
print("\nReshaped tensor:")
print(reshaped_x)

# 如果不确定某个维度的大小,可以使用 -1 让 PyTorch 自动计算
# 这里将 (4, 3) 转换为 (12,) 的一维张量
flat_x = torch.reshape(x, (-1))
print("\nFlattened tensor:")
print(flat_x)

# 更复杂的形状变换
# 将 (4, 3) 转换为 (3, 4) 的张量
complex_reshaped_x = torch.reshape(x, (3, 4))
print("\nComplex reshaped tensor:")
print(complex_reshaped_x)

4. flatten函数

torch.flatten 是 PyTorch 库中的一个函数,用于将一个多维张量转换为一维张量或降低其维度。

torch.flatten参数说明

input: 这是要被展平的张量。这是必需的参数。 

start_dim (可选): 指定从哪个维度开始展平。默认值为 0,这意味着展平将从第一个维度(通常是批量大小)开始。如果你希望保留前几个维度并只展平后续的维度,你可以设置这个参数。 

end_dim (可选): 指定展平到哪个维度结束。默认值为 -1,这表示展平将一直持续到最后一个维度。如果只想展平中间的一部分维度,可以设置这个参数来指定结束维度。

注意:当 start_dim 和 end_dim 都没有被显式地指定时,torch.flatten 将会展平除了第一个维度之外的所有维度,通常第一个维度是批量大小,会被保留以便于批次处理。

代码示例:

举个例子,假设有一个形状为 [batch_size, channels, height, width] 的四维张量,如果你想将其展平为 [batch_size, channels * height * width] 的二维张量,你可以直接调用 torch.flatten 而不需要额外的参数。但是,如果你想保留通道维度,并展平高度和宽度维度,你可以设置 start_dim=1 和 end_dim=2。

import torch

# 创建一个形状为 [8, 3, 64, 64] 的随机张量
x = torch.randn(8, 3, 64, 64)

# 展平除了第一个维度外的所有维度
y = torch.flatten(x)
print(y.shape)  # 输出: torch.Size([8, 12288])

# 只展平第二和第三个维度[也就是最后两个维度],0,1,2,3
z = torch.flatten(x, 1, 2)
print(z.shape)  # 输出: torch.Size([8, 3, 4096])