课程主页:CS224W | Home
课程视频链接:斯坦福CS224W《图机器学习》课程(2021) by Jure Leskovec
1 前言
上一篇文章介绍了如何对GNN模型进行图特征增强和图结构增强,本篇文章将会继续介绍如何来训练一个GNN模型:
2 GNN训练途径
GNN训练途径大致包含以下过程:
输入图数据→构建GNN模型→输出节点嵌入→预测头(在不同粒度的任务下,将节点嵌入转换为最终需要的预测向量)→得到预测向量和标签→选取损失函数 & 选取评估指标
(前三个部分已经学习过,接下来将学习剩下的部分)
2.1 预测头(Prediction Head)
不同任务级别下的预测头不同:节点级别,边级别,图级别
2.1.1 节点级别的预测头
节点级别:直接用节点嵌入进行预测
GNN得到的节点嵌入:(d维)
预测任务向量:k维
- 分类任务:在k个类别之间做分类
- 回归任务:在k个目标(节点特征)上做回归
我们通过将节点嵌入映射到预测空间,即d维到k维:
注:表示预测标签,
表示真实标签
2.1.2 边级别的预测头
边级别:用一对节点嵌入进行预测
对于 的可选方法有:
1. Concatenation + Linear(拼接 + 线性)
(这种方法在讲GAT的时候介绍过,注意力机制也可以才用这种方法将节点间传递的信息转换为注意力系数e)
这时,将2d维的嵌入(d维和d维拼接之后)映射到k维的输出 。
2. 点积
这种方法只能应用于1维的预测任务(因为点积输出结果只有一维),例如链接预测任务(预测边是否存在)
如果要应用到k维的预测任务中:跟GAT中的多头注意力机制类似,多算几组然后连接
式子中每一维的预测值都包含可训练的参数
2.1.3 图级别的预测头
图级别:用图中所有节点的嵌入向量来做预测
这时的头函数类似于GNN单个层中的聚合函数 ,都是将若干嵌入聚合为一个嵌入。
对的可选方法有:
如果想比较不同大小的图,Mean方法可能比较好(因为结果不受节点数量的影响);如果关心图的节点数或图的大小,Sum方法可能比较好。
Global pooling面临的问题
总之,以上三个方法在小图上的表现都很好,但是Global pooling与可能会面临丢失信息的问题。
举例:使用一维节点嵌入(每个节点嵌入仅是一个值),假设
显然两个图的节点嵌入差别很大,图结构很不相同。但是经过Global sum pooling后:
这两个图有一样的预测值,无法区分两个图。
解决方法:Hierarchical Global Pooling(分层聚合节点嵌入)
举例:使用进行聚合,先分别聚合前两个节点和后三个节点的嵌入,然后再聚合这两个嵌入,如下图
这样我们就可以区分两个图了
在这种分层的Global Pooling中,我们如何确定聚合内容的先后以及如何进行分层次的聚合呢?
在实际应用中,尤其是在图分类的任务上,我们通常使用DiffPool的想法:
DiffPool原文传送门:Hierarchical Graph Representation Learning with Differentiable Pooling
大致来说,DiffPool的思路是同时使用两个GNN模型,一个GNN模型用于正常计算节点嵌入,而另一个GNN则用于学习该节点属于哪一个聚类,然后再将每一个聚类看作一个子图单独进行pool。之后,将每一类得到的预测值作为一个新的节点,同时保留每个类之间的连边,产生一个新的图。接着在新的图上再进行聚类和pool,重复这一过程,直至得到最终的预测值。

2.2 预测值和标签值(Predictions & Labels)
2.2.1 有监督学习 vs. 无监督学习
- 有监督学习:标签来自图的外部,算法基于带标签的数据进行学习,同时数据提供答案,算法可利用该答案来评估其在训练数据方面的准确性,比如分类问题;
- 无监督学习:信号来自图的本身,使用的是无标签的数据,算法需要自行提取特征和规律来理解这些数据,比如聚类、链接预测,都不需要任何外部信息;
- 有时两种情况区分比较模糊,在无监督学习任务中也可能有“有监督任务”,比如训练GNN来预测节点聚类系数。
- 还有一种特殊的“无监督学习”被称为“自监督学习”
Self-supervised:自监督学习是指直接从大规模的无监督数据中挖掘自身监督信息来进行监督学习和训练的一种机器学习方法(可以看成是无监督学习的一种特殊情况),自监督学习需要标签,不过这个标签不来自于人工标注,而是来自于数据本身
2.2.2 有监督学习的标签
有监督学习的标签通常来自一些具体的情况中,例如:
- 节点级别——引用网络中,节点(论文)属于哪一学科
- 边级别——交易网络中,边(交易)是否有欺诈行为
- 图级别——图(分子图)是药的概率
2.2.3 无监督学习的信号
在没有外部标签时,可以使用“自监督”的思想,在图自身寻找信号来作为有监督学习的标签。举例来说,一个GNN模型可以预测:
- 节点级别:节点统计量(如聚类系数clustering coefficient, PageRank等)
- 边级别:链接预测(预测两节点间的隐藏边)
- 图级别:图统计量(如预测两个图是否同构)
这些任务都是不需要外部标签的。
2.3 损失函数
损失函数用于衡量预测值和标签值之间的差异,优化损失函数的过程就是反向传播优化神经网络参数的过程。
2.3.1 Settings
假设我们有N个数据点,每个数据点的类型可以是节点级别 / 边级别 / 图级别,其预测值和标签分别表示如下:
2.3.2 分类 / 回归
- 分类任务:节点的标签
是离散数值,例如节点分类任务的标签是节点的类别
- 回归任务:节点的标签
是连续数值,例如预测分子图是药的概率具体是多少
两种任务都能用GNN完成,其区别主要在于损失函数和评估方法不同。
2.3.3 分类任务的损失函数——交叉熵
已在第六讲中介绍过,此处不再赘述。
2.3.4 回归任务的损失函数——均方误差(L2损失)
Jure介绍这个MSE没有取均值,好像不是MSE,概念区分: 区分混淆概念之L2范数,L2范数损失,L2损失,均方误差
2.4 评估指标
2.4.1 回归任务的评估指标
均方根误差RMSE:
- 平均绝对误差MAE(L1损失):
2.4.2 分类任务的评估指标
- 二分类问题评价指标
- Accuracy:准确率(分类正确的观测占所有观测的比例)
- Precision:精度(预测为正的样本中真的为正(预测正确)的样本所占比例)
- Recall:召回率(真的为正的样本中预测为正(预测正确)的样本所占比例)
- F1-Score:F1分数(Precision和Recall的调和平均值,信息检索、文本挖掘等领域常用)
ROC(Receiver operating characteristic)曲线:
- 横轴FPR:FPR越大,预测正类中实际负类越多
- 纵轴TPR:TPR越大,预测正类中实际正类越多
理想目标:TPR=1,FPR=0,即图中(0,1)点,故ROC曲线越靠拢(0,1)点,越偏离45度对角线越好
ROC曲线详解:浅显易懂介绍ROC曲线 - 知乎
ROC AUC:(Area under the ROC Curve:ROC曲线下的面积)
随机抽取一个正样本和一个负样本,正样本被识别为正样本的概率比负样本被识别为正样本的概率高的概率。
人话:比较ROC曲线下面积做为二分类器优劣的判断标准,ROC曲线下面积越大,正确率越高。
ROC AUC详解:[概念回顾]ROC-AUC到底是个什么鬼 - 知乎
注:AUC=0.5叫随机分类器,AUC=1叫完美分类器.
2.5 拆分数据集
2.5.1 拆分方式
1. 固定拆分(Fixed split):只拆分一次数据集(此后一直使用这种拆分方式)
- 训练集:用来优化GNN模型的参数
- 验证集:用于调整超参数(用于控制模型行为的参数,这些参数不是通过模型本身学习而来的,例如网络层数、网络节点数、学习率等,需要多次使用)
- 测试集:测试GNN模型表现的数据集(仅使用一次)
传送门:验证集与测试集的区别
2. 随机拆分(Random split):随机将数据集划分为训练集 / 验证集 / 测试集,最后使用不同随机拆分后计算结果的平均值
2.5.2 拆分图结构数据集的特殊性
在拆分时,我们希望三部分数据集之间相互独立,没有交叉。但由于图结构数据的特殊性,如果直接像普通数据一样拆分,我们有时不能保证三部分数据集之间相互独立。
例如:在节点分类任务中,每个节点就是一个数据,这时,测试集里的节点可能与训练集里的节点有边相连,在消息传递的过程中就会互相影响,导致信息泄露。

2.5.3 解决方法1:Transductive setting
使输入的整个图在训练集 / 验证集 / 测试集中共享(即测试集、验证集、训练集在同一个图上,整个数据集由一张图构成),仅拆分节点标签。
具体来说:
- 在训练过程中,我们使用整个图来计算节点嵌入,但仅使用节点1、2的标签进行训练;
- 在验证过程中,我们使用整个图来计算节点嵌入,但仅使用节点3、4的标签进行评估。
这种方法仅适用于节点 / 边预测任务。
2.5.4 解决方法2:Inductive setting
去掉拆分后各个数据集之间的连边,得到多个相互独立的子图。(测试集、验证集、训练集分别在不同图上,整个数据集由多个图构成)这样,不同数据集之间的节点就不会互相影响。
具体来说:
- 在训练过程中,我们使用节点1、2构成的子图来计算嵌入,使用节点1、2的标签进行训练;
- 在验证过程中,我们使用节点3、4构成的子图来计算嵌入,使用节点3、4的标签进行训练。
这种方法适用于节点 / 边 / 图预测任务。
2.5.5 举例
1. 节点分类任务
Transductive:所有拆分后的数据集共享全图结构,但只能独享专属的节点的标签
Inductive:将三个图分别作为训练集 / 验证集 / 测试集,如果没有多个图就将一个图拆分成3部分,并去除各部分之间连接的边
2. 图分类任务
因为我们有独立的图,所以不需要考虑数据集之间会产生交叉的情况。仅需要使用Inductive setting的拆分方法,简单地拆分为训练集 / 验证集 / 测试集即可。
3. 链接预测任务
目标:预测缺失的边
这是个无监督/ 自监督的任务,需要自行创建标签,自行拆分数据集。
具体来说,我们还需要隐藏一些边,然后让GNN预测这些边是否存在。
在做法上,我们要将边划分两次。
第一步:将原始图中的边分为两个类型:message edges(用于GNN传递信息)和supervision edges(用于训练目标函数)。只留下message edges,不将supervision edges传入GNN。
第二步:将边划分为训练集 / 验证集 / 测试集。
第二步的划分选择又分为两种:
选择一:Inductive link prediction split
假设我们有一个包含三个图的数据集。 每个拆分后的split即一个独立的图,且每个split里的边按照第一步分为message edges和supervision edges。
选择二:Transductive link prediction split(通常是链接预测默认的设置方式)
假设我们的数据集仅是一个图:,Transductive的思想是让所有splits都共享整个图,而这时的边既是图结构又是标签,所以我们需要在验证和测试时留下验证集 / 测试集中的validation / test edges;在训练时留下训练集中的supervision edges(显然)
精确操作:
- 训练阶段:用 training message edges 预测 training supervision edges;
- 验证阶段:用 training message edges 和 training supervision edges 预测 validation edges;
- 测试阶段:用 training message edges 和 training supervision edges 和 validation edges 预测 test edges。
上述过程是个链接越来越多,图变得越来越稠密的过程。这是因为在训练之后,supervision edges就被GNN感知到了,所以一个理想的模型在验证时应该用 supervision edges 来进行消息传递,在测试时也应该用 supervision edges 来进行消息传递。
总之,链接预测的设置十分复杂, 不同论文中有不同的链接预测设置方式。
3 总结
至此,GNN的模型的整个训练途径介绍完毕:输入图数据→构建GNN模型→输出节点嵌入→预测头(在不同粒度的任务下,将节点嵌入转换为最终需要的预测向量)→得到预测向量和标签→选取损失函数 & 选取评估指标→拆分数据集(训练集 / 验证集 / 测试集)
与此同时,第七讲和第八讲所介绍的GNN Design Space的内容也全部结束:GNN单个层的设置(信息转换+信息聚合:GCN、GraphSAGE、GAT)→GNN层之间的连接方式(过平滑问题以及skip connections)→GNN图增强(特征增强和结构增强)→训练目标函数
4 参考文献
http://web.stanford.edu/class/cs224w/slides/08-GNN-application
cs224w(图机器学习)2021冬季课程学习笔记10 Applications of Graph Neural Networks_诸神缄默不语的博客