https://github.com/shaharharel/CDN_Molecule
Raw File
Tip revision: d8cbe5ecf70e05d77b3ed2b60632720d2baee805 authored by shaharharel on 13 February 2018, 13:08:28 UTC
.
Tip revision: d8cbe5e
model.py
import tensorflow as tf


class CDN:

    def __init__(self, sequence_length, vocab_size, embedding_size, filter_sizes, num_filters, max_molecule_length,
                 gaussian_samples, variational=True, l2_reg_lambda=0.5, generation_mode=False, test_mode=False):

        self.sequence_length = sequence_length
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.filter_sizes = filter_sizes
        self.num_filters = num_filters
        self.max_molecule_length = max_molecule_length
        self.l2_reg_lambda = l2_reg_lambda
        self.gaussian_samples_dim = gaussian_samples
        self.variational = variational
        self.encoder_input_GO = tf.placeholder(tf.int32, [None, sequence_length], name="encoder_input")
        self.encoder_input = tf.placeholder(tf.int32, [None, sequence_length], name="encoder_input")  ## no go
        self.gaussian_samples = tf.placeholder(tf.float32, [None, self.gaussian_samples_dim], name="unit_gaussians")
        self.generation_mode = generation_mode
        self.test_mode = test_mode

    def encode(self):
        # Embedding layer
        with tf.name_scope("embedding"):
            self.E = tf.Variable(tf.random_uniform([self.vocab_size, self.embedding_size], -1.0, 1.0), name="W")
            self.embedded_chars = tf.nn.embedding_lookup(self.E, self.encoder_input)
            self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)
            self.embedded_chars_go = tf.nn.embedding_lookup(self.E, self.encoder_input_GO)

        # Create a convolution layers for each filter size
        conv_flatten = []
        for i, filter_size in enumerate(self.filter_sizes):
            with tf.name_scope("conv-maxpool-%s" % filter_size):
                # Convolution Layer
                filter_shape = [filter_size, self.embedding_size, 1, self.num_filters]
                W = tf.Variable(tf.truncated_normal(filter_shape, stddev=0.1), name="W")
                b = tf.Variable(tf.constant(0.1, shape=[self.num_filters]), name="b")
                conv = tf.nn.conv2d(
                    self.embedded_chars_expanded,
                    W,
                    strides=[1, 1, 1, 1],
                    padding="VALID",
                    name="conv")
                h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
                conv_flatten.append(tf.contrib.layers.flatten(h))
        conv_output = tf.concat(conv_flatten, axis=1)

        # Flatten feature vector
        h_pool_flat3 = tf.nn.relu(tf.contrib.layers.linear(conv_output, 450))

        if self.variational:
            with tf.name_scope("Variational"):
                self.z_mean = tf.contrib.layers.linear(h_pool_flat3, 300)
                self.z_stddev = tf.contrib.layers.linear(h_pool_flat3, 300)
                latent_loss = 0.5 * tf.reduce_sum(tf.square(self.z_mean) + tf.square(self.z_stddev) -
                                                       tf.log(tf.square(self.z_stddev)) - 1, 1)
                self.mean_latent_loss = tf.reduce_mean(latent_loss)
                if self.generation_mode:
                    h_pool_flat = self.gaussian_samples
                else:
                    h_pool_flat = self.z_mean + (self.z_stddev * self.gaussian_samples)

                h_pool_flat = tf.identity(h_pool_flat, "encoded_final")

        return h_pool_flat, self.mean_latent_loss

    def decode_rnn(self, z):
        def pick_next_argmax(former_output, step):
            next_symbol = tf.expand_dims(tf.stop_gradient(tf.argmax(former_output, 1)), axis=-1)
            return tf.nn.embedding_lookup(self.E, next_symbol), next_symbol

        def pick_next_top_k(former_output, step):
            next_symbol = tf.multinomial(former_output, 1)
            return tf.nn.embedding_lookup(self.E, next_symbol), next_symbol

        with tf.name_scope("Decoder"):
            self.decode_start = tf.nn.relu(tf.contrib.layers.linear(z, 150))
            decoder_inputs_list = tf.split(self.embedded_chars_go, self.max_molecule_length, axis=1)
            decoder_inputs_list = [tf.squeeze(i, axis=1) for i in decoder_inputs_list]
            rnn_cell = tf.nn.rnn_cell.LSTMCell(150, state_is_tuple=False)

            self.lstm_outputs = []
            temp_logits = []
            self.all_symbols = []
            symbol = tf.ones(1)  # output for test mode
            for i in range(self.max_molecule_length):
                if not self.test_mode or i == 0:
                    if i == 0:
                        output, state = rnn_cell(decoder_inputs_list[i], state=z)
                    else:
                        output, state = rnn_cell(decoder_inputs_list[i], state=state)
                else:
                    next_decoder_input, symbol = pick_next_argmax(temp_logits[-1], i)
                    next_decoder_input = tf.squeeze(next_decoder_input, axis=1)
                    output, state = rnn_cell(next_decoder_input, state=state)
                with tf.variable_scope("decoder_output_to_logits") as scope_logits:
                    if i > 0:
                        scope_logits.reuse_variables()
                    temp_logits.append(tf.contrib.layers.linear(output, self.vocab_size))

                self.lstm_outputs.append(output)
                if i > 0:
                    self.all_symbols.append(symbol)
                if i == self.max_molecule_length - 1 and self.test_mode:
                    self.all_symbols.append(pick_next_argmax(temp_logits[-1], i)[1])
            if self.test_mode:
                self.all_symbols = tf.squeeze(tf.transpose(tf.stack(self.all_symbols), [1,0,2]), axis=-1)

            self.decoder_logits = tf.transpose(tf.stack(temp_logits), perm=[1, 0, 2])
            self.decoder_prediction = tf.argmax(self.decoder_logits, 2, name="decoder_predictions")

            return self.decoder_logits

    def loss(self, logits, latent_loss):
        with tf.name_scope("loss"):
            self.output_onehot = tf.one_hot(self.encoder_input, self.vocab_size)
            self.losses = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=self.output_onehot)
            self.CE_loss = tf.reduce_mean(self.losses)
            self.total_loss = self.CE_loss + .00001 * latent_loss

        with tf.name_scope("accuracy"):
            decoder_prediction = tf.argmax(logits, 2, name="decoder_predictions")
            x_target = tf.to_int64(self.encoder_input)
            correct_predictions = tf.equal(decoder_prediction, x_target)
            self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy")

        return self.total_loss, self.accuracy
back to top