Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

swh:1:snp:eee76444da62e238a10272cb71070ca8823b3f3d
  • Code
  • Branches (1)
  • Releases (0)
    • Branches
    • Releases
    • HEAD
    • refs/heads/main
    No releases to show
  • 6250ce0
  • /
  • nerf
  • /
  • utils.py
Raw File Download

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • content
  • directory
  • revision
  • snapshot
content badge
swh:1:cnt:843b6a90682f5e70aeeca07dcf7efdc9bb68d5b7
directory badge
swh:1:dir:2aec7f959a197d6347e16cd0cda954702fe7482a
revision badge
swh:1:rev:da207d03e7994d9c5a097126dcd509abedc26bc0
snapshot badge
swh:1:snp:eee76444da62e238a10272cb71070ca8823b3f3d

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • content
  • directory
  • revision
  • snapshot
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
Tip revision: da207d03e7994d9c5a097126dcd509abedc26bc0 authored by zachzhang07 on 21 November 2024, 08:07:14 UTC
Update readme.md
Tip revision: da207d0
utils.py
import os
import glob
import tqdm
import math
import imageio
import random
import warnings
import tensorboardX
import mcubes
import shutil

import numpy as np
import json

import time

import cv2
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torchmetrics.functional import structural_similarity_index_measure

import trimesh
import nvdiffrast.torch as dr
from rich.console import Console
from torch_ema import ExponentialMovingAverage

from packaging import version as pver
import lpips
from meshutils import *
from .renderer import contract, uncontract
from nerf import quantize

from skimage import measure


class CharbonnierLoss(nn.Module):
    def __init__(self, epsilon=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.epsilon2 = epsilon * epsilon

    def forward(self, pred_rgb, gt_rgb):
        value = torch.sqrt(torch.pow(pred_rgb - gt_rgb, 2) + self.epsilon2)
        return value


def laplacian_cot(verts, faces):
    """
    Compute the cotangent laplacian
    Inspired by https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/loss/mesh_laplacian_smoothing.html
    Parameters
    ----------
    verts : torch.Tensor
        Vertex positions.
    faces : torch.Tensor
        array of triangle faces.
    """

    # V = sum(V_n), F = sum(F_n)
    V, F = verts.shape[0], faces.shape[0]

    face_verts = verts[faces]
    v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]

    # Side lengths of each triangle, of shape (sum(F_n),)
    # A is the side opposite v1, B is opposite v2, and C is opposite v3
    A = (v1 - v2).norm(dim=1)
    B = (v0 - v2).norm(dim=1)
    C = (v0 - v1).norm(dim=1)

    # Area of each triangle (with Heron's formula); shape is (sum(F_n),)
    s = 0.5 * (A + B + C)
    # note that the area can be negative (close to 0) causing nans after sqrt()
    # we clip it to a small positive value
    area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()

    # Compute cotangents of angles, of shape (sum(F_n), 3)
    A2, B2, C2 = A * A, B * B, C * C
    cota = (B2 + C2 - A2) / area
    cotb = (A2 + C2 - B2) / area
    cotc = (A2 + B2 - C2) / area
    cot = torch.stack([cota, cotb, cotc], dim=1)
    cot /= 4.0

    # Construct a sparse matrix by basically doing:
    # L[v1, v2] = cota
    # L[v2, v0] = cotb
    # L[v0, v1] = cotc
    ii = faces[:, [1, 2, 0]]
    jj = faces[:, [2, 0, 1]]
    idx = torch.stack([ii, jj], dim=0).view(2, F * 3)
    L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))

    # Make it symmetric; this means we are also setting
    # L[v2, v1] = cota
    # L[v0, v2] = cotb
    # L[v1, v0] = cotc
    L += L.t()

    # Add the diagonal indices
    vals = torch.sparse.sum(L, dim=0).to_dense()
    indices = torch.arange(V, device='cuda')
    idx = torch.stack([indices, indices], dim=0)
    L = torch.sparse.FloatTensor(idx, vals, (V, V)) - L
    return L


def laplacian_uniform(verts, faces):
    """
    Compute the uniform laplacian
    Parameters
    ----------
    verts : torch.Tensor
        Vertex positions.
    faces : torch.Tensor
        array of triangle faces.
    """
    V = verts.shape[0]
    F = faces.shape[0]

    # Neighbor indices
    ii = faces[:, [1, 2, 0]].flatten()
    jj = faces[:, [2, 0, 1]].flatten()
    adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(dim=1)
    adj_values = torch.ones(adj.shape[1], device=verts.device, dtype=torch.float)

    # Diagonal indices
    diag_idx = adj[0]

    # Build the sparse matrix
    idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)
    values = torch.cat((-adj_values, adj_values))

    # The coalesce operation sums the duplicate indices, resulting in the
    # correct diagonal
    return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()


def laplacian_smooth_loss(verts, faces, cotan=False):
    with torch.no_grad():
        if cotan:
            L = laplacian_cot(verts, faces.long())
            norm_w = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1)
            mask = norm_w > 0
            norm_w[mask] = 1.0 / norm_w[mask]
        else:
            L = laplacian_uniform(verts, faces.long())
    if cotan:
        loss = L.mm(verts) * norm_w - verts
    else:
        loss = L.mm(verts)
    loss = loss.norm(dim=1)
    loss = loss.mean()
    return loss


def custom_meshgrid(*args):
    # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
    if pver.parse(torch.__version__) < pver.parse('1.10'):
        return torch.meshgrid(*args)
    else:
        return torch.meshgrid(*args, indexing='ij')


def plot_pointcloud(pc, color=None):
    # pc: [N, 3]
    # color: [N, 3/4]

    print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0))

    # import open3d as o3d
    # pcd = o3d.geometry.PointCloud()
    # pcd.points = o3d.utility.Vector3dVector(pc)
    # if color is not None:
    #     pcd.colors = o3d.utility.Vector3dVector(color)
    # o3d.visualization.draw_geometries([pcd])

    pc = trimesh.PointCloud(pc, color)
    # axis
    axes = trimesh.creation.axis(axis_length=4)
    # sphere
    box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline()
    box.colors = np.array([[128, 128, 128]] * len(box.entities))

    trimesh.Scene([pc, axes, box]).show()


def create_dodecahedron_cameras(radius=1, center=np.array([0, 0, 0])):
    vertices = np.array([
        -0.57735, -0.57735, 0.57735,
        0.934172, 0.356822, 0,
        0.934172, -0.356822, 0,
        -0.934172, 0.356822, 0,
        -0.934172, -0.356822, 0,
        0, 0.934172, 0.356822,
        0, 0.934172, -0.356822,
        0.356822, 0, -0.934172,
        -0.356822, 0, -0.934172,
        0, -0.934172, -0.356822,
        0, -0.934172, 0.356822,
        0.356822, 0, 0.934172,
        -0.356822, 0, 0.934172,
        0.57735, 0.57735, -0.57735,
        0.57735, 0.57735, 0.57735,
        -0.57735, 0.57735, -0.57735,
        -0.57735, 0.57735, 0.57735,
        0.57735, -0.57735, -0.57735,
        0.57735, -0.57735, 0.57735,
        -0.57735, -0.57735, -0.57735,
    ]).reshape((-1, 3), order="C")

    length = np.linalg.norm(vertices, axis=1).reshape((-1, 1))
    vertices = vertices / length * radius + center

    # construct camera poses by lookat
    def normalize(x):
        return x / (np.linalg.norm(x, axis=-1, keepdims=True) + 1e-8)

    # forward is simple, notice that it is in fact the inversion of camera direction!
    forward_vector = normalize(vertices - center)
    # pick a temp up_vector, usually [0, 1, 0]
    up_vector = np.array([0, 1, 0], dtype=np.float32)[None].repeat(forward_vector.shape[0], 0)
    # cross(up, forward) --> right
    right_vector = normalize(np.cross(up_vector, forward_vector, axis=-1))
    # rectify up_vector, by cross(forward, right) --> up
    up_vector = normalize(np.cross(forward_vector, right_vector, axis=-1))

    ### construct c2w
    poses = np.eye(4, dtype=np.float32)[None].repeat(forward_vector.shape[0], 0)
    poses[:, :3, :3] = np.stack((right_vector, up_vector, forward_vector), axis=-1)
    poses[:, :3, 3] = vertices

    return poses


@torch.cuda.amp.autocast(enabled=False)
def get_rays(poses, intrinsics, H, W, N=-1, patch_size=1, coords=None):
    ''' get rays
    Args:
        poses: [N/1, 4, 4], cam2world
        intrinsics: [N/1, 4] tensor or [4] ndarray
        H, W, N: int
    Returns:
        rays_o, rays_d: [N, 3]
        i, j: [N]
    '''

    device = poses.device

    if isinstance(intrinsics, np.ndarray):
        fx, fy, cx, cy = intrinsics
    else:
        fx, fy, cx, cy = intrinsics[:, 0], intrinsics[:, 1], intrinsics[:, 2], intrinsics[:, 3]

    i, j = custom_meshgrid(torch.linspace(0, W - 1, W, device=device),
                           torch.linspace(0, H - 1, H, device=device))  # float
    i = i.t().contiguous().view(-1) + 0.5
    j = j.t().contiguous().view(-1) + 0.5

    results = {}

    if N > 0:

        if coords is not None:
            inds = coords[:, 0] * W + coords[:, 1]

        elif patch_size > 1:

            # random sample left-top cores.
            num_patch = N // (patch_size ** 2)
            inds_x = torch.randint(0, H - patch_size, size=[num_patch], device=device)
            inds_y = torch.randint(0, W - patch_size, size=[num_patch], device=device)
            inds = torch.stack([inds_x, inds_y], dim=-1)  # [np, 2]

            # create meshgrid for each patch
            pi, pj = custom_meshgrid(torch.arange(patch_size, device=device), torch.arange(patch_size, device=device))
            offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1)  # [p^2, 2]

            inds = inds.unsqueeze(1) + offsets.unsqueeze(0)  # [np, p^2, 2]
            inds = inds.view(-1, 2)  # [N, 2]
            inds = inds[:, 0] * W + inds[:, 1]  # [N], flatten

        else:  # random sampling
            inds = torch.randint(0, H * W, size=[N], device=device)  # may duplicate

        i = torch.gather(i, -1, inds)
        j = torch.gather(j, -1, inds)

        results['i'] = i.long()
        results['j'] = j.long()

    else:
        inds = torch.arange(H * W, device=device)

    zs = -torch.ones_like(i)  # z is flipped
    xs = (i - cx) / fx
    ys = -(j - cy) / fy  # y is flipped
    directions = torch.stack((xs, ys, zs), dim=-1)  # [N, 3]
    # do not normalize to get actual depth, ref: https://github.com/dunbar12138/DSNeRF/issues/29
    # directions = directions / torch.norm(directions, dim=-1, keepdim=True) 
    rays_d = (directions.unsqueeze(1) @ poses[:, :3, :3].transpose(-1, -2)).squeeze(
        1)  # [N, 1, 3] @ [N, 3, 3] --> [N, 1, 3]

    rays_o = poses[:, :3, 3].expand_as(rays_d)  # [N, 3]

    results['rays_o'] = rays_o
    results['rays_d'] = rays_d

    # visualize_rays(rays_o[0].detach().cpu().numpy(), rays_d[0].detach().cpu().numpy())

    return results


def visualize_rays(rays_o, rays_d):
    axes = trimesh.creation.axis(axis_length=4)
    box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline()
    box.colors = np.array([[128, 128, 128]] * len(box.entities))
    objects = [axes, box]

    for i in range(0, rays_o.shape[0], 10):
        ro = rays_o[i]
        rd = rays_d[i]

        segs = np.array([[ro, ro + rd * 3]])
        segs = trimesh.load_path(segs)
        objects.append(segs)

    trimesh.Scene(objects).show()


def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True


# def torch_vis_2d(x, renormalize=False):
#     # x: [3, H, W], [H, W, 3] or [1, H, W] or [H, W]
#     import matplotlib.pyplot as plt
#     import numpy as np
#     import torch
#
#     if isinstance(x, torch.Tensor):
#         if len(x.shape) == 3 and x.shape[0] == 3:
#             x = x.permute(1, 2, 0).squeeze()
#         x = x.detach().cpu().numpy()
#
#     print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}')
#
#     x = x.astype(np.float32)
#
#     # renormalize
#     if renormalize:
#         x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8)
#
#     plt.imshow(x)
#     plt.show()


class PSNRMeter:
    def __init__(self):
        self.V = 0
        self.N = 0

    def clear(self):
        self.V = 0
        self.N = 0

    def prepare_inputs(self, *inputs):
        outputs = []
        for i, inp in enumerate(inputs):
            if torch.is_tensor(inp):
                inp = inp.detach().cpu().numpy()
            outputs.append(inp)

        return outputs

    def update(self, preds, truths):
        preds, truths = self.prepare_inputs(preds, truths)  # [B, N, 3] or [B, H, W, 3], range[0, 1]

        # simplified since max_pixel_value is 1 here.
        psnr = -10 * np.log10(np.mean((preds - truths) ** 2))

        self.V += psnr
        self.N += 1

        return psnr

    def measure(self):
        return self.V / self.N

    def write(self, writer, global_step, prefix=""):
        writer.add_scalar(os.path.join(prefix, "PSNR"), self.measure(), global_step)

    def report(self):
        return f'PSNR = {self.measure():.6f}'


class LPIPSMeter:
    def __init__(self, net='vgg', device=None):
        self.V = 0
        self.N = 0
        self.net = net

        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.fn = lpips.LPIPS(net=net).eval().to(self.device)

    def clear(self):
        self.V = 0
        self.N = 0

    def prepare_inputs(self, *inputs):
        outputs = []
        for i, inp in enumerate(inputs):
            if len(inp.shape) == 3:
                inp = inp.unsqueeze(0)
            inp = inp.permute(0, 3, 1, 2).contiguous()  # [B, 3, H, W]
            inp = inp.to(self.device)
            outputs.append(inp)
        return outputs

    def update(self, preds, truths):
        preds, truths = self.prepare_inputs(preds, truths)  # [H, W, 3] --> [B, 3, H, W], range in [0, 1]
        v = self.fn(truths, preds, normalize=True).item()  # normalize=True: [0, 1] to [-1, 1]
        self.V += v
        self.N += 1

        return v

    def measure(self):
        return self.V / self.N

    def write(self, writer, global_step, prefix=""):
        writer.add_scalar(os.path.join(prefix, f"LPIPS ({self.net})"), self.measure(), global_step)

    def report(self):
        return f'LPIPS ({self.net}) = {self.measure():.6f}'


class SSIMMeter:
    def __init__(self, device=None):
        self.V = 0
        self.N = 0

        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def clear(self):
        self.V = 0
        self.N = 0

    def prepare_inputs(self, *inputs):
        outputs = []
        for i, inp in enumerate(inputs):
            if len(inp.shape) == 3:
                inp = inp.unsqueeze(0)
            inp = inp.permute(0, 3, 1, 2).contiguous()  # [B, 3, H, W]
            inp = inp.to(self.device)
            outputs.append(inp)
        return outputs

    def update(self, preds, truths):
        preds, truths = self.prepare_inputs(preds, truths)  # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1]

        ssim = structural_similarity_index_measure(preds, truths)

        self.V += ssim
        self.N += 1

    def measure(self):
        return self.V / self.N

    def write(self, writer, global_step, prefix=""):
        writer.add_scalar(os.path.join(prefix, "SSIM"), self.measure(), global_step)

    def report(self):
        return f'SSIM = {self.measure():.6f}'


def create_point_pyramid(points):
    # points: Float[Tensor, "3 height width depth"]) -> List[Float[Tensor, "3 height width depth"]]
    """
    Create a point pyramid for multi-resolution evaluation.

    Args:
        points: A torch tensor containing 3D points.

    Returns:
        A list of torch tensors representing points at different resolutions.
    """
    points_pyramid = [points]
    for _ in range(3):
        points = torch.nn.AvgPool3d(2, stride=2)(points[None])[0]
        points_pyramid.append(points)
    points_pyramid = points_pyramid[::-1]
    return points_pyramid


class Trainer(object):
    def __init__(self,
                 name,  # name of this experiment
                 opt,  # extra conf
                 model,  # network
                 criterion=None,  # loss function, if None, assume inline implementation in train_step
                 optimizer=None,  # optimizer
                 ema_decay=None,  # if use EMA, set the decay
                 lr_scheduler=None,  # scheduler
                 metrics=[],
                 # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
                 local_rank=0,  # which GPU am I
                 world_size=1,  # total num of GPUs
                 device=None,  # device to use, usually setting to None is OK. (auto choose device)
                 mute=False,  # whether to mute all print
                 fp16=False,  # amp optimize level
                 eval_interval=1,  # eval once every $ epoch
                 save_interval=1,  # save once every $ epoch (independently from eval)
                 max_keep_ckpt=1,  # max num of saved ckpts in disk
                 workspace='workspace',  # workspace to save logs & ckpts
                 best_mode='min',  # the smaller/larger result, the better
                 use_loss_as_metric=True,  # use loss as the first metric
                 report_metric_at_train=False,  # also report metrics at training
                 use_checkpoint="latest",  # which ckpt to use at init time
                 use_tensorboardX=True,  # whether to use tensorboard for logging
                 scheduler_update_every_step=False,  # whether to call scheduler.step() after every train step
                 vosh_interval=None
                 ):

        self.opt = opt
        self.writer = None
        self.name = name
        self.mute = mute
        self.metrics = metrics
        self.local_rank = local_rank
        self.world_size = world_size
        self.workspace = workspace
        self.ema_decay = ema_decay
        self.fp16 = fp16
        self.best_mode = best_mode
        self.use_loss_as_metric = use_loss_as_metric
        self.report_metric_at_train = report_metric_at_train
        self.max_keep_ckpt = max_keep_ckpt
        self.eval_interval = eval_interval
        self.save_interval = save_interval
        self.vosh_interval = vosh_interval
        self.use_checkpoint = use_checkpoint
        self.use_tensorboardX = use_tensorboardX
        self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
        self.scheduler_update_every_step = scheduler_update_every_step
        self.device = device if device is not None else torch.device(
            f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
        self.console = Console()
        self.glctx = None

        # try out torch 2.0
        if torch.__version__[0] == '2':
            model = torch.compile(model)

        model.to(self.device)
        if self.world_size > 1:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
        self.model = model

        if isinstance(criterion, nn.Module):
            criterion.to(self.device)
        self.criterion = criterion

        # self.optimizer = optimizer
        # self.lr_scheduler = lr_scheduler

        self.optimizer_fn = optimizer
        if optimizer is None:
            # naive adam
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4)
        else:
            self.optimizer = self.optimizer_fn(self.model)

        self.lr_scheduler_fn = lr_scheduler
        if lr_scheduler is None:
            # fake scheduler
            self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1)
        else:
            self.lr_scheduler = self.lr_scheduler_fn(self.optimizer)

        if ema_decay is not None:
            self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay)
        else:
            self.ema = None

        self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)

        # variable init
        self.epoch = 0
        self.global_step = 0
        self.local_step = 0
        self.stats = {
            "loss": [],
            "valid_loss": [],
            "results": [],  # metrics[0], or valid_loss
            "checkpoints": [],  # record path of saved ckpt, to automatically remove old ckpt
            "best_result": None,
        }
        self.measures = []

        # auto fix
        if len(metrics) == 0 or self.use_loss_as_metric:
            self.best_mode = 'min'

        # workspace prepare
        self.log_ptr = None
        os.makedirs(self.workspace, exist_ok=True)
        self.log_path = os.path.join(workspace, f"log_{self.name}.txt")
        self.log_ptr = open(self.log_path, "a+")

        self.ckpt_path = os.path.join(self.workspace, 'checkpoints')
        self.best_path = f"{self.ckpt_path}/{self.name}.pth"
        os.makedirs(self.ckpt_path, exist_ok=True)

        self.voxel_error_grid_path = None
        error_grid = None

        if self.opt.render != 'grid':
            self.continue_vosh = False
            self.occ_grid_path = None

            self.mesh_path = os.path.join(self.workspace, 'mesh')
            os.makedirs(self.mesh_path, exist_ok=True)

            checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/{self.name}*.pth'))

            if not checkpoint_list:
                checkpoint_list = sorted(glob.glob(f'{opt.vol_path}/checkpoints/mesh*.pth'))
                if not checkpoint_list or ('use_vol_pth' in self.opt and self.opt.use_vol_pth):
                    checkpoint_list = sorted(glob.glob(f'{opt.vol_path}/checkpoints/vol*.pth'))
                assert len(checkpoint_list), f'No checkpoints found in {opt.vol_path}/checkpoints !'
                checkpoint = checkpoint_list[-1]
                shutil.copy(checkpoint, self.ckpt_path)
                self.log(f"[INFO] Copy checkpoint from {checkpoint}")
                if 'mesh_encoder' in self.opt and self.opt.mesh_encoder:
                    mesh_ckpt = sorted(glob.glob(f'{opt.vol_path}/checkpoints/mesh*.pth'))[-1]
                    shutil.copy(mesh_ckpt, self.ckpt_path)
                    self.log(f"[INFO] Copy mesh checkpoint from {mesh_ckpt}")

                # occ_grid_list = sorted(glob.glob(f'{opt.vol_path}/merf_occ_grid_vol.pt'))
                # if len(occ_grid_list):
                #     self.occ_grid_path = occ_grid_list[-1]

            if self.opt.render == 'mesh':
                voxel_error_grid_list = sorted(glob.glob(f'{opt.vol_path}/voxel_error_grid*.pt'))
                if len(voxel_error_grid_list):
                    self.voxel_error_grid_path = voxel_error_grid_list[-1]
            # elif self.opt.render == 'mixed' and self.opt.mesh_select > 0:
            elif self.opt.render == 'mixed' and 0 < self.opt.mesh_select < 1.0:
                error_grid_list = sorted(glob.glob(f'{opt.vol_path}/error_grid*.pt'))
                if len(error_grid_list):
                    error_grid = torch.load(error_grid_list[-1]).to(self.device)
                    error_grid.requires_grad = False
                    self.log(f'[INFO] Load error grid {error_grid_list[-1]}!')
                    self.log(f'[INFO] error>0: {torch.sum(error_grid > 0)}, error<0: {torch.sum(error_grid < 0)}')
                    verts, faces = self.model.remove_faces_in_selected_voxels(error_grid,
                                                                              mesh_select=self.opt.mesh_select,
                                                                              keep_center=opt.keep_center)
                    self.log(f'[INFO] update mesh with {len(verts)} verts, {len(faces)} faces')

                #         error_grid = 1.0 - (error_grid - error_grid.min()) / (
                #                 error_grid.max() - error_grid.min())
                #
                #         if self.opt.ec_center > 0:
                #             center_mask = torch.ones_like(error_grid).bool()
                #             half_resolution = center_mask.shape[0] // 2
                #             half_itv = int(self.opt.ec_center * half_resolution)
                #             center_mask[half_resolution - half_itv:half_resolution + half_itv,
                #             half_resolution - half_itv:half_resolution + half_itv,
                #             half_resolution - half_itv:half_resolution + half_itv, ] = False
                #             error_grid[center_mask] = 1.0
                #             self.log(f'[INFO] Load error grid with center {self.opt.ec_center}!')
                # elif self.opt.render == 'mixed' and self.opt.lambda_ec_weight > 0 and self.opt.lambda_mesh_weight > 0:
                #     error_grid_list = sorted(glob.glob(f'{opt.vol_path}/error_grid*.pt'))
                #     if len(error_grid_list):
                #         error_grid = torch.load(error_grid_list[-1]).to(self.device)
                #         error_grid.requires_grad = False
                #         # error_grid[error_grid < 0] = 0
                #         # 1 - error_grid for mesh weight loss
                #         self.log(f'[INFO] Load error grid {error_grid_list[-1]}!')
                #         self.log(f'[INFO] error>0: {torch.sum(error_grid > 0)}, error<0: {torch.sum(error_grid < 0)}')
                #
                #         error_grid = 1.0 - (error_grid - error_grid.min()) / (
                #                 error_grid.max() - error_grid.min())
                #
                #         if self.opt.ec_center > 0:
                #             center_mask = torch.ones_like(error_grid).bool()
                #             half_resolution = center_mask.shape[0] // 2
                #             half_itv = int(self.opt.ec_center * half_resolution)
                #             center_mask[half_resolution - half_itv:half_resolution + half_itv,
                #             half_resolution - half_itv:half_resolution + half_itv,
                #             half_resolution - half_itv:half_resolution + half_itv, ] = False
                #             error_grid[center_mask] = 1.0
                #             self.log(f'[INFO] Load error grid with center {self.opt.ec_center}!')
                else:
                    raise Exception('No error grid found for E_convert!')

        if error_grid is not None:
            self.model.set_error_grid(error_grid)

        backup_path = os.path.join(self.workspace, 'backup')
        os.makedirs(backup_path, exist_ok=True)
        os.makedirs(os.path.join(backup_path, 'nerf'), exist_ok=True)

        file_list = os.listdir('.')
        for filename in os.listdir('nerf/'):
            file_list.append(os.path.join('nerf', filename))
        for filename in file_list:
            if filename.endswith('py'):
                shutil.copy(filename, os.path.join(backup_path, filename))

        self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | '
                 f'{"fp16" if self.fp16 else "fp32"} | {self.workspace}')
        self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}')

        self.log(opt)
        self.log(self.model)

        if self.use_checkpoint == "scratch":
            self.log("[INFO] Training from scratch ...")
        elif self.use_checkpoint == "latest":
            self.log("[INFO] Loading latest checkpoint ...")
            self.load_checkpoint()
        elif self.use_checkpoint == "latest_model":
            self.log("[INFO] Loading latest checkpoint (model only)...")
            self.load_checkpoint(model_only=True)
        elif self.use_checkpoint == "best":
            if os.path.exists(self.best_path):
                self.log("[INFO] Loading best checkpoint ...")
                self.load_checkpoint(self.best_path)
            else:
                self.log(f"[INFO] {self.best_path} not found, loading latest ...")
                self.load_checkpoint()
        else:  # path to ckpt
            self.log(f"[INFO] Loading {self.use_checkpoint} ...")
            self.load_checkpoint(self.use_checkpoint)

    def __del__(self):
        if self.log_ptr:
            self.log_ptr.close()

    def log(self, *args, **kwargs):
        if self.local_rank == 0:
            if not self.mute:
                # print(*args)
                self.console.print(*args, **kwargs)
            if self.log_ptr:
                print(*args, file=self.log_ptr)
                self.log_ptr.flush()  # write immediately to file

    ### ------------------------------

    def train_step(self, data):

        rays_o = data['rays_o']  # [N, 3]
        rays_d = data['rays_d']  # [N, 3]
        index = data['index']  # [1/N]
        cam_near_far = data['cam_near_far'] if 'cam_near_far' in data else None  # [1/N, 2] or None
        images = data['images']  # [N, 3/4]

        N, C = images.shape

        if self.opt.background == 'random':
            bg_color = torch.rand(N, 3, device=self.device)  # [N, 3], pixel-wise random.
        else:  # white / last_sample
            bg_color = 1

        if C == 4:
            gt_mask = images[..., 3:]
            gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (1 - images[..., 3:])
        else:
            gt_mask = None
            gt_rgb = images

        shading = 'diffuse' if self.global_step < self.opt.diffuse_step else 'full'
        update_proposal = self.global_step <= 3000 or self.global_step % 5 == 0

        if self.opt.render == 'mesh' or self.opt.alpha_thres == 0:
            alpha_threshold = None
        else:
            alpha_threshold = self.model.alpha_threshold(self.global_step)

        if self.opt.render != 'grid':
            outputs = self.model.render(rays_o, rays_d, index=index, bg_color=bg_color, perturb=True,
                                        cam_near_far=cam_near_far, shading=shading, update_proposal=update_proposal,
                                        mvp=data['mvp'], rays_d_all=data['rays_d_all'],
                                        rays_i=data['rays_i'], rays_j=data['rays_j'], H=data['H'], W=data['W'],
                                        alpha_threshold=alpha_threshold)

        else:
            outputs = self.model.render(rays_o, rays_d, index=index, bg_color=bg_color, perturb=True,
                                        cam_near_far=cam_near_far, shading=shading, update_proposal=update_proposal,
                                        alpha_threshold=alpha_threshold)

        # loss
        pred_rgb = outputs['image']
        mask_mesh = None
        # if self.opt.render == 'mesh':
        #     mask_mesh = outputs['mask_mesh']
        #     pred_rgb, gt_rgb = pred_rgb[mask_mesh], gt_rgb[mask_mesh]
        #     outputs['specular'] = outputs['specular'][mask_mesh]

        loss = self.criterion(pred_rgb, gt_rgb).mean(-1)  # [N, 3] --> [N]
        loss_dict = {'loss_smoothL1': loss.mean().item()}

        if gt_mask is not None and self.opt.lambda_mask > 0:
            if self.opt.render == 'mixed':
                pred_mask = outputs['weights_vosh']
            else:
                pred_mask = outputs['weights_sum']
            # if self.opt.render == 'mesh':
            #     pred_mask, gt_mask = pred_mask[mask_mesh], gt_mask[mask_mesh]
            assert self.opt.criterion == 'MSE'
            loss = loss + self.opt.lambda_mask * self.criterion(pred_mask, gt_mask.squeeze(1))

        if self.opt.render == 'mesh' and self.opt.refine:
            self.model.update_triangles_errors(loss.detach(), mask_mesh)

        # depth loss
        if 'depth' in data:
            gt_depth = data['depth'].view(-1, 1)
            pred_depth = outputs['depth'].view(-1, 1)

            lambda_depth = self.opt.lambda_depth * min(1.0, self.global_step / 1000)
            mask = gt_depth > 0

            loss_depth = lambda_depth * self.criterion(pred_depth * mask, gt_depth * mask)  # [N]
            loss = loss + loss_depth
            loss_dict['loss_depth'] = loss_depth.mean().item()

        loss = loss.mean()

        # extra loss
        if 'specular' in outputs and self.opt.lambda_specular > 0:
            loss_spec = self.opt.lambda_specular * (outputs['specular'] ** 2).mean()
            loss = loss + loss_spec
            loss_dict['loss_spec'] = loss_spec.item()

        if update_proposal and 'proposal_loss' in outputs and self.opt.lambda_proposal > 0:
            loss_prop = self.opt.lambda_proposal * outputs['proposal_loss']
            loss = loss + loss_prop
            loss_dict['loss_prop'] = loss_prop.item()

        if 'distort_loss' in outputs and self.opt.lambda_distort > 0:
            # progressive to avoid bad init
            lambda_distort = min(1.0, self.global_step / (0.5 * self.opt.iters)) * self.opt.lambda_distort
            loss_dist = lambda_distort * outputs['distort_loss']
            loss = loss + loss_dist
            loss_dict['loss_dist'] = loss_dist.item()

        if 'lambda_entropy' in self.opt and self.opt.lambda_entropy > 0:
            w = outputs['alphas'].clamp(1e-5, 1 - 1e-5)
            entropy = - w * torch.log2(w) - (1 - w) * torch.log2(1 - w)
            loss_entro = self.opt.lambda_entropy * (entropy.mean())
            loss = loss + loss_entro
            loss_dict['loss_entro'] = loss_entro.item()

        if 'sparsity_loss' in outputs and self.opt.lambda_sparsity > 0:
            loss_spar = outputs['sparsity_loss'] * self.opt.lambda_sparsity
            loss = loss + loss_spar
            loss_dict['loss_spar'] = loss_spar.item()

        if self.opt.render == 'mixed':
            if self.opt.lambda_mesh_weight > 0:
                loss_mesh_weight = outputs['mesh_weight_loss'] * self.opt.lambda_mesh_weight
                loss = loss + loss_mesh_weight
                loss_dict['loss_mesh_weight'] = loss_mesh_weight.item()

            if self.opt.lambda_bg_weight > 0:
                loss_bg_weight = outputs['bg_weight_loss'] * self.opt.lambda_bg_weight
                loss = loss + loss_bg_weight
                loss_dict['loss_bg_weight'] = loss_bg_weight.item()

        # if 'lambda_offsets' in self.opt and self.opt.lambda_entropy > 0:
        #     abs_offsets_inner = self.model.vertices_offsets[:self.model.v_cumsum[1]].abs()
        #     abs_offsets_outer = self.model.vertices_offsets[self.model.v_cumsum[1]:].abs()
        #
        #     # inner mesh
        #     loss_offsets = (abs_offsets_inner ** 2).sum(-1).mean()
        #
        #     # outer mesh
        #     if self.opt.bound > 1:
        #         loss_offsets = loss_offsets + 0.1 * (abs_offsets_outer ** 2).sum(-1).mean()
        #
        #     loss = loss + self.opt.lambda_offsets * loss_offsets

        if 'lambda_lap' in self.opt and self.opt.lambda_lap > 0 and self.model.vertices_offsets is not None:
            loss_lap = laplacian_smooth_loss(self.model.vertices + self.model.vertices_offsets, self.model.triangles)
            loss = loss + self.opt.lambda_lap * loss_lap

        # adaptive num_rays
        if 'adaptive_num_rays' in self.opt and self.opt.adaptive_num_rays:
            self.opt.num_rays = int(round((self.opt.num_points / outputs['num_points']) * self.opt.num_rays))

        # if 'lambda_mesh_weight' in self.opt and self.opt.lambda_mesh_weight > 0:
        #     print(f"\nrgb : {loss:.5f}, "
        #           f"distort: {loss_dict['loss_dist']:.5f}, "
        #           f"mesh_weight: {loss_dict['loss_mesh_weight']:.5f}")

        return pred_rgb, gt_rgb, loss, loss_dict

    def post_train_step(self):

        # the new inplace TV loss
        if self.opt.lambda_tv > 0:
            # # progressive...
            # lambda_tv = min(1.0, self.global_step / 10000) * self.opt.lambda_tv
            lambda_tv = self.opt.lambda_tv

            # unscale grad before modifying it!
            # ref: https://pytorch.org/docs/stable/notes/amp_examples.html#gradient-clipping
            self.scaler.unscale_(self.optimizer)

            # different tv weights for inner and outer points
            self.model.apply_total_variation(lambda_tv)

    def eval_step(self, data):

        rays_o = data['rays_o']  # [N, 3]
        rays_d = data['rays_d']  # [N, 3]
        images = data['images']  # [H, W, 3/4]
        H, W, C = images.shape

        cam_near_far = data['cam_near_far'] if 'cam_near_far' in data else None  # [1/N, 2] or None

        # eval with fixed white background color
        bg_color = 1
        if C == 4:
            gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (1 - images[..., 3:])
        else:
            gt_rgb = images

        if self.opt.render == 'mesh' or self.opt.alpha_thres == 0:
            alpha_threshold = None
        else:
            alpha_threshold = self.model.alpha_threshold(self.global_step)

        if self.opt.render != 'grid':
            outputs = self.model.render(rays_o, rays_d, bg_color=bg_color, perturb=False,
                                        cam_near_far=cam_near_far, mvp=data['mvp'], rays_d_all=data['rays_d_all'],
                                        rays_i=data['rays_i'], rays_j=data['rays_j'], H=data['H'], W=data['W'],
                                        alpha_threshold=alpha_threshold)
        else:
            outputs = self.model.render(rays_o, rays_d, bg_color=bg_color, perturb=False,
                                        cam_near_far=cam_near_far, alpha_threshold=alpha_threshold)

        # pred = {'rgb': outputs['image'].reshape(H, W, 3),
        #         'depth': outputs['depth'].reshape(H, W), }
        #
        # if 'diffuse' in outputs.keys():
        #     pred['diffuse'] = outputs['diffuse'].reshape(H, W, 3)
        # if 'specular' in outputs.keys():
        #     pred['specular'] = outputs['specular'].reshape(H, W, 3)
        # if 'dirs' in outputs.keys():
        #     pred['dirs'] = outputs['dirs'].reshape(H, W, 3)
        #
        # if 'depth_mesh' in outputs.keys():
        #     pred['depth_mesh'] = outputs['depth_mesh'].reshape(H, W)
        #     pred['mask_mesh'] = outputs['mask_mesh'].reshape(H, W)
        #     pred['alphas_mesh'] = outputs['alphas_mesh'].reshape(H, W)

        pred = {'rgb': outputs['image'].reshape(H, W, 3)}

        for k in outputs.keys():
            if k == 'image': continue
            if outputs[k].numel() == H * W * 3:
                pred[k] = outputs[k].reshape(H, W, 3)
            elif outputs[k].numel() == H * W:
                pred[k] = outputs[k].reshape(H, W)
            else:
                pred[k] = outputs[k]

        loss = self.criterion(pred['rgb'], gt_rgb).mean()

        return pred, gt_rgb, loss

    # moved out bg_color and perturb for more flexible control...
    def test_step(self, data, bg_color=None, perturb=False, shading='full'):

        rays_o = data['rays_o']  # [N, 3]
        rays_d = data['rays_d']  # [N, 3]
        index = data['index']  # [1/N]
        H, W = data['H'], data['W']

        cam_near_far = data['cam_near_far'] if 'cam_near_far' in data else None  # [1/N, 2] or None

        if bg_color is not None:
            bg_color = bg_color.to(self.device)

        if self.opt.render == 'mesh' or self.opt.alpha_thres == 0:
            alpha_threshold = None
        else:
            alpha_threshold = self.model.alpha_threshold(self.global_step)

        if self.opt.render != 'grid':
            outputs = self.model.render(rays_o, rays_d, index=index, bg_color=bg_color, perturb=perturb,
                                        cam_near_far=cam_near_far, shading=shading,
                                        mvp=data['mvp'], rays_d_all=data['rays_d_all'],
                                        H=data['H'], W=data['W'], alpha_threshold=alpha_threshold)
        else:
            outputs = self.model.render(rays_o, rays_d, index=index, bg_color=bg_color, perturb=perturb,
                                        cam_near_far=cam_near_far, shading=shading, alpha_threshold=alpha_threshold)

        pred = {'rgb': outputs['image'].reshape(H, W, 3)}

        for k in outputs.keys():
            if k == 'image': continue
            if outputs[k].numel() == H * W * 3:
                pred[k] = outputs[k].reshape(H, W, 3)
            elif outputs[k].numel() == H * W:
                pred[k] = outputs[k].reshape(H, W)
            else:
                pred[k] = outputs[k]

        return pred

    @torch.no_grad()
    def save_baking(self, loader=None, occupancy_grid=None, save_path=None):
        if save_path is None:
            save_path = os.path.join(self.workspace, 'assets')
        os.makedirs(save_path, exist_ok=True)

        assert self.opt.triplane_resolution <= 0

        self.model.eval()

        self.log(f"==> Start Baking, save results to {save_path}")

        device = self.device
        bound = self.model.bound  # always 2 in contracted mode
        BS = 8  # block size
        LS = self.opt.grid_resolution
        AS = LS // BS  # 64, atlas size
        # thresh = 0.005  # 0.005 in paper
        scene_params = {
            'sparse_grid_resolution': LS,
            "sparse_grid_voxel_size": 2 * bound / LS,  # 4/1024
            'data_block_size': BS,
            'has_vol': True,
            'slice_depth': 1,
            "format": "png",
            'range_features': [-7.0, 7.0],
            'range_density': [-14.0, 14.0],
        }

        # covers the [-2, 2]^3 space
        occupancy_rate = occupancy_grid.cpu().sum().item() / occupancy_grid.numel() * 100.
        if occupancy_grid is None:
            if loader is not None:
                occupancy_grid = self.cal_occ_grid(loader)
            else:
                raise ValueError('No loader or occupancy_grid is provided.')
        else:
            self.log(f'[INFO] Occupancy rate: '
                     f'{occupancy_grid.cpu().sum().item() / occupancy_grid.numel() * 100.:.4f}% '
                     f'with resolution {occupancy_grid.shape[0]}')

        # if self.opt.render == 'mixed':
        #     self.model.update_mesh_occ_mask(loader)
        #     # occupancy_grid *= ~self.model.mesh_occ_mask
        #     mesh_occ_mask = self.model.mesh_occ_mask if self.opt.mesh_check_ratio == 1 else (
        #         F.interpolate(self.model.mesh_occ_mask.float().unsqueeze(0).unsqueeze(0), size=occupancy_grid.shape,
        #                       mode='nearest').squeeze(0).squeeze(0))
        #     self.log(f'[INFO] mesh_check_ratio: {self.opt.mesh_check_ratio}')
        #     self.log(f'[INFO] mesh_occ_grid: {mesh_occ_mask.sum().item()} and '
        #              f'occupancy_grid: {occupancy_grid.sum().item()}')
        #     self.log(
        #         f'[INFO] mesh_occ_grid in occupancy_grid: '
        #         f'{torch.where(self.model.mesh_occ_mask.cpu(), occupancy_grid,
        #                        torch.zeros_like(occupancy_grid)).sum().item()}')
        #     occupancy_grid = torch.where(mesh_occ_mask, torch.zeros_like(occupancy_grid), occupancy_grid)
        #     self.log(f'[INFO] Occupancy rate after mesh_occ_grid: '
        #              f'{occupancy_grid.sum().item() / occupancy_grid.numel() * 100.:.4f}% '
        #              f'with resolution {occupancy_grid.shape[0]}')
        #     torch.save(self.model.mesh_occ_mask.cpu(), os.path.join(self.opt.workspace, 'mesh_occ_mask_cpu_vosh.pt'))

        # maxpooling
        occupancy_grid_L1 = (F.max_pool3d(occupancy_grid.float().unsqueeze(0).unsqueeze(0), 8, stride=8)
                             .squeeze(0).squeeze(0)).bool()  # LS // 8
        self.log(f'[INFO] Occ blocks: {occupancy_grid_L1.sum().item()} with resolution {occupancy_grid_L1.shape[0]}')

        if self.opt.render == 'mixed' and self.opt.use_mesh_occ_grid:
            self.model.update_mesh_occ_mask(loader)
            self.log(f'[INFO] mesh_check_ratio: {self.opt.mesh_check_ratio}')
            mesh_occ_mask = F.interpolate(self.model.mesh_occ_mask.float().unsqueeze(0).unsqueeze(0),
                                          size=occupancy_grid_L1.shape, mode='nearest').squeeze(0).squeeze(0).bool()
            self.log(f'[INFO] mesh_occ_mask: {mesh_occ_mask.sum().item()} with resolution {mesh_occ_mask.shape[0]}')
            self.log(f'[INFO] mesh_occ_grid in occupancy_grid_L1: '
                     f'{torch.where(mesh_occ_mask, occupancy_grid_L1, torch.zeros_like(occupancy_grid_L1)).sum().item()}')
            occupancy_grid_L1 = torch.where(mesh_occ_mask, torch.zeros_like(occupancy_grid_L1), occupancy_grid_L1)
            self.log(f'[INFO] Occ blocks after mesh_occ_grid: {occupancy_grid_L1.sum().item()} '
                     f'with resolution {occupancy_grid_L1.shape[0]}')

        occupancy_grid_L1 = occupancy_grid_L1.cpu()
        occupancy_grid_L2 = (F.max_pool3d(occupancy_grid_L1.float().unsqueeze(0).unsqueeze(0), 2, stride=2)
                             .squeeze(0).squeeze(0)).bool()  # LS // 16
        occupancy_grid_L3 = (F.max_pool3d(occupancy_grid_L2.float().unsqueeze(0).unsqueeze(0), 2, stride=2)
                             .squeeze(0).squeeze(0)).bool()  # LS // 32
        occupancy_grid_L4 = (F.max_pool3d(occupancy_grid_L3.float().unsqueeze(0).unsqueeze(0), 2, stride=2)
                             .squeeze(0).squeeze(0)).bool()  # LS // 64
        occupancy_grid_L5 = (F.max_pool3d(occupancy_grid_L4.float().unsqueeze(0).unsqueeze(0), 2, stride=2)
                             .squeeze(0).squeeze(0)).bool()  # LS // 128

        # grid features
        coords = torch.nonzero(occupancy_grid_L1)  # [N, 3]

        if coords.shape[0] == 0:
            coords = torch.tensor([[0, 0, 0]])

        # total occ blocks to save
        N_occ = coords.shape[0]
        # per depth slice we can save at most 255x255 occ blocks.
        # due to limit of max texture size (2048), we actually only have 2048 / 9 = 227 max size
        atlas_blocks_z = N_occ // (227 * 227) + 1
        N_per_slice = math.ceil(N_occ / atlas_blocks_z)
        atlas_blocks_x = math.ceil(math.sqrt(N_per_slice))
        atlas_blocks_y = math.ceil(N_per_slice / atlas_blocks_x)
        atlas_width = 9 * atlas_blocks_x
        atlas_height = 9 * atlas_blocks_y
        atlas_depth = 9 * atlas_blocks_z

        self.log(f'atlas block size: ({atlas_blocks_x}, {atlas_blocks_y}, {atlas_blocks_z})')

        has_vol = False if occupancy_rate * 10000.0 < 1.0 else True

        scene_params['atlas_width'] = atlas_width
        scene_params['atlas_height'] = atlas_height
        scene_params['atlas_depth'] = atlas_depth
        scene_params['num_slices'] = atlas_depth
        scene_params['atlas_blocks_x'] = atlas_blocks_x
        scene_params['atlas_blocks_y'] = atlas_blocks_y
        scene_params['atlas_blocks_z'] = atlas_blocks_z

        scene_params['has_vol'] = has_vol

        self.log('has vol: {}'.format(has_vol))

        atlas_indices = 255 * torch.ones([AS, AS, AS, 3], dtype=torch.long)  # [64, 64, 64, 3]
        atlas_data = torch.zeros([atlas_width, atlas_height, atlas_depth, 8], dtype=torch.float32)  # [W, H, D, 8]

        # grid indice
        indices = torch.arange(N_occ)
        indices_w = indices // (atlas_blocks_y * atlas_blocks_z)
        indices_hd = indices % (atlas_blocks_y * atlas_blocks_z)
        indices_h = indices_hd // atlas_blocks_z
        indices_d = indices_hd % atlas_blocks_z

        # xyz --> whd
        atlas_indices[tuple(coords.T)] = torch.stack([indices_w, indices_h, indices_d], dim=1)  # [N, 3]

        # convert coords to 9x9x9 xyzs
        xyzs_000 = (coords / AS) * 2 * bound - bound  # [N, 3], left-top point of the 9x9x9 block
        grid_size = 2 * bound / AS
        step_size = grid_size / BS

        # loops
        if self.opt.grid_resolution > 0:
            self.log(f'[INFO] baking grid features...')
            for i in range(BS + 1):
                for j in range(BS + 1):
                    for k in range(BS + 1):
                        # print(f'[INFO] baking features for grid ({i}, {j}, {k})...')
                        # [N, 3]
                        # xyzs = (xyzs_000 + torch.tensor([i, j, k], dtype=torch.float32) * step_size).to(device)
                        xyzs = (xyzs_000 + step_size * 0.5 +
                                torch.tensor([i, j, k], dtype=torch.float32) * step_size).to(device)
                        # query features
                        # f_grid = self.model.quantize_feature(self.model.grid(xyzs, self.model.bound),
                        #                                      baking=True).cpu()
                        # features, density = self.model.query_representation(xyzs, baking=True)
                        features, density = self.model.DensityAndFeaturesMLP(xyzs, bound=bound)
                        features = quantize.quantize_float_to_byte(torch.sigmoid(features))
                        density = quantize.quantize_float_to_byte(torch.sigmoid(density))

                        f_grid = (torch.cat([features[..., :3], density,
                                             features[..., 3:]], dim=-1)).cpu()

                        # plot_pointcloud(xyzs.cpu().numpy(), torch.sigmoid(features[..., :3]).cpu().numpy())

                        # place in grid (note the coordinate conventions!)
                        atlas_data[indices_w * (BS + 1) + j, indices_h * (BS + 1) + k,
                                   indices_d * (BS + 1) + i] = f_grid

        # save all assets

        def imwrite(f, x):
            if x.shape[-1] == 1:
                cv2.imwrite(f, x)
            elif x.shape[-1] == 3:
                cv2.imwrite(f, cv2.cvtColor(x, cv2.COLOR_RGB2BGR))
            else:
                cv2.imwrite(f, cv2.cvtColor(x, cv2.COLOR_RGBA2BGRA))

        # occupancy_grid_8-128.png
        imwrite(os.path.join(save_path, 'occupancy_grid_8.png'),
                occupancy_grid_L1.cpu().numpy().transpose(0, 2, 1).reshape(
                    occupancy_grid_L1.shape[0] * occupancy_grid_L1.shape[0], occupancy_grid_L1.shape[0], 1)
                .astype(np.uint8).repeat(4, axis=-1))
        imwrite(os.path.join(save_path, 'occupancy_grid_16.png'),
                occupancy_grid_L2.cpu().numpy().transpose(0, 2, 1).reshape(
                    occupancy_grid_L2.shape[0] * occupancy_grid_L2.shape[0], occupancy_grid_L2.shape[0], 1)
                .astype(np.uint8).repeat(4, axis=-1))
        imwrite(os.path.join(save_path, 'occupancy_grid_32.png'),
                occupancy_grid_L3.cpu().numpy().transpose(0, 2, 1).reshape(
                    occupancy_grid_L3.shape[0] * occupancy_grid_L3.shape[0], occupancy_grid_L3.shape[0], 1)
                .astype(np.uint8).repeat(4, axis=-1))
        imwrite(os.path.join(save_path, 'occupancy_grid_64.png'),
                occupancy_grid_L4.cpu().numpy().transpose(0, 2, 1).reshape(
                    occupancy_grid_L4.shape[0] * occupancy_grid_L4.shape[0], occupancy_grid_L4.shape[0], 1)
                .astype(np.uint8).repeat(4, axis=-1))
        imwrite(os.path.join(save_path, 'occupancy_grid_128.png'),
                occupancy_grid_L5.cpu().numpy().transpose(0, 2, 1).reshape(
                    occupancy_grid_L5.shape[0] * occupancy_grid_L5.shape[0], occupancy_grid_L5.shape[0], 1)
                .astype(np.uint8).repeat(4, axis=-1))

        # atlas_indices.png
        imwrite(os.path.join(save_path, 'sparse_grid_block_indices.png'),
                atlas_indices.cpu().clamp(0, 255).numpy().transpose(0, 2, 1, 3).reshape(AS * AS, AS, 3).astype(
                    np.uint8))

        # feature_xxx.png & rgba_xxx.png
        for i in range(atlas_depth):
            # density = atlas_data[..., i, :1]
            # rgb = atlas_data[..., i, 1:4]
            # rgb_and_density = torch.cat([rgb, density], dim=-1)  # they place density after rgb...
            rgb_and_density = atlas_data[..., i, :4]
            feature = atlas_data[..., i, 4:]

            print(f'[INFO] grid pre {i}: '
                  f'rgb: {atlas_data[..., i, :3].min().item()} ~ {atlas_data[..., i, :3].max().item()} '
                  f'density: {atlas_data[..., i, 3].min().item()} ~ {atlas_data[..., i, 3].max().item()} '
                  f'f: {atlas_data[..., i, 4:].min().item()} ~ {atlas_data[..., i, 4:].max().item()}')
            # print(f'[INFO] grid slice {i}: '
            #       f'density: {density.min().item()} ~ {density.max().item()} '
            #       f'rgb: {rgb.min().item()} ~ {rgb.max().item()} '
            #       f'f: {feature.min().item()} ~ {feature.max().item()}')

            imwrite(os.path.join(save_path, f'sparse_grid_rgb_and_density_{i:03d}.png'),
                    rgb_and_density.cpu().numpy().transpose(1, 0, 2).astype(np.uint8))
            imwrite(os.path.join(save_path, f'sparse_grid_features_{i:03d}.png'),
                    feature.cpu().numpy().transpose(1, 0, 2).astype(np.uint8))

        # mlp
        params = dict(self.model.DeferredMLP.view_mlp.named_parameters())

        for k, p in params.items():
            p_np = p.detach().cpu().numpy().T
            # 'net.0.weight' --> '0_weights'
            # 'net.0.bias' --> '0_bias'
            scene_params[k[4:].replace('.', '_').replace('weight', 'weights')] = p_np.tolist()

        # save scene_params.json
        with open(os.path.join(save_path, 'scene_params.json'), 'w') as f:
            json.dump(scene_params, f, indent=2)

        if self.opt.render == 'mixed':
            self.log(f"==> Start Exporting mesh, save results to {save_path}")
            os.makedirs(os.path.join(save_path, 'mesh'), exist_ok=True)
        else:
            self.log(f"==> Finished baking.")

    @torch.no_grad()
    def cal_occ_grid(self, loader, resolution=None):
        self.model.eval()
        self.model.set_training(False)
        bound = self.model.bound  # always 2 in contracted mode
        if resolution is None:
            resolution = self.opt.grid_resolution

        # covers the [-2, 2]^3 space
        occupancy_grid = torch.zeros([resolution, resolution, resolution], dtype=torch.bool,
                                     device=self.device, requires_grad=False)

        if self.opt.render == 'mesh' or self.opt.alpha_thres == 0:
            alpha_threshold = None
        else:
            alpha_threshold = self.model.alpha_threshold(self.global_step)

        self.log(f"==> Calculating occupancy_grid...")
        for i, data in enumerate(tqdm.tqdm(loader)):
            rays_o = data['rays_o']  # [N, 3]
            rays_d = data['rays_d']  # [N, 3]

            cam_near_far = data['cam_near_far'] if 'cam_near_far' in data else None  # [1/N, 2] or None

            if self.opt.render == 'mixed':
                outputs = self.model.render(rays_o, rays_d, cam_near_far=cam_near_far, shading='diffuse', baking=True,
                                            mvp=data['mvp'], rays_d_all=data['rays_d_all'],
                                            H=data['H'], W=data['W'], alpha_threshold=alpha_threshold)
            else:
                outputs = self.model.render(rays_o, rays_d, cam_near_far=cam_near_far, shading='diffuse', baking=True,
                                            H=data['H'], W=data['W'], alpha_threshold=alpha_threshold)

            xyzs = outputs['xyzs'].reshape(-1, 3)
            weights = outputs['weights'].reshape(-1)
            alphas = outputs['alphas'].reshape(-1)

            # thresholding
            mask = (weights > self.opt.alpha_thres) & (alphas > self.opt.alpha_thres)
            xyzs = xyzs[mask]  # [N*T, 3] in [-2, 2]

            # mark occupancy, write data
            # [-bound, bound] -> [0, resolution-1]
            coords = torch.floor((xyzs + bound) / (2 * bound) * resolution).long().clamp(0, resolution - 1)  # [N*T, 3]
            occupancy_grid[tuple(coords.T)] = 1

        # occupancy_grid = occupancy_grid.cpu()
        self.log(f'[INFO] Occupancy rate: '
                 f'{occupancy_grid.cpu().sum().item() / (occupancy_grid.shape[0] ** 3) * 100.:.2f}% '
                 f'with resolution {resolution} and N = {rays_o.shape[0]}')
        self.model.set_training(True)
        return occupancy_grid

    # @torch.no_grad()
    # def run_mesh_error_epoch(self, loader, resolution=None):
    #     bound = self.model.bound  # always 2 in contracted mode
    #     if resolution is None:
    #         resolution = self.opt.mcubes_reso
    #     mesh_occ_mask = torch.zeros([resolution] * 3, dtype=torch.bool, device=self.device)
    #
    #     self.log(f"==> Calculating mesh_error_grid...")
    #
    #     # get mesh error
    #     for i, data in enumerate(tqdm.tqdm(loader)):
    #         results_mesh = self.model.render_mesh(rays_o=None, rays_d=data['rays_d_all'],
    #                                               h0=data['H'], w0=data['W'], mvp=data['mvp'])
    #         dirs = data['rays_d_all'] / torch.norm(data['rays_d_all'], dim=-1, keepdim=True)
    #         dirs = self.model.view_encoder(dirs)
    #         image = (results_mesh['features'][:, :3] +
    #                  self.model.view_mlp(torch.cat([results_mesh['features'][:, :3],
    #                                                 results_mesh['features'][:, 3:],
    #                                                 dirs], dim=1)))
    #         image = image.clamp(0, 1)
    #         loss = self.criterion(image, data['images_all']).mean(-1)  # [N, 3] --> [N]
    #
    #         self.model.update_triangles_errors(loss.detach())
    #
    #         xyzs = contract(results_mesh['xyzs'][results_mesh['mask']])
    #         coords = torch.floor((xyzs + bound) / (2 * bound) * resolution).long().clamp(0, resolution - 1)  # [N, 3]
    #         mesh_occ_mask[tuple(coords.T)] = 1
    #
    #     return self.model.get_error_grid(), mesh_occ_mask

    @torch.no_grad()
    def cal_diff(self, loader):
        mesh_error_grid_list = sorted(glob.glob(f'{self.opt.workspace}/mesh_error_grid*.pt'))
        if mesh_error_grid_list:
            print(f'[INFO] Load mesh_error_grid from {mesh_error_grid_list[-1]}')
            mesh_error_grid = torch.load(mesh_error_grid_list[-1]).to(self.device).half()
        else:
            mesh_error_grid = self.run_mesh_error_epoch(loader).half()
            torch.save(mesh_error_grid, os.path.join(self.opt.workspace, 'mesh_error_grid.pt'))

        # voxel_error_grid_list = sorted(glob.glob(f'{self.opt.workspace}/voxel_error_grid*.pt'))
        # if voxel_error_grid_list:
        #     print(f'[INFO] Load voxel_error_grid from {voxel_error_grid_list[-1]}')
        #     voxel_error_grid = torch.load(voxel_error_grid_list[-1]).to(self.device)
        # else:
        #     voxel_error_grid = self.run_voxel_error_epoch(loader)
        #     torch.save(voxel_error_grid, os.path.join(self.opt.workspace, 'voxel_error_grid.pt'))

        assert self.voxel_error_grid_path is not None, 'Please specify the path of voxel_error_grid'
        print(f'[INFO] Load voxel_error_grid from {self.voxel_error_grid_path}')
        voxel_error_grid = torch.load(self.voxel_error_grid_path).to(self.device).half()

        if mesh_error_grid.shape[0] < voxel_error_grid.shape[0]:
            voxel_error_grid = F.interpolate(voxel_error_grid.unsqueeze(0).unsqueeze(0).float(),
                                             size=[mesh_error_grid.shape[0]] * 3,
                                             mode='trilinear').squeeze(0).squeeze(0)
        elif mesh_error_grid.shape[0] > voxel_error_grid.shape[0]:
            mesh_error_grid = F.interpolate(mesh_error_grid.unsqueeze(0).unsqueeze(0).float(),
                                            size=[voxel_error_grid.shape[0]] * 3,
                                            mode='trilinear').squeeze(0).squeeze(0)

        error_grid = mesh_error_grid - voxel_error_grid
        error_grid[voxel_error_grid == 0] = 0
        error_grid[mesh_error_grid == 0] = 0
        # error_grid = (error_grid - error_grid.min()) / (error_grid.max() - error_grid.min())
        torch.save(error_grid, os.path.join(self.opt.workspace, 'error_grid.pt'))
        return error_grid

    @torch.no_grad()
    def run_mesh_error_epoch(self, loader, resolution=None):
        # bound = self.model.bound  # always 2 in contracted mode
        if resolution is None:
            resolution = self.opt.grid_resolution
        mesh_error_grid = torch.zeros([resolution] * 3, dtype=torch.float32, device=self.device)
        mesh_cnt_grid = torch.zeros([resolution] * 3, dtype=torch.long, device=self.device)

        self.log(f"==> Calculating mesh_error_grid...")

        # get mesh error
        for i, data in enumerate(tqdm.tqdm(loader)):
            results = self.model.run_surface_render(rays_o=data['rays_o_all'], rays_d=data['rays_d_all'],
                                                    H=data['H'], W=data['W'], mvp=data['mvp'],
                                                    rays_d_all=data['rays_d_all'])
            image = results['image']
            mask = results['mask_mesh']

            C = data['images'].shape[-1]
            bg_color = 1
            if C == 4:
                gt_mask = data['images'][..., 3:]
                gt_rgb = data['images'][..., :3] * data['images'][..., 3:] + bg_color * (1 - data['images'][..., 3:])
            else:
                gt_mask = None
                gt_rgb = data['images']

            loss = self.criterion(image, gt_rgb).mean(-1)[mask]  # [the number of 1 in mask]
            if self.opt.contract:
                xyzs_old = results['xyzs'][mask]
                xyzs = contract(xyzs_old)  # [the number of 1 in mask, 3] in (-2, 2)
            else:
                xyzs = results['xyzs'][mask]  # [the number of 1 in mask, 3]

            if xyzs.shape[0] > 0:
                coords = (torch.floor((xyzs + self.model.bound) / (2 * self.model.bound) * resolution)
                          .long().clamp(0, resolution - 1))  # [N*T, 3]
                # mesh_cnt_grid[tuple(coords.T)] += 1
                # mesh_error_grid[tuple(coords.T)] += loss
                mesh_cnt_grid.index_put_((coords[:, 0], coords[:, 1], coords[:, 2]),
                                         torch.ones_like(coords[:, 0]), accumulate=True)
                mesh_error_grid.index_put_((coords[:, 0], coords[:, 1], coords[:, 2]), loss, accumulate=True)

            # self.model.update_triangles_errors(loss.detach())

        mesh_cnt_grid[mesh_cnt_grid < 1] = 1
        mesh_error_grid = mesh_error_grid.half()
        mesh_error_grid = mesh_error_grid.div(mesh_cnt_grid)
        return mesh_error_grid

    @torch.no_grad()
    def run_voxel_error_epoch(self, loader, resolution=None, thresh=0.005):
        self.model.set_training(False)
        if resolution is None:
            resolution = self.opt.grid_resolution
        voxel_error_grid = torch.zeros([resolution, resolution, resolution], dtype=torch.float32, device=self.device)
        voxel_cnt_grid = torch.zeros([resolution, resolution, resolution], dtype=torch.long, device=self.device)

        self.log(f"==> Calculating voxel_error_grid...")
        if self.opt.render == 'mesh' or self.opt.alpha_thres == 0:
            alpha_threshold = None
        else:
            alpha_threshold = self.model.alpha_threshold(self.global_step)

        # get grid error
        for i, data in enumerate(tqdm.tqdm(loader)):
            cam_near_far = data['cam_near_far'] if 'cam_near_far' in data else None  # [1/N, 2] or None
            results = self.model.render(data['rays_o'], data['rays_d'], cam_near_far=cam_near_far,
                                        H=data['H'], W=data['W'], alpha_threshold=alpha_threshold)
            # results = self.model.render(data['rays_o'], data['rays_d'], cam_near_far=cam_near_far,
            # run_voxel_error=True, mvp=data['mvp'], rays_d_all=data['rays_d_all'],
            # rays_i=data['rays_i'], rays_j=data['rays_j'], H=data['H'], W=data['W'])

            image = results['image']  # [N, 3]
            weights = results['weights']  # [N, T]
            alphas = results['alphas']

            # thresholding
            mask = (weights > thresh) & (alphas > thresh)
            weights[~mask] = 0

            C = data['images'].shape[-1]
            bg_color = 1
            if C == 4:
                gt_mask = data['images'][..., 3:]
                gt_rgb = data['images'][..., :3] * data['images'][..., 3:] + bg_color * (1 - data['images'][..., 3:])
            else:
                gt_mask = None
                gt_rgb = data['images']

            # [N*T]
            loss = (self.criterion(image, gt_rgb.view(-1, 3)).mean(-1).detach()[:, None] * weights).view(-1)
            mask = loss > 0

            xyzs = results['xyzs'].view(-1, 3)[mask]  # [the number of 1 in mask, 3] in (-2, 2)

            if xyzs.shape[0] > 0:
                coords = (torch.floor((xyzs + self.model.bound) / (2 * self.model.bound) * resolution)
                          .long().clamp(0, resolution - 1))  # [N*T, 3]
                # voxel_cnt_grid[tuple(coords.T)] += 1
                # voxel_error_grid[tuple(coords.T)] += loss[mask]
                voxel_cnt_grid.index_put_((coords[:, 0], coords[:, 1], coords[:, 2]),
                                          torch.ones_like(coords[:, 0]), accumulate=True)
                voxel_error_grid.index_put_((coords[:, 0], coords[:, 1], coords[:, 2]), loss[mask], accumulate=True)

            # self.model.update_triangles_errors(loss.detach())

        voxel_cnt_grid[voxel_cnt_grid < 1] = 1
        voxel_error_grid = voxel_error_grid.half()
        voxel_error_grid = voxel_error_grid.div(voxel_cnt_grid)
        # print('[INFO] Voxel_error_grid: ', voxel_error_grid.shape, voxel_error_grid.min(), voxel_error_grid.max())
        return voxel_error_grid

    # @torch.no_grad()
    # def cal_mesh_occ_grid(self, loader, occ_grid=None, resolution=None):
    #
    #     self.model.eval()
    #     bound = self.model.bound  # always 2 in contracted mode
    #     # thresh = 0.005  # 0.005 in paper
    #     if resolution is None:
    #         resolution = self.opt.grid_resolution
    #
    #     # covers the [-2, 2]^3 space
    #     mesh_occ_mask = torch.zeros([resolution, resolution, resolution], dtype=torch.bool, device=self.device)
    #
    #     self.log(f"==> Calculating mesh_occ_grid...")
    #     # mark mesh position
    #     for i, data in enumerate(tqdm.tqdm(loader)):
    #         results_mesh = self.model.render_mesh(rays_o=None, rays_d=data['rays_d_all'],
    #                                               h0=data['H'], w0=data['W'], mvp=data['mvp'], early_return=True)
    #
    #         # mark occupancy, write data
    #         # [-bound, bound] -> [0, resolution-1]
    #         xyzs = contract(results_mesh['xyzs'][results_mesh['mask']])
    #         coords = torch.floor((xyzs + bound) / (2 * bound) * resolution).long().clamp(0, resolution - 1)  # [N, 3]
    #         mesh_occ_mask[tuple(coords.T)] = 1
    #
    #     # self.log(f'[INFO] Mesh_occ_mask rate: '
    #     #          f'{mesh_occ_mask.float().sum().item() / mesh_occ_mask.numel() * 100.:.2f}% '
    #     #          f'with resolution {resolution}')
    #     # torch.save(mesh_occ_mask, os.path.join(self.opt.workspace, 'mesh_occ_grid.pt'))
    #
    #     if occ_grid is not None:
    #         self.log(f'[INFO] Mesh in occ_grid rate: '
    #                  f'{(mesh_occ_mask * occ_grid).float().sum().item() / occ_grid.numel() * 100.:.2f}% '
    #                  f'with resolution {resolution}')
    #     return mesh_occ_mask

    @torch.no_grad()
    def mark_unseen_triangles(self, vertices, triangles, mvps, H, W):
        # vertices: coords in world system
        # mvps: [B, 4, 4]
        device = self.device

        if isinstance(vertices, np.ndarray):
            vertices = torch.from_numpy(vertices).contiguous().float().to(device)

        if isinstance(triangles, np.ndarray):
            triangles = torch.from_numpy(triangles).contiguous().int().to(device)

        mask = torch.zeros_like(triangles[:, 0])  # [M,], for face.

        if self.glctx is None:
            self.glctx = dr.RasterizeGLContext(output_db=False)

        for mvp in tqdm.tqdm(mvps):
            vertices_clip = torch.matmul(F.pad(vertices, pad=(0, 1), mode='constant', value=1.0),
                                         torch.transpose(mvp.to(device), 0, 1)).float().unsqueeze(0)  # [1, N, 4]

            # ENHANCE: lower resolution since we don't need that high?
            rast, _ = dr.rasterize(self.glctx, vertices_clip, triangles, (H, W))  # [1, H, W, 4]

            # collect the triangle_id (it is offseted by 1)
            trig_id = rast[..., -1].long().view(-1) - 1

            # no need to accumulate, just a 0/1 mask.
            mask[trig_id] += 1  # wrong for duplicated indices, but faster.
            # mask.index_put_((trig_id,), torch.ones(trig_id.shape[0], device=device, dtype=mask.dtype), accumulate=True)

        mask = (mask == 0)  # unseen faces by all cameras

        self.log(f'[mark unseen trigs] {mask.sum()} from {mask.shape[0]}')

        return mask  # [N]

    @torch.no_grad()
    def get_occ_grid(self, loader):
        if self.occ_grid_path is not None:
            occ_grid = torch.load(self.occ_grid_path).to(self.model.device)
            self.log(f'[INFO] Load occ_grid from {self.occ_grid_path}')
            self.log(f'[INFO] Occupancy rate: '
                     f'{occ_grid.float().sum().item() / occ_grid.numel() * 100.:.4f}% '
                     f'with resolution {occ_grid.shape[0]}')
        else:
            occ_grid = self.cal_occ_grid(loader)
            # torch.save(occ_grid, os.path.join(self.opt.workspace, 'merf_occ_grid_cpu.pt'))
        return occ_grid

    # @torch.no_grad()
    # def voxel_to_mesh(self, resolution=1024, S=128, loader=None):
    #
    #     def export_mesh(s_path, vv, tt, q_bound):
    #         m = trimesh.Trimesh(vv, tt, process=False)
    #         s_path = os.path.join(s_path, f'mesh_{q_bound}.ply')
    #         m.export(s_path)
    #         self.log(f'[INFO] Saved mesh into {s_path}')
    #         self.log('[INFO] <============================>')
    #
    #     self.log('[INFO] <============================>')
    #     self.log('[INFO] Extracting mesh from grid...')
    #     save_path = os.path.join(self.workspace, 'mesh')
    #     os.makedirs(save_path, exist_ok=True)
    #     device = self.device
    #     bound = self.model.bound  # always 2 in contracted mode
    #     real_bound = self.model.real_bound  # 128
    #     # sequentially load cascaded meshes
    #     """
    #         Computes the isosurface of a signed distance function (SDF) defined by the
    #         callable `sdf` in a given bounding box with a specified resolution. The SDF
    #         is sampled at a set of points within a regular grid, and the marching cubes
    #         algorithm is used to generate a mesh that approximates the isosurface at a
    #         specified isovalue `level`.
    #
    #         Args:
    #             sdf: A callable function that takes as input a tensor of size
    #                 (N, 3) containing 3D points, and returns a tensor of size (N,) containing
    #                 the signed distance function evaluated at those points.
    #             output_path: The output directory where the resulting mesh will be saved.
    #             resolution: The resolution of the grid used to sample the SDF.
    #             bounding_box_min: The minimum coordinates of the bounding box in which the SDF
    #                 will be evaluated.
    #             bounding_box_max: The maximum coordinates of the bounding box in which the SDF
    #                 will be evaluated.
    #             isosurface_threshold: The isovalue at which to approximate the isosurface.
    #             coarse_mask: A binary mask tensor of size ("height", "width", "depth") that indicates the regions
    #                 of the bounding box where the SDF is expected to have a zero-crossing. If
    #                 provided, the algorithm first evaluates the SDF at the coarse voxels where
    #                 the mask is True, and then refines the evaluation within these voxels using
    #                 a multi-scale approach. If None, evaluates the SDF at all points in the
    #                 bounding box.
    #         Returns:
    #             A torch tensor with the SDF values evaluated at the given points.
    #         """
    #
    #     resolution: int = 1024
    #     bounding_box_min, bounding_box_max = (-2.0, -2.0, -2.0), (2.0, 2.0, 2.0)
    #     isosurface_threshold = 10.0
    #
    #     # Check if resolution is divisible by 512
    #     assert (
    #             resolution % 512 == 0
    #     ), f"""resolution must be divisible by 512, got {resolution}.
    #            This is important because the algorithm uses a multi-resolution approach
    #            to evaluate the SDF where the mimimum resolution is 512."""
    #
    #     # Initialize variables
    #     crop_n = 128
    #     N = resolution // crop_n
    #     grid_min = bounding_box_min
    #     grid_max = bounding_box_max
    #     xs = np.linspace(grid_min[0], grid_max[0], N + 1)
    #     ys = np.linspace(grid_min[1], grid_max[1], N + 1)
    #     zs = np.linspace(grid_min[2], grid_max[2], N + 1)
    #
    #     # Initialize meshes list
    #     meshes = []
    #
    #     # Iterate over the grid
    #     for i in range(N):
    #         for j in range(N):
    #             for k in range(N):
    #                 # Calculate grid cell boundaries
    #                 x_min, x_max = xs[i], xs[i + 1]
    #                 y_min, y_max = ys[j], ys[j + 1]
    #                 z_min, z_max = zs[k], zs[k + 1]
    #
    #                 # Create point grid
    #                 x = np.linspace(x_min, x_max, crop_n)
    #                 y = np.linspace(y_min, y_max, crop_n)
    #                 z = np.linspace(z_min, z_max, crop_n)
    #                 xx, yy, zz = np.meshgrid(x, y, z, indexing="ij")
    #                 points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).to(
    #                     device)
    #
    #                 # Construct point pyramids
    #                 points = points.reshape(crop_n, crop_n, crop_n, 3)
    #                 _, val = self.model.forward(points.reshape(-1, 3))
    #                 z = val.detach().cpu().numpy()
    #
    #                 if not (np.min(z) > isosurface_threshold or np.max(z) < isosurface_threshold):
    #                     z = z.astype(np.float32)
    #                     verts, faces, normals, _ = measure.marching_cubes(  # type: ignore
    #                         volume=z.reshape(crop_n, crop_n, crop_n),
    #                         level=isosurface_threshold,
    #                         spacing=(
    #                             (x_max - x_min) / (crop_n - 1),
    #                             (y_max - y_min) / (crop_n - 1),
    #                             (z_max - z_min) / (crop_n - 1),
    #                         ),
    #                     )
    #                     verts = verts + np.array([x_min, y_min, z_min])
    #
    #                     meshcrop = trimesh.Trimesh(verts, faces, normals)
    #                     meshes.append(meshcrop)
    #
    #     combined_mesh = trimesh.util.concatenate(meshes)
    #     # combined_mesh.export(os.path.join(save_path, f'mesh_ori_{resolution}.ply'))
    #     v, t = combined_mesh.vertices, combined_mesh.faces
    #     v = uncontract(torch.from_numpy(v)).numpy().astype(np.float32)
    #     combined_mesh_uncontracted = trimesh.Trimesh(v, t)
    #     combined_mesh_uncontracted.export(os.path.join(save_path, f'mesh_uncontracted_{resolution}.ply'))
    #
    #     return

    @torch.no_grad()
    def voxel_to_mesh(self, occ_grid, resolution=1024, S=128, loader=None):

        def export_mesh(s_path, vv, tt, q_bound):
            m = trimesh.Trimesh(vv, tt, process=False)
            s_path = os.path.join(s_path, f'mesh_{q_bound}.ply')
            m.export(s_path)
            self.log(f'[INFO] Saved mesh into {s_path}')
            self.log('[INFO] <============================>')

        self.log('[INFO] <============================>')
        self.log('[INFO] Extracting mesh from grid...')
        save_path = os.path.join(self.workspace, 'mesh')
        os.makedirs(save_path, exist_ok=True)
        device = self.device
        bound = self.model.bound  # always 2 in contracted mode
        real_bound = self.model.real_bound  # 128
        # sequentially load cascaded meshes
        vertices = []
        triangles = []
        v_cumsum = [0]
        f_cumsum = [0]

        # query
        X = torch.linspace(-1, 1, resolution).split(S)
        Y = torch.linspace(-1, 1, resolution).split(S)
        Z = torch.linspace(-1, 1, resolution).split(S)

        last_query_bound = 0
        for query_bound in self.model.cascade_list:
            sigmas = torch.zeros([resolution] * 3, dtype=torch.float32, device=device)

            for xi, xs in enumerate(X):
                for yi, ys in enumerate(Y):
                    for zi, zs in enumerate(Z):
                        xx, yy, zz = custom_meshgrid(xs, ys, zs)
                        pts = query_bound * torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1),
                                                       zz.reshape(-1, 1)], dim=-1).to(device)  # [S, 3]
                        with torch.cuda.amp.autocast(enabled=self.opt.fp16):
                            _, val = self.model.forward(pts, bound=bound)
                            # val = self.model.forward(x=pts, shading='diffuse')['sigma']

                        sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys),
                        zi * S: zi * S + len(zs)] = val.reshape(len(xs), len(ys), len(zs))

            query_occ_grid = 1
            if query_bound < bound:
                half_resolution = int(resolution / 2)
                half_itv = int(query_bound / 2.0 * resolution / 2)
                query_occ_grid = occ_grid[half_resolution - half_itv:half_resolution + half_itv,
                                 half_resolution - half_itv:half_resolution + half_itv,
                                 half_resolution - half_itv:half_resolution + half_itv, ]
                query_occ_grid = F.interpolate(query_occ_grid.float().unsqueeze(0).unsqueeze(0), size=[resolution] * 3,
                                               mode='nearest').squeeze(0).squeeze(0) > 0

            sigmas = torch.nan_to_num(sigmas, 0)
            mean_density = torch.mean(sigmas.clamp(min=0)).item()
            sigmas = sigmas * query_occ_grid

            sigmas = sigmas.cpu().numpy()

            density_thresh = min(mean_density, self.model.density_thresh)
            self.log(f'[INFO] using density_thresh={density_thresh} in query_bound {query_bound}')
            v, t = mcubes.marching_cubes(sigmas, density_thresh)

            v = v / (resolution - 1.0) * 2 - 1  # range in [-1, 1]
            v = v * query_bound
            v = v.astype(np.float32)
            t = t.astype(np.int32)

            # half_grid_size = bound / resolution
            # if query_bound > 1.0:
            #     assert query_bound == bound
            #     v = v * (bound - half_grid_size)

            # decimation
            decimate_target = self.opt.decimate_target[len(vertices)]
            if 0 < decimate_target < t.shape[0]:
                v, t = decimate_mesh(v, t, decimate_target, remesh=False)

            # ## visibility test.
            # if loader is not None:
            #     dataset = loader._data
            #     visibility_mask = self.mark_unseen_triangles(v, t, dataset.mvps, dataset.H,
            #                                                  dataset.W).cpu().numpy()
            #     v, t = remove_masked_trigs(v, t, visibility_mask,
            #                                dilation=self.opt.visibility_mask_dilation)

            # remove the center (already covered by previous cascades)
            if last_query_bound:
                _r = last_query_bound - 0.05
                v, t = remove_selected_verts(v, t, f'(x <= {_r}) && (x >= -{_r}) && (y <= {_r}) && '
                                                   f'(y >= -{_r}) && (z <= {_r} ) && (z >= -{_r})')

            if query_bound > 1.0:
                # warp back (uncontract)
                if self.opt.contract:
                    v = uncontract(torch.from_numpy(v)).numpy().astype(np.float32)

                v, t = remove_selected_verts(v, t, f'(x <= {-real_bound}) || '
                                                   f'(x >= {real_bound}) || '
                                                   f'(y <= {-real_bound}) || '
                                                   f'(y >= {real_bound}) || '
                                                   f'(z <= {-real_bound} ) || '
                                                   f'(z >= {real_bound})')
                # export_mesh(save_path, v, t, str(query_bound)+'_uncleaned')
                # remove the faces which have the length over self.opt.max_edge_len
                # v, t = remove_selected_vt_by_edge_length(v, t, self.opt.max_edge_len)

                # remove the isolated component composed by a limited number of triangles
                # v, t = remove_selected_isolated_faces(v, t, self.opt.min_iso_size)

            ## reduce floaters by post-processing...
            v, t = clean_mesh(v, t, min_f=self.opt.clean_min_f,
                              min_d=self.opt.clean_min_d,
                              repair=True, remesh=False)

            # decimate_target = self.opt.decimate_target[len(vertices)]
            # if 0 < decimate_target < t.shape[0]:
            #     v, t = decimate_mesh(v, t, decimate_target, remesh=False)

            # v, t = close_holes_meshfix(v, t)
            export_mesh(save_path, v, t, query_bound)

            last_query_bound = query_bound

            vertices.append(v)
            triangles.append(t + v_cumsum[-1])

            v_cumsum.append(v_cumsum[-1] + v.shape[0])
            f_cumsum.append(f_cumsum[-1] + t.shape[0])

        vertices = np.concatenate(vertices, axis=0)
        triangles = np.concatenate(triangles, axis=0)

        # visibility test.
        if loader is not None:
            dataset = loader._data
            visibility_mask = self.mark_unseen_triangles(vertices, triangles, dataset.mvps, dataset.H,
                                                         dataset.W).cpu().numpy()
            vertices, triangles = remove_masked_trigs(vertices, triangles, visibility_mask,
                                                      dilation=self.opt.visibility_mask_dilation)

        export_mesh(save_path, vertices, triangles, 'all')

        # # remove the faces which have the length over self.opt.max_edge_len
        # vertices, triangles = remove_selected_vt_by_edge_length(vertices, triangles, self.opt.max_edge_len)
        #
        # # remove the isolated component composed by a limited number of triangles
        # vertices, triangles = remove_selected_isolated_faces(vertices, triangles, self.opt.min_iso_size)

        # export_mesh(save_path, vertices, triangles, 'all_clean')

        # inner_idx = (np.array(self.opt.cascade_list) <= 1.0).sum()
        # if inner_idx > 1:
        #     self.log('[INFO] extract mesh which bound < self.opt.cascade_list[inner_idx-1]...')
        #     vertices, triangles = (remove_selected_verts(vertices, triangles,
        #                                                  f'(x <= {-self.opt.cascade_list[inner_idx - 1]}) || '
        #                                                  f'(x >= {self.opt.cascade_list[inner_idx - 1]}) || '
        #                                                  f'(y <= {-self.opt.cascade_list[inner_idx - 1]}) || '
        #                                                  f'(y >= {self.opt.cascade_list[inner_idx - 1]}) || '
        #                                                  f'(z <= {-self.opt.cascade_list[inner_idx - 1]} ) || '
        #                                                  f'(z >= {self.opt.cascade_list[inner_idx - 1]})'))
        #     export_mesh(save_path, vertices, triangles, f'{self.opt.cascade_list[inner_idx - 1]}_updated')

    def train(self, train_loader, valid_loader, max_epochs):
        if self.use_tensorboardX and self.local_rank == 0:
            self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name))

        start_t = time.time()
        if self.opt.render == 'mesh' and self.opt.vis_error:
            assert self.epoch == max_epochs, 'vis_error only works for evaluation.'

        if (self.opt.render == 'mixed' and self.opt.use_mesh_occ_grid and
                (self.opt.vert_offset or self.model.mesh_occ_mask is None)):
            self.model.update_mesh_occ_mask(train_loader)

        for epoch in range(self.epoch + 1, max_epochs + 1):
            self.epoch = epoch

            if (self.opt.render == 'mixed' and self.opt.use_mesh_occ_grid and
                    (self.opt.vert_offset or self.model.mesh_occ_mask is None)):
                self.model.update_mesh_occ_mask(train_loader)

            self.train_one_epoch(train_loader)

            if self.epoch % self.save_interval == 0 or self.epoch == max_epochs:
                self.save_checkpoint(full=True, best=False, is_last=self.epoch == max_epochs)

            if self.epoch % self.eval_interval == 0:
                if self.epoch != max_epochs:
                    self.evaluate_one_epoch(valid_loader)

                # self.save_checkpoint(full=True, best=True)

        end_t = time.time()

        self.log(f"[INFO] training takes {(end_t - start_t) / 60:.6f} minutes.")

        if self.use_tensorboardX and self.local_rank == 0:
            self.writer.close()

    def evaluate(self, loader, name=None):
        self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX
        self.evaluate_one_epoch(loader, name)
        self.use_tensorboardX = use_tensorboardX

    def test(self, loader, save_path=None, name=None, write_video=True):

        if save_path is None:
            save_path = os.path.join(self.workspace, 'results')

        if name is None:
            name = f'{self.name}_ep{self.epoch:04d}'

        os.makedirs(save_path, exist_ok=True)

        self.log(f"==> Start Test, save results to {save_path}")

        pbar = tqdm.tqdm(total=len(loader) * loader.batch_size,
                         bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
        self.model.eval()

        if write_video:
            all_preds = []
            all_preds_depth = []

            all_preds_feat = []
            all_preds_diffuse = []
            all_preds_spec = []
            all_preds_mesh = []
            all_preds_voxel = []

        with torch.no_grad():
            for i, data in enumerate(loader):
                pred_dict = self.test_step(data)

                vis_list_in_pred_dict = ['rgb', 'depth',
                                         'specular', 'diffuse',
                                         'full_mesh', 'full_voxel', ]

                for vis_name in vis_list_in_pred_dict:
                    if vis_name not in pred_dict:
                        continue
                    # if pred_dict[vis_name].dtype != torch.bool:
                    #     print(f'{vis_name}: {pred_dict[vis_name].shape}, '
                    #           f'{pred_dict[vis_name].min()}, {pred_dict[vis_name].max()}')
                    if 'depth' in vis_name:
                        pred_item = pred_dict[vis_name].detach().cpu().numpy()
                        pred_item = (pred_item - pred_item.min()) / (
                                pred_item.max() - pred_item.min() + 1e-6)  # [N]
                    else:
                        if pred_dict[vis_name].dtype != torch.bool and (
                                pred_dict[vis_name].min() < 0 or pred_dict[vis_name].max() > 1.0):
                            print(f'\n{vis_name}: {pred_dict[vis_name].shape}, '
                                  f'{pred_dict[vis_name].min()}, {pred_dict[vis_name].max()} '
                                  f'in evaluate_one_epoch')
                        pred_item = pred_dict[vis_name].clamp(0, 1).detach().cpu().numpy()
                    pred_item = (pred_item * 255).astype(np.uint8)
                    pred_dict[vis_name] = pred_item

                if write_video:
                    all_preds.append(pred_dict['rgb'])
                    all_preds_depth.append(pred_dict['depth'])
                    # all_preds_feat.append(pred_dict['feat'])
                    all_preds_diffuse.append(pred_dict['diffuse'])
                    all_preds_spec.append(pred_dict['specular'])
                    all_preds_mesh.append(pred_dict['full_mesh'])
                    all_preds_voxel.append(pred_dict['full_voxel'])

                # if write_video:
                #     all_preds.append(pred)
                #     all_preds_depth.append(pred_depth)
                # else:
                #     cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_rgb.png'),
                #                 cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))
                #     cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_depth.png'), pred_depth)

                pbar.update(loader.batch_size)

        if write_video:
            all_preds = np.stack(all_preds, axis=0)  # [N, H, W, 3]
            all_preds_depth = np.stack(all_preds_depth, axis=0)  # [N, H, W, 3]
            all_preds_diffuse = np.stack(all_preds_diffuse, axis=0)  # [N, H, W, 3]
            all_preds_spec = np.stack(all_preds_spec, axis=0)  # [N, H, W, 3]
            all_preds_mesh = np.stack(all_preds_mesh, axis=0)  # [N, H, W, 3]
            all_preds_voxel = np.stack(all_preds_voxel, axis=0)  # [N, H, W, 3]

            # all_preds_depth = np.stack(all_preds_depth, axis=0)  # [N, H, W]

            # print('all_preds0: ', all_preds.shape)
            # fix ffmpeg not divisible by 2
            all_preds = np.pad(all_preds, (
                (0, 0), (0, 1 if all_preds.shape[1] % 2 != 0 else 0), (0, 1 if all_preds.shape[2] % 2 != 0 else 0),
                (0, 0)))
            all_preds_depth = np.pad(all_preds_depth, (
                (0, 0), (0, 1 if all_preds_depth.shape[1] % 2 != 0 else 0),
                (0, 1 if all_preds_depth.shape[2] % 2 != 0 else 0)))
            # all_preds_feat = np.pad(all_preds_feat, (
            #     (0, 0), (0, 1 if all_preds_feat.shape[1] % 2 != 0 else 0),
            #     (0, 1 if all_preds_feat.shape[2] % 2 != 0 else 0),
            #     (0, 0)))
            all_preds_diffuse = np.pad(all_preds_diffuse, (
                (0, 0), (0, 1 if all_preds_diffuse.shape[1] % 2 != 0 else 0),
                (0, 1 if all_preds_diffuse.shape[2] % 2 != 0 else 0),
                (0, 0)))
            all_preds_spec = np.pad(all_preds_spec, (
                (0, 0), (0, 1 if all_preds_spec.shape[1] % 2 != 0 else 0),
                (0, 1 if all_preds_spec.shape[2] % 2 != 0 else 0),
                (0, 0)))
            all_preds_mesh = np.pad(all_preds_mesh, (
                (0, 0), (0, 1 if all_preds_mesh.shape[1] % 2 != 0 else 0),
                (0, 1 if all_preds_mesh.shape[2] % 2 != 0 else 0),
                (0, 0)))
            all_preds_voxel = np.pad(all_preds_voxel, (
                (0, 0), (0, 1 if all_preds_voxel.shape[1] % 2 != 0 else 0),
                (0, 1 if all_preds_voxel.shape[2] % 2 != 0 else 0),
                (0, 0)))

            imageio.mimwrite(os.path.join(save_path, f'{name}_rgb.mp4'), all_preds, fps=30, quality=8,
                             macro_block_size=1)
            imageio.mimwrite(os.path.join(save_path, f'{name}_depth.mp4'), all_preds_depth, fps=30, quality=8,
                             macro_block_size=1)
            # imageio.mimwrite(os.path.join(save_path, f'{name}_feat.mp4'), all_preds_feat, fps=30, quality=8,
            #                  macro_block_size=1)
            imageio.mimwrite(os.path.join(save_path, f'{name}_diffuse.mp4'), all_preds_diffuse, fps=30, quality=8,
                             macro_block_size=1)
            imageio.mimwrite(os.path.join(save_path, f'{name}_spec.mp4'), all_preds_spec, fps=30, quality=8,
                             macro_block_size=1)
            imageio.mimwrite(os.path.join(save_path, f'{name}_mesh.mp4'), all_preds_mesh, fps=30, quality=8,
                             macro_block_size=1)
            imageio.mimwrite(os.path.join(save_path, f'{name}_voxel.mp4'), all_preds_voxel, fps=30, quality=8,
                             macro_block_size=1)

        self.log(f"==> Finished Test.")

    # [GUI] just train for 16 steps, without any other overhead that may slow down rendering.
    def train_gui(self, train_loader, step=16):

        self.model.train()

        total_loss = torch.tensor([0], dtype=torch.float32, device=self.device)

        loader = iter(train_loader)

        # mark untrained grid
        if self.global_step == 0 and self.opt.mark_untrained:
            self.model.mark_untrained_grid(train_loader._data)

        for _ in range(step):

            # mimic an infinite loop dataloader (in case the total dataset is smaller than step)
            try:
                data = next(loader)
            except StopIteration:
                loader = iter(train_loader)
                data = next(loader)

            # update grid every 16 steps
            if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
                self.model.update_extra_state()

            self.global_step += 1

            self.optimizer.zero_grad()

            preds, truths, loss_net, loss_dict = self.train_step(data)

            loss = loss_net

            self.scaler.scale(loss).backward()

            self.post_train_step()  # for TV loss...

            self.scaler.step(self.optimizer)
            self.scaler.update()

            if self.scheduler_update_every_step:
                self.lr_scheduler.step()

            total_loss += loss_net.detach()

        if self.ema is not None:
            self.ema.update()

        average_loss = total_loss.item() / step

        if not self.scheduler_update_every_step:
            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                self.lr_scheduler.step(average_loss)
            else:
                self.lr_scheduler.step()

        outputs = {
            'loss': average_loss,
            'lr': self.optimizer.param_groups[0]['lr'],
        }

        return outputs

    # [GUI] test on a single image
    def test_gui(self, pose, intrinsics, mvp, W, H, bg_color=None, spp=1, downscale=1, shading='full'):

        # render resolution (may need downscale to for better frame rate)
        rH = int(H * downscale)
        rW = int(W * downscale)
        intrinsics = intrinsics * downscale

        pose = torch.from_numpy(pose).unsqueeze(0).to(self.device)

        rays = get_rays(pose, intrinsics, rH, rW, -1)

        data = {
            'mvp': mvp,
            'rays_o': rays['rays_o'],
            'rays_d': rays['rays_d'],
            'H': rH,
            'W': rW,
            'index': [0],
        }

        self.model.eval()

        if self.ema is not None:
            self.ema.store()
            self.ema.copy_to()

        with torch.no_grad():
            # here spp is used as perturb random seed! (but not perturb the first sample)
            preds, preds_depth = self.test_step(data, bg_color=bg_color, perturb=False if spp == 1 else spp,
                                                shading=shading)

        if self.ema is not None:
            self.ema.restore()

        # interpolation to the original resolution
        if downscale != 1:
            # TODO: have to permute twice with torch...
            preds = (F.interpolate(preds.unsqueeze(0).permute(0, 3, 1, 2), size=(H, W), mode='nearest')
                     .permute(0, 2, 3, 1).squeeze(0).contiguous())
            preds_depth = (F.interpolate(preds_depth.unsqueeze(0).unsqueeze(1), size=(H, W), mode='nearest')
                           .squeeze(0).squeeze(1))

        pred = preds.detach().cpu().numpy()
        pred_depth = preds_depth.detach().cpu().numpy()

        outputs = {
            'image': pred,
            'depth': pred_depth,
        }

        return outputs

    def train_one_epoch(self, loader):
        self.log(f"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...")

        total_loss = 0
        if self.local_rank == 0 and self.report_metric_at_train:
            for metric in self.metrics:
                metric.clear()

        self.model.train()

        # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs
        # ref: https://pytorch.org/docs/stable/data.html
        if self.world_size > 1:
            loader.sampler.set_epoch(self.epoch)

        if self.local_rank == 0:
            pbar = tqdm.tqdm(total=len(loader) * loader.batch_size,
                             bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')

        self.local_step = 0

        for data in loader:

            # update grid every 16 steps
            # if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0:
            #     self.model.update_extra_state()

            self.local_step += 1
            self.global_step += 1

            self.optimizer.zero_grad()

            preds, truths, loss_net, loss_dict = self.train_step(data)

            loss = loss_net

            self.scaler.scale(loss).backward()

            self.post_train_step()  # for TV loss...

            self.scaler.step(self.optimizer)
            self.scaler.update()

            if self.scheduler_update_every_step:
                self.lr_scheduler.step()

            loss_val = loss_net.item()
            total_loss += loss_val

            if self.local_rank == 0:
                if self.report_metric_at_train:
                    for metric in self.metrics:
                        metric.update(preds, truths)

                if self.use_tensorboardX:
                    self.writer.add_scalar("train/loss", loss_val, self.global_step)
                    self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step)
                    for k in loss_dict.keys():
                        # print(f"{k}: {loss_dict[k]}")
                        self.writer.add_scalar("train/" + k, loss_dict[k], self.global_step)

                if self.scheduler_update_every_step:
                    pbar.set_description(
                        f"loss={loss_val:.6f} ({total_loss / self.local_step:.6f}), "
                        f"lr={self.optimizer.param_groups[0]['lr']:.6f}")
                else:
                    pbar.set_description(f"loss={loss_val:.6f} ({total_loss / self.local_step:.6f})")
                pbar.update(loader.batch_size)

            if self.opt.render == 'mesh' and self.opt.refine and self.global_step in self.opt.refine_steps:
                self.log(f'\n[INFO] refine and decimate mesh at {self.global_step} step')
                self.model.refine_and_decimate()

                # reinit optim since params changed.
                self.optimizer = self.optimizer_fn(self.model)
                self.lr_scheduler = self.lr_scheduler_fn(self.optimizer)

        if self.ema is not None:
            self.ema.update()

        average_loss = total_loss / self.local_step
        self.stats["loss"].append(average_loss)

        if self.local_rank == 0:
            pbar.close()
            if self.report_metric_at_train:
                for metric in self.metrics:
                    self.log(metric.report(), style="red")
                    if self.use_tensorboardX:
                        metric.write(self.writer, self.epoch, prefix="train")
                    metric.clear()

        if not self.scheduler_update_every_step:
            if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                self.lr_scheduler.step(average_loss)
            else:
                self.lr_scheduler.step()

        self.log(f"==> Finished Epoch {self.epoch}, loss={average_loss:.6f}.")

    def evaluate_one_epoch(self, loader, name=None, only_metric=False):
        self.log(f"++> Evaluate at epoch {self.epoch} ...")

        if name is None:
            name = f'{self.opt.render}_ep{self.epoch:04d}'

        total_loss = 0
        if self.local_rank == 0:
            for metric in self.metrics:
                metric.clear()

        self.model.eval()

        if self.ema is not None:
            self.ema.store()
            self.ema.copy_to()

        if self.local_rank == 0:
            pbar = tqdm.tqdm(total=len(loader) * loader.batch_size,
                             bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')

        with torch.no_grad():
            self.local_step = 0

            for data in loader:
                self.local_step += 1

                pred_dict, truths, loss = self.eval_step(data)

                loss_val = loss.item()
                total_loss += loss_val

                # only rank = 0 will perform evaluation.
                if self.local_rank == 0:

                    metric_vals = []
                    for metric in self.metrics:
                        metric_val = metric.update(pred_dict['rgb'], truths)
                        metric_vals.append(metric_val)

                    # save image
                    save_path = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_rgb.png')
                    # save_path_depth = os.path.join(self.workspace, 'validation',
                    #                                f'{name}_{self.local_step:04d}_depth.png')
                    save_path_error = os.path.join(self.workspace, 'validation',
                                                   f'{name}_{self.local_step:04d}_error_{metric_vals[0]:.2f}.png')
                    # metric_vals[0] should be the PSNR

                    # self.log(f"==> Saving validation image to {save_path}")
                    if only_metric is False:

                        os.makedirs(os.path.dirname(save_path), exist_ok=True)
                        vis_list_in_pred_dict = ['rgb', 'depth',
                                                 'depth_mesh', 'mask_mesh',
                                                 'weights_sum', 'weights_mesh', 'weights_bg',
                                                 'specular', 'diffuse', 'diffuse_mesh', 'diffuse_voxel',
                                                 'diffuse_mesh_raw', 'full_mesh', 'full_voxel', 'full_mesh_raw',
                                                 'error_rgb_32', 'error_rgb_64', 'error_rgb_128', 'error_rgb_256',
                                                 'error_rgb_512',
                                                 'error_grey_32', 'error_grey_64', 'error_grey_128', 'error_grey_256',
                                                 'error_grey_512',
                                                 # f"depth_weights_{self.opt.alpha_thres:.3f}"
                                                 ]

                        # pred = pred_dict['rgb'].detach().cpu().numpy()
                        # pred = (pred * 255).astype(np.uint8)
                        # cv2.imwrite(save_path, cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))

                        for vis_name in vis_list_in_pred_dict:
                            if vis_name not in pred_dict:
                                continue
                            # if pred_dict[vis_name].dtype != torch.bool:
                            #     print(f'{vis_name}: {pred_dict[vis_name].shape}, '
                            #           f'{pred_dict[vis_name].min()}, {pred_dict[vis_name].max()}')
                            if 'depth' in vis_name:
                                pred_item = pred_dict[vis_name].detach().cpu().numpy()
                                pred_item = (pred_item - pred_item.min()) / (
                                        pred_item.max() - pred_item.min() + 1e-6)  # [N]
                            else:
                                if pred_dict[vis_name].dtype != torch.bool and (
                                        pred_dict[vis_name].min() < 0 or pred_dict[vis_name].max() > 1.0):
                                    print(f'\n{vis_name}: {pred_dict[vis_name].shape}, '
                                          f'{pred_dict[vis_name].min()}, {pred_dict[vis_name].max()} '
                                          f'in evaluate_one_epoch')
                                pred_item = pred_dict[vis_name].clamp(0, 1).detach().cpu().numpy()
                            pred_item = (pred_item * 255).astype(np.uint8)
                            if pred_item.shape[-1] == 3:
                                cv2.imwrite(save_path.replace('rgb', vis_name),
                                            cv2.cvtColor(pred_item, cv2.COLOR_RGB2BGR))
                            else:
                                cv2.imwrite(save_path.replace('rgb', vis_name), pred_item)

                        # pred_depth = pred_dict['depth'].detach().cpu().numpy()
                        # pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - pred_depth.min() + 1e-6)
                        # pred_depth = (pred_depth * 255).astype(np.uint8)
                        # cv2.imwrite(save_path_depth, pred_depth)

                        truth = truths.detach().cpu().numpy()
                        truth = (truth * 255).astype(np.uint8)
                        cv2.imwrite(save_path.replace('rgb', 'gt'), cv2.cvtColor(truth, cv2.COLOR_RGB2BGR))

                        pred = pred_dict['rgb'].detach().cpu().numpy()
                        pred = (pred * 255).astype(np.uint8)
                        error = np.abs(truth.astype(np.float32) - pred.astype(np.float32)).mean(-1).astype(np.uint8)
                        # print(save_path_error, error.shape)
                        cv2.imwrite(save_path_error, error)

                    # if 'depth_mesh' in pred_dict:
                    #     pred_depth_mesh = pred_dict['depth_mesh'].detach().cpu().numpy()
                    #     pred_depth_mesh = (pred_depth_mesh - pred_depth_mesh.min()) / \
                    #                       (pred_depth_mesh.max() - pred_depth_mesh.min() + 1e-6)
                    #     pred_depth_mesh = (pred_depth_mesh * 255).astype(np.uint8)
                    #     cv2.imwrite(save_path_depth.replace('depth', 'depth_mesh'), pred_depth_mesh)
                    #
                    #     pred_mask_mesh = pred_dict['mask_mesh'].detach().cpu().numpy()
                    #     pred_mask_mesh = (pred_mask_mesh * 255).astype(np.uint8)
                    #     cv2.imwrite(save_path_depth.replace('depth', 'mask'), pred_mask_mesh)
                    #
                    #     pred_alpha_mesh = pred_dict['alphas_mesh'].detach().cpu().numpy()
                    #     pred_alpha_mesh = (pred_alpha_mesh * 255).astype(np.uint8)
                    #     cv2.imwrite(save_path_depth.replace('depth', 'alpha_mesh'), pred_alpha_mesh)

                    pbar.set_description(f"loss={loss_val:.6f} ({total_loss / self.local_step:.6f})")
                    pbar.update(loader.batch_size)

        average_loss = total_loss / self.local_step
        self.stats["valid_loss"].append(average_loss)

        if self.local_rank == 0:
            pbar.close()
            if not self.use_loss_as_metric and len(self.metrics) > 0:
                result = self.metrics[0].measure()
                self.stats["results"].append(
                    result if self.best_mode == 'min' else - result)  # if max mode, use -result
            else:
                self.stats["results"].append(average_loss)  # if no metric, choose best by min loss

            measures = []
            for metric in self.metrics:
                measures.append(metric.measure())
                self.log(metric.report(), style="blue")
                if self.use_tensorboardX:
                    metric.write(self.writer, self.epoch, prefix="evaluate")
                metric.clear()
            self.measures.append(measures)

        if self.ema is not None:
            self.ema.restore()

        self.log(f"++> Evaluate epoch {self.epoch} Finished.")

    def save_checkpoint(self, name=None, full=False, best=False, remove_old=True, is_last=False):

        if name is None:
            name = f'{self.name}_ep{self.epoch:04d}'

        state = {
            'epoch': self.epoch,
            'global_step': self.global_step,
            'stats': self.stats,
        }

        if self.model.cuda_ray:
            state['mean_density'] = self.model.mean_density

        if full:
            state['optimizer'] = self.optimizer.state_dict()
            state['lr_scheduler'] = self.lr_scheduler.state_dict()
            state['scaler'] = self.scaler.state_dict()
            if self.ema is not None:
                state['ema'] = self.ema.state_dict()

        if not best:

            state['model'] = self.model.state_dict()

            file_path = f"{name}.pth"

            if remove_old:
                self.stats["checkpoints"].append(file_path)

                if len(self.stats["checkpoints"]) > self.max_keep_ckpt:
                    old_ckpt = os.path.join(self.ckpt_path, self.stats["checkpoints"].pop(0))
                    if os.path.exists(old_ckpt):
                        os.remove(old_ckpt)

            self.log(f'[INFO] saving checkpoint to {os.path.join(self.ckpt_path, file_path)}...')
            if self.opt.render == 'mesh' and is_last and 'vertices_offsets' in state['model']:
                del state['model']['vertices_offsets']
            self.log(f'[INFO] delete vertex offsets from {file_path} in last step in mesh refine...')

            torch.save(state, os.path.join(self.ckpt_path, file_path))

        else:
            if len(self.stats["results"]) > 0:
                if self.stats["best_result"] is None or self.stats["results"][-1] < self.stats["best_result"]:
                    self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}")
                    self.stats["best_result"] = self.stats["results"][-1]

                    # save ema results
                    if self.ema is not None:
                        self.ema.store()
                        self.ema.copy_to()

                    state['model'] = self.model.state_dict()

                    # we don't consider continued training from the best ckpt,
                    # so we discard the unneeded density_grid to save some storage (especially important for dnerf)
                    # if 'density_grid' in state['model']:
                    #     del state['model']['density_grid']

                    if self.ema is not None:
                        self.ema.restore()

                    torch.save(state, self.best_path)
                    self.log(f'[INFO] saving checkpoint to {self.best_path}...')

            else:
                self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.")

    def load_checkpoint(self, checkpoint=None, model_only=False):

        if checkpoint is None:  # load latest
            checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/{self.name}*.pth'))
            if not checkpoint_list:
                if 'use_vol_pth' in self.opt and self.opt.use_vol_pth:
                    checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/vol*.pth'))
                else:
                    checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/*.pth'))

            if checkpoint_list:
                if len(checkpoint_list) > 1 and self.name not in checkpoint_list[-1].split('/')[-1]:
                    checkpoint = checkpoint_list[-2]
                else:
                    checkpoint = checkpoint_list[-1]
                if self.name in checkpoint.split('/')[-1]:
                    self.continue_vosh = True
                self.log(f"[INFO] Latest checkpoint is {checkpoint}")
            else:
                self.log("[WARN] No checkpoint found, abort loading latest model.")
                return

        checkpoint_dict = torch.load(checkpoint, map_location=self.device)

        if 'model' not in checkpoint_dict:
            self.model.load_state_dict(checkpoint_dict)
            self.log("[INFO] loaded model.")
            return

        # deal with unwanted prefix '_orig_mod.'
        state_dict = checkpoint_dict['model']
        # unwanted_prefix = '_orig_mod.'
        # for k, v in list(state_dict.items()):
        #     if not k.startswith(unwanted_prefix):
        #         state_dict[unwanted_prefix + k] = state_dict.pop(k)

        missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
        self.log("[INFO] loaded model.")
        if len(missing_keys) > 0:
            self.log(f"[WARN] missing keys: {missing_keys}")
        if len(unexpected_keys) > 0:
            self.log(f"[WARN] unexpected keys: {unexpected_keys}")

        if self.ema is not None and 'ema' in checkpoint_dict:
            try:
                self.ema.load_state_dict(checkpoint_dict['ema'])
                self.log("[INFO] loaded EMA.")
            except:
                self.log("[WARN] failed to loaded EMA.")

        if self.model.cuda_ray:
            if 'mean_density' in checkpoint_dict:
                self.model.mean_density = checkpoint_dict['mean_density']

        if model_only:
            return

        self.log(f"[INFO] load at epoch {checkpoint_dict['epoch']}, global step {checkpoint_dict['global_step']}")
        if self.opt.render == 'grid' or self.continue_vosh:
            self.stats = checkpoint_dict['stats']
            self.epoch = checkpoint_dict['epoch']
            self.global_step = checkpoint_dict['global_step']
            self.log(f"[INFO] continue using epoch and global_step")

        # if self.opt.reset_lr:
        #     current_step = self.global_step
        #     self.log(f'[INFO] reset lr to 0.1**((iter-{current_step})/({self.opt.iters - current_step}))')
        #     self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        #         self.optimizer, lambda iter: 0.1 ** ((iter - current_step) / (self.opt.iters - current_step)))

        if self.opt.render == 'grid' or self.continue_vosh:
            if self.optimizer and 'optimizer' in checkpoint_dict:
                try:
                    self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
                    self.log("[INFO] loaded optimizer.")
                except:
                    self.log("[WARN] Failed to load optimizer.")

            if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
                try:
                    self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
                    self.log("[INFO] loaded scheduler.")
                except:
                    self.log("[WARN] Failed to load scheduler.")
        else:
            self.log("[INFO] skipped optimizer and scheduler in vosh training.")

        if self.scaler and 'scaler' in checkpoint_dict:
            try:
                self.scaler.load_state_dict(checkpoint_dict['scaler'])
                self.log("[INFO] loaded scaler.")
            except:
                self.log("[WARN] Failed to load scaler.")

        if self.opt.render == 'mixed' and self.opt.mesh_encoder:
            if 'DensityAndFeaturesMLP_mesh.mlp.net.0.weight' in state_dict.keys():
                self.log(f"[INFO] loaded mesh encoder from {checkpoint}")
            else:
                checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/mesh*.pth'))
                assert checkpoint_list is not None, "No mesh checkpoint found."
                mesh_checkpoint = checkpoint_list[-1]
                mesh_checkpoint_dict = torch.load(mesh_checkpoint, map_location=self.device)['model']
                for key in list(mesh_checkpoint_dict.keys()):
                    v = mesh_checkpoint_dict.pop(key)
                    if key.startswith('DensityAndFeaturesMLP'):
                        mesh_checkpoint_dict[key.replace('DensityAndFeaturesMLP.', '')] = v

                self.model.DensityAndFeaturesMLP_mesh.load_state_dict(mesh_checkpoint_dict)
                self.log(f"[INFO] loaded mesh encoder from {mesh_checkpoint}")

back to top

Software Heritage — Copyright (C) 2015–2026, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API