模型训练方法、任务处理方法及相关产品与流程

文档序号:40258086发布日期:2024-12-11 12:49阅读:14来源:国知局
模型训练方法、任务处理方法及相关产品与流程

本申请涉及人工智能,具体涉及一种模型训练方法、任务处理方法及相关产品。


背景技术:

1、在现实世界的数据处理过程中,噪声标注问题几乎无法避免。从社交媒体平台上用户生成的标签或者人工标注产生的标签,都可能导致标注数据中含有噪声,这种含有噪声的标注数据可称为噪声样本,存在的噪声称为标注噪声。如图1所示,把冲锋衣标注成羽绒服,将针织衫标注为衬衫。然而,使用这些噪声样本会对模型的训练造成混淆,尤其是对于大规模深度学习模型,它们会在学习过程中将这些噪声样认为有价值的信息,进而导致模型对于噪声样过拟合。

2、因此,如何削弱噪声样本对模型训练的影响,增强模型面对噪声样本时的训练鲁棒性是目前亟待解决的技术问题。


技术实现思路

1、本申请实施例提供了一种模型训练方法、任务处理方法及相关产品,降低噪声样本对模型训练的影响,提高模型的鲁棒性。

2、第一方面,本申请实施例提供一种模型训练方法,包括:

3、利用分类模型对训练样本进行第t轮类别预测,得到所述训练样本在第t轮的类别预测概率,其中,所述软标签用于指示所述训练样本属于各个类别的概率;

4、基于所述训练样本在t-1轮的软标签和第t轮的类别预测概率,确定所述训练样本在第t轮的软标签;

5、基于所述第t轮的软标签、所述第t轮的类别预测概率以及所述训练样本的硬标签,对所述分类模型进行第t轮训练,其中,所述硬标签用于指示所述训练样本所属的类别。

6、第二方面,本申请实施例提供一种任务处理方法,包括:

7、获取待处理内容;

8、通过目标特征提取网络,对所述待处理内容进行特征提取,得到第一目标特征,其中,所述目标特征提取网络通过第一方面所述的方法训练得到;

9、基于所述第一目标特征进行下游任务处理。

10、第三方面,本申请实施例提供一种模型训练装置,包括:分类单元和训练单元;

11、所述分类单元,用于利用所述分类模型对训练样本进行第t轮类别预测,得到所述训练样本在第t轮的类别预测概率;

12、所述训练单元,用于基于所述训练样本在t-1轮的软标签和第t轮的类别预测概率,确定所述训练样本在第t轮的软标签,其中,所述软标签用于指示所述训练样本属于各个类别的概率;基于所述第t轮的软标签、所述第t轮的类别预测概率以及所述训练样本的硬标签,对所述分类模型进行第t轮训练,其中,所述硬标签用于指示所述训练样本所属的类别。

13、第四方面,本申请实施例提供一种任务处理装置,包括:获取单元和处理单元;

14、所述获取单元,用于获取待处理内容;

15、所述处理单元,用于通过目标特征提取网络,对所述待处理内容进行特征提取,得到第一目标特征,其中,所述目标特征提取网络通过第一方面所述的方法训练得到;

16、基于所述第一目标特征进行下游任务处理。

17、第五方面,本申请实施例提供一种电子设备,包括:处理器和存储器,所述处理器与存储器相连,所述存储器用于存储计算机程序,所述处理器用于执行所述存储器中存储的计算机程序,以使得所述电子设备执行如第一方面或第二方面所述的方法。

18、第六方面,本申请实施例提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如第一方面或第二方面所述的方法。

19、第七方面,本申请实施例提供一种计算机程序产品,所述计算机程序产品包括计算机程序,所述计算机程序被处理器执行时实现如第一方面或第二方面所述的方法。

20、实施本申请实施例,具有如下有益效果:

21、可以看出,在本申请实施例中,在使用训练样本对模型训练时,会为训练样本构造软标签进行模型训练,从而可以从软标签中学习到内容,不会过度依赖硬标签,从而降低标注噪声对模型训练的影响,提高分类模型的鲁棒性,并且在每轮更新训练沿边的软标签时,并不是单纯只使用训练样本在上一轮的软标签更新出本轮的软标签,而是使用本轮的类别预测概率和上一轮的软标签同时更新出训练样本在本轮的软标签,从而实现借鉴本轮分类模型的类别预测情况,对上一轮的软标签进行更新,保持训练样本的软标签更新的稳定性和准确性,从而降低了软标签中标注噪声的含量,弱化了标注噪声对模型训练的影响,提高了模型的鲁棒性。



技术特征:

1.一种模型训练方法,其特征在于,包括:

2.根据权利要求1所述的方法,其特征在于,所述基于所述训练样本在t-1轮的软标签和第t轮的类别预测概率,确定所述训练样本在第t轮的软标签,包括:

3.根据权利要求1或2所述的方法,其特征在于,所述基于所述第t轮的软标签、所述第t轮的类别预测概率以及所述训练样本的硬标签,进行第t轮模型训练,包括:

4.根据权利要求3所述的方法,其特征在于,所述基于所述第t轮的软标签和所述第t轮的类别预测概率,确定第一交叉熵损失,包括:

5.根据权利要求3或4所述的方法,其特征在于,所述基于所述第t轮的软标签、所述第t轮的类别预测概率以及所述训练样本的硬标签,确定第二交叉熵损失,包括:

6.根据权利要求1-5任一项所述的方法,其特征在于,

7.根据权利要求1-6任一项所述的方法,其特征在于,

8.一种任务处理方法,其特征在于,包括:

9.一种模型训练装置,其特征在于,包括:分类单元和训练单元;

10.一种任务处理装置,其特征在于,包括:获取单元和处理单元;

11.一种电子设备,其特征在于,包括:处理器和存储器,所述处理器与所述存储器相连,所述存储器用于存储计算机程序,所述处理器用于执行所述存储器中存储的计算机程序,以使得所述电子设备执行如权利要求1-8任一项所述的方法。

12.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行以实现如权利要求1-8任一项所述的方法。


技术总结
本申请实施例公开了一种模型训练方法、任务处理方法及相关产品。该方法包括:利用所述分类模型对训练样本进行第t轮类别预测,得到所述训练样本在第t轮的类别预测概率;基于所述训练样本在t‑1轮的软标签和第t轮的类别预测概率,确定所述训练样本在第t轮的软标签,其中,所述软标签用于指示所述训练样本属于各个类别的概率;基于所述第t轮的软标签、所述第t轮的类别预测概率以及所述训练样本的硬标签,对所述分类模型进行第t轮训练,其中,所述硬标签用于指示所述训练样本所属的类别。本申请有利于提高模型的鲁棒性。

技术研发人员:王麒雄,蒋小龙
受保护的技术使用者:书行科技(北京)有限公司
技术研发日:
技术公布日:2024/12/10
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1