当客户端收到服务器下发的全局模型后,它的训练过程与传统的集中式训练在算法上非常相似,但有几个关键区别。其核心流程如下图所示:
1、客户端本地训练的核心步骤
假设客户端是一个手机,它收到了全局模型 W_global
。
1. 加载模型与数据:
手机将收到的全局模型参数
W_global
加载到本地的模型架构中。现在,手机上的这个模型和服务器上的模型一模一样。手机从自己的存储中加载本地私有数据(例如,你的照片、打字习惯等)。这些数据永远不会被发送出去。
2. 本地迭代训练 (Local Epochs > 1):
这是与简单一步梯度下降最关键的差别。客户端不会只看一次数据就更新,而是会进行多次(通常Epochs > 1)完整的本地训练。
Local Epoch (本地轮次):指本地数据集被完整遍历一次的次数。例如,如果本地有1000张图片,一个Epoch就是模型看完了这1000张图片。
Batch (批):为了高效,数据会分成多个小批(Batches)。例如,1000张图片,如果批量大小(Batch Size)是100,那么1个Epoch就包含10个Batches或10次梯度更新。
对于每一个Batch,客户端会执行以下标准步骤:
a. 前向传播 (Forward Pass):
输入一个Batch的本地数据到模型中,得到模型的预测结果。
b. 计算损失 (Loss Calculation):
将模型的预测结果与真实的标签进行比较,通过一个损失函数(如交叉熵、均方误差等)计算出“误差”有多大。
c. 反向传播 (Backward Pass / Backpropagation):
计算损失函数相对于每一个模型参数(权重)的梯度。梯度指明了“为了减小误差,各个参数应该朝哪个方向、以多大的幅度调整”。
d. 模型更新 (Model Update - SGD):
使用优化器(最常见的是随机梯度下降SGD)根据计算出的梯度来更新本地的模型参数。
更新公式:
这个步骤会在当前客户端的本地数据上重复多次(由Local Epochs和Batch Size决定)。
3. 训练完成:
在完成了所有指定的本地训练轮次(Local Epochs)后,初始的
W_global
已经被更新成了一个新的、更适合本地数据分布的模型,我们称之为 W_local。
4. 计算更新量 (Optional but Important):
有些实现中,客户端会直接上传整个
W_local
。但是也存在一定的风险:即使不拿到原始数据,攻击者(包括好奇的服务器)也可以通过分析模型参数来推断用户的隐私信息。可能会遭遇的攻击有:
(1) 模型逆向攻击 (Model Inversion Attacks):通过分析模型参数,攻击者有可能重构出训练数据的特征。例如,针对图像识别模型,可能生成出与训练图片相似的模糊图像。
(2)成员推断攻击 (Membership Inference Attacks):攻击者通过分析模型对某个数据点的响应(置信度等),判断该数据点是否属于模型的训练集。例如,判断某人的医疗记录是否被用于训练一个疾病诊断模型。
(3)属性推断攻击 (Property Inference Attacks):攻击者可以推断出训练数据集的全局属性。例如,通过分析多个手机输入法模型的更新,推断出某个特定人群的用词习惯或兴趣偏好。
而另一种更常见的做法是计算更新量(Update):
ΔW = W_local - W_global
上传
ΔW
有时在通信上更高效,同时也为后续的隐私保护技术(如差分隐私)提供了便利。
需要注意的是:客户端上传局部参数还是ΔW需要绝对统一,并且服务器必须预先知道客户端上传的是什么。 这依赖于预先定义好的、所有参与者都必须严格遵守的通信协议。
1. 为什么需要统一?通信协议的作用
联邦学习不是一个随意的过程,而是一个由中央服务器严格 orchestrated(编排)的、有状态的迭代过程。为了保证所有客户端和服务器能正确协同工作,它们必须遵循一个预先共同定义好的 “游戏规则”,也就是通信协议。
这个协议会明确规定每一轮通信中:
下行(Server -> Client):发送什么(例如:当前的全局模型
W_global
,本轮训练的配置如本地轮次E
、学习率η
等)。上行(Client -> Server):返回什么(例如:是完整的
W_local
还是更新量ΔW
)。
所有客户端都必须按照同一套规则行事。服务器在聚合时,假设所有客户端上传的都是同一种类型的数据。如果有的客户端传 ΔW
,有的传 W_local
,服务器将无法进行正确的聚合计算,会导致整个训练过程完全失败。
2. 服务器怎么知道?协议约定的具体内容
服务器知道客户端上传的是什么,是因为这是服务器自己要求的。
在每一轮通信开始时,服务器在下发给客户端的消息中,就已经包含了明确的指令,告诉客户端“你需要如何训练,以及完成后上传什么”。这个过程是这样的:
服务器下发任务包:
服务器发送给客户端的不仅仅是一个模型文件W_global
,而是一个结构化的任务消息(通常是一个配置文件或特定格式的数据包)。这个消息包可能包含:model_weights
: 当前的全局模型参数W_global
。training_hyperparameters
: 本地训练的配置,如local_epochs
,batch_size
,learning_rate
。upload_type
: 一个关键的指令字段,明确指定客户端上传的数据类型。如果这个字段的值是
"full_weights"
,客户端就知道它需要上传完整的W_local
。如果这个字段的值是
"weight_update"
,客户端就知道它需要计算并上传ΔW = W_local - W_global
。
客户端按指令执行:
客户端收到这个任务包后,会解析这些配置。它加载
model_weights
来初始化本地模型。它根据
training_hyperparameters
进行本地训练。训练完成后,它根据
upload_type
的指示,准备相应的数据并上传。
服务器按预期聚合:
因为服务器是自己发出的指令,所以它非常清楚地知道即将收到的每个更新是什么格式。它会用对应的方式进行聚合。如果要求上传
ΔW
,聚合公式可能是:
(其中
η
是一个全局学习率,ΔW_avg
是加权平均后的更新)如果要求上传
W_local
,聚合公式就是标准的加权平均:
2、关键特点与为什么重要
主要计算在本地 (Computation at the Edge):最耗计算资源的训练过程完全在客户端设备上完成,充分利用了边缘计算能力,减轻了服务器的负担。
通信效率 (Communication Efficiency):
这是FedAvg算法的精髓。通过进行多次本地迭代(Local Epochs > 1),客户端在本地做了大量“学习”工作,最终只将学习成果(一次更新)上传。
这极大地减少了服务器与客户端之间的通信轮数,使得在带宽受限的网络(如移动网络)中进行大规模协作学习成为可能。如果每次梯度更新都通信,成本是无法承受的。
应对非独立同分布数据 (Handling Non-IID Data):每个客户端的数据都是独特的(你的手机图片和我的完全不同)。本地训练允许模型在每个客户端的独特数据分布上进行微调,从而捕捉到更丰富的数据特征。
隐私保护 (Privacy Preservation):再次强调,所有上述过程都在设备本地完成。只有最终的模型参数(或更新)被加密上传,原始数据纹丝不动。
3、例子解释
把它想象成学生在家自学:
老师(服务器)发下一本标准教材(
W_global
)。学生(客户端)回家后,不是只做一道题就告诉老师,而是结合自己的练习册(本地数据)反复练习、复习了好几遍(Local Epochs),在教材上做了很多笔记和修正。
最后,学生不是把整本练习册(原始数据)交给老师,而是把自己的修改建议和笔记总结(
ΔW
或W_local
)提交上去。
总结来说,客户端的训练就是一个标准的、基于梯度下降的模型优化过程,但它发生在本地,并且在多次迭代后才进行一次通信,这是联邦学习高效且能保护隐私的关键所在。