Variational Inference for Bayesian Linear Regression

Notations:

  • KL-divergence: Given two distribution $p(x)$ and $q(x)$, the Kullback–Leibler divergence, also written as KL-divergence, is used to value the difference between $p(x)$ and $q(x)$ denoted as
    \begin{align}
    \mathcal{D}_{\text{KL} } (q(x)||p(x))=\int q(x)\log \frac{q(x)}{p(x)}\text{d}x
    \end{align}
    The KL-divergence is also named relative entropy in information theory.
  • Gamma function [1]
    \begin{align}
    \text{Gam}(\alpha|a,b)=\frac{1}{\Gamma(a)}b^a\alpha^{a-1}e^{-b\alpha}
    \end{align}
    It has following properties
    \begin{align}
    \mathbb{E}[\alpha]&=\frac{a}{b}\\
    \text{Var}[\alpha]&=\frac{a}{b^2}
    \end{align}
  • Gaussian product lemma
    \begin{align}
    \mathcal{N}(x|a,B)\mathcal{N}(x|b,B)=\mathcal{N}(0|a-b,A+B)\mathcal{N}(x|c,C)
    \end{align}
    where $C=(1/A+/B)^{-1}$ and $c=C\cdot(a/A+b/B)$.

Variational Inference

In signal processing, we are interested in the posterior distribution $p(\mathbf{x}|\mathbf{y})$, where $\mathbf{y}$ is observed signal while $\mathbf{x}$ denotes the signal to be estimated. However, it is generally difficult to obtain the posterior distribution. In order to avoid the disastrous computation, we then try to use $q(\mathbf{x})$ to approximate the posterior distribution. To this end, the KL-divergence is used to measure the difference between $q(\mathbf{x})$ and $p(\mathbf{x}|\mathbf{y})$, defined as
\begin{align}
\mathcal{D}_{\text{KL} }(q||p)=\int q(\mathbf{x})\log \frac{q(\mathbf{x})}{p(\mathbf{x}|\mathbf{y})}\text{d}\mathbf{x}
\end{align}
As the decrease of KL divergence, $q(\mathbf{x})$ is closer to $p(\mathbf{x}|\mathbf{y})$. Specially, as $q(\mathbf{x})$ equals to $p(\mathbf{x}|\mathbf{y})$, the KL-divergence becomes zero.

For simplication, we generally restrict that $q(\mathbf{x})$ is from a family of distribution such as quadratic function or linear combination denoted by $\mathcal{S}$, and $q(\mathbf{x})$ is found by minimizing the KL-divergence, i.e.,
\begin{align}
q(\mathbf{x})=\underset{q(\mathbf{x})\in \mathcal{S} }{\arg \min}\ \mathcal{D}_{\text{KL} }(q||p)
\end{align}

We rewrite the $\mathcal{D}_{\text{KL} }(q||p)$ as
\begin{align}
\mathcal{D}_{\text{KL} }(q||p)
&=\int q(\mathbf{x}) \log \frac{q(\mathbf{x})p(\mathbf{y})}{p(\mathbf{x},\mathbf{y})}\text{d}\mathbf{x}\\
&=\int q(\mathbf{x})\log \frac{q(\mathbf{x})}{p(\mathbf{x},\mathbf{y})}\text{d}\mathbf{x}+\log p(\mathbf{y})\\
&=-\mathcal{L}(q)+\log p(\mathbf{y})
\end{align}
where
\begin{align}
\mathcal{L}(q)\overset{\triangle}{=}\int q(\mathbf{x})\log \frac{p(\mathbf{x},\mathbf{y})}{q(\mathbf{x})}\text{d}\mathbf{x}
\end{align}
Since $\log p(\mathbf{y})$ is known, the minimum of $\mathcal{D}_{\text{KL} }(q||p)$ can be obtained by maximizing $\mathcal{L}(q)$.

Assumption: The factorization condition is gengerally taken into account.
\begin{align}
q(\mathbf{x})=\prod_{i=1}^M q(\mathbf{x}_i)
\end{align}
where $\mathbf{x}=\left\{\mathbf{x}_1,\cdots,\mathbf{x}_M\right\}$. Note that each group $\mathbf{x}_i (\forall i)$ have one element at least.

With this assumption, we rewrite $\mathcal{L}(q)$
\begin{align}
\mathcal{L}(q)
&=\int \prod_{i=1}^M q(\mathbf{x}_i) \left[\log p(\mathbf{x},\mathbf{y})-\sum_{i=1}^M \log q(\mathbf{x}_i)\right]\text{d}\mathbf{x}\\
&=\int q(\mathbf{x}_j)\left[\log p(\mathbf{x},\mathbf{y})\prod_{i\ne j}\left(q(\mathbf{x}_i)\text{d}\mathbf{x}_i\right)\right]\text{d}\mathbf{x}_j-\int q(\mathbf{x}_j) \log q(\mathbf{x}_j)\text{d}\mathbf{x}_j+\text{const}\\
&=\int q(\mathbf{x}_j) \log \tilde{p}(\mathbf{y},\mathbf{x}_j)\text{d}\mathbf{x}_j-\int q(\mathbf{x}_j)\log q(\mathbf{x}_j)\text{d}\mathbf{x}_j+\text{const}\\
&=-\mathcal{D}_{\text{KL} }(q(\mathbf{x}_j)||\tilde{p}(\mathbf{y},\mathbf{x}_j))+\text{const}
\end{align}
Since we focus on the distribution involved $\mathbf{x}_j$, we use ‘const’ to represent all of iterms without $\mathbf{x}_j$. This notation also appears in the rest of this blog. In addition, some definitions are used
\begin{align}
\log \tilde{p}(\mathbf{y},\mathbf{x}_j)=\mathbb{E}_{q^{\backslash j}(\mathbf{x})}[\log p(\mathbf{y},\mathbf{x})]+\text{const}
\end{align}
where $q^{\backslash j}(\mathbf{x})=\prod_{i\ne j}q(\mathbf{x}_i)$.

As a result, we try to minimize $\mathcal{D}_{\text{KL} }(q(\mathbf{x}_j)||\tilde{p}(\mathbf{y},\mathbf{x}_j))$ by choosing $q(\mathbf{x}_j)=\tilde{p}(\mathbf{y},\mathbf{x}_j)$. Hence, we obtain
\begin{align}
q^{\star}(\mathbf{x}_j)=\frac{\exp (\mathbb{E}_{q^{\backslash j}(\mathbf{x})}[\log p(\mathbf{y},\mathbf{x})])}{\int \exp(\mathbb{E}_{q^{\backslash j}(\mathbf{x})}[\log p(\mathbf{y},\mathbf{x})]) \text{d}\mathbf{x}_j} \quad (*1)
\end{align}
where we assume that $q(\mathbf{x}_i) (i\ne j)$ are determined beforehand.

Remarks:
Note that the variational EM is EM algorithm with a varitional E-step. (The expectation maximization (EM) algorithm is use to calculate the maximum likelihood with latent variables or parameters, a kind of major-maximization (MM) algorithm)

Application in Bayesian Linear Regression

Given the data sets $\left\{\mathbf{y},\mathbf{x}\right\}$, we assume the likelihood function $p(\mathbf{y}|\mathbf{x})$ and prior distribution $p(\mathbf{x})$
\begin{align}
p(\mathbf{y}|\mathbf{x};\beta)
&=\prod_{n=1}^N \mathcal{N}(y_n|\mathbf{x}^T\boldsymbol{\phi},\beta^{-1})\\
p(\mathbf{x}|\alpha)
&=\mathcal{N}(\mathbf{x}|\boldsymbol{0},\alpha^{-1}\mathbf{I})
\end{align}
where $\beta$ and $\alpha$ are unknown parameters, and $\boldsymbol{\phi}$ is a base function. In addition, $\alpha$ is with distribution
\begin{align}
p(\alpha)=\text{Gam}(\alpha|a_0,b_0)
\end{align}
Thus the joint distribution of all the variables is given by
\begin{align}
p(\mathbf{y},\mathbf{x},\alpha)=p(\mathbf{y}|\mathbf{x};\beta)p(\mathbf{x}|\alpha)p(\alpha)
\end{align}

Variational Method

  1. Train Stage: We define
    \begin{align}
    q(\mathbf{x},\alpha)=q(\mathbf{x})q(\alpha)
    \end{align}
    On the one hand, using $(*1)$, we have
    \begin{align}
    \log q^{\star} (\alpha)
    &=\mathbb{E}_{q(\mathbf{x})}[\log p(\mathbf{x},\mathbf{y})]+ \text{const}\\
    &=\mathbb{E}_{q(\mathbf{x})}[\log p(\mathbf{y}|\mathbf{x};\beta)+\log p(\mathbf{x}|\alpha)+\log p(\alpha)]+\text{const}\\
    &=\log p(\alpha)+\mathbb{E}_{q(\mathbf{x})}[\log p(\mathbf{x}|\alpha)]+\text{const}\\
    &=(a_0-1)\log \alpha-b_0\alpha+\frac{M}{2}\log \alpha-\frac{\alpha}{2}\mathbb{E}_{q(\mathbf{x})}[\mathbf{w}^T\mathbf{w}]+\text{const}
    \end{align}
    It also can be rewritten as
    \begin{align}
    q^{\star}(\alpha)=\text{Gam}(\alpha|a_N,b_N)
    \end{align}
    where
    \begin{align}
    a_N&=a_0+\frac{M}{2}\\
    b_N&=b_0+\frac{1}{2}\mathbb{E}\left[\mathbf{x}^T\mathbf{x}\right]
    \end{align}
    On the other hand, we have
    \begin{align}
    \log q^{\star}(\mathbf{x})
    &=\mathbb{E}_{q(\alpha)}[\log p(\mathbf{x},\mathbf{y})]+\text{const}\\
    &=\mathbb{E}_{q(\alpha)}[\log p(\mathbf{y}|\mathbf{x};\beta)+\log p(\mathbf{x}|\alpha)+\log p(\alpha)]+\text{const}\\
    &=\log p(\mathbf{y}|\mathbf{x};\beta)+\mathbb{E}_{q(\alpha)}[\log p(\mathbf{x}|\alpha)]+\text{const}\\
    &=-\frac{\beta}{2}\sum_{i=1}^N(\mathbf{x}^T\boldsymbol{\phi}_n-y_n)^2-\frac{1}{2}\mathbb{E}_{q(\alpha)}[\alpha]\mathbf{x}^T\mathbf{x}+\text{const}\\
    &=-\frac{1}{2}\mathbf{x}^T(\mathbb{E}[\alpha]\mathbf{I}+\beta \boldsymbol{\Phi}^T\boldsymbol{\Phi})+\beta \mathbf{x}^T\boldsymbol{\Phi}^T\mathbf{y}+\text{const}
    \end{align}
    where
    \begin{align}
    \boldsymbol{\Phi}=
    \left(
    \begin{matrix}
    \phi_0(x_1) &\cdots &\phi_{M-1}(x_1)\\
    \vdots &\ddots &\vdots\\
    \phi_0(x_N) &\cdots &\phi_{M-1}(x_N)
    \end{matrix}
    \right)
    \end{align}
    It also can be rewritten as
    \begin{align}
    q^{\star} (\mathbf{x})=\mathcal{N}(\mathbf{x}|\mathbf{m}_N,\mathbf{S}_N)
    \end{align}
    where
    \begin{align}
    \mathbf{m}_N&=\beta \mathbf{S}_N \boldsymbol{\Phi}^T\mathbf{y}\\
    \mathbf{S}_N&=(\mathbb{E}_{q(\mathbf{x})}[\alpha]+\beta \boldsymbol{\Phi}^T\boldsymbol{\Phi})^{-1}
    \end{align}
    with $\mathbb{E}_{q(\mathbf{x})}=\frac{a_N}{b_N}$.
  1. Test Stage: The new input data denotes as $\mathbf{w}$ and its corresponding output is represented by $\mathbf{t}$, our goal is to approximate $\mathbf{t}$ based on the information of train step.
    \begin{align}
    p(\mathbf{t}|\mathbf{w},\mathbf{y})
    &=\int p(\mathbf{y}|\mathbf{w},\mathbf{x})p(\mathbf{x}|\mathbf{y})\text{d}\mathbf{x}\\
    &\approx \int p(\mathbf{t}|\mathbf{w},\mathbf{x})q(\mathbf{x})\text{d}\mathbf{x}\\
    &=\int \mathcal{N}(\mathbf{t}|\mathbf{x}^T \boldsymbol{\phi}(\mathbf{w}),\beta^{-1})\mathcal{N}(\mathbf{x}|\mathbf{m}_N,\mathbf{S}_N)\text{d}\mathbf{x}\\
    &=\mathcal{N}(\mathbf{t}|\mathbf{m}_N^T\boldsymbol{\phi}(\mathbf{w}),\sigma^2(\mathbf{w}))
    \end{align}
    where the last equation can be obtained by Gaussian product lemma and definition
    \begin{align}
    \sigma^2(\mathbf{x})=\beta^{-1}+\boldsymbol{\phi}(\mathbf{w})^T\mathbf{S}_{N}\boldsymbol{\phi}(\mathbf{w})
    \end{align}
    As the precision $\sigma^2(\mathbf{w})$ is small, we regard $\mathbf{m}_N^T\boldsymbol{\phi}(\mathbf{w})$ as a prediction of $\mathbf{t}$.

References

[1] Bishop C M. Pattern Recognition and Machine Learning (Information Science and Statistics)[M]. 2006.
[2] Fox C W , Roberts S J . A tutorial on variational Bayesian inference[J]. Artificial Intelligence Review, 2012, 38(2):85-95.