本说明书一个或多个实施例涉及模型训练,尤其涉及一种模型训练方法及装置、电子设备及存储介质。
背景技术:
1、随着人工智能技术和大数据技术的不断发展,各种自动化、智能化的服务越来越多,这些服务给用户带来了非常好的使用体验,尤其是给不同用户带来了针对性的使用体验。这些服务的质量取决于机器学习中网络模型的训练效果。相关技术中在训练网络模型时,终端设备采集用户数据并上传至云端,云端利用用户数据进行模型训练和模型下发,但是相关技术中的上述训练方式所取得的训练效果还有待提高。
技术实现思路
1、有鉴于此,本说明书一个或多个实施例提供一种模型训练方法及装置、电子设备及存储介质。
2、为实现上述目的,本说明书一个或多个实施例提供技术方案如下:
3、根据本说明书一个或多个实施例的第一方面,提出了一种模型训练方法,所述方法包括:
4、接收多个梯度数据,其中,所述梯度数据由终端根据样本数据和目标模型生成,并发送至边缘cdn节点;
5、对所述多个梯度数据进行聚合处理,得到第一聚合结果;
6、将所述第一聚合结果发送至云端,以使所述云端根据所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至生成所述多个梯度数据的终端。
7、在本说明书的一个实施例中,所述梯度数据在发送至边缘cdn节点之前由终端进行压缩处理;
8、在所述对所述多个梯度数据进行聚合处理,得到第一聚合结果之前,所述方法还包括:
9、分别对所述多个梯度数据中每个梯度数据进行解压缩处理。
10、在本说明书的一个实施例中,所述梯度数据在发送至cdn节点之前由终端进行同态加密处理;
11、所述对所述多个梯度数据进行聚合处理,得到第一聚合结果,包括:
12、对多个经过同态加密处理的梯度数据进行聚合处理,得到第一聚合结果;
13、所述将所述第一聚合结果发送至云端,以使所述云端根据所述第一聚合结果更新所述目标模型,包括:
14、将所述第一聚合结果发送至云端,以使所述云端根据第一聚合结果的解密结果更新所述目标模型。
15、在本说明书的一个实施例中,所述对所述多个梯度数据进行聚合处理,得到第一聚合结果,包括:
16、将所述多个梯度数据进行聚合平均,并将得到的平均值确定为所述第一聚合结果。
17、在本说明书的一个实施例中,所述梯度数据包括梯度和模型版本;
18、所述对所述多个梯度数据进行聚合处理,得到第一聚合结果,包括:
19、将所述多个梯度数据中模型版本为最新版本的梯度数据进行聚合处理,得到第一聚合结果;
20、所述将所述第一聚合结果发送至云端,以使所述云端根据所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至生成所述多个梯度数据的终端,包括:
21、将所述第一聚合结果发送至云端,以使所述云端根据所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至生成所述多个梯度数据中模型版本为最新版本的梯度数据的终端。
22、在本说明书的一个实施例中,所述对所述多个梯度数据进行聚合处理,得到第一聚合结果,包括:
23、响应于所述多个梯度数据的数量达到预设数量阈值,对所述多个梯度数据进行聚合处理,得到第一聚合结果。
24、在本说明书的一个实施例中,所述将所述第一聚合结果发送至云端,以使所述云端根据所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至生成所述多个梯度数据的终端,包括:
25、将所述第一聚合结果经过压缩处理后发送至云端,以使所述云端根据解压缩处理后的所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至生成所述多个梯度数据的终端。
26、在本说明书的一个实施例中,所述梯度数据包括梯度和终端标识;
27、所述将所述第一聚合结果发送至云端,以使所述云端根据所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至生成所述多个梯度数据的终端,包括:
28、将所述第一聚合结果和标识列表发送至云端,以使所述云端根据所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至所述标识列表内指示的终端,其中,所述标识列表包括所述多个梯度数据中每个梯度数据中的终端标识。
29、根据本说明书一个或多个实施例的第二方面,提出了一种模型训练方法,所述方法包括:
30、接收第一聚合结果,其中,所述第一聚合结果由边缘cdn节点对多个梯度数据进行聚合处理得到并发送至云端,所述梯度数据由终端根据样本数据和目标模型生成并发送至边缘cdn节点;
31、根据所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至生成所述多个梯度数据的终端。
32、在本说明书的一个实施例中,所述接收第一聚合结果,包括:
33、接收多个第一聚合结果,其中,所述多个第一聚合结果由至少一个边缘cdn节点生成并发送至云端;
34、所述根据所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至生成所述多个梯度数据的终端,包括:
35、对所述多个第一聚合结果进行聚合处理,得到第二聚合结果;
36、根据所述第二聚合结果更新所述目标模型,并将更新后的所述目标模型发送至所述多个第一聚合结果对应的终端。
37、在本说明书的一个实施例中,所述第一聚合结果由边缘cdn节点对多个经过同态加密的梯度数据进行聚合处理得到并发送至云端;
38、所述根据所述第二聚合结果更新所述目标模型,包括:
39、根据同态解密后的所述第二聚合结果更新所述目标模型。
40、在本说明书的一个实施例中,所述多个第一聚合结果由多个边缘cdn节点生成并发送至云端,所述多个边缘cdn节点属于至少一个区域;
41、所述接收第一聚合结果,还包括:
42、接收所述多个第一聚合结果中每个第一聚合结果对应节点标识,其中,所述节点标识包括生成所述第一聚合结果的边缘cdn节点的标识;
43、所述对所述多个第一聚合结果进行聚合处理,得到第二聚合结果,包括:
44、根据所述多个第一聚合结果中每个第一聚合结果对应的节点标识,对所述多个第一聚合结果进行分组,得到每个区域的至少一个第一聚合结果;
45、分别对每个区域的至少一个第一聚合结果进行聚合处理,得到每个区域的第二聚合结果;
46、所述根据所述第二聚合结果更新所述目标模型,并将更新后的所述目标模型发送至所述多个第一聚合结果对应的终端,包括:
47、对于每个区域,根据所述区域的第二聚合结果更新所述区域对应的目标模型,并将更新后的目标模型发送至所述区域的至少一个第一聚合结果对应的终端。
48、在本说明书的一个实施例中,所述接收第一聚合结果,还包括:
49、接收所述多个第一聚合结果中每个第一聚合结果对应的标识列表,其中,所述标识列表包括所述第一聚合结果对应的多个梯度数据中的终端标识;
50、所述将更新后的目标模型发送至所述多个第一聚合结果对应的终端,包括:
51、根据所述多个第一聚合结果中每个第一聚合结果对应的标识列表,将更新后的目标模型发送至所述多个第一聚合结果对应的终端。
52、在本说明书的一个实施例中,所述第一聚合结果由边缘cdn节点对多个梯度数据中模型版本为最新版本的梯度数据进行聚合处理得到并发送至云端;
53、所述根据所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至生成所述多个梯度数据的终端,包括:
54、将更新后的所述目标模型发送至生成所述多个梯度数据中模型版本为最新版本的梯度数据的终端。
55、在本说明书的一个实施例中,所述方法还包括:
56、将最新版本的所述目标模型的原始模型发送至生成所述多个梯度数据中模型版本非最新版本的梯度数据的终端。
57、根据本说明书一个或多个实施例的第三方面,提出了一种模型训练装置,所述装置包括:
58、第一接收模块,用于接收多个梯度数据,其中,所述梯度数据由终端根据样本数据和目标模型生成,并发送至边缘cdn节点;
59、聚合模块,用于对所述多个梯度数据进行聚合处理,得到第一聚合结果;
60、第一更新模块,用于将所述第一聚合结果发送至云端,以使所述云端根据所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至生成所述多个梯度数据的终端。
61、在本说明书的一个实施例中,所述梯度数据在发送至边缘cdn节点之前由终端进行压缩处理;
62、所述装置还包括解压模块,用于:
63、在所述对所述多个梯度数据进行聚合处理,得到第一聚合结果之前,分别对所述多个梯度数据中每个梯度数据进行解压缩处理。
64、在本说明书的一个实施例中,所述梯度数据在发送至cdn节点之前由终端进行同态加密处理;
65、所述聚合模块具体用于:
66、对多个经过同态加密处理的梯度数据进行聚合处理,得到第一聚合结果;
67、所述第一更新模块用于将所述第一聚合结果发送至云端,以使所述云端根据所述第一聚合结果更新所述目标模型时,具体用于:
68、将所述第一聚合结果发送至云端,以使所述云端根据第一聚合结果的解密结果更新所述目标模型。
69、在本说明书的一个实施例中,所述聚合模块具体用于:
70、将所述多个梯度数据进行聚合平均,并将得到的平均值确定为所述第一聚合结果。
71、在本说明书的一个实施例中,所述梯度数据包括梯度和模型版本;
72、所述聚合模块具体用于:
73、将所述多个梯度数据中模型版本为最新版本的梯度数据进行聚合处理,得到第一聚合结果;
74、所述第一更新模块具体用于:
75、将所述第一聚合结果发送至云端,以使所述云端根据所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至生成所述多个梯度数据中模型版本为最新版本的梯度数据的终端。
76、在本说明书的一个实施例中,所述聚合模块具体用于:
77、响应于所述多个梯度数据的数量达到预设数量阈值,对所述多个梯度数据进行聚合处理,得到第一聚合结果。
78、在本说明书的一个实施例中,所述第一更新模块具体用于:
79、将所述第一聚合结果经过压缩处理后发送至云端,以使所述云端根据解压缩处理后的所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至生成所述多个梯度数据的终端。
80、在本说明书的一个实施例中,所述梯度数据包括梯度和终端标识;
81、所述第一更新模块具体用于:
82、将所述第一聚合结果和标识列表发送至云端,以使所述云端根据所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至所述标识列表内指示的终端,其中,所述标识列表包括所述多个梯度数据中每个梯度数据中的终端标识。
83、根据本说明书一个或多个实施例的第四方面,提出了一种模型训练装置,所述装置包括:
84、第二接收模块,用于接收第一聚合结果,其中,所述第一聚合结果由边缘cdn节点对多个梯度数据进行聚合处理得到并发送至云端,所述梯度数据由终端根据样本数据和目标模型生成并发送至边缘cdn节点;
85、第二更新模块,用于根据所述第一聚合结果更新所述目标模型,并将更新后的所述目标模型发送至生成所述多个梯度数据的终端。
86、在本说明书的一个实施例中,所述第二接收模块具体用于:
87、接收多个第一聚合结果,其中,所述多个第一聚合结果由至少一个边缘cdn节点生成并发送至云端;
88、所述第二更新模块具体用于:
89、对所述多个第一聚合结果进行聚合处理,得到第二聚合结果;
90、根据所述第二聚合结果更新所述目标模型,并将更新后的所述目标模型发送至所述多个第一聚合结果对应的终端。
91、在本说明书的一个实施例中,所述第一聚合结果由边缘cdn节点对多个经过同态加密的梯度数据进行聚合处理得到并发送至云端;
92、所述第二更新模块用于根据所述第二聚合结果更新所述目标模型时,具体用于:
93、根据同态解密后的所述第二聚合结果更新所述目标模型。
94、在本说明书的一个实施例中,所述多个第一聚合结果由多个边缘cdn节点生成并发送至云端,所述多个边缘cdn节点属于至少一个区域;
95、所述第二接收模块还用于:
96、接收所述多个第一聚合结果中每个第一聚合结果对应节点标识,其中,所述节点标识包括生成所述第一聚合结果的边缘cdn节点的标识;
97、所述第二更新模块用于对所述多个第一聚合结果进行聚合处理,得到第二聚合结果时,具体用于:
98、根据所述多个第一聚合结果中每个第一聚合结果对应的节点标识,对所述多个第一聚合结果进行分组,得到每个区域的至少一个第一聚合结果;
99、分别对每个区域的至少一个第一聚合结果进行聚合处理,得到每个区域的第二聚合结果;
100、所述第二更新模块用于根据所述第二聚合结果更新所述目标模型,并将更新后的所述目标模型发送至所述多个第一聚合结果对应的终端时,具体用于:
101、对于每个区域,根据所述区域的第二聚合结果更新所述区域对应的目标模型,并将更新后的目标模型发送至所述区域的至少一个第一聚合结果对应的终端。
102、在本说明书的一个实施例中,所述第二接收模块还用于:
103、接收所述多个第一聚合结果中每个第一聚合结果对应的标识列表,其中,所述标识列表包括所述第一聚合结果对应的多个梯度数据中的终端标识;
104、所述第二更新模块用于将更新后的目标模型发送至所述多个第一聚合结果对应的终端时,具体用于:
105、根据所述多个第一聚合结果中每个第一聚合结果对应的标识列表,将更新后的目标模型发送至所述多个第一聚合结果对应的终端。
106、在本说明书的一个实施例中,所述第一聚合结果由边缘cdn节点对多个梯度数据中模型版本为最新版本的梯度数据进行聚合处理得到并发送至云端;
107、所述第二更新模块用于将更新后的所述目标模型发送至生成所述多个梯度数据的终端时,具体用于:
108、将更新后的所述目标模型发送至生成所述多个梯度数据中模型版本为最新版本的梯度数据的终端。
109、在本说明书的一个实施例中,所述装置还包括版本模块,用于:
110、将最新版本的所述目标模型的原始模型发送至生成所述多个梯度数据中模型版本非最新版本的梯度数据的终端。
111、根据本说明书一个或多个实施例的第三方面,提出了一种电子设备,包括:
112、处理器;
113、用于存储处理器可执行指令的存储器;
114、其中,所述处理器通过运行所述可执行指令以实现如第一方面或第二方面所述的方法。
115、根据本说明书一个或多个实施例的第四方面,提出了一种计算机可读存储介质,其上存储有计算机指令,该指令被处理器执行时实现如第一方面或第二方面所述方法的步骤。
116、本说明书的实施例提供的技术方案可以包括以下有益效果:
117、本说明书实施例所提供的模型训练方法中,终端根据样本数据和目标模型生成梯度数据并发送至cdn节点,cdn节点对接收到的多个梯度数据进行聚合处理,并将得到的第一聚合结果发送至云端,云端根据接收的第一聚合结果更新目标模型,并将更新后的目标模型下发至生成上述多个梯度数据的终端。该方法结合终端、边缘cnd节点和云端三侧来完成目标模型的训练,从至少以下三个方面提高了目标模型的训练效果:由于在终端产生梯度数据,即终端的样本数据始终保持在本地,因此保护了用户数据的安全,避免用户隐私的泄露,而且终端能够利用本地的所有样本数据(包括用户的隐私数据),使得目标模型能够学习到全面的知识;由于多个梯度数据的聚合处理在边缘cdn节点完成,且只向云端上传第一聚合结果,从而减低云端的数据量,降低云端的带宽成本,提高云端的模型更新频率;由于目标模型的更新基于多个梯度数据聚合处理得到的第一聚合结果,从而使每个终端可以共享其他终端的数据,即能够得到基于全面、多样的数据集训练的目标模型,从而提高目标模型的准确性和鲁棒性。