本公开实施例涉及数据处理,尤其涉及一种基于关键周期识别的联邦学习恶意模型更新检测方法。
背景技术:
1、目前,现存联邦学习恶意更新检测方法,分为三类:基于统计的方法;基于过滤的方法;基于差分隐私的方法。这三种方法有各自的缺点:①基于统计的方法用简单的统计机器学习方法构建模型,没有针对恶意更新的特点出发,构建检测模型,导致检测率低,仅能应对恶意模型更新与正常模型更新幅度差距较大的情况;②基于过滤的方法将异常样本看作边缘点处理,认为处于所有样本边缘的样本为异常点,或认为所处位置样本最稀疏的点为异常点。由于联邦学习模型更新维度很高,不进行降维会导致维度灾难。同时少部分方法使用基于偏差的检测方法:使用机器学习方法学习样本隐藏表达,将更新数据映射到隐藏表达,再通过隐藏表达还原重构更新数据,并计算重构误差,通过最小化重构误差来优化神经网络和该隐藏表达。依据重构误差判断样本是否异常。当异常模型更新较多时检测效率低下,直接将高维模型更新输入异常检测模型计算开销大。③基于差分隐私的方法将高斯噪声加入全局模型更新,使得隐匿型有特定目标的恶意更新无法起效,但无法应对以破坏模型整体为目标的恶意模型更新;同时引入噪声本身会降低最终得到的全局模型的表现,拖延联邦学习模型训练进度。
2、总的来说,现有方法无法同时应对无目标攻击(破坏模型整体性能)和有目标攻击(破坏模型特定子任务性能),且计算开销大、全局模型表现差。
3、可见,亟需一种检测效率、精准度和适应性强的基于关键周期识别的联邦学习恶意模型更新检测方法。
技术实现思路
1、有鉴于此,本公开实施例提供一种基于关键周期识别的联邦学习恶意模型更新检测方法,至少部分解决现有技术中存在检测效率、精准度和适应性较差的问题。
2、本公开实施例提供了一种基于关键周期识别的联邦学习恶意模型更新检测方法,包括:
3、步骤1,在每个客户端上计算本地更新数据后发送至服务端;
4、步骤2,服务端对本轮次收集的来自客户端的本地更新数据进行特征提取,得到本地模型更新元组;
5、步骤3,通过关键周期识别模块识别本地模型更新元组,得到关键周期识别符并据此将本地模型更新元组分为关键周期更新元组和非关键周期更新元组;
6、步骤4,将本地模型更新元组输入异常检测模块得到异常得分并筛选出正常模型更新;
7、步骤5,通过聚合算法和正常模型更新得到全局模型更新。
8、根据本公开实施例的一种具体实现方式,所述本地更新数据为三元组,所述三元组包括客户端本地模型的梯度、本地数据集平均交叉熵和客户端训练完成符。
9、根据本公开实施例的一种具体实现方式,所述步骤2具体包括:
10、步骤2.1,解析该本地更新数据,根据三元组中的客户端训练完成符判断该数据是否是客户端训练完毕后发送的,若判断结果为否则丢弃,若判断结果为是,则进行步骤2.2;
11、步骤2.2,根据收到本地更新数据的顺序判断其到达服务端的序号,若为第一个到达,则为本轮次的本地模型梯度和本地数据集平均交叉熵形成的二元组开辟存储空间;若为最后一个包,则进行步骤2.3,否则,将该二元组存入对应的存储空间;
12、步骤2.3,对本轮次的本地模型梯度进行minmax数据标准化处理,得到本地模型更新元组。
13、根据本公开实施例的一种具体实现方式,所述步骤3具体包括:
14、步骤3.1,从本地模型更新元组的二元组集合中提取本地数据集平均交叉熵,并计算各客户端交叉熵与上一轮次的差值,然后将各客户端的差值取加权平均值,并计算变化幅度;
15、步骤3.2,判断当前轮次是否在关键周期对应的轮次范围内,若是,则输出关键周期识别符为true,存储当前轮次变化幅度,并结束当前轮次识别,若否,则转入步骤3.3;
16、步骤3.3,将当前记录的变化幅度集合输入jenks-caspall自然断点分类法,识别出断点,若当前轮次在断点之后,输出关键周期识别符为false,否则为true。
17、根据本公开实施例的一种具体实现方式,当本地模型更新元组为关键周期更新元组时,所述步骤4具体包括:
18、步骤4.1,将本地模型更新元组分层,各层模型更新展平为一维向量,分别存储到不同的集合中;
19、步骤4.2,将各层一维向量分割为预设长度的子向量,分别存储到不同集合,每个集合大小为本轮次接收的本地模型更新数量,不满预设长度的子向量保留为新的子向量;
20、步骤4.3,将不同子向量集合分别输入dbscan异常检测方法,分别找出异常子向量,并分别记录异常标记向量,其大小为子向量集合容量,若为异常标记为0,否则为1,将所有子向量集合的异常标记向量合并为本地模型更新数量*本地模型更新子向量数量的二维矩阵;
21、步骤4.4,使用二维矩阵进行软投票,其中全连接层的子向量权重为1,其余为0.2,计算得到得票总和,将得票前50%高的本地模型更新保留,其余本地模型更新丢弃,输出本地模型更新异常标签作为正常模型更新。
22、根据本公开实施例的一种具体实现方式,当本地模型更新元组为非关键周期更新元组时,所述步骤4具体包括:
23、步骤4.1,设置随机种子;
24、步骤4.2,将本地模型更新元组分层,各层模型更新展平为一维向量,各层向量随机取预设长度的子向量,分别存储到不同的集合中,将各子向量分别存储到不同集合,每个集合大小为本轮次接收的本地模型更新数量;
25、步骤4.3,将不同子向量集合分别输入dbscan异常检测方法,分别找出异常子向量,并分别记录异常得分向量,其大小为子向量集合容量,异常得分为[0,1]的数值,将所有子向量集合的异常得分向量合并为本地模型更新数量*本地模型更新子向量数量的二维矩阵;
26、步骤4.4,根据二维矩阵计算得到各本地模型更新异常得分总和,将得分前50%高的本地模型更新保留,其余本地模型更新丢弃,输出本地模型更新异常标签作为正常模型更新。
27、本公开实施例中的基于关键周期识别的联邦学习恶意模型更新检测方案,包括:步骤1,在每个客户端上计算本地更新数据后发送至服务端;步骤2,服务端对本轮次收集的来自客户端的本地更新数据进行特征提取,得到本地模型更新元组;步骤3,通过关键周期识别模块识别本地模型更新元组,得到关键周期识别符并据此将本地模型更新元组分为关键周期更新元组和非关键周期更新元组;步骤4,将本地模型更新元组输入异常检测模块得到异常得分并筛选出正常模型更新;步骤5,通过聚合算法和正常模型更新得到全局模型更新。
28、本公开实施例的有益效果为:通过本公开的方案,利用关键周期识别,找出对全局模型有不可逆影响的周期,在关键周期中使用细粒度的异常检测方法,在非关键周期使用粗粒度的异常检测方法,实现检测能力和计算效率的平衡,使用交叉熵损失的变化幅度识别关键周期,识别精度与使用联邦fisher矩阵接近但计算开销远小于联邦fisher矩阵。同时,现有关键周期识别方法使用预设阈值的方法,数值固定,无法适应联邦学习客户端数据集分布、模型结构、损失函数的改变。本发明使用jenks-caspall自然断点分类法自适应识别断点,将断点前的训练轮次视为关键周期,断点后的训练轮次视为非关键周期,实现了自适应关键周期识别,提高了关键周期识别的稳定性,通过将本地模型更新分层分割处理为子向量,将子向量分别进行异常检测,并使用软投票汇总子向量的异常检测结果,实现对隐匿型攻击的有效检测,提高了检测效率、精准度和适应性。