https://github.com/GPflow/GPflow
Raw File
Tip revision: 00073d8dfa0c4cee80597fe8adb0324a7f72e7a5 authored by Sergio Diaz on 16 September 2019, 10:10:17 UTC
Merge branch 'awav/gpflow-2.0' into sergio_pasc/gpflow-2.0/ordinal_regression
Tip revision: 00073d8
intro_to_gpflow2_plotting.py
import io
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf


def summary_matplotlib_image(figures, step, fmt="png"):
    for name, fig in figures.items():
        buf = io.BytesIO()
        fig.savefig(buf, format=fmt, bbox_inches='tight')
        buf.seek(0)
        image = buf.getvalue()
        image = tf.image.decode_image(buf.getvalue(), channels=4)
        image = tf.expand_dims(image, 0)
        tf.summary.image(name=name, data=image, step=step)


def plotting_regression(X, Y, xx, mean, var, samples):
    fig = plt.figure(figsize=(12, 6))
    ax = fig.add_subplot(111)
    ax.plot(xx, mean, 'C0', lw=2)
    ax.fill_between(xx[:, 0],
                    mean[:, 0] - 1.96 * np.sqrt(var[:, 0]),
                    mean[:, 0] + 1.96 * np.sqrt(var[:, 0]),
                    color='C0',
                    alpha=0.2)
    ax.plot(X, Y, 'kx')
    ax.plot(xx, samples[:, :, 0].numpy().T, 'C0', linewidth=.5)
    ax.set_ylim(-2., +2.)
    ax.set_xlim(0, 10)
    plt.close()
    return fig
back to top