pytorch小记(三十二):深度解析 PyTorch 的 `torch.remainder`:向下取整余数运算

发布于:2025-07-18 ⋅ 阅读:(20) ⋅ 点赞:(0)


深度解析 PyTorch 的 torch.remainder:向下取整余数运算

在深度学习和科学计算中,经常需要对张量执行“取余”操作,将数值映射到某个固定范围内。与 Python 自身的 % 操作符不同,PyTorch 提供了更加明确的 符号约定 的两种取余函数:

  • torch.remainder(a, b):余数与除数同号或为零,基于向下取整
  • torch.fmod(a, b):余数与被除数同号或为零,基于截断取整

本文重点介绍 torch.remainder 的工作原理、数学定义、与 fmod 的区别以及丰富示例,帮助你在周期化、环绕索引、模运算等场景中选用正确的工具。


一、函数签名

torch.remainder(input, other, *, out=None) → Tensor
  • input (Tensor or scalar):被除数。
  • other (Tensor or scalar):除数,形状可与 input 相同或可广播。
  • out (可选):将结果写入到指定张量。

返回值形状与 input 相同,每个元素是对应的余数。


二、数学定义:向下取整除法

对标量 a a a(被除数)和 b b b(除数, b ≠ 0 b\neq0 b=0),

r e m a i n d e r ( a , b ) = a − b × ⌊ a b ⌋ , \mathrm{remainder}(a,b) = a - b \times \left\lfloor \frac{a}{b} \right\rfloor, remainder(a,b)=ab×ba,

其中 ⌊ x ⌋ \lfloor x\rfloor x 表示向下取整(不大于 x x x 的最大整数)。

  • b > 0 b>0 b>0,余数 r ∈ [ 0 , b ) r\in[0,b) r[0,b)
  • b < 0 b<0 b<0,余数 r ∈ ( b , 0 ] r\in(b,0] r(b,0]

三、与 torch.fmod 的区别

运算符 定义 余数符号归属
remainder(a,b) 向下取整: a − b ⌊ a / b ⌋ a - b\lfloor a/b\rfloor aba/b 除数 同号或为 0
fmod(a,b) 截断取整: a − b t r u n c ( a / b ) a - b\mathrm{trunc}(a/b) abtrunc(a/b) 被除数 同号或为 0

示例( a = − 3 , b = 2 a=-3, b=2 a=3,b=2):

  • torch.remainder: − 3 − 2 ⌊ − 1.5 ⌋ = − 3 − 2 × ( − 2 ) = 1 -3 - 2\lfloor -1.5\rfloor = -3 - 2\times(-2) = 1 321.5=32×(2)=1
  • torch.fmod: − 3 − 2 × t r u n c ( − 1.5 ) = − 3 − 2 × ( − 1 ) = − 1 -3 - 2\times\mathrm{trunc}(-1.5) = -3 - 2\times(-1) = -1 32×trunc(1.5)=32×(1)=1

四、经典示例

import torch

# 标量示例
print(torch.remainder(torch.tensor( 7), torch.tensor( 3)))  # 1
print(torch.remainder(torch.tensor(-7), torch.tensor( 3)))  # 2
print(torch.remainder(torch.tensor( 7), torch.tensor(-3)))  # -2
print(torch.remainder(torch.tensor(-7), torch.tensor(-3)))  # -1

# 对比 fmod
print(torch.fmod(torch.tensor(-7), torch.tensor( 3)))  # -1
print(torch.fmod(torch.tensor(-7), torch.tensor(-3)))  # -1

结果:

1
2
-2
-1
-1
-1

五、广播与张量示例

import torch

x = torch.tensor([-2.7, -1.2, 0.5, 1.8, 3.3])
b = 1.0
print(torch.remainder(x, b))
# => tensor([0.3000, 0.8000, 0.5000, 0.8000, 0.3000])

b2 = torch.tensor([2.0, -2.0, 2.0, -2.0, 2.0])
print(torch.remainder(x, b2))
# => tensor([ 1.3000, -1.2000,  0.5000, -0.2000,  1.3000])

手算过程示例:

  • − 2.7   m o d   1.0 = − 2.7 − 1 × ⌊ − 2.7 ⌋ = − 2.7 − 1 × ( − 3 ) = 0.3 -2.7 \bmod 1.0 = -2.7 - 1\times\lfloor-2.7\rfloor = -2.7 - 1\times(-3) = 0.3 2.7mod1.0=2.71×2.7=2.71×(3)=0.3.
  • − 2.7   m o d   2.0 = − 2.7 − 2 × ⌊ − 1.35 ⌋ = − 2.7 − 2 × ( − 2 ) = 1.3 -2.7 \bmod 2.0 = -2.7 - 2\times\lfloor-1.35\rfloor = -2.7 - 2\times(-2) = 1.3 2.7mod2.0=2.72×1.35=2.72×(2)=1.3.

六、应用场景

  1. 周期化数据:将角度或索引循环到固定范围,如 [ 0 , 2 π ) [0,2\pi) [0,2π)
  2. 环绕索引:在循环队列、环形缓冲区或图像平铺中计算有效索引。
  3. 同余运算:数学模型、密码学和信号处理中的模运算实现。

七、小结

  • torch.remainder(a,b) 基于向下取整,保证余数与除数同号或为零,适合需要“正向环绕”的场景。
  • torch.fmod(截断取整、余数符号随被除数)的区别在于符号归属
  • 支持标量、张量与广播,语义清晰,易于处理周期性和模运算。

网站公告

今日签到

点亮在社区的每一天
去签到