https://github.com/amirhertz/geometric-textures
Revision 087c967a74b023bae3cd874816ab8e0e5e30ddd7 authored by Rana Hanocka on 21 October 2020, 10:05:00 UTC, committed by GitHub on 21 October 2020, 10:05:00 UTC
1 parent 0fb5bfd
Tip revision: 087c967a74b023bae3cd874816ab8e0e5e30ddd7 authored by Rana Hanocka on 21 October 2020, 10:05:00 UTC
Update README.md
Update README.md
Tip revision: 087c967
dgts_base.py
from custom_types import *
import options as options
from process_data import mesh_utils
import models.factory as factory
from models.single_mesh_models import SingleMeshGenerator
from models.mesh_handler import MeshHandler, MeshInference, load_template_mesh
class DGTS:
def __init__(self, opt: Union[options.Options, options.TrainOption], device: D):
self.opt = opt.load()
self.generator: SingleMeshGenerator = factory.model_lc(opt, SingleMeshGenerator, device=device)
self.level = 0
self.device = device
def get_random_z(self, num_randoms: int) -> T:
if self.opt.noise_before:
return torch.randn(num_randoms, 3, device=self.device) * self.opt.noise_amplitude
return torch.randn(1, self.generator.opt.in_nf, num_randoms, device=self.device) * self.opt.noise_amplitude
def get_z_by_level(self, base_mesh: MeshHandler, level: int) -> T:
num_faces = len(base_mesh) * 4 ** level
if self.opt.noise_before:
num_faces = (num_faces - len(base_mesh)) // 2 + base_mesh.vs.shape[0]
return self.get_random_z(num_faces)
def get_z_sequence(self, base_mesh: MeshHandler, max_level: int) -> TS:
return [self.get_z_by_level(base_mesh, level) for level in range(max_level + 1)]
def __len__(self):
return len(self.generator.levels)
class Mesh2Mesh(DGTS):
def __init__(self, opt: options.Options, device: D):
super(Mesh2Mesh, self).__init__(opt, device)
self.generator.eval()
def trim(self, start: int, end: int) -> Tuple[int, int]:
if end < 0 or end > len(self) - 1:
end = len(self) - 1
if start > end:
start = end
return start, end
def get_z_sequence(self, base_mesh: MeshHandler, max_level: int) -> factory.Noise:
return factory.Noise(data=super(Mesh2Mesh, self).get_z_sequence(base_mesh, max_level))
def growing(self, mesh: MeshInference, start: int, end: int, num_frames: int, zero_places: NoiseT = ()) -> factory.Noise:
export_name = f'{self.opt.cp_folder}/inference/{mesh.mesh_name}/scene00'
start, end = self.trim(start, end)
if len(zero_places) == 1:
zero_places = zero_places * (end - start + 1)
z = self.get_z_sequence(mesh, end - start)
for i in range(min(len(z), len(zero_places))):
if zero_places[i]:
z[i] = 0
deltas = self.generator.grow_forward(mesh.copy(), z, end, start)
mesh.export(f'{export_name}/{0:02d}')
for i in range(end - start + 1):
base_vs, cur_delta = mesh.vs.clone(), deltas[i]
for j in range(num_frames):
mesh.vs = base_vs + cur_delta * (j + 1) / num_frames
mesh.export(f'{export_name}/{(num_frames * i + j + 1):02d}')
print(f'done: {num_frames * i + j + 1}/{num_frames * (end - start + 1)}')
if i < end - start:
mesh.upsample()
return z
def animate(self, mesh: MeshInference, start: int, end: int, num_scene: int, num_frames: Tuple[int, int], zero_places: NoiseT = ()):
export_name = f'{self.opt.cp_folder}/inference/{mesh.mesh_name}/scene01'
start, end = self.trim(start, end)
if len(zero_places) == 1:
zero_places = zero_places * (end - start + 1)
z_a = self.growing(mesh.copy(), start, end, num_frames[0], zero_places)
z_b = self.get_z_sequence(mesh, end - start).to(self.device)
for i in range(min(len(z_a), len(zero_places))):
if zero_places[i]:
z_b[i] = 0
z_start = z_a
num_frames = num_frames[1]
for s in range(num_scene):
for i in range(num_frames):
alpha = (i + 1) / float(num_frames)
z = z_a * (1 - alpha) + z_b * alpha
m = mesh.copy()
out = self.generator(m, z, end, start, upsample=True)
out.export(f'{export_name}/{s * num_frames + i:02d}')
print(f'frame {s * num_frames + i + 1} / {num_scene * num_frames}...')
z_a = z_b
if s == num_scene - 2:
z_b = z_start
else:
z_b = self.get_z_sequence(mesh, end - start).to(self.device)
def __call__(self, mesh: Union[str, MeshHandler, T_Mesh], start: int, end: int, zero_places: NoiseT = 0) -> MeshHandler:
MeshHandler.reset()
start, end = self.trim(start, end)
if type(zero_places) is int:
zero_places = [zero_places]
if len(zero_places) == 1:
zero_places = zero_places * (end - start + 1)
if mesh is None:
mesh = MeshHandler(mesh_utils.load_real_mesh(self.opt.template_name, start),self.opt, 0).to(self.device)
elif type(mesh) is not MeshHandler:
mesh = MeshHandler(mesh, self.opt, 0).to(self.device)
z = self.get_z_sequence(mesh, end - start)
for i in range(min(len(z), len(zero_places))):
if zero_places[i]:
z[i] = 0
remeshed = self.generator.forward(mesh, z, end, start, upsample=True)
return remeshed
class MeshGen(DGTS):
def __init__(self, opt: options.Options, device: D):
super(MeshGen, self).__init__(opt, device)
self.generator.eval()
template_name, template = load_template_mesh(opt, opt.start_level)
self.template = MeshInference(template_name, template, self.opt, self.opt.start_level).to(self.device)
self.reconstruction_z = factory.NoiseMem(opt).load().to(device)
def compose_z(self, start_level) -> factory.Noise:
random_noise = self.get_z_sequence(self.template, len(self) - 1)
noise = self.reconstruction_z[: start_level] + random_noise[start_level:]
return noise
def generate_seq(self, num_seqs: int):
for seq in range(num_seqs):
z = self.compose_z(0)
self.generator.inference_forward(self.template.copy(), z, len(self) - 1, 0,
f'{opt_.cp_folder}/inference/gen/{self.opt.mesh_name}_{seq}',
upsample=True)
def generate_all(self, num_samples: int):
for i in range(len(self.generator.levels)):
for j in range(num_samples):
out_mesh = self(i)
out_mesh.export(f'{self.opt.cp_folder}/inference/gen/{self.opt.mesh_name}_{i}_{j:02d}')
print(f'gen {self.opt.mesh_name} {i * num_samples + j +1:02d} / {len(self.generator.levels) * num_samples}')
def __call__(self, start_level: int):
with torch.no_grad():
if start_level < 0:
start_level = len(self)
start_level = min(len(self), start_level)
z = self.compose_z(start_level)
return self.generator(self.template.copy(), z, len(self) - 1)
if __name__ == '__main__':
opt_ = options.Options()
opt_.parse_cmdline()
device = CPU
with_noise = False
if opt_.gen_mode == 'generate':
mg = MeshGen(opt_, device)
mg.generate_all(opt_.num_gen_samples)
elif opt_.gen_mode == 'animate':
m2m = Mesh2Mesh(opt_, device)
in_mesh = MeshInference(opt_.target, mesh_utils.load_real_mesh(opt_.target, 0, True), opt_, 0).to(device)
m2m.animate(in_mesh, opt_.gen_levels[0], opt_.gen_levels[1], 0, (12, 17), zero_places=(0, 0, 1, 1, 1))

Computing file changes ...