https://github.com/jramapuram/LifelongVAE
Raw File
Tip revision: 91a1d61826637ed5d173679af25faca997532ee4 authored by Jason Ramapuram on 06 December 2017, 10:57:01 UTC
add fashion; fixes for newer tf distributions
Tip revision: 91a1d61
lifelong_vae.py
import os
import sys
import datetime
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.distributions as distributions
# from tensorflow.python.training.moving_averages import weighted_moving_average
from reparameterizations import gumbel_reparmeterization, gaussian_reparmeterization
from encoders import forward, DenseEncoder, CNNEncoder, copy_layer, reinit_last_layer
from decoders import CNNDecoder
from utils import *

sg = tf.contrib.bayesflow.stochastic_graph
st = tf.contrib.bayesflow.stochastic_tensor
sys.setrecursionlimit(200)

# Global variables
GLOBAL_ITER = 0  # keeps track of the iteration ACROSS models
TRAIN_ITER = 0  # the iteration of the current model


class VAE(object):
    """ Online Variational Autoencoder with consistent sampling.

    See "Auto-Encoding Variational Bayes" by Kingma and Welling
    for more details on the original work.

    Note: reparam_type is fixed; provided in ctor due to compatibility
          with the vanilla VAE
    """
    def __init__(self, sess, x, input_size, batch_size, latent_size,
                 encoder, decoder, is_training, discrete_size, activation=tf.nn.elu,
                 reconstr_loss_type="binary_cross_entropy", learning_rate=1e-4,
                 submodel=0, total_true_models=0, vae_tm1=None,
                 p_x_given_z_func=distributions.Bernoulli,
                 base_dir=".", mutual_info_reg=0.0, img_shape=[28, 28, 1]):
        self.x = x
        self.activation = activation
        self.learning_rate = learning_rate
        self.is_training = is_training
        self.encoder_model = encoder
        self.decoder_model = decoder
        self.vae_tm1 = vae_tm1
        self.p_x_given_z_func = p_x_given_z_func
        self.global_iter_base = GLOBAL_ITER
        self.input_size = input_size
        self.latent_size = latent_size
        self.batch_size = batch_size
        self.img_shape = img_shape
        self.iteration = 0
        self.test_epoch = 0
        self.submodel = submodel
        self.total_true_models = total_true_models
        self.mutual_info_reg = mutual_info_reg
        self.reconstr_loss_type = reconstr_loss_type
        self.num_discrete = discrete_size
        print 'latent size = ', self.latent_size, ' | disc size = ', self.num_discrete
        self.base_dir = base_dir  # dump all our stuff into this dir

        # gumbel params
        self.tau0 = 1.0
        self.tau_host = self.tau0
        self.anneal_rate = 0.00003
        # self.anneal_rate = 0.0003 #1e-5
        self.min_temp = 0.5

        # sess & graph
        self.sess = sess
        # self.graph = tf.Graph()

        # create these in scope
        self._create_variables(x)

        # Create autoencoder network
        self._create_network()

        # Define loss function based variational upper-bound and
        # corresponding optimizer
        self._create_loss_optimizer()

        # create the required directories to hold data for this specific model
        self._create_local_directories()

        # Create all the summaries and their corresponding ops
        self._create_summaries()

        # Check for NaN's
        # self.check_op = tf.add_check_numerics_ops()

        # collect variables & build saver
        self.vae_vars = [v for v in tf.global_variables()
                         if v.name.startswith(self.get_name())]
        self.vae_local_vars = [v for v in tf.local_variables()
                               if v.name.startswith(self.get_name())]
        self.saver = tf.train.Saver(tf.global_variables())  # XXX: use local
        self.init_op = tf.variables_initializer(self.vae_vars
                                                + self.vae_local_vars)

        print 'model: ', self.get_name()
        # print 'there are ', len(self.vae_vars), ' vars in ', \
        #     tf.get_variable_scope().name, ' out of a total of ', \
        #     len(tf.global_variables()), ' with %d total trainable vars' \
        #     % len(tf.trainable_variables())

    '''
    Helper to create the :
         1) experiment_%d/models directory
         2) experiment_%d/imgs directory
         3) experiment_%d/logs directory
    '''
    def _create_local_directories(self):
        models_dir = '%s/models' % (self.base_dir)
        if not os.path.exists(models_dir):
            os.makedirs(models_dir)

        imgs_dir = '%s/imgs' % (self.base_dir)
        if not os.path.exists(imgs_dir):
            os.makedirs(imgs_dir)

        logs_dir = '%s/logs' % (self.base_dir)
        if not os.path.exists(logs_dir):
            os.makedirs(logs_dir)

    def _create_variables(self, x_placeholder):
        with tf.variable_scope(self.get_name()):
            # Create the placeholders if we are at the first model
            # Else simply pull the references
            # if self.submodel == 0:
            #     self.x = tf.placeholder(tf.float32, shape=[self.batch_size,
            #                                                self.input_size],
            #                             name="input_placeholder")
            # else:
            #     self.x = self.vae_tm1.x

            # gpu iteration count
            self.iteration_gpu = tf.Variable(0.0, trainable=False)
            self.iteration_gpu_op = self.iteration_gpu.assign_add(1.0)

            # gumbel related
            self.tau = tf.Variable(5.0, trainable=False, dtype=tf.float32,
                                   name="temperature")
            # self.ema = tf.train.ExponentialMovingAverage(decay=0.9999)

    '''
    A helper function to create all the summaries.
    Adds things like image_summary, histogram_summary, etc.
    '''
    def _create_summaries(self):
        # Summaries and saver
        summaries = [tf.summary.scalar("vae_loss_mean", self.cost_mean),
                     tf.summary.scalar("vae_negative_elbo", self.elbo_mean),
                     tf.summary.scalar("vae_latent_loss_mean", self.latent_loss_mean),
                     tf.summary.scalar("vae_grad_norm", self.grad_norm),
                     tf.summary.scalar("bits_per_dim", self.generate_bits_per_dim()),
                     tf.summary.scalar("vae_selected_class", tf.argmax(tf.reduce_sum(self.z_pre_gumbel, 0), 0)),
                     tf.summary.scalar("vae_selected_class_xtm1", tf.argmax(tf.reduce_sum(self.z_pre_gumbel[self.num_current_data:], 0), 0)),
                     tf.summary.histogram("vae_kl_normal", self.kl_normal),
                     tf.summary.histogram("vae_kl_discrete", self.kl_discrete),
                     tf.summary.histogram("vae_latent_dist", self.latent_kl),
                     tf.summary.scalar("vae_latent_loss_max", tf.reduce_max(self.latent_kl)),
                     tf.summary.scalar("vae_latent_loss_min", tf.reduce_min(self.latent_kl)),
                     tf.summary.scalar("vae_reconstr_loss_mean", self.reconstr_loss_mean),
                     tf.summary.scalar("vae_reconstr_loss_max", tf.reduce_max(self.reconstr_loss)),
                     tf.summary.scalar("vae_reconstr_loss_min", tf.reduce_min(self.reconstr_loss)),
                     tf.summary.histogram("z_dist", self.z)]

        # Display image summaries : i.e. samples from P(X|Z=z_i)
        # Visualize:
        #           1) augmented images;
        #           2) original images[current distribution]
        #           3) reconstructed images
        dimensions = len(shp(self.x))
        if dimensions == 2:
            x_orig, x_aug, x_reconstr = shuffle_jointly(self.x, self.x_augmented, # noqa
                                                        self.p_x_given_z.mean())
        else:
            # TODO: modify shuffle jointly for 3d images
            x_orig, x_aug, x_reconstr = [self.x,
                                         self.x_augmented,
                                         self.p_x_given_z.mean()]

        img_shp = [self.batch_size] + self.img_shape
        image_summaries = [tf.summary.image("x_augmented_t", tf.reshape(x_aug, img_shp), # noqa
                                            max_outputs=self.batch_size),
                           tf.summary.image("x_t", tf.reshape(x_orig, img_shp),
                                            max_outputs=self.batch_size),
                           tf.summary.image("x_reconstr_mean_activ_t",
                                            tf.reshape(x_reconstr, img_shp),
                                            max_outputs=self.batch_size)]

        # In addition show the following if they exist:
        #          4) Images from previous interval
        #          5) Distilled KL Divergence
        if hasattr(self, 'xhat_tm1'):
            with tf.variable_scope(self.get_name()):  # accuracy operator
                # selected_classes_for_xtm1 = tf.argmax(self.z_discrete[self.num_current_data:], 0)
                # selected_classes_by_vae_tm1 = tf.argmax(self.q_z_t_given_x_t, 0)
                selected_classes_for_xtm1 = self.z_pre_gumbel[self.num_current_data:]  # self.z_discrete[self.num_current_data:]
                selected_classes_by_vae_tm1 = self.q_z_t_given_x_t
                correct_prediction = tf.equal(tf.argmax(selected_classes_by_vae_tm1, 1),
                                              tf.argmax(selected_classes_for_xtm1, 1))
                self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

            num_xhat_tm1 = self.xhat_tm1.get_shape().as_list()
            image_summaries += [tf.summary.image("xhat_tm1",
                                                 tf.reshape(self.xhat_tm1, img_shp),
                                                 max_outputs=num_xhat_tm1[0])]
            summaries += [tf.summary.scalar("vae_tm1_selected_class", tf.argmax(tf.reduce_sum(self.q_z_t_given_x_t, 0), 0)),
                          tf.summary.scalar("consistency_accuracy", self.accuracy),
                          tf.summary.scalar("vae_kl_distill_mean",
                                            tf.reduce_mean(self.kl_consistency))]

        # Merge all the summaries, but ensure we are post-activation
        # keep the image summaries separate, but also include the regular
        # summaries in them
        #with tf.control_dependencies([self.p_x_given_z_logits]):
        self.summaries = tf.summary.merge(summaries)
        self.image_summaries = tf.summary.merge(image_summaries
                                                + summaries)

        # Write all summaries to logs, but VARY the model name AND add a TIMESTAMP
        # current_summary_name = self.get_name() + self.get_formatted_datetime()
        self.train_summary_writer = tf.summary.FileWriter("%s/logs/train" % self.base_dir,
                                                          self.sess.graph,
                                                          flush_secs=60)
        self.test_summary_writer = tf.summary.FileWriter("%s/logs/test" % self.base_dir,
                                                         self.sess.graph,
                                                         flush_secs=60)

    def generate_bits_per_dim(self):
        num_pixels = np.prod(self.img_shape[1:])
        batch_size = self.img_shape[0]
        return self.elbo_mean / (np.log(2.) * num_pixels *
                                 batch_size)

    '''
    A helper function to format the name as a function of the hyper-parameters
    '''
    def get_name(self):
        if self.submodel == 0:
            full_hash_str = self.activation.__name__ \
                            + '_enc' + str(self.encoder_model.get_sizing()) \
                            + '_dec' + str(self.decoder_model.get_sizing()) \
                            + "_learningrate" + str(self.learning_rate) \
                            + "_latentsize" + str(self.latent_size) \
                            + "_discsize" + str(self.num_discrete) \
                            + "_mutintoreg" + str(self.mutual_info_reg)

            full_hash_str = full_hash_str.strip().lower().replace('[', '')  \
                                                         .replace(']', '')  \
                                                         .replace(' ', '')  \
                                                         .replace('{', '') \
                                                         .replace('}', '') \
                                                         .replace(',', '_') \
                                                         .replace(':', '') \
                                                         .replace('(', '') \
                                                         .replace(')', '') \
                                                         .replace('\'', '')
            return 'vae%d_' % self.submodel + full_hash_str
        else:
            vae_tm1_name = self.vae_tm1.get_name()
            indexof = vae_tm1_name.find('_')
            return 'vae%d_' % self.submodel + vae_tm1_name[indexof+1:]

    def get_formatted_datetime(self):
        return str(datetime.datetime.now()).replace(" ", "_") \
                                           .replace("-", "_") \
                                           .replace(":", "_")

    def save(self):
        model_filename = "%s/models/%s.cpkt" % (self.base_dir, self.get_name())
        print 'saving vae model to %s...' % model_filename
        self.saver.save(self.sess, model_filename)

    def restore(self):
        model_filename = "%s/models/%s.cpkt" % (self.base_dir, self.get_name())
        print 'into restore, model name = ', model_filename
        if os.path.isfile(model_filename):
            print 'restoring vae model from %s...' % model_filename
            self.saver.restore(self.sess, model_filename)

    @staticmethod
    def kl_categorical(p=None, q=None, p_logits=None, q_logits=None, eps=1e-6):
        '''
        Given p and q (as EITHER BOTH logits or softmax's)
        then this func returns the KL between them.

        Utilizes an eps in order to resolve divide by zero / log issues
        '''
        if p_logits is not None and q_logits is not None:
            Q = distributions.Categorical(logits=q_logits, dtype=tf.float32)
            P = distributions.Categorical(logits=p_logits, dtype=tf.float32)
        elif p is not None and q is not None:
            print 'p shp = ', p.get_shape().as_list(), \
                ' | q shp = ', q.get_shape().as_list()
            Q = distributions.Categorical(probs=q+eps, dtype=tf.float32)
            P = distributions.Categorical(probs=p+eps, dtype=tf.float32)
        else:
            raise Exception("please provide either logits or dists")

        return distributions.kl_divergence(P, Q)

    @staticmethod
    def zero_pad_smaller_cat(cat1, cat2):
        c1shp = cat1.get_shape().as_list()
        c2shp = cat2.get_shape().as_list()
        diff = abs(c1shp[1] - c2shp[1])

        # blend in extra zeros appropriately
        if c1shp[1] > c2shp[1]:
            cat2 = tf.concat([cat2, tf.zeros([c2shp[0], diff])], axis=1)
        elif c2shp[1] > c1shp[1]:
            cat1 = tf.concat([cat1, tf.zeros([c1shp[0], diff])], axis=1)
        return [cat1, cat2]

    def _create_constraints(self):
        # 0.) add in a kl term between the old model's posterior
        #     and the current model's posterior using the
        #     data generated from the previous model [for the discrete ONLY]
        #
        # Recall data is : [current_data ; old_data]
        if hasattr(self, 'xhat_tm1'):
            # First we encode the generated data w/the student
            # Note: encode returns z, z_normal, z_discrete,
            #                      kl_normal, kl_discrete
            # Note2: discrete dimension is self.submodel
            self.q_z_s_given_x_t = self.z_pre_gumbel[self.num_current_data:]
            assert self.q_z_s_given_x_t.get_shape().as_list()[0] \
                == self.num_old_data
            # _, _, self.q_z_s_given_x_t, _, _ \
            #     = self.encoder(self.xhat_tm1,
            #                    rnd_sample=None,
            #                    hard=False,  # True?
            #                    reuse=True)


            # We also need to encode the data back through the teacher
            # This is necessary because we need to evaluate the posterior
            # in order to compare Q^T(x|z) against Q^S(x|z)
            # Note2: discrete dimension is self.submodel - 1 [possibly?]
            rnd_sample = self.rnd_sample[:, 0:self.vae_tm1.num_discrete]
            _, _, _, self.q_z_t_given_x_t, _, _ \
                = self.vae_tm1.encoder(self.xhat_tm1,
                                       rnd_sample=rnd_sample,
                                       hard=False,  # True?
                                       reuse=True)

            # Get the number of gaussians for student and teacher
            # We also only consider num_old_data of the batch
            self.q_z_t_given_x_t = self.q_z_t_given_x_t[0:self.num_old_data]
            self.q_z_s_given_x_t, self.q_z_t_given_x_t \
                = VAE.zero_pad_smaller_cat(self.q_z_s_given_x_t,
                                           self.q_z_t_given_x_t)

            # Now we ONLY want eval the KL on the discrete z
            # below is the reverse KL:
            kl = self.kl_categorical(q=self.q_z_t_given_x_t,
                                     p=self.q_z_s_given_x_t)

            # forward KL :
            # kl = self.kl_categorical(q=self.q_z_s_given_x_t,
            #                          p=self.q_z_t_given_x_t)
            print 'kl_consistency [prepad] : ', kl.get_shape().as_list()
            kl = [tf.zeros([self.num_current_data]), kl]
            self.kl_consistency = tf.concat(axis=0, values=kl)
        else:
            self.q_z_given_x = tf.zeros_like(self.x)
            self.kl_consistency = tf.zeros([self.batch_size], dtype=tf.float32)

    @staticmethod
    def reparameterize(encoded, num_discrete, tau, hard=False,
                       rnd_sample=None, eps=1e-20):
        eshp = encoded.get_shape().as_list()
        print("encoded = ", eshp)
        num_normal = eshp[1] - num_discrete
        print 'num_normal = ', num_normal
        logits_normal = encoded[:, 0:num_normal]
        logits_gumbel = encoded[:, num_normal:eshp[1]]

        # we reparameterize using both the N(0, I) and the gumbel(0, 1)
        z_discrete, kl_discrete = gumbel_reparmeterization(logits_gumbel,
                                                           tau,
                                                           rnd_sample,
                                                           hard)
        z_n, kl_n = gaussian_reparmeterization(logits_normal)

        # merge and pad appropriately
        z = tf.concat([z_n, z_discrete], axis=1)

        return [slim.flatten(z),
                slim.flatten(z_n),
                slim.flatten(z_discrete),
                slim.flatten(tf.nn.softmax(logits_gumbel)),
                kl_n,
                kl_discrete]

    def encoder(self, X, rnd_sample=None, reuse=False, hard=False):
        with tf.variable_scope(self.get_name() + "/encoder", reuse=reuse):
            encoded = forward(X, self.encoder_model)
            return VAE.reparameterize(encoded, self.num_discrete,
                                      self.tau, hard=hard,
                                      rnd_sample=rnd_sample)

    def generator(self, Z, reuse=False):
        with tf.variable_scope(self.get_name() + "/generator", reuse=reuse):
            print 'generator scope: ', tf.get_variable_scope().name
            logits = forward(Z, self.decoder_model)

            if self.p_x_given_z_func == distributions.Bernoulli:
                print 'generator: using bernoulli'
                return self.p_x_given_z_func(logits=logits)
            elif (self.p_x_given_z_func == distributions.Normal or self.p_x_given_z_func == distributions.Logistic) \
                 and self.encoder_model.layer_type == 'cnn':
                print 'generator: using exponential family [cnn]'
                channels = shp(logits)[3]
                assert channels % 2 == 0, "need to project to 2x the channels for gaussian p(x|z)"
                loc = logits[:, :, :, channels/2:]  # tf.nn.sigmoid(logits[:, :, :, channels/2:])
                scale = 1e-6 + tf.nn.softplus(logits[:, :, :, 0:channels/2])
                return self.p_x_given_z_func(loc=loc,
                                             scale=scale)
            elif (self.p_x_given_z_func == distributions.Normal or self.p_x_given_z_func == distributions.Logistic) \
                 and self.encoder_model.layer_type == 'dnn':
                print 'generator: using exponential family [dnn]'
                features = shp(logits)[-1]
                assert features % 2 == 0, "need to project to 2x the channels for gaussian p(x|z)"
                loc = logits[:, features/2:]  # tf.nn.sigmoid(logits[:, :, :, channels/2:])
                scale = 1e-6 + tf.nn.softplus(logits[:, 0:features/2])
                return self.p_x_given_z_func(loc=loc,
                                             scale=scale)
            else:
                raise Exception("unknown distribution provided for likelihood")

    def _augment_data(self):
        '''
        Augments [current_data ; old_data]
        '''
        def _train():
            if hasattr(self, 'xhat_tm1'):  # make sure we have forked
                # zero pad the current data on the bottom and add to
                # the data we generated in _generate_vae_tm1_data()
                full_data = [self.x[0:self.num_current_data],
                             self.xhat_tm1[0:self.num_old_data]]
                combined = tf.concat(axis=0, values=full_data,
                                     name="current_data")
            else:
                combined = self.x

            print 'augmented data = ', combined.get_shape().as_list()
            return combined

        def _test():
            return self.x

        #return tf.cond(self.is_training, _train, _test)
        return _train()

    def generate_at_least(self, vae_tm1, batch_size):
        # Returns :
        # 1) a categorical and a Normal distribution concatenated
        # 2) x_hat_tm1 : the reconstructed data from the old model
        print 'generating data from previous #discrete: ', vae_tm1.num_discrete
        z_cat = generate_random_categorical(vae_tm1.num_discrete,
                                            batch_size)
        z_normal = tf.random_normal([batch_size, vae_tm1.latent_size])
        z = tf.concat([z_normal, z_cat], axis=1)
        zshp = z.get_shape().as_list()  # TODO: debug trace
        print 'z_generated = ', zshp

        # Generate reconstructions of historical Z's
        # xr = tf.stop_gradient(tf.nn.sigmoid(vae_tm1.generator(z, reuse=True)))
        p_x_given_z_tm1 = vae_tm1.generator(z, reuse=True)
        return [z, z_cat, p_x_given_z_tm1.mean()]

    def _generate_vae_tm1_data(self):
        if self.vae_tm1 is not None:
            num_instances = self.x.get_shape().as_list()[0]
            self.num_current_data = int((1.0/(self.total_true_models + 1.0))
                                        * float(num_instances))
            self.num_old_data = num_instances - self.num_current_data
            # TODO: Remove debug trace
            print 'total instances: %d | current_model: %d | current_true_models: %d | current data number: %d | old data number: %d'\
                % (num_instances, self.submodel, self.total_true_models,
                   self.num_current_data, self.num_old_data)

            if self.num_old_data > 0:  # make sure we aren't in base case
                # generate data by randomly sampling a categorical for
                # N-1 positions; also sample a N(0, I) in order to
                # generate variability
                self.z_tm1, self.z_discrete_tm1, self.xhat_tm1 \
                    = self.generate_at_least(self.vae_tm1,
                                             self.batch_size)

                print 'z_tm1 = ', self.z_tm1.get_shape().as_list(), \
                    '| xhat_tm1 = ', self.xhat_tm1.get_shape().as_list()

    @staticmethod
    def _z_to_one_hot(z, latent_size):
        indices = tf.arg_max(z, 1)
        return tf.one_hot(indices, latent_size, dtype=tf.float32)

    def _shuffle_all_data_together(self):
        if not hasattr(self, 'shuffle_indices'):
            self.shuffle_indices = np.random.permutation(self.batch_size)

        if self.vae_tm1 is not None:
            # we get the total size of the cols and jointly shuffle
            # using the perms generated above.
            self.x_augmented = shuffle_rows_based_on_indices(self.shuffle_indices,
                                                             self.x_augmented)

    '''
    Helper op to create the network structure
    '''
    def _create_network(self, num_test_memories=10):
        self.num_current_data = self.x.get_shape().as_list()[0]

        # use the same rnd_sample for all the discrete generations
        self.rnd_sample = sample_gumbel([self.x.get_shape().as_list()[0],
                                         self.num_discrete])

        # generate & shuffle data together
        self._generate_vae_tm1_data()
        self.x_augmented = self._augment_data()
        assert self.x_augmented.get_shape().as_list() \
            == self.x.get_shape().as_list()
        # TODO: self._shuffle_all_data_together() possible?

        # run the encoder operation
        self.z, \
            self.z_normal,\
            self.z_discrete, \
            self.z_pre_gumbel, \
            self.kl_normal, \
            self.kl_discrete = self.encoder(self.x_augmented,
                                            rnd_sample=self.rnd_sample)
        print 'z_encoded = ', self.z.get_shape().as_list()
        print 'z_discrete = ', self.z_discrete.get_shape().as_list()
        print 'z_normal = ', self.z_normal.get_shape().as_list()

        # reconstruct x via the generator & run activation
        #self.p_x_given_z_logits = self.generator(self.z)
        self.p_x_given_z = self.generator(self.z)
        print 'pxgivenz = ', shp(self.p_x_given_z.mean())
        # self.x_reconstr_mean_activ = tf.nn.sigmoid(self.x_reconstr_mean)

    def _loss_helper(self, truth, pred):
        if self.reconstr_loss_type == "binary_cross_entropy":
            loss = self._cross_entropy(truth, pred)
        else:
            loss = self._l2_loss(truth, pred)

        channels = truth.get_shape().as_list()
        reduction_indices = [1, 2, 3] if len(channels) > 3 else [1]
        return tf.reduce_sum(loss, reduction_indices)

    @staticmethod
    def _cross_entropy(x, x_reconstr):
        # To ensure stability and avoid overflow, the implementation uses
        # max(x, 0) - x * z + log(1 + exp(-abs(x)))
        # return tf.maximum(x, 0) - x * z + tf.log(1.0 + tf.exp(-tf.abs(x)))
        return tf.nn.sigmoid_cross_entropy_with_logits(logits=x_reconstr,
                                                       labels=x)

    @staticmethod
    def _l2_loss(x, x_reconstr):
        return tf.square(x - x_reconstr)

    @staticmethod
    def mutual_information_bernouilli_cat(Q_z_given_x_softmax, eps=1e-9):
        # we compute the mutual information term,
        # which is the conditional entropy of the prior
        # and our variational distribution, plus the entropy of our prior:

        # first we build a uniform cat prior and sample it
        qzshp = Q_z_given_x_softmax.get_shape().as_list()
        # batch_size = qzshp[0]
        # feature_size = qzshp[1]
        # prior = tf.contrib.distributions.Categorical([1.0/feature_size]*feature_size)
        # prior_sample = tf.one_hot(prior.sample(batch_size), feature_size, dtype=tf.float32)
        prior_sample = generate_random_categorical(qzshp[1], qzshp[0])

        cond_ent = tf.reduce_mean(-tf.reduce_sum(tf.log(Q_z_given_x_softmax + eps)
                                                 * prior_sample, 1))
        ent = tf.reduce_mean(-tf.reduce_sum(tf.log(prior_sample + eps)
                                            * prior_sample, 1))
        return cond_ent + ent

    # @staticmethod
    # def mutual_information_bernouilli_cat(bern_logits, cat_probs, eps=1e-9):
    #     '''
    #     I(\hat{X} ; Z) = H(Z) - H(Z | \hat{X}) = H(\hat{X}) - H(\hat{X} | Z)
    #     '''
    #     p_x_given_z = distributions.Bernoulli(logits=bern_logits,
    #                                           dtype=tf.float32)
    #     q_z = distributions.Categorical(probs=cat_probs + eps,
    #                                     dtype=tf.float32)
    #     # TODO: debug traces
    #     # print 'q_z_entropy = ', q_z.entropy().get_shape().as_list()
    #     # print 'p_x_given_z.entropy() = ', p_x_given_z.entropy().get_shape().as_list()
    #     return q_z.entropy() - tf.reduce_sum(p_x_given_z.entropy(), 1)

    def vae_loss(self, x, p_x_given_z, latent_kl, consistency_kl):
        # the loss is composed of two terms:
        # 1.) the reconstruction loss (the negative log probability
        #     of the input under the reconstructed bernoulli distribution
        #     induced by the decoder in the data space).
        #     this can be interpreted as the number of "nats" required
        #     for reconstructing the input when the activation in latent
        #     is given.
        # log_likelihood = self._loss_helper(x, p_x_given_z.mean())
        channels = x.get_shape().as_list()
        reduction_indices = [1, 2, 3] if len(channels) > 3 else [1]
        log_likelihood = tf.reduce_sum(self.p_x_given_z.log_prob(x),
                                       reduction_indices)

        mutual_info_regularizer \
            = VAE.mutual_information_bernouilli_cat(self.z_pre_gumbel)

        # 2.) the latent loss, which is defined as the kullback leibler divergence
        #     between the distribution in latent space induced by the encoder on
        #     the data and some prior. this acts as a kind of regularizer.
        #     this can be interpreted as the number of "nats" required
        #     for transmitting the the latent space distribution given
        #     the prior.
        # kl_categorical(p=none, q=none, p_logits=none, q_logits=none, eps=1e-6):
        # cost = reconstr_loss - latent_kl
        elbo = -log_likelihood + latent_kl
        cost = elbo + consistency_kl - self.mutual_info_reg * mutual_info_regularizer

        # create the reductions only once
        latent_loss_mean = tf.reduce_mean(latent_kl)
        log_likelihood_mean = tf.reduce_mean(log_likelihood)
        elbo_mean = tf.reduce_mean(elbo)
        cost_mean = tf.reduce_mean(cost)

        return [log_likelihood, log_likelihood_mean,
                latent_loss_mean, cost, cost_mean,
                elbo_mean]

    def _create_loss_optimizer(self):
        # build constraint graph
        self._create_constraints()

        with tf.variable_scope(self.get_name() + "/loss_optimizer"):
            self.latent_kl = self.kl_normal + self.kl_discrete

            # if self.submodel > 0:
            #     set the indexes[batch] of the latent kl to zero for the
            #     indices that we are constraining over as we are computing
            #     a regularizer in the above function
            #     zero_vals = [self.latent_kl[0:self.num_current_data],
            #                  tf.zeros([self.num_old_data])]
            #     self.latent_kl = tf.concat(axis=0, values=zero_vals)

            # tabulate total loss
            self.reconstr_loss, self.reconstr_loss_mean, \
                self.latent_loss_mean, \
                self.cost, self.cost_mean, self.elbo_mean \
                = self.vae_loss(self.x_augmented,
                                self.p_x_given_z,
                                self.latent_kl,
                                self.kl_consistency)

            # construct our optimizer
            #with tf.control_dependencies([self.p_x_given_z_logits]):
            filtered = [v for v in tf.trainable_variables()
                        if v.name.startswith(self.get_name())]
            self.optimizer = self._create_optimizer(filtered,
                                                    self.cost_mean,
                                                    self.learning_rate)

    def _create_optimizer(self, tvars, cost, lr):
        # optimizer = tf.train.rmspropoptimizer(self.learning_rate)
        optimizer = tf.train.AdamOptimizer(learning_rate=lr)

        print 'there are %d trainable vars in cost %s\n' % (len(tvars), cost.name)
        grads = tf.gradients(cost, tvars)

        # DEBUG: exploding gradients test with this:
        # for index in range(len(grads)):
        #     if grads[index] is not None:
        #         gradstr = "\n grad [%i] | tvar [%s] =" % (index, tvars[index].name)
        #         grads[index] = tf.Print(grads[index], [grads[index]], gradstr, summarize=100)

        # grads, _ = tf.clip_by_global_norm(grads, 5.0)
        self.grad_norm = tf.norm(tf.concat([tf.reshape(t, [-1]) for t in grads],
                                           axis=0))
        return optimizer.apply_gradients(zip(grads, tvars))
        # return tf.train.AdamOptimizer(learning_rate=lr).minimize(cost, var_list=tvars)

    def partial_fit(self, inputs, iteration_print=10,
                    iteration_save_imgs=2000,
                    is_forked=False, summary="train"):
        """Train model based on mini-batch of input data.

        Return cost of mini-batch.
        """

        feed_dict = {self.x: inputs,
                     self.is_training: True if summary == "train" else False,
                     self.tau: self.tau_host}
        if summary == "train":
            writer = self.train_summary_writer
        else:
            writer = self.test_summary_writer

        try:
            # update tau for gumbel-softmax
            if summary == "train" and self.iteration > 0 and self.iteration % 10 == 0:
                rate = -self.anneal_rate*self.iteration
                self.tau_host = np.maximum(self.tau0 * np.exp(rate),
                                           self.min_temp)
                print 'updated tau to ', self.tau_host

            # full list of session ops
            ops_to_run = [self.cost_mean, self.elbo_mean,
                          self.reconstr_loss_mean,
                          self.latent_loss_mean]
            if summary == "train":
                ops_to_run = [self.optimizer,
                              self.iteration_gpu_op] + ops_to_run

            if iteration_save_imgs > 0 and self.iteration % iteration_save_imgs == 0:
                # write images + summaries
                if summary == "train":
                    _, _, cost, elbo, rloss, lloss, summary_img \
                        = self.sess.run(ops_to_run + [self.image_summaries],
                                        feed_dict=feed_dict)
                else:
                    cost, elbo, rloss, lloss, summary_img \
                        = self.sess.run(ops_to_run + [self.image_summaries],
                                        feed_dict=feed_dict)

                writer.add_summary(summary_img, self.iteration)
            elif self.iteration % iteration_print == 0:
                # write regular summaries
                if summary == "train":
                    _, _, cost, elbo, rloss, lloss, summary \
                        = self.sess.run(ops_to_run + [self.summaries],
                                        feed_dict=feed_dict)
                else:
                    cost, elbo, rloss, lloss, summary \
                        = self.sess.run(ops_to_run + [self.summaries],
                                        feed_dict=feed_dict)

                writer.add_summary(summary, self.iteration)
            else:
                # write no summary
                if summary == "train":
                    _, _, cost, elbo, rloss, lloss \
                        = self.sess.run(ops_to_run,
                                        feed_dict=feed_dict)
                else:
                    cost, elbo, rloss, lloss \
                        = self.sess.run(ops_to_run,
                                        feed_dict=feed_dict)

        except Exception as e:
            print 'caught exception in partial fit: ', e

        self.iteration += 1
        return cost, elbo, rloss, lloss

    def write_classes_to_file(self, filename, all_classes):
        with open(filename, 'a') as f:
            np.savetxt(f, self.sess.run(all_classes), delimiter=",")

    def build_new_encoder_decoder_pair(self, num_new_classes=1):
        updated_latent_size = 2*self.latent_size \
                              + self.num_discrete \
                              + num_new_classes

        if self.encoder_model.layer_type is not 'cnn':
            # increase the number of latent params
            # if self.submodel % 4 == 0:  # XXX
            #     print 'adding extra layer...'
            #     layer_sizes = self.encoder_model.sizes + [512]
            # else:
            #     layer_sizes = self.encoder_model.sizes
            layer_sizes = self.encoder_model.sizes

            encoder = DenseEncoder(self.sess, updated_latent_size,
                                   self.is_training,
                                   scope="encoder",
                                   sizes=layer_sizes,
                                   use_ln=self.encoder_model.use_ln,
                                   use_bn=self.encoder_model.use_bn)
            is_dec_doubled = self.decoder_model.double_features > 1
            decoder = DenseEncoder(self.sess, self.input_size,
                                   self.is_training,
                                   scope="decoder",
                                   sizes=layer_sizes,
                                   double_features=is_dec_doubled,
                                   use_ln=self.decoder_model.use_ln,
                                   use_bn=self.decoder_model.use_bn)
        else:
            encoder = CNNEncoder(self.sess, updated_latent_size,
                                 self.is_training,
                                 scope="encoder",
                                 use_ln=self.encoder_model.use_ln,
                                 use_bn=self.encoder_model.use_bn,)
            decoder = CNNDecoder(self.sess,
                                 scope="decoder",
                                 double_channels=self.decoder_model.double_channels,
                                 input_size=self.input_size,
                                 is_training=self.is_training,
                                 use_ln=self.decoder_model.use_ln,
                                 use_bn=self.decoder_model.use_bn)

        return encoder, decoder

    def fork(self, num_new_class=1):
        '''
        Fork the current model by copying the model parameters
        into the old ones.

        Note: This is a slow op in tensorflow
              because the session needs to be run
        '''
        encoder, decoder = self.build_new_encoder_decoder_pair(num_new_class)
        print 'encoder = ', encoder.get_info()
        print 'decoder = ', decoder.get_info()

        vae_tp1 = VAE(self.sess, self.x,
                      input_size=self.input_size,
                      batch_size=self.batch_size,
                      latent_size=self.latent_size,
                      discrete_size=self.num_discrete + num_new_class,
                      encoder=encoder,
                      decoder=decoder,
                      p_x_given_z_func=self.p_x_given_z_func,
                      is_training=self.is_training,
                      activation=self.activation,
                      learning_rate=self.learning_rate,
                      submodel=self.submodel+1,
                      total_true_models=self.total_true_models+num_new_class,
                      vae_tm1=self,
                      img_shape=self.img_shape,
                      base_dir=self.base_dir)

        # we want to reinit our weights and biases to their defaults
        # after this we will copy the possible weights over
        self.sess.run([vae_tp1.init_op])  # ,vae_tp1.init_local_op])

        # copy the encoder and decoder layers
        # this helps convergence time
        copy_layer(self.sess, self.encoder_model, self.get_name(),
                   encoder, vae_tp1.get_name())
        copy_layer(self.sess, self.decoder_model, self.get_name(),
                   decoder, vae_tp1.get_name())

        return vae_tp1

    def transform(self, X):
        """Transform data by mapping it into the latent space."""
        # Note: This maps to mean of distribution, we could alternatively
        # sample from Gaussian distribution
        return self.sess.run(self.z, feed_dict={self.x: X,
                                                self.tau: self.tau_host,
                                                self.is_training: False})

    def generate(self, z=None):
        """ Generate data by sampling from latent space.

        If z_mu is not None, data for this point in latent space is
        generated. Otherwise, z_mu is drawn from prior in latent
        space.
        """
        if z is None:
            z = generate_random_categorical(self.latent_size, self.batch_size)

        # Note: This maps to mean of distribution, we could alternatively
        # sample from Gaussian distribution
        return self.sess.run(self.p_x_given_z.mean(),
                             feed_dict={self.z: z,
                                        self.tau: self.tau_host,
                                        self.is_training: False})

    def reconstruct(self, X, return_losses=False):
        """ Use VAE to reconstruct given data. """
        if return_losses:
            ops = [self.p_x_given_z.mean(),
                   self.reconstr_loss, self.reconstr_loss_mean,
                   self.latent_kl, self.latent_loss_mean,
                   self.cost, self.cost_mean, self.elbo_mean]
        else:
            ops = self.p_x_given_z.mean()

        return self.sess.run(ops,
                             feed_dict={self.x: X,
                                        self.tau: self.tau_host,
                                        self.is_training: False})

    def test(self, source, batch_size, iteration_save_imgs=10):
        n_samples = source.num_examples
        avg_cost = avg_elbo = avg_recon = avg_latent = 0.
        total_batch = int(n_samples / batch_size)

        # Loop over all batches
        for i in range(total_batch):
            batch_xs, _ = source.next_batch(batch_size)

            # only save imgs if we are on the Nth test iteration
            if self.test_epoch % iteration_save_imgs == 0:
                iteration_save_imgs_pf = 1
            else:
                iteration_save_imgs_pf = -1

            # Fit training using batch data
            cost, elbo, recon_cost, latent_cost \
                = self.partial_fit(batch_xs, summary="test",
                                   iteration_print=1,  # always print
                                   iteration_save_imgs=iteration_save_imgs_pf)

            # Compute average loss
            avg_cost += cost / n_samples * batch_size
            avg_elbo += elbo / n_samples * batch_size
            avg_recon += recon_cost / n_samples * batch_size
            avg_latent += latent_cost / n_samples * batch_size

        # Display logs at the end of testing
        self.test_epoch += 1
        print "[Test]", \
            "avg cost = ", "{:.4f} | ".format(avg_cost), \
            "avg latent cost = ", "{:.4f} | ".format(avg_latent), \
            "avg elbo loss = ", "{:.4f} | ".format(avg_elbo), \
            "avg recon loss = ", "{:.4f}".format(avg_recon)

    def train(self, source, batch_size, training_epochs=10, display_step=5):
        n_samples = source.train.num_examples
        for epoch in range(training_epochs):
            avg_cost = avg_elbo = avg_recon = avg_latent = 0.
            total_batch = int(n_samples / batch_size)
            # Loop over all batches
            for i in range(total_batch):
                batch_xs, _ = source.train.next_batch(batch_size)

                # Fit training using batch data
                cost, elbo, recon_cost, latent_cost\
                    = self.partial_fit(batch_xs)

                # Compute average loss
                avg_cost += cost / n_samples * batch_size
                avg_elbo += elbo / n_samples * batch_size
                avg_recon += recon_cost / n_samples * batch_size
                avg_latent += latent_cost / n_samples * batch_size

            # Display logs per epoch step
            if epoch % display_step == 0:
                print "[Epoch:", '%04d]' % (epoch+1), \
                    "current cost = ", "{:.4f} | ".format(cost), \
                    "avg cost = ", "{:.4f} | ".format(avg_cost), \
                    "avg elbo = ", "{:.4f} | ".format(avg_elbo), \
                    "avg latent = ", "{:.4f} | ".format(avg_latent), \
                    "avg recon = ", "{:.4f}".format(avg_recon)
back to top