https://github.com/Hosein47/LOGAN
Tip revision: 7717db273a0a0f018d0c0c44db5f888ac14f9672 authored by Hosein Hashemi on 21 December 2020, 21:56:05 UTC
Update README.md
Update README.md
Tip revision: 7717db2
utils.py
import os
import numpy as np
import torch
import torch.nn as nn
import torchvision
from argparse import ArgumentParser
def prepare_parser():
usage = 'Parser for all scripts.'
parser = ArgumentParser(description=usage)
### Dataset/Dataloader stuff ###
parser.add_argument(
'--dataset', type=str, default='I128_hdf5',
help='Which Dataset to train on, out of I128, I256, C10, C100;'
'Append "_hdf5" to use the hdf5 version for ISLVRC '
'(default: %(default)s)')
parser.add_argument(
'--augment', action='store_true', default=False,
help='Augment with random crops and flips (default: %(default)s)')
parser.add_argument(
'--num_workers', type=int, default=8,
help='Number of dataloader workers; consider using less for HDF5 '
'(default: %(default)s)')
parser.add_argument(
'--no_pin_memory', action='store_false', dest='pin_memory', default=True,
help='Pin data into memory through dataloader? (default: %(default)s)')
parser.add_argument(
'--shuffle', action='store_true', default=False,
help='Shuffle the data (strongly recommended)? (default: %(default)s)')
parser.add_argument(
'--load_in_mem', action='store_true', default=False,
help='Load all data into memory? (default: %(default)s)')
parser.add_argument(
'--use_multiepoch_sampler', action='store_true', default=False,
help='Use the multi-epoch sampler for dataloader? (default: %(default)s)')
### Model stuff ###
parser.add_argument(
'--model', type=str, default='BigGAN',
help='Name of the model module (default: %(default)s)')
parser.add_argument(
'--G_param', type=str, default='SN',
help='Parameterization style to use for G, spectral norm (SN) or SVD (SVD)'
' or None (default: %(default)s)')
parser.add_argument(
'--D_param', type=str, default='SN',
help='Parameterization style to use for D, spectral norm (SN) or SVD (SVD)'
' or None (default: %(default)s)')
parser.add_argument(
'--G_ch', type=int, default=64,
help='Channel multiplier for G (default: %(default)s)')
parser.add_argument(
'--D_ch', type=int, default=64,
help='Channel multiplier for D (default: %(default)s)')
parser.add_argument(
'--G_depth', type=int, default=1,
help='Number of resblocks per stage in G? (default: %(default)s)')
parser.add_argument(
'--D_depth', type=int, default=1,
help='Number of resblocks per stage in D? (default: %(default)s)')
parser.add_argument(
'--D_thin', action='store_false', dest='D_wide', default=True,
help='Use the SN-GAN channel pattern for D? (default: %(default)s)')
parser.add_argument(
'--G_shared', action='store_true', default=False,
help='Use shared embeddings in G? (default: %(default)s)')
parser.add_argument(
'--shared_dim', type=int, default=0,
help='G''s shared embedding dimensionality; if 0, will be equal to dim_z. '
'(default: %(default)s)')
parser.add_argument(
'--dim_z', type=int, default=128,
help='Noise dimensionality: %(default)s)')
parser.add_argument(
'--z_var', type=float, default=1.0,
help='Noise variance: %(default)s)')
parser.add_argument(
'--hier', action='store_true', default=False,
help='Use hierarchical z in G? (default: %(default)s)')
parser.add_argument(
'--cross_replica', action='store_true', default=False,
help='Cross_replica batchnorm in G?(default: %(default)s)')
parser.add_argument(
'--mybn', action='store_true', default=False,
help='Use my batchnorm (which supports standing stats?) %(default)s)')
parser.add_argument(
'--G_nl', type=str, default='relu',
help='Activation function for G (default: %(default)s)')
parser.add_argument(
'--D_nl', type=str, default='relu',
help='Activation function for D (default: %(default)s)')
parser.add_argument(
'--G_attn', type=str, default='64',
help='What resolutions to use attention on for G (underscore separated) '
'(default: %(default)s)')
parser.add_argument(
'--D_attn', type=str, default='64',
help='What resolutions to use attention on for D (underscore separated) '
'(default: %(default)s)')
parser.add_argument(
'--norm_style', type=str, default='bn',
help='Normalizer style for G, one of bn [batchnorm], in [instancenorm], '
'ln [layernorm], gn [groupnorm] (default: %(default)s)')
### Model init stuff ###
parser.add_argument(
'--seed', type=int, default=0,
help='Random seed to use; affects both initialization and '
' dataloading. (default: %(default)s)')
parser.add_argument(
'--G_init', type=str, default='ortho',
help='Init style to use for G (default: %(default)s)')
parser.add_argument(
'--D_init', type=str, default='ortho',
help='Init style to use for D(default: %(default)s)')
parser.add_argument(
'--skip_init', action='store_true', default=False,
help='Skip initialization, ideal for testing when ortho init was used '
'(default: %(default)s)')
### Optimizer stuff ###
parser.add_argument(
'--G_lr', type=float, default=5e-5,
help='Learning rate to use for Generator (default: %(default)s)')
parser.add_argument(
'--D_lr', type=float, default=2e-4,
help='Learning rate to use for Discriminator (default: %(default)s)')
parser.add_argument(
'--G_B1', type=float, default=0.0,
help='Beta1 to use for Generator (default: %(default)s)')
parser.add_argument(
'--D_B1', type=float, default=0.0,
help='Beta1 to use for Discriminator (default: %(default)s)')
parser.add_argument(
'--G_B2', type=float, default=0.999,
help='Beta2 to use for Generator (default: %(default)s)')
parser.add_argument(
'--D_B2', type=float, default=0.999,
help='Beta2 to use for Discriminator (default: %(default)s)')
### Batch size, parallel, and precision stuff ###
parser.add_argument(
'--batch_size', type=int, default=64,
help='Default overall batchsize (default: %(default)s)')
parser.add_argument(
'--G_batch_size', type=int, default=0,
help='Batch size to use for G; if 0, same as D (default: %(default)s)')
parser.add_argument(
'--num_G_accumulations', type=int, default=1,
help='Number of passes to accumulate G''s gradients over '
'(default: %(default)s)')
parser.add_argument(
'--num_D_steps', type=int, default=2,
help='Number of D steps per G step (default: %(default)s)')
parser.add_argument(
'--num_D_accumulations', type=int, default=1,
help='Number of passes to accumulate D''s gradients over '
'(default: %(default)s)')
parser.add_argument(
'--split_D', action='store_true', default=False,
help='Run D twice rather than concatenating inputs? (default: %(default)s)')
parser.add_argument(
'--num_epochs', type=int, default=100,
help='Number of epochs to train for (default: %(default)s)')
parser.add_argument(
'--parallel', action='store_true', default=False,
help='Train with multiple GPUs (default: %(default)s)')
parser.add_argument(
'--G_fp16', action='store_true', default=False,
help='Train with half-precision in G? (default: %(default)s)')
parser.add_argument(
'--D_fp16', action='store_true', default=False,
help='Train with half-precision in D? (default: %(default)s)')
parser.add_argument(
'--D_mixed_precision', action='store_true', default=False,
help='Train with half-precision activations but fp32 params in D? '
'(default: %(default)s)')
parser.add_argument(
'--G_mixed_precision', action='store_true', default=False,
help='Train with half-precision activations but fp32 params in G? '
'(default: %(default)s)')
parser.add_argument(
'--accumulate_stats', action='store_true', default=False,
help='Accumulate "standing" batchnorm stats? (default: %(default)s)')
parser.add_argument(
'--num_standing_accumulations', type=int, default=16,
help='Number of forward passes to use in accumulating standing stats? '
'(default: %(default)s)')
### Bookkeping stuff ###
parser.add_argument(
'--G_eval_mode', action='store_true', default=False,
help='Run G in eval mode (running/standing stats?) at sample/test time? '
'(default: %(default)s)')
parser.add_argument(
'--save_every', type=int, default=2000,
help='Save every X iterations (default: %(default)s)')
parser.add_argument(
'--num_save_copies', type=int, default=2,
help='How many copies to save (default: %(default)s)')
parser.add_argument(
'--num_best_copies', type=int, default=2,
help='How many previous best checkpoints to save (default: %(default)s)')
parser.add_argument(
'--which_best', type=str, default='IS',
help='Which metric to use to determine when to save new "best"'
'checkpoints, one of IS or FID (default: %(default)s)')
parser.add_argument(
'--no_fid', action='store_true', default=False,
help='Calculate IS only, not FID? (default: %(default)s)')
parser.add_argument(
'--test_every', type=int, default=5000,
help='Test every X iterations (default: %(default)s)')
parser.add_argument(
'--num_inception_images', type=int, default=50000,
help='Number of samples to compute inception metrics with '
'(default: %(default)s)')
parser.add_argument(
'--hashname', action='store_true', default=False,
help='Use a hash of the experiment name instead of the full config '
'(default: %(default)s)')
parser.add_argument(
'--base_root', type=str, default='',
help='Default location to store all weights, samples, data, and logs '
' (default: %(default)s)')
parser.add_argument(
'--data_root', type=str, default='data',
help='Default location where data is stored (default: %(default)s)')
parser.add_argument(
'--label_root', type=str, default='labels',
help='Default location where annotations is stored (default: %(default)s)')
parser.add_argument(
'--weights_root', type=str, default='weights',
help='Default location to store weights (default: %(default)s)')
parser.add_argument(
'--logs_root', type=str, default='logs',
help='Default location to store logs (default: %(default)s)')
parser.add_argument(
'--samples_root', type=str, default='samples',
help='Default location to store samples (default: %(default)s)')
parser.add_argument(
'--pbar', type=str, default='mine',
help='Type of progressbar to use; one of "mine" or "tqdm" '
'(default: %(default)s)')
parser.add_argument(
'--name_suffix', type=str, default='',
help='Suffix for experiment name for loading weights for sampling '
'(consider "best0") (default: %(default)s)')
parser.add_argument(
'--experiment_name', type=str, default='',
help='Optionally override the automatic experiment naming with this arg. '
'(default: %(default)s)')
parser.add_argument(
'--config_from_name', action='store_true', default=False,
help='Use a hash of the experiment name instead of the full config '
'(default: %(default)s)')
### EMA Stuff ###
parser.add_argument(
'--ema', action='store_true', default=False,
help='Keep an ema of G''s weights? (default: %(default)s)')
parser.add_argument(
'--ema_decay', type=float, default=0.9999,
help='EMA decay rate (default: %(default)s)')
parser.add_argument(
'--use_ema', action='store_true', default=False,
help='Use the EMA parameters of G for evaluation? (default: %(default)s)')
parser.add_argument(
'--ema_start', type=int, default=0,
help='When to start updating the EMA weights (default: %(default)s)')
### Numerical precision and SV stuff ###
parser.add_argument(
'--adam_eps', type=float, default=1e-8,
help='epsilon value to use for Adam (default: %(default)s)')
parser.add_argument(
'--BN_eps', type=float, default=1e-5,
help='epsilon value to use for BatchNorm (default: %(default)s)')
parser.add_argument(
'--SN_eps', type=float, default=1e-8,
help='epsilon value to use for Spectral Norm(default: %(default)s)')
parser.add_argument(
'--num_G_SVs', type=int, default=1,
help='Number of SVs to track in G (default: %(default)s)')
parser.add_argument(
'--num_D_SVs', type=int, default=1,
help='Number of SVs to track in D (default: %(default)s)')
parser.add_argument(
'--num_G_SV_itrs', type=int, default=1,
help='Number of SV itrs in G (default: %(default)s)')
parser.add_argument(
'--num_D_SV_itrs', type=int, default=1,
help='Number of SV itrs in D (default: %(default)s)')
### Ortho reg stuff ###
parser.add_argument(
'--G_ortho', type=float, default=0.0, # 1e-4 is default for BigGAN
help='Modified ortho reg coefficient in G(default: %(default)s)')
parser.add_argument(
'--D_ortho', type=float, default=0.0,
help='Modified ortho reg coefficient in D (default: %(default)s)')
parser.add_argument(
'--toggle_grads', action='store_true', default=True,
help='Toggle D and G''s "requires_grad" settings when not training them? '
' (default: %(default)s)')
### Which train function ###
parser.add_argument(
'--which_train_fn', type=str, default='GAN',
help='How2trainyourbois (default: %(default)s)')
### Resume training stuff
parser.add_argument(
'--load_weights', type=str, default='',
help='Suffix for which weights to load (e.g. best0, copy0) '
'(default: %(default)s)')
parser.add_argument(
'--resume', action='store_true', default=False,
help='Resume training? (default: %(default)s)')
### Log stuff ###
parser.add_argument(
'--logstyle', type=str, default='%3.3e',
help='What style to use when logging training metrics?'
'One of: %#.#f/ %#.#e (float/exp, text),'
'pickle (python pickle),'
'npz (numpy zip),'
'mat (MATLAB .mat file) (default: %(default)s)')
parser.add_argument(
'--log_G_spectra', action='store_true', default=False,
help='Log the top 3 singular values in each SN layer in G? '
'(default: %(default)s)')
parser.add_argument(
'--log_D_spectra', action='store_true', default=False,
help='Log the top 3 singular values in each SN layer in D? '
'(default: %(default)s)')
parser.add_argument(
'--sv_log_interval', type=int, default=10,
help='Iteration interval for logging singular values '
' (default: %(default)s)')
parser.add_argument(
'--log_interval', type=int, default=100,
help='Iteration interval for general logs '
' (default: %(default)s)')
return parser
# Arguments for sample.py; not presently used in train.py
def add_sample_parser(parser):
parser.add_argument(
'--sample_num', type=int, default=10000,
help='Number of images to sample'
'(default: %(default)s)')
parser.add_argument(
'--trunc_z', type=float, default=0.0,
help='Truncate noise vector to this value'
'(default: %(default)s)')
return parser
activation_dict = {
'inplace_relu': nn.ReLU(inplace=True),
'relu': nn.ReLU(inplace=False),
'ir': nn.ReLU(inplace=True),
'inplace_Lrelu' : nn.LeakyReLU(0.2, inplace=True)
'Lrelu' : nn.LeakyReLU(0.2, inplace=False)}
def denorm(x):
out = (x + 1) / 2
return out.clamp_(0, 1)
# Function to join strings or ignore them
# Base string is the string to link "strings," while strings
# is a list of strings or Nones.
def join_strings(base_string, strings):
return base_string.join([item for item in strings if item])
# Save a model's weights, optimizer, and the state_dict
def save_weights(G, D, state_dict, weights_root, experiment_name, name_suffix=None, G_ema=None):
root = '/'.join([weights_root, experiment_name])
if not os.path.exists(root):
os.mkdir(root)
if name_suffix:
print('Saving weights to %s/%s...' % (root, name_suffix))
else:
print('Saving weights to %s...' % root)
torch.save(G.state_dict(),
'%s/%s.pth' % (root, join_strings('_', ['G', name_suffix])))
torch.save(G.optim.state_dict(),
'%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix])))
torch.save(D.state_dict(),
'%s/%s.pth' % (root, join_strings('_', ['D', name_suffix])))
torch.save(D.optim.state_dict(),
'%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix])))
torch.save(state_dict,
'%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix])))
if G_ema is not None:
torch.save(G_ema.state_dict(),
'%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix])))
# Load a model's weights, optimizer, and the state_dict
def load_weights(G, D, state_dict, weights_root, experiment_name,
name_suffix=None, G_ema=None, strict=True, load_optim=True):
root = '/'.join([weights_root, experiment_name])
if name_suffix:
print('Loading %s weights from %s...' % (name_suffix, root))
else:
print('Loading weights from %s...' % root)
if G is not None:
G.load_state_dict(
torch.load('%s/%s.pth' % (root, join_strings('_', ['G', name_suffix]))),
strict=strict)
if load_optim:
G.optim.load_state_dict(
torch.load('%s/%s.pth' % (root, join_strings('_', ['G_optim', name_suffix]))))
if D is not None:
D.load_state_dict(
torch.load('%s/%s.pth' % (root, join_strings('_', ['D', name_suffix]))),
strict=strict)
if load_optim:
D.optim.load_state_dict(
torch.load('%s/%s.pth' % (root, join_strings('_', ['D_optim', name_suffix]))))
# Load state dict
for item in state_dict:
state_dict[item] = torch.load('%s/%s.pth' % (root, join_strings('_', ['state_dict', name_suffix])))[item]
if G_ema is not None:
G_ema.load_state_dict(
torch.load('%s/%s.pth' % (root, join_strings('_', ['G_ema', name_suffix]))),
strict=strict)
# Utility file to seed rngs
def seed_rng(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
# Utility to peg all roots to a base root
# If a base root folder is provided, peg all other root folders to it.
def update_config_roots(config):
if config['base_root']:
print('Pegging all root folders to base root %s' % config['base_root'])
for key in ['data', 'weights', 'logs', 'samples']:
config['%s_root' % key] = '%s/%s' % (config['base_root'], key)
return config
# Utility to prepare root folders if they don't exist; parent folder must exist
def prepare_root(config):
for key in ['weights_root', 'logs_root', 'samples_root']:
if not os.path.exists(config[key]):
print('Making directory %s for %s...' % (config[key], key))
os.mkdir(config[key])
# Simple wrapper that applies EMA to a model. COuld be better done in 1.0 using
# the parameters() and buffers() module functions, but for now this works
# with state_dicts using .copy_
class ema(object):
def __init__(self, source, target, decay=0.9999, start_itr=0):
self.source = source
self.target = target
self.decay = decay
# Optional parameter indicating what iteration to start the decay at
self.start_itr = start_itr
# Initialize target's params to be source's
self.source_dict = self.source.state_dict()
self.target_dict = self.target.state_dict()
print('Initializing EMA parameters to be source parameters...')
with torch.no_grad():
for key in self.source_dict:
self.target_dict[key].data.copy_(self.source_dict[key].data)
# target_dict[key].data = source_dict[key].data # Doesn't work!
def update(self, itr=None):
# If an iteration counter is provided and itr is less than the start itr,
# peg the ema weights to the underlying weights.
if itr and itr < self.start_itr:
decay = 0.0
else:
decay = self.decay
with torch.no_grad():
for key in self.source_dict:
self.target_dict[key].data.copy_(self.target_dict[key].data * decay
+ self.source_dict[key].data * (1 - decay))
# Apply modified ortho reg to a model
# This function is an optimized version that directly computes the gradient,
# instead of computing and then differentiating the loss.
def ortho(model, strength=1e-4, blacklist=[]):
with torch.no_grad():
for param in model.parameters():
# Only apply this to parameters with at least 2 axes, and not in the blacklist
if len(param.shape) < 2 or any([param is item for item in blacklist]):
continue
w = param.view(param.shape[0], -1)
grad = (2 * torch.mm(torch.mm(w, w.t())
* (1. - torch.eye(w.shape[0], device=w.device)), w))
param.grad.data += strength * grad.view(param.shape)
# Default ortho reg
# This function is an optimized version that directly computes the gradient,
# instead of computing and then differentiating the loss.
def default_ortho(model, strength=1e-4, blacklist=[]):
with torch.no_grad():
for param in model.parameters():
# Only apply this to parameters with at least 2 axes & not in blacklist
if len(param.shape) < 2 or param in blacklist:
continue
w = param.view(param.shape[0], -1)
grad = (2 * torch.mm(torch.mm(w, w.t())
- torch.eye(w.shape[0], device=w.device), w))
param.grad.data += strength * grad.view(param.shape)
# Convenience utility to switch off requires_grad
def toggle_grad(model, on_or_off):
for param in model.parameters():
param.requires_grad = on_or_off
def prepare_z_y(G_batch_size, dim_z, nclasses, device='cuda', fp16=False, z_var=1.0, ngd=True, fixed = False):
if ngd:
Tensor = torch.cuda.FloatTensor
if fixed:
z_ = Variable(Tensor(np.random.uniform(-1, 1, (G_batch_size, dim_z))), requires_grad=False)
else:
z_ = Variable(Tensor(np.random.uniform(-1, 1, (G_batch_size, dim_z))), requires_grad=True)
else:
latents = torch.randn((G_batch_size, latent_dim), dtype=torch.float32, device=device)
if fp16:
z_ = z_.half()
y_ = torch.randint(0, num_classes, size=(G_batch_size,), dtype=torch.long, device=device)
return z_, y_
# Sample function for sample sheets
def sample_sheet(G, classes_per_sheet, num_classes, samples_per_class, parallel,
samples_root, experiment_name, folder_number, z_=None):
# Prepare sample directory
if not os.path.isdir('%s/%s' % (samples_root, experiment_name)):
os.mkdir('%s/%s' % (samples_root, experiment_name))
if not os.path.isdir('%s/%s/%d' % (samples_root, experiment_name, folder_number)):
os.mkdir('%s/%s/%d' % (samples_root, experiment_name, folder_number))
# loop over total number of sheets
for i in range(num_classes // classes_per_sheet):
ims = []
y = torch.arange(i * classes_per_sheet, (i + 1) * classes_per_sheet, device='cuda')
for j in range(samples_per_class):
with torch.no_grad():
if parallel:
o = nn.parallel.data_parallel(G, (z_[:classes_per_sheet], G.shared(y)))
else:
o = G(z_[:classes_per_sheet], G.shared(y))
ims += [o.data.cpu()]
# This line should properly unroll the images
out_ims = torch.stack(ims, 1).view(-1, ims[0].shape[1], ims[0].shape[2],
ims[0].shape[3]).data.float().cpu()
# The path for the samples
image_filename = '%s/%s/%d/samples%d.jpg' % (samples_root, experiment_name,
folder_number, i)
torchvision.utils.save_image(out_ims, image_filename,
nrow=samples_per_class, normalize=True)
# Interp function; expects x0 and x1 to be of shape (shape0, 1, rest_of_shape..)
def interp(x0, x1, num_midpoints):
lerp = torch.linspace(0, 1.0, num_midpoints + 2, device='cuda').to(x0.dtype)
return ((x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1)))
# Convenience function to sample an index, not actually a 1-hot
def sample_1hot(batch_size, num_classes, device='cuda'):
return torch.randint(low=0, high=num_classes, size=(batch_size,),
device=device, dtype=torch.int64, requires_grad=False)
# interp sheet function
# Supports full, class-wise and intra-class interpolation
def interp_sheet(G, num_per_sheet, num_midpoints, num_classes, parallel,
samples_root, experiment_name, folder_number, sheet_number=0,
fix_z=False, fix_y=False, device='cuda'):
# Prepare zs and ys
if fix_z: # If fix Z, only sample 1 z per row
zs = torch.randn(num_per_sheet, 1, G.dim_z, device=device)
zs = zs.repeat(1, num_midpoints + 2, 1).view(-1, G.dim_z)
else:
zs = interp(torch.randn(num_per_sheet, 1, G.dim_z, device=device),
torch.randn(num_per_sheet, 1, G.dim_z, device=device),
num_midpoints).view(-1, G.dim_z)
if fix_y: # If fix y, only sample 1 z per row
ys = sample_1hot(num_per_sheet, num_classes)
ys = G.shared(ys).view(num_per_sheet, 1, -1)
ys = ys.repeat(1, num_midpoints + 2, 1).view(num_per_sheet * (num_midpoints + 2), -1)
else:
ys = interp(G.shared(sample_1hot(num_per_sheet, num_classes)).view(num_per_sheet, 1, -1),
G.shared(sample_1hot(num_per_sheet, num_classes)).view(num_per_sheet, 1, -1),
num_midpoints).view(num_per_sheet * (num_midpoints + 2), -1)
# Run the net--note that we've already passed y through G.shared.
if G.fp16:
zs = zs.half()
with torch.no_grad():
if parallel:
out_ims = nn.parallel.data_parallel(G, (zs, ys)).data.cpu()
else:
out_ims = G(zs, ys).data.cpu()
interp_style = '' + ('Z' if not fix_z else '') + ('Y' if not fix_y else '')
image_filename = '%s/%s/%d/interp%s%d.jpg' % (samples_root, experiment_name,
folder_number, interp_style,
sheet_number)
torchvision.utils.save_image(out_ims, image_filename,
nrow=num_midpoints + 2, normalize=True)