本发明实施例涉及深度学习,尤其涉及一种面向分裂联邦学习的客户端选取方法和装置。
背景技术:
1、为了在边缘设备资源受限,系统异构和数据异构的边缘计算系统中,实现高效的分布式模型训练,目前已有一些分裂联邦学习技术被提出。
2、分裂联邦学习将整个模型分为两个子模型,即底层模型和顶层模型,分割层将模型分割开来。底层模型在工作者上进行训练,而顶层模型在服务器上进行训练,极大的减轻了边缘节点的计算负担。具体参见图1,在常见的分裂联邦学习方法中,其基本过程包括三个主要步骤:即在边缘节点和服务器上进行前向传播,在边缘节点和服务器上进行反向传播,在服务器上对节点端模型进行全局聚合。
3、然而,由于不同的边缘节点常常拥有异构的计算能力和通信能力,能力强的边缘节点通常需要等待慢的节点,造成了训练效率的低下。并且边缘节点采集的本地数据通常取决于它们的功能和/或位置,在所有边缘节点上造成了非独立同分布的本地数据。非独立同分布数据会降低收敛速度,甚至破坏模型的准确性。所以现有的分裂联邦学习只关注了如何训练大模型,减少边缘设备的负担,并不能解决系统异构和数据异构的问题,无法在边缘计算系统中进行高效的分布式模型训练。
技术实现思路
1、本发明提供一种面向分裂联邦学习的客户端选取方法和装置,以实现在系统异构和数据异构的边缘计算系统中的高效联邦模型训练。
2、第一方面,本发明实施例提供了一种面向分裂联邦学习的客户端选取方法,完整的模型被分裂成服务端模型和节点端模型,所述服务端模型在服务器上进行训练,所述节点端模型在边缘节点上进行训练,包括:
3、s1、在每个通信轮开始时,服务器收集各候选边缘节点的状态信息;
4、s2、根据各候选边缘节点的状态信息,从各候选边缘节点中确定参与本轮训练的目标边缘节点数量,以及所述目标边缘节点的批量大小配置和特征合并配置;
5、s3、所述目标边缘节点根据所述批量大小配置执行前向传播,将训练后的模型特征推送到服务器中进行合并;
6、s4、根据所述特征合并配置对模型特征进行合并,服务器根据合并的特征对所述参数端模型执行前向传播和反向传播,以获得更新的服务端模型;
7、s5、当本轮通信的迭代完成后,服务器将所有节点端模型聚合以获得下一通信轮的节点端模型。
8、可选的,所述各候选边缘节点的状态信息包括各候选边缘节点的数据分布信息、计算和通信的状态信息。
9、可选的,所述s2具体包括:
10、将目标边缘节点总数的入口带宽之和不能超过服务器的入口带宽作为约束条件,采用贪心的方式确定目标边缘节点的数量和各目标边缘节点对应的批量大小;
11、根据所述目标边缘节点的数据分布与目标边缘节点的批量大小之间的关系,对所有目标边缘节点采取搜索的方式确定批量大小配置和特征合并配置。
12、可选的,目标边缘节点的数据分布与目标边缘节点的批量大小之间的关系具体为:
13、
14、其中,bi+…+bj表示目标边缘节点中节点i到节点j总批量大小,节点i到节点j的数据分布为独立同分布;
15、mh为在第h轮将所有模型特征合并成的独立同分布的模型特征,kh为所选的目标边缘节点的数量。第二方面,本发明实施例还提供了一种面向分裂联邦学习的客户端选取装置,完整的模型被分裂成服务端模型和节点端模型,所述服务端模型在服务器上进行训练,所述节点端模型在边缘节点上进行训练,包括:
16、信息收集模块,用于在每个通信轮开始时,服务器收集各候选边缘节点的状态信息;
17、决策确定模块,用于根据各候选边缘节点的状态信息,从各候选边缘节点中确定参与本轮训练的目标边缘节点数量,以及所述目标边缘节点的批量大小配置和特征合并配置;
18、目标边缘节点训练模块,用于所述目标边缘节点根据所述批量大小配置执行前向传播,将训练后的模型特征推送到服务器中进行合并;
19、服务端模型更新模块,用于根据所述特征合并配置对模型特征进行合并,服务器根据合并的特征对所述参数端模型执行前向传播和反向传播,以获得更新的服务端模型;
20、节点端模型更新模块,用于当本轮通信的迭代完成后,服务器将所有节点端模型聚合以获得下一通信轮的节点端模型。
21、本发明通过根据各候选边缘节点的状态信息,进而从各候选边缘节点中确定参与本轮训练的目标边缘节点数量,以及所述目标边缘节点的批量大小配置和特征合并配置,不仅要考虑了节点的计算能力和通信能力,还考虑了节点的数据分布情况,能够同时解决边缘计算系统中边缘设备资源受限,系统异构和数据异构的问题,可以实现在系统异构和数据异构的边缘计算系统中的高效联邦模型训练。
1.一种面向分裂联邦学习的客户端选取方法,完整的模型被分裂成服务端模型和节点端模型,所述服务端模型在服务器上进行训练,所述节点端模型在边缘节点上进行训练,其特征在于,包括:
2.根据权利要求1所述的方法,其特征在于,所述各候选边缘节点的状态信息包括各候选边缘节点的数据分布信息、计算和通信的状态信息。
3.根据权利要求2所述的方法,其特征在在于,所述s2具体包括:
4.根据权利要求3所述的方法,其特征在于,目标边缘节点的数据分布与目标边缘节点的批量大小之间的关系具体为:
5.一种面向分裂联邦学习的客户端选取装置,完整的模型被分裂成服务端模型和节点端模型,所述服务端模型在服务器上进行训练,所述节点端模型在边缘节点上进行训练,其特征在于,包括: