Raw File
Basic_GAN_Distributed.py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import argparse

import cntk as C
import cntk.tests.test_utils
from timeit import default_timer as timer

cntk.tests.test_utils.set_device_from_pytest_env() # (only needed for our build system)
C.cntk_py.set_fixed_random_seed(1) # fix a random seed for CNTK components

#%matplotlib inline
isFast = True 
# architectural parameters
g_input_dim = 100
g_hidden_dim = 128
g_output_dim = d_input_dim = 784
d_hidden_dim = 128
d_output_dim = 1

# Ensure the training data is generated and available for this tutorial
def create_reader(path, is_training, input_dim, label_dim):
    deserializer = C.io.CTFDeserializer(
        filename = path,
        streams = C.io.StreamDefs(
            labels_unused = C.io.StreamDef(field = 'labels', shape = label_dim, is_sparse = False),
            features = C.io.StreamDef(field = 'features', shape = input_dim, is_sparse = False
            )
        )
    )
    return C.io.MinibatchSource(
        deserializers = deserializer,
        randomize = is_training,
        max_sweeps = C.io.INFINITELY_REPEAT if is_training else 1
    )

np.random.seed(123)
def noise_sample(num_samples):
    return np.random.uniform(
        low = -1.0,
        high = 1.0,
        size = [num_samples, g_input_dim]
    ).astype(np.float32)

def generator(z):
    with C.layers.default_options(init = C.xavier()):
        h1 = C.layers.Dense(g_hidden_dim, activation = C.relu)(z)
        return C.layers.Dense(g_output_dim, activation = C.tanh)(h1)

def discriminator(x):
    with C.layers.default_options(init = C.xavier()):
        h1 = C.layers.Dense(d_hidden_dim, activation = C.relu)(x)
        return C.layers.Dense(d_output_dim, activation = C.sigmoid)(h1)

# training config
minibatch_size = 1024
num_minibatches = 300 if isFast else 40000
lr = 0.00005

def build_graph(noise_shape, image_shape, G_progress_printer, D_progress_printer):
    input_dynamic_axes = [C.Axis.default_batch_axis()]
    Z = C.input_variable(noise_shape, dynamic_axes=input_dynamic_axes)
    X_real = C.input_variable(image_shape, dynamic_axes=input_dynamic_axes)
    X_real_scaled = 2*(X_real / 255.0) - 1.0

    # Create the model function for the generator and discriminator models
    X_fake = generator(Z)
    D_real = discriminator(X_real_scaled)
    D_fake = D_real.clone(
        method = 'share',
        substitutions = {X_real_scaled.output: X_fake.output}
    )

    # Create loss functions and configure optimazation algorithms
    G_loss = 1.0 - C.log(D_fake)
    D_loss = -(C.log(D_real) + C.log(1.0 - D_fake))

    G_learner = C.fsadagrad(
       parameters = X_fake.parameters,
        lr = C.learning_parameter_schedule_per_sample(lr),
        momentum = C.momentum_schedule_per_sample(0.9985724484938566)
    )
    D_learner = C.fsadagrad(
        parameters = D_real.parameters,
        lr = C.learning_parameter_schedule_per_sample(lr),
        momentum = C.momentum_schedule_per_sample(0.9985724484938566)
    )

    DistG_learner = C.train.distributed.data_parallel_distributed_learner(G_learner)
    
    # The following API marks a learner as the matric aggregator, which is used by 
    # the trainer to determine the training progress.
    # It is required, only when more than one learner is provided to a *single* trainer. 
    # In this example, we use two trainers each with a single learner, so it 
    # is not required and automatically set by CNTK for each single learner. However, if you 
    # plan to use both learners with a single trainer, then it needs to be call before 
    # creating the trainer.
    #DistG_learner.set_as_metric_aggregator()

    DistD_learner = C.train.distributed.data_parallel_distributed_learner(D_learner)

    # Instantiate the trainers
    G_trainer = C.Trainer(
        X_fake,
        (G_loss, None),
        DistG_learner,
        G_progress_printer
    )
    D_trainer = C.Trainer(
        D_real,
        (D_loss, None),
        DistD_learner,
        D_progress_printer
    )

    return X_real, X_fake, Z, G_trainer, D_trainer

def train(reader_train):
    k = 2
    worker_rank = C.Communicator.rank()
    # print out loss for each model for upto 50 times
    print_frequency_mbsize = num_minibatches // 50
    pp_G = C.logging.ProgressPrinter(print_frequency_mbsize, rank=worker_rank)
    pp_D = C.logging.ProgressPrinter(print_frequency_mbsize * k, rank=worker_rank)

    X_real, X_fake, Z, G_trainer, D_trainer = \
        build_graph(g_input_dim, d_input_dim, pp_G, pp_D)
    
    input_map = {X_real: reader_train.streams.features}

    num_partitions = C.Communicator.num_workers()
    worker_rank = C.Communicator.rank()
    distributed_minibatch_size  = minibatch_size // num_partitions

    for train_step in range(num_minibatches):

        # train the discriminator model for k steps
        for gen_train_step in range(k):
            Z_data = noise_sample(distributed_minibatch_size)
            X_data = reader_train.next_minibatch(minibatch_size, input_map, num_data_partitions=num_partitions, partition_index=worker_rank)

            if X_data[X_real].num_samples == Z_data.shape[0]:
                batch_inputs = {X_real: X_data[X_real].data, 
                                Z: Z_data}
                D_trainer.train_minibatch(batch_inputs)

        # train the generator model for a single step
        Z_data = noise_sample(distributed_minibatch_size)
        batch_inputs = {Z: Z_data}
        G_trainer.train_minibatch(batch_inputs)

        G_trainer_loss = G_trainer.previous_minibatch_loss_average

    return Z, X_fake, G_trainer_loss

def plot_images(images, subplot_shape):
    plt.style.use('ggplot')
    fig, axes = plt.subplots(*subplot_shape)
    for image, ax in zip(images, axes.flatten()):
        ax.imshow(image.reshape(28, 28), vmin = 0, vmax = 1.0, cmap = 'gray')
        ax.axis('off')
    plt.show()
    

#mpiexec entrance
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-datadir', '--datadir')
    args = vars(parser.parse_args())

    data_found = False
    train_file = os.path.join(args['datadir'], "Train-28x28_cntk_text.txt")
    if os.path.isfile(train_file):
        data_found = True
    if not data_found:
        raise ValueError("Please generate the data by completing CNTK 103 Part A")
    
    worker_rank = C.Communicator.rank()
    
    start = timer()
    reader_train = create_reader(train_file, True, d_input_dim, label_dim=10)
    G_input, G_output, G_trainer_loss = train(reader_train)
    # Print the generator loss 
    C.Communicator.finalize()
    end = timer()
    print("Training loss of the generator at worker: {%d} is: {%f}, time taken is: {%d} seconds."%(worker_rank, G_trainer_loss, (end - start)))
    
    # Please uncomment below to display the generated images.
    #if worker_rank == 0:
    #    noise = noise_sample(36)
    #    images = G_output.eval({G_input: noise})
    #    plot_images(images, subplot_shape =[6, 6])
back to top