一种模型训练方法及装置与流程

文档序号:34185124发布日期:2023-05-17 12:35阅读:58来源:国知局
一种模型训练方法及装置与流程

本技术涉及人工智能,特别是涉及一种模型训练方法及装置。


背景技术:

1、分子从头设计是理性药物设计当中的重要一环,但是目前的分子从头设计方法大部分是通过对化学空间采样,然后生成对应的二维分子,无法显式地捕捉到小分子和蛋白口袋结构的相互作用。同时,2d分子生成模型面临过拟合,难以泛化的问题。

2、目前基于深度学习框架的全新药物设计方法按照训练模型时使用的分子表征类型可以分为两类:第一类是基于2d的全新药物设计方法;第二类是基于3d的全新药物设计方法。

3、而2d的全新药物设计方法,存在没有底层物理的支持,模型面临过拟合,泛化能力低以及生成的分子合理但随机的问题。而3d的全新药物设计方法,则无法有效捕捉和蛋白质口袋之间的作用力模式,生成的分子结构不合理,亲和力无法超过原位配体,对算力显存消耗大,训练低效,难以直接应用于现实世界的药物设计。


技术实现思路

1、本技术实施例所要解决的技术问题是提供一种模型训练方法及装置,以捕捉到更高尺度的相互作用,使得模型生成的分子与蛋白口袋发生合理的几何匹配和能量匹配,实现了可靠的,有效的基于口袋的三维分子从头设计。同时,引入了两个尺度的自回归模式,即全局尺度和原子组件尺度,完成对口袋内分子的几何和拓扑结构学习,在给定口袋结构之后即可以生成与这个口袋相匹配的分子。

2、第一方面,本技术实施例提供了一种模型训练方法,所述方法包括:

3、获取模型训练样本,所述模型训练样本包括:蛋白口袋;

4、将所述模型训练样本输入至待训练蛋白口袋内分子生成模型,所述待训练蛋白口袋内分子生成模型包括:n个循环连接的网络层,n为正整数;

5、调用n个所述循环连接的网络层对所述蛋白口袋进行逐个原子预测处理,得到所述蛋白口袋对应的预测原子类型、预测中心原子、预测原子位置和预测原子键连关系;

6、基于所述预测原子类型、所述预测中心原子、所述预测原子位置和所述预测原子键连关系,计算得到所述待训练蛋白口袋内分子生成模型的损失值;

7、在所述损失值处于预设范围内的情况下,将训练后的待训练蛋白口袋内分子生成模型作为最终的蛋白口袋内分子生成模型。

8、可选地,每个所述网络层包括:特征提取网络层、向量表征网络层和原子预测网络层,

9、所述调用n个所述循环连接的网络层对所述蛋白口袋进行逐个原子预测处理,得到所述蛋白口袋对应的预测原子、预测中心原子、预测原子位置和预测原子键连关系,包括:

10、在采用第m个网络层对所述蛋白口袋进行处理时,调用所述第m个网络层的特征提取网络层提取所述蛋白口袋内的第m个原子的原子特征;m为大于1的正整数,且m≤n;

11、调用所述第m个网络层的向量表征网络层对所述原子特征进行向量表征处理,得到所述第m个原子的原子特征向量;

12、调用所述第m个网络层的原子预测网络层对所述原子特征向量进行处理,得到所述蛋白口袋的第m个原子对应的预测原子、预测中心原子、预测原子位置和预测原子键连关系。

13、可选地,所述原子预测网络层包括:位置预测层、原子类型预测层、中心原子预测层和键类型预测层,

14、所述调用所述第m个网络层的原子预测网络层对所述原子特征向量进行处理,得到所述第m个原子对应的预测原子类型、预测中心原子、预测原子位置和预测原子键连关系,包括:

15、调用所述位置预测层根据第m-1个网络层选定的中心原子对所述原子特征向量进行处理,预测得到所述第m个原子的预测原子位置;

16、调用所述原子类型预测层根据第m-1个网络层选定的中心原子对所述原子特征向量进行处理,预测得到所述第m个原子的预测原子类型;

17、调用所述键类型预测层根据第m-1个网络层选定的中心原子对所述原子特征向量进行处理,预测得到所述第m个原子与所述第m-1个网络层生成的原子之间的预测原子键连关系;

18、调用所述中心原子预测层根据已生成的原子的概率值,从m个已生成的原子中筛选出第m个网络层的预测中心原子。

19、可选地,所述基于所述预测原子类型、所述预测中心原子、所述预测原子位置和所述预测原子键连关系,计算得到所述待训练蛋白口袋内分子生成模型的损失值,包括:

20、基于所述预测中心原子和所述蛋白口袋的标注中心原子,计算得到原子中心损失值;

21、基于所述预测原子位置,计算得到位置损失值;

22、基于所述预测原子类型和所述蛋白口袋的标注原子类型,计算得到类型损失值;

23、基于所述预测原子键连关系和所述蛋白口袋的标注原子键连关系,计算得到键类型损失值;

24、基于所述原子中心损失值、所述位置损失值、所述类型损失值和所述键类型损失值,计算得到所述待训练蛋白口袋内分子生成模型的损失值。

25、可选地,在所述将训练后的待训练蛋白口袋内分子生成模型作为最终的蛋白口袋内分子生成模型之后,还包括:

26、获取待处理蛋白口袋;

27、将所述待处理蛋白口袋输入至所述蛋白口袋内分子生成模型;

28、调用n个所述循环连接的网络层对所述待处理蛋白口袋进行逐个原子处理,预测得到所述待处理蛋白口袋对应的原子类型、中心原子、原子位置和原子键连关系;

29、基于所述原子类型、所述中心原子、所述原子位置和所述原子键连关系,生成所述待处理蛋白口袋对应的分子三维结构。

30、第二方面,本技术实施例提供了一种模型训练装置,所述装置包括:

31、模型训练样本获取模块,用于获取模型训练样本,所述模型训练样本包括:蛋白口袋;

32、模型训练样本输入模块,用于将所述模型训练样本输入至待训练蛋白口袋内分子生成模型,所述待训练蛋白口袋内分子生成模型包括:n个循环连接的网络层,n为正整数;

33、蛋白口袋处理模块,用于调用n个所述循环连接的网络层对所述蛋白口袋进行逐个原子预测处理,得到所述蛋白口袋对应的预测原子类型、预测中心原子、预测原子位置和预测原子键连关系;

34、损失值计算模块,用于基于所述预测原子类型、所述预测中心原子、所述预测原子位置和所述预测原子键连关系,计算得到所述待训练蛋白口袋内分子生成模型的损失值;

35、分子生成模型获取模块,用于在所述损失值处于预设范围内的情况下,将训练后的待训练蛋白口袋内分子生成模型作为最终的蛋白口袋内分子生成模型。

36、可选地,每个所述网络层包括:特征提取网络层、向量表征网络层和原子预测网络层,

37、所述蛋白口袋处理模块包括:

38、原子特征提取单元,用于在采用第m个网络层对所述蛋白口袋进行处理时,调用所述第m个网络层的特征提取网络层提取所述蛋白口袋内的第m个原子的原子特征;m为大于1的正整数,且m≤n;

39、特征向量获取单元,用于调用所述第m个网络层的向量表征网络层对所述原子特征进行向量表征处理,得到所述第m个原子的原子特征向量;

40、原子预测单元,用于调用所述第m个网络层的原子预测网络层对所述原子特征向量进行处理,得到所述蛋白口袋的第m个原子对应的预测原子、预测中心原子、预测原子位置和预测原子键连关系。

41、可选地,所述原子预测网络层包括:位置预测层、原子类型预测层、中心原子预测层和键类型预测层,

42、所述原子预测单元包括:

43、预测原子位置预测子单元,用于调用所述位置预测层根据第m-1个网络层选定的中心原子对所述原子特征向量进行处理,预测得到所述第m个原子的预测原子位置;

44、预测原子类型预测子单元,用于调用所述原子类型预测层根据第m-1个网络层选定的中心原子对所述原子特征向量进行处理,预测得到所述第m个原子的预测原子类型;

45、预测键连关系预测子单元,用于调用所述键类型预测层根据第m-1个网络层选定的中心原子对所述原子特征向量进行处理,预测得到所述第m个原子与所述第m-1个网络层生成的原子之间的预测原子键连关系;

46、预测中心原子预测子单元,用于调用所述中心原子预测层根据已生成的原子的概率值,从m个已生成的原子中筛选出第m个网络层的预测中心原子。

47、可选地,所述损失值计算模块包括:

48、中心损失值计算单元,用于基于所述预测中心原子和所述蛋白口袋的标注中心原子,计算得到原子中心损失值;

49、位置损失值计算单元,用于基于所述预测原子位置,计算得到位置损失值;

50、类型损失值计算单元,用于基于所述预测原子类型和所述蛋白口袋的标注原子类型,计算得到类型损失值;

51、键类型损失值计算单元,用于基于所述预测原子键连关系和所述蛋白口袋的标注原子键连关系,计算得到键类型损失值;

52、损失值计算单元,用于基于所述原子中心损失值、所述位置损失值、所述类型损失值和所述键类型损失值,计算得到所述待训练蛋白口袋内分子生成模型的损失值。

53、可选地,所述装置还包括:

54、蛋白口袋获取模块,用于获取待处理蛋白口袋;

55、蛋白口袋输入模块,用于将所述待处理蛋白口袋输入至所述蛋白口袋内分子生成模型;

56、蛋白口袋原子预测模块,用于调用n个所述循环连接的网络层对所述待处理蛋白口袋进行逐个原子处理,预测得到所述待处理蛋白口袋对应的原子类型、中心原子、原子位置和原子键连关系;

57、分子三维结构生成模块,用于基于所述原子类型、所述中心原子、所述原子位置和所述原子键连关系,生成所述待处理蛋白口袋对应的分子三维结构。

58、第三方面,本技术实施例提供了一种电子设备,包括:

59、处理器、存储器以及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述程序时实现上述任一项所述的模型训练方法。

60、第四方面,本技术实施例提供了一种计算机可读存储介质,当所述存储介质中的指令由电子设备的处理器执行时,使得电子设备能够执行上述任一项所述的模型训练方法。

61、与现有技术相比,本技术实施例包括以下优点:

62、本技术实施例中,通过获取模型训练样本,模型训练样本包括:蛋白口袋。将模型训练样本输入至待训练蛋白口袋内分子生成模型,待训练蛋白口袋内分子生成模型包括:n个循环连接的网络层,n为正整数。调用n个循环连接的网络层对蛋白口袋进行逐个原子预测处理,得到蛋白口袋对应的预测原子类型、预测中心原子、预测原子位置和预测原子键连关系。基于预测原子类型、预测中心原子、预测原子位置和预测原子键连关系,计算得到待训练蛋白口袋内分子生成模型的损失值。在损失值处于预设范围内的情况下,将训练后的待训练蛋白口袋内分子生成模型作为最终的蛋白口袋内分子生成模型。本技术实施例通过采用多尺度建模方式,不仅增强了训练效率,而且可以捕捉到更高尺度的相互作用,使得模型生成的分子与蛋白口袋发生合理的几何匹配和能量匹配,实现了可靠的,有效的基于口袋的三维分子从头设计。同时,引入了两个尺度的自回归模式,即全局尺度和原子组件尺度,完成对口袋内分子的几何和拓扑结构学习,在给定感兴趣的口袋结构之后即可以生成与这个口袋相匹配的分子。

63、应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本技术。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1