https://github.com/tolstikhin/adagan
Tip revision: 746bd8e6a5277a3a95463a66f4e631e1b48fad48 authored by itolstikhin on 13 December 2017, 10:56:01 UTC
Fixed plots
Fixed plots
Tip revision: 746bd8e
fid_pics.py
# This script creates plots for experiments from
# ICLR 2018 submission by Bousquet, Gelly, Tolstikhin, Bernhard.
# 1. Random samples
# 2. Interpolations:
# a. Between points of the test set, linearly
# b. Take a random point from Pz and make a whole circle on geodesic
# 3. Test reconstructions
import os
import sys
import tensorflow as tf
import numpy as np
import ops
from metrics import Metrics
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import utils
from datahandler import DataHandler
NUM_PICS = 10000
SAVE_REAL_PICS = True
SAVE_PNG = True
SAVE_FAKE_PICS = False
CELEBA_DATA_DIR = 'celebA/datasets/celeba/img_align_celeba'
MNIST_DATA_DIR = 'mnist'
OUT_DIR = 'fid_pics_celeba'
class ExpInfo(object):
def __init__(self):
self.trained_model_path = None
self.model_id = None
self.pz_std = None
self.z_dim = None
self.symmetrize = None
self.dataset = None
self.alias = None
self.test_size = None
def main():
exp_name = sys.argv[-1]
create_dir(OUT_DIR)
exp_names = ['mnist_gan', 'mnist_mmd', 'celeba_gan', 'celeba_mmd']
cluster_mnist_mmd_path = './mount/GANs/results_mnist_pot_sota_worst2d_plateau_mmd_tricks_expC_81'
cluster_mnist_mmd2d_path = './mount/GANs/results_mnist_pot_smaller_zdim2_21'
cluster_mnist_gan_path = './mount/GANs/results_mnist_pot_sota_worst2d1'
cluster_mnist_vae_path = './mount/GANs/results_mnist_vae_81'
cluster_mnist_vae2d_path = './mount/GANs/results_mnist_vae_zdim2_21'
cluster_celeba_gan_path = './mount/GANs/results_celeba_pot_worst2d_plateau_gan_jsmod_641'
cluster_celeba_vae_path = './mount/GANs/results_celeba_vae_641'
cluster_celeba_mmd_path = './mount/GANs/results_celeba_pot_worst2d_plateau_mmd_642'
cluster_celeba_mmd_began_path = './mount/GANs/results_celeba_pot_worst2d_plateau_mmd_began_642'
model_name_prefix = 'trained-pot-'
# Exp 1: CelebA with WAE+MMD on 64 dimensional Z space, DCGAN architecture
exp1 = ExpInfo()
exp1.trained_model_path = cluster_celeba_mmd_path
exp1.model_id = '378720'
exp1.pz_std = 2.0
exp1.z_dim = 64
exp1.symmetrize = True
exp1.dataset = 'celebA'
exp1.alias = 'celeba_mmd_dcgan'
exp1.test_size = 512
# Exp 2: CelebA with WAE+GAN on 64 dimensional Z space, DCGAN architecture
exp2 = ExpInfo()
exp2.trained_model_path = cluster_celeba_gan_path
exp2.model_id = '126480'
exp2.pz_std = 2.0
exp2.z_dim = 64
exp2.symmetrize = True
exp2.dataset = 'celebA'
exp2.alias = 'celeba_gan_dcgan'
exp2.test_size = 512
# Exp 3: MNIST with WAE+MMD on 8 dimensional Z space, DCGAN architecture
exp3 = ExpInfo()
exp3.trained_model_path = cluster_mnist_mmd_path
exp3.model_id = '55200'
exp3.pz_std = 1.0
exp3.z_dim = 8
exp3.symmetrize = False
exp3.dataset = 'mnist'
exp3.alias = 'mnist_mmd_dcgan'
exp3.test_size = 1000
# Exp 4: MNIST with WAE+GAN on 8 dimensional Z space, DCGAN architecture
exp4 = ExpInfo()
exp4.trained_model_path = cluster_mnist_gan_path
exp4.model_id = '62100'
exp4.pz_std = 2.0
exp4.z_dim = 8
exp4.symmetrize = False
exp4.dataset = 'mnist'
exp4.alias = 'mnist_gan_dcgan'
exp4.test_size = 1000
# Exp 5: CelebA with WAE+MMD on 64 dimensional Z space, BEGAN architecture
exp5 = ExpInfo()
exp5.trained_model_path = cluster_celeba_mmd_began_path
exp5.model_id = '157800'
exp5.pz_std = 2.0
exp5.z_dim = 64
exp5.symmetrize = True
exp5.dataset = 'celebA'
exp5.alias = 'celeba_mmd_began'
exp5.test_size = 512
# Exp 6: MNIST with VAE on 8 dimensional Z space, DCGAN architecture
exp6 = ExpInfo()
exp6.trained_model_path = cluster_mnist_vae_path
exp6.model_id = 'final-69000'
exp6.pz_std = 1.0
exp6.z_dim = 8
exp6.symmetrize = False
exp6.dataset = 'mnist'
exp6.alias = 'mnist_vae_dcgan'
exp6.test_size = 1000
# Exp 7: MNIST with VAE on 2 dimensional Z space, DCGAN architecture
exp7 = ExpInfo()
exp7.trained_model_path = cluster_mnist_vae2d_path
exp7.model_id = 'final-69000'
exp7.pz_std = 1.0
exp7.z_dim = 2
exp7.symmetrize = False
exp7.dataset = 'mnist'
exp7.alias = 'mnist_vae_2d_dcgan'
exp7.test_size = 1000
# Exp 8: MNIST with WAE-MMD on 2 dimensional Z space, DCGAN architecture
exp8 = ExpInfo()
exp8.trained_model_path = cluster_mnist_mmd2d_path
exp8.model_id = 'final-69000'
exp8.pz_std = 2.0
exp8.z_dim = 2
exp8.symmetrize = False
exp8.dataset = 'mnist'
exp8.alias = 'mnist_mmd_2d_dcgan'
exp8.test_size = 1000
# Exp 9: CelebA with VAE on 64 dimensional Z space, dcgan architecture
exp9 = ExpInfo()
exp9.trained_model_path = cluster_celeba_vae_path
exp9.model_id = '126240'
exp9.pz_std = 1.0
exp9.z_dim = 64
exp9.symmetrize = True
exp9.dataset = 'celebA'
exp9.alias = 'celeba_vae'
exp9.test_size = 512
if exp_name == 'celeba_mmd_dcgan':
exp = exp1
elif exp_name == 'celeba_gan_dcgan':
exp = exp2
elif exp_name == 'mnist_mmd_dcgan':
exp = exp3
elif exp_name == 'mnist_gan_dcgan':
exp = exp4
elif exp_name == 'celeba_mmd_began':
exp = exp5
elif exp_name == 'mnist_vae':
exp = exp6
elif exp_name == 'mnist_vae_2d':
exp = exp7
elif exp_name == 'mnist_mmd_2d':
exp = exp8
elif exp_name == 'celeba_vae':
exp = exp9
exp_list = [exp]
for exp in exp_list:
output_dir = os.path.join(OUT_DIR, exp.alias)
create_dir(output_dir)
z_dim = exp.z_dim
pz_std = exp.pz_std
dataset = exp.dataset
model_path = exp.trained_model_path
normalyze = exp.symmetrize
if SAVE_REAL_PICS:
pic_dir = os.path.join(output_dir, 'real')
create_dir(pic_dir)
# Saving real pics
opts = {}
opts['dataset'] = dataset
opts['input_normalize_sym'] = normalyze
opts['work_dir'] = output_dir
if exp.dataset == 'celebA':
opts['data_dir'] = CELEBA_DATA_DIR
elif exp.dataset == 'mnist':
opts['data_dir'] = MNIST_DATA_DIR
opts['celebA_crop'] = 'closecrop'
data = DataHandler(opts)
pic_id = 1
if dataset == 'celebA':
shuffled_ids = np.load(os.path.join(model_path, 'shuffled_training_ids'))
test_ids = shuffled_ids[-exp.test_size:]
test_images = data.data
train_ids = shuffled_ids[:-exp.test_size]
train_images = data.data
else:
test_images = data.test_data.X
train_images = data.data.X
train_ids = range(len(train_images))
test_ids = range(len(test_images))
if SAVE_PNG:
for idx in test_ids:
if pic_id % 1000 == 0:
print 'Saved %d/%d' % (pic_id, NUM_PICS)
save_pic(test_images[idx], os.path.join(pic_dir, 'real_image{:05d}.png'.format(pic_id)), exp)
pic_id += 1
if pic_id > NUM_PICS:
break
num_remain = max(NUM_PICS - len(test_ids), 0)
train_size = data.num_points
rand_train_ids = np.random.choice(train_size, num_remain, replace=False)
rand_train_ids = [train_ids[idx] for idx in rand_train_ids]
rand_train_pics = train_images[rand_train_ids]
if SAVE_PNG:
for i in range(num_remain):
if pic_id % 1000 == 0:
print 'Saved %d/%d' % (pic_id, NUM_PICS)
save_pic(rand_train_pics[i], os.path.join(pic_dir, 'real_image{:05d}.png'.format(pic_id)), exp)
pic_id += 1
all_pics = np.vstack([test_images, rand_train_pics])
all_pics = all_pics.astype(np.float)
if len(all_pics) > NUM_PICS:
all_pics = all_pics[:NUM_PICS]
np.random.shuffle(all_pics)
np.save(os.path.join(output_dir, 'real'), all_pics)
if SAVE_FAKE_PICS:
with tf.Session() as sess:
with sess.graph.as_default():
saver = tf.train.import_meta_graph(
os.path.join(model_path, 'checkpoints', model_name_prefix + exp.model_id + '.meta'))
saver.restore(sess, os.path.join(model_path, 'checkpoints', model_name_prefix + exp.model_id))
real_points_ph = tf.get_collection('real_points_ph')[0]
noise_ph = tf.get_collection('noise_ph')[0]
is_training_ph = tf.get_collection('is_training_ph')[0]
decoder = tf.get_collection('decoder')[0]
# Saving random samples
mean = np.zeros(z_dim)
cov = np.identity(z_dim)
noise = pz_std * np.random.multivariate_normal(
mean, cov, NUM_PICS).astype(np.float32)
res = sess.run(decoder, feed_dict={noise_ph: noise, is_training_ph: False})
pic_dir = os.path.join(output_dir, 'fake')
create_dir(pic_dir)
if SAVE_PNG:
for i in range(1, NUM_PICS + 1):
if i % 1000 == 0:
print 'Saved %d/%d' % (i, NUM_PICS)
save_pic(res[i-1], os.path.join(pic_dir, 'fake_image{:05d}.png'.format(i)), exp)
np.save(os.path.join(output_dir, 'fake'), res)
def save_pic(pic, path, exp):
if len(pic.shape) == 4:
pic = pic[0]
height = pic.shape[0]
width = pic.shape[1]
fig = plt.figure(frameon=False, figsize=(width, height))#, dpi=1)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.set_axis_off()
fig.add_axes(ax)
if exp.symmetrize:
pic = (pic + 1.) / 2.
if exp.dataset == 'mnist':
pic = pic[:, :, 0]
pic = 1. - pic
if exp.dataset == 'mnist':
ax.imshow(pic, cmap='Greys', interpolation='none')
else:
ax.imshow(pic, interpolation='none')
fig.savefig(path, dpi=1, format='png')
plt.close()
# if exp.dataset == 'mnist':
# pic = pic[:, :, 0]
# pic = 1. - pic
# ax = plt.imshow(pic, cmap='Greys', interpolation='none')
# else:
# ax = plt.imshow(pic, interpolation='none')
# ax.axes.get_xaxis().set_ticks([])
# ax.axes.get_yaxis().set_ticks([])
# ax.axes.set_xlim([0, width])
# ax.axes.set_ylim([height, 0])
# ax.axes.set_aspect(1)
# fig.savefig(path, format='png')
# plt.close()
def create_dir(d):
if not tf.gfile.IsDirectory(d):
tf.gfile.MakeDirs(d)
main()