本申请涉及深度学习与大模型领域,特别是涉及模型训练显存优化方法、装置、电子装置和存储介质。
背景技术:
1、在人工智能领域,大模型技术因其卓越的性能和广泛的应用前景而变得日益流行。这些模型,如自然语言处理和图像识别系统中使用的模型,因其巨大的参数数量和复杂的网络结构而闻名。然而,这种技术的进步也带来了对硬件资源,尤其是显存的巨大需求。显存,作为训练这些大型模型的关键资源,其容量往往限制了模型的规模和训练速度。
2、目前常用的方法是利用混合精度训练模型。混合精度训练是一种使用单精度(32位浮点)表示和半精度(16位浮点)表示相结合的方式来处理计算任务的技术,把深度学习模型中的部分网络参数转换为半精度类型,在训练时一部分计算使用单精度类型进行,一部分计算使用半精度进行,以提高模型精度的同时降低模型训练过程中的存储空间占用以及执行时间。这种方法主要是利用半精度表示在保持足够精度的同时可以减少所需存储空间的优势,从而减轻了显存的负担。但是由于人工智能的发展,模型的规模越来越大,单使用混合精度训练也很难满足所需的显存。
3、针对相关技术中存在模型训练中显存的利用率低和训练速度慢的问题,目前还没有提出有效的解决方案。
技术实现思路
1、在本实施例中提供了一种模型训练显存优化方法、装置、电子装置和存储介质,以解决相关技术中模型训练中显存的利用率低和训练速度慢的问题。
2、第一个方面,在本实施例中提供了一种模型训练显存优化方法,包括:
3、获取模型原始参数,将模型原始参数进行备份,得到备份参数;
4、根据模型结构确定候选暂退的模型连接;对候选暂退的模型连接不分配显存,并为除候选暂退的模型连接之外的其他模型连接分配显存;
5、执行循环训练过程直至达到预设的训练终止条件,得到目标训练模型;循环训练过程包括:
6、根据预设的暂退比例,对候选暂退的模型连接进行随机暂退后,得到目标连接;
7、根据目标连接,从备份参数中复制参数进行半精度训练,得到半精度参数梯度;
8、根据半精度参数梯度更新备份参数。
9、在其中的一些实施例中,根据模型结构确定候选暂退的模型连接,包括:
10、遍历模型结构,确定dropout在模型结构中的具体位置;
11、根据dropout的具体位置确定候选暂退的模型连接。
12、在其中的一些实施例中,根据dropout的具体位置确定候选暂退的模型连接,包括:
13、根据dropout的具体位置,以行或列为维度进行暂退,确定候选暂退的模型连接。
14、在其中的一些实施例中,候选暂退的连接为dropout前后模块中无效计算模块。
15、在其中的一些实施例中,根据目标连接,从备份参数中复制参数进行半精度训练,得到半精度参数梯度,包括:
16、从备份参数中复制参数,采用半精度输入在目标连接中进行前向计算得到损失函数;
17、根据损失函数计算得到半精度参数梯度。
18、在其中的一些实施例中,根据目标连接,从备份参数中复制参数进行半精度训练,得到半精度参数梯度,包括:
19、根据链式推导法则,根据目标连接,从备份参数中复制参数进行半精度训练,得到半精度参数梯度。
20、在其中的一些实施例中,根据半精度参数梯度更新备份参数,包括:
21、将根据预设的暂退比例对候选暂退的模型连接进行随机暂退的候选暂退的模型连接所对应的参数梯度补充至半精度参数梯度中,得到完整的半精度参数梯度;
22、根据完整的半精度参数梯度更新备份参数。
23、第二个方面,在本实施例中提供了一种模型训练显存优化装置,包括:获取模块、确定模块和训练模块,其中:
24、获取模块,用于获取模型原始参数,将模型原始参数进行备份,得到备份参数;
25、确定模块,用于根据模型结构确定候选暂退的模型连接;对候选暂退的模型连接不分配显存,并为除候选暂退的模型连接之外的其他模型连接分配显存;
26、训练模块,用于执行循环训练过程直至达到预设的训练终止条件,得到目标训练模型;循环训练过程包括:
27、根据预设的暂退比例,对候选暂退的模型连接进行随机暂退后,得到目标连接;
28、根据目标连接,从备份参数中复制参数进行半精度训练,得到半精度参数梯度;
29、根据半精度参数梯度更新备份参数。
30、第三个方面,在本实施例中提供了一种电子装置,包括存储器、处理器以及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述第一个方面所述的模型训练显存优化方法。
31、第四个方面,在本实施例中提供了一种存储介质,其上存储有计算机程序,该程序被处理器执行时实现上述第一个方面所述的模型训练显存优化方法。
32、与相关技术相比,在本实施例中提供的模型训练显存优化方法,通过获取模型原始参数,将模型原始参数进行备份,得到备份参数;根据模型结构确定候选暂退的模型连接;对候选暂退的模型连接不分配显存,并为除候选暂退的模型连接之外的其他模型连接分配显存;执行循环训练过程直至达到预设的训练终止条件,得到目标训练模型;循环训练过程包括:根据预设的暂退比例,对候选暂退的模型连接进行随机暂退后,得到目标连接;根据目标连接,从备份参数中复制参数进行半精度训练,得到半精度参数梯度;根据半精度参数梯度更新备份参数,提高了模型训练中显存的利用率和训练速度。
33、本申请的一个或多个实施例的细节在以下附图和描述中提出,以使本申请的其他特征、目的和优点更加简明易懂。
1.一种模型训练显存优化方法,其特征在于,包括:
2.根据权利要求1所述的模型训练显存优化方法,其特征在于,所述根据所述模型结构确定候选暂退的模型连接,包括:
3.根据权利要求2所述的模型训练显存优化方法,其特征在于,所述根据所述dropout的具体位置确定所述候选暂退的模型连接,包括:
4.根据权利要求2所述的模型训练显存优化方法,其特征在于,所述候选暂退的连接为所述dropout前后模块中无效计算模块。
5.根据权利要求1所述的模型训练显存优化方法,其特征在于,所述根据所述目标连接,从所述备份参数中复制参数进行半精度训练,得到半精度参数梯度,包括:
6.根据权利要求1所述的模型训练显存优化方法,其特征在于,所述根据所述目标连接,从所述备份参数中复制参数进行半精度训练,得到半精度参数梯度,包括:
7.根据权利要求1所述的模型训练显存优化方法,其特征在于,所述根据所述半精度参数梯度更新所述备份参数,包括:
8.一种模型训练显存优化装置,其特征在于,包括:获取模块、确定模块和训练模块,其中:
9.一种电子装置,包括存储器和处理器,其特征在于,所述存储器中存储有计算机程序,所述处理器被设置为运行所述计算机程序以执行权利要求1至权利要求7中任一项所述的模型训练显存优化方法。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1至权利要求7中任一项所述的模型训练显存优化方法的步骤。