大模型LLM:为什么简单的乘法ChatGPT会算错?

发布于:2024-05-03 ⋅ 阅读:(21) ⋅ 点赞:(0)

首先“心算”三位整数乘法不管对人类还是对模型来说都不简单的。如果使用CoT的方式就类似于“笔算”,如果使用编程的方式就类似于人拿着计算器算。我将问题更精确一点地表述为“模型如何在心算多位整数乘法上接近或超过人的水平?”

这个问题困扰了我很久,简单乘法是推理能力的一种体现,如果可以解决,那大模型的整体性能应该还能大幅提升。于是我做了很多实验来寻找GPT连简单乘法都难以解决的技术层面原因。直到我借助CRNN里的CTCLoss来解码不定长序列,而不是通过TransformerDecoder以自回归的方式生成,这个问题才得到了很好的解决。下面我从技术层面上分析下可能的原因以及如何解决。

首先让我们看看GPT是如何训练的。 假设有足够多的样本,形式如下:

56*123=6888
122*222=27084
777*512=397824
612*324=198288
243*753=182979
896*267=239232
368*12=4416
......

GPT的原理是每次只看前n个字符,预测后续的1个字符 以56*123=6888为例,每一次模型迭代,等同于10次小的迭代的相加:

1、 输入5,预测6
2、 输入56,预测*
3、 输入56*,预测1
4、 输入56*1,预测2
5、 输入56*12,预测3
6、 输入56*123,预测=
7、 输入56*123=,预测6
8、 输入56*123=6,预测8
9、 输入56*123=68,预测8
10、输入56*123=688,预测8

细心的读者应该很快能发现,迭代1~迭代5是几乎无效的,因为对于等号前面的数字只能预测出是个0~9的字符,而没法预测出具体的值。实验表明在模型迭代过程中,屏蔽掉迭代1~迭代5,可以加快模型的收敛速度。不过虽然加快了收敛速度,但模型仍很难达到较高精度,特别是预测结果中靠近中间的数字,例如896*267,结果“239232”中间的“9”、“2”两个数字更难预测。显然还有更深层的原因我们没有发现。

那是不是因为GPT每一次小的迭代“只看前n个字符,预测后续的1个字符”导致的?也就是说每一次小的迭代都只告诉模型答案中的一个字符,而不告诉模型完整的答案。可是如果在每一次小的迭代中不告诉模型完整的答案,那么即使神经网络本身的推理能力再强,每一次小的跌代也很难学习到足够多的信息。打个比方,一个数学老师这么教他的学生:

同学们,今天我教大家2位数乘法,
11*11的结果中第一位是1
21*11的结果中第一位是2
21*21的结果中第一位是4
21*32的结果中第一位是6
32*68的结果中第一位是2
......
请同学们好好学习,学会了第一位我再教你们第二位,学会了第二位我再教你们第三位...

试想一下如果老师真这么教,那学生更多的是依靠记忆能力,而不是推理能力。有读者可能会问在迭代10中完整的答案不是出现了吗?可惜的是神经网络只会学习输入与目标之间的映射关系,而不会从输入中学到任何信息。

所以我先假设“Transformer的推理能力是足够的,之所以学不会是因为老师教的方式不对”,那接下来该如何验证该假设为真呢?由于我对CRNN比较熟悉,CRNN可以预测不定长序列,并且每一次迭代,都会告诉模型完整的答案。于是我将CRNN模型进行了一定的改造来学习3位数乘法,模型大致结构如下: (B, T) -- nn.Embedding --> (B, T, C) -- TransformerEncoder层 --> (B, T, C) -- MLP层 --> (B, T, C o u t C_{out} )

训练样本对的形式如下:

--------56*123=    6888
-------122*222=    27084
-------777*512=    397824
-------612*324=    198288
-------243*753=    182979
-------896*267=    239232
--------368*12=    4416
......

使用CTCLoss进行梯度更新。经过单卡20分钟左右的训练,训练损失基本接近0,测试准确率约99.994%。

为了进一步验证假设,我设计了对比实验:保持Transformer层不变,将模型改成多分类模型,共10个类别代表数字0~9。训练样本对的形式如下:

--------56*123=        6
-------56*123=6        8
------56*123=68        8
-----56*123=688        8
-------122*222=        2
------122*222=2        7
-----122*222=27        0
----122*222=270        8
---122*222=2708        4
......

原始的一条样本被拆成了多条子样本,打包在同一个batch里,模拟GPT的训练方式。使用交叉熵损失函数进行梯度更新。实验结果:模型收敛很慢,训练了很长时间,测试准确率都不是很高。

结论:Transformer的推理能力是足够强大的,基于Transformer的GPT作为一个学生推理能力也是足够的,但由于老师的教学方法不对,导致连3位整数乘法都很难学会。

相关实验细节:

补充一下,既然改用CTCLoss后可以学会连GPT都很难学会的3位整数乘法,那是否意味着CTCLoss可以广泛应用于大语言模型呢?很遗憾,答案暂时是否定的,因为GPT可以更好地解决模糊性问题,比如同一个问题,有多种回答都是正确的,这种情况不适合CTCLoss。


2024.04.12更新

这几天关注到Cohere公司开源的大模型Command R+,我简单测试了一下,发现Command R+对三位/四位整数乘法的精度相当高(不使用CoT,编程等任何辅助方式)。这让我意识到GPT也是能够将简单乘法训练到几乎100%测试准确率的,我之前训练不出来是因为自己对TransformerDecoder的认识不够深入。

受 的启发,我重新设计了基于TransformerDecoder Only的GPT模型方案,通过合成大量的3位整数乘法训练样本,使用CrossEntropyLoss损失函数进行梯度更新,单卡训练约5个小时后,训练损失几乎为0,测试准确率约99.991%。随机生成10万个样本进行测试,错了5个:

698 * 716 != 509768, 499768(expected)
949 * 959 != 900091, 910091(expected)
616 * 13 != 7008, 8008(expected)
95 * 63 != 6085, 5985(expected)
237 * 38 != 8006, 9006(expected)

结论:使用足够多的样本,训练足够多的轮数,GPT也是可以在3位整数乘法上达到99.99%以上准确率的,相对于TransformerEncoder+CTCLoss方案,TransformerDecoder+CrossEntropyLoss方案收敛更慢,需要数倍的训练时间。但现实情况是,大模型的训练成本太高,一般也就训练一轮,自然在简单乘法上达不到很高的准确率。

相关实验细节:


由于本人水平有限,以上观点难免存在错误,若有错误,欢迎在评论中指出。


网站公告

今日签到

点亮在社区的每一天
去签到