https://github.com/PeizhuoLi/ganimator
Tip revision: 2943064b456d13f0357e23e3c37fb43b6aa3fdaa authored by Peizhuo Li on 09 September 2022, 14:56:35 UTC
Update README.md
Update README.md
Tip revision: 2943064
demo.py
import torch
import os
from os.path import join as pjoin
from dataset.motion import MotionData, load_multiple_dataset
from models import create_model, create_conditional_model
from models.architecture import draw_example, get_pyramid_lengths, FullGenerator
from option import TestOptionParser, TrainOptionParser
from fix_contact import fix_contact_on_file
from models.utils import get_layered_mask
from interactive_utils import sliding_window
def load_all_from_path(save_path, device, use_class=False):
train_parser = TrainOptionParser()
args = train_parser.load(pjoin(save_path, 'args.txt'))
args.device = device
args.save_path = save_path
device = torch.device(args.device)
if not args.multiple_sequences:
motion_data = MotionData(pjoin(args.bvh_prefix, f'{args.bvh_name}.bvh'),
padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
contact=args.contact, keep_y_pos=args.keep_y_pos)
multiple_data = [motion_data]
else:
multiple_data = load_multiple_dataset(prefix=args.bvh_prefix, name_list=pjoin(args.bvh_prefix, args.bvh_name),
padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
contact=args.contact, keep_y_pos=args.keep_y_pos,
no_scale=True)
motion_data = multiple_data[0]
lengths = []
min_len = 10000
for i in range(len(multiple_data)):
new_length = get_pyramid_lengths(args, len(multiple_data[i]))
min_len = min(min_len, len(new_length))
if args.num_stages_limit != -1:
new_length = new_length[:args.num_stages_limit]
lengths.append(new_length)
for i in range(len(multiple_data)):
lengths[i] = lengths[i][-min_len:]
gens = []
for step, length in enumerate(lengths[0]):
create = create_conditional_model if args.conditional_generator and step < args.num_conditional_generator else create_model
gen = create(args, motion_data, evaluation=True)
try:
gen_sate = torch.load(pjoin(args.save_path, f'gen{step:03d}.pt'), map_location=device)
except FileNotFoundError:
gen_sate = torch.load(pjoin(args.save_path, f'gen{step}.pt'), map_location=device)
gen.load_state_dict(gen_sate)
gens.append(gen)
z_star = torch.load(pjoin(args.save_path, 'z_star.pt'), map_location=device)
amps = torch.load(pjoin(args.save_path, 'amps.pt'), map_location=device)
if use_class:
if isinstance(z_star, list):
z_star = z_star[0]
if len(amps.shape) != 1:
amps = amps[0]
return FullGenerator(args, motion_data, gens, z_star, amps)
else:
if len(amps.shape) == 1:
amps = amps.unsqueeze(0)
if isinstance(z_star, torch.Tensor) and len(z_star.shape) == 3:
z_star = z_star.unsqueeze(0)
return args, multiple_data, gens, z_star, amps, lengths
def write_multires(imgs, prefix, writer, interpolator, full_lengths=None, requires_con_loss=True):
os.makedirs(prefix, exist_ok=True)
length = imgs[-1].shape[-1] if full_lengths is None else full_lengths
res = []
for step, img in enumerate(imgs):
full_length = interpolator(img, length)
writer(pjoin(prefix, f'{step:02d}.bvh'), full_length)
velo = full_length[:, -6:-3].norm(dim=1)
res.append(velo)
if requires_con_loss:
res = torch.cat(res, dim=0)
consistency_loss = torch.nn.MSELoss()(res[1], res[0])
return consistency_loss
def gen_noise(n_channel, length, full_noise, device):
if full_noise:
res = torch.randn((1, n_channel, length)).to(device)
else:
res = torch.randn((1, 1, length)).repeat(1, n_channel, 1).to(device)
return res
def main():
test_parser = TestOptionParser()
test_args = test_parser.parse_args()
args, multiple_data, gens, z_stars, amps, lengths = load_all_from_path(test_args.save_path, test_args.device)
device = torch.device(args.device)
n_total_levels = len(gens)
motion_data = multiple_data[0]
noise_channel = z_stars[0].shape[1] if args.full_noise else 1
if len(args.path_to_existing):
ConGen = load_all_from_path(args.path_to_existing, args.device, use_class=True)
ConGen.output_mask = get_layered_mask(args.conditional_mode, motion_data.n_rot)
conds_rec = [motion_data.sample(lengths[0][i]) for i in range(args.num_conditional_generator)]
else:
ConGen = None
conds_rec = None
print('levels:', lengths)
save_path = pjoin(args.save_path, 'bvh')
os.makedirs(save_path, exist_ok=True)
base_id = 0
# Evaluate with reconstruct noise
for i in range(len(multiple_data)):
motion_data = multiple_data[i]
imgs = draw_example(gens, 'rec', z_stars[i], lengths[i] + [1], amps[i], 1, args, all_img=True, conds=conds_rec,
full_noise=args.full_noise)
real = motion_data.sample(size=len(motion_data), slerp=args.slerp).to(device)
motion_data.write(pjoin(save_path, f'gt_{i}.bvh'), real)
motion_data.write(pjoin(save_path, f'rec_{i}.bvh'), imgs[-1])
if imgs[-1].shape[-1] == real.shape[-1]:
rec_loss = torch.nn.MSELoss()(imgs[-1], real).detach().cpu().numpy()
print(f'rec_loss: {rec_loss.item():.07f}')
generation_mode = 'manip' if test_args.style_transfer or test_args.keyframe_editing else 'random'
if test_args.style_transfer:
manip_data = MotionData(f'{test_args.style_transfer}',
padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
contact=args.contact, keep_y_pos=args.keep_y_pos,
no_scale=False)
target_len = len(manip_data)
target_length = get_pyramid_lengths(args, target_len)
conds = manip_data.sample(target_length[0])
elif test_args.keyframe_editing:
manip_data = MotionData(f'{test_args.keyframe_editing}',
padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
contact=args.contact, keep_y_pos=args.keep_y_pos,
no_scale=True)
target_len = len(multiple_data[0]) # Use original length of training data
target_length = get_pyramid_lengths(args, target_len)
conds = manip_data.sample(target_length[0])
elif test_args.conditional_generation:
manip_data = MotionData(f'{test_args.conditional_generation}',
padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
contact=args.contact, keep_y_pos=args.keep_y_pos,
no_scale=True)
target_len = len(manip_data) # Use original length of training data
target_length = get_pyramid_lengths(args, target_len)
conds = [manip_data.sample(l) for l in target_length[:args.num_conditional_generator]]
conds_full = conds[-1]
generation_mode = 'cond'
if not args.conditional_generator:
raise Exception('Conditional generation only applicable to conditional generators.')
elif args.conditional_generator:
"This is a conditional model, but the condition is not given. Then the condition will be sampled from the ConGen model"
target_len = test_args.target_length
target_length = get_pyramid_lengths(args, target_len)
manip_data = None
conds = ConGen.random_generate(target_length)
conds_full = conds[-1]
conds = conds[:args.num_conditional_generator]
else:
target_len = test_args.target_length
target_length = get_pyramid_lengths(args, target_len)
manip_data = None
conds = None
while len(target_length) > n_total_levels:
target_length = target_length[1:]
z_length = target_length[0]
z_target = gen_noise(noise_channel, z_length, args.full_noise, device)
z_target *= amps[base_id][0]
amps2 = amps[base_id].clone()
amps2[1:] = 0
if not test_args.interactive:
imgs = draw_example(gens, generation_mode, z_stars[base_id], target_length, amps2, 1, args, all_img=True,
conds=conds, full_noise=args.full_noise, given_noise=[z_target])
else:
if not args.conditional_generator:
raise Exception('Interactive mode only applicable to conditional generators.')
final_res, imgs = sliding_window(gens, z_stars[base_id], amps2, conds_full, args)
motion_data.write(pjoin(save_path, f'result.bvh'), imgs[-1])
fix_contact_on_file(save_path, name=f'result')
if manip_data is not None:
motion_data.write(pjoin(save_path, 'manipulate_input.bvh'), manip_data.sample())
elif args.conditional_generator:
motion_data.write(pjoin(save_path, 'generated_traj.bvh'), conds_full)
if __name__ == '__main__':
main()
