一种基于小样本学习的图神经网络分类方法及装置

文档序号:24620403发布日期:2021-04-09 20:25阅读:362来源:国知局
一种基于小样本学习的图神经网络分类方法及装置

本发明属于人工智能技术领域,具体涉及一种基于小样本学习的图神经网络分类方法及装置。



背景技术:

深度神经网络的成功推动了诸多计算机视觉任务的研究,如:图像分类、对象检测和语义分割。然而,深度模型的成功部分归因于大型训练数据的可用性。这个前提条件不仅限制了可以应用模型的领域,而且不符合人类的认知过程。因为人可以根据过去的经验,仅通过一个或几个示例就可以快速学习新概念,所以越来越多的研究人员将注意力转向小样本学习。小样本学习是通过仅针对每个对象的几个训练示例来学习新对象。对于人类来说这不是很困难,但是对于机器而言仍然是一个具有挑战性的问题。

受人类学习的启发,研究人员探索出了一种用于小样本学习的元学习过程,该过程可以基于先前的经验获得知识,并以很少的标记数据解决新任务。具体地,元学习策略可以通过分配相似的任务来学习如何以很少的训练数据,并有效地识别未见过的类型。元学习策略会从多个相似的任务中学习一个跨任务元学习器,进而总结出一个通用的表示形式,从而为看不见的类的新任务提供更好的初始化。

目前,有研究表明利用上述元学习范例可以解决少量镜头图像分类的问题。本质上,这些方法学习相似性度量,并将标签信息从图像的支持集传播到查询集。由于在学习中非常需要充分利用支持集和查询之间的关系,因此引入了图神经网络来处理每个识别任务上的丰富关系结构。图神经网络通过消息传递算法迭代地聚合邻居的特征,因此表达了支持和查询实例之间的复杂交互。尤其是,几次学习中的图神经网络方法通过优化节点和边更新特征来获得更好的性能,从而学习类间的唯一性和类内的通用性,使用以节点为中心的图神经网络在连接的节点之间传播消息,以对未标记的样本进行分类,迭代更新边缘标签,以推断与现有支持集的查询关联。

然而,在现有的利用图神经网络处理关系结构的方法中,因为每项少量任务都不具有固有的图结构,所以它们被构造为具有边权重的完全连接图。而这种结构导致图神经网络的关系感应偏弱,从而带来因图结构不精确而无法在图中学习准确的关系的问题,同时,图结构可能会通过不相关节点之间的边缘传播噪声。另外,少量任务中的关系是指样本之间的相似性,而“相似”又是一个模糊的概念,没有明确定义,因此少量任务中的关系较难学习。



技术实现要素:

为解决上述问题,提供一种利用模糊理论对小样本数据的图结构进行处理从而增强关系归纳偏差的图神经网络分类方法及装置,本发明采用了如下技术方案:

本发明提供了一种基于小样本学习的图神经网络分类方法,用于对小样本数据进行分类得到节点分类结果,其特征在于,包括如下步骤:步骤s1,利用训练好的特征提取器对小样本数据进行特征提取得到未知节点;步骤s2,将未知节点初始化为全连接图结构并输入训练好的关系编码器进行边缘预测,得到边预测结果;步骤s3,利用预定的成员函数对边预测结果计算隶属度得到隶属度值μ:

式中,eij为边预测结果,ζ、η均为超参数,fμ为边预测结果eij的线性函数,并根据隶属度值对全连接图结构中的边进行删除从而得到非完全连接的模糊状态的图结构作为更新后图结构;步骤s4,将更新后图结构输入训练好的图神经网络分类器得到节点分类结果。

根据本发明提供的一种基于小样本学习的图神经网络分类方法,还可以具有这样的技术特征,其中,训练好的特征提取器、训练好的关系编码器以及训练好的图神经网络分类器构成一个模糊图神经网络模型,该模糊图神经网络模型训练过程包括如下步骤:步骤e1,利用预定的采样方法对训练数据集进行采样得到小样本的支撑集以及查询集;步骤e2,搭建包括特征提取器、关系编码器以及图神经网络分类器的初始模糊图神经网络模型;步骤e3,利用特征提取器对支撑集进行特征提取得到训练节点;步骤e4,将训练节点初始化为全连接状态的图结构并输入关系编码器进行边缘预测,得到边标签预测结果;步骤e5,利用成员函数对边标签预测结果计算隶属度得到边隶属度值,并根据边隶属度值对图结构中的边进行删除从而得到不再完全连接的图结构作为模糊图结构;步骤e6,将模糊图结构输入图神经网络分类器得到节点预测结果;步骤e7,基于查询集、边标签预测结果以及节点预测结果构建损失函数,并根据该损失函数更新初始模糊图神经网络模型直到收敛从而得到模糊图神经网络模型。

根据本发明提供的一种基于小样本学习的图神经网络分类方法,还可以具有这样的技术特征,其中,关系编码器为基于gnn的度量网络。

根据本发明提供的一种基于小样本学习的图神经网络分类方法,还可以具有这样的技术特征,其中,特征提取器的网络结构为resnet-12。

根据本发明提供的一种基于小样本学习的图神经网络分类方法,还可以具有这样的技术特征,其中,图神经网络分类器为以节点为中心的图神经网络。

本发明提供了一种基于小样本学习的图神经网络分类装置,用于对小样本数据进行分类得到节点分类结果,其特征在于,包括:特征提取部,利用训练好的特征提取器对小样本数据进行特征提取得到未知节点;边预测部,将未知节点初始化为全连接图结构并输入训练好的关系编码器进行边缘预测,得到边预测结果;图结构模糊处理部,利用预定的成员函数对边预测结果计算隶属度得到隶属度值,并根据隶属度值对全连接图结构中的边进行删除从而得到非完全连接的模糊状态的图结构作为更新后图结构;以及节点分类部,将更新后图结构输入训练好的图神经网络分类器得到节点分类结果。

发明作用与效果

根据本发明的一种基于小样本学习的图神经网络分类方法及装置,由于利用预定的成员函数对边预测结果计算隶属度得到隶属度值,并根据隶属度值对全连接图结构中的边进行删除从而得到非完全连接的模糊状态的图结构作为更新后图结构,因此解决了与模棱两可、主观或不精确的判断有关的问题,删除了两个具有显著差异的未知节点之间的边缘,并通过线性函数处理了不可靠的边缘,使得模型具有较强的关系归纳偏差,并且不受噪声影响。另外,由于将未知节点初始化为全连接图结构并输入训练好的关系编码器进行边缘预测得到边预测结果,因此不仅可以聚合来自其邻居的信息,还能聚合图结构中所有未知节点和边缘的信息,从而提高最后的分类准确率。同时,由于图神经网络分类器网络层较少,因此输出的特征不会过平滑,可以很好的区分不同群集的顶点。

附图说明

图1为本发明实施例的一种基于小样本学习的图神经网络分类方法的流程图;

图2为本发明实施例的成员函数的示意图;

图3为本发明实施例的fgnn模型训练过程的流程图;

图4为本发明实施例的fgnn模型训练过程的示意图;

图5为本发明实施例的一种基于小样本学习的图神经网络分类装置的结构框图;以及

图6为本发明实施例的针对tieredimagenet实验的实验结果示意图。

具体实施方式

为了使本发明实现的技术手段、创作特征、达成目的与功效易于明白了解,以下结合实施例及附图对本发明的一种基于小样本学习的图神经网络分类方法及装置作具体阐述。

<实施例>

本实施例通过miniimagenet数据集与tieredimagenet数据集对本发明的一种基于小样本学习的图神经网络分类方法及装置的工作流程以及效果进行阐述。

其中,miniimagenet是最受欢迎的小样本学习数据集,有100个类别,每个类别有600个84×84彩色图像样本。它分为训练集,验证集和测试集,分别具有64、16和20个类别。

另外,tieredimagenet与miniimagenet类似,tieredimagenet是ilsvrc-12数据集的子集,但它具有来自ilsvrc-12的更多类,有608个类别,每个类别平均有1281个样本,具有84×84彩色图像。与miniimagenet不同的是tieredimagenet与imagenet中的高级节点相对应的更广泛的类别采用分层类别结构。其中,属于最高层级的34个类别分为20个训练类别(351个类别),6个验证类别(97个类别)和8个测试类别(160个类别),从而确保了训练类别在语义上不同于测试类别。

本实施例的一种基于小样本学习的图神经网络分类方法及装置的实现的硬件平台为一张nvidiatitanx显卡(gpu加速),深度学习框架为pytorch。

图1为本发明实施例的一种基于小样本学习的图神经网络分类方法的流程图。

如图1所示,一种基于小样本学习的图神经网络分类方法包括如下步骤:

步骤s1,利用训练好的特征提取器对小样本数据进行特征提取得到未知节点。

其中,小样本数据是指只包含一个或几个示例的数据。

步骤s2,将未知节点初始化为全连接图结构并输入训练好的关系编码器进行边缘预测,得到边预测结果。

图2为本发明实施例的成员函数的示意图。

步骤s3,利用预定的成员函数(如图2所示)对边预测结果计算隶属度得到隶属度值μ,并根据隶属度值对全连接图结构中的边进行删除从而得到非完全连接的模糊状态的图结构作为更新后图结构。

其中,成员函数为来自话语范围的数据映射,具体为:

式中,eij为边预测结果,ζ、η均为超参数,fμ为边预测结果eij的线性函数。

如图2所示,成员函数的横坐标为范围域(即边预测结果),纵坐标为隶属度,本实施例中η设为0.3,ζ设为0.7。

根据隶属度值对全连接图结构中的边进行删除后,形成图神经网络分类器(即gnn分类器)的图的稀疏的相邻矩阵,并为图提供了相关的归纳偏置。

步骤s4,将更新后图结构输入训练好的图神经网络分类器得到节点分类结果。

图3为本发明实施例的fgnn模型训练过程的流程图;以及

图4为本发明实施例的fgnn模型训练过程的示意图。

其中,训练好的特征提取器、训练好的关系编码器以及训练好的图神经网络分类器构成一个模糊图神经网络模型(fgnn),该模糊图神经网络模型训练过程如图3以及4所示,包括如下步骤:

步骤e1,利用预定的采样方法对训练数据集进行采样得到小样本的支撑集以及查询集。具体地:

以miniimagenet为例,从训练集(64个类,每类600个样本)中随机采样5个类,每个类5个样本,从而构成支撑集;并从训练集的样本(采出的5个类,每类剩下的样本)中进行采样,得到查询集,查询集中每类有15个样本。

步骤e2,搭建包括特征提取器、关系编码器以及图神经网络分类器的初始模糊图神经网络模型。

步骤e3,利用特征提取器对支撑集进行特征提取得到训练节点。

其中,特征提取器的网络结构为resnet-12。具体地,resnet-12包含4个残差块,每个残差块由3个conv层和3×3内核组成。在每个残差块的末尾都有一个2×2的最大池化层,用来对特征图进行下采样。第一个残差块中的过滤器数量为64,在下一个残差块中的过滤器数量加倍。另外,在cnn残余块和分类器之间,使用了平均池层。

步骤e4,将训练节点初始化为全连接状态的图结构并输入关系编码器进行边缘预测,得到边标签预测结果。

其中,关系编码器为以边为中心的图神经网络,具有边缘更新功能和节点更新功能,用于计算样本之间的相似性得分,随着消息在图形中传播,节点和边缘可以聚合来自图形中所有节点和边的信息,而不仅仅是来自其邻居的信息。因此,相似度分数的计算不仅取决于两个节点,而还取决于其他成对的节点。

本实施例中,关系编码器的边缘更新功能通过fc层、批量归一化层以及s型激活层来实现;关系编码器的节点更新功能通过fc层、批处理规范化以及leakyrelu激活层实现。

步骤e5,利用成员函数对边标签预测结果计算隶属度得到边隶属度值,并根据边隶属度值对图结构中的边进行删除从而得到不再完全连接的图结构作为模糊图结构。

具体地,利用成员函数对边标签预测结果进行隶属度计算得到边隶属度值,根据隶属度值将图结构中的边进行删除从而得到为以节点为中心的gnn的偏差作为模糊图结构。

步骤e6,将模糊图结构输入图神经网络分类器得到节点预测结果。

其中,图神经网络分类器为以节点为中心的图神经网络。与关系编码器相比,gnn分类器没有专有的edgeupdate函数,它的边缘特征由成员函数提供。gnn分类器为2或3层。

步骤e7,基于查询集、边标签预测结果以及节点预测结果构建损失函数,并根据该损失函数更新初始模糊图神经网络模型直到收敛从而得到训练好的模糊图神经网络模型。

具体地,根据查询集、边标签预测结果以及节点预测结果构建损失函数,计算节点损耗以及边损耗,并根据节点损耗以及边损耗更新fgnn中的参数,直到fgnn模型收敛从而得到训练好的fgnn模型。

上述fgnn模型训练时各个部分的具体参数为:

针对特征提取器resnet-12部分,梯度下降方法为adam优化器,初始学习速率为0.1,权重衰减为0.00001,迭代次数为100,学习率每30个迭代减半一次。

另外,在针对训练数据采样前会执行标准的数据增强技术,例如随机水平翻转、裁剪等常规数据增强技术。

针对关系编码器部分,梯度下降方法为adam优化器,学习率为0.01,迭代次数达到1000后,学习率降低4/5。

针对gnn分类器部分,梯度下降方法为adam优化器,学习率为0.001,迭代次数达到1000后,学习率降低4/5。

上述一种基于小样本学习的图神经网络分类方法可以应用在计算机中并形成一个基于小样本学习的图神经网络分类装置1,该装置1包括特征提取部11、边预测部12、图结构模糊处理部13、节点分类部14以及控制部15(如图5所示)。

特征提取部11利用训练好的特征提取器对小样本数据进行特征提取得到未知节点。

边预测部12将未知节点初始化为全连接图结构并输入训练好的关系编码器进行边缘预测,得到边预测结果。

图结构模糊处理部13利用预定的成员函数对边预测结果计算隶属度得到隶属度值,并根据隶属度值对全连接图结构中的边进行删除从而得到非完全连接的模糊状态的图结构作为更新后图结构。

节点分类部14将更新后图结构输入训练好的图神经网络分类器得到节点分类结果。

控制部15控制上述各个部运行。

利用上述的基于小样本学习的图神经网络分类方法及装置对miniimagenet数据集与tieredimagenet数据集进行了5路单命中和5次命中的训练,以12000次迭代来训练fgnn模型,每次迭代包含1个或5个支持以及来自5个类中的每个类的15个查询。

针对miniimagenet,fgnn模型在进行5次1-shot学习时达到了64.15%的最佳性能。在5次5-shot实验中,fgnn达到了80.08%的准确性,并排名第一。其中,tpn使用拉普拉斯矩阵而不是特征相似性来将支持集的标签传播到查询集。在egnn中,支持集不仅传播其节点特征,而且还传播边缘标签信息,以考虑查询样本之间更复杂的交互。相比之下,fgnn通过成员函数优化关系表示从而得到更合理的图结构,并且关系表示(即隶属度值)将被冻结以进行节点分类以在邻居之间交换更多相关信息。

在5次1-shot和5次5-shot实验中,fgnn的准确度比wdae-gnn高3.08%和3.33%,wdae-gnn是以前很少发动学习的最佳gnn方法。

图6为本发明实施例的针对tieredimagenet实验的实验结果示意图。

针对tieredimagenet,如图6所示,通过度量学习、梯度下降以及gnn三种类型的小样本学习方法对tieredimagenet进行5次1-shot和5次5-shot实验,fgnn在5次1-shot和5次5-shot实验中也获得了最佳结果。由图6可知,fgnn很大程度上优于egnn*,对于1-shot,差距达到了10.06%,对于5-shot,fgnn超过egnn*约3.91%。尽管tieredimagenet中的训练集和测试集之间的差距比miniimagenet中的要大,但是由于可获得的数据更多,所以fgnn模型的性能反而更好。

实施例作用与效果

根据本实施例提供的一种基于小样本学习的图神经网络分类方法及装置,由于利用预定的成员函数对边预测结果计算隶属度得到隶属度值,并根据隶属度值对全连接图结构中的边进行删除从而得到非完全连接的模糊状态的图结构作为更新后图结构,因此解决了与模棱两可、主观或不精确的判断有关的问题,删除了两个具有显著差异的未知节点之间的边缘,并通过线性函数处理了不可靠的边缘,使得模型具有较强的关系归纳偏差,并且不受噪声影响。另外,由于将未知节点初始化为全连接图结构并输入训练好的关系编码器进行边缘预测得到边预测结果,因此不仅可以聚合来自其邻居的信息,还能聚合图结构中所有未知节点和边缘的信息,从而提高最后的分类准确率。同时,由于图神经网络分类器网络层较少,因此输出的特征不会过平滑,可以很好的区分不同群集的顶点。

另外,在实施例中,由于特征提取器的网络结构为resnet-12,因此不会模型结构简单,构建快速方便。

上述实施例仅用于举例说明本发明的具体实施方式,而本发明不限于上述实施例的描述范围。

当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1