本文涉及一种数据处理技术,尤指一种扩散常微分模型的训练方法、概率密度估计方法及设备。
背景技术:
1、目前尚无针对常微分方程扩散模型的基于得分函数匹配的最大似然训练算法。现有的基于一阶去噪得分匹配的训练算法只能用于随机微分方程扩散模型的最大似然训练,无法用于常微分方程扩散模型。
2、此外,已有的高阶去噪得分匹配算法需要预训练的低阶得分函数模型来近似真实低阶得分函数,其近似误差在实践中往往不严格为零,这导致已有方法的高阶得分匹配误差很可能趋于无穷,从而给高阶得分匹配带来严重的困难。
技术实现思路
1、本申请实施例提供了一种扩散常微分模型的训练方法、概率密度估计方法及设备,使得训练出的模型能够提高对图像数据的概率密度估计的准确度。
2、本申请提供了一种扩散常微分模型的训练方法,包括:
3、从标准高斯分布中采样高斯噪声,对于预定的多个训练图片,分别计算每个所述训练图片和所述高斯噪声的线性组合,得到多个加噪图片;
4、依次通过每个加噪图片对所述扩散常微分模型的参数进行训练,直至满足训练结束条件;其中,每次训练的过程包括:
5、采用预定的一阶得分函数模型计算所述加噪图片的一阶得分估计,根据所述一阶得分估计得到所述加噪图片的一阶得分估计误差;
6、获取所述加噪图片的二阶得分估计,根据所述二阶得分估计得到所述加噪图片的二阶得分估计误差;
7、根据所述一阶得分估计误差与所述二阶得分估计误差训练所述扩散常微分模型的参数;所述扩散常微分模型包括所述一阶得分函数模型。
8、在一些示例性实施例中,所述根据所述一阶得分估计得到所述加噪图片的一阶得分估计差包括:
9、将所述一阶得分估计乘以所述高斯噪声的标准差,将得到的乘积与所述高斯噪声相加,得到所述加噪图片的一阶得分估计差。
10、在一些示例性实施例中,所述获取所述加噪图片的二阶得分估计包括:
11、计算所述一阶得分函数模型相对于加噪图片的雅克比矩阵,得到所述加噪图片的二阶得分估计。
12、在一些示例性实施例中,所述根据所述二阶得分估计得到所述加噪图片的二阶得分估计误差包括:
13、将所述加噪图片的二阶得分估计乘以所述高斯噪声的标准差的平方,将乘积与单位矩阵相加,得到二阶临时变量矩阵;
14、采用二阶临时变量矩阵减去所述加噪图片的一阶得分估计误差与该误差自身的外积,得到第一结果;计算所述第一结果的f-范数的平方,得到所述加噪图片的二阶得分估计误差。
15、在一些示例性实施例中,所述根据所述一阶得分估计误差与所述二阶得分估计误差训练所述扩散常微分模型的参数包括:
16、将所述一阶得分估计误差与所述二阶得分估计误差进行加权相加得到总误差,利用最小化后的所述总误差,训练所述扩散常微分模型的参数。
17、在一些示例性实施例中,所述根据所述一阶得分估计误差与所述二阶得分估计误差训练所述扩散常微分模型的参数包括:
18、采用所述二阶临时变量矩阵的迹减去所述加噪图片的一阶得分估计误差的二范数的平方,得到第二结果;计算所述第二结果的绝对值的平方,得到所述加噪图片的二阶得分迹估计的误差;
19、计算一阶得分函数模型对加噪图片的雅克比的矩阵的迹,并使用得到的迹对所述加噪图片进行求导,得到所述加噪图片的三阶得分估计;
20、将所述加噪图片的三阶得分估计乘以所述高斯噪声的标准差的三次方,与所述一阶得分估计误差的二范数的平方乘以一阶得分估计误差所得的乘积相加,减去所述二阶临时变量矩阵的迹与所述一阶得分估计误差的乘积,再减去所述二阶临时变量矩阵与所述一阶得分估计误差的乘积的两倍,得到加噪图片的三阶得分估计误差;
21、将所述一阶得分估计误差、所述二阶得分矩阵估计误差、所述二阶得分迹估计误差、所述三阶得分估计误差进行加权相加得到总误差,利用最小化后的所述总误差,训练所述扩散常微分模型的参数。
22、在一些示例性实施例中,所述获取所述加噪图片的二阶得分估计包括:
23、计算所述一阶得分函数模型相对于加噪图片的雅克比矩阵的迹,得到所述加噪图片的二阶得分估计。
24、在一些示例性实施例中,所述根据所述二阶得分估计得到所述加噪图片的二阶得分估计误差包括:
25、将所述加噪图片的二阶得分估计乘以所述高斯噪声的标准差的平方,与数据维度相加,减去所述加噪图片的一阶得分估计误差的二范数的平方,得到第三结果;计算所述第三结果的绝对值的平方,得到加噪图片的二阶得分估计的误差。
26、在一些示例性实施例中,所述根据所述一阶得分估计误差与所述二阶得分估计误差训练所述扩散常微分模型的参数包括:
27、将所述加噪图片的一阶得分估计误差与所述二阶得分估计误差进行加权相加得到总误差,利用最小化后的所述总误差,训练所述扩散常微分模型的参数。
28、本申请还提供了一种图片的概率密度估计方法,包括:
29、将待估计图片输入按照如上述任一项所述的扩散常微分模型的训练方法训练好的扩散常微分模型中;
30、通过所述扩散常微分模型对所述待估计图片进行概率密度估计。
31、本申请还提供了一种设备,包括存储器和处理器;
32、所述存储器用于保存可执行程序;
33、所述处理器用于读取执行所述可执行程序,进行如上述任一项所述的扩散常微分模型的训练方法,或进行如上述所述的图片的概率密度估计方法。
34、与相关技术相比,本申请提供的扩散常微分模型的训练方法、概率密度估计方法和设备,通过误差有界的高阶得分匹配训练方法来训练常微分方程扩散模型,从而使得训练出的模型能够提高对图像数据的概率密度估计的准确度。
35、本申请的其它特征和优点将在随后的说明书中阐述,并且,部分地从说明书中变得显而易见,或者通过实施本申请而了解。本申请的其他优点可通过在说明书以及附图中所描述的方案来实现和获得。
1.一种扩散常微分模型的训练方法,其特征在于,包括:
2.如权利要求1所述的扩散常微分模型的训练方法,其特征在于,所述根据所述一阶得分估计得到所述加噪图片的一阶得分估计差包括:
3.如权利要求1所述的扩散常微分模型的训练方法,其特征在于,所述获取所述加噪图片的二阶得分估计包括:
4.如权利要求3所述的扩散常微分模型的训练方法,其特征在于,所述根据所述二阶得分估计得到所述加噪图片的二阶得分估计误差包括:
5.如权利要求4所述的扩散常微分模型的训练方法,其特征在于,所述根据所述一阶得分估计误差与所述二阶得分估计误差训练所述扩散常微分模型的参数包括:
6.如权利要求4所述的扩散常微分模型的训练方法,其特征在于,所述根据所述一阶得分估计误差与所述二阶得分估计误差训练所述扩散常微分模型的参数包括:
7.如权利要求1所述的扩散常微分模型的训练方法,其特征在于,所述获取所述加噪图片的二阶得分估计包括:
8.如权利要求7所述的扩散常微分模型的训练方法,其特征在于,所述根据所述二阶得分估计得到所述加噪图片的二阶得分估计误差包括:
9.如权利要求8所述的扩散常微分模型的训练方法,其特征在于,所述根据所述一阶得分估计误差与所述二阶得分估计误差训练所述扩散常微分模型的参数包括:
10.一种图片的概率密度估计方法,其特征在于,包括:
11.一种设备,包括存储器和处理器;其特征在于: