Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

swh:1:snp:eee76444da62e238a10272cb71070ca8823b3f3d
  • Code
  • Branches (1)
  • Releases (0)
    • Branches
    • Releases
    • HEAD
    • refs/heads/main
    No releases to show
  • 6250ce0
  • /
  • nerf
  • /
  • grid_utils.py
Raw File Download

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • content
  • directory
  • revision
  • snapshot
content badge
swh:1:cnt:704aac6e877029af9e9efaf6f535ce4ec81e77f9
directory badge
swh:1:dir:2aec7f959a197d6347e16cd0cda954702fe7482a
revision badge
swh:1:rev:da207d03e7994d9c5a097126dcd509abedc26bc0
snapshot badge
swh:1:snp:eee76444da62e238a10272cb71070ca8823b3f3d

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • content
  • directory
  • revision
  • snapshot
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
Tip revision: da207d03e7994d9c5a097126dcd509abedc26bc0 authored by zachzhang07 on 21 November 2024, 08:07:14 UTC
Update readme.md
Tip revision: da207d0
grid_utils.py
# coding=utf-8
# Copyright 2023 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Triplane and voxel grid simulation."""
import torch
import itertools

# After contraction point lies within [-2, 2]^3. See coord.contract.
# WORLD_MIN = -2.0
# WORLD_MAX = 2.0


def calculate_voxel_size(resolution, WORLD_MAX, WORLD_MIN):
    return (WORLD_MAX - WORLD_MIN) / resolution


def grid_to_world(x_grid, voxel_size, WORLD_MAX, WORLD_MIN):
    """Converts grid coordinates [0, res]^3 to a world coordinates ([-2, 2]^3)."""
    # We also account for the fact that the grid is going to be queried in WebGL
    # by adopting WebGL's indexing: Y and Z coordinates are swapped and the X
    # coordinate is mirrored. Inverse of world_to_grid.

    x = torch.zeros_like(x_grid)

    def get_x():
        return (WORLD_MAX - voxel_size / 2) - voxel_size * x_grid[..., 0]

    def get_yz():
        return (WORLD_MIN + voxel_size / 2) + voxel_size * x_grid[..., [2, 1]]

    x[..., 0] = get_x()
    x[..., [1, 2]] = get_yz()
    return x


def world_to_grid(x, voxel_size, WORLD_MAX, WORLD_MIN):
    """Converts a world coordinate (in [-2, 2]^3) to a grid coordinate [0, res]^3."""
    # Inverse of grid_to_world.
    x_grid = torch.zeros_like(x)

    def get_x():
        return ((WORLD_MAX - voxel_size / 2) - x[..., 0]) / voxel_size

    def get_yz():
        return (x[..., [1, 2]] - (WORLD_MIN + voxel_size / 2)) / voxel_size

    x_grid[..., 0] = get_x()
    x_grid[..., [2, 1]] = get_yz()
    return x_grid


def calculate_num_evaluations_per_sample(opt):
    """Calculates number of MLP evals required for each sample along the ray."""
    # For grid evaluation, we need to evaluate the MLP multiple times per query
    # to the representation. The number of samples depends on whether a sparse
    # grid, triplanes or both are used.
    assert (
            opt.triplane_resolution > 0 or opt.grid_resolution > 0
    ), 'one or both of these values needs to be specified'
    x = 0
    if opt.triplane_resolution > 0:
        x += 3 * 4  # Three planes and 4 samples for bi-linear interpolation.
    if opt.grid_resolution > 0:
        x += 8  # Tri-linear interpolation.
    return x


def calculate_grid_config(opt):
    """Computes voxel sizes from grid resolutions."""
    # `voxel_size_to_use` is for instance used to infer the step size used during
    # rendering, which should equal to the voxel size of the finest grid
    # (triplane or sparse grid) that is used.
    if opt.contract:
        WORLD_MIN, WORLD_MAX = -2.0, 2.0
    else:
        WORLD_MIN, WORLD_MAX = -opt.bound, opt.bound

    triplane_voxel_size = calculate_voxel_size(opt.triplane_resolution, WORLD_MAX, WORLD_MIN)
    sparse_grid_voxel_size = calculate_voxel_size(opt.grid_resolution, WORLD_MAX, WORLD_MIN)
    # Assuming that triplane_resolution is higher than sparse_grid_resolution
    # when the triplanes are used.
    voxel_size_to_use = (
        triplane_voxel_size
        if opt.triplane_resolution > 0
        else sparse_grid_voxel_size
    )
    resolution_to_use = (
        opt.triplane_resolution
        if opt.triplane_resolution > 0
        else opt.grid_resolution
    )
    return dict(
        triplane_voxel_size=triplane_voxel_size,
        sparse_grid_voxel_size=sparse_grid_voxel_size,
        voxel_size_to_use=voxel_size_to_use,
        resolution_to_use=resolution_to_use,
        WORLD_MIN=WORLD_MIN,
        WORLD_MAX=WORLD_MAX,
    )


def get_eval_positions_and_local_coordinates(positions, opt, grid_config):
    """Given as input is a batch of positions of shape Lx3."""
    # Prepare grid simulation, the returned `positions` has the shape S*Lx3
    #   S = 1 if no grid simulation is used (no-op)
    #   S = 8 if only a 3D grid is simulated
    #   S = 3*4 = 12 if only tri-planes are simulated
    #   S = 20 if both tri-planes and a 3D grid are used (MERF)
    #   see: calculate_num_evaluations_per_sample
    #
    # For every query to our grid-based representation we have to perform S
    # queries to the grid which is parameterized by an MLP.
    #
    # Further we compute positions (∈ [0,1]^3) local to a texel/voxel,
    # which are later used to compute interpolation weights:
    #     triplane_positions_local, sparse_grid_positions_local: Lx3
    triplane_positions_local = sparse_grid_positions_local = None
    if opt.triplane_resolution > 0:
        if opt.grid_resolution > 0:
            sparse_grid_positions, sparse_grid_positions_local = (
                sparse_grid_get_eval_positions_and_local_coordinates(
                    positions, grid_config['sparse_grid_voxel_size'], axis=1,
                    WORLD_MAX=grid_config['WORLD_MAX'], WORLD_MIN=grid_config['WORLD_MIN']
                )
            )  # 8*Lx3 and Lx3.
        positions, triplane_positions_local = (
            triplane_get_eval_posititons_and_local_coordinates(
                positions, grid_config['triplane_voxel_size'], axis=1,
                WORLD_MAX=grid_config['WORLD_MAX'], WORLD_MIN=grid_config['WORLD_MIN']
            )
        )  # 12*Lx3 and Lx3.
        if opt.grid_resolution > 0:
            # Concatenate sparse grid and triplane positions for MERF.
            positions = torch.cat([sparse_grid_positions, positions], dim=1)
    else:  # implies config.sparse_grid_resolution > 0.
        positions, sparse_grid_positions_local = (
            sparse_grid_get_eval_positions_and_local_coordinates(
                positions, grid_config['sparse_grid_voxel_size'], axis=1,
                WORLD_MAX=grid_config['WORLD_MAX'], WORLD_MIN=grid_config['WORLD_MIN']
            )
        )  # 8*Lx3 and Lx3.
    positions = positions.reshape(-1, *positions.shape[2:])
    return positions, triplane_positions_local, sparse_grid_positions_local


def interpolate_based_on_local_coordinates(
        y, triplane_positions_local, sparse_grid_positions_local, opt
):
    """Linearly interpolates values fetched from grid corners."""
    # Linearly interpolates values fetched from grid corners based on
    # blending weights computed from within-voxel/texel local coordinates.
    #
    # y: S*LxC
    # triplane_positions_local: Lx3
    # sparse_grid_positions_local: Lx3
    #
    # Output: LxC
    s = calculate_num_evaluations_per_sample(opt)
    y = y.reshape(-1, s, *y.shape[1:])
    if opt.triplane_resolution > 0:
        if opt.grid_resolution > 0:
            sparse_grid_y, y = y.split([8, 12], dim=1)
        r = triplane_interpolate_based_on_local_coordinates(
            y, triplane_positions_local, axis=1
        )
        if opt.grid_resolution > 0:
            r += sparse_grid_interpolate_based_on_local_coordinates(
                sparse_grid_y, sparse_grid_positions_local, axis=1
            )
        return r
    else:  # implies sparse_grid_resolution is > 0.
        return sparse_grid_interpolate_based_on_local_coordinates(
            y, sparse_grid_positions_local, axis=1
        )


def sparse_grid_get_eval_positions_and_local_coordinates(x, voxel_size, axis, WORLD_MAX, WORLD_MIN):
    """Compute positions of 8 surrounding voxel corners and within-voxel coords."""
    x_grid = world_to_grid(x, voxel_size, WORLD_MAX, WORLD_MIN)
    x_floor = torch.floor(x_grid)
    x_ceil = torch.ceil(x_grid)
    local_coordinates = x_grid - x_floor
    positions_corner = []
    corner_coords = [[False, True] for _ in range(x.shape[-1])]
    for z in itertools.product(*corner_coords):
        l = []
        for i, b in enumerate(z):
            l.append(x_ceil[..., i] if b else x_floor[..., i])
        positions_corner.append(torch.stack(l, dim=-1))
    positions_corner = torch.stack(positions_corner, dim=axis)
    positions_corner = grid_to_world(positions_corner, voxel_size, WORLD_MAX, WORLD_MIN)
    return positions_corner, local_coordinates


def sparse_grid_interpolate_based_on_local_coordinates(
        y, local_coordinates, axis
):
    """Blend 8 MLP outputs based on weights computed from local coordinates."""
    y = torch.moveaxis(y, axis, -2)
    res = torch.zeros(y.shape[:-2] + (y.shape[-1],)).to(y)
    corner_coords = [[False, True] for _ in range(local_coordinates.shape[-1])]
    for corner_index, z in enumerate(itertools.product(*corner_coords)):
        w = torch.ones(local_coordinates.shape[:-1]).to(y)
        for i, b in enumerate(z):
            w = w * (
                local_coordinates[..., i] if b else (1 - local_coordinates[..., i])
            )
        res = res + w[..., None] * y[..., corner_index, :]
    return res


def triplane_get_eval_posititons_and_local_coordinates(x, voxel_size, axis, WORLD_MAX, WORLD_MIN):
    """For each of the 3 planes return the 4 sampling positions at texel corners."""
    x_grid = world_to_grid(x, voxel_size)
    x_floor = torch.floor(x_grid)
    x_ceil = torch.ceil(x_grid)
    local_coordinates = x_grid - x_floor
    corner_coords = [
        [False, True] for _ in range(2)
    ]  # (0, 0), (0, 1), (1, 0), (1, 1).
    r = []
    for plane_idx in range(3):  # Index of the plane to project to.
        # Indices of the two ouf of three dims along which we bilineary interpolate.
        inds = [h for h in range(3) if h != plane_idx]
        for z in itertools.product(*corner_coords):
            l = [None for _ in range(3)]
            l[plane_idx] = torch.zeros_like(x_grid[..., 0])
            for i, b in enumerate(z):
                l[inds[i]] = x_ceil[..., inds[i]] if b else x_floor[..., inds[i]]
            r.append(torch.stack(l, dim=-1))
    r = torch.stack(r, dim=axis)
    return grid_to_world(r, voxel_size, WORLD_MAX, WORLD_MIN), local_coordinates


def triplane_interpolate_based_on_local_coordinates(y, local_coordinates, axis):
    """Blend 3*4=12 MLP outputs based on weights computed from local coordinates."""
    y = torch.moveaxis(y, axis, -2)
    res = torch.zeros(y.shape[:-2] + (y.shape[-1],)).to(y)
    corner_coords = [[False, True] for _ in range(2)]
    query_index = 0
    for plane_idx in range(3):
        # Indices of the two ouf of three dims along which we bilineary interpolate.
        inds = [h for h in range(3) if h != plane_idx]
        for z in itertools.product(*corner_coords):
            w = torch.ones(local_coordinates.shape[:-1]).to(y)
            for i, b in enumerate(z):
                w = w * (
                    local_coordinates[..., inds[i]]
                    if b
                    else (1 - local_coordinates[..., inds[i]])
                )
            res += w[..., None] * y[..., query_index, :]
            query_index += 1
    return res

back to top

Software Heritage — Copyright (C) 2015–2026, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API