pytorch小记(三十二):深度解析 PyTorch 的 `torch.remainder`:向下取整余数运算
深度解析 PyTorch 的 torch.remainder
:向下取整余数运算
在深度学习和科学计算中,经常需要对张量执行“取余”操作,将数值映射到某个固定范围内。与 Python 自身的 %
操作符不同,PyTorch 提供了更加明确的 符号约定 的两种取余函数:
torch.remainder(a, b)
:余数与除数同号或为零,基于向下取整。torch.fmod(a, b)
:余数与被除数同号或为零,基于截断取整。
本文重点介绍 torch.remainder
的工作原理、数学定义、与 fmod
的区别以及丰富示例,帮助你在周期化、环绕索引、模运算等场景中选用正确的工具。
一、函数签名
torch.remainder(input, other, *, out=None) → Tensor
- input (
Tensor
or scalar):被除数。 - other (
Tensor
or scalar):除数,形状可与input
相同或可广播。 - out (可选):将结果写入到指定张量。
返回值形状与 input
相同,每个元素是对应的余数。
二、数学定义:向下取整除法
对标量 a a a(被除数)和 b b b(除数, b ≠ 0 b\neq0 b=0),
r e m a i n d e r ( a , b ) = a − b × ⌊ a b ⌋ , \mathrm{remainder}(a,b) = a - b \times \left\lfloor \frac{a}{b} \right\rfloor, remainder(a,b)=a−b×⌊ba⌋,
其中 ⌊ x ⌋ \lfloor x\rfloor ⌊x⌋ 表示向下取整(不大于 x x x 的最大整数)。
- 若 b > 0 b>0 b>0,余数 r ∈ [ 0 , b ) r\in[0,b) r∈[0,b)。
- 若 b < 0 b<0 b<0,余数 r ∈ ( b , 0 ] r\in(b,0] r∈(b,0]。
三、与 torch.fmod
的区别
运算符 | 定义 | 余数符号归属 |
---|---|---|
remainder(a,b) |
向下取整: a − b ⌊ a / b ⌋ a - b\lfloor a/b\rfloor a−b⌊a/b⌋ | 与 除数 同号或为 0 |
fmod(a,b) |
截断取整: a − b t r u n c ( a / b ) a - b\mathrm{trunc}(a/b) a−btrunc(a/b) | 与 被除数 同号或为 0 |
示例( a = − 3 , b = 2 a=-3, b=2 a=−3,b=2):
torch.remainder
: − 3 − 2 ⌊ − 1.5 ⌋ = − 3 − 2 × ( − 2 ) = 1 -3 - 2\lfloor -1.5\rfloor = -3 - 2\times(-2) = 1 −3−2⌊−1.5⌋=−3−2×(−2)=1torch.fmod
: − 3 − 2 × t r u n c ( − 1.5 ) = − 3 − 2 × ( − 1 ) = − 1 -3 - 2\times\mathrm{trunc}(-1.5) = -3 - 2\times(-1) = -1 −3−2×trunc(−1.5)=−3−2×(−1)=−1
四、经典示例
import torch
# 标量示例
print(torch.remainder(torch.tensor( 7), torch.tensor( 3))) # 1
print(torch.remainder(torch.tensor(-7), torch.tensor( 3))) # 2
print(torch.remainder(torch.tensor( 7), torch.tensor(-3))) # -2
print(torch.remainder(torch.tensor(-7), torch.tensor(-3))) # -1
# 对比 fmod
print(torch.fmod(torch.tensor(-7), torch.tensor( 3))) # -1
print(torch.fmod(torch.tensor(-7), torch.tensor(-3))) # -1
结果:
1
2
-2
-1
-1
-1
五、广播与张量示例
import torch
x = torch.tensor([-2.7, -1.2, 0.5, 1.8, 3.3])
b = 1.0
print(torch.remainder(x, b))
# => tensor([0.3000, 0.8000, 0.5000, 0.8000, 0.3000])
b2 = torch.tensor([2.0, -2.0, 2.0, -2.0, 2.0])
print(torch.remainder(x, b2))
# => tensor([ 1.3000, -1.2000, 0.5000, -0.2000, 1.3000])
手算过程示例:
- − 2.7 m o d 1.0 = − 2.7 − 1 × ⌊ − 2.7 ⌋ = − 2.7 − 1 × ( − 3 ) = 0.3 -2.7 \bmod 1.0 = -2.7 - 1\times\lfloor-2.7\rfloor = -2.7 - 1\times(-3) = 0.3 −2.7mod1.0=−2.7−1×⌊−2.7⌋=−2.7−1×(−3)=0.3.
- − 2.7 m o d 2.0 = − 2.7 − 2 × ⌊ − 1.35 ⌋ = − 2.7 − 2 × ( − 2 ) = 1.3 -2.7 \bmod 2.0 = -2.7 - 2\times\lfloor-1.35\rfloor = -2.7 - 2\times(-2) = 1.3 −2.7mod2.0=−2.7−2×⌊−1.35⌋=−2.7−2×(−2)=1.3.
六、应用场景
- 周期化数据:将角度或索引循环到固定范围,如 [ 0 , 2 π ) [0,2\pi) [0,2π)。
- 环绕索引:在循环队列、环形缓冲区或图像平铺中计算有效索引。
- 同余运算:数学模型、密码学和信号处理中的模运算实现。
七、小结
torch.remainder(a,b)
基于向下取整,保证余数与除数同号或为零,适合需要“正向环绕”的场景。- 与
torch.fmod
(截断取整、余数符号随被除数)的区别在于符号归属。 - 支持标量、张量与广播,语义清晰,易于处理周期性和模运算。