多任务模型训练方法、任务预测方法、装置、计算机设备及介质与流程

文档序号:35930530发布日期:2023-11-05 04:23阅读:72来源:国知局
多任务模型训练方法、任务预测方法、装置、计算机设备及介质与流程

本发明涉及数据处理,具体涉及一种多任务模型训练方法、任务预测方法、装置、计算机设备及介质。


背景技术:

1、随着深度学习技术的快速发展,通常要求算法多任务模型完成对目标进行检测和分类的所有任务,即算法多任务模型需要先基于获取的全局图像检测出目标的位置,再基于目标的位置对目标的属性信息进行分类。具体地,在自动驾驶场景中,若以车辆为目标,则要求算法多任务模型识别车辆的位置,并对车辆的车型、遮挡情况及车头方向等属性信息进行分类。若以车辆内的驾驶人员为目标,则要求算法多任务模型识别驾驶人员的位置,并对驾驶人员的穿戴及坐姿等属性信息进行识别。通常用于实现对目标进行检测和分类的算法多任务模型通常为包括多个任务头的多任务模型。现有的多任务模型一个任务头用于识别目标的位置,得到目标检测结果。其余的任务头用于对目标的属性信息进行分类,得到分类预测结果。

2、在训练多任务模型时,需要将训练集中的图像样本输入至所有的任务头,进而要求海量的图像样本中均包括对所有任务头的输出结果进行标注的标签结果,以完成对多个任务头的训练。然而,新增任务头等使得多任务模型的任务头发生变化的情况,将导致图像样本缺乏变化后的任务头所需的标签结果。当图像样本中不包括任务头的所需标签结果时,无法根据任务头输出的预测结果与标签结果对应的真实结果的偏离程度,对任务头进行数据迭代,导致了多任务模型的性能不佳。利用多任务模型对目标进行检测和分类时,无法得到准确的结果。


技术实现思路

1、本发明的目的之一在于提供一种多任务模型训练方法,以解决现有技术中利用多任务模型对目标进行检测和分类时,无法得到准确的结果的问题;目的之二在于提供一种任务预测方法;目的之三在于提供一种多任务模型训练装置;目的之四在于提供一种任务预测装置;目的之五在于提供一种计算机设备;目的之六在于提供一种机器可读存储介质。

2、为了实现上述目的,本发明采用的技术方案如下:

3、第一方面,本技术提供一种多任务模型的训练方法,多任务模型包括共享网络、目标检测头及至少一个任务头,多任务模型的训练方法包括:

4、获取多个图像样本,其中,图像样本包括至少一个任务头的标签结果;

5、将标签结果为同一任务头的图像样本进行合并,得到每个任务头的训练集,其中,每个训练集的名称与对应的任务头的名称相同;

6、将每个训练集输入至共享网络,得到图像样本的特征图;

7、将图像样本的特征图输入至目标检测头,得到待检测目标的特征图;

8、针对每个训练集,将待检测目标的特征图输入至与训练集的名称相同的任务头,迭代训练多任务模型。

9、结合第一方面,在第一种可能的实现方式中,多任务模型的构建步骤包括:

10、获取每个训练集的名称;

11、针对每个训练集,构建与训练集的名称相同的任务头;

12、基于共享网络、目标检测头及所有的任务头,构建多任务模型。

13、结合第一方面,在第二种可能的实现方式中,针对每个训练集,将待检测目标的特征图输入至与训练集的名称相同的任务头,迭代训练多任务模型,包括:

14、针对每个训练集,将待检测目标的特征图输入至与训练集的名称相同的任务头,得到每个任务头的损失;

15、基于每个任务头的损失,依次对每个任务头进行数据迭代;

16、在所有任务头均完成数据迭代的情况下,执行针对每个训练集,将待检测目标的特征图输入至与训练集的名称相同的任务头,得到每个任务头的损失的步骤,直到每个任务头的损失均小于预设损失阈值。

17、结合第一方面的第二种可能的实现方式,在第三种可能的实现方式中,每个任务头的损失均包括目标检测头的回归损失、目标检测头的分类损失及任务头的分类损失。

18、结合第一方面,在第四种可能的实现方式中,共享网络还包括主干网络和中间网络,将每个训练集输入至共享网络,得到图像样本的特征图,包括:

19、将训练集输入至主干网络,得到初始的图像样本特征图;

20、将初始的图像样本特征图输入至中间网络,得到更新尺度的图像样本特征图。

21、结合第一方面,在第五种可能的实现方式中,每个任务头包括特征聚集层和任务网络,针对每个训练集,将待检测目标的特征图输入至与训练集的名称相同的任务头,迭代训练多任务模型,包括:

22、针对每个训练集,将待检测目标的特征图输入至与训练集的名称相同的任务头的特征聚集层,得到预设尺度的待检测目标的特征图;

23、将预设尺度的待检测目标的特征图输入至任务网络,迭代训练多任务模型。

24、第二方面,本技术提供一种任务预测方法,任务预测方法包括:

25、获取目标图像;

26、将目标图像输入至多任务模型,得到目标检测结果和目标检测结果对应的至少一个任务的预测结果,其中,多任务模型根据如第一方面的多任务模型的训练方法得到。

27、第三方面,本技术提供一种多任务模型的训练装置,多任务模型包括目标检测头和至少一个任务头,多任务模型的训练装置包括:

28、图像样本获取模块,用于获取多个图像样本,其中,图像样本包括至少一个任务头的标签结果;

29、训练集得到模块,用于将标签结果为同一任务头的图像样本进行合并,得到每个任务头的训练集,其中,每个训练集的名称与对应的任务头的名称相同;

30、样本特征图得到模块,用于将每个训练集输入至共享网络,得到图像样本的特征图;

31、目标特征图得到模块,用于将图像样本的特征图输入至目标检测头,得到待检测目标的特征图;

32、模型迭代训练模块,用于针对每个训练集,将待检测目标的特征图输入至与训练集的名称相同的任务头,迭代训练多任务模型。

33、第四方面,本技术提供一种任务预测装置,任务预测装置包括:

34、图像获取模块,用于获取目标图像;

35、预测结果得到模块,用于将目标图像输入至多任务模型,得到目标检测结果和目标检测结果对应的至少一个任务的预测结果,其中,多任务模型根据如第一方面的多任务模型的训练方法得到。

36、第五方面,本技术提供一种计算机设备,计算机设备包括存储器及处理器,存储器存储有计算机程序,计算机程序在处理器执行时,实现如第一方面的多任务模型的训练方法,或实现如第二方面的任务预测方法。

37、第六方面,本技术提供一种机器可读存储介质,计算机可读存储介质上存储有计算机程序,计算机程序被处理器执行时,实现如第一方面的多任务模型的训练方法,或实现如第二方面的任务预测方法。

38、本发明的有益效果:

39、(1)针对每个训练集,将训练集中待检测目标的特征图输入至与训练集的名称相同的任务头。由于训练集中的图像样本并没有输入至所有的任务头,而根据训练集的名称和任务头名称,将训练集中图像样本输入至对应的任务头,使得输入至任务头的图像样本的特征图都具有对任务头进行迭代的标签结果。能够根据预测结果与真实结果的偏离程度,对任务头进行数据迭代,进而能够利用训练后的多任务模型,得到准确的目标检测结果和分类预测结果。

40、(2)在需要输出多个待检测目标的目标检测结果时,现有技术中利用模型的多个任务头独立输出目标检测结果和分类预测结果,需要将分类预测结果与目标检测结果进行匹配,才能得到待检测目标的分类预测结果。本技术由于先利用目标检测头确定了图像中的待检测目标,再将训练集中待检测目标的特征图输入至任务头,使得任务头能够直接输出待检测目标的分类预测结果,不需要再将分类预测结果与目标检测结果进行匹配。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1