Raw File
custom_root_test.py
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re

from absl.testing import absltest
from absl.testing import parameterized

import numpy as np

import jax
from jax import lax
from jax._src import test_util as jtu
from jax import tree_util
import jax.numpy as jnp  # scan tests use numpy
import jax.scipy as jsp

from jax.config import config
config.parse_flags_with_absl()


def high_precision_dot(a, b):
  return lax.dot(a, b, precision=lax.Precision.HIGHEST)


# Simple optimization routine for testing custom_root
def binary_search(func, x0, low=0.0, high=100.0):
  del x0  # unused

  def cond(state):
    low, high = state
    midpoint = 0.5 * (low + high)
    return (low < midpoint) & (midpoint < high)

  def body(state):
    low, high = state
    midpoint = 0.5 * (low + high)
    update_upper = func(midpoint) > 0
    low = jnp.where(update_upper, low, midpoint)
    high = jnp.where(update_upper, midpoint, high)
    return (low, high)

  solution, _ = lax.while_loop(cond, body, (low, high))
  return solution

# Optimization routine for testing custom_root.
def newton_raphson(func, x0):
  tol = 1e-16
  max_it = 20

  fx0, dfx0 = func(x0), jax.jacobian(func)(x0)
  initial_state = (0, x0, fx0, dfx0)  # (iteration, x, f(x), grad(f)(x))

  def cond(state):
    it, _, fx, _ = state
    return (jnp.max(jnp.abs(fx)) > tol) & (it < max_it)

  def body(state):
    it, x, fx, dfx = state
    step = jnp.linalg.solve(
      dfx.reshape((-1, fx.size)), fx.ravel()
    ).reshape(fx.shape)
    x_next = x - step
    fx, dfx = func(x_next), jax.jacobian(func)(x_next)
    return (it + 1, x_next, fx, dfx)

  _, x, _, _ = lax.while_loop(cond, body, initial_state)

  return x


class CustomRootTest(jtu.JaxTestCase):

  @parameterized.named_parameters(
      {"testcase_name": "binary_search", "solve_method": binary_search},
      {"testcase_name": "newton_raphson", "solve_method": newton_raphson},
  )
  def test_custom_root_scalar(self, solve_method):

    def scalar_solve(f, y):
      return y / f(1.0)

    def sqrt_cubed(x, tangent_solve=scalar_solve):
      f = lambda y: y ** 2 - x ** 3
      # Note: Nonzero derivative at x0 required for newton_raphson
      return lax.custom_root(f, 1.0, solve_method, tangent_solve)

    value, grad = jax.value_and_grad(sqrt_cubed)(5.0)
    self.assertAllClose(value, 5 ** 1.5, check_dtypes=False, rtol=1e-6)
    self.assertAllClose(grad, jax.grad(pow)(5.0, 1.5), check_dtypes=False,
                        rtol=1e-7)
    jtu.check_grads(sqrt_cubed, (5.0,), order=2,
                    rtol={jnp.float32: 1e-2, jnp.float64: 1e-3})

    inputs = jnp.array([4.0, 5.0])
    results = jax.vmap(sqrt_cubed)(inputs)
    self.assertAllClose(
      results, inputs ** 1.5, check_dtypes=False,
      atol={jnp.float32: 1e-3, jnp.float64: 1e-6},
      rtol={jnp.float32: 1e-3, jnp.float64: 1e-6},
    )

    results = jax.jit(sqrt_cubed)(5.0)
    self.assertAllClose(
        results, 5.0**1.5, check_dtypes=False, rtol={np.float64: 1e-7})

  @jtu.skip_on_flag("jax_skip_slow_tests", True)
  def test_custom_root_vector_with_solve_closure(self):

    def vector_solve(f, y):
      return jnp.linalg.solve(jax.jacobian(f)(y), y)

    def linear_solve(a, b):
      f = lambda y: high_precision_dot(a, y) - b
      x0 = jnp.zeros_like(b)
      solution = jnp.linalg.solve(a, b)
      oracle = lambda func, x0: solution
      return lax.custom_root(f, x0, oracle, vector_solve)

    rng = self.rng()
    a = rng.randn(2, 2)
    b = rng.randn(2)
    jtu.check_grads(linear_solve, (a, b), order=2,
                    atol={np.float32: 1e-2, np.float64: 1e-11})

    actual = jax.jit(linear_solve)(a, b)
    expected = jnp.linalg.solve(a, b)
    self.assertAllClose(expected, actual)

  def test_custom_root_vector_nonlinear(self):

    def nonlinear_func(x, y):
      # func(x, y) == 0 if and only if x == y.
      return (x - y) * (x**2 + y**2 + 1)

    def tangent_solve(g, y):
      return jnp.linalg.solve(
        jax.jacobian(g)(y).reshape(-1, y.size),
        y.ravel()
      ).reshape(y.shape)

    def nonlinear_solve(y):
      f = lambda x: nonlinear_func(x, y)
      x0 = -jnp.ones_like(y)
      return lax.custom_root(f, x0, newton_raphson, tangent_solve)

    y = self.rng().randn(3, 1)
    jtu.check_grads(nonlinear_solve, (y,), order=2,
                rtol={jnp.float32: 1e-2, jnp.float64: 1e-3})

    actual = jax.jit(nonlinear_solve)(y)
    self.assertAllClose(y, actual, rtol=1e-5, atol=1e-5)

  def test_custom_root_with_custom_linear_solve(self):

    def linear_solve(a, b):
      f = lambda x: high_precision_dot(a, x) - b
      factors = jsp.linalg.cho_factor(a)
      cho_solve = lambda f, b: jsp.linalg.cho_solve(factors, b)
      def pos_def_solve(g, b):
        return lax.custom_linear_solve(g, b, cho_solve, symmetric=True)
      return lax.custom_root(f, b, cho_solve, pos_def_solve)

    rng = self.rng()
    a = rng.randn(2, 2)
    b = rng.randn(2)

    actual = linear_solve(high_precision_dot(a, a.T), b)
    expected = jnp.linalg.solve(high_precision_dot(a, a.T), b)
    self.assertAllClose(expected, actual)

    actual = jax.jit(linear_solve)(high_precision_dot(a, a.T), b)
    expected = jnp.linalg.solve(high_precision_dot(a, a.T), b)
    self.assertAllClose(expected, actual)

    jtu.check_grads(lambda x, y: linear_solve(high_precision_dot(x, x.T), y),
                    (a, b), order=2, rtol={jnp.float32: 1e-2})

  def test_custom_root_with_aux(self):
    def root_aux(a, b):
      f = lambda x: high_precision_dot(a, x) - b
      factors = jsp.linalg.cho_factor(a)
      cho_solve = lambda f, b: (jsp.linalg.cho_solve(factors, b), orig_aux)

      def pos_def_solve(g, b):
        # prune aux to allow use as tangent_solve
        cho_solve_noaux = lambda f, b: cho_solve(f, b)[0]
        return lax.custom_linear_solve(g, b, cho_solve_noaux, symmetric=True)

      return lax.custom_root(f, b, cho_solve, pos_def_solve, has_aux=True)

    orig_aux = {"converged": np.array(1.), "nfev": np.array(12345.), "grad": np.array([1.0, 2.0, 3.0])}

    rng = self.rng()
    a = rng.randn(2, 2)
    b = rng.randn(2)

    actual, actual_aux = root_aux(high_precision_dot(a, a.T), b)
    actual_jit, actual_jit_aux = jax.jit(root_aux)(high_precision_dot(a, a.T), b)
    expected = jnp.linalg.solve(high_precision_dot(a, a.T), b)

    self.assertAllClose(expected, actual)
    self.assertAllClose(expected, actual_jit)
    jtu.check_eq(actual_jit_aux, orig_aux)

    # grad check with aux
    jtu.check_grads(lambda x, y: root_aux(high_precision_dot(x, x.T), y),
                    (a, b), order=2, rtol={jnp.float32: 1e-2})

    # test vmap and jvp combined by jacfwd
    fwd = jax.jacfwd(lambda x, y: root_aux(high_precision_dot(x, x.T), y), argnums=(0, 1))
    expected_fwd = jax.jacfwd(lambda x, y: jnp.linalg.solve(high_precision_dot(x, x.T), y), argnums=(0, 1))

    fwd_val, fwd_aux = fwd(a, b)
    expected_fwd_val = expected_fwd(a, b)
    self.assertAllClose(fwd_val, expected_fwd_val, rtol={np.float32: 1E-6, np.float64: 1E-12})

    jtu.check_close(fwd_aux, tree_util.tree_map(jnp.zeros_like, fwd_aux))

  def test_custom_root_errors(self):
    with self.assertRaisesRegex(TypeError, re.escape("f() output pytree")):
      lax.custom_root(lambda x: (x, x), 0.0, lambda f, x: x, lambda f, x: x)
    with self.assertRaisesRegex(TypeError, re.escape("solve() output pytree")):
      lax.custom_root(lambda x: x, 0.0, lambda f, x: (x, x), lambda f, x: x)

    def dummy_root_usage(x):
      f = lambda y: x - y
      return lax.custom_root(f, 0.0, lambda f, x: x, lambda f, x: (x, x))

    with self.assertRaisesRegex(
        TypeError, re.escape("tangent_solve() output pytree")):
      jax.jvp(dummy_root_usage, (0.0,), (0.0,))


if __name__ == '__main__':
  absltest.main(testLoader=jtu.JaxTestLoader())
back to top