Julia Community 🟣

Cover image for The Laplace approximation (part 2)
Martin Roa Villescas
Martin Roa Villescas

Posted on • Edited on

The Laplace approximation (part 2)

This is the second post in a two-part series on the Laplace transformation. In part 1, we saw how a well-behaved general function f(x)f(x) can be approximated with an unnormalized normal distribution. In this post, I will extend this idea by showing how the Laplace approximation is used in the context of Bayesian inference. I present an example implemented in Julia that compares the Laplace approximation of a probability distribution with its exact counterpart.

Bayesian inference

Bayesian inference is a method used to update the probability for a hypothesis as more information becomes available. At the core of this process lies Bayes' theorem, a mathematical formula that transforms the prior distribution of a variable into a posterior distribution in light of new or additional data:

p(λD)posterior=p(Dλ)likelihoodp(λ)priorp(D)evidence, \overbrace{p(\lambda \mid D)}^{\text{posterior}} = \frac{\overbrace{p(D \mid \lambda)}^{\text{likelihood}} \, \overbrace{p(\lambda)}^{\text{prior}}}{\underbrace{p(D)}_{\text{evidence}}},

where λ\lambda is the variable of interest and DD the set of observations. Using the chain rule and the law of total probability we can rewrite Bayes' theorem as

p(λD)posterior=p(D,λ)joint likelihoodλp(D,λ)dλevidence. \overbrace{p(\lambda \mid D)}^{\text{posterior}} = \frac{\overbrace{p(D, \lambda)}^{\text{joint likelihood}} }{\underbrace{\int_{\lambda}p(D, \lambda) \, \mathrm{d}\lambda}_{\text{evidence}}}.

In many practical applications, the integral in the equation above is intractable, i.e., impossible to solve analytically. Therefore, we have to resort to approximation methods to compute it. The Laplace approximation is one example.

The Laplace approximation

In contrast with part 1, where the goal was to find the Laplace approximation q(x)q(x) of a general function f(x)f(x) , here we want to approximate a probability distribution, which by definition integrates to one over its entire domain. This means that in addition to finding the unnormalized distribution q(x)q(x) , we also need to calculate its normalizing constant CC . I.e. our goal is to find

pL(λD)=1Cf(λ), p_{\mathcal{L}}(\lambda \mid D) = \frac{1}{C} f(\lambda),

which is nothing more than a disguised form of Bayes' theorem:

pL(λD)=pL(D,λ)pL(D). p_{\mathcal{L}}(\lambda \mid D) = \frac{p_{\mathcal{L}}(D, \lambda)}{p_{\mathcal{L}}(D)}.

In the following sections, we will derivate the equations for pL(D,λ)p_{\mathcal{L}}(D, \lambda) and pL(D)p_{\mathcal{L}}(D) according to the Laplace approximation.

Even though the derivation of these equations might seem tedious, the conclusion is succinct and elegant, and is all you need in practice.

Step 1: approximating the joint likelihood pL(D,λ)p_{\mathcal{L}}(D, \lambda)

We start by finding the mode λ0\lambda_0 of the joint likelihood p(D,λ)p(D, \lambda) , or equivalently, of the joint log-likelihood lnp(D,λ)\ln p(D, \lambda) , which can be done either analytically or numerically. The latter option is demonstrated later in the example.

Then, we make a second-order Taylor expansion of p(D,λ0)p(D, \lambda_0) around the mode λ0\lambda_0 , like we did in part 1. The procedure is repeated here for completeness.

To simplify the notation, let f(λ)=pL(D,λ)f(\lambda) = p_{\mathcal{L}}(D, \lambda) and h(λ)=lnf(λ)h(\lambda) = \ln f(\lambda) . Consider the second-order Taylor expansion of h(λ)h(\lambda) around λ0\lambda_0 :

h(λ)h(λ0)+h(λ0)(λλ0)+12h(λ0)(λλ0)2 h(\lambda) \approx h(\lambda_0) + h'(\lambda_0)(\lambda-\lambda_0) + \frac{1}{2}h''(\lambda_0)(\lambda-\lambda_0)^2

Note that h(λ0)=dlnf(λ)dλλ=λ0=f(λ0)f(λ0)=0h'(\lambda_0) = \frac{\mathrm{d}\ln f(\lambda)}{\mathrm{d}\lambda}\Bigr|_{\lambda = \lambda_0} = \frac{f'(\lambda_0)}{f(\lambda_0)} = 0 , since f(λ0)=0f'(\lambda_0) =0 , so we have:

h(λ)h(λ0)+12h(λ0)(λλ0)2. h(\lambda) \approx h(\lambda_0) + \frac{1}{2}h''(\lambda_0)(\lambda-\lambda_0)^2.

Applying the exponential function on both sides yields:

exp(h(λ))exp(h(λ0)+12h(λ0)(λλ0)2)exp(lnf(λ))exp(h(λ0))exp(12h(λ0)(λλ0)2)f(λ)f(λ0)exp(12h(λ0)(λλ0)2), \begin{aligned} \exp \left( h(\lambda) \right) &\approx \exp\left( h(\lambda_0) + \frac{1}{2}h''(\lambda_0)(\lambda-\lambda_0)^2 \right) \\ \exp \left( \ln f(\lambda) \right) &\approx \exp\left( h(\lambda_0) \right) \exp \left( \frac{1}{2}h''(\lambda_0)(\lambda-\lambda_0)^2 \right) \\ f(\lambda) &\approx f(\lambda_0) \exp \left( \frac{1}{2}h''(\lambda_0)(\lambda-\lambda_0)^2 \right), \end{aligned}

Hence,

pL(D,λ)p(D,λ0)exp(12h(λ0)(λλ0)2), \begin{aligned} p_{\mathcal{L}}(D, \lambda) &\approx p(D, \lambda_0) \exp \left( \frac{1}{2}h''(\lambda_0)(\lambda-\lambda_0)^2 \right), \end{aligned}

Step 2: approximating the evidence pL(D)p_{\mathcal{L}}(D)

Again, to simplify the notation, let f(λ)=pL(D,λ)f(\lambda) = p_{\mathcal{L}}(D, \lambda) and C=pL(D)C = p_{\mathcal{L}}(D) . According to the law of total probability, we have that

C=λf(λ)dλ. C = \int_{\lambda} f(\lambda) \, \mathrm{d}\lambda.

Substituting the result that we derived above for f(λ)f(\lambda) , we obtain

Cf(λ0)λexp(12h(λ0)(λλ0)2)dλf(λ0)λexp((h(λ0))2(λλ0)2)dλ. \begin{aligned} C &\approx f(\lambda_0) \int_{\lambda} \exp \left( \frac{1}{2}h''(\lambda_0)(\lambda-\lambda_0)^2 \right) \, \mathrm{d}\lambda \\ &\approx f(\lambda_0) \begingroup\color{RoyalBlue} \int_{\lambda} \exp \left( -\frac{(-h''(\lambda_0))}{2}(\lambda-\lambda_0)^2 \right) \, \mathrm{d}\lambda \endgroup. \end{aligned}

Recall that the probability density function (PDF) of a normally distributed random variable XX with mean μ\mu and variance σ2\sigma^2 , denoted as XN(μ,σ2)X \sim \mathcal{N}(\mu, \sigma^2) , is given by

p(X)=12πσexp(12(xμ)2σ2), p(X) = \frac{1}{\sqrt{2\pi}\sigma}\exp\left(-\frac{1}{2}\frac{(x-\mu)^2}{\sigma^2}\right),

or equivalently by

p(X)=γ2πexp(γ2(xμ)2), p(X) = \frac{\sqrt{\gamma}}{\sqrt{2\pi}} \exp \left(-\frac{\gamma}{2}(x-\mu)^2\right),

where the precision γ=1σ2\gamma= \frac{1}{\sigma^2} .

Integrating both sides with respect to xx yields

xp(X)dx=γ2πxexp(γ2(xμ)2)dx1=γ2πxexp(γ2(xμ)2)dx2πγ=xexp(γ2(xμ)2)dx. \begin{aligned} \int_x p(X) \, \mathrm{d}x &= \frac{\sqrt{\gamma}}{\sqrt{2\pi}} \int_x \exp \left(-\frac{\gamma}{2}(x-\mu)^2\right) \, \mathrm{d}x \\ 1 &= \frac{\sqrt{\gamma}}{\sqrt{2\pi}} \int_x \exp \left(-\frac{\gamma}{2}(x-\mu)^2\right) \, \mathrm{d}x \\ \frac{\sqrt{2\pi}}{\sqrt{\gamma}} &= \begingroup\color{RoyalBlue} \int_x \exp \left(-\frac{\gamma}{2}(x-\mu)^2\right) \, \mathrm{d}x \endgroup. \end{aligned}

If we compare the two equations in blue above, we can conclude that

λexp((h(λ0))2(λλ0)2)dλ=2πh(λ0). \int_{\lambda} \exp \left( -\frac{(-h''(\lambda_0))}{2}(\lambda-\lambda_0)^2 \right) \, \mathrm{d}\lambda = \frac{\sqrt{2\pi}}{\sqrt{-h''(\lambda_0)}}.

Hence,

Cf(λ0)2πh(λ0),pL(D)p(D,λ0)2πh(λ0). \begin{aligned} C &\approx f(\lambda_0) \frac{\sqrt{2\pi}}{\sqrt{-h''(\lambda_0)}}, \\ p_{\mathcal{L}}(D) &\approx p(D, \lambda_0) \frac{\sqrt{2\pi}}{\sqrt{-h''(\lambda_0)}}. \end{aligned}

Step 3: bringing it all together

Recall that we seek to approximate the posterior distribution p(λD)p(\lambda \mid D) with its Laplace approximation pL(λD)=pL(D,λ)pL(D)p_{\mathcal{L}}(\lambda \mid D) = \frac{p_{\mathcal{L}}(D, \lambda)}{p_{\mathcal{L}}(D)} .

Using a second-order Taylor expansion, we derived that

pL(D,λ)p(D,λ0)exp(12h(λ0)(λλ0)2), p_{\mathcal{L}}(D, \lambda) \approx p(D, \lambda_0) \exp \left( \frac{1}{2}h''(\lambda_0)(\lambda-\lambda_0)^2 \right),

We also showed that the normalizing constant is

pL(D)p(D,λ0)2πh(λ0). p_{\mathcal{L}}(D) \approx p(D, \lambda_0) \frac{\sqrt{2\pi}}{\sqrt{-h''(\lambda_0)}}.

Substituting these two results into pL(λD)=pL(D,λ)pL(D)p_{\mathcal{L}}(\lambda \mid D) = \frac{p_{\mathcal{L}}(D, \lambda)}{p_{\mathcal{L}}(D)} and rearranging the equation, we obtain

pL(λD)=h(λ0)p(D,λ0)2πp(D,λ0)exp(12h(λ0)(λλ0)2)=12π(h(λ0))1exp(12(λλ0)2(h(λ0))1), \begin{aligned} p_\mathcal{L}(\lambda \mid D) &= \frac{\sqrt{-h''(\lambda_0)}}{p(D, \lambda_0)\sqrt{2\pi}} p(D, \lambda_0) \exp \left( \frac{1}{2} h''(\lambda_0)(\lambda - \lambda_0)^2 \right) \\ &= \frac{1}{\sqrt{2\pi}\sqrt{(-h''(\lambda_0))^{-1}}} \exp \left( -\frac{1}{2} \frac{(\lambda - \lambda_0)^2}{(-h''(\lambda_0))^{-1}} \right), \end{aligned}

or more succinctly

pL(λD)=N(λ0,(h(λ0))1), \boxed{p_\mathcal{L}(\lambda \mid D) = \mathcal{N} \left( \lambda_0, (-h''(\lambda_0))^{-1} \right)} \, ,

where

h(λ0)=d2lnp(D,λ)dλ2λ=λ0. \boxed{h''(\lambda_0) = \frac{\mathrm{d}^2 \ln p(D, \lambda)}{\mathrm{d}\lambda^2}\Bigr|_{\lambda = \lambda_0}} \, .

In conclusion, in the context of Bayesian inference, the Laplace approximation of a posterior distribution is a normal distribution centered around the mode of the joint likelihood and has a precision equal to minus the second derivative of the joint log-likelihood evaluated at its mode.

Example: Poisson data with a gamma prior

Let's look at a simple example of a model that has an exact solution. This will allow us to compare the Laplace approximation of a posterior distribution with its exact counterpart. Our model assumes that the data is drawn from a Poisson distribution with a gamma distributed rate parameter λ\lambda . I.e. the model is defined as

YPoisson(λ)λGamma(α,θ). \begin{aligned} Y &\sim \text{Poisson}(\lambda) \\ \lambda &\sim \text{Gamma}(\alpha, \theta). \end{aligned}

The objective is to calculate p(λY)p(\lambda \mid Y) . The exact solution is given by

p(λ{yi})Gamma(α+i=1nyi,θnθ+1), p(\lambda \mid \{y_i\}) \sim \text{Gamma}\left(\alpha + \sum_{i=1}^n y_i, \frac{\theta}{n\theta + 1} \right),

where {yi}\{y_i\} are the set of observed values.

We start by importing the necessary packages and by setting some default options for the plotting package:

using Plots, LaTeXStrings         # used to visualize the result
using SpecialFunctions: gamma     # used to define the gamma PDF
using Optim: maximize, maximizer  # used to find x₀
using ForwardDiff: derivative     # used to find h''(x₀)
default(fill=true, fillalpha=0.3, xlabel="λ") # plot defaults
Enter fullscreen mode Exit fullscreen mode

Next, we define the normal, Poisson, and gamma PDFs:

Normal(μ, σ²) = x -> (1/(sqrt(2π*σ²))) * exp(-(x-μ)^2/(2*σ²))
Poisson(Y::Int) = λ -> (λ^Y) / factorial(Y) * exp(-λ)
Gamma(α, θ) = λ -> (λ^(α-1) * exp(-λ/θ)) / (θ^(α) * gamma(α))
Enter fullscreen mode Exit fullscreen mode

The next step is to specify the model. To do so, we need to define the prior p(λ)p(\lambda) and the likelihood p(Yλ)p(Y \mid \lambda) according to the specification above. To define the prior, we need to set the hyperparameters of the gamma distribution α\alpha and θ\theta . We will set them both to 3. To specify the likelihood, we use a Poisson distribution with a placeholder YY that we can use to enter observed values. The model is now fully specified. However, let's define the joint log-likelihood lnp(D,λ)\ln p(D, \lambda) , (the log of the product of the likelihood and the prior) which is a quantity needed by the Laplace approximation:

prior = Gamma(3, 3)
likelihood(Y) = Poisson(Y)
joint_log_likelihood(Y) = λ -> log(likelihood(Y)(λ) * prior(λ))
Enter fullscreen mode Exit fullscreen mode

Suppose we observe y=2y=2 . The exact posterior is given by:

y = 2
exact_posterior = Gamma(prior.α + y, prior.θ / (prior.θ + 1))
Enter fullscreen mode Exit fullscreen mode

We can visualize the prior and exact posterior:

λ = range(0, 12, length=500)
plot(λ, prior, lab=L"p(\lambda)")
plot!(λ, exact_posterior, lab=L"p(\lambda|D)")
Enter fullscreen mode Exit fullscreen mode

Example of the prior and exact posterior distributions

Let's now find the Laplace approximation of the exact posterior using the results we derived. The first step is to find the mode λ0\lambda_0 of the joint log-likelihood lnp(Y,λ)\ln p(Y, \lambda) . We can do so using the maximize function from the package Optim.jl. The arguments expected by maximize are the function to be maximized and the endpoints of the interval along which the optimization will be performed. The returned object is a structure with several information from which we can extract the mode using the maximizer function.

λ₀ = maximize(joint_log_likelihood(y), first(λ), last(λ)) |> maximizer
Enter fullscreen mode Exit fullscreen mode

The final step is to find lnh(λ0)\ln h''(\lambda_0) , i.e., the second derivative of the joint log-likelihood lnp(Y,λ)\ln p(Y, \lambda) at its mode λ0\lambda_0 . We use the derivative function from the ForwardDiff.jl package:

hꜛꜛ(λ) = derivative(λ -> derivative(joint_log_likelihood(y), λ), λ)
Enter fullscreen mode Exit fullscreen mode

Finally, we bring together the pieces to define the Laplace approximation:

laplace_approx = Normal(λ₀, -(hꜛꜛ(λ₀))^(-1))
Enter fullscreen mode Exit fullscreen mode

We can now compare the Laplace approximation to the exact posterior:

plot!(λ, laplace_approx, lab=L"p_{\mathcal{L}}(\lambda|D)", title=L"\textrm{Laplace~approximation}")
Enter fullscreen mode Exit fullscreen mode

Example of the posterior Laplace approximation

A few remarks to conclude this post. The Laplace approximation replaces the problem of integration with that of optimization. This is important because intractable integrals are common in practical applications. On the other hand, even when the integral can be solved, optimization is often faster to compute. However, the Laplace approximation has several limitations: if the posterior distribution being approximated is not unimodal, or does not have it's mass concentrated around the mode, or is not twice-differentiable, then the result will be inaccurate. This is worrying because in many applications we don't have access to the exact posterior, which means that we don't have a reference to test the accuracy of the approximation. In any case, the Laplace approximation has been proven to be effective not only in the context of Bayesian inference, but also in many other mathematical domains that exhibit hard-to-compute integrals.


Credit to @OswaldoGressani for his awesome YouTube video

Top comments (0)