基于排序约束的文本生成模型的迁移训练方法和装置

文档序号:40814650发布日期:2025-01-29 02:29阅读:3来源:国知局
基于排序约束的文本生成模型的迁移训练方法和装置

本发明涉及自然语言处理 ,尤其涉及一种基于排序约束的文本生成模型的迁移训练方法和装置。


背景技术:

1、大型语言模型具有解决文本生成任务的强大能力,开启了人工智能领域的巨大变革。然而,大型语言模型庞杂的参数量带来了巨大的计算负担,造成了训练和推理的高昂使用成本。因此,探索如何减少模型的大小和计算成本,且尽可能保持其文本生成性能的模型压缩技术具有重要意义。

2、语言模型的知识蒸馏技术是模型压缩技术的一种,是将大型语言模型(称为教师模型)的知识迁移到参数规模更小的模型(称为学生模型)的过程。通过该过程,学生模型可以学习到教师模型的隐式知识,即教师模型在输入数据上的复杂概率分布,而非仅学习输入数据的表面内容。经过知识蒸馏,学生模型可以在保持较高文本生成性能的同时,具有更小的参数规模和计算资源需求,使其更适合在资源受限的环境下部署和使用。

3、然而,由于对于同样的输入可能有多个正确的预测词,因此语言模型的预测分布往往冗长且多峰。语言模型分布的多峰现象造成了学生模型学习分布的困难,现有的蒸馏目标在学习多峰分布的过程中表现出较低的效率,知识蒸馏的效果较差。


技术实现思路

1、本发明提供一种基于排序约束的文本生成模型的迁移训练方法和装置,用以解决现有的蒸馏目标在学习多峰分布的过程中表现出较低的效率,知识蒸馏的效果较差的问题。

2、本发明提供一种基于排序约束的文本生成模型的迁移训练方法,包括:获取训练数据,每个训练数据包括输入文本和答复文本;将所述训练数据输入第一文本生成模型得到所述第一文本生成模型对答复文本的每个预测位置上的第一预测分布输出;将所述训练数据输入第二文本生成模型得到所述第二文本生成模型对答复文本的每个预测位置上的第二预测分布输出;计算所述第一预测分布输出和所述第二预测分布输出之间的排序损失和蒸馏损失;将所述排序损失和所述蒸馏损失融合得到混合损失,并将所有预测位置上的混合损失的均值作为总损失;基于所述总损失进行反向传播,以更新所述第二文本生成模型的模型参数;其中,所述第一文本生成模型的参数规模大于所述第二文本生成模型的参数规模。

3、根据本发明提供的一种基于排序约束的文本生成模型的迁移训练方法,所述计算所述第一预测分布输出和所述第二预测分布输出之间的排序损失和蒸馏损失,包括:分别获取所述第一文本生成模型和所述第二文本生成模型在每个预测位置上的头部预测序列;确定所述第一文本生成模型的头部预测序列与所述第二文本生成模型的头部预测序列的并集;分别获取所述并集内类别在所述第一文本生成模型和所述第二文本生成模型的原始预测分布中的预测概率值序列;根据所述预测概率值序列计算基于斯皮尔曼等级相关系数的排序损失。

4、根据本发明提供的一种基于排序约束的文本生成模型的迁移训练方法,所述排序损失的目标函数为:

5、;

6、其中,表示排序损失,表示第一文本生成模型和第二文本生成模型预测分布之间的斯皮尔曼等级相关系数,表示第一文本生成模型的预测概率值序列,表示第二文本生成模型的预测概率值序列,表示和的协方差,和表示和的标准差。

7、根据本发明提供的一种基于排序约束的文本生成模型的迁移训练方法,所述将所述排序损失和所述蒸馏损失融合得到混合损失,包括:采用固定配比将所述排序损失和所述蒸馏损失进行融合处理,得到所述混合损失。

8、根据本发明提供的一种基于排序约束的文本生成模型的迁移训练方法,所述获取训练数据之后,所述方法还包括:对所述输入文本和所述答复文本进行分词处理,并基于预设词表将所述输入文本和所述答复文本转换为数据序列。

9、本发明还提供一种基于排序约束的文本生成模型的迁移训练装置,包括如下模块:获取模块和处理模块;所述获取模块,用于获取训练数据,每个训练数据包括输入文本和答复文本;所述处理模块,用于将所述训练数据输入第一文本生成模型得到所述第一文本生成模型对答复文本的每个预测位置上的第一预测分布输出;将所述训练数据输入第二文本生成模型得到所述第二文本生成模型对答复文本的每个预测位置上的第二预测分布输出;计算所述第一预测分布输出和所述第二预测分布输出之间的排序损失和蒸馏损失;将所述排序损失和所述蒸馏损失融合得到混合损失,并将所有预测位置上的混合损失的均值作为总损失;基于所述总损失进行反向传播,以更新所述第二文本生成模型的模型参数;其中,所述第一文本生成模型的参数规模大于所述第二文本生成模型的参数规模。

10、根据本发明提供的一种基于排序约束的文本生成模型的迁移训练装置,所述处理模块,用于分别获取所述第一文本生成模型和所述第二文本生成模型在每个预测位置上的头部预测序列;确定所述第一文本生成模型的头部预测序列与所述第二文本生成模型的头部预测序列的并集;分别获取所述并集内类别在所述第一文本生成模型和所述第二文本生成模型的原始预测分布中的预测概率值序列;根据所述预测概率值序列计算基于斯皮尔曼等级相关系数的排序损失。

11、根据本发明提供的一种基于排序约束的文本生成模型的迁移训练装置,所述排序损失的目标函数为:

12、;

13、其中,表示排序损失,表示第一文本生成模型和第二文本生成模型预测分布之间的斯皮尔曼等级相关系数,表示第一文本生成模型的预测概率值序列,表示第二文本生成模型的预测概率值序列,表示和的协方差,和表示和的标准差。

14、根据本发明提供的一种基于排序约束的文本生成模型的迁移训练装置,所述处理模块,用于采用固定配比将所述排序损失和所述蒸馏损失进行融合处理,得到所述混合损失。

15、根据本发明提供的一种基于排序约束的文本生成模型的迁移训练装置,所述处理模块,用于对所述输入文本和所述答复文本进行分词处理,并基于预设词表将所述输入文本和所述答复文本转换为数据序列。

16、本发明还提供一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现如上述任一种所述基于排序约束的文本生成模型的迁移训练方法。

17、本发明还提供一种非暂态计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现如上述任一种所述基于排序约束的文本生成模型的迁移训练方法。

18、本发明还提供一种计算机程序产品,包括计算机程序,所述计算机程序被处理器执行时实现如上述任一种所述基于排序约束的文本生成模型的迁移训练方法。

19、本发明提供的基于排序约束的文本生成模型的迁移训练方法和装置,可以获取训练数据,每个训练数据包括输入文本和答复文本;将所述训练数据输入第一文本生成模型得到所述第一文本生成模型对答复文本的每个预测位置上的第一预测分布输出;将所述训练数据输入第二文本生成模型得到所述第二文本生成模型对答复文本的每个预测位置上的第二预测分布输出;计算所述第一预测分布输出和所述第二预测分布输出之间的排序损失和蒸馏损失;将所述排序损失和所述蒸馏损失融合得到混合损失,并将所有预测位置上的混合损失的均值作为总损失;基于所述总损失进行反向传播,以更新所述第二文本生成模型的模型参数;其中,所述第一文本生成模型的参数规模大于所述第二文本生成模型的参数规模。通过该方案,由于可以计算第一预测分布输出和第二预测分布输出之间的排序损失,并结合蒸馏损失得到的总损失来更新第二文本生成模型的模型参数,因此可以通过词级别的排序损失计算两个模型峰值预测的一致程度,从而实现第一文本生成模型和第二文本生成模型的多峰预测分布的高效对齐。如此,不仅能够更充分地利用两个预测分布峰值类别之间的细粒度信息,还能够保证与现有蒸馏目标的良好兼容性。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1