https://github.com/OpenSVBRDF/OpenSVBRDF_source_code
Tip revision: 27d60e6e95bc66f65d1780b114f2aa228992e7f4 authored by Xiaohe Ma on 29 October 2024, 07:13:38 UTC
Merge pull request #7 from BeAShaper/main
Merge pull request #7 from BeAShaper/main
Tip revision: 27d60e6
setup_config.py
import os
import cv2
from typing import Tuple, Dict
import torch
import numpy as np
from camera import Camera
class SetupConfig(object):
"""
This class manages the setup of lightstage.
The setup should be saved in `wallet_of_torch_renderer` folder, include many
configs: `cam_pos`, `lights_data`, `lights_angular`, `lights_intensity`,
`visualize_map`, `rgb_tensor` and so on.
"""
def __init__(
self,
config_dir: str,
mask: torch.Tensor = None,
low_res: bool = False,
) -> None:
"""
Args:
config_dir: the directory of the config
mask: a bool tensor to remove unwanted lights. The shape is
(lightnum, ), if true, the light will be kept, else
the light will be removed.
low_res: low camera resolution, usually for optix simulation.
"""
self.config_dir = config_dir
self.mask = mask
self.camera = Camera(self.config_dir, low_res)
self.light_poses, self.light_normals = self._load_light_data(
self.config_dir + "lights.bin")
self.lights_angular = self._load_lights_angular(
self.config_dir + "lights_angular.npy")
self.lights_intensity = self._load_lights_intensity(
self.config_dir + "lights_intensity.npy")
self.img_size, self.visualize_map = self._load_visualize_map(
self.config_dir + "visualize_config_torch.bin")
self.rgb_tensor = self._load_rgb_tensor(
self.config_dir + "color_tensor.bin"
)
self.mask_data = self._load_mask_data(
self.config_dir + "mask.npy"
)
self.trained_idx = self._load_trained_idx(
self.config_dir + "trained_idx.txt"
)
def _load_light_data(
self, config_file: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Load light position and normal from config file.
TODO: Attention! The normal of lights is reversed! You should generate new setup config.
Args:
config_file: The config file is a bin file, and has data of shape
(2, lightnum, 3), dtype=float32. In the first dim, 0 is position
and 1 is normal.
Returns:
light_poses: a ndarray of shape (lightnum, 3)
light_normals: a ndarray of shape (lightnum, 3)
"""
light_data = np.fromfile(config_file, np.float32).reshape([2, -1, 3])
light_poses, light_normals = light_data[0], light_data[1]
if self.mask is not None:
light_poses = light_data[0, self.mask, :]
light_normals = light_data[1, self.mask, :]
return light_poses, light_normals
def _load_lights_angular(self, config_file: str) -> torch.Tensor:
"""
Load angular distribution parameters of light intensity.
The distribution is fitted with two cubic polynomials, therefore each
light is with 8 parameters. See `compute_light_distribution` function
for more detail.
Args:
config_file: a npy file.
Returns:
lights_angular: a ndarray of shape (1, lightnum, 8)
"""
lights_angular = None
if os.path.isfile(config_file):
lights_angular = np.load(config_file)
if self.mask is not None:
lights_angular = lights_angular[0, self.mask, :]
return lights_angular
def _load_lights_intensity(self, config_file: str) -> torch.Tensor:
"""
Load the relative intensity of lights.
Args:
config_file: a npy file
Returns:
lights_intensity: a ndarray of shape (lightnum, 3).
"""
lights_intensity = None
if os.path.isfile(config_file):
lights_intensity = np.load(config_file)
if self.mask is not None:
lights_intensity = lights_intensity[self.mask, :]
return lights_intensity
def _load_visualize_map(
self, config_file: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Load visualize mapping data from config file.
The sequence of lights in `light_poses` data. for example: our
light board has 48*64 lights. The lights sequence in `light_poses`
starts from right bottom corner, from bottom to top, from left to
right. Then the map idx sequence will be 63 47 63 46 63 45 ...
63 0 62 47 62 46 ... 0 0.
Args:
config_file: The config file is a bin file with dtype int32. The
first 2 integers are image size (W, H), after are map idx.
Returns:
img_size: a ndarray of shape (2, ), represents H, W
visualize_map: a ndarray of shape (lightnum, 2)
"""
with open(config_file, "rb") as pf:
img_size = np.fromfile(pf, np.int32, 2)
visualize_map = np.fromfile(pf, np.int32).reshape([-1, 2])
return img_size, visualize_map
def _load_rgb_tensor(self, config_file):
"""
Load rgb tensor from config file
Args:
config_file:
Returns:
rgb_tensor: a ndarray of shape (3, 3, 3), (light, object, camera)
"""
rgb_tensor = np.zeros((3, 3, 3))
for i in range(3):
rgb_tensor[i][i][i] = 1.0
if os.path.isfile(config_file):
rgb_tensor = np.fromfile(config_file, np.float32).reshape([3, 3, 3])
else:
print("Note: color tensor file is not found, use default color tensor!")
return rgb_tensor
def _load_mask_data(self, config_file):
"""
Load mask data from config file
Args:
config_file: a npy file
Returns:
A dict contians `mask_anchor`, `mask_axis_x`, `mask_axis_y`
"""
mask_data = np.load(config_file, allow_pickle=True).item()
return mask_data
def _load_trained_idx(self, config_file):
if os.path.exists(config_file):
trained_idx = np.loadtxt(config_file)
else:
trained_idx = []
return trained_idx
def print_configs(self) -> None:
"""
Print the configs
"""
print("[SETUP CONFIG]")
print("cam_pos:", self.cam_pos)
def get_rgb_tensor(self, device : str) -> torch.Tensor:
"""
Returns:
rgb_tensor: a tensor of shape (3, 3, 3), (light, object, camera)
"""
return torch.from_numpy(self.rgb_tensor).to(device)
def get_cam_pos(self, device : str) -> torch.Tensor:
"""
Returns:
cam_pos: shape (3, )
"""
return self.camera.get_cam_pos(device)
def get_lights_intensity(self, device : str) -> torch.Tensor:
"""
Returns:
lights_intensity: a tensor of shape (lightnum, 3).
"""
return torch.from_numpy(self.lights_intensity).to(device)
def get_lights_angular(self, device : str) -> torch.Tensor:
"""
Returns:
lights_angular: a tensor of shape (1, lightnum, 8)
"""
return torch.from_numpy(self.lights_angular).to(device)
def get_light_normals(self, device : str) -> torch.Tensor:
"""
Returns:
light_normals: a tensor of shape (lightnum, 3)
"""
return torch.from_numpy(self.light_normals).to(device)
def get_light_poses(self, device : str) -> torch.Tensor:
"""
Returns:
light_poses: a tensor of shape (lightnum, 3)
"""
return torch.from_numpy(self.light_poses).to(device)
def get_light_idx(self, light_pos: np.array) -> int:
"""
Get the nearest light
Args:
light_pos: ndarray, (1, 3)
"""
assert(light_pos.shape[0] == 1 and light_pos.shape[1] == 3)
dist = self.light_poses - light_pos
light_idx = np.argmin(np.sum(dist * dist, axis=1))
return light_idx
def get_trained_idx(self) -> list:
"""
Returns:
trained_idx: a list of trained_idx
"""
return self.trained_idx
def get_vis_img_size(self) -> Tuple[int, int]:
"""
Returns:
H, W
"""
return self.img_size[0], self.img_size[1]
def get_visualize_map(self, device : str) -> torch.Tensor:
"""
Returns:
visualize_map: a tensor of shape (lightnum, 2)
"""
return torch.from_numpy(self.visualize_map).long().to(device)
def get_mask_data(self, device : str) -> Dict:
"""
Returns:
mask_anchor:
mask_axis_x:
mask_axis_y:
"""
anchor = torch.from_numpy(self.mask_data['mask_anchor']).float().to(device)
offset1 = torch.from_numpy(self.mask_data['mask_axis_x']).float().to(device)
offset2 = torch.from_numpy(self.mask_data['mask_axis_y']).float().to(device)
return anchor, offset1, offset2
def get_light_num(self) -> int:
return self.light_normals.shape[0]
@classmethod
def convert_xy_to_idx(cls, x, y):
"""
Args:
x, y: left top is (0, 0), x is vertical
Returns:
idx: light index, right buttom is 0, right top is 47.
"""
return 48 * (63 - y) + (47 - x)
@classmethod
def convert_idx_to_xy(cls, idx):
x = 47 - idx % 48
y = 63 - idx // 48
return x, y
@classmethod
def get_downsampled_light_poses(cls, ori_light_poses):
"""
Args:
light_poses: (3072, 3)
Returns:
light_poses downsampled
"""
light_poses = np.zeros((3072 // 16, 3))
cnt = 0
for row in range(48 // 4):
for col in range(64 // 4):
ori_idx = cls.convert_xy_to_idx(4*row+2, 4*col+2)
light_poses[cnt] = ori_light_poses[ori_idx]
cnt += 1
return light_poses
@classmethod
def upsample_data(
cls,
light_data: np.array
):
"""
Args:
light_data: (192, *)
"""
new_light_data = np.zeros((3072, light_data.shape[1]))
for i in range(light_data.shape[0]):
row = i // 16
col = i % 16
for x in range(4*row, 4*(row+1)):
for y in range(4*col, 4*(col+1)):
idx = cls.convert_xy_to_idx(x, y)
new_light_data[idx] = light_data[i]
return new_light_data