import field_utils import options from pathlib import Path from field_utils import * from util import orient_center from inference_utils import load_model_from_file, fix_n_filter, voting_policy torch.manual_seed(1) def run(opts): export_path: Path = opts.export_dir export_path.mkdir(exist_ok=True) max_patch_size = 500 device = torch.device(torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu')) MyTimer = util.timer_factory() with MyTimer('load pc', count=False): input_pc = util.xyz2tensor(open(opts.pc, 'r').read(), append_normals=False).to(device) input_pc, transform = util.Transform.trans(input_pc) if opts.estimate_normals: with MyTimer('estimating normals'): input_pc = util.estimate_normals(input_pc, max_nn=opts.n) softmax = torch.nn.Softmax(dim=-1) n_models = len(opts.models) models = [load_model_from_file(opts.models[i], device) for i in range(n_models)] with MyTimer('divide patches'): patch_indices = util.divide_pc(input_pc[:, :3], opts.number_parts, min_patch=opts.minimum_points_per_patch) all_patches_indices = [x.clone() for x in patch_indices] with MyTimer('filter patches'): patch_indices = fix_n_filter(input_pc, patch_indices, opts.curvature_threshold) num_patches = len(patch_indices) num_all_patches = len(all_patches_indices) print(f'number of patches {num_patches}/{num_all_patches}') with MyTimer('orient center'): for i, p in patch_indices: input_pc[p] = orient_center(input_pc[p]) with MyTimer('find reps'): represent = [] for p in all_patches_indices: permutation = torch.randperm(p.shape[0]) represent.append((p[permutation[:max_patch_size]], p[permutation[max_patch_size:]])) pc_probs = torch.ones_like(input_pc[:, 0]) with MyTimer('network orientation'): for i, _ in patch_indices: with torch.no_grad(): current_reps, non_reps_points = represent[i] data = input_pc[current_reps] data = data.to(device) for _ in range(opts.iters): votes = [model(data.clone()) for model in models] vote_probabilities = [softmax(scores)[:, 1] for scores in votes] flip, probs = voting_policy(vote_probabilities) pc_probs[current_reps] = probs input_pc[current_reps[flip], 3:] *= -1 [model.to('cpu') for model in models] with MyTimer('propagating field'): strongest_field_propagation_reps(input_pc, represent, diffuse=True) with MyTimer('fix global orientation'): if field_utils.measure_mean_potential(input_pc) < 0: # if average global potential is negative, flip all normals input_pc[:, 3:] *= -1 with MyTimer('exporting result', count=False): util.export_pc(transform.inverse(input_pc).transpose(0, 1), export_path / f'final_result.xyz') MyTimer.print_total_time() if __name__ == '__main__': opts = options.get_parser().parse_args() opts.export_dir.mkdir(exist_ok=True, parents=True) options.export_options(opts) run(opts)