200 行代码,深入分析动态计算图的原理及实现
原文地址:CSDN 博客
文章目录
1. 前言
机器学习这几年可是大红大紫,各行各业的人都往这里涌入,硬是在机器学习这一领域里挤出了一片人口红海。而在机器学习领域,神经网络由于自己下限低、上限高的特点,赢得了不少人的青睐。
在神经网络中,却有一件我们经常使用,经常耳闻,但又不太熟悉的东西——BP 算法。入门“炼丹”的小萌新往往会对这个一头雾水,久经沙场的老油条对这个也可能不了解细节。
我在查阅许多文章后,发现大多数文章对 BP 算法的介绍往往是点到为止,更深入者也就在数学公式推导层面止步,涉及到代码层面的博主鲜少,更很少提及 BP 算法在神经网络中的更广泛实现——计算图机制。
于是,秉持着“科普”的原则,笔者就撰写了这篇有关于 BP 算法以及计算图原理的文章,并在其中以笔者自己的代码实现,详细地讲解计算图的工作机制,并最终与成熟的计算框架进行比较。
2. BP 算法
BP 算法,又名反向传播算法,是目前深度学习的理论基石。其原始论文于 1986 年由 D. Rumelhart 发表在 Nature 上1。在其论文中,就已经使用 MSE(Mean Square Error) 均方误差作为训练目标,并使用多层的 MLP 感知机作为模型,进行亲戚关系的分类。
当前时代的神经网络,早已比当时的网络来的更加庞大,几百个万的模型参数比比皆是,GPT-3 甚至已经上千亿的模型,而其最基本的算法,却来自于 40 年前,让人感到不可思议。
对于 BP 算法的理解其实非常简单。假设神经网络的的损失是 L L L, x \bm{x} x 是输入向量, W i j \bm{W}_{ij} Wij 是第 i i i 层的第 j j j 个参数,那么根据梯度下降的原理,我们需要得到 L L L 对 W i j \bm{W}_{ij} Wij 偏微分值:
∇ W i j = ∂ L ∂ W i j \nabla\bm{W}_{ij}=\frac{\partial L}{\partial \bm{W}_{ij}} ∇Wij=∂Wij∂L
设 η \eta η 为学习率,则最终的参数更新算法为:
W i j t + 1 = W i j t − η ⋅ ∇ W i j \bm{W}_{ij}^{t+1}=\bm{W}_{ij}^{t}-\eta\cdot\nabla\bm{W}_{ij} Wijt+1=Wijt−η⋅∇Wij
然后问题来了:怎么计算 ∂ L ∂ W i j \frac{\partial L}{\partial \bm{W}_{ij}} ∂Wij∂L?
许多的博文都对这个问题作出众多的解释,大部分人会选择使用数学推导的形式阐述,最终结果或许可能如下:
这串花里胡哨的东西,对数学系的同学来说刚刚好,对笔者来说可不好。讲到底 BP 算法就是一个偏导数的链式法则应用,写这么复杂真的有用吗?
∂ y ∂ x 1 = ∂ y ∂ x n ⋅ ∂ y ∂ x n − 1 ⋅ ⋯ ⋅ ∂ x 2 ∂ x 1 \frac{\partial y}{\partial x_1}=\frac{\partial y}{\partial x_n}\cdot \frac{\partial y}{\partial x_{n-1}}\cdot\dots\cdot\frac{\partial x_2}{\partial x_1} ∂x1∂y=∂xn∂y⋅∂xn−1∂y⋅⋯⋅∂x1∂x2
看吧!如果我把上面这串链式法则的公式, y y y 换成 L L L, x 1 x_1 x1 换为 W i j \bm{W}_{ij} Wij,剩下的 x i x_i xi 换为神经网络中的一些其他变量,不就把 BP 算法拆成了许多更小的偏导数的乘积吗?
对于 BP 算法的数学机理,了解到这已经足够。下一节,笔者将以程序员的角度,带大家看 BP 算法的另一个视角——计算图机制。
3. 计算图
本章中通过一个实际的例子,给出计算图的详细说明,并引出了计算图反向传播机制的定理。
3.1 计算图定义
计算图是描述计算过程的数据结构,而且通常是 DAG 图(有向无环图)。
在计算图中,每一个节点表示一个变量(值),每一条边表示数据的流动方向,并且每一条边的值被定义为边的首尾节点的偏导数值。例如:
这幅图表示以下的三个算式:
c = a + b d = b + 1 e = c × d \begin{aligned} c&=a+b\\ d&=b+1\\ e&=c\times d \end{aligned} cde=a+b=b+1=c×d
在这副计算图中,每个节点都表示着一个变量值,每条边表示数据的流动。在每条边上,笔者提前算出了每条边的末尾节点对起始节点的偏导数,例如边 (b,d)
的偏导数就是 ∂ d ∂ b = 1 \frac{\partial d}{\partial b}=1 ∂b∂d=1。
3.2 计算图机制
拥有计算图的定义后,下面来详细介绍一下计算图是如何对应 BP 算法的。
3.2.1 前向传播
对应于 BP 算法的前向传播(Forward Pass)过程,计算图的前向传播其实相同,就是把计算图的每个节点的值都计算出来。
例如,在上面的示例图中,若设 a = 2 , b = 1 a=2,b=1 a=2,b=1,那么前向传播的过程就把其他的节点值都算出来:
c = a + b = 3 d = b + 1 = 2 e = c × d = 5 \begin{aligned} c&=a+b=3\\ d&=b+1=2\\ e&=c\times d=5 \end{aligned} cde=a+b=3=b+1=2=c×d=5
前向传播没有理解上的难点,大家一眼就能明白,而难点在于反向传播的过程中。
3.2.2 反向传播
在反向传播的机制中,笔者并不打算引入过于复杂的数学公式来证明,而是选择用更加浅显易懂的大白话,说明计算图在反向传播过程中的工作原理。
对于示例图中,如果想求 ∂ e ∂ b \frac{\partial e}{\partial b} ∂b∂e,该怎么办?首先,从节点 b b b 开始,可以发现 b b b 通过作用于 c c c 和 d d d,进而对节点 e e e 造成了影响。
这个连环影响的现象表达成数学的形式,即为
Δ e = ∂ e ∂ c ⋅ Δ c + ∂ e ∂ d ⋅ Δ d = ∂ e ∂ c ⋅ ( ∂ c ∂ b ⋅ Δ b ) + ∂ e ∂ d ⋅ ( ∂ d ∂ b ⋅ Δ b ) \begin{aligned} \Delta e&=\frac{\partial e}{\partial c}\cdot\Delta c + \frac{\partial e}{\partial d}\cdot\Delta d \\ &=\frac{\partial e}{\partial c}\cdot(\frac{\partial c}{\partial b}\cdot \Delta b) + \frac{\partial e}{\partial d}\cdot(\frac{\partial d}{\partial b}\cdot \Delta b) \end{aligned} Δe=∂c∂e⋅Δc+∂d∂e⋅Δd=∂c∂e⋅(∂b∂c⋅Δb)+∂d∂e⋅(∂b∂d⋅Δb)
上式左右两侧同时除以 Δ b \Delta b Δb,则可以不严谨的得到:
∂ e ∂ b = ∂ e ∂ c ⋅ ∂ c ∂ b + ∂ e ∂ d ⋅ ∂ d ∂ b \frac{\partial e}{\partial b}=\frac{\partial e}{\partial c}\cdot\frac{\partial c}{\partial b} + \frac{\partial e}{\partial d}\cdot\frac{\partial d}{\partial b} ∂b∂e=∂c∂e⋅∂b∂c+∂d∂e⋅∂b∂d
仔细地观察这个式子,对比下图可以发现:式子的前半部分 ∂ e ∂ c ⋅ ∂ c ∂ b \frac{\partial e}{\partial c}\cdot\frac{\partial c}{\partial b} ∂c∂e⋅∂b∂c,正好是路线 A 的边上梯度值的乘积;同理,式子的后半部分 ∂ e ∂ d ⋅ ∂ d ∂ b \frac{\partial e}{\partial d}\cdot\frac{\partial d}{\partial b} ∂d∂e⋅∂b∂d,也是路线 B 的边上梯度值的乘积。
从这里例子,可以总结出计算图的最终定理。
定理(计算图反向传播机制):计算图上任意两点 x x x 和 y y y,且 y y y 在 x x x 之后,则 ∂ y ∂ x \frac{\partial y}{\partial x} ∂x∂y 的值为点 x x x 到点 y y y 上所有的不重复路径上的边值乘积的总和。
如果觉得这个定理有点难懂,那么其详细的计算过程如下:
- 找到所有从点 x x x 到 y y y 的不重复路径,记作集合 P \mathcal{P} P
- 对任意 p i ∈ P p_i \in \mathcal{P} pi∈P,计算路径 p i p_i pi 上所有边值乘积 M i M_i Mi
- 则 ∂ y ∂ x = ∑ p i ∈ P M i \frac{\partial y}{\partial x}=\sum^{p_i\in \mathcal{P}} M_i ∂x∂y=∑pi∈PMi
对应到这个例子,就是说:从路线 A,得到其路径上的乘积为 d d d;从路线 B,得到其路径上的乘积为 c c c。那么最终的结果为
∂ e ∂ b = d + c = a + 2 b + 1 \frac{\partial e}{\partial b}=d+c=a+2b+1 ∂b∂e=d+c=a+2b+1
由于在前向传播的过程中,所有的变量值我们都已经确定,所以算出 ∂ e ∂ b \frac{\partial e}{\partial b} ∂b∂e 的过程也就迎刃而解了。
有兴趣的同学可以试着验证其他的变量,看它们是否符合此规律。此外,笔者更推荐对其他的计算图检查,可以加深对这条规则的理解。
4. 代码实现
下面就是代码实现的部分咯,觉得麻烦的小伙伴可以跳过不看哦,但还是希望能给我的代码点个 star 收藏一下,十分感激!ヾ(≧▽≦*)o
Github 仓库:toy_computational_graph
4.1 Operation 定义
在个人的 200 行代码的实现中,大部分代码用于实现加减乘除的操作,事实上真正涉及反向传播的代码可能不足 30 行。下面是关于 Operation
的基类定义:
class Operation(ABC):
def __init__(self):
super().__init__()
# 反向传播过程中所需要的上下文 ctx
self.ctx: Optional[Dict] = None
# 记录输入的节点
self.inputs: List[Value] = []
def __call__(self, *args) -> Scalar:
self.inputs = list(args)
self.ctx = dict()
ret = self.forward(args, ctx=self.ctx)
ret.op = self
return ret
@staticmethod
@abstractmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
# 进行前向传播,并将反向传播的必要信息存放于 ctx 中
pass
@staticmethod
@abstractmethod
def backward(grad_output: float, ctx=None) -> List[float]:
# 反向传播的过程,返回每条输入边的累积梯度值
# grad_output 是从更加往后的节点传播到此处的累积梯度乘积
pass
可见,每个 Operation
其实就有以下功能:
- 记录输入节点
- 记录前向传播过程中产生的上下文
- 前向传播
- 反向传播
根据这个基类,最终派生出了加减乘除操作的实现类:
class AddOperation(Operation):
@staticmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
x, y = inputs
return Scalar(x.value + y.value)
@staticmethod
def backward(grad_output: float, ctx=None) -> List[float]:
return [grad_output, grad_output]
class SubOperation(Operation):
@staticmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
x, y = inputs
return Scalar(x.value - y.value)
@staticmethod
def backward(grad_output: float, ctx=None) -> List[float]:
return [grad_output, -grad_output]
class MulOperation(Operation):
@staticmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
x, y = inputs
ctx["x"] = x
ctx["y"] = y
return Scalar(x.value * y.value)
@staticmethod
def backward(grad_output: float, ctx=None) -> List[float]:
x, y = ctx["x"].value, ctx["y"].value
return [grad_output * y, grad_output * x]
class DivOperation(Operation):
@staticmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
x, y = inputs
assert y.value != 0, "Division by zero"
ctx["x"] = x
ctx["y"] = y
return Scalar(x.value / y.value)
@staticmethod
def backward(grad_output: float, ctx=None) -> List[float]:
x, y = ctx["x"].value, ctx["y"].value
return [grad_output / y, -x * grad_output / (y ** 2)]
代码简短而且清爽,适合读者学习。
4.2 数值类型
由于这个 codebase 体量不大,因此只允许使用 float
的包装类 Scalar
作为数值类型。其中 Value
类是 Scalar
类的基类,其定义并实现了反向传播的机制,如下:
class Value:
def __init__(self, op: Optional[Operation]):
self.op = op
self.grad = 0.
def zero_grad(self):
# 梯度清零,类似于 PyTorch
self.grad = 0.
def backward(self, grad_output: Optional[float] = None):
# 反向传播的实际执行,就是从此节点,迭代地把累积梯度乘积向更前的节点传播
# 等节点根据所传入的累积梯度乘积,更新完自身的梯度值后,就继续进行此过程
# 注:在保证 DAG 的前提下,此过程相等于遍历图上的所有不同路径
grad_output = grad_output if grad_output is not None else 1.
self.grad += grad_output
if self.op is not None:
prev_grads = self.op.backward(grad_output, ctx=self.op.ctx)
for input, prev_grad in zip(self.op.inputs, prev_grads):
input.backward(prev_grad)
至于 Scalar
类,只是实现了 __add__
之类的加减乘除的 Dunder 函数的封装类,大致如下:
class Scalar(Value):
def __init__(self, value: numbers.Number, op: Optional[Operation] = None):
super().__init__(op)
self._value = float(value)
def __add__(self, other):
from operation import AddOperation
if isinstance(other, Scalar):
op = AddOperation()
return op(self, other)
elif isinstance(other, numbers.Number):
op = AddOperation()
return op(self, Scalar(other))
else:
raise TypeError("unsupported type")
... ...
由于 Scalar
类并不包括太多实际操作,因此完整代码供有兴趣的读者自行查看。
4.3 运行结果
详细代码可以查看代码仓库中的 example.py
,结果如下:
example1:
x=10.0, y=2.0, r=x+2*y=14.0
=> x.grad=1.0, y.grad=2.0
example2:
x=10.0, r=x*x=100.0
=> x.grad=20.0
example3:
x=10.0, r=x*(x+1)=110.0
=> x.grad=21.0
example4:
x=8.0, y=4.0, r=x/y=2.0
=> x.grad=0.25, y.grad=-0.5
example5:
x=3.0, r=1/(x*x+1)=0.1
=> x.grad=-0.06
example6:
x=8.0, y=3.0, r=(x*x+1)/(y*y-1)=8.125
=> x.grad=2.0, y.grad=-6.09375
以上六个例子的运算结果均正确。
5. 杂谈
事实上,我这个 demo 和 PyTorch 一样,采用的是动态计算图的形式,即计算图是在运算的过程中实时产生。相反的,Tensorflow 就是采用静态计算图,其计算图需要在一开始就进行编译并固定。
相较于我这个毫无优化的 demo,PyTorch 对于计算图的优化则是出神入化。首先在这个计算图的迭代过程中,明显可以发现,不同路径之间的乘积是可以并行计算的。
同时,从计算图机制的定理中可以发现,由于各个路径上的梯度最终是相加起来的,因此并行下最好的实现方式就是将各个变量的梯度都初始化为 0,否则梯度相加后会出错。这也是为什么 PyTorch 训练时,会需要 zero_grad()
这一步。当然,笔者的实现中也仿效了这一设计。
6. 总结
本文从程序员的角度,总结出了计算图机制下的运行定理,并给出了约 200 行的代码实现,希望能够帮助所有正在入门机器学习的人。
如果您觉得本文有价值,还希望您能给我的文章点个赞、收藏和关注的三连,我们下期再见!ヾ( ̄▽ ̄)ByeBye
最后的最后,附上本文代码的 repo 地址:toy_computational_graph,希望读者能点几个 star 支持一下!