https://github.com/liam6699/TS-NeRF.git
Raw File
Tip revision: 31b6aa1940b0a4ffd0467cac5d11bf32f940d222 authored by “YourUsername” on 03 October 2023, 03:25:47 UTC
Modify readme file
Tip revision: 31b6aa1
train.py
import os
import torch

# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from torch import nn
from opt import get_opts
import glob
import imageio
import numpy as np
import cv2
from einops import rearrange

# data
from torch.utils.data import DataLoader
from datasets import dataset_dict
from datasets.ray_utils import axisangle_to_R, get_rays

# models
from kornia.utils.grid import create_meshgrid3d
from models.networks import NGP
from models.rendering import render, MAX_SAMPLES

# optimizer, losses
from apex.optimizers import FusedAdam
from torch.optim.lr_scheduler import CosineAnnealingLR
from losses import NeRFLoss

# metrics
from torchmetrics import (
    PeakSignalNoiseRatio,
    StructuralSimilarityIndexMeasure
)
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

# pytorch-lightning
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available

# custom
from utils import slim_ckpt, load_ckpt
from models_ts.global_val import Global_instance
from models_ts import RAIN
from models_ts.RAIN import Net as RAIN_net
import utils_ts.utils as my_utils

# torch.set_float32_matmul_precision('medium')
# TF_ENABLE_ONEDNN_OPTS=0

# Constrain all sources of randomness
seed = 29
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)


def depth2img(depth):
    depth = (depth - depth.min()) / (depth.max() - depth.min())
    depth_img = cv2.applyColorMap((depth * 255).astype(np.uint8),
                                  cv2.COLORMAP_TURBO)

    return depth_img


class NeRFSystem(LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)

        self.warmup_steps = 256
        self.update_interval = 16

        self.loss = NeRFLoss(lambda_distortion=self.hparams.distortion_loss_w)
        self.train_psnr = PeakSignalNoiseRatio(data_range=1)
        self.val_psnr = PeakSignalNoiseRatio(data_range=1)
        self.val_ssim = StructuralSimilarityIndexMeasure(data_range=1)
        if self.hparams.eval_lpips:
            self.val_lpips = LearnedPerceptualImagePatchSimilarity('vgg')
            for p in self.val_lpips.net.parameters():
                p.requires_grad = False

        rgb_act = 'None' if self.hparams.use_exposure else 'Sigmoid'

        self.model = NGP(scale=self.hparams.scale, rgb_act=rgb_act, stage=self.hparams.stage)

        G = self.model.grid_size
        self.model.register_buffer('density_grid',
                                   torch.zeros(self.model.cascades, G ** 3))
        self.model.register_buffer('grid_coords',
                                   create_meshgrid3d(G, G, G, False, dtype=torch.int32).reshape(-1, 3))

        """custom initialization"""
        if hparams.stage == "second_stage":
            # Setting VGG
            for param in self.model.xyz_encoder.parameters():
                param.requires_grad = False

            # Create vgg and fc_encoder in RAIN_net
            vgg = RAIN.vgg
            fc_encoder = RAIN.fc_encoder
            # Load pretrained weights of vgg and fc_encoder
            vgg.load_state_dict(torch.load(hparams.vgg_pretrained_path))
            fc_encoder.load_state_dict(torch.load(hparams.fc_encoder_pretrained_path))
            vgg = nn.Sequential(*list(vgg.children())[:31])
            self.RAIN_net = RAIN_net(vgg, fc_encoder).to(device)

            # Fixed RAIN_net
            for param in self.RAIN_net.parameters():
                param.requires_grad = False

        # Whether to turn on the nearest neighbor finder
        Global_instance.clip_loss.set_ArtBench_search(self.hparams.enable_ArtBench_search)

    def forward(self, batch, split):
        if split == 'train':
            poses = self.poses[batch['img_idxs']]
            directions = self.directions[batch['pix_idxs']]
        else:
            poses = batch['pose']
            directions = self.directions

        if self.hparams.optimize_ext:
            dR = axisangle_to_R(self.dR[batch['img_idxs']])
            poses[..., :3] = dR @ poses[..., :3]
            poses[..., 3] += self.dT[batch['img_idxs']]

        rays_o, rays_d = get_rays(directions, poses)

        kwargs = {'test_time': split != 'train',
                  'random_bg': self.hparams.random_bg}
        if self.hparams.scale > 0.5:
            kwargs['exp_step_factor'] = 1 / 256
        if self.hparams.use_exposure:
            kwargs['exposure'] = batch['exposure']

        return render(self.model, rays_o, rays_d, **kwargs)

    def setup(self, stage):
        dataset = dataset_dict[self.hparams.dataset_name]
        kwargs = {'root_dir': self.hparams.root_dir,
                  'downsample': self.hparams.downsample}
        self.train_dataset = dataset(split=self.hparams.split, **kwargs)
        self.train_dataset.batch_size = self.hparams.batch_size
        self.train_dataset.ray_sampling_strategy = self.hparams.ray_sampling_strategy

        self.test_dataset = dataset(split='test', **kwargs)

    def configure_optimizers(self):
        # define additional parameters
        self.register_buffer('directions', self.train_dataset.directions.to(self.device))
        self.register_buffer('poses', self.train_dataset.poses.to(self.device))

        if self.hparams.optimize_ext:
            N = len(self.train_dataset.poses)
            self.register_parameter('dR',
                                    nn.Parameter(torch.zeros(N, 3, device=self.device)))
            self.register_parameter('dT',
                                    nn.Parameter(torch.zeros(N, 3, device=self.device)))

        load_ckpt(self.model, self.hparams.weight_path)

        net_params = []
        for n, p in self.named_parameters():
            if n not in ['dR', 'dT']: net_params += [p]

        opts = []
        self.net_opt = FusedAdam(net_params, self.hparams.lr, eps=1e-15)
        opts += [self.net_opt]
        if self.hparams.optimize_ext:
            opts += [FusedAdam([self.dR, self.dT], 1e-6)]  # learning rate is hard-coded
        net_sch = CosineAnnealingLR(self.net_opt,
                                    self.hparams.num_epochs,
                                    self.hparams.lr / 30)

        return opts, [net_sch]

    def train_dataloader(self):
        return DataLoader(self.train_dataset,
                          num_workers=16,
                          persistent_workers=True,
                          batch_size=None,
                          pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.test_dataset,
                          num_workers=8,
                          batch_size=None,
                          pin_memory=True)

    def on_train_start(self):
        self.model.mark_invisible_cells(self.train_dataset.K.to(self.device),
                                        self.poses,
                                        self.train_dataset.img_wh)

    def training_step(self, batch, batch_nb, *args):
        # self.hparams.is_valid = False  # 这个地方千万不能动

        if self.hparams.is_valid:
            loss = torch.zeros(1, requires_grad=True).to(device)
            return loss

        if self.global_step % self.update_interval == 0:
            self.model.update_density_grid(0.01 * MAX_SAMPLES / 3 ** 0.5,
                                           warmup=self.global_step < self.warmup_steps,
                                           erode=self.hparams.dataset_name == 'colmap')

        if self.hparams.stage == "second_stage":

            W, H = batch["W"], batch["H"]

            if self.hparams.enable_random_sampling:
                with torch.no_grad():
                    results = self(batch, split='train')  # 560 283
                """
                n = 4  # Splitting factor, suggested to be Nth power of 2
                w, h = int(batch["W"] / n), int(batch["H"] / n)
                idxs = np.random.choice(n, 2)  # Randomly select n indexes
                i, j = idxs[0], idxs[1]
                grad_idxs = np.arange(0, H * W).reshape(H, W)[i * h:(i + 1) * h, j * w:(j + 1) * w]
                grad_idxs = grad_idxs.flatten()
                """
                n = 7
                grad_idxs = np.random.choice(H * W, n * 8192)  # Sampling n * 8192 samples to calculate the gradien
                batch_grad = batch.copy()
                batch_grad["pix_idxs"] = batch_grad["pix_idxs"][grad_idxs]
                batch_grad["rgb"] = batch_grad["rgb"][grad_idxs]

            if self.hparams.is_valid:
                with torch.no_grad():
                    if self.hparams.enable_random_sampling:
                        results_grad = self(batch_grad, split='train')
                    else:
                        results = self(batch, split='train')
            else:
                if self.hparams.enable_random_sampling:
                    results_grad = self(batch_grad, split='train')
                else:
                    results = self(batch, split='train')

            if self.hparams.enable_random_sampling:
                results["rgb"][grad_idxs] = results_grad["rgb"]

            rgb_gt = batch["rgb"].reshape(H, W, 3)
            rgb_result = results["rgb"].reshape(H, W, 3).permute(2, 0, 1).unsqueeze(0)

            # VGG loss
            content_feat_pred = self.RAIN_net.get_content_feat(rgb_result)
            content_feat_gt = self.RAIN_net.get_content_feat(rgb_gt.permute(2, 0, 1).unsqueeze(0))
            content_loss = my_utils.get_content_loss(content_feat_gt, content_feat_pred)

            # CLIP dirction loss
            clip_loss = Global_instance.clip_loss(rgb_gt.permute(2, 0, 1).unsqueeze(0), "photo", rgb_result,
                                                  self.hparams.style_target)
            # HCL loss
            hcl_loss = Global_instance.clip_loss_nce(rgb_gt.permute(2, 0, 1).unsqueeze(0), "photo", rgb_result,
                                                     self.hparams.style_target, lamda=0)  # default lamda=0

            if self.hparams.enable_NeRF_loss:
                loss_d = self.loss(results_grad, batch_grad, first_stage=False)  # NeRF loss
            else:
                loss_d = {}
            loss_d["content_loss"] = content_loss * 0.01
            loss_d["hcl_loss"] = hcl_loss * 1
            loss_d["clip_loss"] = clip_loss * 28
            loss = sum(lo.mean() for lo in loss_d.values())
        else:
            results = self(batch, split='train')
            loss_d = self.loss(results, batch, first_stage=True)
            if self.hparams.use_exposure:
                zero_radiance = torch.zeros(1, 3, device=self.device)
                unit_exposure_rgb = self.model.log_radiance_to_rgb(zero_radiance,
                                                                   **{'exposure': torch.ones(1, 1, device=self.device)})
                loss_d['unit_exposure'] = \
                    0.5 * (unit_exposure_rgb - self.train_dataset.unit_exposure_rgb) ** 2
            loss = sum(lo.mean() for lo in loss_d.values())

        torch.cuda.empty_cache()

        with torch.no_grad():
            self.train_psnr(results['rgb'], batch['rgb'])
        self.log('lr', self.net_opt.param_groups[0]['lr'])
        self.log('train/loss', loss)
        # ray marching samples per ray (occupied space on the ray)
        self.log('train/rm_s', results['rm_samples'] / len(batch['rgb']), True)
        # volume rendering samples per ray (stops marching when transmittance drops below 1e-4)
        self.log('train/vr_s', results['vr_samples'] / len(batch['rgb']), True)
        self.log('train/psnr', self.train_psnr, True)

        return loss

    def on_validation_start(self):
        torch.cuda.empty_cache()
        if not self.hparams.no_save_test:
            self.val_dir = f'results/{self.hparams.dataset_name}/{self.hparams.exp_name}'
            os.makedirs(self.val_dir, exist_ok=True)

    def validation_step(self, batch, batch_nb):
        rgb_gt = batch['rgb']
        results = self(batch, split='test')

        logs = {}
        # compute each metric per image
        self.val_psnr(results['rgb'], rgb_gt)
        logs['psnr'] = self.val_psnr.compute()
        self.val_psnr.reset()

        w, h = self.train_dataset.img_wh
        rgb_pred = rearrange(results['rgb'], '(h w) c -> 1 c h w', h=h)
        rgb_gt = rearrange(rgb_gt, '(h w) c -> 1 c h w', h=h)
        self.val_ssim(rgb_pred, rgb_gt)
        logs['ssim'] = self.val_ssim.compute()
        self.val_ssim.reset()
        if self.hparams.eval_lpips:
            self.val_lpips(torch.clip(rgb_pred * 2 - 1, -1, 1),
                           torch.clip(rgb_gt * 2 - 1, -1, 1))
            logs['lpips'] = self.val_lpips.compute()
            self.val_lpips.reset()

        if not self.hparams.no_save_test:  # save test image to disk
            idx = batch['img_idxs']
            rgb_pred = rearrange(results['rgb'].cpu().numpy(), '(h w) c -> h w c', h=h)
            rgb_pred = (rgb_pred * 255).astype(np.uint8)
            depth = depth2img(rearrange(results['depth'].cpu().numpy(), '(h w) -> h w', h=h))
            imageio.imsave(os.path.join(self.val_dir, f'{idx:03d}.png'), rgb_pred)
            imageio.imsave(os.path.join(self.val_dir, f'{idx:03d}_d.png'), depth)

        return logs

    def validation_epoch_end(self, outputs):
        psnrs = torch.stack([x['psnr'] for x in outputs])
        mean_psnr = all_gather_ddp_if_available(psnrs).mean()
        self.log('test/psnr', mean_psnr, True)

        ssims = torch.stack([x['ssim'] for x in outputs])
        mean_ssim = all_gather_ddp_if_available(ssims).mean()
        self.log('test/ssim', mean_ssim)

        if self.hparams.eval_lpips:
            lpipss = torch.stack([x['lpips'] for x in outputs])
            mean_lpips = all_gather_ddp_if_available(lpipss).mean()
            self.log('test/lpips_vgg', mean_lpips)

    def get_progress_bar_dict(self):
        # don't show the version number
        items = super().get_progress_bar_dict()
        items.pop("v_num", None)
        return items


if __name__ == '__main__':
    hparams = get_opts()
    if hparams.val_only and (not hparams.ckpt_path):
        raise ValueError('You need to provide a @ckpt_path for validation!')
    system = NeRFSystem(hparams)

    # Set the current training stage
    Global_instance.set_current_stage(hparams.stage)
    ckpt_num_epochs = 50
    ckpt_cb = ModelCheckpoint(dirpath=f'ckpts/{hparams.dataset_name}/{hparams.exp_name}',
                              filename='{epoch:d}',
                              save_weights_only=True,
                              every_n_epochs=ckpt_num_epochs,
                              save_on_train_epoch_end=True,
                              save_top_k=-1, save_last=True)
    callbacks = [ckpt_cb, TQDMProgressBar(refresh_rate=1)]

    logger = TensorBoardLogger(save_dir=f"logs/{hparams.dataset_name}",
                               name=hparams.exp_name,
                               default_hp_metric=False)

    trainer = Trainer(max_epochs=hparams.num_epochs,
                      check_val_every_n_epoch=hparams.num_epochs,
                      callbacks=callbacks,
                      logger=logger,
                      enable_model_summary=False,
                      accelerator='gpu',
                      devices=hparams.num_gpus,
                      strategy=DDPPlugin(find_unused_parameters=False)
                      if hparams.num_gpus > 1 else None,
                      num_sanity_val_steps=-1 if hparams.val_only else 0,
                      precision=16, amp_backend="apex", amp_level="O1")

    trainer.fit(system, ckpt_path=hparams.ckpt_path)

    if not hparams.val_only:  # save slimmed ckpt for the last epoch

        ckpt_ = \
            slim_ckpt(f'ckpts/{hparams.dataset_name}/{hparams.exp_name}/last.ckpt',
                      save_poses=hparams.optimize_ext)

        torch.save(ckpt_, f'ckpts/{hparams.dataset_name}/{hparams.exp_name}/epoch={hparams.num_epochs - 1}_slim.ckpt')

    if (not hparams.no_save_test) and hparams.dataset_name == 'nsvf' and 'Synthetic' in hparams.root_dir:  # save video
        imgs = sorted(glob.glob(os.path.join(system.val_dir, '*.png')))
        imageio.mimsave(os.path.join(system.val_dir, 'rgb.mp4'),
                        [imageio.imread(img) for img in imgs[::2]],
                        fps=30, macro_block_size=1)
back to top