Revision 0a2a4ad4da69f2f8c53fd5ee96e895c72aeb9f26 authored by NicoZenith on 02 February 2022, 17:32:26 UTC, committed by GitHub on 02 February 2022, 17:32:26 UTC
1 parent 65c811b
utils.py
import torch
import numpy as np
# plotting
import matplotlib
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import Dataset, TensorDataset
from scipy import linalg
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_dataset(dataset_name, dataroot, imageSize, is_train=True, drop_rate=0.0, tile_size=32):
if dataset_name == 'cifar10':
dataset = dset.CIFAR10(
train=is_train,
root=dataroot, download=False,
transform=transforms.Compose([
transforms.Resize(imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
Occlude(drop_rate=drop_rate, tile_size=tile_size),
]))
unorm = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
img_channels = 3
elif dataset_name == 'svhn':
if is_train:
split = 'train'
else:
split = 'test'
dataset = dset.SVHN(
root=dataroot, download=False,
split = split,
transform=transforms.Compose([
transforms.Resize(imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
Occlude(drop_rate=drop_rate, tile_size=tile_size)
]))
unorm = UnNormalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
img_channels = 3
elif dataset_name == 'mnist':
dataset = dset.MNIST(
train=is_train,
root=dataroot, download=False,
transform=transforms.Compose([
transforms.Resize(imageSize),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,)),
Occlude(drop_rate=drop_rate, tile_size=tile_size)
])
)
unorm = UnNormalize(mean=(0.5,), std=(0.5,))
img_channels = 1
elif dataset_name == 'fashion':
dataset = dset.FashionMNIST(
train=is_train,
root=dataroot, download=False,
transform=transforms.Compose([
transforms.Resize(imageSize),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,)),
Occlude(drop_rate=drop_rate, tile_size=tile_size)
]))
unorm = UnNormalize(mean=(0.5,), std=(0.5,))
img_channels = 1
else:
raise NotImplementedError("No such dataset {}".format(dataset_name))
assert dataset
return dataset, unorm, img_channels
# compute the current classification accuracy
def compute_acc(preds, labels):
correct = 0
preds_ = preds.data.max(1)[1]
correct = preds_.eq(labels.data).cpu().sum()
acc = float(correct) / float(len(labels.data)) * 100.0
return acc
def get_latent(dim_latent, batch_size, device):
latent_z = np.random.normal(0, 1, (batch_size, dim_latent)) # generate random labels
latent_z = torch.tensor(latent_z, dtype=torch.float32, device=device)
latent_z = latent_z.view(batch_size, dim_latent, 1, 1)
return latent_z
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensorBatch):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized image.
"""
for i in range(len(tensorBatch)):
for j in range(len(tensorBatch[i])):
tensorBatch[i][j].mul_(self.std[j]).add_(self.mean[j])
# The normalize code -> t.sub_(m).div_(s)
return tensorBatch
def save_fig_losses(epoch, d_losses, g_losses, r_losses_real, r_losses_fake, kl_losses, fid_NREM, fid_REM, dir_files):
e = np.arange(0, epoch+1)
fig = plt.figure(figsize=(10,5))
ax1 = fig.add_subplot(121)
if g_losses is not None:
ax1.plot(e, g_losses, label='generator (REM)')
if d_losses is not None:
ax1.plot(e, d_losses, color='green', label='discriminator (Wake, REM)')
#ax1.set_ylim(0, 10)
ax1.set_xlabel('epochs')
ax1.set_ylabel('loss')
ax1.set_title('losses with training')
if r_losses_real is not None:
ax1.plot(e, r_losses_real, color='orange', label='data rec. (Wake)')
if r_losses_fake is not None:
ax1.plot(e, r_losses_fake, color='magenta', label='latent rec. (NREM)')
if kl_losses is not None:
ax1.plot(e, kl_losses, color='brown', label='KL div. (Wake)')
ax1.legend()
if fid_NREM is not None and fid_REM is not None:
ax2 = fig.add_subplot(122)
ax2.plot(e, fid_NREM, color='darkorange', label='FID NREM')
ax2.plot(e, fid_REM, color='magenta', label='FID REM')
ax2.legend()
fig.savefig(dir_files+'/losses.pdf')
def save_fig_trainval(epoch, all_losses, all_accuracies, dir_files):
e = np.arange(0, epoch+1)
fig = plt.figure(figsize=(10,5))
ax1 = fig.add_subplot(121)
ax1.plot(e, all_losses['train'], label='train loss')
ax1.plot(e, all_losses['val'], label='validation loss')
ax1.set_xlabel('epochs')
ax1.set_ylabel('loss')
ax1.legend()
ax2 = fig.add_subplot(122)
ax2.plot(e, all_accuracies['train'], label='train accuracy')
ax2.plot(e, all_accuracies['val'], label='val accuracy')
ax2.set_xlabel('epochs')
ax2.set_ylabel('accuracy (%)')
ax2.set_ylim(0, 100)
ax2.legend()
fig.savefig(dir_files + '/trainval.pdf')
class Occlude(object):
def __init__(self, drop_rate=0.0, tile_size=7):
self.drop_rate = drop_rate
self.tile_size = tile_size
def __call__(self, imgs, d=0):
imgs_n = imgs.clone()
if d==0:
device='cpu'
else:
device = imgs.get_device()
if device ==-1:
device = 'cpu'
mask = torch.ones((imgs_n.size(d), imgs_n.size(d+1), imgs_n.size(d+2)), device=device) # only ones = no mask
i = 0
while i < imgs_n.size(d+1):
j = 0
while j < imgs_n.size(d+2):
if np.random.rand() < self.drop_rate:
for k in range(mask.size(0)):
mask[k, i:i + self.tile_size, j:j + self.tile_size] = 0 # set to zero the whole tile
j += self.tile_size
i += self.tile_size
imgs_n = imgs_n * mask # apply the mask to each image
return imgs_n
def kl_loss(latent_output):
m = torch.mean(latent_output, dim=0)
s = torch.std(latent_output, dim=0)
kl_loss = torch.mean((s ** 2 + m ** 2) / 2 - torch.log(s) - 1/2)
return kl_loss
def mean_and_sem(array, color=None, axis=0):
mean = array.mean(axis=0)
sem_plus = mean + stats.sem(array, axis=axis)
sem_minus = mean - stats.sem(array, axis=axis)
if color is not None:
ax.fill_between(np.arange(mean.shape[0]), sem_plus, sem_minus, color=color, alpha=0.5)
else:
ax.fill_between(np.arange(mean.shape[0]), sem_plus, sem_minus, alpha=0.5)
return mean
def calculate_activation_statistics(images,model,batch_size=128, dims=2048,
cuda=False):
model.eval()
act=np.empty((len(images), dims))
if cuda:
batch=images.cuda()
else:
batch=images
pred = model(batch)[0]
# If model output is not scalar, apply global spatial average pooling.
# This happens if you choose a dimensionality not equal 2048.
if pred.size(2) != 1 or pred.size(3) != 1:
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
act= pred.cpu().data.numpy().reshape(pred.size(0), -1)
return act
mu = np.mean(act, axis=0)
sigma = np.cov(act, rowvar=False)
return mu, sigma
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
"""
mu1 = np.atleast_1d(mu1)
mu2 = np.atleast_1d(mu2)
sigma1 = np.atleast_2d(sigma1)
sigma2 = np.atleast_2d(sigma2)
assert mu1.shape == mu2.shape, \
'Training and test mean vectors have different lengths'
assert sigma1.shape == sigma2.shape, \
'Training and test covariances have different dimensions'
diff = mu1 - mu2
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = ('fid calculation produces singular product; '
'adding %s to diagonal of cov estimates') % eps
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError('Imaginary component {}'.format(m))
covmean = covmean.real
tr_covmean = np.trace(covmean)
return (diff.dot(diff) + np.trace(sigma1) +
np.trace(sigma2) - 2 * tr_covmean)
def calculate_frechet(inception_real,inception_fake,model, return_statistics=False):
mu_1 = np.mean(inception_real, axis=0)
mu_2 = np.mean(inception_fake, axis=0)
std_1 = np.cov(inception_real, rowvar=False)
std_2 = np.cov(inception_fake, rowvar=False)
"""get fretched distance"""
fid_value = calculate_frechet_distance(mu_1, std_1, mu_2, std_2)
return fid_value
Computing file changes ...