https://github.com/PeizhuoLi/ganimator
Raw File
Tip revision: 2943064b456d13f0357e23e3c37fb43b6aa3fdaa authored by Peizhuo Li on 09 September 2022, 14:56:35 UTC
Update README.md
Tip revision: 2943064
demo.py
import torch
import os
from os.path import join as pjoin
from dataset.motion import MotionData, load_multiple_dataset
from models import create_model, create_conditional_model
from models.architecture import draw_example, get_pyramid_lengths, FullGenerator
from option import TestOptionParser, TrainOptionParser
from fix_contact import fix_contact_on_file
from models.utils import get_layered_mask
from interactive_utils import sliding_window


def load_all_from_path(save_path, device, use_class=False):
    train_parser = TrainOptionParser()
    args = train_parser.load(pjoin(save_path, 'args.txt'))
    args.device = device
    args.save_path = save_path
    device = torch.device(args.device)

    if not args.multiple_sequences:
        motion_data = MotionData(pjoin(args.bvh_prefix, f'{args.bvh_name}.bvh'),
                                 padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
                                 contact=args.contact, keep_y_pos=args.keep_y_pos)
        multiple_data = [motion_data]
    else:
        multiple_data = load_multiple_dataset(prefix=args.bvh_prefix, name_list=pjoin(args.bvh_prefix, args.bvh_name),
                                              padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
                                              contact=args.contact, keep_y_pos=args.keep_y_pos,
                                              no_scale=True)
        motion_data = multiple_data[0]

    lengths = []
    min_len = 10000
    for i in range(len(multiple_data)):
        new_length = get_pyramid_lengths(args, len(multiple_data[i]))
        min_len = min(min_len, len(new_length))
        if args.num_stages_limit != -1:
            new_length = new_length[:args.num_stages_limit]
        lengths.append(new_length)

    for i in range(len(multiple_data)):
        lengths[i] = lengths[i][-min_len:]

    gens = []
    for step, length in enumerate(lengths[0]):
        create = create_conditional_model if args.conditional_generator and step < args.num_conditional_generator else create_model
        gen = create(args, motion_data, evaluation=True)
        try:
            gen_sate = torch.load(pjoin(args.save_path, f'gen{step:03d}.pt'), map_location=device)
        except FileNotFoundError:
            gen_sate = torch.load(pjoin(args.save_path, f'gen{step}.pt'), map_location=device)
        gen.load_state_dict(gen_sate)
        gens.append(gen)
    z_star = torch.load(pjoin(args.save_path, 'z_star.pt'), map_location=device)
    amps = torch.load(pjoin(args.save_path, 'amps.pt'), map_location=device)
    if use_class:
        if isinstance(z_star, list):
            z_star = z_star[0]
        if len(amps.shape) != 1:
            amps = amps[0]
        return FullGenerator(args, motion_data, gens, z_star, amps)
    else:
        if len(amps.shape) == 1:
            amps = amps.unsqueeze(0)
        if isinstance(z_star, torch.Tensor) and len(z_star.shape) == 3:
            z_star = z_star.unsqueeze(0)
        return args, multiple_data, gens, z_star, amps, lengths


def write_multires(imgs, prefix, writer, interpolator, full_lengths=None, requires_con_loss=True):
    os.makedirs(prefix, exist_ok=True)
    length = imgs[-1].shape[-1] if full_lengths is None else full_lengths
    res = []
    for step, img in enumerate(imgs):
        full_length = interpolator(img, length)
        writer(pjoin(prefix, f'{step:02d}.bvh'), full_length)
        velo = full_length[:, -6:-3].norm(dim=1)
        res.append(velo)
    if requires_con_loss:
        res = torch.cat(res, dim=0)
        consistency_loss = torch.nn.MSELoss()(res[1], res[0])
        return consistency_loss


def gen_noise(n_channel, length, full_noise, device):
    if full_noise:
        res = torch.randn((1, n_channel, length)).to(device)
    else:
        res = torch.randn((1, 1, length)).repeat(1, n_channel, 1).to(device)
    return res


def main():
    test_parser = TestOptionParser()
    test_args = test_parser.parse_args()

    args, multiple_data, gens, z_stars, amps, lengths = load_all_from_path(test_args.save_path, test_args.device)
    device = torch.device(args.device)
    n_total_levels = len(gens)

    motion_data = multiple_data[0]

    noise_channel = z_stars[0].shape[1] if args.full_noise else 1

    if len(args.path_to_existing):
        ConGen = load_all_from_path(args.path_to_existing, args.device, use_class=True)
        ConGen.output_mask = get_layered_mask(args.conditional_mode, motion_data.n_rot)
        conds_rec = [motion_data.sample(lengths[0][i]) for i in range(args.num_conditional_generator)]
    else:
        ConGen = None
        conds_rec = None

    print('levels:', lengths)
    save_path = pjoin(args.save_path, 'bvh')
    os.makedirs(save_path, exist_ok=True)

    base_id = 0

    # Evaluate with reconstruct noise
    for i in range(len(multiple_data)):
        motion_data = multiple_data[i]
        imgs = draw_example(gens, 'rec', z_stars[i], lengths[i] + [1], amps[i], 1, args, all_img=True, conds=conds_rec,
                            full_noise=args.full_noise)
        real = motion_data.sample(size=len(motion_data), slerp=args.slerp).to(device)
        motion_data.write(pjoin(save_path, f'gt_{i}.bvh'), real)
        motion_data.write(pjoin(save_path, f'rec_{i}.bvh'), imgs[-1])

        if imgs[-1].shape[-1] == real.shape[-1]:
            rec_loss = torch.nn.MSELoss()(imgs[-1], real).detach().cpu().numpy()
            print(f'rec_loss: {rec_loss.item():.07f}')

    generation_mode = 'manip' if test_args.style_transfer or test_args.keyframe_editing else 'random'

    if test_args.style_transfer:
        manip_data = MotionData(f'{test_args.style_transfer}',
                                padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
                                contact=args.contact, keep_y_pos=args.keep_y_pos,
                                no_scale=False)
        target_len = len(manip_data)
        target_length = get_pyramid_lengths(args, target_len)
        conds = manip_data.sample(target_length[0])
    elif test_args.keyframe_editing:
        manip_data = MotionData(f'{test_args.keyframe_editing}',
                                padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
                                contact=args.contact, keep_y_pos=args.keep_y_pos,
                                no_scale=True)
        target_len = len(multiple_data[0])                       # Use original length of training data
        target_length = get_pyramid_lengths(args, target_len)
        conds = manip_data.sample(target_length[0])
    elif test_args.conditional_generation:
        manip_data = MotionData(f'{test_args.conditional_generation}',
                                padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
                                contact=args.contact, keep_y_pos=args.keep_y_pos,
                                no_scale=True)
        target_len = len(manip_data)  # Use original length of training data
        target_length = get_pyramid_lengths(args, target_len)
        conds = [manip_data.sample(l) for l in target_length[:args.num_conditional_generator]]
        conds_full = conds[-1]
        generation_mode = 'cond'
        if not args.conditional_generator:
            raise Exception('Conditional generation only applicable to conditional generators.')
    elif args.conditional_generator:
        "This is a conditional model, but the condition is not given. Then the condition will be sampled from the ConGen model"
        target_len = test_args.target_length
        target_length = get_pyramid_lengths(args, target_len)
        manip_data = None
        conds = ConGen.random_generate(target_length)
        conds_full = conds[-1]
        conds = conds[:args.num_conditional_generator]
    else:
        target_len = test_args.target_length
        target_length = get_pyramid_lengths(args, target_len)
        manip_data = None
        conds = None

    while len(target_length) > n_total_levels:
        target_length = target_length[1:]
    z_length = target_length[0]

    z_target = gen_noise(noise_channel, z_length, args.full_noise, device)
    z_target *= amps[base_id][0]

    amps2 = amps[base_id].clone()
    amps2[1:] = 0

    if not test_args.interactive:
        imgs = draw_example(gens, generation_mode, z_stars[base_id], target_length, amps2, 1, args, all_img=True,
                            conds=conds, full_noise=args.full_noise, given_noise=[z_target])
    else:
        if not args.conditional_generator:
            raise Exception('Interactive mode only applicable to conditional generators.')
        final_res, imgs = sliding_window(gens, z_stars[base_id], amps2, conds_full, args)

    motion_data.write(pjoin(save_path, f'result.bvh'), imgs[-1])
    fix_contact_on_file(save_path, name=f'result')

    if manip_data is not None:
        motion_data.write(pjoin(save_path, 'manipulate_input.bvh'), manip_data.sample())
    elif args.conditional_generator:
        motion_data.write(pjoin(save_path, 'generated_traj.bvh'), conds_full)


if __name__ == '__main__':
    main()
back to top