基于任务的焦点损失提升多语言元学习语音识别方法与流程

文档序号:31698531发布日期:2022-10-01 06:50阅读:来源:国知局

技术特征:
1.基于任务的焦点损失提升多语言元学习语音识别方法,其特征在于,所述方法采用端到端的语音识别网络架构,具体包括:步骤1:初始化语音识别模型f
θ
,输入原始语音特征序列(x1,x2,...,x
t
);步骤2:针对从多语言数据集中抽取的任务t
i
,将所述任务t
i
分为支持集和查询集和查询集表示第i种语言数据;步骤3:计算任务t
i
的asr损失,使用梯度下降得到在支持集上更新后的参数θ
i
;步骤4:使用在支持集上更新后的参数θ
i
在查询集上计算查询损失步骤5:根据任务t
i
的查询损失计算得到任务t
i
的难任务调节器m
tfl
(θ),所述难任务调节器m
tfl
(θ)用于表示任务t
i
的学习难度等级;其中,查询损失越大,则对应的难任务调节器m
tfl
(θ)越大;步骤6:重复n次步骤2至步骤5,计算得到n个任务对应的查询损失和难任务调节器;步骤7:基于所有n个任务对应的查询损失和难任务调节器计算得到基于任务的焦点损失l
tfl
;步骤8:使用所述焦点损失l
tfl
更新语音识别模型f
θ
的元参数θ;步骤9:重复步骤2至步骤8,直至更新后的语音识别模型f
θ
满足给定要求。2.根据权利要求1所述的基于任务的焦点损失提升多语言元学习语音识别方法,其特征在于,步骤5中,难任务调节器m
tfl
(θ)的计算公式为:其中,k≥0和γ≥0为可调超参数。3.根据权利要求1所述的基于任务的焦点损失提升多语言元学习语音识别方法,其特征在于,步骤7中,基于任务的焦点损失l
tfl
的计算公式为:其中,是基础学习器的损失函数。4.根据权利要求1所述的基于任务的焦点损失提升多语言元学习语音识别方法,其特征在于,步骤8中,元参数θ的更新公式为:其中,β表示学习率。5.根据权利要求1所述的基于任务的焦点损失提升多语言元学习语音识别方法,所述端到端的语音识别网络架构具体采用ctc-注意力联合架构;对应的,步骤3中,任务t
i
的asr损失的计算公式为:l=λl
ctc
+(1-λ)l
att
其中,l
ctc
为ctc损失,l
att
为解码损失,超参数λ表示l
ctc
的权重。

技术总结
本发明提供一种基于任务的焦点损失提升多语言元学习语音识别方法。该方法基于任务的焦点损失改进多语言元学习对任务不平衡的忽略,基于每个任务的查询损失引入了难任务调节器,引导模型更加关注难任务,并且为了充分利用难任务的数据,同时使用支持集梯度与查询集梯度来更新元参数。此外,本发明还在样本层面解释了难任务调节器的意义,经过公式推导,发现它与任务内样本的预测概率乘积成反相关。通过使用本发明方法,可以使模型学习到的初始化更加均衡,更加充分地利用了所有源语言的知识,从而能够有效的对目标语言进行泛化。从而能够有效的对目标语言进行泛化。从而能够有效的对目标语言进行泛化。


技术研发人员:屈丹 陈雅淇 杨绪魁 张文林 张昊 陈琦 李静涛
受保护的技术使用者:郑州信大先进技术研究院
技术研发日:2022.06.28
技术公布日:2022/9/30
当前第2页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1