https://github.com/zachzhang07/vosh
Tip revision: da207d03e7994d9c5a097126dcd509abedc26bc0 authored by zachzhang07 on 21 November 2024, 08:07:14 UTC
Update readme.md
Update readme.md
Tip revision: da207d0
main_vol.py
import argparse
from nerf.utils import *
# from nerf.gui import NeRFGUI
from nerf import math
import time
if __name__ == '__main__':
start_t = time.time()
parser = argparse.ArgumentParser()
parser.add_argument('path', type=str)
parser.add_argument('--workspace', type=str, default='workspace')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--ckpt', type=str, default='latest')
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--fast_baking', action='store_true',
help="faster baking at the cost of maybe missing blocks at background")
### new options
parser.add_argument('--render', type=str, default='grid', choices=['grid'])
parser.add_argument('--criterion', type=str, default='MSE', choices=['L1', 'MSE'])
parser.add_argument('--max_edge_len', type=float, default=0.5)
parser.add_argument('--min_iso_size', type=int, default=100)
parser.add_argument('--alpha_thres', type=float, default=0.005,
help="initial learning rate for vert optimization")
parser.add_argument('--cascade_list', type=float, nargs='*', default=[0.5, 1.0, 1.5, 2.0])
parser.add_argument('--decimate_target', type=int, nargs='*', default=[5e5, 5e5, 2e6, 2e6],
help="decimate target for number of triangles, <=0 to disable")
### mesh options
parser.add_argument('--visibility_mask_dilation', type=int, default=50, help="visibility dilation")
parser.add_argument('--clean_min_f', type=int, default=25, help="mesh clean: min face count for isolated mesh")
parser.add_argument('--clean_min_d', type=int, default=10, help="mesh clean: min diameter for isolated mesh")
parser.add_argument('--ssaa', type=int, default=2, help="super sampling anti-aliasing ratio")
parser.add_argument('--texture_size', type=int, default=4096, help="exported texture resolution")
### model options
# parser.add_argument('--backbone', type=str, default='merf_new', choices=['merf_new'], help="backbone type")
parser.add_argument('--grid_resolution', type=int, default=1024)
parser.add_argument('--triplane_resolution', type=int, default=-1)
### testing options
parser.add_argument('--save_cnt', type=int, default=1,
help="save checkpoints for $ times during training")
parser.add_argument('--eval_cnt', type=int, default=1,
help="perform validation for $ times during training")
parser.add_argument('--test', action='store_true', help="test mode")
parser.add_argument('--test_no_video', action='store_true', help="test mode: do not save video")
parser.add_argument('--test_no_baking', action='store_true', help="test mode: do not save baking")
parser.add_argument('--test_no_mesh', action='store_true', help="test mode: do not save mesh")
parser.add_argument('--camera_traj', type=str, default='path',
help="interp for interpolation, circle for circular camera")
### dataset options
parser.add_argument('--data_format', type=str, default='colmap', choices=['nerf', 'colmap', 'dtu'])
parser.add_argument('--train_split', type=str, default='train', choices=['train', 'all'])
parser.add_argument('--test_split', type=str, default='test', choices=['train', 'val', 'test'])
parser.add_argument('--preload', action='store_true',
help="preload all data into GPU, accelerate training but use more GPU memory")
parser.add_argument('--random_image_batch', action='store_true',
help="randomly sample rays from all images per step in training")
parser.add_argument('--downscale', type=int, default=4, help="downscale training images")
parser.add_argument('--bound', type=float, default=128,
help="assume the scene is bounded in box[-bound, bound]^3, "
"if > 1, will invoke adaptive ray marching.")
parser.add_argument('--scale', type=float, default=-1,
help="scale camera location into box[-bound, bound]^3, "
"-1 means automatically determine based on camera poses..")
parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location")
parser.add_argument('--enable_cam_near_far', action='store_true',
help="colmap mode: use the sparse points to estimate camera near far per view.")
parser.add_argument('--enable_cam_center', action='store_true',
help="use camera center instead of sparse point center (colmap dataset only)")
parser.add_argument('--min_near', type=float, default=0.2, help="minimum near distance for camera")
parser.add_argument('--T_thresh', type=float, default=2e-4,
help="minimum transmittance to continue ray marching")
### optimization options
parser.add_argument('--lr_init', type=float, default=1e-2, help="The initial learning rate")
parser.add_argument('--lr_final', type=float, default=1e-3, help="The final learning rate")
parser.add_argument('--lr_delay_steps', type=int, default=100,
help="The number of 'warmup' learning steps")
parser.add_argument('--lr_delay_mult', type=float, default=0.01,
help="How much sever the 'warmup' should be")
# grad_max_norm: float = 0.001 # Gradient clipping magnitude, disabled if == 0.
# grad_max_val: float = 0.0 # Gradient clipping value, disabled if == 0.
### training options
parser.add_argument('--iters', type=int, default=25000, help="training iters")
# parser.add_argument('--lr', type=float, default=1e-3, help="initial learning rate")
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
# parser.add_argument('--max_steps', type=int, default=1024,
# help="max num steps sampled per ray (only valid when using --cuda_ray)")
parser.add_argument('--num_steps', type=int, nargs='*', default=[128, 64, 32],
help="num steps sampled per ray for each proposal level (only valid when NOT using --cuda_ray)")
parser.add_argument('--contract', action='store_true',
help="apply spatial contraction as in MERF, only work for bound > 1, will override bound to 2.")
parser.add_argument('--enable_dense_depth', action='store_true', help="dense depth supervision")
parser.add_argument('--background', type=str, default='random',
choices=['white', 'random', 'last_sample'], help="training background mode")
# parser.add_argument('--update_extra_interval', type=int, default=16,
# help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096 * 4,
help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
# parser.add_argument('--grid_size', type=int, default=128, help="density grid resolution")
parser.add_argument('--mark_untrained', action='store_true', help="mark_untrained grid")
parser.add_argument('--dt_gamma', type=float, default=1 / 256,
help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, "
">0 to accelerate rendering (but usually with worse quality)")
parser.add_argument('--density_thresh', type=float, default=10,
help="threshold for density grid to be occupied")
parser.add_argument('--diffuse_step', type=int, default=0,
help="training iters that only trains diffuse color for better initialization")
# batch size related
parser.add_argument('--num_rays', type=int, default=4096,
help="num rays sampled per image for each training step")
parser.add_argument('--adaptive_num_rays', action='store_true',
help="adaptive num rays for more efficient training")
parser.add_argument('--num_points', type=int, default=2 ** 18,
help="target num points for each training step, only work with adaptive num_rays")
# regularizations
parser.add_argument('--lambda_entropy', type=float, default=0, help="loss scale")
parser.add_argument('--lambda_tv', type=float, default=0, help="loss scale")
parser.add_argument('--lambda_proposal', type=float, default=1,
help="loss scale (only for non-cuda-ray mode)")
parser.add_argument('--lambda_distort', type=float, default=0.01,
help="loss scale (only for non-cuda-ray mode)")
parser.add_argument('--lambda_specular', type=float, default=1e-5, help="loss scale")
parser.add_argument('--lambda_depth', type=float, default=0, help="loss scale")
parser.add_argument('--lambda_sparsity', type=float, default=0.01, help="loss scale")
parser.add_argument('--lambda_mask', type=float, default=0.1, help="loss scale")
### GUI options
parser.add_argument('--vis_pose', action='store_true', help="visualize the poses")
parser.add_argument('--gui', action='store_true', help="start a GUI")
parser.add_argument('--W', type=int, default=1000, help="GUI width")
parser.add_argument('--H', type=int, default=1000, help="GUI height")
parser.add_argument('--radius', type=float, default=1, help="default GUI camera radius from center")
parser.add_argument('--fovy', type=float, default=50, help="default GUI camera fovy")
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")
opt = parser.parse_args()
opt.fp16 = True
opt.preload = True
opt.contract = True
opt.adaptive_num_rays = True
opt.enable_cam_center = True
opt.enable_cam_near_far = True
# assert opt.alpha_thres == 0.005
opt.random_image_batch = True
if opt.contract:
# mark untrained is not correct in contraction mode...
opt.mark_untrained = False
if opt.data_format == 'colmap':
from nerf.colmap_provider import ColmapDataset as NeRFDataset
elif opt.data_format == 'dtu':
from nerf.dtu_provider import NeRFDataset
else: # nerf
from nerf.provider import NeRFDataset
#todo: change contract by opt.data_format
if opt.data_format == 'nerf':
# opt.bound = 1
opt.min_near = 0.05
# opt.diffuse_step = 0
# opt.cascade_list = [0.5, 1.0, 2.0]
# opt.cascade_list = [0.5, 1.0]
opt.contract = False
#opt.alpha_thres = 0.001
opt.decimate_target = [1e5, 1e5, 1e5]
opt.random_image_batch = True
opt.clean_min_f = 16
opt.clean_min_d = 10
opt.visibility_mask_dilation = 5
#todo end
seed_everything(opt.seed)
# if opt.backbone == 'merf_new':
# from nerf.network_vosh import NeRFNetwork
# else:
# raise NotImplementedError
from nerf.network_vosh import NeRFNetwork
if 'stump' in opt.path:
opt.diffuse_step = 5000
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NeRFNetwork(opt)
if opt.criterion == 'L1':
criterion = torch.nn.SmoothL1Loss(reduction='none')
elif opt.criterion == 'MSE':
criterion = torch.nn.MSELoss(reduction='none')
else:
raise NotImplementedError
if opt.test:
trainer = Trainer('vol', opt, model, device=device, workspace=opt.workspace, criterion=criterion,
fp16=opt.fp16, use_checkpoint=opt.ckpt)
if not opt.test_no_video:
test_loader = NeRFDataset(opt, device=device, type=opt.test_split).dataloader()
if test_loader.has_gt:
trainer.metrics = [PSNRMeter(), SSIMMeter(), LPIPSMeter(device=device)] # set up metrics
trainer.evaluate(test_loader, name='test') # blender has gt, so evaluate it.
trainer.test(test_loader, write_video=True) # test and save video
else:
optimizer = lambda model: torch.optim.Adam(model.get_params(1.0), eps=1e-15)
train_loader = NeRFDataset(opt, device=device, type=opt.train_split).dataloader()
max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)
save_interval = max(1, max_epoch // max(1, opt.save_cnt))
eval_interval = max(1, max_epoch // max(1, opt.eval_cnt))
print(f'[INFO] max_epoch {max_epoch}, eval every {eval_interval}.')
# colmap can estimate a more compact AABB
# if not opt.contract and opt.data_format == 'colmap':
# model.update_aabb(train_loader._data.pts_aabb)
# scheduler = lambda optimizer: torch.optim.lr_scheduler.LambdaLR(optimizer,
# lambda iter: 0.1 ** (iter / opt.iters))
scheduler = lambda optimizer: (
torch.optim.lr_scheduler.LambdaLR(optimizer, lambda iter:
math.learning_rate_decay(iter + 1, opt.lr_init, opt.lr_final, opt.iters,
opt.lr_delay_steps, opt.lr_delay_mult)))
trainer = Trainer('vol', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer,
criterion=criterion, ema_decay=None, fp16=opt.fp16, lr_scheduler=scheduler,
scheduler_update_every_step=True, use_checkpoint=opt.ckpt, eval_interval=eval_interval,
save_interval=save_interval)
if opt.gui:
# gui = NeRFGUI(opt, trainer, train_loader)
# gui.render()
pass
else:
valid_loader = NeRFDataset(opt, device=device, type='val').dataloader()
trainer.metrics = [PSNRMeter(), ]
trainer.train(train_loader, valid_loader, max_epoch)
# last validation
trainer.metrics = [PSNRMeter(), SSIMMeter(), LPIPSMeter(device=device)]
trainer.evaluate(valid_loader)
# also test
test_loader = NeRFDataset(opt, device=device, type=opt.test_split).dataloader()
if test_loader.has_gt:
trainer.evaluate(test_loader, name='test') # blender has gt, so evaluate it.
# trainer.test(test_loader, write_video=True) # test and save video
all_loader = NeRFDataset(opt, device=device, type='train_all')
all_loader.training = False # load full image from train split
all_loader = all_loader.dataloader()
occ_grid_list = sorted(glob.glob(f'{opt.workspace}/merf_occ_grid_vol.pt'))
if occ_grid_list:
print(f'[INFO] Load occ_grid from {occ_grid_list[-1]}')
occ_grid = torch.load(occ_grid_list[-1])
print(f'[INFO] Occupancy rate: '
f'{occ_grid.sum().item() / occ_grid.numel() * 100.:.2f}% '
f'with resolution {occ_grid.shape[0]}')
else:
occ_grid = trainer.cal_occ_grid(all_loader)
torch.save(occ_grid, os.path.join(opt.workspace, 'merf_occ_grid_vol.pt'))
# trainer.save_baking(loader=all_loader, occupancy_grid=occ_grid)
trainer.voxel_to_mesh(occ_grid.cuda(), resolution=opt.grid_resolution, loader=all_loader)
voxel_error_grid_list = sorted(glob.glob(f'{opt.workspace}/voxel_error_grid*.pt'))
if voxel_error_grid_list:
print(f'[INFO] Load voxel_error_grid from {voxel_error_grid_list[-1]}')
voxel_error_grid = torch.load(voxel_error_grid_list[-1]).to(device)
else:
voxel_error_grid = trainer.run_voxel_error_epoch(all_loader, resolution=512).half()
torch.save(voxel_error_grid, os.path.join(opt.workspace, 'voxel_error_grid.pt'))
end_t = time.time()
print(f"[INFO] volume stage takes {(end_t - start_t) / 60:.6f} minutes.")
