图片检测模型训练方法、装置、计算机设备及存储介质与流程

文档序号:34301238发布日期:2023-05-31 16:34阅读:74来源:国知局
图片检测模型训练方法、装置、计算机设备及存储介质与流程

本技术涉及人工智能,具体涉及一种图片检测模型训练方法、装置、计算机设备及存储介质。


背景技术:

1、随着人工智能热潮的兴起和大数据时代的到来,基于图像和视频理解世界的计算机视觉技术得到了蓬勃的发展。目前在物流领域,图像识别技术得到广泛应用,例如违禁品的识别、包裹的计数等等。但是目前市场上很多模型并不能适应复杂的应用场景,在物流安检领域包裹的多样性、不同摆放姿态以及不同的成像角度,都可能导致漏检错检,因此如何有效检测待检测品类中小型物体和遮挡物体成为亟待解决的问题。


技术实现思路

1、基于此,有必要针对上述技术问题,提供一种图片检测模型训练方法、装置、计算机设备及存储介质,引导神经网络加强对小型物体可判别特征的识别度、以及增加遮挡部位识别的敏感性。

2、第一方面,本技术提供一种图片检测模型训练方法,包括:

3、获取训练图片样本集,所述训练图片样本集的待训练图片中包含检测物标记类别和检测物标记框;

4、通过图片检测模型对所述训练图片样本集中的各待训练图片进行训练,得到检测物预测类别和检测物预测框;

5、根据所述检测物标记类别、所述检测物标记框、所述检测物预测类别以及所述检测物预测框,计算所述图片检测模型的总损失值;

6、根据所述总损失值对所述图片检测模型进行修正,得到已训练的图片检测模型。

7、在本技术一些实施例中,所述根据所述总损失值对所述图片检测模型进行修正,得到修正后的图片检测模型之后,包括:

8、对所述图片检测模型进行修正,直至图片检测模型的修正次数达到预设次数,将最后一次修正后图片检测模型确定为已训练的图片检测模型。

9、在本技术一些实施例中,所述图片检测模型包括特征提取层、类别预测分支网络以及坐标预测分支网络,所述通过图片检测模型对所述训练图片样本集中的各待训练图片进行训练,得到检测物预测类别和检测物预测框,包括:

10、通过所述特征提取层对所述训练图片样本集中的各待训练图片进行特征识别,得到所述各待训练图片的图片特征;

11、将所述图片特征输入到类别预测分支网络得到检测物预测类别;

12、将所述图片特征输入到坐标预测分支网络得到检测物预测框;

13、根据所述检测物标记类别、所述检测物标记框、所述检测物预测类别以及检测物预测框,计算所述图片检测模型的总损失值,包括:

14、根据所述检测物标记类别、所述检测物标记框、所述检测物预测类别以及检测物预测框,计算级联所述类别预测分支网络的损失值和所述坐标预测分支网络的损失值的总损失值。

15、在本技术一些实施例中,所述特征提取层、类别预测分支网络以及坐标预测分支网络均包含神经网络层,所述根据所述总损失值对所述图片检测模型进行修正,得到已训练的图片检测模型,包括:

16、通过梯度下降算法将所述总损失值进行反向传播,对所述神经网络层的参数进行修正,得到已训练的图片检测模型。

17、在本技术一些实施例中,所述根据所述检测物标记类别、所述检测物标记框、所述检测物预测类别以及检测物预测框,计算所述图片检测模型的总损失值,包括:

18、根据所述检测物标记类别和所述检测物预测类别,计算分类损失值;

19、根据所述检测物标记框和所述检测物预测框,计算坐标损失值;

20、根据所述分类损失值和所述坐标损失值,确定所述图片检测模型的所述总损失值。

21、在本技术一些实施例中,所述根据所述检测物标记类别和所述检测物预测类别,计算分类损失值,包括:

22、根据所述检测物标记类别和所述检测物预测类别,计算所述待训练图片的检测物预测类别梯度模长;

23、根据所述梯度模长和预设的超参数,计算所述待训练图片的检测物预测类别的有效示例数量;

24、根据所述梯度模长、所述超参数以及有效示例数量,得到所述待训练图片的检测物预测类别的梯度密度;

25、根据所述梯度密度计算所述分类损失值。

26、在本技术一些实施例中,所述根据所述检测物标记框和所述检测物预测框,计算坐标损失值,包括:

27、根据所述检测物预测框计算自排斥函数损失值lr,计算方式为:其中,smoothln为平滑函数,bi和bj分别为同一待训练图片中第i个和第j个预测框,i和j均小于s,s为同一待训练图片中预测框数量,if为判断函数,λ为常量;

28、根据所述检测物标记框和所述检测物预测框,确定giou损失值;

29、根据所述检测物标记框的中心点坐标和所述检测物预测框的中心点坐标,确定中心点距离损失值;

30、根据所述自排斥函数损失值、所述giou损失值以及所述中心点距离损失值,计算所述坐标损失值。

31、第二方面,本技术提供一种图片检测方法,包括:

32、获取待检测图片;

33、通过图片检测模型对所述待检测图片进行识别,得到所述待检测图片中包含的检测物类别和检测物坐标,所述图片检测模型通过如上述任意一项所述的图片检测模型训练方法得到。

34、第三方面,本技术提供一种图片检测模型训练装置,包括:

35、样本获取模块,用于获取训练图片样本集,所述训练图片样本集的待训练图片中包含检测物标记类别和检测物标记框;

36、模型训练模块,与所述样本获取模块通讯连接,用于通过图片检测模型对所述训练图片样本集中的各待训练图片进行训练,得到检测物预测类别和检测物预测框;根据所述检测物标记类别、所述检测物标记框、所述检测物预测类别以及检测物预测框,计算所述图片检测模型的总损失值;根据所述总损失值对所述图片检测模型进行修正,得到已训练的图片检测模型。

37、在本技术一些实施例中,模型训练模块还用于对所述图片检测模型进行修正,直至图片检测模型的修正次数达到预设次数,将最后一次修正后图片检测模型确定为已训练的图片检测模型。

38、在本技术一些实施例中,模型训练模块还用于通过所述特征提取层对所述训练图片样本集中的各待训练图片进行特征识别,得到所述各待训练图片的图片特征;将所述图片特征输入到类别预测分支网络得到检测物预测类别;将所述图片特征输入到坐标预测分支网络得到检测物预测框;根据所述检测物标记类别、所述检测物标记框、所述检测物预测类别以及检测物预测框,计算级联所述类别预测分支网络的损失值和所述坐标预测分支网络的损失值的总损失值,所述图片检测模型包括特征提取层、类别预测分支网络以及坐标预测分支网络。

39、在本技术一些实施例中,模型训练模块还用于通过梯度下降算法将所述总损失值进行反向传播,对所述神经网络层的参数进行修正,得到已训练的图片检测模型,所述特征提取层、类别预测分支网络以及坐标预测分支网络均包含神经网络层。

40、在本技术一些实施例中,模型训练模块还用于根据所述检测物标记类别和所述检测物预测类别,计算分类损失值;根据所述检测物标记框和所述检测物预测框,计算坐标损失值;根据所述分类损失值和所述坐标损失值,确定所述图片检测模型的所述总损失值。

41、在本技术一些实施例中,模型训练模块还用于根据所述检测物标记类别和所述检测物预测类别,计算所述待训练图片的检测物预测类别梯度模长;根据所述梯度模长和预设的超参数,计算所述待训练图片的检测物预测类别的有效示例数量;根据所述梯度模长、所述超参数以及有效示例数量,得到所述待训练图片的检测物预测类别的梯度密度;根据所述梯度密度计算所述分类损失值。

42、在本技术一些实施例中,模型训练模块还用于根据所述检测物预测框计算自排斥函数损失值lr,计算方式为:其中,smoothln为平滑函数,bi和bj分别为同一待训练图片中第i个和第j个预测框,i和j均小于s,s为同一待训练图片中预测框数量,if为判断函数,λ为常量;根据所述检测物标记框和所述检测物预测框,确定giou损失值;根据所述检测物标记框的中心点坐标和所述检测物预测框的中心点坐标,确定中心点距离损失值;根据所述自排斥函数损失值、所述giou损失值以及所述中心点距离损失值,计算所述坐标损失值。

43、第四方面,本技术还提供一种服务器,服务器包括:

44、一个或多个处理器;

45、存储器;以及

46、一个或多个应用程序,其中一个或多个应用程序被存储于存储器中,并配置为由处理器执行以实现的图片检测模型训练方法。

47、第五方面,本技术还提供一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器进行加载,以执行的图片检测模型训练方法中的步骤。

48、上述图片检测模型训练方法、装置、计算机设备及存储介质,通过综合检测物标记类别、检测物标记框、检测物预测类别以及检测物预测框计算总损失用于反向传播修正图片检测模型,引导神经网络加强对小型物体可判别特征的识别度、以及增加遮挡部位识别的敏感性。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1