Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

swh:1:snp:eee76444da62e238a10272cb71070ca8823b3f3d
  • Code
  • Branches (1)
  • Releases (0)
    • Branches
    • Releases
    • HEAD
    • refs/heads/main
    No releases to show
  • 6250ce0
  • /
  • nerf
  • /
  • math.py
Raw File Download

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • content
  • directory
  • revision
  • snapshot
content badge
swh:1:cnt:ce9eb4f1eba08733ac310367a9dd6250e7fa31c4
directory badge
swh:1:dir:2aec7f959a197d6347e16cd0cda954702fe7482a
revision badge
swh:1:rev:da207d03e7994d9c5a097126dcd509abedc26bc0
snapshot badge
swh:1:snp:eee76444da62e238a10272cb71070ca8823b3f3d

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • content
  • directory
  • revision
  • snapshot
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
Tip revision: da207d03e7994d9c5a097126dcd509abedc26bc0 authored by zachzhang07 on 21 November 2024, 08:07:14 UTC
Update readme.md
Tip revision: da207d0
math.py
# coding=utf-8
# Copyright 2023 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mathy utility functions."""
import torch
import numpy as np


def matmul(a, b):
    """torch.matmul defaults to bfloat16, but this helper function doesn't."""
    return torch.matmul(a, b)


# def safe_trig_helper(x, fn, t=100 * torch.pi):
#     """Helper function used by safe_cos/safe_sin: mods x before sin()/cos()."""
#     return fn(torch.nan_to_num(torch.where(torch.abs(x) < t, x, x % t)))
# 
# 
# def safe_cos(x):
#     """torch.cos() on a TPU may NaN out for large values."""
#     return safe_trig_helper(x, torch.cos)
# 
# 
# def safe_sin(x):
#     """torch.sin() on a TPU may NaN out for large values."""
#     return safe_trig_helper(x, torch.sin)


def safe_exp(x):
    """torch.exp() but with finite output and gradients for large inputs."""
    return torch.exp(torch.clamp(x, max=88.0))  # np.exp(89) is infinity.


def safe_exp_jvp(primals, tangents):
    """Override safe_exp()'s gradient so that it's large when inputs are large."""
    [x] = primals
    [x_dot] = tangents
    exp_x = safe_exp(x)
    exp_x_dot = exp_x * x_dot
    return exp_x, exp_x_dot


def safe_log(x):
    """torch.log() but with finite outputs/gradients for negative/huge inputs."""
    return torch.log(torch.clip(x, 1e-37, 1e37))  # torch.log(1e-38) is -infinity.


def safe_log_jvp(primals, tangents):
    """Override safe_log()'s gradient to always be finite."""
    [x] = primals
    [x_dot] = tangents
    log_x = safe_log(x)
    log_x_dot = x_dot / torch.maximum(1e-37, x)
    return log_x, log_x_dot


def log_lerp(t, v0, v1):
    """Interpolate log-linearly from `v0` (t=0) to `v1` (t=1)."""
    if v0 <= 0 or v1 <= 0:
        raise ValueError(f'Interpolants {v0} and {v1} must be positive.')
    lv0 = np.log(v0)
    lv1 = np.log(v1)
    return np.exp(np.clip(t, 0, 1) * (lv1 - lv0) + lv0)


def learning_rate_decay(
        step, lr_init, lr_final, max_steps, lr_delay_steps=0, lr_delay_mult=1
):
    """Continuous learning rate decay function.

    The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
    is log-linearly interpolated elsewhere (equivalent to exponential decay).
    If lr_delay_steps>0 then the learning rate will be scaled by some smooth
    function of lr_delay_mult, such that the initial learning rate is
    lr_init*lr_delay_mult at the beginning of optimization but will be eased back
    to the normal learning rate when steps>lr_delay_steps.

    Args:
      step: int, the current optimization step.
      lr_init: float, the initial learning rate.
      lr_final: float, the final learning rate.
      max_steps: int, the number of steps during optimization.
      lr_delay_steps: int, the number of steps to delay the full learning rate.
      lr_delay_mult: float, the multiplier on the rate when delaying it.

    Returns:
      lr: the learning for current step 'step'.
    """
    if lr_delay_steps > 0:
        # A kind of reverse cosine decay.
        delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
            0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
        )
    else:
        delay_rate = 1.0
    return delay_rate * log_lerp(step / max_steps, lr_init, lr_final)


def density_to_alpha(x, step_size):
    return 1.0 - safe_exp(-x * step_size)


def density_activation(x):
    return safe_exp(x - 1.0)


def normalize(x, v_min, v_max):
    """[v_min, v_max] -> [0, 1]."""
    return (x - v_min) / (v_max - v_min)


def denormalize(x, v_min, v_max):
    """[0, 1] -> [v_min, v_max]."""
    return v_min + x * (v_max - v_min)


def sorted_lookup(x, xp, fps, device_is_tpu=False):
    """Lookup `x` into locations `xp` , return indices and each `[fp]` value."""
    if not isinstance(fps, tuple):
        raise ValueError(f'Input `fps` must be a tuple, but is {type(fps)}.')

    if device_is_tpu:
        # Identify the location in `xp` that corresponds to each `x`.
        # The final `True` index in `mask` is the start of the matching interval.
        mask = x[Ellipsis, None, :] >= xp[Ellipsis, :, None]

        def find_interval(x):
            # Grab the value where `mask` switches from True to False, and vice versa.
            # This approach takes advantage of the fact that `x` is sorted.
            x0 = np.max(np.where(mask, x[Ellipsis, None], x[Ellipsis, :1, None]), -2)
            x1 = np.min(np.where(~mask, x[Ellipsis, None], x[Ellipsis, -1:, None]), -2)
            return x0, x1

        idx0, idx1 = find_interval(np.arange(xp.shape[-1]))
        vals = [find_interval(fp) for fp in fps]
    else:
        # np.searchsorted() has slightly different conventions for boundary
        # handling than the rest of this codebase.

        idx = np.searchsorted(xp, x, side='right')
        idx1 = np.minimum(idx, xp.shape[-1] - 1)
        idx0 = np.maximum(idx - 1, 0)
        vals = []
        for fp in fps:
            fp0 = np.take_along_axis(fp, idx0, axis=-1)
            fp1 = np.take_along_axis(fp, idx1, axis=-1)
            vals.append((fp0, fp1))
    return (idx0, idx1), vals


def sorted_interp(
        x, xp, fp, device_is_tpu, eps=np.finfo(np.float32).eps ** 2
):
    """A version of interp() where xp and fp must be sorted."""
    (xp0, xp1), (fp0, fp1) = sorted_lookup(
        x, xp, (xp, fp), device_is_tpu=device_is_tpu
    )[1]
    offset = np.clip((x - xp0) / np.maximum(eps, xp1 - xp0), 0, 1)
    ret = fp0 + offset * (fp1 - fp0)
    return ret

back to top

Software Heritage — Copyright (C) 2015–2026, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API