本发明涉及一种标签翻转攻击客户端检测方法、全局模型训练方法及装置,属于网络信息安全与数据隐私保护。
背景技术:
1、联邦学习是目前极为流行的一种分布式机器学习方法,它可以使参与到其中的各个数据拥有者,其原始数据均存留在本地(客户端),而不需要上传至中心服务器。同时在本地用户与中心服务器之间,仅需要交互模型参数或计算的最终结果,这种“数据不动模型动”的交互模式,可以保证中心服务器在进行学习的过程当中,无法推测出本地用户的原始数据,能够满足数据拥有者对于数据隐私保护的需要,极大提高数据拥有者参与机器学习的热情。
2、但也正是由于在联邦学习中,中心服务器无法直接接触到训练方的原始数据,使得中心服务器无法检测出参与方中可能存在的携带有有毒数据的恶意客户端,恶意客户端便可以通过携带或篡改自身所提供的训练数据来参与到整体的训练中来,以实现对于整体模型训练准确度的攻击,使得最终模型训练结果朝着攻击者所预想的方向靠近,这是目前的联邦学习中亟需解决的问题。
3、标签翻转攻击就是一种通过对训练样本中的数据进行修改,使整体模型训练的结果产生错误的攻击方式。尽管目前针对于标签翻转攻击者(即标签翻转攻击客户端)的检测方法已有一定的发展,但在联邦学习环境中,由于中心服务器并不会直接接触到训练客户端的原始数据,而是直接接触客户端训练后的数据,因此现有检测方法应用在联邦学习中时存在以下缺点:
4、1)现有检测方法是与联邦学习的整体训练同步进行的,当训练模型不断迭代后,后期伴随着整体模型训练准确率的增加,标签翻转攻击者的攻击行为会被隐藏在整体的高准确率之下,难以被检测出来。
5、2)现有检测方法多采取单一指标对客户端进行评价,会造成将正常客户端误检测为恶意客户端的情况发生,导致最后训练出的模型准确率及召回率降低。
技术实现思路
1、本发明提供了一种标签翻转攻击客户端检测方法、全局模型训练方法及装置,解决了背景技术中披露的问题。
2、为了解决上述技术问题,本发明所采用的技术方案是:
3、标签翻转攻击客户端检测方法,包括:
4、从客户端获取每轮次的信誉值以及训练数据的质量值;其中,轮次为客户端采用训练数据训练全局模型的轮次,信誉值根据全局模型训练的标签召回率计算获得,质量值根据标签的先验分布和标签的后验分布计算获得;
5、根据每轮次的信誉值,计算客户端的平均信誉值;
6、根据客户端的平均信誉值和训练数据的质量值,检测标签翻转攻击的客户端。
7、每轮次信誉值计算公式为:
8、
9、式中,为第d个客户端第i轮次的信誉值,m为数据种类,numj为第j种类别的数据数量,为根据标签召回率计算获得的第j种类别的被识别成功的数据数量。
10、训练数据质量值计算公式为:
11、
12、式中,dqd为第d个客户端的训练数据质量值,dd为第d个客户端的训练数据分布,表示自变量x从dd中采样,表示自变量z从gw(z|y)中采样,gw(z|y)为通过生成器gw生成的与标签对应的潜在特征值,y为训练时的标签序列,acc(argmaxh(z),argmaxh(f(x)))为计算argmaxh(z)、argmaxh(f(x))中相同元素总和的函数,argmaxh(z)为基于全局模型的关于标签所对应的特征z的预测,argmaxh(f(x))为基于客户端训练数据样本的预测。
13、根据客户端的平均信誉值和训练数据的质量值,检测标签翻转攻击的客户端,包括:
14、根据客户端的平均信誉值,查找出平均信誉值小于阈值的客户端,将查找出的客户端加入恶意客户端集合s1;
15、对所有客户端的训练数据质量值进行k-means聚类,获得两簇数据,将较低的一簇数据的客户端加入到恶意客户端集合s2;
16、计算恶意客户端集合s1和恶意客户端集合s2的交集,获得标签翻转攻击的客户端。
17、标签翻转攻击客户端检测装置,包括:
18、获取模块,从客户端获取每轮次的信誉值以及训练数据的质量值;其中,轮次为客户端采用训练数据训练全局模型的轮次,信誉值根据全局模型训练的标签召回率计算获得,质量值根据标签的先验分布和标签的后验分布计算获得;
19、平均信誉值计算模块,根据每轮次的信誉值,计算客户端的平均信誉值;
20、检测模块,根据客户端的平均信誉值和训练数据的质量值,检测标签翻转攻击的客户端。
21、检测模块被配置为:
22、根据客户端的平均信誉值,查找出平均信誉值小于阈值的客户端,将查找出的客户端加入恶意客户端集合s1;
23、对所有客户端的训练数据质量值进行k-means聚类,获得两簇数据,将较低的一簇数据的客户端加入到恶意客户端集合s2;
24、计算恶意客户端集合s1和恶意客户端集合s2的交集,获得标签翻转攻击的客户端。
25、全局模型训练方法,包括:
26、采用上述标签翻转攻击客户端检测方法,检测出标签翻转攻击的客户端,并从所有的客户端中剔除标签翻转攻击客户端;
27、采用剩余的客户端进行联邦学习全局模型训练。
28、全局模型训练装置,包括:
29、检测剔除模块,采用上述标签翻转攻击客户端检测方法,检测出标签翻转攻击的客户端,并从所有的客户端中剔除标签翻转攻击客户端;
30、训练模块,采用剩余的客户端进行联邦学习全局模型训练。
31、一种计算机可读存储介质,所述计算机可读存储介质存储一个或多个程序,所述一个或多个程序包括指令,所述指令当由计算设备执行时,使得所述计算设备执行标签翻转攻击客户端检测方法或全局模型训练方法。
32、一种计算机设备,包括一个或多个处理器、以及一个或多个存储器,一个或多个程序存储在所述一个或多个存储器中并被配置为由所述一个或多个处理器执行,所述一个或多个程序包括用于执行标签翻转攻击客户端检测方法或全局模型训练方法的指令。
33、本发明所达到的有益效果:本发明采用客户端的平均信誉值和训练数据的质量值,检测标签翻转攻击的客户端,相较于单一的指标,降低了正常客户端被误识别为恶意客户端的误检测率,并且客户端的平均信誉值是根据每轮次的信誉值计算获得,即使训练模型不断迭代,标签翻转攻击者的攻击行也可以被检测出来。
1.标签翻转攻击客户端检测方法,其特征在于,包括:
2.根据权利要求1所述的标签翻转攻击客户端检测方法,其特征在于,每轮次信誉值计算公式为:
3.根据权利要求1所述的标签翻转攻击客户端检测方法,其特征在于,训练数据质量值计算公式为:
4.根据权利要求1所述的标签翻转攻击客户端检测方法,其特征在于,根据客户端的平均信誉值和训练数据的质量值,检测标签翻转攻击的客户端,包括:
5.标签翻转攻击客户端检测装置,其特征在于,包括:
6.根据权利要求5所述的标签翻转攻击客户端检测装置,其特征在于,检测模块被配置为:
7.全局模型训练方法,其特征在于,包括:
8.全局模型训练装置,其特征在于,包括:
9.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储一个或多个程序,所述一个或多个程序包括指令,所述指令当由计算设备执行时,使得所述计算设备执行权利要求1~4、7所述的任一方法。
10.一种计算机设备,其特征在于,包括: