上一节对贝叶斯算法在线性回归中的任务进行介绍,本节将介绍贝叶斯线性回归推断任务的推导过程。
贝叶斯线性回归中的推断任务(Inference)本质上是求解模型参数W\mathcal WW的后验概率结果P(W∣Data)\mathcal P(\mathcal W \mid Data)P(W∣Data):
其中
DataDataData表示数据集合,包含样本集合
X\mathcal XX和对应标签集合
Y\mathcal YY.
P(W∣Data)=P(Y∣W,X)⋅P(W)∫WP(Y∣W,X)⋅P(W)dW∝P(Y∣W,X)⋅P(W)\begin{aligned} \mathcal P(\mathcal W \mid Data) & = \frac{\mathcal P(\mathcal Y \mid \mathcal W,\mathcal X) \cdot \mathcal P(\mathcal W)}{\int_{\mathcal W} \mathcal P(\mathcal Y \mid \mathcal W,\mathcal X) \cdot \mathcal P(\mathcal W) d\mathcal W} \\ & \propto \mathcal P(\mathcal Y \mid \mathcal W,\mathcal X) \cdot \mathcal P(\mathcal W) \end{aligned}P(W∣Data)=∫WP(Y∣W,X)⋅P(W)dWP(Y∣W,X)⋅P(W)∝P(Y∣W,X)⋅P(W)
其中P(Y∣W,X)\mathcal P(\mathcal Y \mid \mathcal W,\mathcal X)P(Y∣W,X)是似然(Likelihood),根据线性回归模型的定义,P(Y∣W,X)\mathcal P(\mathcal Y \mid \mathcal W,\mathcal X)P(Y∣W,X)服从高斯分布:
各样本之间’独立同分布‘~
Y=WTX+ϵϵ∼N(0,σ2)P(Y∣W,X)∼N(WTX,σ2)=∏i=1NN(WTx(i),σ2)\begin{aligned} \mathcal Y & = \mathcal W^T\mathcal X + \epsilon \quad \epsilon \sim \mathcal N(0,\sigma^2) \\ \mathcal P(\mathcal Y \mid \mathcal W,\mathcal X) & \sim \mathcal N(\mathcal W^T \mathcal X,\sigma^2) \\ & = \prod_{i=1}^N \mathcal N(\mathcal W^Tx^{(i)},\sigma^2) \end{aligned}YP(Y∣W,X)=WTX+ϵϵ∼N(0,σ2)∼N(WTX,σ2)=i=1∏NN(WTx(i),σ2)
P(W)\mathcal P(\mathcal W)P(W)表示先验分布(Piror Distribution),表示推断前给定的初始分布。这里假设P(W)\mathcal P(\mathcal W)P(W)同样服从高斯分布:
先验分布
P(W)\mathcal P(\mathcal W)P(W)的完整表达是
P(W∣X)\mathcal P(\mathcal W \mid \mathcal X)P(W∣X),这里
W\mathcal WW和样本
X\mathcal XX无关,故省略。
P(W)∼N(0,Σprior)\mathcal P(\mathcal W) \sim \mathcal N(0,\Sigma_{prior})P(W)∼N(0,Σprior)
根据指数族分布的共轭性质 以及高斯分布自身的自共轭性质,后验P(W∣Data)\mathcal P(\mathcal W \mid Data)P(W∣Data)同样服从高斯分布。定义其高斯分布为N(μW,ΣW)\mathcal N(\mu_{\mathcal W},\Sigma_{\mathcal W})N(μW,ΣW),具体表达如下:
N(μW,ΣW)∝N(WTX,σ2)⋅N(0,Σprior)=[∏i=1NN(y(i)∣WTx(i),σ2)]⋅N(0,Σprior)\begin{aligned} \mathcal N(\mu_{\mathcal W},\Sigma_{\mathcal W}) & \propto \mathcal N(\mathcal W^T\mathcal X,\sigma^2) \cdot \mathcal N(0,\Sigma_{prior}) \\ & = \left[\prod_{i=1}^N \mathcal N(y^{(i)} \mid \mathcal W^Tx^{(i)},\sigma^2)\right] \cdot \mathcal N(0,\Sigma_{prior}) \end{aligned}N(μW,ΣW)∝N(WTX,σ2)⋅N(0,Σprior)=[i=1∏NN(y(i)∣WTx(i),σ2)]⋅N(0,Σprior)
推断任务的目的就是求解N(μW,ΣW)\mathcal N(\mu_{\mathcal W},\Sigma_{\mathcal W})N(μW,ΣW)的分布形式,即求解分布参数μW,ΣW\mu_{\mathcal W},\Sigma_{\mathcal W}μW,ΣW。
首先观察似然的概率分布,并进行展开:
需要注意的是:
N(y(i)∣WTx(i),σ2)(i=1,2,⋯,N)\mathcal N(y^{(i)} \mid \mathcal W^Tx^{(i)},\sigma^2)(i=1,2,\cdots,N)N(y(i)∣WTx(i),σ2)(i=1,2,⋯,N)是一维高斯分布。
P(Y∣W,X)∼∏i=1NN(y(i)∣WTx(i),σ2)=∏i=1N1σ2πexp[−12σ2(y(i)−WTx(i))2]\begin{aligned} \mathcal P(\mathcal Y \mid \mathcal W,\mathcal X) & \sim \prod_{i=1}^N \mathcal N(y^{(i)} \mid \mathcal W^Tx^{(i)},\sigma^2) \\ & = \prod_{i=1}^N \frac{1}{\sigma \sqrt{2\pi}} \exp\left[-\frac{1}{2 \sigma^2} \left(y^{(i)} - \mathcal W^T x^{(i)}\right)^2\right] \end{aligned}P(Y∣W,X)∼i=1∏NN(y(i)∣WTx(i),σ2)=i=1∏Nσ2π1exp[−2σ21(y(i)−WTx(i))2]
将连乘符号∏\prod∏代入exp\expexp中,并使用矩阵乘法的方式进行描述:
主要是对
∑i=1N(y(i)−WTx(i))2\sum_{i=1}^N \left(y^{(i)} - \mathcal W^Tx^{(i)}\right)^2∑i=1N(y(i)−WTx(i))2进行变换,变换结果表示如下:
传送门
∑i=1N(y(i)−WTx(i))2=(y(1)−WTx(1),⋯,y(N)−WTx(N))(y(1)−WTx(1)⋮y(N)−WTx(N))=(YT−WTXT)(Y−XW)=(Y−XW)T(Y−XW)\begin{aligned} \sum_{i=1}^N \left(y^{(i)} - \mathcal W^Tx^{(i)}\right)^2 & = \left(y^{(1)} - \mathcal W^Tx^{(1)},\cdots,y^{(N)} - \mathcal W^Tx^{(N)}\right) \begin{pmatrix}y^{(1)} - \mathcal W^Tx^{(1)} \\ \vdots \\ y^{(N)} - \mathcal W^Tx^{(N)}\end{pmatrix} \\ & = (\mathcal Y^T - \mathcal W^T\mathcal X^T)(\mathcal Y - \mathcal X\mathcal W) \\ & = (\mathcal Y - \mathcal X \mathcal W)^T(\mathcal Y -\mathcal X \mathcal W) \end{aligned}i=1∑N(y(i)−WTx(i))2=(y(1)−WTx(1),⋯,y(N)−WTx(N))⎝⎜⎛y(1)−WTx(1)⋮y(N)−WTx(N)⎠⎟⎞=(YT−WTXT)(Y−XW)=(Y−XW)T(Y−XW)
12σ2\frac{1}{2\sigma^2}2σ21和
iii无关,拿到连加号外面,
I\mathcal II表示单位矩阵。
=1(2π)N2σNexp[−12σ2∑i=1N(y(i)−WTx(i))2]=1(2π)N2σNexp[−12(Y−XW)Tσ−2I(Y−XW)]\begin{aligned} & = \frac{1}{(2\pi)^{\frac{N}{2}}\sigma^N} \exp \left[-\frac{1}{2\sigma^2} \sum_{i=1}^N \left(y^{(i)} - \mathcal W^Tx^{(i)}\right)^2\right] \\ & = \frac{1}{(2\pi)^{\frac{N}{2}}\sigma^N} \exp \left[- \frac{1}{2} (\mathcal Y - \mathcal X \mathcal W)^T \sigma^{-2} \mathcal I(\mathcal Y - \mathcal X \mathcal W)\right] \end{aligned}=(2π)2NσN1exp[−2σ21i=1∑N(y(i)−WTx(i))2]=(2π)2NσN1exp[−21(Y−XW)Tσ−2I(Y−XW)]
观察上式,上式同样也是高斯分布的表达格式,这也从侧面证明后验概率P(Y∣W,X)\mathcal P(\mathcal Y \mid \mathcal W,\mathcal X)P(Y∣W,X)确实服从高斯分布。上述高斯分布格式可化简为:
中间的项
σ−2I\sigma^{-2} \mathcal Iσ−2I表示’精度矩阵‘。需要注意~
P(Y∣W,X)∼N(XW,σ2I)\mathcal P(\mathcal Y \mid \mathcal W,\mathcal X) \sim \mathcal N(\mathcal X\mathcal W,\sigma^2 \mathcal I)P(Y∣W,X)∼N(XW,σ2I)
至此,后验分布P(W∣Data)\mathcal P(\mathcal W \mid Data)P(W∣Data)可表示为:
P(W∣Data)∝N(XW,σ2I)⋅N(0,Σprior)\mathcal P(\mathcal W \mid Data) \propto \mathcal N(\mathcal X \mathcal W,\sigma^2 \mathcal I) \cdot \mathcal N(0,\Sigma_{prior})P(W∣Data)∝N(XW,σ2I)⋅N(0,Σprior)
言归正传,如何求解μW,ΣW\mu_{\mathcal W},\Sigma_{\mathcal W}μW,ΣW?
对上式进行如下转换:
这里只关心与
W\mathcal WW相关的项,其他的项均视作常数。
P(W∣Data)∝{1(2π)N2σNexp[−12(Y−XW)Tσ−2I(Y−XW)]}⋅{1(2π)p2∣Σprior∣12[−12WTΣprior−1W]}∝exp[−12(Y−XW)Tσ−2I(Y−XW)]⋅exp[−12WTΣprior−1W]=exp{−12σ2(YT−WTXT)(Y−XW)−12WTΣprior−1W}\begin{aligned} \mathcal P(\mathcal W \mid Data) & \propto \left\{ \frac{1}{(2\pi)^{\frac{N}{2}}\sigma^N} \exp \left[- \frac{1}{2} (\mathcal Y - \mathcal X \mathcal W)^T \sigma^{-2} \mathcal I(\mathcal Y - \mathcal X \mathcal W)\right] \right\} \cdot \left\{\frac{1}{(2\pi)^{\frac{p}{2}}|\Sigma_{prior}|^{\frac{1}{2}}}\left[ - \frac{1}{2} \mathcal W^T \Sigma_{prior}^{-1}\mathcal W \right]\right\} \\ & \propto \exp \left[- \frac{1}{2} (\mathcal Y - \mathcal X \mathcal W)^T \sigma^{-2} \mathcal I(\mathcal Y - \mathcal X \mathcal W)\right] \cdot \exp \left[- \frac{1}{2} \mathcal W^T \Sigma_{prior}^{-1}\mathcal W\right] \\ & = \exp \left\{-\frac{1}{2\sigma^2}(\mathcal Y^T - \mathcal W^T\mathcal X^T)(\mathcal Y - \mathcal X\mathcal W) - \frac{1}{2} \mathcal W^T\Sigma_{prior}^{-1} \mathcal W\right\} \end{aligned}P(W∣Data)∝{(2π)2NσN1exp[−21(Y−XW)Tσ−2I(Y−XW)]}⋅{(2π)2p∣Σprior∣211[−21WTΣprior−1W]}∝exp[−21(Y−XW)Tσ−2I(Y−XW)]⋅exp[−21WTΣprior−1W]=exp{−2σ21(YT−WTXT)(Y−XW)−21WTΣprior−1W}
思路:使用配方法,将上式化简为12(W−μW)TΣW−1(W−μW)\frac{1}{2}(\mathcal W - \mu_{\mathcal W})^T\Sigma_{\mathcal W}^{-1}(\mathcal W - \mu_{\mathcal W})21(W−μW)TΣW−1(W−μW)的格式,从而求出μW,ΣW−1\mu_{\mathcal W},\Sigma_{\mathcal W}^{-1}μW,ΣW−1。
我们先对
12(W−μW)TΣW−1(W−μW)\frac{1}{2}(\mathcal W - \mu_{\mathcal W})^T\Sigma_{\mathcal W}^{-1}(\mathcal W - \mu_{\mathcal W})21(W−μW)TΣW−1(W−μW)进行展开:用
Δ\DeltaΔ表示。
这里的
μWTΣW−1W\mu_{\mathcal W}^T \Sigma_{\mathcal W}^{-1} \mathcal WμWTΣW−1W和
WTΣW−1μW\mathcal W^T\Sigma_{\mathcal W}^{-1}\mu_{\mathcal W}WTΣW−1μW互为转置并且均表示实数,因而有:
μWTΣW−1W=WTΣW−1μW\mu_{\mathcal W}^T \Sigma_{\mathcal W}^{-1} \mathcal W = \mathcal W^T\Sigma_{\mathcal W}^{-1}\mu_{\mathcal W}μWTΣW−1W=WTΣW−1μW.
Δ=−12[WTΣW−1W−μWTΣW−1W−WTΣW−1μW+μWTΣW−1μW]=−12[WTΣW−1W−2μWTΣW−1W+μWTΣW−1μW]\begin{aligned} \Delta & = -\frac{1}{2} \left[\mathcal W^T\Sigma_{\mathcal W}^{-1} \mathcal W - \mu_{\mathcal W}^T \Sigma_{\mathcal W}^{-1} \mathcal W - \mathcal W^T\Sigma_{\mathcal W}^{-1}\mu_{\mathcal W} + \mu_{\mathcal W}^T\Sigma_{\mathcal W}^{-1} \mu_{\mathcal W}\right] \\ & = -\frac{1}{2} \left[\mathcal W^T\Sigma_{\mathcal W}^{-1} \mathcal W - 2 \mu_{\mathcal W}^T \Sigma_{\mathcal W}^{-1} \mathcal W + \mu_{\mathcal W}^T\Sigma_{\mathcal W}^{-1} \mu_{\mathcal W}\right] \end{aligned}Δ=−21[WTΣW−1W−μWTΣW−1W−WTΣW−1μW+μWTΣW−1μW]=−21[WTΣW−1W−2μWTΣW−1W+μWTΣW−1μW]
其中二次项是
−12WTΣW−1W- \frac{1}{2}\mathcal W^T\Sigma_{\mathcal W}^{-1} \mathcal W−21WTΣW−1W,一次项是
μWTΣW−1W\mu_{\mathcal W}^T \Sigma_{\mathcal W}^{-1} \mathcal WμWTΣW−1W,常数项是
−12μWTΣW−1μW-\frac{1}{2}\mu_{\mathcal W}^T\Sigma_{\mathcal W}^{-1} \mu_{\mathcal W}−21μWTΣW−1μW。对比这三项去寻找目标结果的相应项。
对上式完全展开:
观察
YTXW\mathcal Y^T\mathcal X\mathcal WYTXW和
WTXTY\mathcal W^T\mathcal X^T\mathcal YWTXTY这两项,它们是互为转置,并且均表示实数。因此有:
YTXW=WTXTY\mathcal Y^T\mathcal X\mathcal W = \mathcal W^T\mathcal X^T\mathcal YYTXW=WTXTY。
P(W∣Data)∝exp{−12σ2(YTY−YTXW−WTXTY+WTXTXW)−12WTΣpiror−1W}=exp{−12σ2(YTY−2YTXW+WTXTXW)−12WTΣpiror−1W}\begin{aligned} \mathcal P(\mathcal W \mid Data) & \propto \exp \left\{- \frac{1}{2\sigma^2} (\mathcal Y^T\mathcal Y - \mathcal Y^T\mathcal X\mathcal W - \mathcal W^T\mathcal X^T\mathcal Y + \mathcal W^T\mathcal X^T\mathcal X\mathcal W) - \frac{1}{2} \mathcal W^T\Sigma_{piror}^{-1}\mathcal W\right\} \\ & = \exp\left\{- \frac{1}{2\sigma^2} \left(\mathcal Y^T\mathcal Y - 2\mathcal Y^T\mathcal X\mathcal W + \mathcal W^T\mathcal X^T\mathcal X\mathcal W\right)- \frac{1}{2} \mathcal W^T\Sigma_{piror}^{-1}\mathcal W\right\} \end{aligned}P(W∣Data)∝exp{−2σ21(YTY−YTXW−WTXTY+WTXTXW)−21WTΣpiror−1W}=exp{−2σ21(YTY−2YTXW+WTXTXW)−21WTΣpiror−1W}
这里令
A=ΣW−1\mathcal A = \Sigma_{\mathcal W}^{-1}A=ΣW−1。
此时我们不需要在去观察’常数项部分‘。因为仅需要求解
μW\mu_{\mathcal W}μW和
ΣW\Sigma_{\mathcal W}ΣW.此时已经得到了两个方程:
{μWTμW−1=YTXσ2μW−1=A\begin{cases} \mu_{\mathcal W}^T \mu_{\mathcal W}^{-1} = \frac{\mathcal Y^T\mathcal X} {\sigma^2} \\ \mu_{\mathcal W}^{-1} = \mathcal A \end{cases}{μWTμW−1=σ2YTXμW−1=A
解这个方程,有:
{μW=A−1XYσ2ΣW−1=A\begin{cases} \mu_{\mathcal W} = \frac{\mathcal A^{-1}\mathcal X\mathcal Y}{\sigma^2} \\ \Sigma_{\mathcal W}^{-1} = \mathcal A \end{cases}{μW=σ2A−1XYΣW−1=A
至此,μW,ΣW−1\mu_{\mathcal W},\Sigma_{\mathcal W}^{-1}μW,ΣW−1均已求解,那么后验概率分布P(W∣Data)\mathcal P(\mathcal W \mid Data)P(W∣Data)表示为:
P(W∣Data)∼N(μW,ΣW){μW=A−1XYσ2ΣW=A−1A=XTXσ2+Σpiror−1\begin{aligned} \mathcal P(\mathcal W \mid Data) \sim \mathcal N(\mu_{\mathcal W},\Sigma_{\mathcal W}) \quad \begin{cases} \mu_{\mathcal W} = \frac{\mathcal A^{-1}\mathcal X\mathcal Y}{\sigma^2} \\ \Sigma_{\mathcal W} = \mathcal A^{-1} \\ \mathcal A = \frac{\mathcal X^T\mathcal X}{\sigma^2} + \Sigma_{piror}^{-1} \end{cases} \end{aligned}P(W∣Data)∼N(μW,ΣW)⎩⎪⎨⎪⎧μW=σ2A−1XYΣW=A−1A=σ2XTX+Σpiror−1
下一节将介绍预测任务(Prediction)。
相关参考:
机器学习-贝叶斯线性回归(3)-推导Inference