本发明涉及联邦学习,尤其涉及一种联邦学习方法、装置、系统、存储介质及电子设备。
背景技术:
1、联邦学习作为一个分布式机器学习框架,多个计算节点协同训练同一个模型。由于在训练过程中,计算节点与中心服务器节点仅交互模型参数而没有交互原始数据,该训练方法可以避免各计算节点原始数据共享。
2、在实现本发明的过程中,发现现有技术中至少存在以下技术问题:由于各计算节点本地训练数据的异构性,导致了各计算节点的最优参数往往不一样,同时,这些最优参数的均值也往往不是全局最优参数。不同计算节点在本地迭代得到的模型参数的差异不利于模型整体训练。
技术实现思路
1、本发明提供了一种联邦学习方法、装置、系统、存储介质及电子设备,以降低各计算节点在训练过程中陷入本地局部最优的情况。
2、根据本发明的一方面,提供了一种联邦学习方法,应用于计算节点,所述方法包括:
3、在任一次全局迭代过程中,接收当前次全局迭代过程中的修正梯度项;
4、在当前次全局迭代过程中的任一次局部迭代过程中,确定机器学习模型在当前次局部迭代的局部梯度项;
5、基于所述修正梯度项对当前次局部迭代的局部梯度项进行修正,并基于修正后的目标梯度项对机器学习模型的模型参数进行更新,并基于所述更新后的模型参数执行下一次局部迭代过程;
6、在完成局部迭代过程的情况下,将当前次全局迭代过程中的模型参数变化发送至中心服务器节点,其中,所述中心服务器节点基于所述当前次全局迭代过程中的模型参数变化确定下一次全局迭代过程所需的修正梯度项。
7、根据本发明的另一方面,提供了一种联邦学习方法,应用于中心服务器节点,所述方法包括:
8、在任一次全局迭代过程中,接收各计算节点发送的模型参数变化;
9、基于各计算节点发送的模型参数变化对模型参数进行全局更新,并确定下一次全局迭代过程的修正梯度项;
10、将更新得到的模型参数和下一次全局迭代过程的修正梯度项发送至各计算节点,其中,各所述计算节点进行下一次全局迭代过程。
11、根据本发明的另一方面,提供了一种联邦学习装置,集成于计算节点设备,所述装置包括:
12、全局信息获取模块,用于在任一次全局迭代过程中,接收当前次全局迭代过程中的修正梯度项;
13、修正梯度项确定模块,用于在当前次全局迭代过程中的任一次局部迭代过程中,确定机器学习模型在当前次局部迭代的局部梯度项;
14、局部参数更新模块,用于基于所述修正梯度项对当前次局部迭代的局部梯度项进行修正,并基于修正后的目标梯度项对机器学习模型的模型参数进行更新,并基于所述更新后的模型参数执行下一次局部迭代过程;
15、信息发送模块,用于在完成局部迭代过程的情况下,将当前次全局迭代过程中的模型参数变化发送至中心服务器节点,其中,所述中心服务器节点基于所述当前次全局迭代过程中的模型参数变化确定下一次全局迭代过程所需的修正梯度项。
16、根据本发明的另一方面,提供了一种联邦学习装置,集成于中心服务器节点设备,所述装置包括:
17、局部信息接收模块,用于在任一次全局迭代过程中,接收各计算节点发送的模型参数变化;
18、全局更新模块,用于基于各计算节点发送的模型参数变化对模型参数进行全局更新,并确定下一次全局迭代过程的修正梯度项;
19、信息发送模块,用于将更新得到的模型参数和下一次全局迭代过程的修正梯度项发送至各计算节点,其中,各所述计算节点进行下一次全局迭代过程。
20、根据本发明的另一方面,提供了一种联邦学习系统,包括中心服务器节点和多个计算节点,其中,
21、所述中心服务器节点向各计算节点下发当前次全局迭代过程中的修正梯度项、全局模型参数和动量参数;
22、所述计算节点接收修正梯度项、全局模型参数和动量参数,在当前次全局迭代过程中的任一次局部迭代过程中,确定机器学习模型在当前次局部迭代的局部梯度项;基于所述修正梯度项对当前次局部迭代的局部梯度项进行修正,并基于修正后的目标梯度项对机器学习模型的模型参数进行更新,并基于所述更新后的模型参数执行下一次局部迭代过程;在完成局部迭代过程的情况下,将当前次全局迭代过程中的模型参数变化和局部动量参数发送至中心服务器节点;
23、所述中心服务器节点基于所述当前次全局迭代过程中的模型参数变化和局部动量参数确定下一次全局迭代过程所需的修正梯度项、全局模型参数和全局动量参数,并将所述下一次全局迭代过程所需的修正梯度项、全局模型参数和全局动量参数下发至各计算节点,直到完成全局迭代,得到训练完成的机器学习模型。
24、根据本发明的另一方面,提供了一种电子设备,所述电子设备包括:
25、至少一个处理器;以及
26、与所述至少一个处理器通信连接的存储器;其中,
27、所述存储器存储有可被所述至少一个处理器执行的计算机程序,所述计算机程序被所述至少一个处理器执行,以使所述至少一个处理器能够执行本发明任一实施例所述的联邦学习方法。
28、根据本发明的另一方面,提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机指令,所述计算机指令用于使处理器执行时实现本发明任一实施例所述的联邦学习方法。
29、本实施例的技术方案,通过在计算节点进行本地局部迭代的过程中,通过中心服务器节点下发的修正梯度项,对每一局部迭代过程中的局部梯度项进行修正,减小不同计算节点上得到局部梯度项差异导致的模型参数差异,避免各计算节点在本地训练过程中陷入本地局部最优值的情况,提高训练得到的机器学习模型的泛化性能。
30、应当理解,本部分所描述的内容并非旨在标识本发明的实施例的关键或重要特征,也不用于限制本发明的范围。本发明的其它特征将通过以下的说明书而变得容易理解。
1.一种联邦学习方法,其特征在于,应用于计算节点,所述方法包括:
2.根据权利要求1所述的方法,其特征在于,所述确定机器学习模型在当前次局部迭代的局部梯度项,包括:
3.根据权利要求2所述的方法,其特征在于,所述确定机器学习模型在当前次局部迭代的局部动量参数,包括:
4.根据权利要求3所述的方法,其特征在于,所述基于所述随机梯度和上一次局部迭代的二阶动量,确定当前次局部迭代的二阶动量,包括:
5.根据权利要求1所述的方法,其特征在于,所述基于所述修正梯度项对当前次局部迭代的局部梯度项进行修正,包括:
6.根据权利要求1所述的方法,其特征在于,所述接收当前次全局迭代过程中的修正梯度项,包括:
7.一种联邦学习方法,其特征在于,应用于中心服务器节点,所述方法包括:
8.根据权利要求7所述的方法,其特征在于,所述将更新得到的模型参数和下一次全局迭代过程的修正梯度项发送至各计算节点,包括:
9.一种联邦学习装置,其特征在于,集成于计算节点设备,所述装置包括:
10.一种联邦学习装置,其特征在于,集成于中心服务器节点设备,所述装置包括:
11.一种联邦学习系统,其特征在于,包括中心服务器节点和多个计算节点,其中,
12.一种电子设备,其特征在于,所述电子设备包括:
13.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机指令,所述计算机指令用于使处理器执行时实现权利要求1-6中任一项所述的联邦学习方法,和/或,权利要求7-8中任一项所述的联邦学习方法。