PyTorch使用(7)-张量常见运算函数

发布于:2025-04-06 ⋅ 阅读:(19) ⋅ 点赞:(0)

1. 基本数学运算

1.1 平方根和幂运算

import torch

x = torch.tensor([4.0, 9.0, 16.0])

# 平方根
sqrt_x = torch.sqrt(x)  # tensor([2., 3., 4.])

# 平方
square_x = torch.square(x)  # tensor([16., 81., 256.])

# 任意幂次
pow_x = torch.pow(x, 3)  # tensor([64., 729., 4096.])

# 运算符形式
sqrt_x_alt = x ** 0.5
square_x_alt = x ** 2

1.2 指数和对数

# 自然指数
exp_x = torch.exp(x)  # tensor([5.4595e+01, 8.1031e+03, 8.8861e+06])

# 自然对数
log_x = torch.log(x)  # tensor([1.3863, 2.1972, 2.7726])

# 以10为底的对数
log10_x = torch.log10(x)  # tensor([0.6021, 0.9542, 1.2041])

# 带clip的最小值保护(避免log(0))
safe_log = torch.log(x + 1e-8)

2. 统计运算

2.1 求和与均值

x = torch.randn(3, 4)  # 3x4随机张量

# 全局求和
total = torch.sum(x)  # 标量

# 沿特定维度求和
sum_dim0 = torch.sum(x, dim=0)  # 形状(4,),沿行求和
sum_dim1 = torch.sum(x, dim=1)  # 形状(3,),沿列求和

# 均值计算
mean_val = torch.mean(x)  # 全局均值
mean_dim0 = torch.mean(x, dim=0)  # 沿行求均值

2.2 极值与排序

# 最大值/最小值
max_val = torch.max(x)  # 全局最大值
min_val = torch.min(x)  # 全局最小值

# 沿维度的极值及索引
max_vals, max_indices = torch.max(x, dim=1)  # 每行最大值及位置
min_vals, min_indices = torch.min(x, dim=0)  # 每列最小值及位置

# 排序
sorted_vals, sorted_indices = torch.sort(x, dim=1, descending=True)

2.3 方差与标准差

# 无偏方差(分母n-1)
var_x = torch.var(x, unbiased=True)  # 全局方差
var_dim0 = torch.var(x, dim=0)  # 沿行方差

# 标准差
std_x = torch.std(x)  # 全局标准差
std_dim1 = torch.std(x, dim=1)  # 沿列标准差

3. 矩阵运算

3.1 基本矩阵运算

A = torch.randn(3, 4)
B = torch.randn(4, 5)

# 矩阵乘法
matmul = torch.matmul(A, B)  # 形状(3,5)
matmul_alt = A @ B  # 等价写法

# 点积(向量)
v1 = torch.randn(3)
v2 = torch.randn(3)
dot_product = torch.dot(v1, v2)

# 批量矩阵乘法
batch_A = torch.randn(5, 3, 4)  # 5个3x4矩阵
batch_B = torch.randn(5, 4, 5)  # 5个4x5矩阵
batch_matmul = torch.bmm(batch_A, batch_B)  # 形状(5,3,5)

3.2 矩阵分解

# 特征分解(对称矩阵)
sym_matrix = torch.randn(3, 3)
sym_matrix = sym_matrix @ sym_matrix.T  # 构造对称矩阵
eigenvals, eigenvecs = torch.linalg.eigh(sym_matrix)

# SVD分解
U, S, V = torch.linalg.svd(A)

4. 比较运算

4.1 元素级比较

a = torch.tensor([1, 2, 3])
b = torch.tensor([3, 2, 1])

# 比较运算
eq = torch.eq(a, b)  # tensor([False, True, False])
gt = torch.gt(a, b)  # tensor([False, False, True])
lt = torch.lt(a, b)  # tensor([True, False, False])

# 运算符形式
eq_alt = a == b
gt_alt = a > b

4.2 约简比较

# 判断所有元素为True
all_true = torch.all(eq)

# 判断任一元素为True
any_true = torch.any(gt)

# 判断张量相等(形状和值)
torch.equal(a, b)  # False

5. 规约运算

5.1 常用规约

x = torch.randn(2, 3)

# 求和规约
sum_all = x.sum()  # 全局求和
sum_dim = x.sum(dim=1)  # 沿维度规约

# 累积和
cumsum = x.cumsum(dim=0)  # 沿维度累积

# 乘积规约
prod_all = x.prod()  # 全局乘积

5.2 高级规约

# 加权平均
weights = torch.softmax(torch.randn(3), dim=0)
weighted_mean = torch.sum(x * weights, dim=1)

# 沿维度的logsumexp(数值稳定)
logsumexp = torch.logsumexp(x, dim=1)

6. 工程实践建议

6.1. 广播机制理解:确保运算张量的形状兼容

# 广播示例
a = torch.randn(3, 1)
b = torch.randn(1, 3)
c = a + b  # 形状(3,3)

6.2. 原地操作:使用_后缀节省内存

x.sqrt_()  # 原地平方根
x.add_(1)  # 原地加1

6.3. 设备一致性:确保运算张量在同一设备

if torch.cuda.is_available():
    x = x.cuda()
    y = y.cuda()
    z = x + y

6.4. 梯度保留:注意运算对计算图的影响

x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
y.backward()  # dy/dx = 2x = 4.0

6.5. 数值稳定性:使用稳定实现

# 不稳定的softmax实现
unstable = torch.exp(x) / torch.exp(x).sum(dim=1, keepdim=True)

# 稳定的softmax实现
stable = torch.softmax(x, dim=1)

7. 性能优化技巧

7.1 向量化操作:避免Python循环

# 不好的做法
result = torch.zeros_like(x)
for i in range(x.size(0)):
    result[i] = x[i] * 2

# 好的做法
result = x * 2

7.2. 融合操作:减少中间结果

# 低效
temp = x + y
result = temp * z

# 高效
result = (x + y) * z

7.3. 使用内置函数:利用优化实现

# 自定义实现
custom_norm = torch.sqrt(torch.sum(x ** 2))

# 内置优化函数
optimized_norm = torch.norm(x)