https://github.com/zachzhang07/vosh
Tip revision: da207d03e7994d9c5a097126dcd509abedc26bc0 authored by zachzhang07 on 21 November 2024, 08:07:14 UTC
Update readme.md
Update readme.md
Tip revision: da207d0
main_vosh.py
import argparse
from nerf.utils import *
# from nerf.gui import NeRFGUI
from nerf import math
import time
if __name__ == '__main__':
start_t = time.time()
parser = argparse.ArgumentParser()
parser.add_argument('path', type=str)
parser.add_argument('--mesh_select', type=float, default=0.8,
help="select ratio in error_grid[error_grid > 0] in voxel to mesh, "
"0 means all mesh, 1 means only center mesh")
parser.add_argument('--keep_center', type=float, default=0.25,
help="keep center ratio in mesh selection, valid only mesh_select > 0")
parser.add_argument('--lambda_mesh_weight', type=float, default=0, help="loss scale")
parser.add_argument('--lambda_bg_weight', type=float, default=0, help="loss scale")
parser.add_argument('--local_face_num', type=int, default=200000)
parser.add_argument('--vol_path', type=str)
parser.add_argument('--use_occ_grid', action='store_true')
parser.add_argument('--use_mesh_occ_grid', action='store_true')
parser.add_argument('--mesh_check_ratio', type=int, default=0)
parser.add_argument('--lambda_ec_weight', type=float, default=0, help="loss scale")
parser.add_argument('--use_vol_pth', action='store_true',
help="use pretained volume encoder for volume render, default is mesh encoder")
parser.add_argument('--mesh_encoder', action='store_true',
help="use mesh encoder for surface render, valid only when use_vol_pth is True")
parser.add_argument('--no_baking', action='store_true',
help="do not baking")
# parser.add_argument('--ec_center', type=float, default=0.25, help="loss scale")
# parser.add_argument('--ras_mask', action='store_true')
# parser.add_argument('--density_scale', type=float, default=0, help="")
# parser.add_argument('--all_mesh', action='store_true', help="use all mesh in voxel to mesh")
parser.add_argument('--vert_offset', action='store_true')
parser.add_argument('--lr_vert', type=float, default=1e-4,
help="initial learning ratio for vert optimization")
parser.add_argument('--workspace', type=str, default='workspace')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--ckpt', type=str, default='latest')
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--fast_baking', action='store_true',
help="faster baking at the cost of maybe missing blocks at background")
### new options
parser.add_argument('--render', type=str, default='mixed', choices=['mixed'])
parser.add_argument('--criterion', type=str, default='MSE', choices=['L1', 'MSE'])
parser.add_argument('--max_edge_len', type=float, default=1.0)
parser.add_argument('--min_iso_size', type=int, default=0)
parser.add_argument('--alpha_thres', type=float, default=0.005,
help="initial learning rate for vert optimization")
# parser.add_argument('--min_iso_size', type=int, default=100)
parser.add_argument('--clean_min_f', type=int, default=25,
help="mesh clean: min face count for isolated mesh")
parser.add_argument('--clean_min_d', type=int, default=0,
help="mesh clean: min diameter for isolated mesh")
# parser.add_argument('--cascade_list', type=float, nargs='*', default=[1.0, 2.0, 3.0])
### mesh options
parser.add_argument('--ssaa', type=int, default=2, help="super sampling anti-aliasing ratio")
parser.add_argument('--texture_size', type=int, default=4096, help="exported texture resolution")
# parser.add_argument('--refine', action='store_true', help="track face error and do subdivision")
# parser.add_argument("--refine_steps_ratio", type=float, action="append", default=[0.1, 0.2, 0.3, 0.4, 0.5, 0.7])
# parser.add_argument('--refine_size', type=float, default=0.01, help="refine trig length")
# parser.add_argument('--refine_decimate_ratio', type=float, default=0.1, help="refine decimate ratio")
# parser.add_argument('--refine_remesh_size', type=float, default=0.02, help="remesh trig length")
# parser.add_argument('--decimate_target', type=float, default=3e5,
# help="decimate target for number of triangles, <=0 to disable")
# parser.add_argument('--visibility_mask_dilation', type=int, default=10, help="visibility dilation")
parser.add_argument('--pos_gradient_boost', type=float, default=1, help="nvdiffrast option")
### model options
# parser.add_argument('--backbone', type=str, default='merf_new', help="backbone type")
parser.add_argument('--grid_resolution', type=int, default=1024)
parser.add_argument('--triplane_resolution', type=int, default=-1)
### testing options
parser.add_argument('--save_cnt', type=int, default=1,
help="save checkpoints for $ times during training")
parser.add_argument('--eval_cnt', type=int, default=1,
help="perform validation for $ times during training")
parser.add_argument('--test', action='store_true', help="test mode")
parser.add_argument('--test_no_video', action='store_true', help="test mode: do not save video")
parser.add_argument('--test_no_baking', action='store_true', help="test mode: do not save baking")
parser.add_argument('--test_no_mesh', action='store_true', help="test mode: do not save mesh")
parser.add_argument('--camera_traj', type=str, default='path',
help="interp for interpolation, circle for circular camera")
### dataset options
parser.add_argument('--data_format', type=str, default='colmap', choices=['nerf', 'colmap', 'dtu'])
parser.add_argument('--train_split', type=str, default='train', choices=['train', 'all'])
parser.add_argument('--test_split', type=str, default='test', choices=['train', 'val', 'test'])
parser.add_argument('--preload', action='store_true',
help="preload all data into GPU, accelerate training but use more GPU memory")
parser.add_argument('--random_image_batch', action='store_true',
help="randomly sample rays from all images per step in training")
parser.add_argument('--downscale', type=int, default=4, help="downscale training images")
parser.add_argument('--bound', type=float, default=128,
help="assume the scene is bounded in box[-bound, bound]^3, "
"if > 1, will invoke adaptive ray marching.")
parser.add_argument('--scale', type=float, default=-1,
help="scale camera location into box[-bound, bound]^3, "
"-1 means automatically determine based on camera poses..")
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
parser.add_argument('--enable_cam_near_far', action='store_true',
help="colmap mode: use the sparse points to estimate camera near far per view.")
parser.add_argument('--enable_cam_center', action='store_true',
help="use camera center instead of sparse point center (colmap dataset only)")
parser.add_argument('--min_near', type=float, default=0.2, help="minimum near distance for camera")
parser.add_argument('--T_thresh', type=float, default=2e-4,
help="minimum transmittance to continue ray marching")
### optimization options
parser.add_argument('--lr_init', type=float, default=1e-2, help="The initial learning rate")
parser.add_argument('--lr_final', type=float, default=1e-3, help="The final learning rate")
parser.add_argument('--lr_delay_steps', type=int, default=100,
help="The number of 'warmup' learning steps")
parser.add_argument('--lr_delay_mult', type=float, default=0.01,
help="How much sever the 'warmup' should be")
### training options
parser.add_argument('--iters', type=int, default=20000, help="training iters")
# parser.add_argument('--lr', type=float, default=1e-3, help="initial learning rate")
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
# parser.add_argument('--max_steps', type=int, default=1024,
# help="max num steps sampled per ray (only valid when using --cuda_ray)")
parser.add_argument('--num_steps', type=int, nargs='*', default=[128, 64, 32],
help="num steps sampled per ray for each proposal level (only valid when NOT using --cuda_ray)")
parser.add_argument('--contract', action='store_true',
help="apply spatial contraction as in MERF, only work for bound > 1, will override bound to 2.")
parser.add_argument('--enable_dense_depth', action='store_true', help="dense depth supervision")
parser.add_argument('--background', type=str, default='random', choices=['white', 'random', 'last_sample'],
help="training background mode")
# parser.add_argument('--update_extra_interval', type=int, default=16,
# help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096 * 2,
help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
# parser.add_argument('--grid_size', type=int, default=128, help="density grid resolution")
parser.add_argument('--mark_untrained', action='store_true', help="mark_untrained grid")
parser.add_argument('--dt_gamma', type=float, default=1 / 256,
help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, "
">0 to accelerate rendering (but usually with worse quality)")
parser.add_argument('--density_thresh', type=float, default=10,
help="threshold for density grid to be occupied")
parser.add_argument('--diffuse_step', type=int, default=0,
help="training iters that only trains diffuse color for better initialization")
# batch size related
parser.add_argument('--num_rays', type=int, default=4096,
help="num rays sampled per image for each training step")
parser.add_argument('--adaptive_num_rays', action='store_true',
help="adaptive num rays for more efficient training")
parser.add_argument('--num_points', type=int, default=2 ** 18,
help="target num points for each training step, only work with adaptive num_rays")
# regularizations
parser.add_argument('--lambda_entropy', type=float, default=0, help="loss scale")
parser.add_argument('--lambda_tv', type=float, default=0, help="loss scale")
parser.add_argument('--lambda_proposal', type=float, default=1,
help="loss scale (only for non-cuda-ray mode)")
parser.add_argument('--lambda_distort', type=float, default=0.01,
help="loss scale (only for non-cuda-ray mode)")
parser.add_argument('--lambda_specular', type=float, default=1e-5, help="loss scale")
parser.add_argument('--lambda_depth', type=float, default=0, help="loss scale")
parser.add_argument('--lambda_sparsity', type=float, default=0, help="loss scale")
parser.add_argument('--lambda_mask', type=float, default=0, help="loss scale")
### GUI options
parser.add_argument('--vis_pose', action='store_true', help="visualize the poses")
parser.add_argument('--gui', action='store_true', help="start a GUI")
parser.add_argument('--W', type=int, default=1000, help="GUI width")
parser.add_argument('--H', type=int, default=1000, help="GUI height")
parser.add_argument('--radius', type=float, default=1, help="default GUI camera radius from center")
parser.add_argument('--fovy', type=float, default=50, help="default GUI camera fovy")
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
opt = parser.parse_args()
opt.fp16 = True
opt.preload = True
opt.contract = True
opt.adaptive_num_rays = True
opt.enable_cam_center = True
opt.enable_cam_near_far = True
opt.use_vol_pth = True
opt.mesh_encoder = True
assert opt.vert_offset is False
assert opt.alpha_thres == 0.005
# opt.random_image_batch = True
assert opt.vol_path is not None, f'vol_path is not valid: {opt.vol_path}'
if opt.use_mesh_occ_grid:
assert opt.mesh_check_ratio >= 1
else:
assert opt.mesh_check_ratio < 1
# todo: change contract by opt.data_format
if opt.data_format == 'nerf':
# opt.bound = 1
opt.min_near = 0.05
# opt.diffuse_step = 0
# opt.cascade_list = [0.5, 1.0, 2.0]
# opt.cascade_list = [0.5, 1.0]
opt.contract = False
# opt.alpha_thres = 0.001
opt.decimate_target = [1e5, 1e5, 1e5]
# opt.random_image_batch = True
opt.clean_min_f = 16
opt.clean_min_d = 10
opt.visibility_mask_dilation = 5
# todo end
if opt.contract:
# mark untrained is not correct in contraction mode...
opt.mark_untrained = False
if opt.data_format == 'colmap':
from nerf.colmap_provider import ColmapDataset as NeRFDataset
elif opt.data_format == 'dtu':
from nerf.dtu_provider import NeRFDataset
else: # nerf
from nerf.provider import NeRFDataset
# convert ratio to steps
# opt.refine_steps = [int(round(x * opt.iters)) for x in opt.refine_steps_ratio]
# inner_idx = (np.array(opt.cascade_list) <= 1.0).sum()
# opt.cascade_list = opt.cascade_list[inner_idx - 1:]
seed_everything(opt.seed)
# if opt.backbone == 'merf_new':
# from nerf.network_vosh import NeRFNetwork
# else:
# raise NotImplementedError
from nerf.network_vosh import NeRFNetwork
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NeRFNetwork(opt)
if opt.criterion == 'L1':
criterion = torch.nn.SmoothL1Loss(reduction='none')
elif opt.criterion == 'MSE':
criterion = torch.nn.MSELoss(reduction='none')
else:
raise NotImplementedError
if opt.test:
trainer = Trainer('vosh', opt, model, device=device, workspace=opt.workspace, criterion=criterion,
fp16=opt.fp16, use_checkpoint=opt.ckpt)
if not opt.test_no_video:
test_loader = NeRFDataset(opt, device=device, type=opt.test_split).dataloader()
if test_loader.has_gt:
trainer.metrics = [PSNRMeter(), SSIMMeter(), LPIPSMeter(device=device)] # set up metrics
trainer.evaluate(test_loader, name='test') # blender has gt, so evaluate it.
trainer.test(test_loader, write_video=True) # test and save video
else:
lr = opt.lr if 'lr' in opt else 1.0
optimizer = lambda model: torch.optim.Adam(model.get_params(lr), eps=1e-15)
train_loader = NeRFDataset(opt, device=device, type=opt.train_split).dataloader()
max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
save_interval = max(1, max_epoch // max(1, opt.save_cnt))
eval_interval = max(1, max_epoch // max(1, opt.eval_cnt))
print(f'[INFO] max_epoch {max_epoch}, eval every {eval_interval}.')
# colmap can estimate a more compact AABB
# if not opt.contract and opt.data_format == 'colmap':
# model.update_aabb(train_loader._data.pts_aabb)
# scheduler = lambda optimizer: torch.optim.lr_scheduler.LambdaLR(optimizer,
# lambda iter: 0.1 ** (iter / opt.iters))
scheduler = lambda optimizer: (
torch.optim.lr_scheduler.LambdaLR(optimizer, lambda iter:
math.learning_rate_decay(iter + 1, opt.lr_init, opt.lr_final, opt.iters,
opt.lr_delay_steps, opt.lr_delay_mult)))
trainer = Trainer('vosh', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer,
criterion=criterion, ema_decay=None, fp16=opt.fp16, lr_scheduler=scheduler,
scheduler_update_every_step=True, use_checkpoint=opt.ckpt, eval_interval=eval_interval,
save_interval=save_interval)
valid_loader = NeRFDataset(opt, device=device, type='val').dataloader()
trainer.metrics = [PSNRMeter(), ]
trainer.train(train_loader, valid_loader, max_epoch)
# last validation
trainer.metrics = [PSNRMeter(), SSIMMeter(), LPIPSMeter(device=device)]
trainer.evaluate(valid_loader)
# also test
test_loader = NeRFDataset(opt, device=device, type=opt.test_split).dataloader()
if test_loader.has_gt:
trainer.evaluate(test_loader, name='test') # blender has gt, so evaluate it.
# trainer.test(test_loader, write_video=True) # test and save video
if opt.no_baking:
end_t = time.time()
print(f"[INFO] vosh stage takes {(end_t - start_t) / 60:.6f} minutes without baking.")
exit()
# baking
all_loader = NeRFDataset(opt, device=device, type='train_all')
all_loader.training = False # load full image from train split
all_loader = all_loader.dataloader()
occ_grid_start_t = time.time()
print(f"==> Start baking occupancy grid.")
occ_grid_list = sorted(glob.glob(f'{opt.workspace}/merf_occ_grid_vol.pt'))
if occ_grid_list:
print(f'[INFO] Load occ_grid from {occ_grid_list[-1]}')
occ_grid = torch.load(occ_grid_list[-1])
print(f'[INFO] Occupancy rate: '
f'{occ_grid.sum().item() / occ_grid.numel() * 100.:.2f}% '
f'with resolution {occ_grid.shape[0]}')
else:
occ_grid = trainer.cal_occ_grid(all_loader)
torch.save(occ_grid, os.path.join(opt.workspace, 'merf_occ_grid_vol.pt'))
occ_grid_end_t = time.time()
print(f"[INFO] Baking occupancy grid takes {(occ_grid_end_t - occ_grid_start_t) / 60:.6f} minutes.")
baking_start_t = time.time()
trainer.save_baking(loader=all_loader, occupancy_grid=occ_grid)
model.export_mesh_assert_by_number(path=os.path.join(opt.workspace, 'assets', 'mesh'))
# model.export_mesh_assert_cas(path=os.path.join(opt.workspace, 'assets', 'mesh'))
baking_end_t = time.time()
print(f"[INFO] Baking takes {(baking_end_t - baking_start_t) / 60:.6f} minutes.")
print(f"==> Finished baking.")
end_t = time.time()
print(f"[INFO] vosh stage takes {(end_t - start_t) / 60:.6f} minutes.")