import numpy as np import tensorflow as tf import tensorflow_probability as tfp from ..base import Parameter from ..config import default_int class RobustMax(tf.Module): """ This class represent a multi-class inverse-link function. Given a vector f=[f_1, f_2, ... f_k], the result of the mapping is y = [y_1 ... y_k] with y_i = (1-eps) i == argmax(f) eps/(k-1) otherwise. """ def __init__(self, num_classes, epsilon=1e-3, **kwargs): super().__init__(**kwargs) transform = tfp.bijectors.Sigmoid() prior = tfp.distributions.Beta(0.2, 5.) self.epsilon = Parameter(epsilon, transform=transform, prior=prior, trainable=False) self.num_classes = num_classes def __call__(self, F): i = tf.argmax(F, 1) return tf.one_hot(i, self.num_classes, tf.squeeze(1. - self.epsilon), tf.squeeze(self.eps_k1)) @property def eps_k1(self): return self.epsilon / (self.num_classes - 1.) def prob_is_largest(self, Y, mu, var, gh_x, gh_w): Y = tf.cast(Y, default_int()) # work out what the mean and variance is of the indicated latent function. oh_on = tf.cast(tf.one_hot(tf.reshape(Y, (-1, )), self.num_classes, 1., 0.), dtype=mu.dtype) mu_selected = tf.reduce_sum(oh_on * mu, 1) var_selected = tf.reduce_sum(oh_on * var, 1) # generate Gauss Hermite grid X = tf.reshape(mu_selected, (-1, 1)) + gh_x * tf.reshape(tf.sqrt(tf.clip_by_value(2. * var_selected, 1e-10, np.inf)), (-1, 1)) # compute the CDF of the Gaussian between the latent functions and the grid (including the selected function) dist = (tf.expand_dims(X, 1) - tf.expand_dims(mu, 2)) / tf.expand_dims( tf.sqrt(tf.clip_by_value(var, 1e-10, np.inf)), 2) cdfs = 0.5 * (1.0 + tf.math.erf(dist / np.sqrt(2.0))) cdfs = cdfs * (1 - 2e-4) + 1e-4 # blank out all the distances on the selected latent function oh_off = tf.cast(tf.one_hot(tf.reshape(Y, (-1, )), self.num_classes, 0., 1.), dtype=mu.dtype) cdfs = cdfs * tf.expand_dims(oh_off, 2) + tf.expand_dims(oh_on, 2) # take the product over the latent functions, and the sum over the GH grid. return tf.reduce_prod(cdfs, axis=[1]) @ tf.reshape(gh_w / np.sqrt(np.pi), (-1, 1))