分布式机器学习通信优化方法、装置、服务器及终端设备

文档序号:25880366发布日期:2021-07-16 18:29阅读:123来源:国知局
分布式机器学习通信优化方法、装置、服务器及终端设备

1.本申请属于机器学习技术领域,尤其涉及一种分布式机器学习通信优化方法、装置、服务器及终端设备。


背景技术:

2.现有的分布式机器学习方法,在一次迭代过程中,每一个计算节点负责执行训练并计算梯度参数,并执行push()操作(即推送操作),将梯度参数推送到服务器节点;服务器节点在接收到梯度参数后,根据更新公式计算更新参数,并执行pull()操作(即拉取操作),将更新参数拉取到每一个计算节点进行本地更新;由于节点中模型的深度越来越深,服务器节点与计算节点之间需要推送和拉取的参数数量急剧增加,导致网络开销较大、通信量较大,从而造成通信延迟,影响训练速度。现有技术是通过对梯度参数或权重参数做量化处理,从而压缩梯度,以减少节点间传输的数据量;但是上述方式需要增加量化处理,且在量化处理中需要利用复杂的量化、变换公式,导致压缩梯度过程较为复杂,不易于实现,压缩结果具有局限性,造成优化通信效果不佳。


技术实现要素:

3.本申请实施例提供了一种分布式机器学习通信优化方法、装置、服务器及终端设备,可以解决现有通信优化方法优化效果不佳的问题。
4.第一方面,本申请实施例提供了一种分布式机器学习通信优化方法,所述分布式机器学习通信优化方法包括:
5.获取多个计算节点的梯度;
6.根据所述多个计算节点的梯度,更新服务器节点的全局模型的参数;
7.获取所述全局模型更新后的参数;
8.从所述多个计算节点中选取预设个数的计算节点;
9.将第一标志位信息和所述全局模型更新后的参数输出给所述预设个数的计算节点,其中,所述第一标志位信息用于指示所述预设个数的计算节点根据所述全局模型更新后的参数更新本地模型;
10.将第二标志位信息和所述全局模型更新后的参数输出给剩余计算节点,其中,所述剩余计算节点是指所述多个计算节点中除所述预设个数的计算节点之外的计算节点,所述第二标志位信息用于指示所述剩余计算节点不接收所述服务器输出的所述全局模型更新后的参数,并指示所述剩余计算节点根据自身的梯度更新本地模型。
11.第二方面,本申请实施例提供了一种分布式机器学习通信优化方法,所述分布式机器学习通信优化方法包括:
12.获取计算节点的梯度;
13.将所述计算节点的梯度发送给服务器,其中,所述计算节点的梯度用于指示所述服务器根据接收到的多个计算节点的梯度更新服务器中服务器节点的全局模型的参数,并
向所述计算节点反馈标志位信息;
14.接收所述标志位信息;
15.根据所述标志位信息,获取目标参数;
16.根据所述目标参数更新所述计算节点的本地模型。
17.第三方面,本申请实施例提供了一种分布式机器学习通信优化装置,所述分布式机器学习通信优化装置包括:
18.梯度获取模块,用于获取多个计算节点的梯度;
19.参数更新模块,用于根据所述多个计算节点的梯度,更新服务器节点的全局模型的参数;
20.参数获取模块,用于获取所述全局模型更新后的参数;
21.节点选择模块,用于从所述多个计算节点中选取预设个数的计算节点;
22.第一输出模块,用于将第一标志位信息和所述全局模型更新后的参数输出给所述预设个数的计算节点,其中,所述第一标志位信息用于指示所述预设个数的计算节点根据所述全局模型更新后的参数更新本地模型;
23.第二输出模块,用于将第二标志位信息和所述全局模型更新后的参数输出给剩余计算节点,其中,所述剩余计算节点是指所述多个计算节点中除所述预设个数的计算节点之外的计算节点,所述第二标志位信息用于指示所述剩余计算节点不接收所述服务器输出的所述全局模型更新后的参数,并指示所述剩余计算节点根据自身的梯度更新本地模型。
24.第四方面,本申请实施例提供了一种分布式机器学习通信优化装置,所述分布式机器学习通信优化装置包括:
25.梯度获取模块,用于获取计算节点的梯度;
26.梯度输出模块,用于将所述计算节点的梯度发送给服务器,其中,所述计算节点的梯度用于指示所述服务器根据接收到的多个计算节点的梯度更新服务器中服务器节点的全局模型的参数,并向所述计算节点反馈标志位信息;
27.标志位接收模块,用于接收所述标志位信息;
28.目标参数获取模块,用于根据所述标志位信息,获取目标参数;
29.本地模型更新模块,用于根据所述目标参数更新所述计算节点的本地模型。
30.第五方面,本申请实施例提供了一种服务器,所述服务器包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如上述第一方面所述的分布式机器学习通信优化方法。
31.第六方面,本申请实施例提供了一种终端设备,所述终端设备包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如上述第二方面所述的分布式机器学习通信优化方法。
32.第七方面,本申请实施例提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如上述第一方面所述的分布式机器学习通信优化方法。
33.第八方面,本申请实施例提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如上述第二方面所述的分布式机器学习通信优化方法。
34.第九方面,本申请实施例提供了一种计算机程序产品,当计算机程序产品在服务器上运行时,使得所述服务器执行如上述第一方面所述的分布式机器学习通信优化方法。
35.第十方面,本申请实施例提供了一种计算机程序产品,当计算机程序产品在终端设备上运行时,使得所述终端设备执行如上述第二方面所述的分布式机器学习通信优化方法。
36.本申请实施例与现有技术相比存在的有益效果是:本申请在服务器节点更新过全局模型后,选取一部分的计算节点发送相应的标志位,使该部分计算节点从服务器节点上拉取全局模型更新后的参数,剩余的计算节点使用自身的梯度更新本地模型的参数,实现过程较为简单,有效减少了分布式机器学习模型的计算节点与服务器节点之间的拉取通信量,减少了交互通信开销,提高了分布式机器学习的通信效率,降低了在特殊场景下的计算节点的通信花费,例如边缘场景,收费网络,跨区域网络等场景下一部分计算节点进行拉取通信,另一部分计算节点不进行拉取通信,相比全部节点进行拉取通信,减少了计算节点的通信开销。
附图说明
37.为了更清楚地说明本申请实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
38.图1是本申请实施例一提供的一种分布式机器学习通信优化方法的流程示意图;
39.图2是本申请实施例一提供的一种分布式机器学习通信优化方法的原理示例图;
40.图3是本申请实施例二提供的一种分布式机器学习通信优化方法的流程示意图;
41.图4是本申请实施例三提供的一种分布式机器学习通信优化方法的交互示意图;
42.图5是本申请实施例三提供的一种分布式机器学习通信优化方法的采用mnistcnn模型实验的收敛效率图;
43.图6是本申请实施例三提供的一种分布式机器学习通信优化方法的采用mnistcnn模型实验的通信率对比图;
44.图7是本申请实施例三提供的一种分布式机器学习通信优化方法的采用alexnet模型实验的收敛效率图;
45.图8是本申请实施例三提供的一种分布式机器学习通信优化方法的采用alexnet模型实验的通信率对比图;
46.图9是本申请实施例四提供的一种分布式机器学习通信优化装置的结构示意图;
47.图10是本申请实施例五提供的一种分布式机器学习通信优化装置的结构示意图;
48.图11是本申请实施例六提供的服务器的结构示意图;
49.图12是本申请实施例七提供的终端设备的结构示意图。
具体实施方式
50.以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本申请实施例。然而,本领域的技术人员应当清楚,在没有这些具体
细节的其它实施例中也可以实现本申请。在其它情况中,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本申请的描述。
51.应当理解,当在本申请说明书和所附权利要求书中使用时,术语“包括”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
52.还应当理解,在本申请说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
53.如在本申请说明书和所附权利要求书中所使用的那样,术语“如果”可以依据上下文被解释为“当...时”或“一旦”或“响应于确定”或“响应于检测到”。类似地,短语“如果确定”或“如果检测到[所描述条件或事件]”可以依据上下文被解释为意指“一旦确定”或“响应于确定”或“一旦检测到[所描述条件或事件]”或“响应于检测到[所描述条件或事件]”。
[0054]
另外,在本申请说明书和所附权利要求书的描述中,术语“第一”、“第二”、“第三”等仅用于区分描述,而不能理解为指示或暗示相对重要性。
[0055]
在本申请说明书中描述的参考“一个实施例”或“一些实施例”等意味着在本申请的一个或多个实施例中包括结合该实施例描述的特定特征、结构或特点。由此,在本说明书中的不同之处出现的语句“在一个实施例中”、“在一些实施例中”、“在其他一些实施例中”、“在另外一些实施例中”等不是必然都参考相同的实施例,而是意味着“一个或多个但不是所有的实施例”,除非是以其他方式另外特别强调。术语“包括”、“包含”、“具有”及它们的变形都意味着“包括但不限于”,除非是以其他方式另外特别强调。
[0056]
本申请实施例提供了一种分布式机器学习通信优化方法可以应用于桌上型计算机、笔记本电脑、超级移动个人计算机(ultra-mobile personal computer,umpc)、上网本、云端服务器、个人数字助理(personal digital assistant,pda)等终端设备上,本申请实施例对终端设备的具体类型不作任何限制。
[0057]
应理解,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本申请实施例的实施过程构成任何限定。
[0058]
为了说明本申请所述的技术方案,下面通过具体实施例来进行说明。
[0059]
参见图1,是本申请实施例一提供的一种分布式机器学习通信优化方法的流程,该分布式机器学习通信优化方法可应用于服务器,本申请实施例一是针对服务器侧的改进,如图所示,该分布式机器学习通信优化方法可以包括以下步骤:
[0060]
步骤s101,获取多个计算节点的梯度。
[0061]
参数服务器(parameter server)是一种分布式机器学习的架构,包括服务器节点(server)和计算节点(worker),其中,服务器中设置有服务器节点,每个终端设备中设置有一个计算节点,服务器与多个终端设备连接,终端设备中计算节点将计算的梯度发送给该服务器,服务器接收与其相连接所有终端设备的计算节点发送的梯度;其中,在终端设备的计算节点中,本地模型的训练目标就是使损失函数最小,一般采用梯度下降的方法求解,根据该计算节点的本地模型和该计算节点的训练数据集进行训练得到梯度,即为该计算节点计算的梯度。
[0062]
步骤s102,根据所述多个计算节点的梯度,更新服务器节点的全局模型的参数。
[0063]
全局模型是存储在服务器节点上的机器学习模型,服务器节点的全局模型可以包括结构和参数,全局模型更新前的参数为当前参数。
[0064]
可选的是,所述根据所述多个计算节点的梯度,更新服务器节点的全局模型的参数包括:
[0065]
根据所述多个计算节点的梯度和预设参数更新公式,更新所述服务器节点的全局模型的参数,其中,所述预设参数更新公式为:
[0066][0067]
其中,η为学习率,p为总的计算节点数,g
i
为第i个计算节点的梯度,ω

是所述全局模型的当前参数,ω是所述全局模型更新后的参数。
[0068]
步骤s103,获取所述全局模型更新后的参数。
[0069]
其中,更新服务器节点的全局模型的参数,即可得到更新的全局模型,其中,更新后的全局模型中结构没有改变,更新后的全局模型中参数与更新前的全局模型中的参数不同,获取全局模型更新后的参数,为后续服务器节点与计算节点的通信做准备。
[0070]
步骤s104,从所述多个计算节点中选取预设个数的计算节点。
[0071]
服务器从与其相连的多个终端设备中选取预设个数的终端设备,也就是从多个计算节点中选取预设个数的计算节点,服务器可以获取和识别输入的预设个数,根据该预设个数和相应的选择方法选取计算节点,其中,相应的选择方法可根据实际需求设定,例如,随机选取法。
[0072]
可选的是,所述从所述多个计算节点中选取预设个数的计算节点包括:
[0073]
从所述多个计算节点中随机选取所述预设个数的计算节点。
[0074]
例如,本申请实施例一的分布式机器学习采用一个服务器节点和四个计算节点,参见图2,是本申请实施例一提供的一种分布式机器学习通信优化方法的原理示例图,w1、w2、w3、w4分别为四个计算节点,w1、w2、w3、w4发送梯度(push gradient)给服务器节点(server),服务器节点更新全局模型(update global model),服务器节点随机选择计算节点w1和w4,那么显然,计算节点w2和w3即为剩余节点。
[0075]
步骤s105,将第一标志位信息和所述全局模型更新后的参数输出给所述预设个数的计算节点。
[0076]
其中,所述第一标志位信息用于指示所述预设个数的计算节点根据所述全局模型更新后的参数更新本地模型。
[0077]
本地模型是存储在计算节点上的机器学习模型,计算节点的本地模型可以包括结构和参数。
[0078]
服务器节点将第一标志位信息输出给选取的计算节点,这些计算节点根据第一标志位信息判定接收服务器节点输出的全局模型更新后的参数,该第一标志位信息可以是一个比特(bit),例如,“1”,参见图2,服务器节点将“1”输出给计算节点w1和w4,计算节点w1和w4根据“1”判定执行接收函数,以接收服务器节点输出的全局模型更新后的参数,其中,接收函数为计算节点中用于执行拉取(pull)动作的函数,拉取即为将参数从服务器节点拉到计算节点;随后,计算节点w1和w4根据该全局模型更新后的参数更新本地模型,如利用该全
局模型更新后的参数替换本地模型中的当前参数。
[0079]
步骤s106,将第二标志位信息和所述全局模型更新后的参数输出给剩余计算节点。
[0080]
其中,所述剩余计算节点是指所述多个计算节点中除所述预设个数的计算节点之外的计算节点,所述第二标志位信息用于指示所述剩余计算节点不接收所述服务器输出的所述全局模型更新后的参数,并指示所述剩余计算节点根据自身的梯度更新本地模型。
[0081]
服务器节点将第二标志位信息输出给剩余计算节点,这些剩余计算节点根据第二标志位信息判定不接收服务器节点输出的全局模型更新后的参数,该第二标志位信息可以是一个比特(bit),例如,“0”,参见图2,服务器节点将“0”输出给计算节点w2和w3,计算节点w2和w3根据“0”判定不执行接收函数,不接收服务器节点输出的全局模型更新后的参数;随后,计算节点w2和w3根据自身的梯度更新对应的本地模型,如计算节点w2将本地模型中的当前参数减去自身的梯度得到新的参数,用该新的参数替换本地模型中的当前参数。
[0082]
本申请实施例在服务器节点更新过全局模型后,选取一部分的计算节点对应的终端设备发送相应的标志位,使该部分计算节点从服务器节点上拉取全局模型更新后的参数,剩余的计算节点对应的终端设备使用自身的梯度更新本地模型的参数,通过将标志位发给计算节点的方式,简单实现部分计算节点的拉取通信,有助于解决全部计算节点拉取通信导致的通信花费、开销较大的问题。
[0083]
参见图3,是本申请实施例二提供的一种分布式机器学习通信优化方法的流程,该分布式机器学习通信优化方法可应用于终端设备,本申请实施例二是针对终端设备侧的改进,如图所示,该分布式机器学习通信优化方法可以包括以下步骤:
[0084]
步骤s301,获取计算节点的梯度。
[0085]
参数服务器(parameter server)是一种分布式机器学习的架构,包括服务器节点(server)和计算节点(worker),其中,服务器中设置有服务器节点,每个终端设备中设置有一个计算节点,服务器与多个终端设备连接,任一终端设备中计算节点根据该计算节点的本地模型和该计算节点的训练数据集进行训练得到该计算节点的梯度。
[0086]
可选的是,当服务器节点采用随机选取法选取计算节点时,所有计算节点采用的训练数据集为具有相同概率分布的数据集,该类数据集的数据点一般认为是从一个概率分布中采样得到,例如高斯分布。
[0087]
步骤s302,将所述计算节点的梯度发送给服务器。
[0088]
其中,所述计算节点的梯度用于指示所述服务器根据接收到的多个计算节点的梯度更新服务器中服务器节点的全局模型的参数,并向所述计算节点反馈标志位信息。
[0089]
每一个终端设备均与服务器连接,每一个终端设备在获取了自身的计算节点的梯度后将该计算节点的梯度发送给服务器;服务器接收到所有计算节点的梯度后,更新服务器中服务器节点的全局模型参数,在更新完全局模型参数后向计算节点反馈标志位信息,实现服务器与终端设备的交互。
[0090]
步骤s303,接收标志位信息。
[0091]
终端设备在将计算节点的梯度发送给服务器后,等待接收服务器反馈的标志位信息,该标志位信息可以是一个比特(bit),例如,“0”或“1”,根据不同的标志位信息,计算节点执行不同的步骤。
[0092]
步骤s304,根据所述标志位信息,获取目标参数。
[0093]
可选的是,所述根据所述标志位信息,获取目标参数包括:
[0094]
若所述标志位信息为第一标志位信息,则接收所述服务器反馈的所述全局模型更新后的参数,并确定该参数为所述目标参数;
[0095]
若所述标志位信息为第二标志位信息,则拒绝接收所述服务器反馈的所述全局模型更新后的参数,并确定所述计算节点的梯度为所述目标参数;
[0096]
其中,所述全局模型更新后的参数是所述服务器更新所述服务器节点的全局模型的参数后反馈的。
[0097]
例如,参见图2,计算节点w1接收标志位信息为“1”,执行接收函数,接收服务器反馈的全局模型更新后的参数,以该全局模型更新后的参数作为计算节点w1的目标参数;计算节点w2接收标志位信息为“0”,不执行接收函数,拒绝接收服务器反馈的全局模型更新后的参数,另外,获取该计算节点w2的当前梯度,以该当前梯度作为该计算节点w2的目标参数。
[0098]
步骤s305,根据所述目标参数更新所述计算节点的本地模型。
[0099]
本地模型是存储在计算节点上的机器学习模型,计算节点的本地模型可以包括结构和参数。根据目标参数更新计算节点的本地模型,本质上是更新计算节点的本地模型中的参数,对于计算节点的本地模型的结构不做更新。
[0100]
可选的是,所述根据所述目标参数更新所述计算节点的本地模型包括:
[0101]
获取所述计算节点的本地模型的当前参数;
[0102]
若所述标志位信息为第一标志位信息,则将所述计算节点的本地模型的当前参数替换为所述目标参数,所述目标参数为所述计算节点的本地模型更新后的参数;
[0103]
若所述标志位信息为第二标志位信息,则将所述计算节点的本地模型的当前参数减去所述目标参数,所得差值为所述计算节点的本地模型更新后的参数。
[0104]
其中,若计算节点接收的标志位信息为第一标志位信息,则该计算节点接收服务器反馈的全局模型更新后的参数,并确定该参数为目标参数,获取该计算节点的本地模型的当前参数,利用该全局模型更新后的参数(即目标参数)替换该计算节点的本地模型中的当前参数,该全局模型更新后的参数即为该计算节点的本地模型更新后的参数。
[0105]
若计算节点接收的标志位信息为第二标志位信息,则该计算节点拒绝接收服务器反馈的全局模型更新后的参数,并确定该计算节点的梯度为目标参数,获取该计算节点的本地模型的当前参数,将该计算节点的本地模型的当前参数减去该计算节点的梯度(即目标参数),所得差值为该计算节点的本地模型更新后的参数。
[0106]
本申请实施例计算节点根据标志位,一部分选择从服务器节点上拉取全局模型更新后的参数,剩余部分的计算节点使用自身的梯度更新本地模型的参数,通过选择性的拉取,明显降低了计算节点从服务器节点拉取参数的通信开销,从而降低了在特殊场景下的计算节点的通信花费,例如边缘场景,收费网络,跨区域网络等场景。
[0107]
参见图4,是本申请实施例三的一种分布式机器学习通信优化方法的交互示意图,为了便于说明,仅示出了与本申请实施例相关的部分。
[0108]
该分布式机器学习通信优化方法在服务器与多个终端设备之间进行交互,如图4所示,第一终端设备和第二终端设备将其计算节点发送给服务器;服务器根据多个计算节
点的梯度更新服务器中服务器节点的全局模型的参数,服务器选取第一终端设备发送第一标志位信息,选取第二终端设备发送第二标志位信息;第一终端设备根据第一标志位信息接收全局模型更新后的参数,并根据全局模型更新后的参数更新第一终端设备中计算节点的本地模型;第二终端设备根据第二标志位信息不接收全局模型更新后的参数,并根据第二终端设备中计算节点的梯度更新该计算节点的本地模型。
[0109]
第一终端设备将更新后的本地模型与对应的训练数据集结合训练得到新的梯度,第二终端设备将更新后的本地模型与对应的训练数据集结合训练得到新的梯度,第一终端设备和第二终端设备都将新的梯度发送给服务器,循环执行上述交互,直至收敛或达到迭代次数。
[0110]
参见图5至图8,是本申请实施例提供的分布式机器学习通信优化方法的实验结果,实验运行在10个计算节点和1个服务器节点的架构上。如图5和图6所示,全局模型和本地模型采用的结构均为mnistcnn模型的结构,mnistcnn模型是利用mnist数据集进行训练的卷积神经网络模型;本地模型中执行训练所使用的训练数据集是mnist数据集,mnist数据集由60000个训练样本和10000个测试样本组成,每个样本都是一张28
×
28像素的灰度手写数字图片,其中,图5表示收敛效率,即纵坐标为训练损失,横坐标为迭代次数,全部拉取表示每一个计算节点都进行拉取参数的情况(即不进行随机选取计算节点),拉取概率-0.1表示每一个计算节点进行拉取参数的拉取概率为0.1的情况,拉取概率-0.2表示每一个计算节点进行拉取参数的拉取概率为0.2的情况,拉取概率-0.4表示每一个计算节点进行拉取参数的拉取概率为0.4的情况,拉取概率-0.8表示每一个计算节点进行拉取参数的拉取概率为0.8的情况,图5中1、2、3、4、5分别表示全部拉取、拉取概率-0.1、拉取概率-0.2、拉取概率-0.4、拉取概率-0.8在训练时对应的训练损失,其中,由于2和3对应的拉取概率相差较小,两者对应的训练损失相差较小,因此,表示2和3的训练损失的线条出现了重合,相比而言,同一迭代次数下2对应的训练损失低于3对应的训练损失;图6表示不同拉取概率的压缩率,压缩率为不同拉取概率达到图5中虚线对应的训练损失时的迭代次数乘以对应拉取概率,再与全部拉取达到图5中虚线对应的训练损失时的迭代次数的比值,r=0.8表示拉取概率为0.8,r=0.4、r=0.2、r=0.1分别表示拉取概率为0.4、0.2、0.1,其中,2.1
×
、8.6
×
、20.2
×
、48.3
×
均为压缩倍数,2.1
×
表示2.1倍,压缩倍数为压缩率的倒数,压缩倍数越大,通信优化效果越好。
[0111]
如图7和图8所示,全局模型和本地模型采用的结构均为alexnet模型的结构,alexnet模型是一种深层的卷积神经网络模型,特点是使用了relu(rectified linear units)激励函数和多个gpu训练等;本地模型中执行训练所使用的训练数据集是cifar10数据集,cifar10数据集由10个类的60000个32
×
32彩色图像组成,每个类有6000个图像,有50000个训练图像和10000个测试图像,其中,图7表示收敛效率,图7中1、2、3、4、5分别表示全部拉取、拉取概率-0.1、拉取概率-0.2、拉取概率-0.4、拉取概率-0.8在训练时对应的训练损失,其中,由于2和3对应的拉取概率相差较小,两者对应的训练损失相差较小,因此,表示2和3的训练损失的线条出现了重合,相比而言,同一迭代次数下2对应的训练损失低于3对应的训练损失;图8表示不同拉取概率的压缩率,其中,1.8
×
、4.0
×
、7.9
×
、16.2
×
均为压缩倍数,1.8
×
表示1.8倍。
[0112]
本申请在服务器与多个终端设备交互过程中,在服务器中服务器节点更新过全局
模型后,选取一部分的计算节点对应的终端设备发送相应的标志位,使该部分计算节点从服务器节点上拉取全局模型更新后的参数,剩余的计算节点对应的终端设备使用自身的梯度更新本地模型的参数,实现过程较为简单,有效减少了分布式机器学习模型的计算节点与服务器节点之间的拉取通信量,减少了交互通信开销,提高了分布式机器学习的通信效率,降低了在特殊场景下的计算节点的通信花费,例如边缘场景,收费网络,跨区域网络等场景下一部分计算节点进行拉取通信,另一部分计算节点不进行拉取通信,相比全部节点进行拉取通信,减少了通信开销。
[0113]
参见图9,是本申请实施例四提供了一种分布式机器学习通信优化装置的结构框图,为了便于说明,仅示出了与本申请实施例相关的部分。
[0114]
该分布式机器学习通信优化装置包括:
[0115]
梯度获取模块91,用于获取多个计算节点的梯度;
[0116]
参数更新模块92,用于根据所述多个计算节点的梯度,更新服务器节点的全局模型的参数;
[0117]
参数获取模块93,用于获取所述全局模型更新后的参数;
[0118]
节点选择模块94,用于从所述多个计算节点中选取预设个数的计算节点;
[0119]
第一输出模块95,用于将第一标志位信息和所述全局模型更新后的参数输出给所述预设个数的计算节点,其中,所述第一标志位信息用于指示所述预设个数的计算节点根据所述全局模型更新后的参数更新本地模型;
[0120]
第二输出模块96,用于将第二标志位信息和所述全局模型更新后的参数输出给剩余计算节点,其中,所述剩余计算节点是指所述多个计算节点中除所述预设个数的计算节点之外的计算节点,所述第二标志位信息用于指示所述剩余计算节点不接收所述服务器输出的所述全局模型更新后的参数,并指示所述剩余计算节点根据自身的梯度更新本地模型。
[0121]
可选的是,该节点选择模块94具体用于:
[0122]
从所述多个计算节点中随机选取所述预设个数的计算节点。
[0123]
可选的是,该参数更新模块92具体用于:
[0124]
根据所述多个计算节点的梯度和预设参数更新公式,更新所述服务器节点的全局模型的参数,其中,所述预设参数更新公式为:
[0125][0126]
其中,η为学习率,p为总的计算节点数,g
i
为第i个计算节点的梯度,ω

是所述全局模型的当前参数,ω是所述全局模型更新后的参数。
[0127]
需要说明的是,上述模块之间的信息交互、执行过程等内容,由于与本申请方法实施例基于同一构思,其具体功能及带来的技术效果,具体可参见方法实施例一部分,此处不再赘述。
[0128]
参见图10,是本申请实施例五提供了一种分布式机器学习通信优化装置的结构框图,为了便于说明,仅示出了与本申请实施例相关的部分。
[0129]
该分布式机器学习通信优化装置包括:
[0130]
梯度获取模块101,用于获取计算节点的梯度;
[0131]
梯度输出模块102,用于将所述计算节点的梯度发送给服务器,其中,所述计算节点的梯度用于指示所述服务器根据接收到的多个计算节点的梯度更新服务器中服务器节点的全局模型的参数,并向所述计算节点反馈标志位信息;
[0132]
标志位接收模块103,用于接收所述标志位信息;
[0133]
目标参数获取模块104,用于根据所述标志位信息,获取目标参数;
[0134]
本地模型更新模块105,用于根据所述目标参数更新所述计算节点的本地模型。
[0135]
可选的是,该目标参数获取模块104包括:
[0136]
第一目标参数获取单元,用于若所述标志位信息为第一标志位信息,则接收所述服务器反馈的所述全局模型更新后的参数,并确定该参数为所述目标参数;
[0137]
第二目标参数获取单元,用于若所述标志位信息为第二标志位信息,则拒绝接收所述服务器反馈的所述全局模型更新后的参数,并确定所述计算节点的梯度为所述目标参数;
[0138]
其中,所述全局模型更新后的参数是所述服务器更新所述服务器节点的全局模型的参数后反馈的。
[0139]
可选的是,该本地模型更新模块105包括:
[0140]
当前参数获取单元,用于获取所述计算节点的本地模型的当前参数;
[0141]
第一本地模型更新单元,用于若所述标志位信息为第一标志位信息,则将所述计算节点的本地模型的当前参数替换为所述目标参数,所述目标参数为所述计算节点的本地模型更新后的参数;
[0142]
第二本地模型更新单元,用于若所述标志位信息为第二标志位信息,则将所述计算节点的本地模型的当前参数减去所述目标参数,所得差值为所述计算节点的本地模型更新后的参数。
[0143]
需要说明的是,上述模块之间的信息交互、执行过程等内容,由于与本申请方法实施例基于同一构思,其具体功能及带来的技术效果,具体可参见方法实施例二部分,此处不再赘述。
[0144]
图11为本申请实施例六提供的一种服务器的结构示意图。服务器用于连接多个终端设备,其中,每个终端设备中设置一个计算节点,服务器中设置有服务器节点,如图11所示,该实施例的服务器11包括:至少一个处理器110(图11中仅示出一个)处理器、存储器111以及存储在存储器111中并可在至少一个处理器110上运行的计算机程序112,处理器110执行计算机程序112时实现上述实施例一中分布式机器学习通信优化方法的步骤。
[0145]
服务器11可以是桌上型计算机、笔记本、掌上电脑及云端服务器等计算设备。该服务器可包括,但不仅限于,处理器110、存储器111。本领域技术人员可以理解,图11仅仅是服务器11的举例,并不构成对服务器11的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件,例如还可以包括输入输出设备、网络接入设备等。
[0146]
所称处理器110可以是中央处理单元(central processing unit,cpu),该处理器110还可以是其他通用处理器、数字信号处理器(digital signal processor,dsp)、专用集成电路(application specific integrated circuit,asic)、现成可编程门阵列(field-programmable gate array,fpga)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、
分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
[0147]
所述存储器111在一些实施例中可以是服务器11的内部存储单元,例如服务器11的硬盘或内存。所述存储器111在另一些实施例中也可以是服务器11的外部存储设备,例如服务器11上配备的插接式硬盘,智能存储卡(smart media card,smc),安全数字(secure digital,sd)卡,闪存卡(flash card)等。进一步地,所述存储器111还可以既包括服务器11的内部存储单元也包括外部存储设备。所述存储器111用于存储操作系统、应用程序、引导装载程序(bootloader)、数据以及其他程序等,例如所述计算机程序的程序代码等。所述存储器111还可以用于暂时地存储已经输出或者将要输出的数据。
[0148]
图12为本申请实施例七提供的一种终端设备的结构示意图。终端设备用于连接服务器,其中,该服务器中设置有服务器节点,终端设备中设置一个计算节点,如图12所示,该实施例的终端设备12包括:至少一个处理器120(图12中仅示出一个)处理器、存储器121以及存储在存储器121中并可在至少一个处理器120上运行的计算机程序122,处理器120执行计算机程序122时实现上述实施例二的分布式机器学习通信优化方法的步骤。
[0149]
终端设备12可以是桌上型计算机、笔记本、掌上电脑及云端服务器等计算设备。该终端设备可包括,但不仅限于,处理器120、存储器121。本领域技术人员可以理解,图12仅仅是终端设备12的举例,并不构成对终端设备12的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件,例如还可以包括输入输出设备、网络接入设备等。
[0150]
所称处理器120可以是中央处理单元(central processing unit,cpu),该处理器120还可以是其他通用处理器、数字信号处理器(digital signal processor,dsp)、专用集成电路(application specific integrated circuit,asic)、现成可编程门阵列(field-programmable gate array,fpga)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
[0151]
所述存储器121在一些实施例中可以是终端设备12的内部存储单元,例如终端设备12的硬盘或内存。所述存储器121在另一些实施例中也可以是终端设备12的外部存储设备,例如终端设备12上配备的插接式硬盘,智能存储卡(smart media card,smc),安全数字(secure digital,sd)卡,闪存卡(flash card)等。进一步地,所述存储器121还可以既包括终端设备12的内部存储单元也包括外部存储设备。所述存储器121用于存储操作系统、应用程序、引导装载程序(bootloader)、数据以及其他程序等,例如所述计算机程序的程序代码等。所述存储器121还可以用于暂时地存储已经输出或者将要输出的数据。
[0152]
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本申请的保护范围。上述装置中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。所述集
成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请实现上述实施例方法中的全部或部分流程,可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,所述计算机程序包括计算机程序代码,所述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。所述计算机可读介质至少可以包括:能够携带所述计算机程序代码的任何实体或装置、记录介质、计算机存储器、只读存储器(rom,read-only memory)、随机存取存储器(ram,random access memory)、电载波信号、电信信号以及软件分发介质。例如u盘、移动硬盘、磁碟或者光盘等。在某些司法管辖区,根据立法和专利实践,计算机可读介质不可以是电载波信号和电信信号。
[0153]
本申请实现上述实施例方法中的全部或部分流程,也可以通过一种计算机程序产品来完成,当所述计算机程序产品在服务器上运行时,使得所述服务器执行时实现可实现上述方法实施例一中的步骤,或者,当所述计算机程序产品在终端设备上运行时,使得所述终端设备执行时实现可实现上述方法实施例二中的步骤。
[0154]
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
[0155]
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
[0156]
在本申请所提供的实施例中,应该理解到,所揭露的装置/终端设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/终端设备实施例仅仅是示意性的,例如,所述模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通讯连接可以是通过一些接口,装置或单元的间接耦合或通讯连接,可以是电性,机械或其它的形式。
[0157]
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
[0158]
以上所述实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围,均应包含在本申请的保护范围之内。
当前第1页1 2 3 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1