200 行代码,深入分析动态计算图的原理及实现

发布于:2023-02-18 ⋅ 阅读:(491) ⋅ 点赞:(0)

200 行代码,深入分析动态计算图的原理及实现

原文地址:CSDN 博客

代码实现:toy_computational_graph



1. 前言

机器学习这几年可是大红大紫,各行各业的人都往这里涌入,硬是在机器学习这一领域里挤出了一片人口红海。而在机器学习领域,神经网络由于自己下限低、上限高的特点,赢得了不少人的青睐。

在神经网络中,却有一件我们经常使用,经常耳闻,但又不太熟悉的东西——BP 算法。入门“炼丹”的小萌新往往会对这个一头雾水,久经沙场的老油条对这个也可能不了解细节。

我在查阅许多文章后,发现大多数文章对 BP 算法的介绍往往是点到为止,更深入者也就在数学公式推导层面止步,涉及到代码层面的博主鲜少,更很少提及 BP 算法在神经网络中的更广泛实现——计算图机制

于是,秉持着“科普”的原则,笔者就撰写了这篇有关于 BP 算法以及计算图原理的文章,并在其中以笔者自己的代码实现,详细地讲解计算图的工作机制,并最终与成熟的计算框架进行比较。


2. BP 算法

BP 算法,又名反向传播算法,是目前深度学习的理论基石。其原始论文于 1986 年由 D. Rumelhart 发表在 Nature 上1。在其论文中,就已经使用 MSE(Mean Square Error) 均方误差作为训练目标,并使用多层的 MLP 感知机作为模型,进行亲戚关系的分类。

BP算法原始论文

当前时代的神经网络,早已比当时的网络来的更加庞大,几百个万的模型参数比比皆是,GPT-3 甚至已经上千亿的模型,而其最基本的算法,却来自于 40 年前,让人感到不可思议。

对于 BP 算法的理解其实非常简单。假设神经网络的的损失是 L L L x \bm{x} x 是输入向量, W i j \bm{W}_{ij} Wij 是第 i i i 层的第 j j j 个参数,那么根据梯度下降的原理,我们需要得到 L L L W i j \bm{W}_{ij} Wij 偏微分值:
∇ W i j = ∂ L ∂ W i j \nabla\bm{W}_{ij}=\frac{\partial L}{\partial \bm{W}_{ij}} Wij=WijL
η \eta η 为学习率,则最终的参数更新算法为:
W i j t + 1 = W i j t − η ⋅ ∇ W i j \bm{W}_{ij}^{t+1}=\bm{W}_{ij}^{t}-\eta\cdot\nabla\bm{W}_{ij} Wijt+1=WijtηWij

然后问题来了:怎么计算 ∂ L ∂ W i j \frac{\partial L}{\partial \bm{W}_{ij}} WijL

许多的博文都对这个问题作出众多的解释,大部分人会选择使用数学推导的形式阐述,最终结果或许可能如下:
BP算法数学形式
这串花里胡哨的东西,对数学系的同学来说刚刚好,对笔者来说可不好。讲到底 BP 算法就是一个偏导数的链式法则应用,写这么复杂真的有用吗?
∂ y ∂ x 1 = ∂ y ∂ x n ⋅ ∂ y ∂ x n − 1 ⋅ ⋯ ⋅ ∂ x 2 ∂ x 1 \frac{\partial y}{\partial x_1}=\frac{\partial y}{\partial x_n}\cdot \frac{\partial y}{\partial x_{n-1}}\cdot\dots\cdot\frac{\partial x_2}{\partial x_1} x1y=xnyxn1yx1x2
看吧!如果我把上面这串链式法则的公式, y y y 换成 L L L x 1 x_1 x1 换为 W i j \bm{W}_{ij} Wij,剩下的 x i x_i xi 换为神经网络中的一些其他变量,不就把 BP 算法拆成了许多更小的偏导数的乘积吗?

对于 BP 算法的数学机理,了解到这已经足够。下一节,笔者将以程序员的角度,带大家看 BP 算法的另一个视角——计算图机制。


3. 计算图

本章中通过一个实际的例子,给出计算图的详细说明,并引出了计算图反向传播机制的定理。

3.1 计算图定义

计算图是描述计算过程的数据结构,而且通常是 DAG 图(有向无环图)。

在计算图中,每一个节点表示一个变量(值),每一条边表示数据的流动方向,并且每一条边的值被定义为边的首尾节点的偏导数值。例如:
计算图示例
这幅图表示以下的三个算式:
c = a + b d = b + 1 e = c × d \begin{aligned} c&=a+b\\ d&=b+1\\ e&=c\times d \end{aligned} cde=a+b=b+1=c×d
在这副计算图中,每个节点都表示着一个变量值,每条边表示数据的流动。在每条边上,笔者提前算出了每条边的末尾节点对起始节点的偏导数,例如边 (b,d) 的偏导数就是 ∂ d ∂ b = 1 \frac{\partial d}{\partial b}=1 bd=1

3.2 计算图机制

拥有计算图的定义后,下面来详细介绍一下计算图是如何对应 BP 算法的。

3.2.1 前向传播

对应于 BP 算法的前向传播(Forward Pass)过程,计算图的前向传播其实相同,就是把计算图的每个节点的值都计算出来

例如,在上面的示例图中,若设 a = 2 , b = 1 a=2,b=1 a=2,b=1,那么前向传播的过程就把其他的节点值都算出来:
c = a + b = 3 d = b + 1 = 2 e = c × d = 5 \begin{aligned} c&=a+b=3\\ d&=b+1=2\\ e&=c\times d=5 \end{aligned} cde=a+b=3=b+1=2=c×d=5
前向传播没有理解上的难点,大家一眼就能明白,而难点在于反向传播的过程中。

3.2.2 反向传播

在反向传播的机制中,笔者并不打算引入过于复杂的数学公式来证明,而是选择用更加浅显易懂的大白话,说明计算图在反向传播过程中的工作原理。

计算图示例

对于示例图中,如果想求 ∂ e ∂ b \frac{\partial e}{\partial b} be,该怎么办?首先,从节点 b b b 开始,可以发现 b b b 通过作用于 c c c d d d,进而对节点 e e e 造成了影响。

这个连环影响的现象表达成数学的形式,即为
Δ e = ∂ e ∂ c ⋅ Δ c + ∂ e ∂ d ⋅ Δ d = ∂ e ∂ c ⋅ ( ∂ c ∂ b ⋅ Δ b ) + ∂ e ∂ d ⋅ ( ∂ d ∂ b ⋅ Δ b ) \begin{aligned} \Delta e&=\frac{\partial e}{\partial c}\cdot\Delta c + \frac{\partial e}{\partial d}\cdot\Delta d \\ &=\frac{\partial e}{\partial c}\cdot(\frac{\partial c}{\partial b}\cdot \Delta b) + \frac{\partial e}{\partial d}\cdot(\frac{\partial d}{\partial b}\cdot \Delta b) \end{aligned} Δe=ceΔc+deΔd=ce(bcΔb)+de(bdΔb)
上式左右两侧同时除以 Δ b \Delta b Δb,则可以不严谨的得到:
∂ e ∂ b = ∂ e ∂ c ⋅ ∂ c ∂ b + ∂ e ∂ d ⋅ ∂ d ∂ b \frac{\partial e}{\partial b}=\frac{\partial e}{\partial c}\cdot\frac{\partial c}{\partial b} + \frac{\partial e}{\partial d}\cdot\frac{\partial d}{\partial b} be=cebc+debd

仔细地观察这个式子,对比下图可以发现:式子的前半部分 ∂ e ∂ c ⋅ ∂ c ∂ b \frac{\partial e}{\partial c}\cdot\frac{\partial c}{\partial b} cebc,正好是路线 A 的边上梯度值的乘积;同理,式子的后半部分 ∂ e ∂ d ⋅ ∂ d ∂ b \frac{\partial e}{\partial d}\cdot\frac{\partial d}{\partial b} debd,也是路线 B 的边上梯度值的乘积。

在这里插入图片描述
从这里例子,可以总结出计算图的最终定理。

定理(计算图反向传播机制):计算图上任意两点 x x x y y y,且 y y y x x x 之后,则 ∂ y ∂ x \frac{\partial y}{\partial x} xy 的值为点 x x x 到点 y y y 上所有的不重复路径上的边值乘积的总和。

如果觉得这个定理有点难懂,那么其详细的计算过程如下:

  1. 找到所有从点 x x x y y y 的不重复路径,记作集合 P \mathcal{P} P
  2. 对任意 p i ∈ P p_i \in \mathcal{P} piP,计算路径 p i p_i pi 上所有边值乘积 M i M_i Mi
  3. ∂ y ∂ x = ∑ p i ∈ P M i \frac{\partial y}{\partial x}=\sum^{p_i\in \mathcal{P}} M_i xy=piPMi

对应到这个例子,就是说:从路线 A,得到其路径上的乘积为 d d d;从路线 B,得到其路径上的乘积为 c c c。那么最终的结果为
∂ e ∂ b = d + c = a + 2 b + 1 \frac{\partial e}{\partial b}=d+c=a+2b+1 be=d+c=a+2b+1

由于在前向传播的过程中,所有的变量值我们都已经确定,所以算出 ∂ e ∂ b \frac{\partial e}{\partial b} be 的过程也就迎刃而解了。

有兴趣的同学可以试着验证其他的变量,看它们是否符合此规律。此外,笔者更推荐对其他的计算图检查,可以加深对这条规则的理解。


4. 代码实现

下面就是代码实现的部分咯,觉得麻烦的小伙伴可以跳过不看哦,但还是希望能给我的代码点个 star 收藏一下,十分感激!ヾ(≧▽≦*)o

Github 仓库:toy_computational_graph

4.1 Operation 定义

在个人的 200 行代码的实现中,大部分代码用于实现加减乘除的操作,事实上真正涉及反向传播的代码可能不足 30 行。下面是关于 Operation 的基类定义:

class Operation(ABC):
    def __init__(self):
        super().__init__()
        # 反向传播过程中所需要的上下文 ctx
        self.ctx: Optional[Dict] = None
        # 记录输入的节点
        self.inputs: List[Value] = []

    def __call__(self, *args) -> Scalar:
        self.inputs = list(args)
        self.ctx = dict()
        ret = self.forward(args, ctx=self.ctx)
        ret.op = self
        return ret

    @staticmethod
    @abstractmethod
    def forward(inputs: List[Scalar], ctx=None) -> Scalar:
    	# 进行前向传播,并将反向传播的必要信息存放于 ctx 中
        pass

    @staticmethod
    @abstractmethod
    def backward(grad_output: float, ctx=None) -> List[float]:
    	# 反向传播的过程,返回每条输入边的累积梯度值
    	# grad_output 是从更加往后的节点传播到此处的累积梯度乘积
        pass

可见,每个 Operation 其实就有以下功能:

  • 记录输入节点
  • 记录前向传播过程中产生的上下文
  • 前向传播
  • 反向传播

根据这个基类,最终派生出了加减乘除操作的实现类:

class AddOperation(Operation):
    @staticmethod
    def forward(inputs: List[Scalar], ctx=None) -> Scalar:
        x, y = inputs
        return Scalar(x.value + y.value)

    @staticmethod
    def backward(grad_output: float, ctx=None) -> List[float]:
        return [grad_output, grad_output]


class SubOperation(Operation):
    @staticmethod
    def forward(inputs: List[Scalar], ctx=None) -> Scalar:
        x, y = inputs
        return Scalar(x.value - y.value)

    @staticmethod
    def backward(grad_output: float, ctx=None) -> List[float]:
        return [grad_output, -grad_output]


class MulOperation(Operation):
    @staticmethod
    def forward(inputs: List[Scalar], ctx=None) -> Scalar:
        x, y = inputs
        ctx["x"] = x
        ctx["y"] = y
        return Scalar(x.value * y.value)

    @staticmethod
    def backward(grad_output: float, ctx=None) -> List[float]:
        x, y = ctx["x"].value, ctx["y"].value
        return [grad_output * y, grad_output * x]


class DivOperation(Operation):
    @staticmethod
    def forward(inputs: List[Scalar], ctx=None) -> Scalar:
        x, y = inputs
        assert y.value != 0, "Division by zero"
        ctx["x"] = x
        ctx["y"] = y
        return Scalar(x.value / y.value)

    @staticmethod
    def backward(grad_output: float, ctx=None) -> List[float]:
        x, y = ctx["x"].value, ctx["y"].value
        return [grad_output / y, -x * grad_output / (y ** 2)]

代码简短而且清爽,适合读者学习。

4.2 数值类型

由于这个 codebase 体量不大,因此只允许使用 float 的包装类 Scalar 作为数值类型。其中 Value 类是 Scalar 类的基类,其定义并实现了反向传播的机制,如下:

class Value:
    def __init__(self, op: Optional[Operation]):
        self.op = op
        self.grad = 0.

    def zero_grad(self):
    	# 梯度清零,类似于 PyTorch
        self.grad = 0.

    def backward(self, grad_output: Optional[float] = None):
    	# 反向传播的实际执行,就是从此节点,迭代地把累积梯度乘积向更前的节点传播
    	# 等节点根据所传入的累积梯度乘积,更新完自身的梯度值后,就继续进行此过程
    	# 注:在保证 DAG 的前提下,此过程相等于遍历图上的所有不同路径
        grad_output = grad_output if grad_output is not None else 1.
        self.grad += grad_output
        if self.op is not None:
            prev_grads = self.op.backward(grad_output, ctx=self.op.ctx)
            for input, prev_grad in zip(self.op.inputs, prev_grads):
                input.backward(prev_grad)

至于 Scalar 类,只是实现了 __add__ 之类的加减乘除的 Dunder 函数的封装类,大致如下:

class Scalar(Value):
    def __init__(self, value: numbers.Number, op: Optional[Operation] = None):
        super().__init__(op)
        self._value = float(value)

    def __add__(self, other):
        from operation import AddOperation
        if isinstance(other, Scalar):
            op = AddOperation()
            return op(self, other)
        elif isinstance(other, numbers.Number):
            op = AddOperation()
            return op(self, Scalar(other))
        else:
            raise TypeError("unsupported type")

	... ...

由于 Scalar 类并不包括太多实际操作,因此完整代码供有兴趣的读者自行查看。

4.3 运行结果

详细代码可以查看代码仓库中的 example.py,结果如下:

example1:
x=10.0, y=2.0, r=x+2*y=14.0
=> x.grad=1.0, y.grad=2.0

example2:
x=10.0, r=x*x=100.0
=> x.grad=20.0

example3:
x=10.0, r=x*(x+1)=110.0
=> x.grad=21.0

example4:
x=8.0, y=4.0, r=x/y=2.0
=> x.grad=0.25, y.grad=-0.5

example5:
x=3.0, r=1/(x*x+1)=0.1
=> x.grad=-0.06

example6:
x=8.0, y=3.0, r=(x*x+1)/(y*y-1)=8.125
=> x.grad=2.0, y.grad=-6.09375

以上六个例子的运算结果均正确。


5. 杂谈

事实上,我这个 demo 和 PyTorch 一样,采用的是动态计算图的形式,即计算图是在运算的过程中实时产生。相反的,Tensorflow 就是采用静态计算图,其计算图需要在一开始就进行编译并固定。

相较于我这个毫无优化的 demo,PyTorch 对于计算图的优化则是出神入化。首先在这个计算图的迭代过程中,明显可以发现,不同路径之间的乘积是可以并行计算的。

同时,从计算图机制的定理中可以发现,由于各个路径上的梯度最终是相加起来的,因此并行下最好的实现方式就是将各个变量的梯度都初始化为 0,否则梯度相加后会出错。这也是为什么 PyTorch 训练时,会需要 zero_grad() 这一步。当然,笔者的实现中也仿效了这一设计。


6. 总结

本文从程序员的角度,总结出了计算图机制下的运行定理,并给出了约 200 行的代码实现,希望能够帮助所有正在入门机器学习的人。

如果您觉得本文有价值,还希望您能给我的文章点个赞、收藏和关注的三连,我们下期再见!ヾ( ̄▽ ̄)ByeBye

最后的最后,附上本文代码的 repo 地址:toy_computational_graph,希望读者能点几个 star 支持一下!


  1. Learning representations by back-propagating errors ↩︎

本文含有隐藏内容,请 开通VIP 后查看