Revision f20a1c968b2e0d2ec88319a2a17116793ee6fe4c authored by Peter Hawkins on 07 April 2020, 15:21:39 UTC, committed by GitHub on 07 April 2020, 15:21:39 UTC
1 parent fa383b4
Raw File
advi.py
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Automatic differentiation variational inference in Numpy and JAX.

This demo fits a Gaussian approximation to an intractable, unnormalized
density, by differentiating through a Monte Carlo estimate of the
variational evidence lower bound (ELBO)."""


from functools import partial
import matplotlib.pyplot as plt

from jax.api import jit, grad, vmap
from jax import random
from jax.experimental import optimizers
import jax.numpy as np
import jax.scipy.stats.norm as norm


# ========= Functions to define the evidence lower bound. =========

def diag_gaussian_sample(rng, mean, log_std):
    # Take a single sample from a diagonal multivariate Gaussian.
    return mean + np.exp(log_std) * random.normal(rng, mean.shape)

def diag_gaussian_logpdf(x, mean, log_std):
    # Evaluate a single point on a diagonal multivariate Gaussian.
    return np.sum(vmap(norm.logpdf)(x, mean, np.exp(log_std)))

def elbo(logprob, rng, mean, log_std):
    # Single-sample Monte Carlo estimate of the variational lower bound.
    sample = diag_gaussian_sample(rng, mean, log_std)
    return logprob(sample) - diag_gaussian_logpdf(sample, mean, log_std)

def batch_elbo(logprob, rng, params, num_samples):
    # Average over a batch of random samples.
    rngs = random.split(rng, num_samples)
    vectorized_elbo = vmap(partial(elbo, logprob), in_axes=(0, None, None))
    return np.mean(vectorized_elbo(rngs, *params))


# ========= Helper function for plotting. =========

@partial(jit, static_argnums=(0, 1, 2, 4))
def _mesh_eval(func, x_limits, y_limits, params, num_ticks):
    # Evaluate func on a 2D grid defined by x_limits and y_limits.
    x = np.linspace(*x_limits, num=num_ticks)
    y = np.linspace(*y_limits, num=num_ticks)
    X, Y = np.meshgrid(x, y)
    xy_vec = np.stack([X.ravel(), Y.ravel()]).T
    zs = vmap(func, in_axes=(0, None))(xy_vec, params)
    return X, Y, zs.reshape(X.shape)

def mesh_eval(func, x_limits, y_limits, params, num_ticks=101):
    return _mesh_eval(func, x_limits, y_limits, params, num_ticks)

# ========= Define an intractable unnormalized density =========

def funnel_log_density(params):
    return norm.logpdf(params[0], 0, np.exp(params[1])) + \
           norm.logpdf(params[1], 0, 1.35)


if __name__ == "__main__":
    num_samples = 40

    @jit
    def objective(params, t):
        rng = random.PRNGKey(t)
        return -batch_elbo(funnel_log_density, rng, params, num_samples)

    # Set up figure.
    fig = plt.figure(figsize=(8,8), facecolor='white')
    ax = fig.add_subplot(111, frameon=False)
    plt.ion()
    plt.show(block=False)
    x_limits = [-2, 2]
    y_limits = [-4, 2]
    target_dist = lambda x, _: np.exp(funnel_log_density(x))
    approx_dist = lambda x, params: np.exp(diag_gaussian_logpdf(x, *params))

    def callback(params, t):
        print("Iteration {} lower bound {}".format(t, objective(params, t)))

        plt.cla()
        X, Y, Z = mesh_eval(target_dist, x_limits, y_limits, 1)
        ax.contour(X, Y, Z, cmap='summer')
        X, Y, Z = mesh_eval(approx_dist, x_limits, y_limits, params)
        ax.contour(X, Y, Z, cmap='winter')
        ax.set_xlim(x_limits)
        ax.set_ylim(y_limits)
        ax.set_yticks([])
        ax.set_xticks([])

        # Plot random samples from variational distribution.
        # Here we clone the rng used in computing the objective
        # so that we can show exactly the same samples.
        rngs = random.split(random.PRNGKey(t), num_samples)
        samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params)
        ax.plot(samples[:, 0], samples[:, 1], 'b.')

        plt.draw()
        plt.pause(1.0/60.0)


    # Set up optimizer.
    D = 2
    init_mean = np.zeros(D)
    init_std  = np.zeros(D)
    init_params = (init_mean, init_std)
    opt_init, opt_update, get_params = optimizers.momentum(step_size=0.1, mass=0.9)
    opt_state = opt_init(init_params)

    @jit
    def update(i, opt_state):
        params = get_params(opt_state)
        gradient = grad(objective)(params, i)
        return opt_update(i, gradient, opt_state)


    # Main loop.
    print("Optimizing variational parameters...")
    for t in range(100):
        opt_state = update(t, opt_state)
        params = get_params(opt_state)
        callback(params, t)
    plt.show(block=True)
back to top