推断 (inference )是指利用观测到的数据或信息,对未知的概率分布、参数或假设进行推论的过程。在概率模型应用中,一项核心任务是对给定观测(可见)数据变量 X \mathbf{X} X 条件下潜在变量 Z \mathbf{Z} Z 的后验分布 p ( Z ∣ X ) p(\mathbf{Z}|\mathbf{X}) p ( Z ∣ X ) 进行评估,以及计算相对于该分布的期望。模型可能还包含一些确定性参数,这里我们暂时将其视为隐含的;或者,它可能是一个完全贝叶斯模型,其中任何未知参数都被赋予先验分布,并纳入由向量 Z \mathbf{Z} Z 表示的潜在变量集合中。例如,在EM算法中,我们需要相对于潜在变量的后验分布来计算完整数据对数似然的期望。对于许多实际应用中的模型而言,评估后验分布或计算相对于该分布的期望往往是不可行的。这可能是因为潜在空间的维度太高而无法直接处理,或是因为后验分布形式过于复杂,导致期望无法解析求解。对于连续变量,所需的积分可能没有闭合形式的解析解,而空间的维度以及被积函数的复杂性也可能使数值积分难以进行。对于离散变量,边缘化过程涉及对所有可能的隐藏变量配置进行求和。尽管原则上这总是可行的,但在实际应用中,我们常常发现隐藏状态的数量可能呈指数级增长,以至于精确计算的代价高昂得难以承受。
在这种情况下,我们需要借助近似方法,这些方法根据其依赖随机近似还是确定性近似,大致分为两类。诸如马尔可夫链蒙特卡洛之类的随机技术,已在许多领域推动了贝叶斯方法的广泛应用。这类方法通常具有一种特性:在拥有无限计算资源的情况下,它们能够产生精确结果,而近似性则源于使用有限的处理器时间。实际上,采样方法的计算需求可能很高,常常限制其只能用于小规模问题。此外,有时很难判断一个采样方案是否从所需分布中生成了独立样本。
在本章中,我们将介绍一系列确定性近似方法,其中一些方法能够很好地扩展到大规模应用中。这些方法基于对后验分布的解析近似,例如通过假设其以特定方式分解或具有特定的参数形式(如高斯分布)。因此,它们永远无法产生精确结果,所以其优势与劣势与采样方法互为补充。
泛函数与变分法
泛函 (functional )指以函数构成的向量空间为定义域,实数为值域为的“函数”,即某一个依赖于其它一个或者几个函数确定其值的量,往往被称为“函数的函数”。例如,连续函数空间上的定积分是典型的泛函,它将函数转化为实数。变分法 (variational method )处理泛函的极值问题,其基本思想源于对泛函的微小扰动分析,旨在找到路径、曲线、曲面等,使得给定的函数具有平稳值(在物理问题中,通常是最小值或最大值)。
我们可以将函数 y ( x ) y(x) y ( x ) 视为一个算子,对于任意输入值 x x x ,它返回一个输出值 y y y 。同样地,我们可以定义一个泛函 F [ y ] F[y] F [ y ] 为一个算子,它以函数 y ( x ) y(x) y ( x ) 作为输入并返回一个输出值 F F F 。泛函的一个例子是二维平面上绘制的曲线长度,其中曲线的路径由某个函数定义。在机器学习领域,一个广泛使用的泛函是连续变量 x x x 的熵 H [ x ] H[x] H [ x ] ,因为对于任意选择的概率密度函数 p ( x ) p(x) p ( x ) ,它返回一个标量值,表示在该密度下 x x x 的熵。因此,p ( x ) p(x) p ( x ) 的熵同样可以写为 H [ p ] H[p] H [ p ] 。
传统微积分中的一个常见问题是找到使函数 y ( x ) y(x) y ( x ) 最大化(或最小化)的 x x x 值。类似地,在变分法中,我们寻找使泛函 F [ y ] F[y] F [ y ] 最大化(或最小化)的函数 y ( x ) y(x) y ( x ) 。也就是说,在所有可能的函数 y ( x ) y(x) y ( x ) 中,我们希望找到使泛函 F [ y ] F[y] F [ y ] 达到最大值(或最小值)的那个特定函数。例如,变分法可以用来证明两点之间的最短路径是一条直线,或者最大熵分布是高斯分布。
在普通微积分中,我们可以通过对变量 x x x 进行一个微小改变 ϵ \epsilon ϵ ,然后按 ϵ \epsilon ϵ 的幂次展开来求常规导数 d y d x \frac{\mathrm{d}y}{\mathrm{d}x} d x d y ,即
y ( x + ϵ ) = y ( x ) + d y d x ϵ + O ( ϵ 2 ) y(x+\epsilon)=y(x)+\frac{\mathrm{d}y}{\mathrm{d}x}\epsilon+O(\epsilon^{2})
y ( x + ϵ ) = y ( x ) + d x d y ϵ + O ( ϵ 2 )
最后取极限 ϵ → 0 \epsilon \to 0 ϵ → 0 。类似地,对于多变量函数 y ( x 1 , … , x D ) y(x_{1},\ldots,x_{D}) y ( x 1 , … , x D ) ,相应的偏导数定义为
y ( x 1 + ϵ 1 , … , x D + ϵ D ) = y ( x 1 , … , x D ) + ∑ i = 1 D ∂ y ∂ x i ϵ i + O ( ϵ 2 ) y(x_{1}+\epsilon_{1},\ldots,x_{D}+\epsilon_{D})=y(x_{1},\ldots,x_{D})+\sum_{i=1 }^{D}\frac{\partial y}{\partial x_{i}}\epsilon_{i}+O(\epsilon^{2})
y ( x 1 + ϵ 1 , … , x D + ϵ D ) = y ( x 1 , … , x D ) + i = 1 ∑ D ∂ x i ∂ y ϵ i + O ( ϵ 2 )
类似地,当我们考虑对函数 y ( x ) y(x) y ( x ) 做一个微小改变 ϵ η ( x ) \epsilon\eta(x) ϵη ( x ) 时,泛函 F [ y ] F[y] F [ y ] 改变了多少,这就引出了泛函导数的类似定义。我们将 E [ f ] E[f] E [ f ] 关于 f ( x ) f(x) f ( x ) 的泛函导数记作 δ F δ f ( x ) \frac{\delta F}{\delta f(x)} δ f ( x ) δ F ,并通过以下关系定义它:
F [ y ( x ) + ϵ η ( x ) ] = F [ y ( x ) ] + ϵ ∫ δ F δ y ( x ) η ( x ) d x + O ( ϵ 2 ) F[y(x)+\epsilon\eta(x)]=F[y(x)]+\epsilon\int\frac{\delta F}{\delta y(x)}\eta(x)\,\mathrm{d}x+O(\epsilon^{2})
F [ y ( x ) + ϵη ( x )] = F [ y ( x )] + ϵ ∫ δy ( x ) δ F η ( x ) d x + O ( ϵ 2 )
这可以看作是多变量函数偏导数的自然推广,其中 F [ y ] F[y] F [ y ] 现在依赖于一组连续的变量,即 y y y 在所有点 x x x 处的值。要求泛函关于函数 y ( x ) y(x) y ( x ) 的小变化是平稳的,则有
∫ δ E δ y ( x ) η ( x ) d x = 0 \int\frac{\delta E}{\delta y(x)}\eta(x)\,\mathrm{d}x=0
∫ δy ( x ) δ E η ( x ) d x = 0
由于这对 η ( x ) \eta(x) η ( x ) 的任意选择都必须成立,因此泛函导数必须为零。要理解这一点,可以想象选择一个扰动 η ( x ) \eta(x) η ( x ) ,它除了在点 x ^ \widehat{x} x 的邻域内之外处处为零,那么泛函导数在 x = x ^ x = \widehat{x} x = x 处必须为零。然而,由于这对每个 x ^ \widehat{x} x 的选择都必须成立,所以泛函导数对所有 x x x 值都必须为零。
对于泛函
I = ∫ x 1 x 2 f ( x , y , y ′ ) d x I = \int_{x_1}^{x_2} f(x, y, y') dx
I = ∫ x 1 x 2 f ( x , y , y ′ ) d x
取极值的必要条件是:
∂ f ∂ y − d d x ( ∂ f ∂ y ′ ) = 0 \begin{equation}
\frac{\partial f}{\partial y} - \frac{d}{dx} \left( \frac{\partial f}{\partial y'} \right) = 0
\end{equation}
∂ y ∂ f − d x d ( ∂ y ′ ∂ f ) = 0
这,就是欧拉-拉格朗日方程,是我们确定泛函极值点的关键工具,相当于“函数在极值点处导数为零”这个工具。
两点之间为什么线段最短呢?
已知曲线的弧微分
d s = 1 + y ′ 2 d x ds = \sqrt{1 + y'^2} dx
d s = 1 + y ′2 d x
所以连接 A A A 、B B B 两点的曲线长度可以表示为:
S = ∫ x 1 x 2 1 + y ′ 2 d x S = \int_{x_1}^{x_2} \sqrt{1 + y'^2} dx
S = ∫ x 1 x 2 1 + y ′2 d x
这里面的 1 + y ′ 2 \sqrt{1+y'^2} 1 + y ′2 就是欧拉-拉格朗日方程中的 f f f ,于是有:
∂ f ∂ y = 0 \frac{\partial f}{\partial y} = 0
∂ y ∂ f = 0
∂ f ∂ y ′ = y ′ 1 + y ′ 2 \frac{\partial f}{\partial y'} = \frac{y'}{\sqrt{1+y'^2}}
∂ y ′ ∂ f = 1 + y ′2 y ′
代入欧拉-拉格朗日方程就得到:
0 − d d x ( y ′ 1 + y ′ 2 ) = 0 → d d x ( y ′ 1 + y ′ 2 ) = 0 0 - \frac{d}{dx} \left( \frac{y'}{\sqrt{1+y'^2}} \right) = 0 \rightarrow \frac{d}{dx} \left( \frac{y'}{\sqrt{1+y'^2}} \right) = 0
0 − d x d ( 1 + y ′2 y ′ ) = 0 → d x d ( 1 + y ′2 y ′ ) = 0
所以
y ′ 1 + y ′ 2 = C ( c o n s t ) \frac{y'}{\sqrt{1+y'^2}} = C(const)
1 + y ′2 y ′ = C ( co n s t )
所以 y ′ = c ( c o n s t ) y' = c(const) y ′ = c ( co n s t ) ,也即:y = c x + b y = cx + b y = c x + b ,是一条直线。这样就证明了两点之间线段最短。
KL散度
Kullback-Leibler散度 (Kullback-Leibler divergence ,KL divergence )又被称为相对熵(relative entropy ),记作K L ( P ∥ Q ) KL(P \parallel Q) K L ( P ∥ Q ) ,是一种衡量两个概率分布P P P 和Q Q Q 之间差异的非对称性度量。
对于离散概率分布,其定义可表述为:
K L ( P ∥ Q ) = ∑ P ( x ) log P ( x ) Q ( x ) , \begin{equation}
\mathrm{KL}(P \parallel Q) = \sum P(x) \log \frac{P(x)}{Q(x)},
\end{equation}
KL ( P ∥ Q ) = ∑ P ( x ) log Q ( x ) P ( x ) ,
对于连续概率分布,相应的定义通过概率密度函数以积分形式给出:
K L ( P ∥ Q ) = ∫ p ( x ) log p ( x ) q ( x ) d x , \begin{equation}
\mathrm{KL}(P \parallel Q) = \int p(x) \log \frac{p(x)}{q(x)} \, dx,
\end{equation}
KL ( P ∥ Q ) = ∫ p ( x ) log q ( x ) p ( x ) d x ,
KL散度是不对称的。
优化目标:
min K L ( P ∥ Q ) = ∫ p ( x ) log p ( x ) q ( x ) d x s.t. ∫ p ( x ) d x = 1 \begin{equation}
\begin{aligned}
\min & \quad \mathrm{KL}(P \parallel Q) = \int p(x) \log \frac{p(x)}{q(x)} \, dx\\
\text{s.t.} & \quad \int p(x)dx = 1
\end{aligned}
\end{equation}
min s.t. KL ( P ∥ Q ) = ∫ p ( x ) log q ( x ) p ( x ) d x ∫ p ( x ) d x = 1
构造拉格朗日对偶:
L [ p ] = ∫ p ( x ) log p ( x ) q ( x ) d x + λ ∫ p ( x ) d x = ∫ p ( x ) log p ( x ) q ( x ) + λ p ( x ) d x L[p] = \int p(x) \log \frac{p(x)}{q(x)} \, dx + \lambda \int p(x)dx = \int p(x) \log \frac{p(x)}{q(x)}+ \lambda p(x)dx
L [ p ] = ∫ p ( x ) log q ( x ) p ( x ) d x + λ ∫ p ( x ) d x = ∫ p ( x ) log q ( x ) p ( x ) + λ p ( x ) d x
使用欧拉–拉格朗日方程简化:
∂ ∂ p ( x ) ( p ( x ) log p ( x ) q ( x ) + λ p ( x ) ) = 0 \frac{\partial }{\partial p(x)}(p(x) \log \frac{p(x)}{q(x)}+ \lambda p(x)) = 0
∂ p ( x ) ∂ ( p ( x ) log q ( x ) p ( x ) + λ p ( x )) = 0
即:
log p ( x ) q ( x ) + 1 + λ = 0 p ( x ) q ( x ) = C \begin{aligned}
\log \frac{p(x)}{q(x)} + 1 + \lambda = 0 \\
\frac{p(x)}{q(x)} = C \\
\end{aligned}
log q ( x ) p ( x ) + 1 + λ = 0 q ( x ) p ( x ) = C
因为:
∫ p ( x ) d x = C ∫ q ( x ) d x = 1 \int p(x)dx = C\int q(x)dx = 1
∫ p ( x ) d x = C ∫ q ( x ) d x = 1
所以C = 1 C = 1 C = 1 ,即KL散度为0当且仅当P P P 与Q Q Q 在离散型变量的情况下是相同的分布,或者在连续型变量的情况下是“处处相同”的。因为KL散度是非负的,并且衡量的是两个分布之间的差异,所以往往被用作分布之间的某种“距离”。
变分推断
参考文章:证据下界(ELBO)、EM算法、变分推断、变分自编码器(VAE)和混合高斯模型(GMM) - 知乎
变分推断是一种用于近似复杂后验分布的机器学习方法。给定观测数据 x x x 、隐变量 z z z 和待估计参数θ \theta θ ,联合分布为 p θ ( x , z ) p_{\theta}(x,z) p θ ( x , z ) ,后验分布 p θ ( z ∣ x ) p_{\theta}(z|x) p θ ( z ∣ x ) 通常难以直接计算。变分推断引入一个变分分布 p ϕ ( z ) p_{\phi}(z) p ϕ ( z ) ,通过最大化证据下界来近似后验分布。
对x x x 做极大似然估计,即最大化以下目标函数:
L ( θ ) = ∫ x ∼ p ( x ) log p θ ( x ) = E x [ L x ( θ ) ] = E x [ log p θ ( x ) ] \begin{align*}
L(\theta) &= \int_{x \sim p(x)} \log p_{\theta}(x)\\
& = E_{x}[L_{x}(\theta)] \\
& = E_{x}[\log p_{\theta}(x)]
\end{align*}
L ( θ ) = ∫ x ∼ p ( x ) log p θ ( x ) = E x [ L x ( θ )] = E x [ log p θ ( x )]
其中:
L x ( θ ) = log p θ ( x ) = log ( E z [ p θ ( x , z ) ] ) = log ( E z [ p θ ( z ∣ x ) p θ ( x ) ] ) = log ( E z [ p θ ( z ∣ x ) p θ ( x , z ) p θ ( z ∣ x ) ] ) \begin{align*}
L_{x}(\theta) &= \log p_{\theta}(x)\\
&= \log(E_{z}[p_{\theta}(x,z)])\\
&= \log(E_{z}[p_{\theta}(z|x)p_{\theta}(x)])\\
&= \log\left( E_{z}\left[ p_{\theta}(z|x) \frac{p_{\theta}(x,z)}{p_{\theta}(z|x)} \right] \right)
\end{align*}
L x ( θ ) = log p θ ( x ) = log ( E z [ p θ ( x , z )]) = log ( E z [ p θ ( z ∣ x ) p θ ( x )]) = log ( E z [ p θ ( z ∣ x ) p θ ( z ∣ x ) p θ ( x , z ) ] )
对于我们的目标p θ ( z ∣ x ) p_{\theta}(z|x) p θ ( z ∣ x ) ,我们要找一个最好的p ϕ ( z ∣ x ) p_{\phi}(z|x) p ϕ ( z ∣ x ) 去近似该分布,即最小化函数:
K L ( p ϕ ( z ∣ x ) ∥ p θ ( z ∣ x ) ) = ∫ z p ϕ ( z ∣ x ) log p ϕ ( z ∣ x ) p θ ( z ∣ x ) d z = ∫ z p ϕ ( z ∣ x ) log p ϕ ( z ∣ x ) p θ ( x ) p θ ( z , x ) d z = ∫ z p ϕ ( z ∣ x ) log p ϕ ( z ∣ x ) p θ ( z , x ) d z + ∫ z p ϕ ( z ∣ x ) log p θ ( x ) d z = ∫ z p ϕ ( z ∣ x ) log p ϕ ( z ∣ x ) p θ ( z , x ) d z + log p θ ( x ) \begin{align}
KL(p_{\phi}(z|x)\|p_{\theta}(z|x)) &= \int_{z} p_{\phi}(z|x)\log \frac{p_{\phi}(z|x)}{p_{\theta}(z|x)} dz\\
&= \int_{z} p_{\phi}(z|x)\log \frac{p_{\phi}(z|x)p_{\theta}(x)}{p_{\theta}(z,x)} dz \notag \\
&= \int_{z} p_{\phi}(z|x)\log \frac{p_{\phi}(z|x)}{p_{\theta}(z,x)} dz + \int_{z} p_{\phi}(z|x)\log p_{\theta}(x) dz\notag \\
&= \int_{z} p_{\phi}(z|x)\log \frac{p_{\phi}(z|x)}{p_{\theta}(z,x)} dz + \log p_{\theta}(x) \notag
\end{align}
K L ( p ϕ ( z ∣ x ) ∥ p θ ( z ∣ x )) = ∫ z p ϕ ( z ∣ x ) log p θ ( z ∣ x ) p ϕ ( z ∣ x ) d z = ∫ z p ϕ ( z ∣ x ) log p θ ( z , x ) p ϕ ( z ∣ x ) p θ ( x ) d z = ∫ z p ϕ ( z ∣ x ) log p θ ( z , x ) p ϕ ( z ∣ x ) d z + ∫ z p ϕ ( z ∣ x ) log p θ ( x ) d z = ∫ z p ϕ ( z ∣ x ) log p θ ( z , x ) p ϕ ( z ∣ x ) d z + log p θ ( x )
所以:
L x ( θ ) = ∫ z p ϕ ( z ∣ x ) log p θ ( z , x ) p ϕ ( z ∣ x ) d z + K L ( p ϕ ( z ∣ x ) ∥ p θ ( z ∣ x ) ) \begin{align}
L_{x}(\theta) = \int_{z} p_{\phi}(z|x)\log \frac{p_{\theta}(z,x)}{p_{\phi}(z|x)} dz + KL(p_{\phi}(z|x)\|p_{\theta}(z|x))
\end{align}
L x ( θ ) = ∫ z p ϕ ( z ∣ x ) log p ϕ ( z ∣ x ) p θ ( z , x ) d z + K L ( p ϕ ( z ∣ x ) ∥ p θ ( z ∣ x ))
其中:
E L B O = ∫ z p ϕ ( z ∣ x ) log p θ ( z , x ) p ϕ ( z ∣ x ) d z \begin{align}
ELBO = \int_{z} p_{\phi}(z|x)\log \frac{p_{\theta}(z,x)}{p_{\phi}(z|x)} dz
\end{align}
E L BO = ∫ z p ϕ ( z ∣ x ) log p ϕ ( z ∣ x ) p θ ( z , x ) d z
即为证据下界 (Evidence Lower Bound , ELBO )。因为K L ( p ϕ ( z ∣ x ) ∥ p θ ( z ∣ x ) ) KL(p_{\phi}(z|x)\|p_{\theta}(z|x)) K L ( p ϕ ( z ∣ x ) ∥ p θ ( z ∣ x )) 非负,所以:
L x ( θ ) ≥ E L B O \begin{align}
L_{x}(\theta) \geq ELBO
\end{align}
L x ( θ ) ≥ E L BO
其中L x ( θ ) = log p θ ( x ) L_{x}(\theta) = \log p_{\theta}(x) L x ( θ ) = log p θ ( x ) 是与ϕ \phi ϕ 无关的常数,所以最大化证据下界等价于最小化K L ( p ϕ ( z ∣ x ) ∥ p θ ( z ∣ x ) ) KL(p_{\phi}(z|x)\|p_{\theta}(z|x)) K L ( p ϕ ( z ∣ x ) ∥ p θ ( z ∣ x )) 。