A Variational Inference Perspective on Expectation Propagation

Notations:

  1. $\text{Diag}(\boldsymbol{a})$: a diagonal matrix with $\boldsymbol{a}$ being its diagonal element.
  2. $\text{diag}(\mathbf{A})$: a vector from the diagonal element of $\mathbf{A}$.
  3. $\boldsymbol{a}\odot \boldsymbol{b}$: componentwise multiply.
  4. $\boldsymbol{a}\oslash \boldsymbol{b}$: componentwise divide.

Recap of Variational Inference

As mentioned in [1], we have introduced variational inference and its application in Bayesian linear regression. In this blog, we will introduce a variational inference perspective on expectation propagation.

In signal processing, the posterior distribution is interested. However, it is difficult to obtain since many high-dimension integral are involved. For example, we consider linear Gaussian model
\begin{align}
\mathbf{y}=\mathbf{Hx}+\mathbf{w}
\end{align} Its posterior distribution is given by
\begin{align}
p(\mathbf{x}|\mathbf{y})=\frac{p(\mathbf{y}|\mathbf{x})p(\mathbf{x})}{\int p(\mathbf{y}|\mathbf{x})p(\mathbf{x}) \text{d}\mathbf{x} }
\end{align} where $p(\mathbf{y}|\mathbf{x})=p_{\mathbf{w} }(\mathbf{y}-\mathbf{Hx})$. Unless both $p(\mathbf{y}|\mathbf{x})$ and $p(\mathbf{x})$ are Gaussian, it is difficult to easily obtain the close-form of $p(\mathbf{x}|\mathbf{y})$. Some approximations, thus, are necessary.

For that purpose, we use $q(\mathbf{x})$ to approximate the posterior distribution and applied KL-divergence to measure the difference between $q(\mathbf{x})$ and $p(\mathbf{x}|\mathbf{y})$. For simplication, we generally restrict the form of $q(\mathbf{x})$ from the distribution family $\mathcal{S}$, i.e.,
\begin{align}
q(\mathbf{x})=\underset{q(\mathbf{x})\in \mathcal{S} } {\arg \min} \ \mathcal{D}_{\text{KL} }(p||q)
\end{align} Obviously, a distribution family with excellent properties will greatly reduce the amount of computation. Luckily, the exponential family is one of those.

Exponential Family

The exponential family over $\mathbf{x}$ parametered by $\boldsymbol{\eta}$ is defined by following form
\begin{align}
p(\mathbf{x};\boldsymbol{\eta})=h(\mathbf{x})g(\boldsymbol{\eta})\exp\left(\boldsymbol{\eta}^T\boldsymbol{u}(\mathbf{x})\right)
\end{align} where $g(\boldsymbol{\eta})$ is normalization constant
\begin{align}
g(\boldsymbol{\eta}) \left[\int h(\mathbf{x})\exp\left(\boldsymbol{\eta}^T\boldsymbol{u}(\mathbf{x})\right)\text{d}\mathbf{x}\right]=1
\end{align} Taking the gradient of both side of the above w.r.t. $\boldsymbol{\eta}$, we get
\begin{align}
\nabla g(\boldsymbol{\eta})\int h(\mathbf{x})\exp \left(\boldsymbol{\eta}^T\boldsymbol{u}(\mathbf{x})\right)\text{d}\mathbf{x}+g(\boldsymbol{\eta})\int h(\mathbf{x})\left(\boldsymbol{\eta}^T\boldsymbol{u}(\mathbf{x})\right)\boldsymbol{u}(\mathbf{x})\text{d}\mathbf{x}=0
\end{align} Rearranging above equation yields
\begin{align}
-\frac{1}{g(\boldsymbol{\eta})}\nabla g(\boldsymbol{\eta})
&=g(\boldsymbol{\eta}) \int \boldsymbol{u}(\mathbf{x})h(\mathbf{x})\exp\left(\boldsymbol{\eta}^T\boldsymbol{u}(\mathbf{x})\right)\text{d}\mathbf{x}\\
&=\frac{ \int \boldsymbol{u}(\mathbf{x})h(\mathbf{x})\exp\left(\boldsymbol{\eta}^T\boldsymbol{u}(\mathbf{x})\right)\text{d}\mathbf{x} }{ \int h(\mathbf{x})\exp\left(\boldsymbol{\eta}^T\boldsymbol{u}(\mathbf{x})\right)\text{d}\mathbf{x} }\\
&=\mathbb{E}[\boldsymbol{u}(\mathbf{x})]
\end{align} Using the fact $\nabla \log g(\boldsymbol{\eta})=\frac{1}{g(\boldsymbol{\eta})}\nabla g(\boldsymbol{\eta})$, we have
\begin{align}
-\nabla \log g(\boldsymbol{\eta})=\mathbb{E}[\boldsymbol{u}(\mathbf{x})] \quad \cdots\quad (*1)
\end{align}

A Variational Inference Perspective on EP

For the distribution of $q(\mathbf{x})$ in variational inference, We take exponential family distribution into account
\begin{align}
q(\mathbf{x})=h(\mathbf{x})g(\boldsymbol{\eta})\exp \left(\boldsymbol{\eta}^T\boldsymbol{u}(\mathbf{x})\right)
\end{align} we then write $\mathcal{D}_{\text{KL} }(p||q)$ as
\begin{align}
\mathcal{D}_{\text{KL} }(p||q)=-\log g(\boldsymbol{\eta})-\boldsymbol{\eta}^T\mathbb{E}_{p(\mathbf{x})}[\boldsymbol{u}(\mathbf{x})]+\text{const}
\end{align} Taking the gradient of both side of above w.r.t. $\boldsymbol{\eta}$ to zero yields
\begin{align}
-\nabla \log g(\boldsymbol{\eta}) =\mathbb{E}_{p(\mathbf{x})}[\boldsymbol{u}(\mathbf{x})]
\end{align} As mentioned in $(*1)$, we then get
\begin{align}
\mathbb{E}_{q(\mathbf{x})}[\boldsymbol{u}(\mathbf{x})]=\mathbb{E}_{p(\mathbf{x})}[\boldsymbol{u}(\mathbf{x})]
\end{align} Note that if $q(\mathbf{x})$ is Gaussian $\mathcal{N}(\mathbf{x}|\boldsymbol{\mu},\mathbf{\Sigma})$, we then minimize the KL-divergence by setting $\boldsymbol{\mu}$ equal to the mean of $p(\mathbf{x})$ and $\mathbf{\Sigma}$ equal to the variance of $p(\mathbf{x})$.

We exploit this result to obtain a pratical algorithm for approximate inference. For many probability models, the joint distribution of data $\mathcal{D}=\left\{\mathbf{y}_1,\cdots,\mathbf{y}_N\right\}$ and hidden variables (may including parameters) $\boldsymbol{\theta}$ comprises a product of factors in the form
\begin{align}
p(\mathcal{D},\boldsymbol{\theta})=\prod_i f_i(\boldsymbol{\theta})
\end{align} where $f_0(\boldsymbol{\theta})=p(\boldsymbol{\theta})$ and $f_n(\boldsymbol{\theta})=p(\mathbf{y}_n|\boldsymbol{\theta}),(n\ne 0)$. The posterior distribution is given by
\begin{align}
p(\boldsymbol{\theta}|\mathcal{D})=\frac{p(\mathcal{D},\boldsymbol{\theta})}{p(\mathcal{D})}=\frac{1}{p(\mathcal{D})}\prod_{i} f_i(\boldsymbol{\theta})
\end{align} where $p(\mathcal{D})$ is partition function or evidence function.
\begin{align}
p(\mathcal{D})=\int \prod_i f_i(\boldsymbol{\theta})\text{d}\boldsymbol{\theta}
\end{align} As we determine the form of $q(\mathbf{x})$
\begin{align}
q(\boldsymbol{\theta})=\frac{1}{Z}\prod_i q_i(\boldsymbol{\theta})
\end{align} then $q(\boldsymbol{\theta})$ is updated by minimizing
\begin{align}
q_i(\boldsymbol{\theta})=\underset{q_i(\boldsymbol{\theta})}{\arg \min}\ \mathcal{D}_{\text{KL} }\left(\frac{1}{p(\mathcal{D})}\prod_{i}f_i(\boldsymbol{\theta})||\frac{1}{Z}\prod_{i}q_i(\boldsymbol{\theta})\right)
\end{align} Actually, the approximation is poor since each factor is individually approximated. To remedy this situation, expectation propagation makes a much better approximation by optimizing each factor in turn in the context of all of the remaining factors [2]. Below, we will give the detailed description of EP step-by-step.

$\underline{\text{Step 1} }$: Initialize all factors $q_i(\boldsymbol{\theta})$ from distribution family $\mathcal{S}$.
\begin{align}
q(\boldsymbol{\theta})=\frac{1}{Z}\prod_i q_i(\boldsymbol{\theta})
\end{align} $\underline{\text{Step 2} }$: Compute $q^{\backslash j}(\boldsymbol{\theta})$ denoted by
\begin{align}
q^{\backslash j}(\boldsymbol{\theta})=C\frac{q(\boldsymbol{\theta})}{q_j(\boldsymbol{\theta})}
\end{align}
where $C$ is normalization constant.

$\underline{\text{Step 3} }$: Update
\begin{align}
q^{\text{new} }(\boldsymbol{\theta})=\mathcal{D}_{\text{KL} } \left(\frac{1}{Z_j}f_j(\boldsymbol{\theta})q^{\backslash j}(\boldsymbol{\theta})||q(\boldsymbol{\theta})\right)
\end{align}
where $q^{\text{new} }(\boldsymbol{\theta})$ is the update of $q(\boldsymbol{\theta})$.

$\underline{\text{Step 4} }$: Update $q_j(\boldsymbol{\theta})$
\begin{align}
q_j(\boldsymbol{\theta})=C\frac{q^{\text{new} }(\boldsymbol{\theta})}{q^{\backslash j}(\boldsymbol{\theta})}
\end{align}
where $C$ is a normalization constant.

$\underline{\text{Step 5} }$: $\longrightarrow $ Step 2.

Application in Communication

We consider standard linear Gaussian model (SLM)
\begin{align}
\mathbf{y}=\mathbf{Hx}+\mathbf{w}
\end{align} where $\mathbf{x}\in \mathbb{C}^N$ generated from $M$-QAM constellation with distribution $p(\mathbf{x})=\prod_{i=1}^N p(x_i)$. Passing the channel $\mathbf{H}\in \mathbb{C}^{M\times N}$ (estimated perfect beforhand) and adding the white Gaussian noise $\mathbf{w}\sim \mathcal{N}_c(\mathbf{w}|\boldsymbol{0},\sigma^2\mathbf{I})$, the final observed signal $\mathbf{y}$ is then obtained.

We aim at designing an high-efficient signal detector using EP. Based on above knowledge, we write the posterior distribution of this model as
\begin{align}
p(\mathbf{x}|\mathbf{y})
&=\frac{p(\mathbf{y}|\mathbf{x})p(\mathbf{x})}{p(\mathbf{y})}\\
&\propto p(\mathbf{y}|\mathbf{x})p(\mathbf{x})
\end{align} Notice that since $\mathbf{y}$ is given, then $p(\mathbf{y})$ is regarded as a constant. We further assume the each observed data are independent of others, i.e.,
\begin{align}
p(\mathbf{y}|\mathbf{x})=\prod_{a=1}^M p(y_a|\mathbf{x})
\end{align} $\underline{\text{Step 1} }$: Initialize $q(\mathbf{x})$, the approximation of $q(\mathbf{x})$. Since $p(\mathbf{y}|\mathbf{x})$ is Gaussian, we then approximate $p(\mathbf{x})$ by Gaussian, one of exponential family.
\begin{align}
q(\mathbf{x})=\mathcal{N}_c(\mathbf{x}|\mathbf{m},\text{Diag}(\mathbf{v}))
\end{align} Its marginal distribution is $q(x_i)=\mathcal{N}_c(x_i|m_i,v_i)$. Note that $q(x_i)$ here is $q_i(\boldsymbol{\theta})$ mentioned in section 3.

$\underline{\text{Step 2} }$: Calculate the joint distribution $q(\mathbf{x},\mathbf{y})$
\begin{align}
q(\mathbf{x},\mathbf{y})
&=q(\mathbf{x})p(\mathbf{y}|\mathbf{x})\\
&=\mathcal{N}_c(\mathbf{x}|\boldsymbol{m},\text{Diag}(\mathbf{v}))\mathcal{N}_c(\mathbf{y}|\mathbf{Hx},\sigma^2\mathbf{I})\\
&\propto \mathcal{N}_c(\mathbf{x}|\boldsymbol{m},\text{Diag}(\mathbf{v}))\mathcal{N}_c(\mathbf{x}|(\mathbf{H}^T\mathbf{H})^{-1}\mathbf{H}^H\mathbf{y},(\sigma^{-2}\mathbf{H}^T\mathbf{H})^{-1}) \\
&\propto \mathcal{N}_c(\mathbf{x}|\boldsymbol{\mu},\mathbf{\Sigma})
\end{align} where the last equaiton holds by Gaussian product lemma mentioned in [2] and following definitions
\begin{align}
\mathbf{\Sigma}&=(\sigma^{-1}\mathbf{H}^H\mathbf{H}+\text{Diag}(\mathbf{1}\oslash \mathbf{v}))^{-1}\\
\boldsymbol{\mu}&=\mathbf{\Sigma}\left(\sigma^2\mathbf{H}^H\mathbf{y}+\text{Diag}(\mathbf{m}\oslash \mathbf{v})\right)
\end{align} Here, we further abuse $\mathcal{N}_c(x_j|\mu_{j},\Sigma_{jj})$ to approximate $p(x_j,\mathbf{y})$, the marignal distribuiton of $p(\mathbf{x},\mathbf{y})$. This operation ignores the correlation of $x_j$ and $\mathbf{x}_{\backslash j}$, so we write it as $q(x_j,\mathbf{y})=\mathcal{N}_c(x_j|\mu_j,\Sigma_{jj})$. Alternative operation of $p(x_j|\mathbf{y})$ is found in [Appendix A, 4], where the matrix inverse lemma is applied to rewrite $\mathbf{\Sigma}$.

$\underline{\text{Step 3} }$: Compute $q^{\backslash j}(x_j)$
\begin{align}
q^{\backslash j}(x_j)=\frac{q(x_j,\mathbf{y})}{q(x_j)}=\frac{\mathcal{N}_c(x_j|\mu_j,\Sigma_{jj})}{\mathcal{N}_c(x_j|m_j,v_j)}\propto \mathcal{N}_c(x_j|m^{\text{tem} }_j,v^{\text{tem} }_j)
\end{align} where
\begin{align}
v_j^{\text{tem} }&=\left(\frac{1}{\Sigma_{jj} }-\frac{1}{v_j}\right)^{-1}\\
m_j^{\text{tem} }&=v_j^{\text{tem} } \left(\frac{\mu_j}{\Sigma_{jj} }-\frac{m_j}{v_j}\right)
\end{align} $\underline{\text{Step 4} }$: Update $q(x_i,\mathbf{y})$ by minimizing KL-divergence
\begin{align}
q^{\text{new} }(x_j,\mathbf{y})=\underset{q(x_j,\mathbf{y})\in \mathcal{S} }{\arg \min}\ \mathcal{D}_{\text{KL} } \left(\frac{1}{C}p(x_j)q^{\backslash j}(x_j)||q(x_j,\boldsymbol{y})\right)
\end{align} Thanks to the property of exponential family, the minimizing the KL-divergence is approached by moment match operation. For easy of notation, we define
\begin{align}
\hat{x}_j&\overset{\triangle}{=}\mathbb{E}\left[x_j|m_j^{\text{tem} },v_j^{\text{tem} }\right]\\
\hat{v}_j&\overset{\triangle}{=}\text{Var}\left[x_j|m_j^{\text{tem} },v_j^{\text{tem} }\right]
\end{align} where the expectation is taken over the approximated posterior distribution
\begin{align}
\hat{p}(x_i|\mathbf{y})=\frac{1}{C}p(x_j)q^{\backslash j}(x_j)=\frac{p(x_j)\mathcal{N}_c(x_j|m_j^{\text{tem} },v_j^{\text{tem} })}{\int p(x_j)\mathcal{N}_c(x_j|m_j^{\text{tem} },v_j^{\text{tem} })\text{d}x_j}
\end{align} Accoringly, $q(x_j|\mathbf{y})$ is mapped to
\begin{align}
q^{\text{new} }(x_j,\mathbf{y})=\mathcal{N}_c(x_j|\hat{x}_j,\hat{v}_j)
\end{align} Note that, here we use ‘new’ to distinguish it and the old $q(x_j|\mathbf{y})$.

$\underline{\text{Step 5} }$: Update $q(x_j)$ based on
\begin{align}
q(x_j)\propto \frac{q^{\text{new} }(x_j,\mathbf{y})}{q^{\backslash j}(x_j)}
\end{align} Using the Gaussian product lemma, we get
\begin{align}
v_j&=\left(\frac{1}{\hat{v}_j}-\frac{1}{v_j^{\text{tem} } }\right)^{-1}\\
m_j&=v_j \left(\frac{\hat{x}_j}{\hat{v}_j}-\frac{m_j^{\text{tem} }}{v_j^{\text{tem} }}\right)
\end{align} $\underline{\text{Step 6} }$: $\longrightarrow$ Step 2.

Totally, With above description, the EP algorithm for standard linear model is summarized as below
\begin{align}
\mathbf{\Sigma}&=(\sigma^{-1}\mathbf{H}^H\mathbf{H}+\text{Diag}(\mathbf{1}\oslash \mathbf{v}))^{-1}\\
\boldsymbol{\mu}&=\mathbf{\Sigma}\left(\sigma^2\mathbf{H}^H\mathbf{y}+\text{Diag}(\mathbf{m}\oslash \mathbf{v})\right)\\
\tilde{\mathbf{v} }&=\text{diag}(\mathbf{\Sigma})\\
\mathbf{v}^{\text{tem} }&=\mathbf{1}\oslash \left(\mathbf{1}\oslash \tilde{\mathbf{v} }-\mathbf{1}\oslash \mathbf{v}\right)\\
\mathbf{m}^{\text{tem} }&=\mathbf{v}^{\text{tem} }\odot \left(\boldsymbol{\mu}\oslash \tilde{\mathbf{v} } -\mathbf{m}\oslash \mathbf{v}\right)\\
\hat{\mathbf{x} }&=\mathbb{E}\left[\mathbf{x}|\mathbf{m}^{\text{tem} },\mathbf{v}^{\text{tem} }\right]\\
\hat{\mathbf{v} }&=\text{Var}\left[\mathbf{x}|\mathbf{m}^{\text{tem} },\mathbf{v}^{\text{tem} }\right]\\
\mathbf{v}&=\mathbf{1}\oslash (\mathbf{1}\oslash \hat{\mathbf{v} }-\mathbf{1}\oslash \mathbf{v}^{\text{tem} })\\
\mathbf{m}&=\mathbf{v}\odot (\hat{\mathbf{x} }\oslash \hat{\mathbf{v} }-\mathbf{m}^{\text{tem} }\oslash \mathbf{v}^{\text{tem} })
\end{align}
It is interesting to see that the EP for standard linear model is extrmely simliar to vector approximate message passing (VAMP) [3]. At least, the EP in SLM is equal to VAMP in pseudo-code.

Reference

[1] https://www.qiuyun-blog.cn/2019/01/03/Variational-Inference-for-Bayesian-Linear-Regression/
[2] Bishop C M. Pattern Recognition and Machine Learning (Information Science and Statistics)[M]. 2006.
[3] Rangan S, Schniter P, Fletcher A. Vector Approximate Message Passing[J]. 2016.
[4] Santos I, Murillo-Fuentes J J, Arias-de-Reyna E, et al. Turbo EP-based equalization: A filter-type implementation[J]. IEEE Transactions on Communications, 2018, 66(9): 4259-4270.