本公开涉及计算机,尤其涉及一种大模型的微调方法、大模型的微调装置、电子设备和存储介质。
背景技术:
1、相关技术中,主要通过全量微调、固定部分参数(fix parameter)微调、适配器(adapter)微调或者前缀微调(prefix-tuning)的方式对源数据处理模型进行微调,得到目标数据处理模型。
2、其中,在数据处理模型为大模型的情况下,全量微调的成本过高。
3、在固定部分参数微调的方案中,只微调源数据处理模型的部分网络层的参数,其余网络层的参数冻结。然而,当源数据处理模型对应的源域与目标数据处理模型对应的目标域之间的差距较大时,只微调部分网络层的参数的效果难以保证。另外,由于层与层之间有紧密的联系,语义信息也是逐层传导的,因此,如果只微调部分层的参数,则会导致冻结层的语义信息与微调层的语义信息之间有很大的差别,进而导致微调效果较差。
4、在适配器微调的方案中,增加了推理成本,且只微调少量的参数,效果难以保证。
5、在前缀微调的方案中,训练难度较高,且只微调少量的参数,效果较差。
技术实现思路
1、本公开提供了一种数据处理模型的微调技术方案。
2、根据本公开的一方面,提供了一种大模型的微调方法,包括:
3、将目标数据处理模型的网络层划分为第一网络层组和第二网络层组;
4、根据所述第一网络层组对应的原参数矩阵的行数和列数,初始化第一参数矩阵和第二参数矩阵,其中,所述第一参数矩阵的行数等于所述原参数矩阵的行数,所述第一参数矩阵的列数小于所述原参数矩阵的列数,所述第二参数矩阵的列数等于所述原参数矩阵的列数,且所述第二参数矩阵的行数等于所述第一参数矩阵的列数;
5、将训练样本输入所述目标数据处理模型,通过所述目标数据处理模型输出所述训练样本对应的预测结果;
6、根据所述训练样本对应的预测结果和所述训练样本对应的标签,确定所述目标数据处理模型对应的损失函数的值;
7、根据所述损失函数的值,更新所述第一参数矩阵、所述第二参数矩阵以及所述第二网络层组对应的第三参数矩阵。
8、在一种可能的实现方式中,所述第一网络层组包括所述目标数据处理模型的前m个网络层,所述第二网络层组包括所述目标数据处理模型的后n个网络层,所述目标数据处理模型的网络层数为m+n,其中,m和n均为大于或等于1的整数。
9、在一种可能的实现方式中,所述第一参数矩阵的列数比所述原参数矩阵的列数小至少一个数量级。
10、在一种可能的实现方式中,在所述目标数据处理模型微调的过程中,所述原参数矩阵保持固定。
11、在一种可能的实现方式中,所述将训练样本输入所述目标数据处理模型,通过所述目标数据处理模型输出所述训练样本对应的预测结果,包括:
12、确定最新的所述第一参数矩阵与最新的所述第二参数矩阵的第一乘积;
13、将所述原参数矩阵与所述第一乘积之和,确定为所述第一网络层组对应的最新总参数矩阵;
14、将所述训练样本输入所述目标数据处理模型,基于所述最新总参数矩阵和最新的所述第三参数矩阵,得到所述训练样本对应的预测结果。
15、在一种可能的实现方式中,所述根据所述损失函数的值,更新所述第一参数矩阵、所述第二参数矩阵以及所述第二网络层组对应的第三参数矩阵,包括:
16、根据所述损失函数的值,确定所述第一参数矩阵对应的第一梯度、所述第二参数矩阵对应的第二梯度以及所述第三参数矩阵对应的第三梯度;
17、根据所述第一梯度,更新所述第一参数矩阵;
18、根据所述第二梯度,更新所述第二参数矩阵;
19、根据所述第三梯度,更新所述第三参数矩阵。
20、在一种可能的实现方式中,所述方法还包括:
21、在显存中,保存所述第一梯度、所述第二梯度和所述第三梯度。
22、在一种可能的实现方式中,所述方法还包括:
23、在显存中,保存所述第一参数矩阵对应的第一优化器状态信息、所述第二参数矩阵对应的第二优化器状态信息和所述第三参数矩阵对应的第三优化器状态信息。
24、在一种可能的实现方式中,在所述更新所述第一参数矩阵、所述第二参数矩阵以及所述第二网络层组对应的第三参数矩阵之后,所述方法还包括:
25、响应于所述目标数据处理模型微调结束,确定最新的所述第一参数矩阵与最新的所述第二参数矩阵的第二乘积;
26、将所述原参数矩阵与所述第二乘积之和,确定为所述第一网络层组对应的更新后的参数矩阵。
27、在一种可能的实现方式中,所述方法还包括:
28、获取显存的容量值;
29、根据所述容量值,确定所述第二网络层组中的网络层的数量。
30、在一种可能的实现方式中,所述损失函数包括预测下一个字的任务对应的第一损失函数。
31、在一种可能的实现方式中,所述损失函数包括强化学习任务对应的第二损失函数。
32、在一种可能的实现方式中,所述目标数据处理模型为文本处理模型,所述训练文本为训练样本。
33、根据本公开的一方面,提供了一种数据处理方法,包括:
34、获取所述大模型的微调方法训练得到的目标数据处理模型;
35、将待处理数据输入所述目标数据处理模型,通过所述目标数据处理模型输出所述待处理数据对应的数据处理结果。
36、在一种可能的实现方式中,所述待处理数据为待处理文本。
37、根据本公开的一方面,提供了一种大模型的微调装置,包括:
38、划分模块,用于将目标数据处理模型的网络层划分为第一网络层组和第二网络层组;
39、初始化模块,用于根据所述第一网络层组对应的原参数矩阵的行数和列数,初始化第一参数矩阵和第二参数矩阵,其中,所述第一参数矩阵的行数等于所述原参数矩阵的行数,所述第一参数矩阵的列数小于所述原参数矩阵的列数,所述第二参数矩阵的列数等于所述原参数矩阵的列数,且所述第二参数矩阵的行数等于所述第一参数矩阵的列数;
40、预测模块,用于将训练样本输入所述目标数据处理模型,通过所述目标数据处理模型输出所述训练样本对应的预测结果;
41、第一确定模块,用于根据所述训练样本对应的预测结果和所述训练样本对应的标签,确定所述目标数据处理模型对应的损失函数的值;
42、更新模块,用于根据所述损失函数的值,更新所述第一参数矩阵、所述第二参数矩阵以及所述第二网络层组对应的第三参数矩阵。
43、在一种可能的实现方式中,所述第一网络层组包括所述目标数据处理模型的前m个网络层,所述第二网络层组包括所述目标数据处理模型的后n个网络层,所述目标数据处理模型的网络层数为m+n,其中,m和n均为大于或等于1的整数。
44、在一种可能的实现方式中,所述第一参数矩阵的列数比所述原参数矩阵的列数小至少一个数量级。
45、在一种可能的实现方式中,在所述目标数据处理模型微调的过程中,所述原参数矩阵保持固定。
46、在一种可能的实现方式中,所述预测模块用于:
47、确定最新的所述第一参数矩阵与最新的所述第二参数矩阵的第一乘积;
48、将所述原参数矩阵与所述第一乘积之和,确定为所述第一网络层组对应的最新总参数矩阵;
49、将所述训练样本输入所述目标数据处理模型,基于所述最新总参数矩阵和最新的所述第三参数矩阵,得到所述训练样本对应的预测结果。
50、在一种可能的实现方式中,所述更新模块用于:
51、根据所述损失函数的值,确定所述第一参数矩阵对应的第一梯度、所述第二参数矩阵对应的第二梯度以及所述第三参数矩阵对应的第三梯度;
52、根据所述第一梯度,更新所述第一参数矩阵;
53、根据所述第二梯度,更新所述第二参数矩阵;
54、根据所述第三梯度,更新所述第三参数矩阵。
55、在一种可能的实现方式中,所述装置还包括:
56、第一保存模块,用于在显存中,保存所述第一梯度、所述第二梯度和所述第三梯度。
57、在一种可能的实现方式中,所述装置还包括:
58、第二保存模块,用于在显存中,保存所述第一参数矩阵对应的第一优化器状态信息、所述第二参数矩阵对应的第二优化器状态信息和所述第三参数矩阵对应的第三优化器状态信息。
59、在一种可能的实现方式中,所述装置还包括:
60、第二确定模块,用于响应于所述目标数据处理模型微调结束,确定最新的所述第一参数矩阵与最新的所述第二参数矩阵的第二乘积;
61、第三确定模块,用于将所述原参数矩阵与所述第二乘积之和,确定为所述第一网络层组对应的更新后的参数矩阵。
62、在一种可能的实现方式中,所述装置还包括:
63、第二获取模块,用于获取显存的容量值;
64、第四确定模块,用于根据所述容量值,确定所述第二网络层组中的网络层的数量。
65、在一种可能的实现方式中,所述损失函数包括预测下一个字的任务对应的第一损失函数。
66、在一种可能的实现方式中,所述损失函数包括强化学习任务对应的第二损失函数。
67、在一种可能的实现方式中,所述目标数据处理模型为文本处理模型,所述训练文本为训练样本。
68、根据本公开的一方面,提供了一种数据处理装置,包括:
69、第一获取模块,用于获取所述大模型的微调装置训练得到的目标数据处理模型;
70、数据处理模块,用于将待处理数据输入所述目标数据处理模型,通过所述目标数据处理模型输出所述待处理数据对应的数据处理结果。
71、在一种可能的实现方式中,所述待处理数据为待处理文本。
72、根据本公开的一方面,提供了一种电子设备,包括:一个或多个处理器;用于存储可执行指令的存储器;其中,所述一个或多个处理器被配置为调用所述存储器存储的可执行指令,以执行上述方法。
73、根据本公开的一方面,提供了一种计算机可读存储介质,其上存储有计算机程序指令,所述计算机程序指令被处理器执行时实现上述方法。
74、根据本公开的一方面,提供了一种计算机程序产品,包括计算机可读代码,或者承载有计算机可读代码的非易失性计算机可读存储介质,当所述计算机可读代码在电子设备中运行时,所述电子设备中的处理器执行上述方法。
75、在本公开实施例中,通过将目标数据处理模型的网络层划分为第一网络层组和第二网络层组,根据所述第一网络层组对应的原参数矩阵的行数和列数,初始化第一参数矩阵和第二参数矩阵,其中,所述第一参数矩阵的行数等于所述原参数矩阵的行数,所述第一参数矩阵的列数小于所述原参数矩阵的列数,所述第二参数矩阵的列数等于所述原参数矩阵的列数,且所述第二参数矩阵的行数等于所述第一参数矩阵的列数,将训练样本输入所述目标数据处理模型,通过所述目标数据处理模型输出所述训练样本对应的预测结果,根据所述训练样本对应的预测结果和所述训练样本对应的标签,确定所述目标数据处理模型对应的损失函数的值,并根据所述损失函数的值,更新所述第一参数矩阵、所述第二参数矩阵以及所述第二网络层组对应的第三参数矩阵,由此在目标数据处理模型的微调中,无需更新参数量较大的所述原参数矩阵,只需更新第三参数矩阵以及参数量较小的第一参数矩阵和第二参数矩阵。
76、与全量微调相比,本公开实施例能够降低显存开销,降低对显存容量的要求,且微调得到的目标数据处理模型在目标域的数据处理效果接近于全量微调的效果。与固定部分参数(fix parameter)微调、适配器(adapter)微调、前缀微调(prefix-tuning)的方案相比,本公开实施例微调得到的目标数据处理模型在目标域的数据处理效果有明显的提升。
77、应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,而非限制本公开。
78、根据下面参考附图对示例性实施例的详细说明,本公开的其它特征及方面将变得清楚。