本发明涉及脑电数据处理,尤其涉及一种运动想象脑电分类模型训练方法、装置、设备及存储介质。
背景技术:
1、哺乳动物大脑中的不同活动和大脑状态会引起皮层同步的不同模式,从而导致头皮上随时间产生多样化的电位变化。这些时空相关的头皮电位被认为是神经元计算的核心机制,可以通过脑电图(electroencephalography,eeg)有效捕获和记录。eeg在检测和精确定位多通道电极信号方面具有较高精度,已成为记录和分析各种大脑活动(如运动想象)的关键工具。因此,基于eeg的运动想象分类近年来受到了广泛关注。
2、随着计算设备性能的显著改进,深度学习模型在多领域特征提取方面表现出卓越的能力,广泛应用于运动想象信号的处理和分析。然而,尽管深度学习模型具有上述优点,但数据短缺和变异性仍对分类模型的性能和泛化性构成巨大挑战。为了解决这些挑战,已建立的许多跨用户的运动想象分类方法通过从源域用户转移已知知识来保持分类模型的学习性能,最终实现目标域用户有限数据下的无监督分类。
3、然而,现有的跨用户的运动想象分类方法并未关注用户间情境下跨时段问题和电极数据分布差异性问题,从而导致生成的跨用户运动想象分类网络的分类准确度较低。
技术实现思路
1、本发明提供了一种运动想象脑电分类模型训练方法、装置、设备及存储介质,以提高运动想象脑电分类模型的分类准确度。
2、根据本发明的一方面,提供了一种运动想象脑电分类模型训练方法,所述方法包括:
3、获取源域电极数据和目标域电极数据;
4、将所述源域电极数据和所述目标域电极数据输入至未经训练的分类网络模型中,得到所述分类网络模型中的空间特征提取器对所述源域电极数据进行特征提取后输出的源域深度特征数据,以及对所述目标域电极数据进行特征提取后输出的目标域深度特征数据;
5、根据所述源域深度特征数据和所述目标域深度特征数据,生成桥接域电极数据;
6、将所述源域深度特征数据和所述目标域深度特征数据输入至所述分类网络模型中的三维全连接层分别进行特征提取,得到所述源域深度特征数据对应的第一特征提取数据和所述目标域深度特征数据对应的第二特征提取数据,并根据所述第一特征提取数据、所述第二特征提取数据和所述桥接域电极数据,确定第一损失值;
7、将所述第一特征提取数据输入至所述分类网络模型中的全连接层,由所述全连接层进行运动类别预测,得到运动预测类别,并根据所述运动预测类别和所述源域电极数据对应的运动真实类别,确定第二损失值;
8、根据所述第一损失值和所述第二损失值,对所述分类网络模型进行迭代训练,直到满足模型训练结束条件,得到目标运动想象脑电分类模型,用于进行跨用户的运动想象动作分类。
9、根据本发明的另一方面,提供了一种运动想象脑电分类模型训练装置,所述装置包括:
10、电极数据获取模块,用于获取源域电极数据和目标域电极数据;
11、深度特征提取模块,用于将所述源域电极数据和所述目标域电极数据输入至未经训练的分类网络模型中,得到所述分类网络模型中的空间特征提取器对所述源域电极数据进行特征提取后输出的源域深度特征数据,以及对所述目标域电极数据进行特征提取后输出的目标域深度特征数据;
12、桥接域数据生成模块,用于根据所述源域深度特征数据和所述目标域深度特征数据,生成桥接域电极数据;
13、第一损失值确定模块,用于将所述源域深度特征数据和所述目标域深度特征数据输入至所述分类网络模型中的三维全连接层分别进行特征提取,得到所述源域深度特征数据对应的第一特征提取数据和所述目标域深度特征数据对应的第二特征提取数据,并根据所述第一特征提取数据、所述第二特征提取数据和所述桥接域电极数据,确定第一损失值;
14、第二损失值确定模块,用于将所述第一特征提取数据输入至所述分类网络模型中的全连接层,由所述全连接层进行运动类别预测,得到运动预测类别,并根据所述运动预测类别和所述源域电极数据对应的运动真实类别,确定第二损失值;
15、目标分类模型生成模块,用于根据所述第一损失值和所述第二损失值,对所述分类网络模型进行迭代训练,直到满足模型训练结束条件,得到目标运动想象脑电分类模型,用于进行跨用户的运动想象动作分类。
16、根据本发明的另一方面,提供了一种电子设备,所述电子设备包括:
17、至少一个处理器;以及
18、与所述至少一个处理器通信连接的存储器;其中,
19、所述存储器存储有可被所述至少一个处理器执行的计算机程序,所述计算机程序被所述至少一个处理器执行,以使所述至少一个处理器能够执行本发明任一实施例所述的运动想象脑电分类模型训练方法。
20、根据本发明的另一方面,提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机指令,所述计算机指令用于使处理器执行时实现本发明任一实施例所述的运动想象脑电分类模型训练方法。
21、本发明实施例技术方案将源域电极数据和目标域电极数据输入至分类网络模型中,根据模型输出的源域深度特征数据和目标域深度特征数据,生成桥接域电极数据,同时将模型输出并输入至模型的三维全连接层进行特征提取,得到第一特征提取数据和第二特征提取数据,根据第一特征提取数据、第二特征提取数据和桥接域电极数据,确定第一损失值;将第一特征提取数据输入至模型的全连接层,并根据输出的运动预测类别和运动真实类别,确定第二损失值;根据第一损失值和第二损失值进行模型迭代训练,直到满足模型训练结束条件,得到目标运动想象脑电分类模型,用于进行跨用户的运动想象动作分类。上述技术方案通过结合源域深度特征和目标域深度特征,生成桥接域电极数据,解决了跨用户的时域电极数据分布差异问题,最小化了源域和目标域用户内和用户间的时域电极差异,完成了源域和目标域之间的数据对齐。通过结合源域深度特征数据、目标域深度特征数据和桥接域深度特征数据确定第一损失值,缩小了时间阶段间、用户间和电极间的分布差异性,提高了运动想象脑电分类模型的模型训练准确度,从而提高了运动想象分类的准确度。
22、应当理解,本部分所描述的内容并非旨在标识本发明的实施例的关键或重要特征,也不用于限制本发明的范围。本发明的其它特征将通过以下的说明书而变得容易理解。
1.一种运动想象脑电分类模型训练方法,其特征在于,包括:
2.根据权利要求1所述的方法,其特征在于,所述源域电极数据和所述目标域电极数据由至少两个实验阶段下的不同电极通道的电极数据组成;
3.根据权利要求2所述的方法,其特征在于,所述根据所述桥接域电极均值,生成桥接域电极数据,包括:
4.根据权利要求1所述的方法,其特征在于,所述将所述源域深度特征数据和所述目标域深度特征数据输入至所述分类网络模型中的三维全连接层分别进行特征提取,得到所述源域深度特征数据对应的第一提取特征数据和所述目标域深度特征数据对应的第二特征提取数据,并根据所述第一特征提取数据、所述第二特征提取数据和所述桥接域电极数据,确定第一损失值,包括:
5.根据权利要求4所述的方法,其特征在于,所述三维全连接层包括第一三维全连接层和第二三维全连接层;
6.根据权利要求5所述的方法,其特征在于,所述根据所述第一阶段源域特征提取数据和所述桥接域电极数据,确定第一阶段源域损失值,包括:
7.根据权利要求4所述的方法,其特征在于,所述三维全连接层包括第三三维全连接层和第四三维全连接层;
8.一种运动想象脑电分类模型训练装置,其特征在于,包括:
9.一种电子设备,其特征在于,所述电子设备包括:
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机指令,所述计算机指令用于使处理器执行时实现权利要求1-7中任一项所述的运动想象脑电分类模型训练方法。