本技术涉及自然语言处理,特别是涉及一种训练文本预测模型的方法、文本预测方法及装置。
背景技术:
1、nlp(natural language processing,自然语言处理)研究的目标是让机器能够理解人类语言。其中llm(large language model,大型语言模型)是自然语言处理领域中的一个核心工具,指的是具有大规模参数(通常数以亿计或更多)的深度学习模型。
2、llm因其具有极高的学习能力而被广泛地应用于文本预测领域。其中,情景学习(in context learning)是目前llm采用的其中一种文本预测方式。所谓情景学习指的是,给定标注数据后,llm进行观察和归纳,对无标签数据进行预测。由于情景学习通常没有训练过程,应用于文本预测时,预测效果不佳。因此亟需一种方式能够提高情景学习场景下基于llm的文本预测效果。
技术实现思路
1、有鉴于此,本技术提供了一种训练文本预测模型的方法、文本预测方法及装置,以便于提高情景学习场景下基于llm的文本预测效果。
2、本技术提供了如下方案:
3、第一方面,提供了一种训练文本预测模型的方法,所述方法包括:
4、获取训练数据集,所述训练数据集包括输入文本样本以及该输入文本样本对应的输出标签;
5、将包含输入文本样本和该输入文本样本对应的输出标签的文本序列作为文本预测模型的输入,训练所述文本预测模型;其中,所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换transformer网络;
6、在所述训练中各transformer网络分别作为当前层transformer网络执行:在当前轮迭代中,利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与当前层transformer网络在上一轮迭代得到的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与当前层transformer网络在上一轮迭代得到的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层transformer网络输出的特征表示;利用所述第一键矩阵对所述第二键矩阵进行更新,将更新后的第二键矩阵作为所述当前层transformer网络在当前轮迭代得到的第二键矩阵;利用所述第一值矩阵对所述第二值矩阵进行更新,将更新后的第二值矩阵作为所述当前层transformer网络在当前轮迭代得到的第二值矩阵。
7、根据本技术实施例中一可实现的方式,所述大型语言模型还包括嵌入网络,所述文本预测模型还包括预测网络;
8、所述嵌入网络用以对所述文本序列进行嵌入处理;
9、若当前层transformer网络为第一层transformer网络,则所述上一层网络输出的特征表示为所述嵌入网络输出的特征表示;否则,所述上一层网络输出的特征表示为上一层transformer网络输出的特征表示;
10、所述预测网络用以利用最后一层transformer网络输出的特征表示预测输入文本样本对应的输出标签。
11、根据本技术实施例中一可实现的方式,利用所述第一键矩阵对所述第二键矩阵进行更新包括:利用所述第一键矩阵对所述当前层transformer网络在上一轮迭代得到的第二键矩阵采用动量梯度下降的方式进行更新;
12、利用所述第一值矩阵对所述第二值矩阵进行更新包括:利用所述第一值矩阵对所述当前层transformer网络在上一轮迭代得到的第二值矩阵采用动量梯度下降的方式进行更新。
13、根据本技术实施例中一可实现的方式,利用所述第一键矩阵对所述当前层transformer网络在上一轮迭代得到的第二键矩阵采用动量梯度下降的方式进行更新包括:利用所述第一键矩阵和所述当前层transformer网络在上一轮迭代得到的第二键矩阵进行逐元素求差,得到键矩阵梯度;利用所述当前层transformer网络在上一轮迭代得到的键矩阵动量和所述键矩阵梯度进行加权求和,得到所述当前层transformer网络在当前轮迭代得到的键矩阵动量;利用所述当前层transformer网络在当前轮迭代得到的键矩阵动量和当前层transformer网络在上一轮迭代得到的第二键矩阵,得到所述更新后的第二键矩阵;
14、利用所述第一值矩阵对所述当前层transformer网络在上一轮迭代得到的第二值矩阵采用动量梯度下降的方式进行更新包括:利用所述第一值矩阵和所述当前层transformer网络在上一轮迭代得到的第二值矩阵进行逐元素求差,得到值矩阵梯度;利用所述当前层transformer网络在上一轮迭代得到的值矩阵动量和所述值矩阵梯度进行加权求和,得到所述当前层transformer网络在当前轮迭代得到的值矩阵动量;利用所述当前层transformer网络在当前轮迭代得到的值矩阵动量和当前层transformer网络在上一轮迭代得到的第二值矩阵,得到所述更新后的第二值矩阵。
15、根据本技术实施例中一可实现的方式,每一轮迭代完成后,若确定达到预设的训练结束条件,则将各层transformer网络当前迭代得到的第二键矩阵和第二值矩阵分别作为训练得到的各层transformer网络的第二键矩阵和第二值矩阵进行存储。
16、第二方面,提供了一种文本预测方法,所述方法包括:
17、获取输入文本;
18、将包含所述输入文本的文本序列输入文本预测模型,获取所述文本预测模型预测得到的所述输入文本对应的输出标签;所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换transformer网络;
19、各transformer网络分别作为当前层transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与预先训练得到的当前层transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与预先训练得到的当前层transformer网络的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层transformer网络输出的特征表示。
20、根据本技术实施例中一可实现的方式,所述大型语言模型还包括嵌入网络,所述文本预测模型还包括预测网络;
21、所述嵌入网络用以对所述文本序列进行嵌入处理;
22、若当前层transformer网络为第一层transformer网络,则所述上一层网络输出的特征表示为所述嵌入网络输出的特征表示;否则,所述上一层网络输出的特征表示为上一层transformer网络输出的特征表示;
23、所述预测网络用以利用最后一层transformer网络输出的特征表示预测输入文本对应的输出标签。
24、第三方面,提供了一种训练文本预测模型的方法,所述方法包括:
25、获取训练数据集,所述训练数据集包括输入文本样本以及该输入文本样本对应的情感类别标签;
26、将包含输入文本样本和该输入文本样本对应的情感类别标签的文本序列作为文本预测模型的输入,训练所述文本预测模型;其中,所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的transformer网络;
27、在所述训练中各transformer网络分别作为当前层transformer网络执行:在当前轮迭代中,利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与当前层transformer网络在上一轮迭代得到的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与当前层transformer网络在上一轮迭代得到的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层transformer网络输出的特征表示;利用所述第一键矩阵对所述第二键矩阵进行更新,将更新后的第二键矩阵作为所述当前层transformer网络在当前轮迭代得到的第二键矩阵;利用所述第一值矩阵对所述第二值矩阵进行更新,将更新后的第二值矩阵作为所述当前层transformer网络在当前轮迭代得到的第二值矩阵。
28、第四方面,提供了一种情感分析方法,所述方法包括:
29、获取输入文本;
30、将包含所述输入文本的文本序列输入文本预测模型,获取所述文本预测模型预测得到的所述输入文本对应的情感类别标签;所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换transformer网络;
31、各transformer网络分别作为当前层transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与预先训练得到的当前层transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与预先训练得到的当前层transformer网络的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层transformer网络输出的特征表示。
32、第五方面,提供了一种文本预测方法,由云端服务器执行,所述方法包括:
33、获取来自用户终端的输入文本;
34、将包含所述输入文本的文本序列输入文本预测模型,获取所述文本预测模型预测得到的所述输入文本对应的输出标签;
35、基于所述输出标签确定对应的服务内容,将所述服务内容发送至所述用户终端;
36、其中,所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换transformer网络;
37、各transformer网络分别作为当前层transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与预先训练得到的当前层transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与预先训练得到的当前层transformer网络的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层transformer网络输出的特征表示。
38、第六方面,提供了一种训练文本预测模型的装置,所述装置包括:
39、样本获取单元,被配置为获取训练数据集,所述训练数据集包括输入文本样本以及该输入文本样本对应的输出标签样本;
40、模型训练单元,被配置为将包含输入文本样本和该输入文本样本对应的输出标签的文本序列作为文本预测模型的输入,训练所述文本预测模型;其中,所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换transformer网络;在所述训练中各transformer网络分别作为当前层transformer网络执行:在当前轮迭代中,利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与当前层transformer网络在上一轮迭代得到的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与当前层transformer网络在上一轮迭代得到的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层transformer网络输出的特征表示;利用所述第一键矩阵对所述第二键矩阵进行更新,将更新后的第二键矩阵作为所述当前层transformer网络在当前轮迭代得到的第二键矩阵;利用所述第一值矩阵对所述第二值矩阵进行更新,将更新后的第二值矩阵作为所述当前层transformer网络在当前轮迭代得到的第二值矩阵。
41、第七方面,提供了一种文本预测装置,所述装置包括:
42、文本获取单元,被配置为获取输入文本;
43、文本预测单元,被配置为将包含所述输入文本的文本序列输入文本预测模型,获取所述文本预测模型预测得到的所述输入文本对应的输出标签;所述文本预测模型采用大型语言模型,所述大型语言模型包括多层串连的转换transformer网络;其中,各transformer网络分别作为当前层transformer网络执行:利用上一层网络输出的特征表示确定第一键矩阵、第一值矩阵和第一查询矩阵;将所述第一键矩阵与预先训练得到的当前层transformer网络的第二键矩阵进行拼接得到第三键矩阵,以及将所述第一值矩阵与预先训练得到的当前层transformer网络的第二值矩阵进行拼接得到第三值矩阵;利用所述第三键矩阵、第三值矩阵和第一查询矩阵进行自注意力机制的处理,得到当前层transformer网络输出的特征表示。
44、根据第八方面,提供了一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现上述第一方面中任一项所述的方法的步骤。
45、根据第九方面,提供了一种电子设备,包括:
46、一个或多个处理器;以及
47、与所述一个或多个处理器关联的存储器,所述存储器用于存储程序指令,所述程序指令在被所述一个或多个处理器读取执行时,执行上述第一方面中任一项所述的方法的步骤。
48、根据本技术提供的具体实施例,本技术公开了以下技术效果:
49、1)本技术利用训练样本集对llm进行训练来得到文本预测模型,这种方式实质上利用了已标注的样本训练llm,在第二键矩阵和第二值矩阵的更新过程中,利用了上一轮迭代得到的第二键矩阵和第二值矩阵与当前输入特征矩阵产生的第一键矩阵和第一值矩阵,既保留了历史信息又保持了当前输入文本的信息,使得llm能够充分对已标注的样本进行理解和学习,从而提高llm的文本预测效果。
50、2)本技术在llm的训练过程中,每一轮迭代过程中使用的第二键矩阵和第二值矩阵均由上一轮迭代得到,且进行更新后用于下一轮迭代使用。这种前向优化模型的方式,仅需要优化各transformer网络的第二键矩阵和第二值矩阵即可,大大缩减了需要更新的模型参数,降低了模型训练的成本,提高了效率。
51、3)本技术实施例中采用动量梯度下降的方式更新各transformer网络的第二键矩阵和第二值矩阵,能够加快梯度下降的速度,使得迭代效率更高,并且避免陷入局部最小值。
52、4)本技术在预测过程中使用了文本预测模型训练后得到的各transformer网络的第二键矩阵和第二值矩阵,与普通的情景学习相比,各transformer网络的第二键矩阵和第二值矩阵包含了文本预测模型对已标注的样本数据更好的观察和理解,能够显著提高预测准确性。
53、当然,实施本技术的任一产品并不一定需要同时达到以上所述的所有优点。