一种高泛化性的个性化联邦学习实现方法

文档序号:32658887发布日期:2022-12-23 22:48阅读:61来源:国知局
一种高泛化性的个性化联邦学习实现方法

1.本发明涉及联邦学习技术领域,尤其涉及一种高泛化性的个性化联邦学习实现方法。


背景技术:

2.联邦学习是指多个相互隔离的孤岛数据集上训练模型的任务,在愈加严格的隐私政策的要求下,传统中心式汇聚多个数据孤岛的数据来进行数据挖掘的方式变得不可行,而单个数据孤岛的有效数据不足,数据驱动的建模和数据挖掘受到限制,此时联邦学习便能发挥作用。通用联邦学习是指,所有客户端在不共享数据的情况下,共同训练一个共识模型,以尽可能地学到来自多个客户端数据的知识。通用联邦学习步骤主要包括:客户端选择、模型分发、模型训练和模型聚合,通过迭代直至收敛得到一个聚合的共识模型。
3.由于联邦学习数据隔离的固有属性,客户端的数据分布不可知,不同客户端模型的学习存在很强的异质性,如通过来自不同地理环境的客户端解决不同的任务客户端,但是此时聚合的共识模型偏向某些客户端从而整体表现不佳。为了处理客户端之间的这种异质性,个性化联邦学习允许每个客户端保留并优化独立的个性化模型,而不是使用全局的共识模型。旨在客户端从联邦学习中获得收益的同时,在本地可见的数据上有更好的表现,即个性化模型的表现优于客户端孤岛式独自训练产生的模型,同时优于联邦共识模型。
4.虽然个性化联邦学习方法为联邦客户端的异质性困境提供了解决方案,但是主流的个性化联邦学习实现方法侧重于在可见数据的性能提升。由于对可见数据的进一步优化,大多数主流方法生成的个性化模型容易过拟合,最终导致较强模型偏向性和模型泛化性降低。然而,模型泛化性是现实场景中需要关注的问题,例如,医院客户端接收来自未知医院的转诊患者的数据,不仅能关注联邦模型在本地可见数据的表现,还可以关注其在未知分布数据上的性能。
5.因此,亟需一种可侧重于模型的泛化性的个性化联邦学习实现方法,在保证个性化联邦学习有效性的同时,还可以提升模型的泛化性。


技术实现要素:

6.针对背景技术中的问题,本发明提供了一种高泛化性的个性化联邦学习实现方法,利用任务独立的个性化批归一化和全局批归一化特征,通过双分支结构同时学习模型的个性化能力和泛化能力,即不仅能有效地提升客户端本地模型面对未知数据的泛化能力,还能保证客户端本地模型在客户端本地数据分布下的个性化能力。
7.第一方面,本发明提供了一种高泛化性的个性化联邦学习实现方法,包括,
8.步骤1:服务端随机初始化双分支结构的全局模型,将得到的初始化模型参数发送至多个选定的客户端;其中,全局模型包括全局任务子模型分支和个性化任务子模型分支;
9.步骤2:每个客户端利用服务端发送的初始化模型参数,初始化双分支结构的客户端本地模型,并利用本地数据进行第一轮本地迭代训练,得到更新后的客户端本地模型;将
更新后的客户端本地模型的统计参数和全局任务子模型的模型参数上传至服务端;
10.步骤3:服务端将所有客户端的全局任务子模型的模型参数进行加权平均计算得到聚合后新的全局任务子模型的模型参数,并将更新后的模型参数发送给多个所选客户端;
11.步骤4:客户端利用服务端发送的全局任务子模型的模型参数,更新客户端本地模型中的全局任务子模型的模型参数,结合本轮联邦训练中迭代训练得到的客户端本地模型中的个性化任务子模型,得到更新的客户端本地模型,完成一轮联邦训练;
12.步骤5:客户端使用步骤4更新的客户端本地模型基于本地数据进行再一轮迭代训练,更新客户端本地模型参数,并将更新后的客户端本地模型的统计参数和全局任务子模型的模型参数上传至服务端,返回步骤3,循环更新客户端本地模型直至满足预设标准。
13.进一步地,服务端使用的全局模型和客服端使用的客户端本地模型结构相同,即模型的特征提取层后添加批归一化层;其中,特征提取层为任务共享层,批归一化层为任务特定层;任务特定层包括全局批归一化层和个性化批归一化层。
14.进一步地,全局任务子模型由任务共享层和全局批归一化层构成;个性化任务子模型由任务共享层和个性化批归一化层构成。
15.进一步地,所述统计参数包括客户端参与训练的数据量。
16.优选地,步骤2中本地迭代训练得到更新后的客户端本地模型的过程具体为:
17.将本地数据x输入客户端本地模型后同时执行两个分支得到两个任务的输出,即全局任务输出yg和个性化任务输出y
l
,通过计算交叉熵损失分别得到全局任务损失lossg和个性化任务损失loss
l

18.交叉熵损失的表达式如下:
[0019][0020]
其中,a取g或l;yj为预测目标,是实际预测结果;m表示参与训练的客户端的数量;
[0021]
利用全局任务损失lossg和个性化任务损失loss
l
得到总体损失loss
overall
,表示式为:
[0022]
loss
overall
=αlossg+(1-α)loss
l
[0023]
其中,α为损失比例系数;
[0024]
结合总体损失和预设的学习率η,客户端通过随机梯度下降和反向传播得到更新的计算模型整体的梯度,得到更新的客户端本地模型的模型参数,客户端本地模型的模型参数更新表达式如下:
[0025][0026]
其中,g
l
表示个性化任务子模型优化得到的一次迭代的总体梯度;gg分别表示全局任务子模型优化得到的一次迭代的总体梯度;wg表示全局任务子模型的模型参数;w
l
表示个性化任务子模型的模型参数;t表示当前联邦训练的轮次;i表示第i个客户端。
[0027]
优选地,步骤3中通过加权平均计算得到聚合后新的全局任务子模型的模型参数具体为:
[0028]
计算客户端参与训练的数据量占所有客户端参与训练数据总量的比重;
[0029]
全局任务子模型的模型参数wg的更新公式如下:
[0030][0031]
其中,k表示参与训练的客户端总数;k表示第k个客户端;n表示所有客户端参与训练的数据总量;nk表示第k个客户端训练的数据量;表示第k个客户端在第t轮联邦训练中的全局任务子模型的模型参数;w
g,t+1
表示第k个客户端在第t+1轮联邦训练中的全局任务子模型的模型参数。
[0032]
优选地,步骤5中预设标准具体为:
[0033]
根据损失曲线对数据和客户端分布进行判断:
[0034]
若为稳定收敛的数据和客户端分布时,客户端本地模型经过预设轮次的联邦训练后,将最后一个轮次的模型参数作为训练结果;
[0035]
若为不能稳定收敛的数据和客户端分布时,通过将数据集中的验证集添加到联邦训练中,选取预设联邦训练轮次内在验证集中表现最优的客户端本地模型的模型参数作为训练结果。
[0036]
进一步地,在训练过程中:
[0037]
通过计算训练集的多任务损失以及反向传播即可更新模型参数,若需要对模型进行预测推理步骤,比如计算训练集、验证集和测试集的准确度,需要对输入数据的类别进行推理时,在推理阶段使用集成推理方法来得到客户端本地模型对输入数据的输出结果;
[0038]
其中,集成推理方法具体如下:
[0039]
本地数据输入到全局任务子模型,输出概率形式的全局任务输出yg;
[0040]
本地数据输入到个性化任务子模型,输出概率形式的个性化任务输出y
p

[0041]
比较上述两个子模型输出的所有类别对应的概率值,选择其中最大的概率值对应的类别作为客户端本地模型的分类结果,计算模型的准确率。
[0042]
第二方面,本发明提供了一种高泛化性的个性化联邦学习实现方法,应用于服务端,包括:
[0043]
step1:服务端随机初始化双分支结构的全局模型,与参与训练的客户端生成连接,并将初始化模型参数发送给参与训练的客户端,等待客户端训练;其中,全局模型包括全局任务子模型分支和个性化任务子模型分支;
[0044]
step2:接收所有参与训练的客户端上传的全局模型训练的统计参数、全局任务子模型的模型参数和客户端客户端本地模型的评价结果;
[0045]
step3:若联邦训练轮次或聚合的评价结果满足预设标准,则停止联邦训练;若联邦训练轮次和聚合的评价结果不满足预设标准,则将所上传的全局任务子模型参数进行加权平均计算得到聚合后的全局任务子模型的模型参数,并将聚合后的模型参数发送给参与训练的客户端,等待客户端训练,返回执行step2,进行循环更新;所述聚合的评价结果是指服务端对各参与训练的客户端上传的客户端本地模型的评价结果进行聚合后的最终结果。
[0046]
第三方面,一种高泛化性的个性化联邦学习实现方法,应用于客户端,包括:
[0047]
s1:与服务端生成连接,接收服务端发送的初始化模型参数对客户端本地模型进行初始化;其中,客户端本地模型包括全局任务子模型分支和个性化任务子模型分支;
[0048]
s2:利用本地数据对客户端本地模型进行进行一轮迭代训练,得到客户端本地模型的模型参数,将客户端本地模型中的统计参数、客户端本地模型是否满足预设标准的评价结果和全局子模型的模型参数上传至服务端;
[0049]
s3:等待服务端发送结束训练指令,若指令为结束训练,进而结束训练并保存预设最佳的客户端本地模型;若指令为继续训练,则等待服务端对全局子模型的模型参数进行聚合,接收服务器发送的聚合后的全局子模型参数,更新客户端本地模型中的全局子模型参数,返回s2,进行循环更新。
[0050]
有益效果
[0051]
本发明提供了一种高泛化的个性化联邦学习实现方法,所述方法利用双分支结构的全局模型,通过全局任务子模型和个性化化任务子模型同时学习全局泛化任务和局部个性化任务,利用任务之间的相关性相互促进,有效提升了客户端本地模型对未知分布数据的性能表现,改善了客户端本地模型泛化性差的问题。
[0052]
参与训练的客户端利用本地数据对客户端本地模型进行训练,将全局子模型参数上传服务端,服务端对所有参与客户端上传的全局任务子模型参数进行联邦聚合,有效的降低了联邦聚合对个性化特征学习的冲突,增强了客户端的客户端本地模型对个性化特征的学习,并且保留了全局特征,在未增加额外联邦通信轮次、局部训练轮次和训练模型的条件下,同时完成个性化特征和全局特征的学习,在提高客户端本地模型的泛化性的同时保证客户端本地模型的个性化性能。
附图说明
[0053]
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
[0054]
图1是本发明所述方法提供的双分支结构的全局模型结构图;
[0055]
图2是本发明所述方法提供的服务端和客户端的通信示意图。
具体实施方式
[0056]
为使本发明的目的、技术方案和优点更加清楚,下面将对本发明的技术方案进行详细的描述。显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动的前提下所得到的所有其它实施方式,都属于本发明所保护的范围。
[0057]
本发明提供了一种高泛化的个性化联邦学习实现方法,用于解决个性化联邦学习在图像中未知分布的数据上表现差的问题,并同时关注联邦模型在本地图像数据的性能表现。本发明提供的技术方案适用于不同的神经网络模型,即在特征提取层后添加批归一化层,可根据不同的需求选取神经网络模型的类型。下面结合附图和具体实施例对本发明中技术方案作进一步详细的说明。
[0058]
实施例1
[0059]
如图1-2所示,本实施例提供了一种高泛化的个性化联邦学习方法,本实施例中选取神经网络中的卷积神经网络对图像进行分类任务为例,包括如下步骤:
[0060]
步骤1:服务端随机初始化双分支结构的全局模型,将得到的初始化模型参数发送至多个选定的客户端;其中,全局模型包括全局任务子模型分支和个性化任务子模型分支。
[0061]
所有客户端使用相同的模型结构,且与服务端使用的模型结构相同。当服务端发送初始化模型参数时,使用相同模型结构的客户端承载服务端的初始化模型参数,可实现各个客户端具备相同的初始化状态。本发明的技术方案中模型结构采用双分支结构的全局模型,模型结构如图1所示。
[0062]
图1中参与联邦训练的模型包含卷积层和全连接层和其他神经网络层等。由于数据隔离是联邦学习的固有属性,不同客户端的数据分布存在常见的几种非独立同分布的情景,如类别分布不均衡和特征分布偏移等。现实应用中,不同医疗机构使用不同的设备产生的影像,就可能产生特征分布偏移。为了适应不同客户端的特征分布情况,本发明在特征提取层后添加批归一化层,具体在图1中表现为在每个卷积层和全连接层后面(除最后一个用于分类的全连接层)添加批归一化层。其实现方式是:通过统计一个批次训练数据的均值和方差;将两个可训练的参数将输出进行归一化,使其为预设合理范围内的一个数值,最终实现对原始输入产生特定于某个批次数据分布的偏移。不同客户端所持有的数据分布存在差异,因此,批归一化层训练得到的参数也不同。本发明使用独立的批归一化层来实现个性化联邦模型。
[0063]
由于个性化模型的低泛化性,本发明设计了全局任务以增强泛化性。针对不同的任务,本发明在原神经网络模型基础上,通过在特征提取层之后的相同位置设计两个用于不同任务的批归一化层,其中全局任务的批归一化层用于联邦训练中,个性化任务的批归一化层用于形成客户端的个性化模型参数。本实施例中,模型训练存在两个任务分支,分别对应一个子模型,子模型由任务共享层和任务特定层组成,任务共享层由两个分支共同优化,任务特定层则单独优化。全局任务子模型由任务共享层和全局批归一化层构成;个性化任务子模型由任务共享层和个性化批归一化层构成,由此构成了双分支结构的全局模型。超参数包括联邦训练轮次、本地迭代轮次、损失比例系数、学习率、数据批大小。
[0064]
步骤2:每个客户端利用服务端发送的初始化模型参数,初始化双分支结构的客户端本地模型,并利用本地数据进行第一轮本地迭代训练,得到更新后的客户端本地模型;将更新后的客户端本地模型的统计参数和全局任务子模型的模型参数上传至服务端。
[0065]
客户端使用初始化后的客户端本地模型进行一轮联邦训练的本地迭代训练,其具体过程为:
[0066]
客户端i使用本地私有的图像数据其中,ni表示客户端i将ni个样本的集合作为训练数据;xj表示训练数据中第j个输入的本地图像数据;yj表示训练数据第j个本地图像数据的真实标签。在具体实施时,可根据实际需求对本地迭代轮次进行设置,本实例中选择1次本地迭代以避免局部模型过拟合局部数据。
[0067]
使用wg表示全局任务子模型,w
l
表示个性化任务子模型。将本地图像数据x输入客户端本地模型后同时执行两个分支得到两个任务的输出,即全局任务输出yg和个性化任务
输出y
l
,通过计算交叉熵损失分别得到全局任务损失lossg和个性化任务损失loss
l

[0068]
交叉熵损失的表达式如下:
[0069][0070]
其中,a取g或l;yj为预测目标,是实际预测结果;m表示参与训练的客户端的数量;利用全局任务损失lossg和个性化任务损失loss
l
得到总体损失loss
overakl
,表示式为:
[0071]
loss
overakl
=αlossg+(1-α)loss
l
[0072]
其中,α为损失比例系数。
[0073]
结合总体损失和预设的学习率η,客户端通过随机梯度下降和反向传播得到更新的计算模型整体的梯度,得到更新的客户端本地模型的模型参数。其中,在全局任务损失lossg计算梯度时,个性化任务特定的批归一化层的梯度为0;同理,在个性化任务损失loss
l
计算梯度时,全局任务特定的批归一化层梯度为0。也就是说,总体损失loss
pverall
对模型的优化相当于同时优化两个子模型,任务共享层的参数由两个损失共同优化,任务特定层由两个损失单独优化。客户端本地模型的模型参数更新表达式如下:
[0074][0075]
其中,g
l
表示个性化任务子模型优化得到的一次迭代的总体梯度;gg分别表示全局任务子模型优化得到的一次迭代的总体梯度;wg表示全局任务子模型的模型参数;w
l
表示个性化任务子模型的模型参数;t表示当前联邦训练的轮次;i表示第i个客户端。
[0076]
步骤3:服务端将所有客户端的全局任务子模型的模型参数进行加权平均计算得到聚合后新的全局任务子模型的模型参数,并将更新后的模型参数发送给多个所选客户端。
[0077]
其中,通过加权平均计算得到聚合后新的全局任务子模型的模型参数具体为:
[0078]
计算客户端参与训练的数据量占所有客户端参与训练数据总量的比重;
[0079]
全局任务子模型的模型参数wg的更新公式如下:
[0080][0081]
其中,k表示参与训练的客户端总数;k表示第k个客户端;n表示所有客户端参与训练的数据总量;nk表示第k个客户端训练的数据量;表示第k个客户端在第t轮联邦训练中的全局任务子模型的模型参数;w
g,t+1
表示第k个客户端在第t+1轮联邦训练中的全局任务子模型的模型参数。
[0082]
值得注意的是,全局任务子模型的联邦聚合与本地双分支结构的模型的本地迭代训练过程是解耦合的,全局任务子模型用于学习数据的一致性知识,也可以根据实际需求采用其他的聚合方法改善一致性特征的学习。
[0083]
步骤4:客户端利用服务端发送的全局任务子模型的模型参数,更新客户端本地模型中的全局任务子模型的模型参数,结合本轮联邦训练中迭代训练得到的客户端本地模型中的个性化任务子模型,得到更新的客户端本地模型,完成一轮联邦训练;
[0084]
其中,客户端客户端本地模型的全局子模型的模型参数来自步骤3中服务端聚合所有客户端的全局子模型的模型参数后更新的数据,包括任务共享层和全局任务特定的批归一化层;而个性化任务特定的批归一化层保持步骤2通过本地图像数据训练后的局部更新的个性化子模型的模型参数,不同客户端模型形成差异化,因此此步骤生成的客户端本地全局模型为个性化模型。
[0085]
步骤5:客户端使用步骤4更新的客户端本地模型基于本地数据进行再一轮迭代训练,更新客户端本地模型参数,并将更新后的客户端本地模型的统计参数和全局任务子模型的模型参数上传至服务端,返回步骤3,循环更新客户端本地模型直至满足预设标准。
[0086]
其中,预设标准具体为:
[0087]
根据损失曲线对数据和客户端分布进行判断:
[0088]
若为稳定收敛的数据和客户端分布时,客户端本地模型经过预设轮次的联邦训练后,将最后一个轮次的模型参数作为训练结果;
[0089]
若为不能稳定收敛的数据和客户端分布时,通过将数据集中的验证集添加到联邦训练中,选取预设联邦训练轮次内的验证集中表现最优的客户端本地模型的模型参数作为训练结果。
[0090]
进一步地,可以稳定收敛指模型的训练损失在一定轮次后变化较小,如手写数字分类为可以稳定收敛的数据和客户端分布;不能稳定收敛指模型的训练损失在较多联邦轮次后变化仍然较大。其中,验证集从训练集中按一定比例划分,不参与训练,仅用于选择训练的模型参数。
[0091]
在训练过程中:
[0092]
通过计算训练集的多任务损失以及反向传播即可更新模型参数,若需要对模型进行预测推理步骤,比如计算训练集、验证集和测试集的准确度,需要对输入数据的类别进行推理时,在推理阶段使用集成推理方法来得到客户端本地模型对输入数据的输出结果;
[0093]
其中,集成推理方法具体如下:
[0094]
将本地图像数据输入到全局任务子模型,输出概率形式的全局任务输出yg;
[0095]
将本地图像数据输入到个性化任务子模型,输出概率形式的个性化任务输出y
p

[0096]
比较上述两个子模型输出的所有类别对应的概率值,选择其中最大的概率值对应的类别作为客户端本地模型的分类结果,计算模型的准确率。若在无标签的预测任务上,同样进行以上三个步骤以输出预测结果。
[0097]
应用实例:
[0098]
本发明以5个客户端分别持有不同特征分布的手写数字数据为例,首先将每个客户端按比例7:3划分训练数据和测试数据,根据客户端是否参与联邦训练形成不同的客户端数据状态:联邦可见数据,联邦不可见数据。其中,采用留一法循环从5个客户端中选择4个客户端参与联邦训练,分别提供用于训练模型的训练数据和用于评估模型的测试数据,测试数据为对应客户端的可见数据;留下1个客户端不参与训练仅提供测试数据,该数据为联邦不可见数据。参与联邦训练的4个客户端分别生成客户端本地模型,在不可见数据上的表现即为模型的泛化性,在其对应客户端可见数据的表现则为模型的个性化性能。每个客户端分别有743张不重叠的手写数字图像。本地客户端本地模型选择如图1所示的卷积神经网络进行实际应用。
[0099]
当所述方法应用于求解不同医院之间的数据孤岛问题时,每个客户端可以视为一个独立的医院,同时不同医院具有不同的数据分布。综上所述,本发明中一个联邦中客户端的知识可以为其他客户端所理解,而无需显式地共享其私有数据,通过联邦聚合和个性化优化,进一步发掘各个参与方的数据价值,提高模型训练的收敛性、鲁棒性和泛化性。
[0100]
实施例2
[0101]
本实施例提供了一种高泛化性的个性化联邦学习实现方法,应用于服务端,包括:
[0102]
step1:服务端随机初始化双分支结构的全局模型,与参与训练的客户端生成连接,并将初始化模型参数发送给参与训练的客户端,等待客户端利用本地图像数据进行训练;其中,全局模型包括全局任务子模型分支和个性化任务子模型分支;
[0103]
step2:接收所有参与训练的客户端上传的全局模型训练的统计参数、全局任务子模型的模型参数和客户端客户端本地模型的评价结果;
[0104]
step3:若联邦训练轮次或聚合的评价结果满足预设标准,则停止联邦训练;若联邦训练轮次和聚合的评价结果不满足预设标准,则将所上传的全局任务子模型参数进行加权平均计算得到聚合后的全局任务子模型的模型参数,并将聚合后的模型参数发送给参与训练的客户端,等待客户端训练,返回执行step2,进行循环更新;所述聚合的评价结果是指服务端对各参与训练的客户端上传的客户端本地模型的评价结果进行聚合后的最终结果。本实施例中客户端客户端本地模型的评价结果的指标包括训练集、验证集、测试集的准确率。
[0105]
实施例3
[0106]
本实施例提供了一种高泛化性的个性化联邦学习实现方法,应用于客户端,包括:
[0107]
s1:与服务端生成连接,接收服务端发送的初始化模型参数对客户端本地模型进行初始化;其中,客户端本地模型包括全局任务子模型分支和个性化任务子模型分支;
[0108]
s2:利用本地图像数据对客户端本地模型进行进行一轮迭代训练,得到客户端本地模型的模型参数,将客户端本地模型中的统计参数、客户端本地模型是否满足预设标准的评价结果和全局子模型的模型参数上传至服务端;
[0109]
s3:等待服务端发送结束训练指令,若指令为结束训练,进而结束训练并保存预设最佳的客户端本地模型;若指令为继续训练,则等待服务端对全局子模型的模型参数进行聚合,接收服务器发送的聚合后的全局子模型参数,更新客户端本地模型中的全局子模型参数,返回s2,进行循环更新。
[0110]
可以理解的是,上述各实施例中相同或相似部分可以相互参考,在一些实施例中未详细说明的内容可以参见其他实施例中相同或相似的内容。
[0111]
尽管上面已经示出和描述了本发明的实施例,可以理解的是,上述实施例是示例性的,不能理解为对本发明的限制,本领域的普通技术人员在本发明的范围内可以对上述实施例进行变化、修改、替换和变型。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1