import sys sys.path.append("..") import os import pathlib import cv2 from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser import numpy as np import torch from scipy import linalg from models.pose_autoencoder import EmbeddingNet import pickle import random try: from tqdm import tqdm except ImportError: # If tqdm is not available, provide a mock version of it def tqdm(x): return x parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser.add_argument('--batch_size', '-bs', type = int, default = 24) parser.add_argument('--device', '-dev', type = int, default = 0) parser.add_argument('--path_gt', '-pgt', type = str, default = "/home/rafael/masters/test_repro/test") parser.add_argument('--path_generated', '-pg', type = str, default = "/home/rafael/masters/to_move/outputs_reprotest2") parser.add_argument('--keypoint_type', '-kt', type = str, default = "ours") parser.add_argument('--path_checkpoints', '-pck', type = str, default = "/home/rafael/masters/to_move/checkpoints_ae2/step_70300/autoencoder.pth") def _read_instance_gt(root, instance_name): instance_fp = os.path.join(root, instance_name) with open(instance_fp, "rb") as handler: instance = pickle.load(handler) return instance def _read_makeittalk_instance(root, instance_name): instance_fp = os.path.join(root, instance_name) instance = np.load(instance_fp)[:, :, :-1] return instance def _read_ours_instance(root, instance_name): instance_fp = os.path.join(root, instance_name) instance = np.load(instance_fp)["arr_0"] return instance def _read_pt_instance(baseline_root, instance_name): instance_fp = os.path.join(baseline_root, instance_name) instance = np.load(instance_fp) return instance def draw_keypoints(kps): img = np.ones((256, 256, 3))*255 for kp in kps: x = kp[0] y = kp[1] img = cv2.circle(img, (int(x*256), int(y*256)), 2, (0, 255 ,0), 2) return img def get_activations(my_instances, model, batch_size=50, device='cpu'): """Calculates the activations of the latente layer for keypoint faces. Params: -- my_instances: List of keypoints instances [64,136] -- model : Instance of autoencoder -- batch_size : Batch size of instances for the model to process at once. -- device : Device to run calculations Returns: -- A numpy array of dimension (num instances, dims) that contains the latente layer activations of the given tensor when feeding autoencoder with the query tensor. """ model.eval() pred_arr = np.empty((len(my_instances), 32)) start_idx = 0 for batch_start in range(0, len(my_instances), batch_size): if start_idx%10 == 0: print ("Feeding " + str(start_idx)) batch_final = (batch_start + batch_size) if (batch_start + batch_size) < len(my_instances) else len(my_instances) batch = np.asarray(my_instances[batch_start:batch_final]) batch_tensor = torch.from_numpy(batch).to(device=device, dtype=torch.float) with torch.no_grad(): feat, a, b, recon_poses = model(batch_tensor) pred = feat.cpu().detach().numpy() pred_arr[start_idx:start_idx + pred.shape[0]] = pred start_idx = start_idx + pred.shape[0] return pred_arr def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): """Numpy implementation of the Frechet Distance. The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). Stable version by Dougal J. Sutherland. Params: -- mu1 : Numpy array containing the activations of a layer of the inception net (like returned by the function 'get_predictions') for generated samples. -- mu2 : The sample mean over activations, precalculated on an representative data set. -- sigma1: The covariance matrix over activations for generated samples. -- sigma2: The covariance matrix over activations, precalculated on an representative data set. Returns: -- : The Frechet Distance. """ mu1 = np.atleast_1d(mu1) mu2 = np.atleast_1d(mu2) sigma1 = np.atleast_2d(sigma1) sigma2 = np.atleast_2d(sigma2) assert mu1.shape == mu2.shape, \ 'Training and test mean vectors have different lengths' assert sigma1.shape == sigma2.shape, \ 'Training and test covariances have different dimensions' diff = mu1 - mu2 # Product might be almost singular covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) if not np.isfinite(covmean).all(): msg = ('fid calculation produces singular product; ' 'adding %s to diagonal of cov estimates') % eps print(msg) offset = np.eye(sigma1.shape[0]) * eps covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) if np.iscomplexobj(covmean): # Numerical error might give slight imaginary component #if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): # m = np.max(np.abs(covmean.imag)) # raise ValueError('Imaginary component {}'.format(m)) covmean = covmean.real tr_covmean = np.trace(covmean) return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) def calculate_activation_statistics(my_instances, model, batch_size=50, device='cpu'): """Calculation of the statistics used by the FID. Params: -- files : List of image files paths -- model : Instance of inception model -- batch_size : The images numpy array is split into batches with batch size batch_size. A reasonable batch size depends on the hardware. -- dims : Dimensionality of features returned by Inception -- device : Device to run calculations -- num_workers : Number of parallel dataloader workers Returns: -- mu : The mean over samples of the activations of the pool_3 layer of the inception model. -- sigma : The covariance matrix of the activations of the pool_3 layer of the inception model. """ act = get_activations(my_instances, model, batch_size, device) mu = np.mean(act, axis=0) sigma = np.cov(act, rowvar=False) return mu, sigma def compute_statistics_of_path(path, model, batch_size, keypoint_type, device): my_instances = list() if keypoint_type == "gt": print ("gt") elif keypoint_type == "pt": print ("pt") elif keypoint_type == "mt": print ("mt") elif keypoint_type == "ours": print ("ours") elif keypoint_type == "nslp": print ("ours") else: print ("Keypoint type error") assert False # todo: format mt and pt keypoints for instance_name in sorted(os.listdir(path)): try: if keypoint_type == "gt": my_instances.append(np.asarray(_read_instance_gt(path, instance_name)["kps"])[:, :-1, :].reshape(-1, 136)) elif keypoint_type == "mt": my_instances.append(_read_makeittalk_instance(path, instance_name).reshape(-1, 136)) elif keypoint_type == "pt": my_instances.append(_read_pt_instance(path, instance_name)[:,:-1].reshape(-1, 136)) elif keypoint_type == "ours": my_instances.append(_read_ours_instance(path, instance_name).squeeze(0)[:, :-1, :].reshape(-1, 136)) elif keypoint_type == "nslp": test = _read_instance_gt(path, instance_name)/256 test = test.reshape(-1, 136).numpy() my_instances.append(test) except: print("found no existing instance, skipping") import traceback print(traceback.format_exc()) m, s = calculate_activation_statistics(my_instances, model, batch_size, device) return m, s def calculate_fid_given_paths(paths_gt, paths_g, batch_size, device, keypoint_type,path_checkpoints): """Calculates the FID of two paths""" if not os.path.exists(paths_gt): raise RuntimeError('Invalid path: %s' % paths_gt) if not os.path.exists(paths_g): raise RuntimeError('Invalid path: %s' % paths_g) # pose_dim = 68*2,n_frames model = EmbeddingNet(136, 64).to(device) try: model.load_state_dict(torch.load(path_checkpoints, map_location=torch.device("cpu"))) except: import traceback print(traceback.format_exc()) import pdb pdb.set_trace() model.eval() m1, s1 = compute_statistics_of_path(paths_gt, model, batch_size, "gt", device) m2, s2 = compute_statistics_of_path(paths_g, model, batch_size, keypoint_type, device) fid_value = calculate_frechet_distance(m1, s1, m2, s2) return fid_value def main(): args = parser.parse_args() SEED = 1234 random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) torch.backends.cudnn.deterministic = True #device = torch.device('cuda:{}'.format(args.device)) if args.device != -1 else torch.device('cuda') device = torch.device("cpu") fid_value = calculate_fid_given_paths(args.path_gt, args.path_generated, args.batch_size, device, args.keypoint_type, args.path_checkpoints) print('FID: ', fid_value) if __name__ == '__main__': main()