本发明涉及一种基于对比学习的点云分类领域自适应方法,属于分类领域自适应。
背景技术:
1、基于深度神经网络的大规模学习方法,在机器人、无人机以及自动驾驶等领域中发挥着重要作用,这些领域通常使用实时深度传感器,如激光雷达,来捕捉场景的精确几何信息,这些信息由3d点云表示。然而,深度神经网络通常需要大量的标记点云来进行表示学习,这限制了其在现实世界中的可扩展性。很多区域存在样本标签不足的问题。
2、当我们已经在一个区域的环境中训练好了一个点云分类模型,能够准确识别该区域中的建筑、道路、车辆等物体。但如果我们直接将这个训练好的模型应用于另一个区域时,由于两个区域的环境风格不同,即特征分布的不同,模型的性能通常会出现显著的下降,并不能准确识别另一个区域中的建筑、道路、车辆等物体,称之为域偏移。为了缓解这个问题,点云领域自适应最近受到了越来越多的关注。
3、领域自适应是一种机器学习技术,旨在解决在训练过程中训练集和测试集之间存在的领域差异问题。当训练模型的数据来自一个领域(源域),而在实际应用中需要将模型应用于另一个领域(目标域)时,由于两个领域之间的数据分布不匹配,模型的性能可能会下降。自监督学习是缓解该问题的一种有效方法。
4、自监督学习通过利用不同输入信号之间的关系或相关性,直接学习无标记目标数据的内部结构。现有的工作大多通过设计辅助任务来学习有利于下游任务的特征表示。defrec首次将自监督学习应用于点云的领域自适应方向,通过对变形的输入样本进行重建来从点云中学习有用的表示,但该方法并没有进一步将无标签的目标域样本充分利用起来。gast将输入样本采样成两部分,分别沿着两个不同的轴进行旋转并对旋转角度进行预测,以及对输入样本的部分区域进行形变并定位这些区域,通过这两个辅助任务来获取对点云领域自适应有意义的特征表示。glrv对输入点云进行放大或缩小然后预测压缩尺度来学习全局结构,为了以自监督的方式捕捉局部结构,将3d区域局部区域投影到2d平面上,然后重建压缩区域,在全局和局部级别上对齐了特征。但对于更具挑战性的数据集,比如点云被严重遮挡,会导致gast和glrv中提出的辅助任务具有模糊性。
5、对比学习是目前较为流行的自监督方法,在其他领域均取得了不错的效果。通过对比学习,模型可以通过学习数据样本之间的相似性和差异性来提取更有用的特征表示,可以很好的应对这些更具挑战性的数据集。
6、数据增强是对比学习中一个很重要的组件,其目的是通过随机的变换将样本映射到不同的视图。数据增强方法的选择对于对比学习的结果有很大的影响,且结合多种数据增强方法可以产生更加有效的特征表示。但目前的增强方法大都只关注对样本的全局增强,忽视了局部的增强。而且点云数据集目前还存在标签样本较少的问题。因此如何有效的结合全局和局部增强,以及合理的利用无标签数据成了本领域亟待解决的问题。
7、公开于该背景技术部分的信息仅仅旨在增加对本发明的总体背景的理解,而不应当被视为承认或以任何形式暗示该信息构成已为本领域普通技术人员所公知的现有技术。
技术实现思路
1、本发明的目的在于克服现有技术中的不足,提供一种基于对比学习的点云分类领域自适应方法,通过结合局部增强和全局增强,能够有效地提高增强点云集样本的多样性和难度,从而提高模型的性能和泛化能力,进一步提高点云分类的准确性。
2、为达到上述目的,本发明是采用下述技术方案实现的:
3、本发明公开了一种基于对比学习的点云分类领域自适应方法,包括如下步骤:
4、获取目标域的实际点云集;
5、将所述实际点云集输入至训练好的点云分类模型,得到目标域的点云分类结果;
6、其中,所述点云分类模型的训练方法如下:
7、获取源域和目标域的原始点云集;
8、对所述源域和目标域的原始点云集分别进行增强操作,得到源域和目标域的增强点云集;
9、根据所述源域和目标域的增强点云集,以及源域的原始点云集,基于预构建的点云分类模型,计算总损失函数;
10、以最小化总损失函数为目标,优化训练预构建的点云分类模型,得到训练好的点云分类模型;
11、其中,所述增强操作包括如下步骤:
12、针对源域或目标域的原始点云集中的每一个原始点云,基于fps算法对原始点云的数据点进行采样,将一个原始点云分为两个部分点云;分别对两个部分点云进行噪声扰动处理后再进行拼接处理,得到一个更新点云;
13、对一个更新点云同时进行两次增强处理,得到两个增强点云;汇总所有的增强点云,得到增强点云集。
14、进一步的,所述源域的原始点云集包括多个源域的原始点云及其对应的真实分类标签;所述目标域的原始点云集包括多个目标域的原始点云。
15、进一步的,所述fps算法包括如下步骤:
16、针对任一个原始点云,基于beta分布获取采样系数lam;
17、根据所述采样系数lam,对原始点云的数据点进行采样,将一个原始点云分为两个部分点云;其中,一个部分点云的数据点的数量为n*lam,另一个部分点云的数据点的数量为n*(1-lam),n表示原始点云中数据点的数量。
18、进一步的,所述噪声扰动处理包括如下步骤:
19、针对任一个部分点云,生成与对应的部分点云大小相同的噪声数据;其中,所述噪声数据中每个元素从正态分布中进行采样;
20、将所述噪声数据加入到对应的部分点云中,得到增强后的部分点云。
21、进一步的,所述增强处理包括如下步骤:
22、对所述更新点云进行裁剪计算,得到更新点云三维坐标范围;
23、生成一个预设长度的三维坐标范围数组,在每个维度上生成随机均匀分布的数值,再将所述三维坐标范围数组归一化并缩放到更新点云三维坐标范围内,得到一个裁剪范围,处于所述裁剪范围内的数据点组成一个裁剪点云;
24、对所述裁剪点云进行压缩处理,得到压缩点云;
25、对所述压缩点云进行旋转处理,得到增强点云。
26、进一步的,所述总损失函数包括源域的对比损失和监督损失、以及目标域的对比损失和无监督损失,具体计算步骤如下:
27、根据所述源域的增强点云集,计算源域的对比损失;
28、根据所述目标域的增强点云集,计算目标域的对比损失;
29、将所述源域的原始点云集输入至预构建的点云分类模型,计算源域的监督损失;
30、将所述目标域的增强点云集输入至预构建的点云分类模型,计算目标域的无监督损失;
31、所述总损失函数的表达式如下:
32、ζtotal=ζcontrast1+ζcontrast2+ζs+ζt
33、式中,ζtotal表示总损失函数;ζcontrast1表示源域的对比损失;ζcontrast2表示目标域的对比损失;ζs表示源域的监督损失;ζr表示目标域的无监督损失。
34、进一步的,所述源域的对比损失的表达式如下:
35、
36、式中,ζcontrast1表示源域的对比损失;log表示对数函数;exp表示以自然常数e为底的指数函数;sim表示相似度计算;(zi,zj)表示源域的同一个原始点云对应的两个增强点云;t表示温度系数;2n1表示源域的增强点云集中增强点云的数量;表示一个指标函数,当k≠i时值为1;zk表示源域的第k个增强点云。
37、进一步的,所述目标域的对比损失的表达式如下:
38、
39、式中,ζcontrast2表示目标域的对比损失;log表示对数函数;exp表示以自然常数e为底的指数函数;sim表示相似度计算;(zp,zq)表示目标域的同一个原始点云对应的两个增强点云;t表示温度系数;2n2表示目标域的增强点云集中增强点云的数量;表示一个指标函数,当h≠p时值为1;zh表示目标域的第h个增强点云。
40、进一步的,所述源域的监督损失的计算步骤如下:
41、将所述源域的原始点云集输入至预构建的点云分类模型,得到每一个原始点云对应的类分布预测数据;
42、根据所述原始点云对应的类分布预测数据及对应的真实分类标签,计算源域的监督损失,所述源域的监督损失的表达式如下:
43、
44、式中,ζs表示源域的监督损失;h()表示交叉熵函数;yn表示第n个原始点云对应的真实分类标签;psrc表示第n个原始点云的类分布预测数据;n1表示源域的原始点云的数量。
45、进一步的,所述目标域的无监督损失的计算步骤如下:
46、一个原始点云对应两个增强点云,分别为增强点云pc_aug1和增强点云pc_aug2;
47、将所述增强点云pc_aug1输入至预构建的点云分类模型中,得到对应的类预测分布数据p1;
48、将所述增强点云pc_aug2输入至预构建的点云分类模型中,得到对应的类预测分布数据p2;
49、根据所述类预测分布数据p1,计算得到伪标签将所述伪标签作为增强点云pc_aug2的伪标签,进行目标域的无监督损失的计算,所述目标域的无监督损失的表达式如下:
50、
51、式中,ζt表示目标域的无监督损失;τ表示设定的置信度阈值;表示一个指标函数,当类预测分布数据p1m中最大的置信度高于设定的置信度阈值时值为1;表示第m个原始点云对应的增强点云pc_aug2的伪标签;p2m表示第m个原始点云对应的增强点云pc_aug2的类预测分布数据;h()表示交叉熵函数;n2表示目标域的原始点云的数量。
52、与现有技术相比,本发明所达到的有益效果:
53、本发明的基于对比学习的点云分类领域自适应方法。通过结合局部增强和全局增强,能够有效地提高增强点云集样本的多样性和难度,从而提高模型的性能和泛化能力,进一步提高点云分类的准确性。
54、本发明还利用伪标签解决了现有技术中标签样本不足的问题。