import torch
from models.layers.mesh import Mesh, PartMesh
from models.networks import init_net, sample_surface, local_nonuniform_penalty
import utils
import numpy as np
from models.losses import chamfer_distance, BeamGapLoss
from options import Options
import time
import os

options = Options()
opts = options.args

device = torch.device('cuda:{}'.format(opts.gpu) if torch.cuda.is_available() else torch.device('cpu'))
print('device: {}'.format(device))

# initial mesh
mesh = Mesh(opts.initial_mesh, device=device, hold_history=True)

# input point cloud
input_xyz, input_normals = utils.read_pts(opts.input_pc)
# normalize point cloud based on initial mesh
input_xyz /= mesh.scale
input_xyz += mesh.translations[None, :]
input_xyz = torch.Tensor(input_xyz).type(options.dtype()).to(device)[None, :, :]
input_normals = torch.Tensor(input_normals).type(options.dtype()).to(device)[None, :, :]

part_mesh = PartMesh(mesh, num_parts=options.get_num_parts(len(mesh.faces)), bfs_depth=opts.overlap)
print(f'number of parts {part_mesh.n_submeshes}')
net, optimizer, rand_verts, scheduler = init_net(mesh, part_mesh, device, opts)

beamgap_loss = BeamGapLoss(device)

if opts.beamgap_iterations > 0:
    print('beamgap on')
    beamgap_loss.update_pm(part_mesh, torch.cat([input_xyz, input_normals], dim=-1))

for i in range(opts.iterations):
    num_samples = options.get_num_samples(i % opts.upsamp)
    if opts.global_step:
    start_time = time.time()
    for part_i, est_verts in enumerate(net(rand_verts, part_mesh)):
        if not opts.global_step:
        part_mesh.update_verts(est_verts[0], part_i)
        num_samples = options.get_num_samples(i % opts.upsamp)
        recon_xyz, recon_normals = sample_surface(part_mesh.main_mesh.faces, part_mesh.main_mesh.vs.unsqueeze(0), num_samples)
        # calc chamfer loss w/ normals
        recon_xyz, recon_normals = recon_xyz.type(options.dtype()), recon_normals.type(options.dtype())
        xyz_chamfer_loss, normals_chamfer_loss = chamfer_distance(recon_xyz, input_xyz, x_normals=recon_normals, y_normals=input_normals,

        if (i < opts.beamgap_iterations) and (i % opts.beamgap_modulo == 0):
            loss = beamgap_loss(part_mesh, part_i)
            loss = (xyz_chamfer_loss + (opts.ang_wt * normals_chamfer_loss))
        if opts.local_non_uniform > 0:
            loss += opts.local_non_uniform * local_nonuniform_penalty(part_mesh.main_mesh).float()
        if not opts.global_step:
    if opts.global_step:
    end_time = time.time()

    if i % 1 == 0:
        print(f'{os.path.basename(opts.input_pc)}; iter: {i} out of: {opts.iterations}; loss: {loss.item():.4f};'
              f' sample count: {num_samples}; time: {end_time - start_time:.2f}')
    if i % opts.export_interval == 0 and i > 0:
        print('exporting reconstruction... current LR: {}'.format(optimizer.param_groups[0]['lr']))
        with torch.no_grad():
            part_mesh.export(os.path.join(opts.save_path, f'recon_iter_{i}.obj'))

    if (i > 0 and (i + 1) % opts.upsamp == 0):
        mesh = part_mesh.main_mesh
        num_faces = int(np.clip(len(mesh.faces) * 1.5, len(mesh.faces), opts.max_faces))

        if num_faces > len(mesh.faces) or opts.manifold_always:
            # up-sample mesh
            mesh = utils.manifold_upsample(mesh, opts.save_path, Mesh,
                                           num_faces=min(num_faces, opts.max_faces),
                                           res=opts.manifold_res, simplify=True)

            part_mesh = PartMesh(mesh, num_parts=options.get_num_parts(len(mesh.faces)), bfs_depth=opts.overlap)
            print(f'upsampled to {len(mesh.faces)} faces; number of parts {part_mesh.n_submeshes}')
            net, optimizer, rand_verts, scheduler = init_net(mesh, part_mesh, device, opts)
            if i < opts.beamgap_iterations:
                print('beamgap updated')
                beamgap_loss.update_pm(part_mesh, input_xyz)

with torch.no_grad():
    mesh.export(os.path.join(opts.save_path, 'last_recon.obj'))
