一种基于CNN与Transformer的糖尿病视网膜病变分级方法与系统

文档序号:35283563发布日期:2023-09-01 04:04阅读:165来源:国知局
一种基于CNN与Transformer的糖尿病视网膜病变分级方法与系统

本发明属于深度学习细粒度视觉分类,具体涉及一种基于cnn与transformer的糖尿病视网膜病变分级方法与系统。


背景技术:

1、根据国际糖尿病联合会发布的数据,2021年患有糖尿病的成年人人数达到5.37亿,约为全球成年人人数的十分之一,预计2045年将达到7.84亿。糖尿病视网膜病变是糖尿病的常见并发症之一,据不完全统计每三名糖尿病患者中就有一人会患有糖尿病视网膜病变。糖尿病视网膜病变的主要症状包括微动脉瘤、渗出物和出血。眼底图像中病变的种类和数量决定糖尿病视网膜病变等级,因此如何检测这些微小的病变实现糖尿病视网膜病变分级是一项具有挑战性的任务。

2、基于cnn的糖尿病视网膜病变分类方法具有较小的感受野,有利于局部特征提取,但可用信息的范围有限。基于transformer的糖尿病视网膜病变分类方法有很大的感受野,有利于对长距离的关系进行建模,但很容易丢失局部细节。此外,dr的严重程度遵循从轻度到重度的自然顺序,然而大对数dr分级方法忽略了不同严重程度糖尿病视网膜病变之间序的信息,将糖尿病视网膜病变分类视为多类分类问题。


技术实现思路

1、本发明的目的在于,提供一种基于cnn与transformer的糖尿病视网膜病变分级方法与系统,通过特征融合得到更合适的感受野,发现眼底图像中更多鉴别性区域,实现更准确的糖尿病视网膜病变分级,将糖尿病视网膜病变分级视为一项联合有序回归的多类分类问题,同时获得类别监督信息和有序监督信息,使得模型最终的分类性能达到目前先进的水平。

2、为解决上述技术问题,本发明的技术方案为:一种基于cnn与transformer的糖尿病视网膜病变分级方法,包括以下步骤:

3、s1:导入带有多组原始图像的原始图像数据集,并将所述原始图像数据集划分为图像训练集和图像测试集;

4、s2:将图像训练集中的各个原始图像等分成若干个图像块,并生成不同粒度版本的打乱图像;

5、s3:通过残差网络resnet50和swin-transformer建立糖尿病视网膜病变分级网络,采用渐进式策略训练糖尿病视网膜病变分级网络;输入打乱图像,通过残差网络resnet50的最后三层输出三个中间阶段不同粒度大小的特征谱;输入原始图像,通过swin-transformer的最后三层输出三个中间阶段不同大小的特征谱;通过特征融合模块融合不同阶段卷积神经网络输出的特征谱与swin-transformer输出的特征谱;串联残差网络resnet50与swin-transformer融合的最后三层输出的特征谱,得到串联阶段输出的多尺度多粒度的特征谱;

6、s4:定义分类损失函数和加权卡帕损失函数,并根据分类损失函数和加权卡帕损失函数构建损失层;

7、s5:利用图像训练集优化糖尿病视网膜病变分级网络;利用测试样本集对糖尿病视网膜病变分级网络进行测试。

8、s2具体为:

9、s21:将原始图像i等分成k×k个图像块,根据图像块的索引得到大小为j×j的索引矩阵pi;

10、s22:随机打乱图像块,通过拼图生成器生成打乱图像,根据图像块的索引矩阵pi得到打乱图像的索引矩阵ps;

11、s23:根据打乱图像的索引矩阵ps得到独热形式的大小为k2×j2的矩阵p。

12、s3具体为:

13、建立糖尿病视网膜病变分级网络,包括残差网络、swin-transformer、特征融合模块、卷积层和分类层;

14、采用渐进式策略训练糖尿病视网膜病变分级网络;训练过程具体包括,步骤1:通过所述特征融合模块融合残差网络第3层输出的特征谱a3与swin-transformer第2层输出的特征谱b2,得到第一融合特征t3;步骤2:通过所述特征融合模块融合残差网络第4层输出的特征谱a4与swin-transformer第3层输出的特征谱b3,得到第二融合特征t4;步骤3:通过所述特征融合模块融合残差网络第5层输出的特征谱a5与swin-transformer第4层输出的特征谱b4,得到第三融合特征t5;步骤4:使用原始图像作为残差网络resnet50与swin-transformer的输入图像,串联残差网络resnet50与swin-transformer融合的特征谱x3,x4,x5,得到串联阶段输出的多尺度多粒度的特征谱xconcat=concat(x3,x4,x5);

15、步骤1,步骤2和步骤3的融合特征t3,t4,t5经过卷积层分别对得到对应的特征谱x3,x4,x5;

16、残差网络resnet50提取打乱图像的图像特征ai,使用残差网络resnet50的最后三层(a3,a4,a5)分别处理打乱图像;对不同粒度版本的图像输出不同中间阶段的特征谱,设k=25-i+1,k×k表示输入拼图中块的数量,其中i={3,4,5},分别对应输出三个阶段的特征谱a3,a4,a5;

17、swin-transformer提取图像训练集的图像特征bl,使用swin-transformer的最后三层(b2,b3,b4)分别处理图像训练集的图像,分别对应输出三个阶段的特征谱b2,b3,b4;

18、特征融合模块包括空间注意力模块,通道注意力模块和细粒度交互模块;在步骤1,步骤2,步骤3和步骤4中,使用特征融合模块融合相对应的残差网络与swin-transformer提取的特征谱,具体描述如下:

19、所述空间注意力模块包括通道池化层,7×7卷积层conv和sigmoid激活函数,通过空间注意力对cnn分支特征进行增强,其方法表示为:

20、

21、所述通道注意力模块包括平均池化层,全连接层和sigmoid激活函数,通过通道注意力对cnn分支特征进行增强,其方法表示为:

22、

23、所述细粒度交互模块包括2个1×1卷积conv1和conv2,哈达玛积和一个3×3卷积conv3,通过细粒度交互模块实现特征交互,其方法表示为:

24、ci=conv3(conv1(ai)⊙conv2(bl));

25、将两个分支的增强特征与交互特征融合在一起,经过一个残差块,以实现不同分支输出特征之间的互补关系,其方法表示为:

26、

27、通过卷积层将融合后的特征谱的通道维度统一到1024维:

28、

29、所述分类层由两个具有batchnorm和elu非线性的两个全连接层组成;糖尿病视网膜病变分为5个等级,分类层通过分类器处理1024维的特征向量后得到一个5维的特征向量作为预测概率分布表示为:

30、

31、其中,步骤1,步骤2,步骤3和步骤4的预测概率分布分别表示为v3,v4,v5,vconcat。

32、s4具体为:

33、s41:在步骤1,步骤2和步骤3中,步骤i中使用vi表示预测概率分布,在步骤4中,使用vnconcat表示预测概率分布;使用v表示图像真实标签,通过交叉熵损失计算分类损失函数,表示为:

34、

35、

36、其中,n表示图像索引,m表示一个批量中图像总数,可表示步骤1,步骤2和步骤3中的分类损失函数,表示步骤4的分类损失函数。

37、s42:通过加权卡帕损失函数计算分类损失函数,表示为:

38、

39、

40、

41、其中,n表示样本的数量,nm表示类别m的样本数量,q表示类别总数,wm,n表示二次加权矩阵,其中|m-n|表示预测类别m和实际类别n之间的距离,tk表示第k个样本xk的真实类别,p(q|xk)表示第k个样本xk预测类别属于q的条件概率;

42、s43:将交叉熵损失函数和加权卡帕损失函数加权平均获得总损失函数:

43、

44、

45、其中,β为超参数,可表示步骤1,步骤2和步骤3中的总损失函数,表示步骤4的总损失函数。

46、s5具体为:

47、s51:采用图像训练集,通过自动微分技术,使用基于随机梯度下降和反向传播算法,根据总损失函数优化糖尿病视网膜病变分级网络;

48、s52:采用图像测试集在图像训练集权重的基础上对糖尿病视网膜病变分级网络进行测试。

49、还提供一种基于cnn与transformer的糖尿病视网膜病变分级系统,包括:拼图生成模块、残差网络模块、swin-transformer模块、特征融合模块、损失函数模块和训练测试模块;其中,

50、拼图生成模块,用于处理图像训练数据集,生成不同粒度大小版本的打乱图像;

51、残差网络模块,用于使用resnet50的最后三层进行渐进式训练,在步骤1至步骤3中使用不同粒度大小的图像作为输入图像,并且选取不同的中间阶段特征谱输出,得到对象的局部细节信息;在步骤4中使用原始图像作为输入图像,残差网络resnet50同时输出最后三层输出的特征谱,得到多尺度的特征信息;

52、swin-transformer模块,用于使用swin-transformer的最后三层进行渐进式训练,使用原始图像作为输入图像,在步骤1至步骤3中选取不同的中间阶段特征谱输出,得到对象的全局表示信息;在步骤4中swin-transformer同时输出最后三层输出的特征谱,得到多尺度的特征信息;

53、特征融合模块,用于在步骤1至步骤3中融合残差网络与swin-transformer相应阶段的特征谱,输出融合特征;在步骤4中串联步骤1至步骤3中融合的特征谱,得到串联阶段多尺度多粒度的特征谱;

54、损失函数模块,用于定义分类损失函数和加权卡帕损失函数,并根据分类损失函数和加权卡帕损失函数构建损失层;

55、训练测试模块,用于利用图像训练集优化糖尿病视网膜病变分级网络,利用图像测试集测试糖尿病视网膜病变分级网络。

56、损失函数模块的工作流程具体为:

57、在步骤1,步骤2,步骤3中,步骤i中使用vi表示预测概率分布,在步骤4中,使用表示预测概率分布,第i阶段使用vi表示预测概率分布,使用v表示图像真实标签,通过交叉熵损失计算分类损失函数,表示为:

58、

59、

60、其中,n表示图像索引,m表示一个批量中图像总数,可表示步骤1,步骤2,步骤3中的分类损失函数,表示步骤4的分类损失函数。;

61、通过加权卡帕损失函数计算分类损失函数,表示为:

62、

63、

64、

65、其中,n表示样本的数量,nm表示类别m的样本数量,q表示类别总数,wm,n表示二次加权矩阵,其中|m-n|表示预测类别m和实际类别n之间的距离,tk表示第k个样本xk的真实类别,p(q|xk)表示第k个样本xk预测类别属于q的条件概率;

66、将交叉熵损失函数和加权卡帕损失函数加权平均获得总损失函数:

67、

68、

69、其中,β为超参数,可表示步骤1,步骤2,步骤3中的总损失函数,表示步骤4的总损失函数。

70、拼图生成模块的工作流程具体为:

71、将原始图像i等分成k×k个图像块,根据图像块的索引得到大小为j×j的索引矩阵pi;

72、随机打乱图像块,通过拼图生成器生成打乱图像,根据图像块的索引矩阵pi得到打乱图像的索引矩阵ps;

73、根据打乱图像的索引矩阵ps得到独热形式的大小为k2×j2的矩阵p。

74、还提供一种计算机设备,包括存储器、处理器以及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如上述任一项所述方法的步骤。

75、还提供一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现如上述任一项所述方法的步骤。

76、与现有技术相比,本发明的有益效果为:

77、1.本发明的一种基于cnn与transformer的基于糖尿病视网膜病变分级方法,通过使用特征融合模块融合残差网络(resnet50)和swin-transformer多层次特征,实现全局信息与局部信息的交互。通过特征融合得到更合适的感受野,发现眼底图像中更多鉴别性区域,实现更准确的糖尿病视网膜病变分级。

78、2.本发明通过同时利用分类损失函数和加权卡帕损失函数,将糖尿病视网膜病变分级视为一项联合有序回归的多类分类问题,同时获得类别监督信息和有序监督信息,使得模型最终的分类性能达到目前先进的水平。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1