1. 基本数学运算
1.1 平方根和幂运算
import torch
x = torch.tensor([4.0, 9.0, 16.0])
sqrt_x = torch.sqrt(x)
square_x = torch.square(x)
pow_x = torch.pow(x, 3)
sqrt_x_alt = x ** 0.5
square_x_alt = x ** 2
1.2 指数和对数
exp_x = torch.exp(x)
log_x = torch.log(x)
log10_x = torch.log10(x)
safe_log = torch.log(x + 1e-8)
2. 统计运算
2.1 求和与均值
x = torch.randn(3, 4)
total = torch.sum(x)
sum_dim0 = torch.sum(x, dim=0)
sum_dim1 = torch.sum(x, dim=1)
mean_val = torch.mean(x)
mean_dim0 = torch.mean(x, dim=0)
2.2 极值与排序
max_val = torch.max(x)
min_val = torch.min(x)
max_vals, max_indices = torch.max(x, dim=1)
min_vals, min_indices = torch.min(x, dim=0)
sorted_vals, sorted_indices = torch.sort(x, dim=1, descending=True)
2.3 方差与标准差
var_x = torch.var(x, unbiased=True)
var_dim0 = torch.var(x, dim=0)
std_x = torch.std(x)
std_dim1 = torch.std(x, dim=1)
3. 矩阵运算
3.1 基本矩阵运算
A = torch.randn(3, 4)
B = torch.randn(4, 5)
matmul = torch.matmul(A, B)
matmul_alt = A @ B
v1 = torch.randn(3)
v2 = torch.randn(3)
dot_product = torch.dot(v1, v2)
batch_A = torch.randn(5, 3, 4)
batch_B = torch.randn(5, 4, 5)
batch_matmul = torch.bmm(batch_A, batch_B)
3.2 矩阵分解
sym_matrix = torch.randn(3, 3)
sym_matrix = sym_matrix @ sym_matrix.T
eigenvals, eigenvecs = torch.linalg.eigh(sym_matrix)
U, S, V = torch.linalg.svd(A)
4. 比较运算
4.1 元素级比较
a = torch.tensor([1, 2, 3])
b = torch.tensor([3, 2, 1])
eq = torch.eq(a, b)
gt = torch.gt(a, b)
lt = torch.lt(a, b)
eq_alt = a == b
gt_alt = a > b
4.2 约简比较
all_true = torch.all(eq)
any_true = torch.any(gt)
torch.equal(a, b)
5. 规约运算
5.1 常用规约
x = torch.randn(2, 3)
sum_all = x.sum()
sum_dim = x.sum(dim=1)
cumsum = x.cumsum(dim=0)
prod_all = x.prod()
5.2 高级规约
weights = torch.softmax(torch.randn(3), dim=0)
weighted_mean = torch.sum(x * weights, dim=1)
logsumexp = torch.logsumexp(x, dim=1)
6. 工程实践建议
6.1. 广播机制理解:确保运算张量的形状兼容
a = torch.randn(3, 1)
b = torch.randn(1, 3)
c = a + b
6.2. 原地操作:使用_后缀节省内存
x.sqrt_()
x.add_(1)
6.3. 设备一致性:确保运算张量在同一设备
if torch.cuda.is_available():
x = x.cuda()
y = y.cuda()
z = x + y
6.4. 梯度保留:注意运算对计算图的影响
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2
y.backward()
6.5. 数值稳定性:使用稳定实现
unstable = torch.exp(x) / torch.exp(x).sum(dim=1, keepdim=True)
stable = torch.softmax(x, dim=1)
7. 性能优化技巧
7.1 向量化操作:避免Python循环
result = torch.zeros_like(x)
for i in range(x.size(0)):
result[i] = x[i] * 2
result = x * 2
7.2. 融合操作:减少中间结果
temp = x + y
result = temp * z
result = (x + y) * z
7.3. 使用内置函数:利用优化实现
custom_norm = torch.sqrt(torch.sum(x ** 2))
optimized_norm = torch.norm(x)