基于双分支交叉注意力池化的眼前节疾病多标签分类方法

文档序号:35576371发布日期:2023-09-24 18:00阅读:46来源:国知局
基于双分支交叉注意力池化的眼前节疾病多标签分类方法

本发明属于计算机辅助医学图像处理领域,涉及一种基于双分支交叉注意力池化的眼前节疾病多标签分类方法。


背景技术:

1、在眼科领域,裂隙灯检查对角膜、结膜和晶状体的评估是眼表疾病诊断的基础,早期的眼病筛查可以帮助患者预防视力损害和其他问题。目前基于裂隙灯图像的眼前节疾病诊断系统多针对单一疾病。由于实际生活中病人可能同时患有多种眼部疾病,使得眼前节疾病多标签分类具有重大的临床意义和研究意义。

2、目前针对多标签图像分类的方法主要包括构建标签相关性、基于目标检测边界框以及注意力机制等。通过构建标签共现矩阵的方法计算成本大并且应用在医学图像小数据集上容易导致过拟合。基于目标检测边界框的方法需要高昂的标注成本并且模型复杂,由于成本和耗时的原因,大多数情况下医生无法对病灶区域进行密集注释。现有的基于视觉注意力的方法已被证明可以提升多标签图像分类任务的准确率,并应用在医学图像数据集上,但是仍然存在如下不足:(1)由于病灶位置和大小不同,很难针对性地区分和提取特征,部分标签对象之间存在视觉相似性很难区分;(2)图像中包含的正样本远远少于负样本,存在严重的正负样本失衡,导致在预测结果时出现可能会出现一些假阴性现象;(3)大多数方法忽视了多标签图像分类中的标签的粗细粒度问题。


技术实现思路

1、本发明针对现有技术的不足和对裂隙灯图像眼前节疾病分类的需求,提出一种基于双分支交叉注意力池化的眼前节疾病多标签分类方法。

2、为实现上述目的,本发明包括以下步骤:

3、s1、对裂隙灯图像进行数据预处理,统一数据集中图像尺寸;

4、s2、将图像输入双分支网络中,通过分支1运算输出切片序列,通过分支2运算输出特征图;

5、s3、分支1和分支2的特征图、注意力图进行双分支交叉注意力池化,输出疾病种类和疾病区域结果;

6、s4、利用分支2中输出的注意力图对该分支经过数据预处理的图像进行注意力引导的数据增强,并将增强的图片再次输入主网络,得到的结果作为数据增强模块的结果输出;

7、s5、将主网络的结果输出和数据增强模块的结果输出分别各自和标签进行损失计算,并将得到的损失加权求和,来对模型进行监督训练,模型训练到设定的迭代次数时停止训练,期间以验证集的最优评价指标结果保存模型参数;

8、其中所述分支1和分支2的主干特征图由resnet50网络模型提取得到。对分支1,主干特征图进行投影得到切片序列,并输入transformer模块,输出切片序列。对分支2,主干特征图直接作为特征图输入卷积模块,输出特征图。分支1和分支2之间通过交互模块进行特征耦合。

9、进一步地,上述基于双分支交叉注意力池化的眼前节疾病多标签分类方法中,所述步骤s3具体包括以下步骤:

10、s31、将分支1输出的切片序列经过reshape,并利用卷积核为1的卷积层改变通道数,得到对齐特征图,记作aligned feature maps。

11、同时,分支2也输出特征图;

12、s32、两个分支的特征图各自使用卷积核为1的卷积层进行卷积操作得到注意力图;

13、s33、将分支1的特征图和注意力图记作f1、a1,将分支2的特征图和注意力图记作f2、a2。对f1和a1、a2分别进行双线性注意力池化,得到的结果在维度为1的通道上进行拼接,记为特征矩阵1。同理对f2和a1、a2分别进行双线性注意力池化,得到的结果在维度为1的通道上进行拼接,记为特征矩阵2;

14、s34、细粒度的疾病分类损失。对特征矩阵1和特征矩阵2分别连接输出通道为疾病类别数的全连接层,疾病类别数为c,得到的logits经过sigmoid激活函数输出疾病类别预测1、疾病类别预测2,和真实标签计算交叉熵损失函数;

15、s35、粗粒度的疾病区域分类损失。上述c个类别属于细粒度级别的标签。同时,通过将疾病按照其所在区域进行划分,可以得到粗粒度级别的标签,即区域类别,区域类别数为r。因此对于每一张输入的裂隙灯图像,除了图像的细粒度疾病类别标签,还有图像的粗粒度疾病区域标签。粗粒度的疾病区域预测计算方法如下:对特征矩阵1、特征矩阵2后分别连接类别数为r的全连接层和relu激活函数,得到疾病区域预测值;对疾病类别预测1、疾病类别预测2后分别连接类别数为r的全连接层,得到疾病区域预测值,将上述得到的预测值相加,最后输入sigmoid激活函数即可得到粗粒度的疾病区域预测值,和真实标签计算交叉熵损失函数。

16、进一步地,上述基于双分支交叉注意力池化的眼前节疾病多标签分类方法中,所述步骤s4具体包括以下步骤:

17、s41、将分支2输出的注意力图进行一次上采样,并和经过数据预处理的图像进行注意力引导的数据增强;

18、s42、注意力引导的数据增强方式分为3种,分别是裁剪、cutmix、mixup。裁剪即在经过预处理的原始图像中将注意力分数较高的区域裁剪出来,并重新调整大小为原始图像大小;cutmix即将原始图像中注意力分数较高的区域进行裁剪,并叠加在原始图像的左上角;mixup即将原始图像中注意力分数较高的区域进行裁剪并重新调整大小为原始图像大小,将其和原始图像进行等比例混合;

19、s43、三种注意力引导的数据增强方式都可以得到一个新的输入,并且该输入无需再进行数据预处理。将三张图片分别输入到网络中,可以得到数据增强分支的三个结果。

20、进一步地,上述基于双分支交叉注意力池化的眼前节疾病多标签分类方法中,所述步骤s5具体包括以下步骤:

21、s51、对主网络的结果输出和数据增强模块的结果输出分别和标签做损失计算,主网络损失函数为交叉熵损失函数,数据增强模块为不对称性损失函数,其中主网络和数据增强模块所占权重分别为0.5。将得到的损失加权求和,对模型进行监督训练;

22、s52、模型训练到设定的迭代次数时停止训练,期间以验证集的最优评价指标结果保存模型参数,其中验证集的最优评价指标定为对疾病种类预测的f1-score。

23、本发明的有益效果:

24、本发明的基于双分支交叉注意力池化的眼前节疾病多标签分类方法,首先考虑到多标签图像病灶区域的位置、大小不一,选择用注意力图进行特征的提取,并使用双分支交叉注意力池化对易混淆或者视觉高度相似特征进行区分,从细粒度视角对多标签眼部图像进行分类;其次,考虑到正负样本不均衡问题,使用裁剪、cutmix、mixup等三种不同的数据增强方式对经过预处理的原始图像进行数据增强,突出正例区域特征,并利用非对称损失函数引导网络在训练时侧重正样本的损失贡献,减少负样本的损失贡献,起到缓解正负样本不均衡、修正部分假阴性标签的作用;最后,本方法考虑了多标签图像分类中的粗细粒度问题,以一个端到端的网络结构实现了多标签眼部图像的分级分类,网络可以同时输出准确的疾病类别信息以及疾病区域信息。



技术特征:

1.一种基于双分支交叉注意力池化的眼前节疾病多标签分类方法,其特征在于,包括如下步骤:

2.根据权利要求1所述的基于双分支交叉注意力池化的眼前节疾病多标签分类方法,其特征在于,所述步骤s1中预处理方法为:统一数据集中图像尺寸,并将数据集划分为训练集和验证集。

3.根据权利要求1所述的基于双分支交叉注意力池化的眼前节疾病多标签分类方法,其特征在于,所述步骤s2中,由resnet50网络模型构成的分支1和分支2分别提取得到主干特征图,对分支1提取得到的主干特征图进行投影得到切片序列,并输入transformer模块,输出切片序列;对分支2提取得到的主干特征图直接作为特征图输入卷积模块,输出特征图,所述分支1和分支2之间通过交互模块进行特征耦合,特征耦合以后,分支1输出切片序列,该切片序列形状与输入交互模块前切片序列形状一致,分支2输出特征图,该特征图与输入交互模块前特征图形状一致。

4.根据权利要求3所述的基于双分支交叉注意力池化的眼前节疾病多标签分类方法,其特征在于,所述分支1是由卷积模块构成,用于提取图片的局部特征;分支2利用transformer中自注意力机制的全局感受野捕获图片的长距离依赖关系,分支1和分支2之间通过交互模块进行特征耦合,将基于cnn的局部特征和基于transformer的全局特征进行融合。

5.根据权利要求4所述的基于双分支交叉注意力池化的眼前节疾病多标签分类方法,其特征在于,所述步骤s3中,得到注意力图的方法为:

6.根据权利要求5所述的基于双分支交叉注意力池化的眼前节疾病多标签分类方法,其特征在于,所述步骤s3中,双分支交叉注意力池化的方法为:

7.根据权利要求6所述的基于双分支交叉注意力池化的眼前节疾病多标签分类方法,其特征在于,所述步骤s3中,输出疾病种类和疾病区域分类结果的方法为:

8.根据权利要求1所述的基于双分支交叉注意力池化的眼前节疾病多标签分类方法,其特征在于:所述步骤s4中

9.根据权利要求1所述的基于双分支交叉注意力池化的眼前节疾病多标签分类方法,其特征在于:所述步骤s5中,对于经过数据预处理的图像,输入双分支网络中,生成一组预测结果,将其和标签进行损失计算,记为损失1,损失函数为交叉熵损失函数,增强图片1、增强图片2、增强图片3,分别通过双分支网络各自输出一组疾病种类和疾病区域结果,对于数据增强图片的每一组预测结果,将其和标签进行损失计算,损失函数为不对称性损失函数,计算得到的3个损失结果取平均值,得到的结果记为损失2,损失1和损失2所占权重分别为0.5,进行加权求和,得到最后的损失。


技术总结
本发明公开了一种基于双分支交叉注意力池化的眼前节疾病多标签分类方法,本发明包括以下步骤:(1)对裂隙灯图像进行数据预处理;(2)将图像输入双分支网络,输出token序列和特征图;(3)对特征图、注意力图进行双分支交叉注意力池化,输出疾病类别和区域的结果;(4)利用CNN分支中的注意力图对该分支输入进行注意力引导的数据增强,再次输入主网络;(5)将各个结果和标签进行损失计算来进行监督训练,训练好的模型可用于疾病类别和区域的诊断。本发明通过构建双分支交叉注意力池化模块,解决了多标签图像中对象大小位置不一且部分特征之间的视觉相似性问题,能够基于裂隙灯图像准确地进行疾病类别和区域的多标签分类。

技术研发人员:王莉莎,顾人舒,陈德潮,高红依,贾刚勇
受保护的技术使用者:杭州电子科技大学
技术研发日:
技术公布日:2024/1/15
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1