Raw File
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mixing TensorFlow models with GPflow\n",
    "\n",
    "This notebook explores the combination of Keras TensorFlow neural networks with GPflow models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-20T10:53:12.620472Z",
     "start_time": "2018-06-20T10:53:11.541346Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from matplotlib import pyplot as plt\n",
    "import gpflow\n",
    "from gpflow.ci_utils import ci_niter\n",
    "from scipy.cluster.vq import kmeans2\n",
    "\n",
    "from typing import Dict, Optional, Tuple\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "import gpflow\n",
    "from gpflow.config import default_float\n",
    "\n",
    "iterations = ci_niter(100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convolutional network inside a GPflow model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-20T10:53:16.316761Z",
     "start_time": "2018-06-20T10:53:12.621686Z"
    }
   },
   "outputs": [],
   "source": [
    "original_dataset, info = tfds.load(name=\"mnist\", split=tfds.Split.TRAIN, with_info=True)\n",
    "total_num_data = info.splits[\"train\"].num_examples\n",
    "image_shape = info.features['image'].shape\n",
    "image_size = tf.reduce_prod(image_shape)\n",
    "batch_size = 32\n",
    "\n",
    "def map_fn(input_slice: Dict[str, tf.Tensor]):\n",
    "    updated = input_slice\n",
    "    image = tf.cast(updated[\"image\"], default_float()) / 255.\n",
    "    label = tf.cast(updated[\"label\"], default_float())\n",
    "    return tf.reshape(image, [-1, image_size]), label\n",
    "\n",
    "autotune = tf.data.experimental.AUTOTUNE\n",
    "dataset = original_dataset\\\n",
    "    .shuffle(1024)\\\n",
    "    .batch(batch_size, drop_remainder=True)\\\n",
    "    .map(map_fn, num_parallel_calls=autotune)\\\n",
    "    .prefetch(autotune)\\\n",
    "    .repeat()\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we'll use the GPflow functionality, but put a non-GPflow model inside the kernel.\\\n",
    "Vanilla ConvNet. This gets 97.3% accuracy on MNIST when used on its own (+ final linear layer) after 20K iterations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-20T10:53:16.335949Z",
     "start_time": "2018-06-20T10:53:16.317967Z"
    }
   },
   "outputs": [],
   "source": [
    "class KernelWithConvNN(gpflow.kernels.Kernel):\n",
    "    def __init__(self, image_shape: Tuple,\n",
    "                 output_dim: int,\n",
    "                 base_kernel: gpflow.kernels.Kernel,\n",
    "                 batch_size: Optional[int] = None):\n",
    "        super().__init__()\n",
    "        with self.name_scope:\n",
    "            self.base_kernel = base_kernel\n",
    "            input_size = int(tf.reduce_prod(image_shape))\n",
    "            input_shape = (input_size, )\n",
    "            \n",
    "            self.cnn = tf.keras.Sequential([\n",
    "                tf.keras.layers.InputLayer(input_shape=input_shape, batch_size=batch_size),\n",
    "                tf.keras.layers.Reshape(image_shape),\n",
    "                tf.keras.layers.Conv2D(filters=32, kernel_size=image_shape[:-1], padding=\"same\", activation=\"relu\"),\n",
    "                tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2),\n",
    "                tf.keras.layers.Conv2D(filters=64, kernel_size=(5, 5), padding=\"same\", activation=\"relu\"),\n",
    "                tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2),\n",
    "                tf.keras.layers.Flatten(),\n",
    "                tf.keras.layers.Dense(output_dim, activation=\"relu\"),\n",
    "                tf.keras.layers.Lambda(lambda x: tf.cast(x, default_float()))\n",
    "            ])\n",
    "            \n",
    "            self.cnn.build()\n",
    "    \n",
    "    def K(self, a_input: tf.Tensor, b_input: Optional[tf.Tensor] = None, presliced: bool = False) -> tf.Tensor:\n",
    "        transformed_a = self.cnn(a_input)\n",
    "        transformed_b = self.cnn(b_input) if b_input is not None else b_input\n",
    "        return self.base_kernel.K(transformed_a, transformed_b, presliced)\n",
    "    \n",
    "    def K_diag(self, a_input: tf.Tensor, presliced: bool = False) -> tf.Tensor:\n",
    "        transformed_a = self.cnn(a_input)\n",
    "        return self.base_kernel.K_diag(transformed_a, presliced)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$K_{uf}$ is in ConvNN output space, therefore we need to update `Kuf` multidispatch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-20T10:53:16.475537Z",
     "start_time": "2018-06-20T10:53:16.337305Z"
    }
   },
   "outputs": [],
   "source": [
    "class KernelSpaceInducingPoints(gpflow.inducing_variables.InducingPoints):\n",
    "    pass\n",
    "\n",
    "@gpflow.covariances.Kuu.register(KernelSpaceInducingPoints, KernelWithConvNN)\n",
    "def Kuu(inducing_variable, kernel, jitter=None):\n",
    "    func = gpflow.covariances.Kuu.dispatch(gpflow.inducing_variables.InducingPoints, gpflow.kernels.Kernel)\n",
    "    return func(inducing_variable, kernel.base_kernel, jitter=jitter)\n",
    "\n",
    "@gpflow.covariances.Kuf.register(KernelSpaceInducingPoints, KernelWithConvNN, object)\n",
    "def Kuf(inducing_variable, kernel, a_input):\n",
    "    return kernel.base_kernel(inducing_variable.Z, kernel.cnn(a_input))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to create and initialize the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-06-20T11:10:46.858524Z",
     "start_time": "2018-06-20T10:53:16.477021Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Layer reshape is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.\n",
      "\n",
      "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
      "\n",
      "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Layer reshape is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.\n",
      "\n",
      "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
      "\n",
      "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "num_mnist_classes = 10\n",
    "output_dim = 5\n",
    "num_inducing_points = 100\n",
    "images_subset, labels_subset = next(iter(dataset.batch(32)))\n",
    "images_subset = tf.reshape(images_subset, [-1, image_size])\n",
    "labels_subset = tf.reshape(labels_subset, [-1, 1])\n",
    "\n",
    "kernel = KernelWithConvNN(image_shape,\n",
    "                          output_dim,\n",
    "                          gpflow.kernels.SquaredExponential(),\n",
    "                          batch_size=batch_size)\n",
    "\n",
    "likelihood = gpflow.likelihoods.MultiClass(num_mnist_classes)\n",
    "\n",
    "inducing_variable_kmeans = kmeans2(images_subset.numpy(), num_inducing_points, minit='points')[0]\n",
    "inducing_variable_cnn = kernel.cnn(inducing_variable_kmeans)\n",
    "inducing_variable = KernelSpaceInducingPoints(inducing_variable_cnn)\n",
    "\n",
    "model = gpflow.models.SVGP(kernel, likelihood,\n",
    "                           inducing_variable=inducing_variable,\n",
    "                           num_data=total_num_data,\n",
    "                           num_latent=num_mnist_classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And start optimization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /Users/artemav/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/math_grad.py:281: setdiff1d (from tensorflow.python.ops.array_ops) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "This op will be removed after the deprecation date. Please switch to tf.sets.difference().\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /Users/artemav/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/math_grad.py:281: setdiff1d (from tensorflow.python.ops.array_ops) is deprecated and will be removed after 2018-11-30.\n",
      "Instructions for updating:\n",
      "This op will be removed after the deprecation date. Please switch to tf.sets.difference().\n"
     ]
    }
   ],
   "source": [
    "data_iterator = iter(dataset)\n",
    "adam_opt = tf.optimizers.Adam(0.001)\n",
    "\n",
    "\n",
    "@tf.function(autograph=False)\n",
    "def loss_cb(batch: Tuple[tf.Tensor, tf.Tensor]):\n",
    "    loss_value = -model.elbo(batch)\n",
    "    return loss_value\n",
    "\n",
    "\n",
    "@tf.function(autograph=False)\n",
    "def optimization_step():\n",
    "    batch = next(data_iterator)\n",
    "    func = lambda: loss_cb(batch)\n",
    "    adam_opt.minimize(func, var_list=model.trainable_variables)\n",
    "\n",
    "\n",
    "for _ in range(iterations):\n",
    "    optimization_step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's do predictions after training. Don't expect that we will get a good accuracy, because we haven't run training for long enough."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy is 97.2266%\n"
     ]
    }
   ],
   "source": [
    "m, v = model.predict_y(images_subset)\n",
    "preds = np.argmax(m, 1).reshape(labels_subset.numpy().shape)\n",
    "correct = (preds == labels_subset.numpy().astype(int))\n",
    "acc = np.average(correct.astype(float)) * 100.\n",
    "\n",
    "print('Accuracy is {:.4f}%'.format(acc))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
back to top