对象类别识别方法、装置、设备及存储介质与流程

文档序号:32969656发布日期:2023-01-17 20:06阅读:47来源:国知局
对象类别识别方法、装置、设备及存储介质与流程

1.本技术涉及计算机技术领域,尤其涉及一种对象类别识别方法、装置、设备及存储介质。


背景技术:

2.相关技术中,针对包含cnn或mlp的神经网络,通常是根据参数的重要性来剪除网络中的不重要参数,从而达到减小模型的目的。而对于attention层,由于层内的参数主要由全连接层(fc)组成,而这些fc层的连接方式与cnn、mlp均不同,如果直接进行通道剪枝,会使得剪枝后的网络无法正确进行计算。


技术实现要素:

3.本技术提供了一种对象类别识别方法、装置、设备及存储介质,可以提高对象类别的识别速率。
4.一方面,本技术提供了一种对象类别识别方法,所述方法包括:
5.获取目标对象的目标数据;
6.基于对象分类模型对所述目标数据进行类别识别处理,得到所述目标对象的目标类别标签;所述对象分类模型为基于样本对象的样本数据,对剪枝分类模型进行对象类别识别训练得到,所述剪枝分类模型为在初始分类模型中剪除待剪除通道后的模型,所述待剪除通道为所述初始分类模型中缩放参数绝对值小于预设阈值的通道;所述初始分类模型为基于所述样本数据对预设网络进行对象类别识别训练得到;所述预设网络中包括更新注意力网络,所述更新注意力网络为设置有缩放层的注意力网络;所述初始分类模型中每个通道的缩放参数基于所述缩放层确定;所述样本数据标注有所述样本对象的样本类别标签。
7.另一方面提供了一种对象类别识别装置,所述装置包括:
8.目标数据获取模块,用于获取目标对象的目标数据;
9.目标类别确定模块,用于基于对象分类模型对所述目标数据进行类别识别处理,得到所述目标对象的目标类别标签;所述对象分类模型为基于样本对象的样本数据,对剪枝分类模型进行对象类别识别训练得到,所述剪枝分类模型为在初始分类模型中剪除待剪除通道后的模型,所述待剪除通道为所述初始分类模型中缩放参数绝对值小于预设阈值的通道;所述初始分类模型为基于所述样本数据对预设网络进行对象类别识别训练得到;所述预设网络中包括更新注意力网络,所述更新注意力网络为设置有缩放层的注意力网络;所述初始分类模型中每个通道的缩放参数基于所述缩放层确定;所述样本数据标注有所述样本对象的样本类别标签。
10.另一方面提供了一种对象类别识别设备,所述设备包括处理器和存储器,所述存储器中存储有至少一条指令或至少一段程序,所述至少一条指令或所述至少一段程序由所述处理器加载并执行以实现如上所述的对象类别识别方法。
11.另一方面提供了一种计算机存储介质,所述计算机存储介质存储有至少一条指令或至少一段程序,所述至少一条指令或至少一段程序由处理器加载并执行以实现如上所述的对象类别识别方法。
12.另一方面提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行以实现如上所述的对象类别识别方法。
13.本技术提供的对象类别识别方法、装置、设备及存储介质,具有如下技术效果:
14.本技术获取目标对象的目标数据;基于对象分类模型对所述目标数据进行类别识别处理,得到所述目标对象的目标类别标签;所述对象分类模型为基于样本对象的样本数据,对剪枝分类模型进行对象类别识别训练得到,所述剪枝分类模型为在初始分类模型中剪除待剪除通道后的模型,所述待剪除通道为所述初始分类模型中缩放参数绝对值小于预设阈值的通道;所述初始分类模型为基于所述样本数据对预设网络进行对象类别识别训练得到;所述预设网络中包括更新注意力网络,所述更新注意力网络为设置有缩放层的注意力网络;所述初始分类模型中每个通道的缩放参数基于所述缩放层确定;所述样本数据标注有所述样本对象的样本类别标签;本技术通过在预设网络的注意力网络中设置缩放层,然后通过缩放层参数确定待剪枝通道,实现了在包含注意力网络的模型中进行剪枝,再进一步根据剪枝分类模型确定对象分类模型,从而减少了对象分类模型的运算量,提高了模型的计算速度,提高了对象类别的识别速度。
附图说明
15.为了更清楚地说明本技术实施例或现有技术中的技术方案和优点,下面将对实施例或现有技术描述中所需要使用的附图作简单的介绍,显而易见地,下面描述中的附图仅仅是本技术的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它附图。
16.图1是本技术实施例提供的一种对象类别识别系统的示意图;
17.图2是本技术实施例提供的一种对象类别识别方法的流程示意图;
18.图3是本技术实施例提供的一种对象分类模型的训练方法的流程示意图;
19.图4是本技术实施例提供的一种在所述原始注意力网络中加入缩放层,得到所述更新注意力网络的方法的流程示意图;
20.图5是本技术实施例提供的基于所述初始分类模型,确定所述对象分类模型的方法的流程示意图;
21.图6是本技术实施例提供的在所述初始分类模型中剪除所述待剪除通道,得到所述剪枝分类模型方法的流程示意图;
22.图7是本技术实施例提供的一种基于所述样本数据对所述剪枝分类模型进行对象类别识别训练,得到所述对象分类模型的方法的流程示意图;
23.图8是本技术实施例提供的一种图片分类模型的结构示意图;
24.图9是本技术实施例提供的增加缩放层前后的attention的结构对比图;
25.图10是本技术实施例提供的初始分类模型中通道剪除前后对比图;
26.图11是本技术实施例提供的一种增加索引池化层前后的attention的结构对比图;
27.图12是本技术实施例提供的一种对象分类模型的构建方法流程图;
28.图13是本技术实施例提供的一种对象类别识别装置的结构示意图;
29.图14是本技术实施例提供的一种服务器的结构示意图。
具体实施方式
30.下面将结合本技术实施例中的附图,对本技术实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本技术一部分实施例,而不是全部的实施例。基于本技术中的实施例,本领域普通技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都属于本技术保护的范围。
31.首先,在对本技术实施例进行描述的过程中出现的部分名词或者术语作如下解释:
32.人工智能(artificial intelligence,ai)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。具体地,本技术实施例提供的方案涉及人工智能的机器学习领域。机器学习(machine learning,ml)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。
33.智慧交通是在整个交通运输领域充分利用物联网、空间感知、云计算、移动互联网等新一代信息技术,综合运用交通科学、系统方法、人工智能、知识挖掘等理论与工具,以全面感知、深度融合、主动服务、科学决策为目标,通过建设实时的动态信息服务体系,深度挖掘交通运输相关数据,形成问题分析模型,实现行业资源配置优化能力、公共决策能力、行业管理能力、公众服务能力的提升,推动交通运输更安全、更高效、更便捷、更经济、更环保、更舒适的运行和发展,带动交通运输相关产业转型、升级。
34.attention:一种神经网络层,attention(注意力)机制的本质是从人类视觉注意力机制中获得灵感。大致是我们视觉在感知东西的时候,一般不会是一个场景从到头看到尾每次全部都看,而往往是根据需求观察注意特定的一部分。而且当我们发现一个场景经常在某部分出现自己想观察的东西时,我们就会进行学习在将来再出现类似场景时把注意力放到该部分上。
35.transformer:一种神经网络模块,一个利用注意力机制来提高模型训练速度的模型。
36.imagenet数据集:一个计算机视觉数据集,是由斯坦福大学的李飞飞教授带领创建。该数据集包合14,197,122张图片和21,841个synset索引。synset是wordnet层次结构中的一个节点,它又是一组同义词集合。imagenet数据集一直是评估图像分类算法性能的基
准。
37.需要说明的是,本技术的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本技术的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或服务器不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
38.请参阅图1,图1是本技术实施例提供的一种对象类别识别系统的示意图,如图1所示,该对象类别识别系统可以至少包括服务器01和客户端02。
39.具体的,本技术实施例中,所述服务器01可以包括一个独立运行的服务器,或者分布式服务器,或者由多个服务器组成的服务器集群,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、cdn(content delivery network,内容分发网络)、以及大数据和人工智能平台等基础云计算服务的云服务器。服务器01可以包括有网络通信单元、处理器和存储器等等。具体的,所述服务器01可以用于训练对象分类模型,并基于对象分类模型确定目标对象的类别标签。
40.具体的,本技术实施例中,所述客户端02可以包括智能手机、台式电脑、平板电脑、笔记本电脑、数字助理、智能可穿戴设备、智能音箱、车载终端、智能电视等类型的实体设备,也可以包括运行于实体设备中的软体,例如一些服务商提供给用户的网页页面,也可以为该些服务商提供给用户的应用。具体的,所述客户端02可以用于在线查询目标类型的类别。
41.以下介绍本技术的一种对象类别识别方法,图2是本技术实施例提供的一种对象类别识别方法的流程示意图,本说明书提供了如实施例或流程图所述的方法操作步骤,但基于常规或者无创造性的劳动可以包括更多或者更少的操作步骤。实施例中列举的步骤顺序仅仅为众多步骤执行顺序中的一种方式,不代表唯一的执行顺序。在实际中的系统或服务器产品执行时,可以按照实施例或者附图所示的方法顺序执行或者并行执行(例如并行处理器或者多线程处理的环境)。具体的如图2所示,所述方法可以应用于图1所示的服务器01中,可以包括:
42.s201:获取目标对象的目标数据。
43.在本技术实施例中,目标对象可以为各个不同应用场景、不同领域的对象,可以包括人、动物、商品、生活用品等,例如可以包括但不限于用户、商店、地址、动物、电子设备等。目标数据为目标对象对应的数据,可以表征目标对象的属性。目标数据可以包括但不限于字符、文本、图像等。
44.在本技术实施例中,所述获取目标对象的目标数据可以包括:
45.接收终端响应于对象类别识别指令,发送的目标数据。
46.在本技术实施例中,可以通过目标对象对应的终端,获取其对应的目标数据。
47.s203:基于对象分类模型对所述目标数据进行类别识别处理,得到所述目标对象的目标类别标签;所述对象分类模型为基于样本对象的样本数据,对剪枝分类模型进行对象类别识别训练得到,所述剪枝分类模型为在初始分类模型中剪除待剪除通道后的模型,
所述待剪除通道为所述初始分类模型中缩放参数绝对值小于预设阈值的通道;所述初始分类模型为基于所述样本数据对预设网络进行对象类别识别训练得到;所述预设网络中包括更新注意力网络,所述更新注意力网络为设置有缩放层的注意力网络;所述初始分类模型中每个通道的缩放参数基于所述缩放层确定;所述样本数据标注有所述样本对象的样本类别标签。
48.在本技术实施例中,样本对象与目标对象为同一应用场景、同一类型的对象,样本数据与目标数据为同一类型的数据,例如,若目标数据为图像,那么样本数据也是图像。预设网络可以为包含注意力网络(attention)的各种网络,例如,预设网络可以为transformer,还可以为其他类型的网络;更新注意力网络为设置有缩放(scale)层的注意力网络;缩放参数对应的预设阈值可以根据实际情况进行设置,可以设置为接近零的数值。
49.在本技术实施例中,首先根据预设网络训练得到初始分类模型,再根据模型收敛时缩放层对应的每个通道的缩放参数,确定待剪除通道,从而确定出包含注意力网络的剪枝分类模型,再对该模型进行进一步训练得到对象分类模型。本实施例中的对象分类模型可以应用于不同场景,对不同的对象进行分类;本实施例的对象分类模型减少了网络部署时attention的计算量,可以运用在文本分类、图片分类、视频分类等用到attention算法的模型中。例如,在图像分类场景中,可以通过对象分类模型对各种图像进行分类;在app场景中,可以根据用户的关联数据以及app业务指标,对用户进行分类;在广告场景中,可以根据用户的关联数据以及广告业务指标,判断用户是否对特定广告感兴趣。
50.在一个具体的实施例中,transformer作为一种神经网络基础模块,在自然语言处理和计算机视觉的任务中发挥了重要作用,例如在图片分类任务中,可以堆叠transformer层来搭建图片分类网络,最后接入分类器,进行图片分类;如图8所示,图8为一种图片分类模型的结构示意图,图片分类模型由多层transformer组成,每个transformer中均包括改进的attention,由此可以减少分类模型的运算量,提高模型的分类速度。
51.在一些实施例中,如图3所示,所述对象分类模型的训练方法包括:
52.s301:获取原始注意力网络;
53.在本技术实施例中,原始注意力网络可以为attention。
54.s303:在所述原始注意力网络中加入缩放层,得到所述更新注意力网络;
55.在本技术实施例中,可以在attention中加入缩放(scale)层,从而得到更新注意力网络。
56.在一些实施例中,如图4所示,所述缩放层包括第一缩放层和第二缩放层,所述在所述原始注意力网络中加入缩放层,得到所述更新注意力网络,包括:
57.s3031:确定所述原始注意力网络中的第一线性层、第二线性层、第一矩阵乘法层和第二矩阵乘法层;所述第一线性层与所述第一矩阵乘法层连接,所述第二线性层与所述第二矩阵乘法层连接;
58.在本技术实施例中,如图9所示,图9中(a)为一种attention的结构示意图,attention中包括线性层和矩阵乘法层(mat multiply),其中,线性层包括第一线性层(对应value线性变换矩阵,v)和第二线性层(对应key线性变换矩阵,k);矩阵乘法层包括第一矩阵乘法层和第二矩阵乘法层;所述第一线性层与所述第一矩阵乘法层连接,所述第二线性层与所述第二矩阵乘法层连接;所述原始注意力网络还包括归一化指数函数层
(softmax)、第三线性层(对应query线性变换矩阵,q),所述第三线性层与所述第二矩阵乘法层连接,所述第二矩阵乘法层与所述归一化指数函数层连接,所述归一化指数函数层与所述第一矩阵乘法层连接。所述第一线性层用于确定数据的内容向量,所述第二线性层用于确定所述数据的内容向量查询标识;所述第三线性层用于确定所述数据的查询向量;q就是词的查询向量,k是“被查”向量,v是内容向量。其中,q是最适合查找目标的,k是最适合接收查找的,v就是内容,这三者不一定要一致,所以网络这么设置了三个向量,然后学习出最适合的q,k,v,以此增强网络的能力。
59.s3033:在所述第一线性层与所述第一矩阵乘法层连接层之间增加所述第一缩放层,在所述第二线性层与所述第二矩阵乘法层连接层之间增加所述第二缩放层,得到所述更新注意力网络。
60.在本技术实施例中,可以同时在attention中增加两个缩放层,得到所述更新注意力网络(scaled attention);如图9所示,图9中(b)为一种增加了缩放层的attention(scaled attention)的结构示意图。
61.s305:基于所述更新注意力网络,构建所述预设网络;
62.在本技术实施例中,可以根据更新注意力网络,构建预设网络;例如,可以根据scaled attention,构建transformer。
63.s307:根据所述样本对象的样本数据,对所述预设网络进行对象类别识别训练,得到初始分类模型;所述初始分类模型包括至少两个通道;
64.在本技术实施例中,预设网络包括至少两个通道,训练得到的初始分类模型与预设网络的通道数量相同。
65.在一些实施例中,所述根据所述样本对象的样本数据,对所述预设网络进行对象类别识别训练,得到初始分类模型,可以包括:
66.将所述样本对象的样本数据输入所述预设网络进行对象类别识别训练,在训练过程中,不断调整所述预设网络的参数,直至所述预设网络输出的对象类别标签与标注的对象类别标签相匹配;
67.将输出的对象类别标签与标注的对象类别标签相匹配时的参数对应的预设网络,作为所述初始分类模型。
68.s309:基于所述初始分类模型,确定所述对象分类模型。
69.在一些实施例中,如图5所示,所述基于所述初始分类模型,确定所述对象分类模型,包括:
70.s3091:获取所述初始分类模型中所述缩放层对应的目标函数;
71.s3093:根据所述目标函数,确定所述初始分类模型中每个通道对应的缩放参数;
72.在本技术实施例中,可以根据目标函数中的系数确定初始分类模型中每个通道对应的缩放参数。
73.s3095:基于所述初始分类模型中每个通道对应的缩放参数,确定待剪除通道;所述待剪除通道为缩放参数绝对值小于预设阈值的通道;
74.在本技术实施例中,预设阈值可以根据实际需求进行设置,例如预设阈值可以为接近0的数值。
75.在一些实施例中,所述基于所述初始分类模型中每个通道对应的缩放参数,确定
待剪除通道,可以包括:
76.确定所述初始分类模型中通道总数量以及待剪除通道的比例;
77.在本技术实施例中,待剪除通道的比例为待剪除通道占模型中通道总数量的比例,可以根据实际情况进行设置,例如可以设置为10%、20%等。
78.根据所述通道总数量以及所述待剪除通道的比例,确定待剪除通道数量;
79.在本技术实施例中,可以计算通道总数量与所述待剪除通道的比例的乘积,得到待剪除通道数量。
80.根据所述初始分类模型中每个通道对应的缩放参数以及所述待剪除通道数量,确定所述预设阈值;
81.根据所述初始分类模型中每个通道对应的缩放参数以及所述预设阈值,确定待剪除通道的标识信息。
82.在本技术实施例中,将模型中所有的attention转化为scaledattention之后,在原有的任务上对模型进行训练,得到的模型与原有的模型在任务精度上保持一致,然后根据scale层的参数选择一部分通道剪除,由于scale层越接近0,说明该通道内的特征越不重要,因此可以根据scale层参数的绝对值进行对应通道剪除,例如预设剪除20%的通道,则从scale层参数中选择20%绝对值最接近0的通道,即预设阈值可以为接近0的数值,记录下通道的序号,后续计算中不考虑。在一个具体的实施例中,如图10所示,图10中(a)为初始分类模型中5个通道的缩放参数,(b)为模型中剪除两个通道后其他通道的缩放参数。
83.在一些实施例中,模型中的一个通道对应一组特征参数,可以计算每个通道对应特征参数的平方和,通道的特征参数包括缩放参数,然后根据模型每个通道对应特征参数的平方和,来确定待剪除通道;例如,可以将平方和小于预设值的通道确定为待剪除通道。
84.s3097:在所述初始分类模型中剪除所述待剪除通道,得到所述剪枝分类模型;
85.在一些实施例中,如图6所示,所述在所述初始分类模型中剪除所述待剪除通道,得到所述剪枝分类模型,包括:
86.s30971:确定所述待剪除通道的待剪除标识信息;
87.s30973:在所述初始分类模型的缩放层后增加索引池化层;
88.在本技术实施例中,通过scale层找到了k和v中不重要的通道,将通道的序号记录下来,然后在scaled attention中的scale层后面加入indexpooling层,即通过pooling选择保留下来的通道,得到的新attention,称为indexpoolingattention。
89.在一些实施例中,所述索引池化层包括第一索引池化层和第二索引池化层,所述在所述初始分类模型的缩放层后增加索引池化层,包括:
90.在第一缩放层与第一矩阵乘法层之间增加第一索引池化层;
91.在第二缩放层与第二矩阵乘法层之间增加第二索引池化层。
92.在一个具体的实施例中,如图11所示,图11中(a)为增加了缩放层的attention(scaled attention)的结构示意图,图11中(b)为增加了索引池化层的attention(indexpoolingattention)的结构示意图。
93.s30975:基于所述索引池化层从所述初始分类模型的通道中,剪除所述待剪除标识信息对应的待剪除通道,得到所述剪枝分类模型。
94.在本技术实施例中,通过pooling选择保留下来的通道,得到的新attention,称为
indexpoolingattention。
95.s3099:基于所述样本数据对所述剪枝分类模型进行对象类别识别训练,得到所述对象分类模型。
96.在本技术实施例中,将模型中的scaledattention都转化为indexpoolingattention之后,对模型进行重新训练。
97.在本技术实施例中,由于scale层和indexpooling的计算量远远小于attention中的矩阵乘法的计算量,通过剪除了k和v的一些通道特征,使得矩阵乘法和softmax操作中的计算量都得到降低,因此达到了降低计算量,从而提升模型识别速度的剪枝目的。
98.在一些实施例中,如图7所示,所述基于所述样本数据对所述剪枝分类模型进行对象类别识别训练,得到所述对象分类模型,包括:
99.s30991:获取所述初始分类模型的初始模型参数;
100.s30993:将所述初始模型参数作为所述剪枝分类模型的初始训练参数;
101.s30995:基于所述样本数据以及所述初始训练参数,对所述剪枝分类模型进行对象类别识别训练,得到所述对象分类模型。
102.在本技术实施例中,由于初始分类模型已经训练到收敛,这里将初始分类模型中训练后参数加载到当前的剪枝分类模型中,再进行训练,从而可以提高模型的收敛速度,且容易保持模型的精度。
103.在一个具体的实施例中,如图12所示,图12为一种对象分类模型的构建方法流程图,包括:
104.首先将模型中的attention模块加入scale层。设原有的attention模型输入经过linear层(线性层)后,得到经过linear层(线性层)后,得到则原有的attention运算可以由公式(1)表示。当输入特征被一个linear层接收时,它们以一个展平成一维张量的形式接收,然后乘以权重矩阵。这个矩阵乘法产生输出特征。
[0105][0106]
在经过linear层之后,对和加入scale层运算,即对向量的n个通道每个通道乘以一个参数,从而可以在训练中通过该参数的学习,得到每个通道的重要性,每个通道的重要性可以通过这个参数的绝对值大小来表征。加入scale层后得到的attention运算可以由公式(2)表示,其中s(*)代表scale层。
[0107][0108]
由于scale层中的参数可以在训练中学习,所以当scale的参数全为1时,等价于原有的attention运算,经过学习后,scale层对应通道的参数越接近0,代表该通道的重要性越低,可以将该通道剪除。
[0109]
具体的,在本技术实施例中,通过剪枝后的模型与未剪枝的模型相比,计算量和模型推理速度都得到了优化。如表1所示,deit是一种基于transformer的图像分类模型,根据模型大小分为deit-small,deit-base等几种,采用本实施例的方法对原deit网络进行剪枝
后,得到的剪枝后模型在保持原有模型在imagenet数据集上的精度前提下,计算量降低,推理速度提升。
[0110]
表1本实施例剪枝后的deit与剪枝前模型的数据对比表
[0111]
模型top1精度计算量(gflops)推理时间(images/s)deit-small79.84.6930deit-base81.817.6290剪枝后deit-small79.53.71120剪枝后deit-base81.314.0350
[0112]
由以上本技术实施例提供的技术方案可见,本技术实施例获取目标对象的目标数据;基于对象分类模型对所述目标数据进行类别识别处理,得到所述目标对象的目标类别标签;所述对象分类模型为基于样本对象的样本数据,对剪枝分类模型进行对象类别识别训练得到,所述剪枝分类模型为在初始分类模型中剪除待剪除通道后的模型,所述待剪除通道为所述初始分类模型中缩放参数绝对值小于预设阈值的通道;所述初始分类模型为基于所述样本数据对预设网络进行对象类别识别训练得到;所述预设网络中包括更新注意力网络,所述更新注意力网络为设置有缩放层的注意力网络;所述初始分类模型中每个通道的缩放参数基于所述缩放层确定;所述样本数据标注有所述样本对象的样本类别标签;本技术通过在预设网络的注意力网络中设置缩放层,然后通过缩放层参数确定待剪枝通道,实现了在包含注意力网络的模型中进行剪枝,再进一步根据剪枝分类模型确定对象分类模型,从而减少了对象分类模型的运算量,提高了模型的计算速度,提高了对象类别的识别速度。
[0113]
本技术实施例还提供了一种对象类别识别装置,如图13所示,所述装置包括:
[0114]
目标数据获取模块1310,用于获取目标对象的目标数据;
[0115]
目标类别确定模块1320,用于基于对象分类模型对所述目标数据进行类别识别处理,得到所述目标对象的目标类别标签;所述对象分类模型为基于样本对象的样本数据,对剪枝分类模型进行对象类别识别训练得到,所述剪枝分类模型为在初始分类模型中剪除待剪除通道后的模型,所述待剪除通道为所述初始分类模型中缩放参数绝对值小于预设阈值的通道;所述初始分类模型为基于所述样本数据对预设网络进行对象类别识别训练得到;所述预设网络中包括更新注意力网络,所述更新注意力网络为设置有缩放层的注意力网络;所述初始分类模型中每个通道的缩放参数基于所述缩放层确定;所述样本数据标注有所述样本对象的样本类别标签。
[0116]
在一些实施例中,所述装置还可以包括:
[0117]
原始注意力网络获取模块,用于获取原始注意力网络;
[0118]
注意力网络更新模块,用于在所述原始注意力网络中加入缩放层,得到所述更新注意力网络;
[0119]
预设网络构建模块,用于基于所述更新注意力网络,构建所述预设网络;
[0120]
初始分类模型确定模块,用于根据所述样本对象的样本数据,对所述预设网络进行对象类别识别训练,得到初始分类模型;所述初始分类模型包括至少两个通道;
[0121]
对象分类模型确定模块,用于基于所述初始分类模型,确定所述对象分类模型。
[0122]
在一些实施例中,所述对象分类模型确定模块包括:
[0123]
目标函数获取单元,用于获取所述初始分类模型中所述缩放层对应的目标函数;
[0124]
缩放参数确定单元,用于根据所述目标函数,确定所述初始分类模型中每个通道对应的缩放参数;
[0125]
待剪除通道确定单元,用于基于所述初始分类模型中每个通道对应的缩放参数,确定待剪除通道;所述待剪除通道为缩放参数绝对值小于预设阈值的通道;
[0126]
剪枝分类模型确定单元,用于在所述初始分类模型中剪除所述待剪除通道,得到所述剪枝分类模型;
[0127]
对象分类模型确定单元,用于基于所述样本数据对所述剪枝分类模型进行对象类别识别训练,得到所述对象分类模型。
[0128]
在一些实施例中,所述缩放层包括第一缩放层和第二缩放层,所述注意力网络更新模块可以包括:
[0129]
网络层确定单元,用于确定所述原始注意力网络中的第一线性层、第二线性层、第一矩阵乘法层和第二矩阵乘法层;所述第一线性层与所述第一矩阵乘法层连接,所述第二线性层与所述第二矩阵乘法层连接;所述第一线性层用于确定数据的内容向量,所述第二线性层用于确定所述数据的内容向量查询标识;
[0130]
缩放层增加单元,用于在所述第一线性层与所述第一矩阵乘法层连接层之间增加所述第一缩放层,在所述第二线性层与所述第二矩阵乘法层连接层之间增加所述第二缩放层,得到所述更新注意力网络。
[0131]
在一些实施例中,所述剪枝分类模型确定单元可以包括:
[0132]
待剪除标识信息确定子单元,用于确定所述待剪除通道的待剪除标识信息;
[0133]
索引池化层增加子单元,用于在所述初始分类模型的缩放层后增加索引池化层;
[0134]
通道剪除子单元,用于基于所述索引池化层从所述初始分类模型的通道中,剪除所述待剪除标识信息对应的待剪除通道,得到所述剪枝分类模型。
[0135]
在一些实施例中,所述索引池化层包括第一索引池化层和第二索引池化层,所述装置还可以包括:
[0136]
在一些实施例中,所述索引池化层增加子单元可以包括:
[0137]
第一增加子单元,用于在第一缩放层与第一矩阵乘法层之间增加第一索引池化层;
[0138]
第二增加子单元,用于在第二缩放层与第二矩阵乘法层之间增加第二索引池化层。
[0139]
在一些实施例中,所述对象分类模型确定单元可以包括:
[0140]
初始模型参数确定子单元,用于获取所述初始分类模型的初始模型参数;
[0141]
初始训练参数确定子单元,用于将所述初始模型参数作为所述剪枝分类模型的初始训练参数;
[0142]
对象分类模型确定子单元,用于基于所述样本数据以及所述初始训练参数,对所述剪枝分类模型进行对象类别识别训练,得到所述对象分类模型。
[0143]
所述的装置实施例中的装置与方法实施例基于同样地发明构思。
[0144]
本技术实施例提供了一种对象类别识别设备,该设备包括处理器和存储器,该存储器中存储有至少一条指令或至少一段程序,该至少一条指令或至少一段程序由该处理器
加载并执行以实现如上述方法实施例所提供的对象类别识别方法。
[0145]
本技术的实施例还提供了一种计算机存储介质,所述存储介质可设置于终端之中以保存用于实现方法实施例中一种对象类别识别方法相关的至少一条指令或至少一段程序,该至少一条指令或至少一段程序由该处理器加载并执行以实现上述方法实施例提供的对象类别识别方法。
[0146]
本技术的实施例还提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行以实现上述方法实施例提供的对象类别识别方法。
[0147]
可选地,在本技术实施例中,存储介质可以位于计算机网络的多个网络服务器中的至少一个网络服务器。可选地,在本实施例中,上述存储介质可以包括但不限于:u盘、只读存储器(rom,read-only memory)、随机存取存储器(ram,random access memory)、移动硬盘、磁碟或者光盘等各种可以存储程序代码的介质。
[0148]
本技术实施例所述存储器可用于存储软件程序以及模块,处理器通过运行存储在存储器的软件程序以及模块,从而执行各种功能应用以及数据处理。存储器可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、功能所需的应用程序等;存储数据区可存储根据所述设备的使用所创建的数据等。此外,存储器可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。相应地,存储器还可以包括存储器控制器,以提供处理器对存储器的访问。
[0149]
本技术实施例所提供的对象类别识别方法实施例可以在移动终端、计算机终端、服务器或者类似的运算装置中执行。以运行在服务器上为例,图14是本技术实施例提供的一种对象类别识别方法的服务器的硬件结构框图。如图14所示,该服务器1400可因配置或性能不同而产生比较大的差异,可以包括一个或一个以上中央处理器(central processing units,cpu)1410(中央处理器1410可以包括但不限于微处理器mcu或可编程逻辑器件fpga等的处理装置)、用于存储数据的存储器1430,一个或一个以上存储应用程序1423或数据1422的存储介质1420(例如一个或一个以上海量存储设备)。其中,存储器1430和存储介质1420可以是短暂存储或持久存储。存储在存储介质1420的程序可以包括一个或一个以上模块,每个模块可以包括对服务器中的一系列指令操作。更进一步地,中央处理器1410可以设置为与存储介质1420通信,在服务器1400上执行存储介质1420中的一系列指令操作。服务器1400还可以包括一个或一个以上电源1460,一个或一个以上有线或无线网络接口1450,一个或一个以上输入输出接口1440,和/或,一个或一个以上操作系统1421,例如windows servertm,mac os xtm,unixtm,linuxtm,freebsdtm等等。
[0150]
输入输出接口1440可以用于经由一个网络接收或者发送数据。上述的网络具体实例可包括服务器1400的通信供应商提供的无线网络。在一个实例中,输入输出接口1440包括一个网络适配器(network interface controller,nic),其可通过基站与其他网络设备相连从而可与互联网进行通讯。在一个实例中,输入输出接口1440可以为射频(radio frequency,rf)模块,其用于通过无线方式与互联网进行通讯。
[0151]
本领域普通技术人员可以理解,图14所示的结构仅为示意,其并不对上述电子装
置的结构造成限定。例如,服务器1400还可包括比图14中所示更多或者更少的组件,或者具有与图14所示不同的配置。
[0152]
由上述本技术提供的对象类别识别方法、装置、设备或存储介质的实施例可见,本技术获取目标对象的目标数据;基于对象分类模型对所述目标数据进行类别识别处理,得到所述目标对象的目标类别标签;所述对象分类模型为基于样本对象的样本数据,对剪枝分类模型进行对象类别识别训练得到,所述剪枝分类模型为在初始分类模型中剪除待剪除通道后的模型,所述待剪除通道为所述初始分类模型中缩放参数绝对值小于预设阈值的通道;所述初始分类模型为基于所述样本数据对预设网络进行对象类别识别训练得到;所述预设网络中包括更新注意力网络,所述更新注意力网络为设置有缩放层的注意力网络;所述初始分类模型中每个通道的缩放参数基于所述缩放层确定;所述样本数据标注有所述样本对象的样本类别标签;本技术通过在预设网络的注意力网络中设置缩放层,然后通过缩放层参数确定待剪枝通道,实现了在包含注意力网络的模型中进行剪枝,再进一步根据剪枝分类模型确定对象分类模型,从而减少了对象分类模型的运算量,提高了模型的计算速度,提高了对象类别的识别速度。
[0153]
需要说明的是:上述本技术实施例先后顺序仅仅为了描述,不代表实施例的优劣。且上述对本说明书特定实施例进行了描述。其它实施例在所附权利要求书的范围内。在一些情况下,在权利要求书中记载的动作或步骤可以按照不同于实施例中的顺序来执行并且仍然可以实现期望的结果。另外,在附图中描绘的过程不一定要求示出的特定顺序或者连续顺序才能实现期望的结果。在某些实施方式中,多任务处理和并行处理也是可以的或者可能是有利的。
[0154]
本说明书中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置、设备、存储介质实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
[0155]
本领域普通技术人员可以理解实现上述实施例的全部或部分步骤可以通过硬件来完成,也可以通过程序来指令相关的硬件完成,所述的程序可以存储于一种计算机存储介质中,上述提到的存储介质可以是只读存储器,磁盘或光盘等。
[0156]
以上所述仅为本技术的较佳实施例,并不用以限制本技术,凡在本技术的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本技术的保护范围之内。
当前第1页1 2 
网友询问留言 已有0条留言
  • 还没有人留言评论。精彩留言会获得点赞!
1