https://github.com/pymc-devs/pymc3
Raw File
Tip revision: 02ef44f197cc55fc6aafa25898819ad0828b17f4 authored by Chris Fonnesbeck on 03 October 2016, 19:36:29 UTC
Removed interactive_sampling (#1423)
Tip revision: 02ef44f
plots.py
import numpy as np
from scipy.stats import kde, mode
from numpy.linalg import LinAlgError
import matplotlib.pyplot as plt
import pymc3 as pm
from .stats import quantiles, hpd

__all__ = ['traceplot', 'kdeplot', 'kde2plot',
           'forestplot', 'autocorrplot', 'plot_posterior']


def traceplot(trace, varnames=None, transform=lambda x: x, figsize=None,
              lines=None, combined=False, plot_transformed=False, grid=False,
              alpha=0.35, priors=None, prior_alpha=1, prior_style='--',
              ax=None):
    """Plot samples histograms and values

    Parameters
    ----------

    trace : result of MCMC run
    varnames : list of variable names
        Variables to be plotted, if None all variable are plotted
    transform : callable
        Function to transform data (defaults to identity)
    figsize : figure size tuple
        If None, size is (12, num of variables * 2) inch
    lines : dict
        Dictionary of variable name / value  to be overplotted as vertical
        lines to the posteriors and horizontal lines on sample values
        e.g. mean of posteriors, true values of a simulation
    combined : bool
        Flag for combining multiple chains into a single chain. If False
        (default), chains will be plotted separately.
    plot_transformed : bool
        Flag for plotting automatically transformed variables in addition to
        original variables (defaults to False).
    grid : bool
        Flag for adding gridlines to histogram. Defaults to True.
    alpha : float
        Alpha value for plot line. Defaults to 0.35.
    priors : iterable of PyMC distributions
        PyMC prior distribution(s) to be plotted alongside poterior. Defaults
        to None (no prior plots).
    prior_alpha : float
        Alpha value for prior plot. Defaults to 1.
    prior_style : str
        Line style for prior plot. Defaults to '--' (dashed line).
    ax : axes
        Matplotlib axes. Defaults to None.

    Returns
    -------

    ax : matplotlib axes

    """

    if varnames is None:
        if plot_transformed:
            varnames = [name for name in trace.varnames]
        else:
            varnames = [name for name in trace.varnames if not name.endswith('_')]

    n = len(varnames)

    if figsize is None:
        figsize = (12, n * 2)

    if ax is None:
        fig, ax = plt.subplots(n, 2, squeeze=False, figsize=figsize)
    elif ax.shape != (n, 2):
        pm._log.warning('traceplot requires n*2 subplots')
        return None

    for i, v in enumerate(varnames):
        prior = None
        if priors is not None:
            prior = priors[i]
        for d in trace.get_values(v, combine=combined, squeeze=False):
            d = np.squeeze(transform(d))
            d = make_2d(d)
            if d.dtype.kind == 'i':
                histplot_op(ax[i, 0], d, alpha=alpha)
            else:
                kdeplot_op(ax[i, 0], d, prior, prior_alpha, prior_style)
            ax[i, 0].set_title(str(v))
            ax[i, 0].grid(grid)
            ax[i, 1].set_title(str(v))
            ax[i, 1].plot(d, alpha=alpha)

            ax[i, 0].set_ylabel("Frequency")
            ax[i, 1].set_ylabel("Sample value")

            if lines:
                try:
                    ax[i, 0].axvline(x=lines[v], color="r", lw=1.5)
                    ax[i, 1].axhline(y=lines[v], color="r",
                                     lw=1.5, alpha=alpha)
                except KeyError:
                    pass

    plt.tight_layout()
    return ax


def histplot_op(ax, data, alpha=.35):
    for i in range(data.shape[1]):
        d = data[:, i]

        mind = np.min(d)
        maxd = np.max(d)
        step = max((maxd - mind) // 100, 1)
        ax.hist(d, bins=range(mind, maxd + 2, step), alpha=alpha, align='left')
        ax.set_xlim(mind - .5, maxd + .5)


def kdeplot_op(ax, data, prior=None, prior_alpha=1, prior_style='--'):
    errored = []
    for i in range(data.shape[1]):
        d = data[:, i]
        try:
            density = kde.gaussian_kde(d)
            l = np.min(d)
            u = np.max(d)
            x = np.linspace(0, 1, 100) * (u - l) + l

            if prior is not None:
                p = prior.logp(x).eval()
                ax.plot(x, np.exp(p), alpha=prior_alpha, ls=prior_style)

            ax.plot(x, density(x))

        except LinAlgError:
            errored.append(i)

    if errored:
        ax.text(.27, .47, 'WARNING: KDE plot failed for: ' + str(errored), style='italic',
                bbox={'facecolor': 'red', 'alpha': 0.5, 'pad': 10})


def make_2d(a):
    """Ravel the dimensions after the first.
    """
    a = np.atleast_2d(a.T).T
    # flatten out dimensions beyond the first
    n = a.shape[0]
    newshape = np.product(a.shape[1:]).astype(int)
    a = a.reshape((n, newshape), order='F')
    return a


def kde2plot_op(ax, x, y, grid=200):
    xmin = x.min()
    xmax = x.max()
    ymin = y.min()
    ymax = y.max()

    grid = grid * 1j
    X, Y = np.mgrid[xmin:xmax:grid, ymin:ymax:grid]
    positions = np.vstack([X.ravel(), Y.ravel()])
    values = np.vstack([x, y])
    kernel = kde.gaussian_kde(values)
    Z = np.reshape(kernel(positions).T, X.shape)

    ax.imshow(np.rot90(Z), cmap=plt.cm.gist_earth_r,
              extent=[xmin, xmax, ymin, ymax])


def kdeplot(data, ax=None):
    if ax is None:
        f, ax = plt.subplots(1, 1, squeeze=True)
    kdeplot_op(ax, data)
    return ax


def kde2plot(x, y, grid=200, ax=None):
    if ax is None:
        f, ax = plt.subplots(1, 1, squeeze=True)
    kde2plot_op(ax, x, y, grid)
    return ax


def autocorrplot(trace, varnames=None, max_lag=100, burn=0, plot_transformed=False,
                 symmetric_plot=False, ax=None, figsize=None):
    """Bar plot of the autocorrelation function for a trace

    Parameters
    ----------
    trace : result of MCMC run
    varnames : list of variable names
        Variables to be plotted, if None all variable are plotted.
        Vector-value stochastics are handled automatically.
    max_lag : int
        Maximum lag to calculate autocorrelation. Defaults to 100.
    burn : int
        Number of samples to discard from the beginning of the trace.
        Defaults to 0.
    plot_transformed : bool
        Flag for plotting automatically transformed variables in addition to
        original variables (defaults to False).
    symmetric_plot : boolean
        Plot from either [0, +lag] or [-lag, lag]. Defaults to False, [-, +lag].
    ax : axes
        Matplotlib axes. Defaults to None.
    figsize : figure size tuple
        If None, size is (12, num of variables * 2) inches.
        Note this is not used if ax is supplied.

    Returns
    -------
    ax : matplotlib axes
    """

    def _handle_array_varnames(varname):
        if trace[0][varname].__class__ is np.ndarray:
            k = trace[varname].shape[1]
            for i in range(k):
                yield varname + '_{0}'.format(i)
        else:
            yield varname

    if varnames is None:
        if plot_transformed:
            varnames = [name for name in trace.varnames]
        else:
            varnames = [name for name in trace.varnames if not name.endswith('_')]

    varnames = [item for sub in [[i for i in _handle_array_varnames(v)]
                                 for v in varnames] for item in sub]

    nchains = trace.nchains

    if figsize is None:
        figsize = (12, len(varnames) * 2)

    if ax is None:
        fig, ax = plt.subplots(len(varnames), nchains, squeeze=False,
                               sharex=True, sharey=True, figsize=figsize)
    elif ax.shape != (len(varnames), nchains):
        raise ValueError('autocorrplot requires {}*{} subplots'.format(
            len(varnames), nchains))
        return None

    max_lag = min(len(trace) - 1, max_lag)

    for i, v in enumerate(varnames):
        for j in range(nchains):
            try:
                d = np.squeeze(trace.get_values(v, chains=[j], burn=burn,
                                                combine=False))
            except KeyError:
                k = int(v.split('_')[-1])
                v_use = '_'.join(v.split('_')[:-1])
                d = np.squeeze(trace.get_values(v_use, chains=[j],
                                                burn=burn, combine=False)[:, k])

            ax[i, j].acorr(d, detrend=plt.mlab.detrend_mean, maxlags=max_lag)

            if not j:
                ax[i, j].set_ylabel("correlation")
            if i == len(varnames) - 1:
                ax[i, j].set_xlabel("lag")

            ax[i, j].set_title(v)

            if not symmetric_plot:
                ax[i, j].set_xlim(0, max_lag)

            if nchains > 1:
                ax[i, j].set_title("chain {0}".format(j + 1))

    return ax


def var_str(name, shape):
    """Return a sequence of strings naming the element of the tallyable object.
    This is a support function for forestplot.

    :Example:
    >>> var_str('theta', (4,))
    ['theta[1]', 'theta[2]', 'theta[3]', 'theta[4]']

    """

    size = np.prod(shape)
    ind = (np.indices(shape) + 1).reshape(-1, size)
    names = ['[' + ','.join(map(str, i)) + ']' for i in zip(*ind)]
    # if len(name)>12:
    #     name = '\n'.join(name.split('_'))
    #     name += '\n'
    names[0] = '%s %s' % (name, names[0])
    return names


def forestplot(trace_obj, varnames=None, transform=lambda x: x, alpha=0.05, quartiles=True,
               rhat=True, main=None, xtitle=None, xrange=None, ylabels=None,
               chain_spacing=0.05, vline=0, gs=None):
    """ Forest plot (model summary plot)

    Generates a "forest plot" of 100*(1-alpha)% credible intervals for either
    the set of variables in a given model, or a specified set of nodes.

    :Arguments:
        trace_obj: NpTrace or MultiTrace object
            Trace(s) from an MCMC sample.

        varnames: list
            List of variables to plot (defaults to None, which results in all
            variables plotted).

        transform : callable
            Function to transform data (defaults to identity)

        alpha (optional): float
            Alpha value for (1-alpha)*100% credible intervals (defaults to
            0.05).

        quartiles (optional): bool
            Flag for plotting the interquartile range, in addition to the
            (1-alpha)*100% intervals (defaults to True).

        rhat (optional): bool
            Flag for plotting Gelman-Rubin statistics. Requires 2 or more
            chains (defaults to True).

        main (optional): string
            Title for main plot. Passing False results in titles being
            suppressed; passing None (default) results in default titles.

        xtitle (optional): string
            Label for x-axis. Defaults to no label

        xrange (optional): list or tuple
            Range for x-axis. Defaults to matplotlib's best guess.

        ylabels (optional): list or array
            User-defined labels for each variable. If not provided, the node
            __name__ attributes are used.

        chain_spacing (optional): float
            Plot spacing between chains (defaults to 0.05).

        vline (optional): numeric
            Location of vertical reference line (defaults to 0).

        gs : GridSpec
            Matplotlib GridSpec object. Defaults to None.

        Returns
        -------

        gs : matplotlib GridSpec

    """
    from matplotlib import gridspec

    # Quantiles to be calculated
    qlist = [100 * alpha / 2, 50, 100 * (1 - alpha / 2)]
    if quartiles:
        qlist = [100 * alpha / 2, 25, 50, 75, 100 * (1 - alpha / 2)]

    # Range for x-axis
    plotrange = None

    # Number of chains
    chains = None

    # Subplots
    interval_plot = None
    rhat_plot = None

    nchains = trace_obj.nchains
    if nchains > 1:
        from .diagnostics import gelman_rubin

        R = gelman_rubin(trace_obj)
        if varnames is not None:
            R = {v: R[v] for v in varnames}
    else:
        # Can't calculate Gelman-Rubin with a single trace
        rhat = False

    if varnames is None:
        varnames = trace_obj.varnames

    # Empty list for y-axis labels
    labels = []

    if gs is None:
        # Initialize plot
        if rhat and nchains > 1:
            gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])

        else:

            gs = gridspec.GridSpec(1, 1)

        # Subplot for confidence intervals
        interval_plot = plt.subplot(gs[0])

    trace_quantiles = quantiles(
        trace_obj, qlist, transform=transform, squeeze=False)
    hpd_intervals = hpd(trace_obj, alpha, transform=transform, squeeze=False)

    for j, chain in enumerate(trace_obj.chains):
        # Counter for current variable
        var = 1
        for varname in varnames:
            var_quantiles = trace_quantiles[chain][varname]

            quants = [var_quantiles[v] for v in qlist]
            var_hpd = hpd_intervals[chain][varname].T

            # Substitute HPD interval for quantile
            quants[0] = var_hpd[0].T
            quants[-1] = var_hpd[1].T

            # Ensure x-axis contains range of current interval
            if plotrange:
                plotrange = [min(
                             plotrange[0],
                             np.min(quants)),
                             max(plotrange[1],
                                 np.max(quants))]
            else:
                plotrange = [np.min(quants), np.max(quants)]

            # Number of elements in current variable
            value = trace_obj.get_values(varname, chains=[chain])[0]
            k = np.size(value)

            # Append variable name(s) to list
            if not j:
                if k > 1:
                    names = var_str(varname, np.shape(value))
                    labels += names
                else:
                    labels.append(varname)
                    # labels.append('\n'.join(varname.split('_')))

            # Add spacing for each chain, if more than one
            e = [0] + [(chain_spacing * ((i + 2) / 2)) *
                       (-1) ** i for i in range(nchains - 1)]

            # Deal with multivariate nodes
            if k > 1:

                for i, q in enumerate(np.transpose(quants).squeeze()):

                    # Y coordinate with jitter
                    y = -(var + i) + e[j]

                    if quartiles:
                        # Plot median
                        plt.plot(q[2], y, 'bo', markersize=4)
                        # Plot quartile interval
                        plt.errorbar(
                            x=(q[1],
                                q[3]),
                            y=(y,
                                y),
                            linewidth=2,
                            color='b')

                    else:
                        # Plot median
                        plt.plot(q[1], y, 'bo', markersize=4)

                    # Plot outer interval
                    plt.errorbar(
                        x=(q[0],
                            q[-1]),
                        y=(y,
                            y),
                        linewidth=1,
                        color='b')

            else:

                # Y coordinate with jitter
                y = -var + e[j]

                if quartiles:
                    # Plot median
                    plt.plot(quants[2], y, 'bo', markersize=4)
                    # Plot quartile interval
                    plt.errorbar(
                        x=(quants[1],
                            quants[3]),
                        y=(y,
                            y),
                        linewidth=2,
                        color='b')
                else:
                    # Plot median
                    plt.plot(quants[1], y, 'bo', markersize=4)

                # Plot outer interval
                plt.errorbar(
                    x=(quants[0],
                        quants[-1]),
                    y=(y,
                        y),
                    linewidth=1,
                    color='b')

            # Increment index
            var += k

    labels = ylabels if ylabels is not None else labels

    # Update margins
    left_margin = np.max([len(x) for x in labels]) * 0.015
    gs.update(left=left_margin, right=0.95, top=0.9, bottom=0.05)

    # Define range of y-axis
    plt.ylim(-var + 0.5, -0.5)

    datarange = plotrange[1] - plotrange[0]
    plt.xlim(plotrange[0] - 0.05 * datarange, plotrange[1] + 0.05 * datarange)

    # Add variable labels
    plt.yticks([-(l + 1) for l in range(len(labels))], labels)

    # Add title
    if main is not False:
        plot_title = main or str(int((
            1 - alpha) * 100)) + "% Credible Intervals"
        plt.title(plot_title)

    # Add x-axis label
    if xtitle is not None:
        plt.xlabel(xtitle)

    # Constrain to specified range
    if xrange is not None:
        plt.xlim(*xrange)

    # Remove ticklines on y-axes
    for ticks in interval_plot.yaxis.get_major_ticks():
        ticks.tick1On = False
        ticks.tick2On = False

    for loc, spine in interval_plot.spines.items():
        if loc in ['bottom', 'top']:
            pass
            # spine.set_position(('outward',10)) # outward by 10 points
        elif loc in ['left', 'right']:
            spine.set_color('none')  # don't draw spine

    # Reference line
    plt.axvline(vline, color='k', linestyle='--')

    # Genenerate Gelman-Rubin plot
    if rhat and nchains > 1:

        # If there are multiple chains, calculate R-hat
        rhat_plot = plt.subplot(gs[1])

        if main is not False:
            plt.title("R-hat")

        # Set x range
        plt.xlim(0.9, 2.1)

        # X axis labels
        plt.xticks((1.0, 1.5, 2.0), ("1", "1.5", "2+"))
        plt.yticks([-(l + 1) for l in range(len(labels))], "")

        i = 1
        for varname in varnames:

            chain = trace_obj.chains[0]
            value = trace_obj.get_values(varname, chains=[chain])[0]
            k = np.size(value)

            if k > 1:
                plt.plot([min(r, 2) for r in R[varname]], [-(j + i)
                                                           for j in range(k)], 'bo', markersize=4)
            else:
                plt.plot(min(R[varname], 2), -i, 'bo', markersize=4)

            i += k

        # Define range of y-axis
        plt.ylim(-i + 0.5, -0.5)

        # Remove ticklines on y-axes
        for ticks in rhat_plot.yaxis.get_major_ticks():
            ticks.tick1On = False
            ticks.tick2On = False

        for loc, spine in rhat_plot.spines.items():
            if loc in ['bottom', 'top']:
                pass
                # spine.set_position(('outward',10)) # outward by 10 points
            elif loc in ['left', 'right']:
                spine.set_color('none')  # don't draw spine

    return gs


def plot_posterior(trace, varnames=None, transform=lambda x: x, figsize=None,
                   alpha_level=0.05, round_to=3, point_estimate='mean', rope=None,
                   ref_val=None, kde_plot=False, plot_transformed=False, ax=None, **kwargs):
    """Plot Posterior densities in style of John K. Kruschke book

    Parameters
    ----------

    trace : result of MCMC run
    varnames : list of variable names
        Variables to be plotted, if None all variable are plotted
    transform : callable
        Function to transform data (defaults to identity)
    figsize : figure size tuple
        If None, size is (12, num of variables * 2) inch
    alpha_level : float
        Defines range for High Posterior Density
    round_to : int
        Controls formatting for floating point numbers
    point_estimate: str
        Must be in ('mode', 'mean', 'median')
    rope: list or numpy array
        Lower and upper values of the Region Of Practical Equivalence
    ref_val: bool
        display the percentage below and above ref_val
    kde_plot: bool
        if True plot a KDE instead of a histogram
    plot_transformed : bool
        Flag for plotting automatically transformed variables in addition to
        original variables (defaults to False).
    ax : axes
        Matplotlib axes. Defaults to None.
    **kwargs
        Passed as-is to plt.hist() or plt.plot() function, depending on the
        value of the argument kde_plot
        Some defaults are added, if not specified
        color='#87ceeb' will match the style in the book


    Returns
    -------

    ax : matplotlib axes

    """

    def plot_posterior_op(trace_values, ax):

        def format_as_percent(x, round_to=0):
            value = np.round(100 * x, round_to)
            if round_to == 0:
                value = int(value)
            return '{}%'.format(value)

        def display_ref_val(ref_val):
            less_than_ref_probability = (trace_values < ref_val).mean()
            greater_than_ref_probability = (trace_values >= ref_val).mean()
            ref_in_posterior = format_as_percent(less_than_ref_probability, 1) + ' <{:g}< '.format(ref_val) + format_as_percent(
                greater_than_ref_probability, 1)
            ax.axvline(ref_val, ymin=0.02, ymax=.75, color='g',
                       linewidth=4, alpha=0.65)
            ax.text(trace_values.mean(), plot_height * 0.6, ref_in_posterior,
                    size=14, horizontalalignment='center')

        def display_rope(rope):
            pc_in_rope = format_as_percent(np.sum((trace_values > rope[0]) &
                                                  (trace_values < rope[1])) / len(trace_values), round_to)
            ax.plot(rope, (plot_height * 0.02, plot_height * 0.02),
                    linewidth=20, color='r', alpha=0.75)
            text_props = dict(size=16, horizontalalignment='center', color='r')
            ax.text(rope[0], plot_height * 0.14, rope[0], **text_props)
            ax.text(rope[1], plot_height * 0.14, rope[1], **text_props)

        def display_point_estimate():
            if not point_estimate:
                return
            if point_estimate not in ('mode', 'mean', 'median'):
                raise ValueError(
                    "Point Estimate should be in ('mode','mean','median', None)")
            if point_estimate == 'mean':
                point_value = trace_values.mean()
                point_text = '{}={}'.format(
                    point_estimate, point_value.round(round_to))
            elif point_estimate == 'mode':
                point_value = mode(trace_values.round(round_to))[0][0]
                point_text = '{}={}'.format(
                    point_estimate, point_value.round(round_to))
            elif point_estimate == 'median':
                point_value = np.median(trace_values)
                point_text = '{}={}'.format(
                    point_estimate, point_value.round(round_to))

            ax.text(point_value, plot_height * 0.8, point_text,
                    size=16, horizontalalignment='center')

        def display_hpd():
            hpd_intervals = hpd(trace_values, alpha=alpha_level)
            ax.plot(hpd_intervals, (plot_height * 0.02,
                                    plot_height * 0.02), linewidth=4, color='k')
            text_props = dict(size=16, horizontalalignment='center')
            ax.text(hpd_intervals[0], plot_height * 0.07,
                    hpd_intervals[0].round(round_to), **text_props)
            ax.text(hpd_intervals[1], plot_height * 0.07,
                    hpd_intervals[1].round(round_to), **text_props)
            ax.text((hpd_intervals[0] + hpd_intervals[1]) / 2, plot_height * 0.2,
                    format_as_percent(1 - alpha_level) + ' HPD', **text_props)

        def format_axes():
            ax.yaxis.set_ticklabels([])
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['left'].set_visible(False)
            ax.spines['bottom'].set_visible(True)
            ax.yaxis.set_ticks_position('none')
            ax.xaxis.set_ticks_position('bottom')
            ax.tick_params(axis='x', direction='out', width=1, length=3,
                           color='0.5')
            ax.spines['bottom'].set_color('0.5')

        def set_key_if_doesnt_exist(d, key, value):
            if key not in d:
                d[key] = value

        if kde_plot:
            density = kde.gaussian_kde(trace_values)
            l = np.min(trace_values)
            u = np.max(trace_values)
            x = np.linspace(0, 1, 100) * (u - l) + l
            ax.plot(x, density(x), **kwargs)
        else:
            set_key_if_doesnt_exist(kwargs, 'bins', 30)
            set_key_if_doesnt_exist(kwargs, 'edgecolor', 'w')
            set_key_if_doesnt_exist(kwargs, 'align', 'right')
            ax.hist(trace_values, **kwargs)

        plot_height = ax.get_ylim()[1]

        format_axes()
        display_hpd()
        display_point_estimate()
        if ref_val is not None:
            display_ref_val(ref_val)
        if rope is not None:
            display_rope(rope)

    def create_axes_grid(figsize, varnames):
        n = np.ceil(len(varnames) / 2.0).astype(int)
        if figsize is None:
            figsize = (12, n * 2.5)
        fig, ax = plt.subplots(n, 2, figsize=figsize)
        ax = ax.reshape(2 * n)
        if len(varnames) % 2 == 1:
            ax[-1].set_axis_off()
            ax = ax[:-1]
        return ax, fig

    if isinstance(trace, np.ndarray):
        if figsize is None:
            figsize = (6, 2)
        if ax is None:
            fig, ax = plt.subplots()
        plot_posterior_op(transform(trace), ax)
    else:
        if varnames is None:
            if plot_transformed:
                varnames = [name for name in trace.varnames]
            else:
                varnames = [name for name in trace.varnames if not name.endswith('_')]

        if ax is None:
            ax, fig = create_axes_grid(figsize, varnames)

        for a, v in zip(ax, varnames):
            tr_values = transform(trace.get_values(
                v, combine=True, squeeze=True))
            plot_posterior_op(tr_values, ax=a)
            a.set_title(v)

        fig.tight_layout()
    return ax
back to top