import os
import sys
import pickle
import numpy as np
import random
import torch
import torchvision
from utils.util import Utils
from models.decoder_stgcn import Decoder
from data.dataset import SignProdDataset
def read_pickle(file_fp):
with open(file_fp, "rb") as handler:
return pickle.load(handler)
def make_exp(test_input, exp_name, decoder, Z, device):
with torch.no_grad():
test_input = torch.Tensor(test_input).to(device)
faces_fake = decoder(None, None, test_input).permute(0, 2, 3, 1).cpu().numpy()
Utils.visualize_data_single(faces_fake, exp_name)
return faces_fake
def printinfo(dataset_root, file):
file_fp = os.path.join(dataset_root, file)
instance = read_pickle(file_fp)
print("Instance label: {}".format(instance["label"]))
print("Instance name: {}".format(file))
def project_l2_ball(z):
""" project the vectors in z onto the l2 unit norm ball"""
return z / np.maximum(np.sqrt(np.sum(z**2, axis=1))[:, np.newaxis], 1)
def project_l2_ball_torch(zt, unit_tensor):
return zt/torch.maximum(torch.sqrt(torch.sum(zt**2, axis = 1))[:, None], unit_tensor)
def is_not_mean(mean_face, kps):
for kp in kps[0]:
kp = np.asarray(kp)
dist = np.linalg.norm(kp - mean_face)
print(dist)
if dist > 0.65:
import pdb
pdb.set_trace()
return True
return False
def main():
dataset_root = "/srv/storage/datasets/rafaelvieira/new_data/new_sent_embeddings"
decoder_ckpt_fp = "/srv/storage/datasets/rafaelvieira/text2expression/should_be_good/decoder.pth"
zs_ckpt_fp = "/srv/storage/datasets/rafaelvieira/text2expression/should_be_good/Zs.pkl"
zsent_ckpt_fp = "/srv/storage/datasets/rafaelvieira/text2expression/should_be_good/Zsent.pkl"
mean_face = np.load("mean_face.npy")/256
torch.backends.cudnn.deterministic = True
device = torch.device("cuda:0")
unit_tensor = torch.Tensor([1])
decoder = Decoder(device).to(device)
decoder.load_state_dict(torch.load(decoder_ckpt_fp))
Z = read_pickle(zs_ckpt_fp)
Zs = read_pickle(zsent_ckpt_fp)
#encoder.train()
decoder.eval()
files = sorted(os.listdir(dataset_root))
idx1 = 2
idx2 = 2
printinfo(dataset_root, files[idx1])
printinfo(dataset_root, files[idx2])
zi = torch.zeros((1, 768, 8))
zi = torch.autograd.Variable(zi, requires_grad=False)
zs = torch.zeros((1, 768, 8))
zs = torch.autograd.Variable(zs, requires_grad=False)
with torch.no_grad():
#test_input = Z[130] + Z[20]
zi.data = (torch.Tensor(Z[idx1]))
zs.data = (torch.Tensor(Zs[idx2]))
input_ = zi + zs
input_ = input_.unsqueeze(0)
#input_ = project_l2_ball_torch(input_, unit_tensor)
faces_fake = make_exp(input_, "youdont.mp4", decoder, Z, device)
is_not_mean(mean_face, faces_fake)
import pdb
pdb.set_trace()
print("heheh")
"""
printinfo(dataset_root, files[idx1])
printinfo(dataset_root, files[idx2])
for idx, file_ in enumerate(files):
print("Looking in video")
file_fp = os.path.join(dataset_root, file_)
instance = read_pickle(file_fp)
label = instance["label"]
if label != "joy":
with torch.no_grad():
#test_input = Z[130] + Z[20]
test_input1 = Z[idx1] + Zs[idx1]
test_input1 = project_l2_ball(test_input1)
make_exp(test_input1, "video_{}.mp4".format(idx), decoder, Z, device)
"""
if __name__ == "__main__":
main()