用于姿态估计的网络训练方法及装置与流程

文档序号:35464139发布日期:2023-09-16 02:45阅读:18来源:国知局
用于姿态估计的网络训练方法及装置与流程

本公开涉及计算机,尤其涉及一种用于姿态估计的网络训练方法及装置。


背景技术:

1、在机器人视觉、动作跟踪和单照相机定标等很多应用领域,都需要对图像中的目标对象进行姿态估计,并基于姿态估计结果进行后续的业务处理。比如,在人脸识别的场景中,可以先通过头部姿态估计筛选出符合条件的人脸图像,再对筛选出的人脸图像进行识别;在增强现实(augmented reality,ar)交互场景中,可以对拍摄的人脸图像进行头部位姿估计,以便为用户提供更真实的虚拟特效(例如发卡等装饰品特效);在安全驾驶的检测场景中,可以对驾驶员的人脸图像进行头部位姿估计,判断驾驶员是否左顾右盼等。

2、其中,以安全驾驶的检测场景为例,随机智能座舱技术的发展,驾驶员监控系统(driver monitor system, dms)由于其实用性和安全性成为了智能座舱必备技术之一。dms在行车过程中实时地对驾驶员进行疲劳监测、注意力监测、危险驾驶行为监测,可以有效避免由驾驶员主观因素引起的交通事故。

3、在dms中进行疲劳检测时,不仅要基于人脸特征判断眼睛睁闭状态和嘴部哈欠动作,还依赖头部姿态的角度范围综合判定驾驶员的专心程度以及是否在打瞌睡。另外,在dms中进行分心检测中也会基于头部姿态初步判断驾驶员的低头扭头行为。因此,头部姿态估计成为dms技术中不可或缺的部分。


技术实现思路

1、本公开提出了一种用于姿态估计的网络训练方法及装置。

2、根据本公开的一方面,提供了一种用于姿态估计的网络训练方法,包括:获取训练数据集;将所述训练数据集中的样本图像输入预设的特征提取网络,确定所述样本图像中目标对象的特征张量;根据所述目标对象的特征张量,对初始姿态网络进行训练,得到目标姿态网络,其中,所述目标姿态网络用于结合所述特征提取网络,确定待处理图像中目标对象的姿态信息。

3、在一种可能的实现方式中,根据所述目标对象的特征张量,对初始姿态网络进行训练,得到目标姿态网络,包括:根据目标对象的特征张量,对初始姿态网络进行多轮的迭代训练,得到多个备选姿态网络;在满足训练结束条件的情况下,从多个备选姿态网络中选择目标姿态网络,所述训练结束条件包括对所述初始姿态网络的迭代训练轮数大于或等于第一预设阈值,和/或,当前轮迭代训练的学习率小于第二预设阈值。

4、在一种可能的实现方式中,根据目标对象的特征张量,对初始姿态网络进行多轮的迭代训练,得到多个备选姿态网络,包括:对于所述初始姿态网络的任一轮的迭代训练,将所述目标对象的特征张量输入初始姿态网络,得到姿态预测结果;根据预设的姿态损失函数、所述姿态预测结果、所述训练数据集的标注信息,确定姿态损失,其中,所述姿态损失用于指示所述初始姿态网络对目标对象的姿态信息的检测误差;根据所述姿态损失,对初始姿态网络进行迭代训练,得到当前轮迭代训练对应的备选姿态网络。

5、在一种可能的实现方式中,所述初始姿态网络包括用于预测目标对象的旋转矩阵的第一网络,将所述目标对象的特征张量输入初始姿态网络,得到姿态预测结果,包括:将所述目标对象的特征张量输入第一网络,得到姿态预测结果,所述姿态预测结果包括所述第一网络预测的目标对象的旋转矩阵;所述预设的姿态损失函数包括第一损失函数,所述第一损失函数用于指示所述第一网络预测目标对象的旋转矩阵的检测误差,根据预设的姿态损失函数、所述姿态预测结果、所述训练数据集的标注信息,确定姿态损失,包括:将所述述第一网络预测的目标对象的旋转矩阵、所述目标对象对应的所述训练数据集的标注信息输入第一损失函数,确定姿态损失。

6、在一种可能的实现方式中,所述初始姿态网络包括用于预测目标对象的旋转矩阵的第一网络,以及预测目标对象的姿态角度的第二网络,将所述目标对象的特征张量输入初始姿态网络,得到姿态预测结果,包括:将所述目标对象的特征张量分别输入第一网络和第二网络,得到姿态预测结果,所述姿态预测结果包括所述第一网络预测的目标对象的旋转矩阵,以及所述第二网络预测的目标对象的姿态角度;所述预设的姿态损失函数包括第一损失函数与第二损失函数的乘积,所述第一损失函数用于指示所述第一网络预测目标对象的旋转矩阵的检测误差,所述第二损失函数用于指示所述第二网络预测的目标对象的姿态角度的检测误差,根据预设的姿态损失函数、所述姿态预测结果、所述训练数据集的标注信息,确定姿态损失,包括:将所述第一网络预测的目标对象的旋转矩阵、所述第二网络预测的目标对象的姿态角度、所述目标对象对应的所述训练数据集的标注信息输入所述预设的姿态损失函数,确定姿态损失。

7、在一种可能的实现方式中,所述初始姿态网络包括用于预测目标对象的旋转矩阵的第一网络,所述第一网络包括第一网络层、第二网络层、第三网络层;其中,所述第一网络层包括卷积核尺寸为第一预设尺寸的卷积单元、归一化单元、非线性激活单元;所述第二网络层包括卷积核尺寸为第二预设尺寸的深度可分离卷积单元、归一化单元、非线性激活单元,所述第二预设尺寸大于所述第一预设尺寸;所述第三网络层包括单位卷积单元。

8、在一种可能的实现方式中,所述初始姿态网络包括用于预测目标对象的旋转矩阵的第一网络,以及用于预测目标对象的姿态角度的第二网络;所述第一网络包括第一网络层、第二网络层、第三网络层,所述第二网络包括第四网络层、第五网络层、第六网络层、第七网络层;其中,所述第一网络层包括卷积核尺寸为第一预设尺寸的卷积单元、归一化单元、非线性激活单元;所述第二网络层包括卷积核尺寸为第二预设尺寸的深度可分离卷积单元、归一化单元、非线性激活单元,所述第二预设尺寸大于所述第一预设尺寸;所述第三网络层包括单位卷积单元;所述第四网络层包括卷积核尺寸为所述第一预设尺寸的卷积单元、归一化单元、非线性激活单元;所述第五网络层包括卷积核尺寸为所述第二预设尺寸的深度可分离卷积单元、归一化单元、非线性激活单元;所述第六网络层包括单位卷积单元、归一化单元、非线性激活单元;所述第七网络层包括单位卷积单元。

9、在一种可能的实现方式中,获取训练数据集,包括:获取第一样本集,所述第一样本集包括目标对象的姿态角度大于预设角度阈值的样本图像;获取第二样本集,所述第二样本集包括多种传感器在不同场景下采集的样本图像;对所述第一样本集和/或所述第二样本集进行数据集增强处理,得到第三样本集;根据所述第一样本集、所述第二样本集、所述第三样本,确定训练数据集。

10、根据本公开的一方面,提供了一种姿态估计方法,所述方法包括:获取待处理图像;将所述待处理图像输入预设的特征提取网络,得到所述待处理图像中目标对象的特征张量;将所述待处理图像中目标对象的特征张量输入目标姿态网络中处理,确定所述待处理图像的检测结果,所述检测结果包括所述目标对象的姿态信息;其中,所述目标姿态网络是根据上述的用于姿态估计的网络训练方法训练得到的。

11、根据本公开的一方面,提供了一种用于姿态估计的网络训练装置,包括:获取模块,用于获取训练数据集;确定模块,用于将所述训练数据集中的样本图像输入预设的特征提取网络,确定所述样本图像中目标对象的特征张量;训练模块,用于根据所述目标对象的特征张量,对初始姿态网络进行训练,得到目标姿态网络,其中,所述目标姿态网络用于结合所述特征提取网络,确定待处理图像中目标对象的姿态信息。

12、在一种可能的实现方式中,所述训练模块用于:根据目标对象的特征张量,对初始姿态网络进行多轮的迭代训练,得到多个备选姿态网络;在满足训练结束条件的情况下,从多个备选姿态网络中选择目标姿态网络,所述训练结束条件包括对所述初始姿态网络的迭代训练轮数大于或等于第一预设阈值,和/或,当前轮迭代训练的学习率小于第二预设阈值。

13、在一种可能的实现方式中,根据目标对象的特征张量,对初始姿态网络进行多轮的迭代训练,得到多个备选姿态网络,包括:对于所述初始姿态网络的任一轮的迭代训练,将所述目标对象的特征张量输入初始姿态网络,得到姿态预测结果;根据预设的姿态损失函数、所述姿态预测结果、所述训练数据集的标注信息,确定姿态损失,其中,所述姿态损失用于指示所述初始姿态网络对目标对象的姿态信息的检测误差;根据所述姿态损失,对初始姿态网络进行迭代训练,得到当前轮迭代训练对应的备选姿态网络。

14、在一种可能的实现方式中,所述初始姿态网络包括用于预测目标对象的旋转矩阵的第一网络,将所述目标对象的特征张量输入初始姿态网络,得到姿态预测结果,包括:将所述目标对象的特征张量输入第一网络,得到姿态预测结果,所述姿态预测结果包括所述第一网络预测的目标对象的旋转矩阵;所述预设的姿态损失函数包括第一损失函数,所述第一损失函数用于指示所述第一网络预测目标对象的旋转矩阵的检测误差,根据预设的姿态损失函数、所述姿态预测结果、所述训练数据集的标注信息,确定姿态损失,包括:将所述述第一网络预测的目标对象的旋转矩阵、所述目标对象对应的所述训练数据集的标注信息输入第一损失函数,确定姿态损失。

15、在一种可能的实现方式中,所述初始姿态网络包括用于预测目标对象的旋转矩阵的第一网络,以及预测目标对象的姿态角度的第二网络,将所述目标对象的特征张量输入初始姿态网络,得到姿态预测结果,包括:将所述目标对象的特征张量分别输入第一网络和第二网络,得到姿态预测结果,所述姿态预测结果包括所述第一网络预测的目标对象的旋转矩阵,以及所述第二网络预测的目标对象的姿态角度;所述预设的姿态损失函数包括第一损失函数与第二损失函数的乘积,所述第一损失函数用于指示所述第一网络预测目标对象的旋转矩阵的检测误差,所述第二损失函数用于指示所述第二网络预测的目标对象的姿态角度的检测误差,根据预设的姿态损失函数、所述姿态预测结果、所述训练数据集的标注信息,确定姿态损失,包括:将所述第一网络预测的目标对象的旋转矩阵、所述第二网络预测的目标对象的姿态角度、所述目标对象对应的所述训练数据集的标注信息输入所述预设的姿态损失函数,确定姿态损失。

16、在一种可能的实现方式中,所述初始姿态网络包括用于预测目标对象的旋转矩阵的第一网络,所述第一网络包括第一网络层、第二网络层、第三网络层;其中,所述第一网络层包括卷积核尺寸为第一预设尺寸的卷积单元、归一化单元、非线性激活单元;所述第二网络层包括卷积核尺寸为第二预设尺寸的深度可分离卷积单元、归一化单元、非线性激活单元,所述第二预设尺寸大于所述第一预设尺寸;所述第三网络层包括单位卷积单元。

17、在一种可能的实现方式中,所述初始姿态网络包括用于预测目标对象的旋转矩阵的第一网络,以及用于预测目标对象的姿态角度的第二网络;所述第一网络包括第一网络层、第二网络层、第三网络层,所述第二网络包括第四网络层、第五网络层、第六网络层、第七网络层;其中,所述第一网络层包括卷积核尺寸为第一预设尺寸的卷积单元、归一化单元、非线性激活单元;所述第二网络层包括卷积核尺寸为第二预设尺寸的深度可分离卷积单元、归一化单元、非线性激活单元,所述第二预设尺寸大于所述第一预设尺寸;所述第三网络层包括单位卷积单元;所述第四网络层包括卷积核尺寸为所述第一预设尺寸的卷积单元、归一化单元、非线性激活单元;所述第五网络层包括卷积核尺寸为所述第二预设尺寸的深度可分离卷积单元、归一化单元、非线性激活单元;所述第六网络层包括单位卷积单元、归一化单元、非线性激活单元;所述第七网络层包括单位卷积单元。

18、在一种可能的实现方式中,所述获取模块用于:获取第一样本集,所述第一样本集包括目标对象的姿态角度大于预设角度阈值的样本图像;获取第二样本集,所述第二样本集包括多种传感器在不同场景下采集的样本图像;对所述第一样本集和/或所述第二样本集进行数据集增强处理,得到第三样本集;根据所述第一样本集、所述第二样本集、所述第三样本,确定训练数据集。

19、根据本公开的一方面,提供了一种姿态估计装置,包括:待处理图像获取模块,用于获取待处理图像;特征提取模块,用于将所述待处理图像输入预设的特征提取网络,得到所述待处理图像中目标对象的特征张量;姿态估计模块,用于将所述待处理图像中目标对象的特征张量输入目标姿态网络中处理,确定所述待处理图像的检测结果,所述检测结果包括所述目标对象的姿态信息;其中,所述目标姿态网络是根据上述的用于姿态估计的网络训练方法训练得到的。

20、根据本公开的一方面,提供了一种电子设备,包括:处理器;用于存储处理器可执行指令的存储器;其中,所述处理器被配置为调用所述存储器存储的指令,以执行上述方法。

21、根据本公开的一方面,提供了一种计算机可读存储介质,其上存储有计算机程序指令,所述计算机程序指令被处理器执行时实现上述方法。

22、在本公开实施例中,可将获取的训练数据集中的样本图像输入预设的特征提取网络(例如驾驶员监控系统dms内置的关键点特征提取网络),以根据特征提取网络输出的样本图像中目标对象的特征张量,对初始姿态网络进行训练,得到目标姿态网络。通过这种方式训练所得的目标姿态网络,在推理阶段进行网络部署的过程中,可以在已经训练好的特征提取网络(例如驾驶员监控系统dms内置的关键点特征提取网络)上增加目标姿态网络,不进行额外的数据预处理,直接复用已有的特征提取网络输出的特征张量,有利于降低部署中对计算资源的占用和相关算法库的依赖;而且,在网络训练过程中,不用再进行特征提取网络相关的训练,可提高训练效率,进一步减少对处理器算力资源和内存资源的消耗。

23、应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,而非限制本公开。根据下面参考附图对示例性实施例的详细说明,本公开的其它特征及方面将变得清楚。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1