https://github.com/GPflow/GPflow
Raw File
Tip revision: 2fa87edc7d9b2b6c6201a4900a1a7da6f089c604 authored by Vincent Adam on 23 June 2020, 19:55:46 UTC
seeger demo
Tip revision: 2fa87ed
convolutional.py
import numpy as np
import tensorflow as tf

from .base import Kernel
from ..base import Parameter
from ..config import default_float
from ..utilities import to_default_float


class Convolutional(Kernel):
    r"""
    Plain convolutional kernel as described in \citet{vdw2017convgp}. Defines
    a GP f( ) that is constructed from a sum of responses of individual patches
    in an image:
      f(x) = \sum_p x^{[p]}
    where x^{[p]} is the pth patch in the image.

    @incollection{vdw2017convgp,
      title = {Convolutional Gaussian Processes},
      author = {van der Wilk, Mark and Rasmussen, Carl Edward and Hensman, James},
      booktitle = {Advances in Neural Information Processing Systems 30},
      year = {2017},
      url = {http://papers.nips.cc/paper/6877-convolutional-gaussian-processes.pdf}
    }
    """

    def __init__(self, base_kernel, image_shape, patch_shape, weights=None, colour_channels=1):
        super().__init__()
        self.image_shape = image_shape
        self.patch_shape = patch_shape
        self.base_kernel = base_kernel
        self.colour_channels = colour_channels
        self.weights = Parameter(
            np.ones(self.num_patches, dtype=default_float()) if weights is None else weights
        )

    # @lru_cache() -- Can we do some kind of memoizing with TF2?
    def get_patches(self, X):
        """
        Extracts patches from the images X. Patches are extracted separately for each of the colour channels.
        :param X: (N x input_dim)
        :return: Patches (N, num_patches, patch_shape)
        """
        # Roll the colour channel to the front, so it appears to
        # `tf.extract_image_patches()` as separate images. Then extract patches
        # and reshape to have the first axis the same as the number of images.
        # The separate patches will then be in the second axis.
        num_data = tf.shape(X)[0]
        castX = tf.transpose(tf.reshape(X, [num_data, -1, self.colour_channels]), [0, 2, 1])
        patches = tf.image.extract_patches(
            tf.reshape(castX, [-1, self.image_shape[0], self.image_shape[1], 1], name="rX"),
            [1, self.patch_shape[0], self.patch_shape[1], 1],
            [1, 1, 1, 1],
            [1, 1, 1, 1],
            "VALID",
        )
        shp = tf.shape(patches)  # img x out_rows x out_cols
        reshaped_patches = tf.reshape(
            patches, [num_data, self.colour_channels * shp[1] * shp[2], shp[3]]
        )
        return to_default_float(reshaped_patches)

    def K(self, X, X2=None):
        Xp = self.get_patches(X)  # [N, P, patch_len]
        Xp2 = Xp if X2 is None else self.get_patches(X2)

        bigK = self.base_kernel.K(Xp, Xp2)  # [N, num_patches, N, num_patches]

        W2 = self.weights[:, None] * self.weights[None, :]  # [P, P]
        W2bigK = bigK * W2[None, :, None, :]
        return tf.reduce_sum(W2bigK, [1, 3]) / self.num_patches ** 2.0

    def K_diag(self, X):
        Xp = self.get_patches(X)  # N x num_patches x patch_dim
        W2 = self.weights[:, None] * self.weights[None, :]  # [P, P]
        bigK = self.base_kernel.K(Xp)  # [N, P, P]
        return tf.reduce_sum(bigK * W2[None, :, :], [1, 2]) / self.num_patches ** 2.0

    @property
    def patch_len(self):
        return np.prod(self.patch_shape)

    @property
    def num_patches(self):
        return (
            (self.image_shape[0] - self.patch_shape[0] + 1)
            * (self.image_shape[1] - self.patch_shape[1] + 1)
            * self.colour_channels
        )
back to top