本技术涉及人工智能领域,尤其涉及一种图像分类模型的训练方法、装置、介质及电子设备。
背景技术:
1、随着科技的发展,隐私数据日益受到关注。随着科技的发展,人工智能飞速发展,图像分类模型在许多领域广泛应用。
2、目前,联邦学习(federated learning,fl)技术可通过中心节点,在保护各边缘节点的本地样本数据隐私的情况下,使得各边缘节点共同训练图像分类模型,从而提高训练好的图像分类模型的性能。在实际场景中,由于不同边缘节点对应的计算资源以及样本数据均存在较大差异,因此不同边缘节点中部署的图像分类模型的架构不同,那么在训练图像分类模型时,中心节点可通过知识蒸馏技术,基于各边缘节点训练的图像分类模型也即教师模型,对中心节点中部署的图像分类模型也即学生模型进行训练,得到训练后的图像分类模型。但是,如前所述,不同边缘节点中部署的图像分类模型的架构不同,那么中心节点将训练后的图像分类模型发送至各边缘节点,会出现边缘节点无法运行或者存储基于训练后的图像分类模型的情况,从而导致训练图像分类模型的任务无法进行。
3、基于此,本技术说明书提供了一种图像分类模型的训练方法。
技术实现思路
1、本说明书提供一种图像分类模型的训练方法及装置、存储介质及电子设备,以至少部分的解决现有技术存在的上述问题。
2、本说明书采用下述技术方案:
3、本说明书提供了一种图像分类模型的训练方法,所述方法应用于中心节点,所述方法用于训练图像分类模型中的编码端;所述方法包括:
4、获取各边缘节点中部署的图像分类模型的局部编码端的参数;
5、根据各局部编码端的参数,对中心节点中部署的图像分类模型的全局编码端进行蒸馏,得到蒸馏后的全局编码端的参数;
6、针对每个边缘节点,根据所述蒸馏后的全局编码端的参数以及预设训练目标,对该边缘节点对应的局部编码端的参数进行调整,得到该边缘节点对应的调整后的局部编码端的参数;
7、将得到的该边缘节点对应的调整后的局部编码端的参数发送至该边缘节点,以使该边缘节点根据基于调整后的局部编码端的参数构建的编码端,得到该边缘节点对应的训练好的图像分类模型。
8、可选地,所述各边缘节点中部署的图像分类模型包括第一在线编码端以及第一目标编码端,所述第一在线编码端与所述第一目标编码端为不对称孪生网络;
9、针对每个边缘节点,该边缘节点确定该边缘节点中部署的图像分类模型的局部编码端的参数,具体包括:
10、该边缘节点针对该边缘节点中预先存储的每个第一样本图像,确定该第一样本图像的第一增强图像;
11、该边缘节点将该第一样本图像输入所述第一在线编码端,得到第一特征,以及将第一增强图像输入所述第一目标编码端,得到第二特征;
12、该边缘节点以最大化所述第一特征以及所述第二特征之间的相似度为目标,对所述第一在线编码端进行训练,得到训练后的第一在线编码端,将训练后的第一在线编码端作为局部编码端;
13、根据所述局部编码端,该边缘节点得到该边缘节点对应的局部编码端的参数。
14、可选地,所述各边缘节点中部署的图像分类模型还包括预测端,所述预测端与所述第一在线编码端连接;
15、该边缘节点将该第一样本图像输入所述第一在线编码端,得到第一特征,具体包括:
16、该边缘节点将该第一样本图像输入所述第一在线编码端,得到第一特征,并将所述第一特征输入所述预测端,得到第三特征;
17、对所述第一在线编码端进行训练,具体包括:
18、该边缘节点以最大化所述第三特征以及所述第二特征之间的相似度为目标,对所述第一在线编码端进行训练。
19、可选地,所述方法还包括:
20、根据第一在线编码端,采用指数移动平均算法,对所述第一目标编码端进行训练。
21、可选地,对该边缘节点对应的局部编码端的参数进行调整,得到该边缘节点对应的调整后的局部编码端的参数,具体包括:
22、将基于所述蒸馏后的全局编码端的参数构建的编码端作为第二目标编码端,将基于该边缘节点对应的局部编码端的参数构建的编码端作为第二在线编码端;并,针对中心节点中预先存储的每个第二样本图像,确定该第二样本图像对应的第二增强图像;
23、将该第二增强图像输入所述第二在线编码端,得到第四特征,以及将该第二样本图像输入所述第二目标编码端,得到第五特征;
24、以最大化所述第四特征以及所述第五特征之间的相似度为目标,调整所述第二在线编码端的参数,得到调整后的第二在线编码端的参数;
25、根据调整后的第二在线编码端的参数,得到该边缘节点对应的调整后的局部编码端的参数。
26、可选地,对中心节点中部署的图像分类模型的全局编码端进行蒸馏,得到蒸馏后的全局编码端的参数,具体包括:
27、将基于各局部编码端的参数构建的编码端作为教师模型,将全局编码端作为学生模型;
28、根据各教师模型以及所述学生模型,确定蒸馏损失;
29、根据所述蒸馏损失,调整所述学生模型的参数,得到参数调整后的学生模型;
30、根据参数调整后的学生模型,得到蒸馏后的全局编码端的参数。
31、可选地,根据各教师模型以及所述学生模型,确定蒸馏损失,具体包括:
32、针对中心节点中预先存储的每个第二样本图像,将该第二样本图像分别输入所述各教师模型,得到各第六特征;并将该第二样本图像输入所述学生模型,得到第七特征,以及将除该第二样本图像之外的各其他样本图像分别输入所述学生模型,得到各第八特征;
33、根据所述各第六特征以及所述第七特征,采用注意力机制,确定该第二样本图像对应的加权聚合特征;
34、根据所述第七特征、所述加权聚合特征以及所述各第八特征,确定蒸馏损失。
35、可选地,调整所述学生模型的参数,具体包括:
36、接收各边缘节点发送的局部损失,所述局部损失为边缘节点训练局部编码端的损失;
37、根据各局部损失以及所述蒸馏损失,调整所述学生模型的参数。
38、本说明书提供了一种语言模型的预训练装置,所述装置位于中心节点,所述装置用于训练图像分类模型中的编码端;所述装置具体包括:
39、获取模块,用于获取各边缘节点中部署的图像分类模型的局部编码端的参数;
40、蒸馏模块,用于根据各局部编码端的参数,对中心节点中部署的图像分类模型的全局编码端进行蒸馏,得到蒸馏后的全局编码端的参数;
41、调整模块,用于针对每个边缘节点,根据所述蒸馏后的全局编码端的参数以及预设训练目标,对该边缘节点对应的局部编码端的参数进行调整,得到该边缘节点对应的调整后的局部编码端的参数;
42、发送模块,用于将得到的该边缘节点对应的调整后的局部编码端的参数发送至该边缘节点,以使该边缘节点根据基于调整后的局部编码端的参数构建的编码端,得到该边缘节点对应的训练好的图像分类模型。
43、可选地,所述各边缘节点中部署的图像分类模型包括第一在线编码端以及第一目标编码端,所述第一在线编码端与所述第一目标编码端为不对称孪生网络;
44、所述获取模块具体用于,该边缘节点针对该边缘节点中预先存储的每个第一样本图像,确定该第一样本图像的第一增强图像;该边缘节点将该第一样本图像输入所述第一在线编码端,得到第一特征,以及将第一增强图像输入所述第一目标编码端,得到第二特征;该边缘节点以最大化所述第一特征以及所述第二特征之间的相似度为目标,对所述第一在线编码端进行训练,得到训练后的第一在线编码端,将训练后的第一在线编码端作为局部编码端;根据所述局部编码端,该边缘节点得到该边缘节点对应的局部编码端的参数。
45、可选地,所述各边缘节点中部署的图像分类模型还包括预测端,所述预测端与所述第一在线编码端连接;
46、所述获取模块具体用于,该边缘节点将该第一样本图像输入所述第一在线编码端,得到第一特征,并将所述第一特征输入所述预测端,得到第三特征;
47、所述获取模块具体用于,该边缘节点以最大化所述第三特征以及所述第二特征之间的相似度为目标,对所述第一在线编码端进行训练。
48、可选地,所述获取模块还用于,根据第一在线编码端,采用指数移动平均算法,对所述第一目标编码端进行训练。
49、可选地,所述调整模块具体用于,将基于所述蒸馏后的全局编码端的参数构建的编码端作为第二目标编码端,将基于该边缘节点对应的局部编码端的参数构建的编码端作为第二在线编码端;并,针对中心节点中预先存储的每个第二样本图像,确定该第二样本图像对应的第二增强图像;将该第二增强图像输入所述第二在线编码端,得到第四特征,以及将该第二样本图像输入所述第二目标编码端,得到第五特征;以最大化所述第四特征以及所述第五特征之间的相似度为目标,调整所述第二在线编码端的参数,得到调整后的第二在线编码端的参数;根据调整后的第二在线编码端的参数,得到该边缘节点对应的调整后的局部编码端的参数。
50、可选地,所述蒸馏模块具体用于,将基于各局部编码端的参数构建的编码端作为教师模型,将全局编码端作为学生模型;根据各教师模型以及所述学生模型,确定蒸馏损失;根据所述蒸馏损失,调整所述学生模型的参数,得到参数调整后的学生模型;根据参数调整后的学生模型,得到蒸馏后的全局编码端的参数。
51、可选地,所述蒸馏模块具体用于,针对中心节点中预先存储的每个第二样本图像,将该第二样本图像分别输入所述各教师模型,得到各第六特征;并将该第二样本图像输入所述学生模型,得到第七特征,以及将除该第二样本图像之外的各其他样本图像分别输入所述学生模型,得到各第八特征;根据所述各第六特征以及所述第七特征,采用注意力机制,确定该第二样本图像对应的加权聚合特征;根据所述第七特征、所述加权聚合特征以及所述各第八特征,确定蒸馏损失。
52、可选地,所述蒸馏模块具体用于,接收各边缘节点发送的局部损失,所述局部损失为边缘节点训练局部编码端的损失;根据各局部损失以及所述蒸馏损失,调整所述学生模型的参数。
53、本说明书提供了一种计算机可读存储介质,所述存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述图像分类模型的训练方法。
54、本说明书提供了一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述图像分类模型的训练方法。
55、本说明书采用的上述至少一个技术方案能够达到以下有益效果:
56、从本说明书提供的图像分类模型的训练方法中可以看出,在各边缘节点训练适用于自身的图像分类模型时,中心节点可采用知识蒸馏技术,根据获取到的各边缘节点的图像分类模型的局部编码端的参数,调整中心节点中的全局编码端的参数,得到调整后的全局编码端的参数。进而中心节点针对每个边缘节点,根据调整后的全局编码端以及该边缘节点的局部编码端,调整该边缘节点的局部编码端的参数,并将调整后的局部编码端的参数发送至该边缘节点。该方法中,中心节点不直接将调整后的全局编码端的参数发送至边缘节点,而是基于调整后的全局编码端,再对局部编码端进行调整,从而将调整后的局部编码端的参数发送至边缘节点,在提高了各边缘节点训练出的图像分类模型的编码端的特征表示能力的同时,实现了在各边缘节点共享数据的情况下,各边缘节点得到适应于自身计算资源配置的图像分类模型,也即在各边缘节点可在数据共享的情况下,即便每个边缘节点的图像分类模型的架构不同,每个边缘节点仍旧能够得到适应于自身模型架构的图像分类模型。