https://github.com/MarkMoHR/virtual_sketching
Revision 958efe45a9120b9d467ba7701efba28c11e38f8f authored by Your Name on 21 August 2021, 08:13 UTC, committed by Your Name on 21 August 2021, 08:13 UTC
1 parent 149a4be
Raw File
Tip revision: 958efe45a9120b9d467ba7701efba28c11e38f8f authored by Your Name on 21 August 2021, 08:13 UTC
Added bilibili links
Tip revision: 958efe4
subnet_tf_utils.py
import tensorflow as tf


def get_initializer(init_method):
    if init_method == 'xavier_normal':
        initializer = tf.glorot_normal_initializer()
    elif init_method == 'xavier_uniform':
        initializer = tf.glorot_uniform_initializer()
    elif init_method == 'he_normal':
        initializer = tf.keras.initializers.he_normal()
    elif init_method == 'he_uniform':
        initializer = tf.keras.initializers.he_uniform()
    elif init_method == 'lecun_normal':
        initializer = tf.keras.initializers.lecun_normal()
    elif init_method == 'lecun_uniform':
        initializer = tf.keras.initializers.lecun_uniform()
    else:
        raise Exception('Unknown initializer:', init_method)
    return initializer


def lrelu(x, leak=0.2, name="lrelu", alt_relu_impl=False):
    with tf.variable_scope(name) as scope:
        if alt_relu_impl:
            f1 = 0.5 * (1 + leak)
            f2 = 0.5 * (1 - leak)
            return f1 * x + f2 * abs(x)
        else:
            return tf.maximum(x, leak * x)


def batchnorm(input, name='batch_norm', init_method=None):
    if init_method is not None:
        initializer = get_initializer(init_method)
    else:
        initializer = tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)

    with tf.variable_scope(name):
        # this block looks like it has 3 inputs on the graph unless we do this
        input = tf.identity(input)

        channels = input.get_shape()[3]
        offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer())
        scale = tf.get_variable("scale", [channels], dtype=tf.float32,
                                initializer=initializer)
        mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False)
        variance_epsilon = 1e-5
        normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
        return normalized


def layernorm(input, name='layer_norm', init_method=None):
    if init_method is not None:
        initializer = get_initializer(init_method)
    else:
        initializer = tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)

    with tf.variable_scope(name):
        n_neurons = input.get_shape()[3]
        offset = tf.get_variable("offset", [n_neurons], dtype=tf.float32, initializer=tf.zeros_initializer())
        scale = tf.get_variable("scale", [n_neurons], dtype=tf.float32,
                                initializer=initializer)
        offset = tf.reshape(offset, [1, 1, -1])
        scale = tf.reshape(scale, [1, 1, -1])
        mean, variance = tf.nn.moments(input, axes=[1, 2, 3], keep_dims=True)
        variance_epsilon = 1e-5
        normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon)
        return normalized


def instance_norm(input, name="instance_norm", init_method=None):
    if init_method is not None:
        initializer = get_initializer(init_method)
    else:
        initializer = tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)

    with tf.variable_scope(name):
        depth = input.get_shape()[3]
        scale = tf.get_variable("scale", [depth], initializer=initializer)
        offset = tf.get_variable("offset", [depth], initializer=tf.constant_initializer(0.0))
        mean, variance = tf.nn.moments(input, axes=[1, 2], keep_dims=True)
        epsilon = 1e-5
        inv = tf.rsqrt(variance + epsilon)
        normalized = (input - mean) * inv
        return scale * normalized + offset


def linear1d(inputlin, inputdim, outputdim, name="linear1d", init_method=None):
    if init_method is not None:
        initializer = get_initializer(init_method)
    else:
        initializer = None

    with tf.variable_scope(name) as scope:
        weight = tf.get_variable("weight", [inputdim, outputdim], initializer=initializer)
        bias = tf.get_variable("bias", [outputdim], dtype=tf.float32, initializer=tf.constant_initializer(0.0))
        return tf.matmul(inputlin, weight) + bias


def general_conv2d(inputconv, output_dim=64, filter_height=4, filter_width=4, stride_height=2, stride_width=2,
                   stddev=0.02, padding="SAME", name="conv2d", do_norm=True, norm_type='instance_norm', do_relu=True,
                   relufactor=0, is_training=True, init_method=None):
    if init_method is not None:
        initializer = get_initializer(init_method)
    else:
        initializer = tf.truncated_normal_initializer(stddev=stddev)

    with tf.variable_scope(name) as scope:
        conv = tf.contrib.layers.conv2d(inputconv, output_dim, [filter_width, filter_height],
                                        [stride_width, stride_height], padding, activation_fn=None,
                                        weights_initializer=initializer,
                                        biases_initializer=tf.constant_initializer(0.0))
        if do_norm:
            if norm_type == 'instance_norm':
                conv = instance_norm(conv, init_method=init_method)
                # conv = tf.contrib.layers.instance_norm(conv, epsilon=1e-05, center=True, scale=True,
                #                                        scope='instance_norm')
            elif norm_type == 'batch_norm':
                # conv = batchnorm(conv, init_method=init_method)
                conv = tf.contrib.layers.batch_norm(conv, decay=0.9, is_training=is_training, updates_collections=None,
                                                    epsilon=1e-5, center=True, scale=True, scope="batch_norm")
            elif norm_type == 'layer_norm':
                # conv = layernorm(conv, init_method=init_method)
                conv = tf.contrib.layers.layer_norm(conv, center=True, scale=True, scope='layer_norm')

        if do_relu:
            if relufactor == 0:
                conv = tf.nn.relu(conv, "relu")
            else:
                conv = lrelu(conv, relufactor, "lrelu")

        return conv


def generative_cnn_c3_encoder(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
    tensor_x = inputs

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
        tensor_x_sp = tensor_x  # [N, h, w, 256]

        tensor_x = tf.reshape(tensor_x, (-1, 256 * 4 * 4))
        tensor_x = linear1d(tensor_x, 256 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)

        return tensor_x, tensor_x_sp


def generative_cnn_c3_encoder_deeper(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
    tensor_x = inputs

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
        tensor_x_sp = tensor_x  # [N, h, w, 512]

        tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))
        tensor_x = linear1d(tensor_x, 512 * 4 * 4, 512, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)

        return tensor_x, tensor_x_sp


def generative_cnn_c3_encoder_combine33(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
    local_x = local_inputs
    global_x = global_inputs

    with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):
        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)

    with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):
        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)

    tensor_x = tf.concat([local_x, global_x], axis=-1)

    with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):
        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
        tensor_x_sp = tensor_x  # [N, h, w, 256]

        tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))
        tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)

        return tensor_x, tensor_x_sp


def generative_cnn_c3_encoder_combine43(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
    local_x = local_inputs
    global_x = global_inputs

    with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):
        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)

        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)

    with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):
        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)

        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)

    tensor_x = tf.concat([local_x, global_x], axis=-1)

    with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):
        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
        tensor_x_sp = tensor_x  # [N, h, w, 256]

        tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))
        tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)

        return tensor_x, tensor_x_sp


def generative_cnn_c3_encoder_combine53(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
    local_x = local_inputs
    global_x = global_inputs

    with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):
        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)

        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)

        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)

    with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):
        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)

        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)

        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)

    tensor_x = tf.concat([local_x, global_x], axis=-1)

    with tf.variable_scope('Combined_Encoder', reuse=tf.AUTO_REUSE):
        tensor_x_sp = tensor_x  # [N, h, w, 256]
        tensor_x = tf.reshape(tensor_x, (-1, 1024 * 4 * 4))
        tensor_x = linear1d(tensor_x, 1024 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)

        return tensor_x, tensor_x_sp


def generative_cnn_c3_encoder_combineFC(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
    local_x = local_inputs
    global_x = global_inputs

    with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):
        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)

        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)

        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)

        local_x = tf.reshape(local_x, (-1, 512 * 4 * 4))
        local_x = linear1d(local_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            local_x = tf.nn.dropout(local_x, drop_keep_prob)

    with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):
        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)

        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)

        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)

        global_x = tf.reshape(global_x, (-1, 512 * 4 * 4))
        global_x = linear1d(global_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            global_x = tf.nn.dropout(global_x, drop_keep_prob)

    tensor_x_sp = None
    tensor_x = tf.concat([local_x, global_x], axis=-1)
    return tensor_x, tensor_x_sp


def generative_cnn_c3_encoder_combineFC_jointAttn(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5,
                                                  init_method=None, combine_manner='attn'):
    local_x = local_inputs
    global_x = global_inputs

    with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):
        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)

        share_x = local_x

        with tf.variable_scope('Attn_branch', reuse=tf.AUTO_REUSE):
            attn_x = general_conv2d(share_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                    is_training=is_training, name="CNN_ENC_1", init_method=init_method)
            attn_x = general_conv2d(attn_x, 32, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
                                    is_training=is_training, name="CNN_ENC_2", init_method=init_method)
            attn_x = general_conv2d(attn_x, 1, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
                                    is_training=is_training, name="CNN_ENC_3", init_method=init_method)
            attn_map = tf.nn.sigmoid(attn_x)  # (N, H/8, W/8, 1), [0.0, 1.0]

        if combine_manner == 'attn':
            attn_feat = attn_map * share_x + share_x
        else:
            raise Exception('Unknown combine_manner', combine_manner)

        local_x = general_conv2d(attn_feat, 256, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)

        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)

        local_x = tf.reshape(local_x, (-1, 512 * 4 * 4))
        local_x = linear1d(local_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            local_x = tf.nn.dropout(local_x, drop_keep_prob)

    with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):
        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)

        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)

        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)

        global_x = tf.reshape(global_x, (-1, 512 * 4 * 4))
        global_x = linear1d(global_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            global_x = tf.nn.dropout(global_x, drop_keep_prob)

    tensor_x_sp = None
    tensor_x = tf.concat([local_x, global_x], axis=-1)
    return tensor_x, tensor_x_sp, attn_map


def generative_cnn_c3_encoder_combineFC_sepAttn(local_inputs, global_inputs, is_training=True, drop_keep_prob=0.5,
                                                  init_method=None, combine_manner='attn'):
    local_x = local_inputs
    global_x = global_inputs

    with tf.variable_scope('Attn_branch', reuse=tf.AUTO_REUSE):
        attn_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,
                                is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        attn_x = general_conv2d(attn_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        attn_x = general_conv2d(attn_x, 64, filter_height=3, filter_width=3,
                                is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        attn_x = general_conv2d(attn_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        attn_x = general_conv2d(attn_x, 128, filter_height=3, filter_width=3,
                                is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        attn_x = general_conv2d(attn_x, 32, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
                                is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        attn_x = general_conv2d(attn_x, 1, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
                                is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)
        attn_map = tf.nn.sigmoid(attn_x)  # (N, H/8, W/8, 1), [0.0, 1.0]

    with tf.variable_scope('Local_Encoder', reuse=tf.AUTO_REUSE):
        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        local_x = general_conv2d(local_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        local_x = general_conv2d(local_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        local_x = general_conv2d(local_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)

        if combine_manner == 'attn':
            attn_feat = attn_map * local_x + local_x
        elif combine_manner == 'channel':
            attn_feat = tf.concat([local_x, attn_map], axis=-1)
        else:
            raise Exception('Unknown combine_manner', combine_manner)

        local_x = general_conv2d(attn_feat, 256, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        local_x = general_conv2d(local_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)

        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3,
                                 is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
        local_x = general_conv2d(local_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                 is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)

        local_x = tf.reshape(local_x, (-1, 512 * 4 * 4))
        local_x = linear1d(local_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            local_x = tf.nn.dropout(local_x, drop_keep_prob)

    with tf.variable_scope('Global_Encoder', reuse=tf.AUTO_REUSE):
        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        global_x = general_conv2d(global_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        global_x = general_conv2d(global_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        global_x = general_conv2d(global_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)

        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        global_x = general_conv2d(global_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)

        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
        global_x = general_conv2d(global_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)

        global_x = tf.reshape(global_x, (-1, 512 * 4 * 4))
        global_x = linear1d(global_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            global_x = tf.nn.dropout(global_x, drop_keep_prob)

    tensor_x_sp = None
    tensor_x = tf.concat([local_x, global_x], axis=-1)
    return tensor_x, tensor_x_sp, attn_map


def generative_cnn_c3_encoder_deeper13(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
    tensor_x = inputs

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
        tensor_x_sp = tensor_x  # [N, h, w, 256]

        tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))
        tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)

        return tensor_x, tensor_x_sp


def generative_cnn_c3_encoder_deeper13_attn(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
    tensor_x = inputs

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_1_2", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_2_2", init_method=init_method)

        tensor_x = self_attention(tensor_x, 64)

        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_3_3", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_4_3", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3,
                                  is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 512, filter_height=3, filter_width=3, stride_height=1, stride_width=1,
                                  is_training=is_training, name="CNN_ENC_5_3", init_method=init_method)
        tensor_x_sp = tensor_x  # [N, h, w, 256]

        tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))
        tensor_x = linear1d(tensor_x, 512 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)

        return tensor_x, tensor_x_sp


def generative_cnn_encoder(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
    tensor_x = inputs

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
        tensor_x = general_conv2d(tensor_x, 32, is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 32, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_1_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 64, is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 64, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_2_2", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 128, is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_3_2", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 256, is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_4_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_5_2", init_method=init_method)
        tensor_x_sp = tensor_x  # [N, h, w, 256]

        tensor_x = tf.reshape(tensor_x, (-1, 256 * 4 * 4))
        tensor_x = linear1d(tensor_x, 256 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)

        return tensor_x, tensor_x_sp


def generative_cnn_encoder_deeper(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
    tensor_x = inputs

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
        tensor_x = general_conv2d(tensor_x, 32, is_training=is_training, name="CNN_ENC_1", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 32, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_1_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 64, is_training=is_training, name="CNN_ENC_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 64, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_2_2", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 128, is_training=is_training, name="CNN_ENC_3", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_3_2", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 256, is_training=is_training, name="CNN_ENC_4", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_4_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 512, is_training=is_training, name="CNN_ENC_5", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 512, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_5_2", init_method=init_method)
        tensor_x_sp = tensor_x  # [N, h, w, 512]

        tensor_x = tf.reshape(tensor_x, (-1, 512 * 4 * 4))
        tensor_x = linear1d(tensor_x, 512 * 4 * 4, 512, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)

        return tensor_x, tensor_x_sp


def generative_cnn_encoder_deeper13(inputs, is_training=True, drop_keep_prob=0.5, init_method=None):
    tensor_x = inputs

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
        tensor_x = general_conv2d(tensor_x, 32, is_training=is_training,
                                  name="CNN_ENC_1", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 32, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_1_2", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 64, is_training=is_training,
                                  name="CNN_ENC_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 64, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_2_2", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 128, is_training=is_training,
                                  name="CNN_ENC_3", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_3_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 128, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_3_3", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 256, is_training=is_training,
                                  name="CNN_ENC_4", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_4_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_4_3", init_method=init_method)

        tensor_x = general_conv2d(tensor_x, 256, is_training=is_training,
                                  name="CNN_ENC_5", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_5_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 256, stride_height=1, stride_width=1, is_training=is_training,
                                  name="CNN_ENC_5_3", init_method=init_method)
        tensor_x_sp = tensor_x  # [N, h, w, 256]

        tensor_x = tf.reshape(tensor_x, (-1, 256 * 4 * 4))
        tensor_x = linear1d(tensor_x, 256 * 4 * 4, 128, name='CNN_ENC_FC', init_method=init_method)

        if is_training:
            tensor_x = tf.nn.dropout(tensor_x, drop_keep_prob)

        return tensor_x, tensor_x_sp


def max_pooling(x) :
    return tf.layers.max_pooling2d(x, pool_size=2, strides=2, padding='SAME')


def hw_flatten(x) :
    return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])


def self_attention(x, in_channel, name='self_attention'):
    with tf.variable_scope(name) as scope:
        f = general_conv2d(x, in_channel // 8, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
                           do_norm=False, do_relu=False, name='f_conv')  # (N, h, w, c')
        f = max_pooling(f)  # (N, h', w', c')
        g = general_conv2d(x, in_channel // 8, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
                           do_norm=False, do_relu=False, name='g_conv')  # (N, h, w, c')
        h = general_conv2d(x, in_channel, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
                           do_norm=False, do_relu=False, name='h_conv')  # (N, h, w, c)
        h = max_pooling(h)  # (N, h', w', c)

        # M = h * w, M' = h' * w'
        s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True)  # (N, M, M')
        beta = tf.nn.softmax(s)  # attention map

        o = tf.matmul(beta, hw_flatten(h))  # (N, M, c)
        gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))

        o = tf.reshape(o, shape=x.shape)  # (N, h, w, c)
        o = general_conv2d(o, in_channel, filter_height=1, filter_width=1, stride_height=1, stride_width=1,
                           do_norm=False, do_relu=False, name='attn_conv')

        x = gamma * o + x

    return x


def global_avg_pooling(x):
    gap = tf.reduce_mean(x, axis=[1, 2])
    return gap


def cnn_discriminator_wgan_gp(discrim_inputs, discrim_targets, init_method=None):
    tensor_x = tf.concat([discrim_inputs, discrim_targets], axis=3)  # (N, H, W, 3 + 1)

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE) as scope:
        tensor_x = general_conv2d(tensor_x, 32, filter_height=3, filter_width=3,
                                  is_training=True, name="CNN_ENC_1", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 64, filter_height=3, filter_width=3,
                                  is_training=True, name="CNN_ENC_2", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,
                                  is_training=True, name="CNN_ENC_3", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 128, filter_height=3, filter_width=3,
                                  is_training=True, name="CNN_ENC_4", init_method=init_method)
        tensor_x = general_conv2d(tensor_x, 1, filter_height=3, filter_width=3,
                                  is_training=True, name="CNN_ENC_5", init_method=init_method)
        # (N, H/32, W/32, 1)

        d_out = global_avg_pooling(tensor_x)  # (N, 1)

    return d_out
back to top