Revision 7356f29ce0d6123363d426c124ef1c71addcd730 authored by Smit Hinsu on 09 January 2023, 06:16:20 UTC, committed by jax authors on 09 January 2023, 06:34:12 UTC
The following commit removed the use of `client` but the check was left out and is not unnecessary.
https://github.com/tensorflow/tensorflow/commit/71ac02ec1f8a0af70d6a49447c9b5d49f4115f8a

PiperOrigin-RevId: 500618944
1 parent 44b97ae
Raw File
optimizers_test.py
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for the optimizers module."""

import functools

from absl.testing import absltest
import numpy as np

import jax.numpy as jnp
import jax._src.test_util as jtu
from jax import jit, grad, jacfwd, jacrev
from jax import tree_util
from jax import lax
from jax.example_libraries import optimizers

from jax.config import config
config.parse_flags_with_absl()


class OptimizerTests(jtu.JaxTestCase):

  def _CheckOptimizer(self, optimizer, loss, x0, num_steps, *args, **kwargs):
    self._CheckFuns(optimizer, loss, x0, *args)
    self._CheckRun(optimizer, loss, x0, num_steps, *args, **kwargs)

  def _CheckFuns(self, optimizer, loss, x0, *args):
    init_fun, update_fun, get_params = optimizer(*args)
    opt_state = init_fun(x0)
    self.assertAllClose(x0, get_params(opt_state))
    opt_state2 = update_fun(0, grad(loss)(x0), opt_state)  # doesn't crash
    self.assertEqual(tree_util.tree_structure(opt_state),
                     tree_util.tree_structure(opt_state2))

  @jtu.skip_on_devices('gpu')
  def _CheckRun(self, optimizer, loss, x0, num_steps, *args, **kwargs):
    init_fun, update_fun, get_params = optimizer(*args)

    opt_state = init_fun(x0)
    for i in range(num_steps):
      x = get_params(opt_state)
      g = grad(loss)(x)
      opt_state = update_fun(i, g, opt_state)
    xstar = get_params(opt_state)
    self.assertLess(loss(xstar), 1e-2)

    update_fun_jitted = jit(update_fun)
    opt_state = init_fun(x0)
    for i in range(num_steps):
      x = get_params(opt_state)
      g = grad(loss)(x)
      opt_state = update_fun_jitted(i, g, opt_state)
    xstar = get_params(opt_state)
    self.assertLess(loss(xstar), 1e-2)

  def testSgdScalar(self):
    def loss(x): return x**2
    x0 = 1.
    num_iters = 100
    step_size = 0.1
    self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_size)

  def testSgdVector(self):
    def loss(x): return jnp.dot(x, x)
    x0 = jnp.ones(2)
    num_iters = 100
    step_size = 0.1
    self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_size)

  def testSgdNestedTuple(self):
    def loss(xyz):
      x, (y, z) = xyz
      return sum(jnp.dot(a, a) for a in [x, y, z])
    x0 = (jnp.ones(2), (jnp.ones(2), jnp.ones(2)))
    num_iters = 100
    step_size = 0.1
    self._CheckOptimizer(optimizers.sgd, loss, x0, num_iters, step_size)

  def testMomentumVector(self):
    def loss(x): return jnp.dot(x, x)
    x0 = jnp.ones(2)
    num_iters = 100
    step_size = 0.1
    mass = 0.
    self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_size, mass)

  def testMomentumDict(self):
    def loss(dct): return jnp.dot(dct['x'], dct['x'])
    x0 = {'x': jnp.ones(2)}
    num_iters = 100
    step_size = 0.1
    mass = 0.
    self._CheckOptimizer(optimizers.momentum, loss, x0, num_iters, step_size, mass)

  def testRmspropVector(self):
    def loss(x): return jnp.dot(x, x)
    x0 = jnp.ones(2)
    num_iters = 100
    step_size = 0.1
    self._CheckOptimizer(optimizers.rmsprop, loss, x0, num_iters, step_size)

  def testAdamVector(self):
    def loss(x): return jnp.dot(x, x)
    x0 = jnp.ones(2)
    num_iters = 100
    step_size = 0.1
    self._CheckOptimizer(optimizers.adam, loss, x0, num_iters, step_size)

  def testSgdClosure(self):
    def loss(y, x): return y**2 * x**2
    x0 = 1.
    y = 1.
    num_iters = 20
    step_size = 0.1
    partial_loss = functools.partial(loss, y)
    self._CheckRun(optimizers.sgd, partial_loss, x0, num_iters, step_size)

  def testAdagrad(self):

    def loss(xs):
      x1, x2 = xs
      return jnp.sum(x1**2) + jnp.sum(x2**2)

    num_iters = 100
    step_size = 0.1
    x0 = (jnp.ones(2), jnp.ones((2, 2)))
    self._CheckOptimizer(optimizers.adagrad, loss, x0, num_iters, step_size)

  def testSM3Scalar(self):
    def loss(x): return x**2
    x0 = jnp.array(1.)
    num_iters = 100
    step_size = 0.1
    self._CheckOptimizer(optimizers.sm3, loss, x0, num_iters, step_size)

  def testSM3Vector(self):
    def loss(xs):
      x1, x2 = xs
      return jnp.sum(x1 ** 2) + jnp.sum(x2 ** 2)

    num_iters = 100
    step_size = 0.1
    x0 = (jnp.ones(2), jnp.ones((2, 2)))
    self._CheckOptimizer(optimizers.sm3, loss, x0, num_iters, step_size)

  def testAdaMaxVector(self):
    def loss(x): return jnp.dot(x, x)
    x0 = jnp.ones(2)
    num_iters = 100
    step_size = 0.1
    self._CheckOptimizer(optimizers.adamax, loss, x0, num_iters, step_size)

  def testSgdVectorExponentialDecaySchedule(self):
    def loss(x): return jnp.dot(x, x)
    x0 = jnp.ones(2)
    step_schedule = optimizers.exponential_decay(0.1, 3, 2.)
    self._CheckFuns(optimizers.sgd, loss, x0, step_schedule)

  def testSgdVectorInverseTimeDecaySchedule(self):
    def loss(x): return jnp.dot(x, x)
    x0 = jnp.ones(2)
    step_schedule = optimizers.inverse_time_decay(0.1, 3, 2.)
    self._CheckFuns(optimizers.sgd, loss, x0, step_schedule)

  def testAdamVectorInverseTimeDecaySchedule(self):
    def loss(x): return jnp.dot(x, x)
    x0 = jnp.ones(2)
    step_schedule = optimizers.inverse_time_decay(0.1, 3, 2.)
    self._CheckFuns(optimizers.adam, loss, x0, step_schedule)

  def testMomentumVectorInverseTimeDecayStaircaseSchedule(self):
    def loss(x): return jnp.dot(x, x)
    x0 = jnp.ones(2)
    step_sched = optimizers.inverse_time_decay(0.1, 3, 2., staircase=True)
    mass = 0.9
    self._CheckFuns(optimizers.momentum, loss, x0, step_sched, mass)

  def testRmspropmomentumVectorPolynomialDecaySchedule(self):
    def loss(x): return jnp.dot(x, x)
    x0 = jnp.ones(2)
    step_schedule = optimizers.polynomial_decay(1.0, 50, 0.1)
    self._CheckFuns(optimizers.rmsprop_momentum, loss, x0, step_schedule)

  def testRmspropVectorPiecewiseConstantSchedule(self):
    def loss(x): return jnp.dot(x, x)
    x0 = jnp.ones(2)
    step_schedule = optimizers.piecewise_constant([25, 75], [1.0, 0.5, 0.1])
    self._CheckFuns(optimizers.rmsprop, loss, x0, step_schedule)

  def testTracedStepSize(self):
    def loss(x): return jnp.dot(x, x)
    x0 = jnp.ones(2)
    step_size = 0.1

    init_fun, _, _ = optimizers.sgd(step_size)
    opt_state = init_fun(x0)

    @jit
    def update(opt_state, step_size):
      _, update_fun, get_params = optimizers.sgd(step_size)
      x = get_params(opt_state)
      g = grad(loss)(x)
      return update_fun(0, g, opt_state)

    update(opt_state, 0.9)  # doesn't crash

  # TODO(mattjj): re-enable
  # def testDeviceTupleState(self):
  #   init_fun, update_fun, _ = optimizers.sgd(0.1)
  #   opt_state = init_fun(jnp.zeros(3))
  #   self.assertIsInstance(opt_state, optimizers.OptimizerState)
  #   self.assertIsInstance(opt_state.packed_state, core.JaxTuple)
  #   opt_state = jit(update_fun)(0, jnp.zeros(3), opt_state)
  #   self.assertIsInstance(opt_state, optimizers.OptimizerState)
  #   self.assertIsInstance(opt_state.packed_state, xla.DeviceTuple)

  def testUpdateFunStructureMismatchErrorMessage(self):
    @optimizers.optimizer
    def opt_maker():
      def init_fun(x0):
        return {'x': x0}
      def update_fun(i, g, opt_state):
        x = opt_state['x']
        return {'x': x - 0.1 * g, 'v': g}  # bug!
      def get_params(opt_state):
        return opt_state['x']
      return init_fun, update_fun, get_params

    init_fun, update_fun, get_params = opt_maker()
    opt_state = init_fun(jnp.zeros(3))
    self.assertRaises(TypeError, lambda: update_fun(opt_state))

  def testUtilityNorm(self):
    x0 = (jnp.ones(2), (jnp.ones(3), jnp.ones(4)))
    norm = optimizers.l2_norm(x0)
    expected = np.sqrt(np.sum(np.ones(2+3+4)**2))
    self.assertAllClose(norm, expected, check_dtypes=False)

  def testUtilityClipGrads(self):
    g = (jnp.ones(2), (jnp.ones(3), jnp.ones(4)))
    norm = optimizers.l2_norm(g)

    ans = optimizers.clip_grads(g, 1.1 * norm)
    expected = g
    self.assertAllClose(ans, expected, check_dtypes=False)

    ans = optimizers.l2_norm(optimizers.clip_grads(g, 0.9 * norm))
    expected = 0.9 * norm
    self.assertAllClose(ans, expected, check_dtypes=False)

  def testIssue758(self):
    # code from https://github.com/google/jax/issues/758
    # this is more of a scan + jacfwd/jacrev test, but it lives here to use the
    # optimizers.py code

    def harmonic_bond(conf, params):
      return jnp.sum(conf * params)

    opt_init, opt_update, get_params = optimizers.sgd(5e-2)

    x0 = np.array([0.5], dtype=np.float64)

    def minimize_structure(test_params):
      energy_fn = functools.partial(harmonic_bond, params=test_params)
      grad_fn = grad(energy_fn, argnums=(0,))
      opt_state = opt_init(x0)

      def apply_carry(carry, _):
        i, x = carry
        g = grad_fn(get_params(x))[0]
        new_state = opt_update(i, g, x)
        new_carry = (i+1, new_state)
        return new_carry, _

      carry_final, _ = lax.scan(apply_carry, (0, opt_state), jnp.zeros((75, 0)))
      trip, opt_final = carry_final
      assert trip == 75
      return opt_final

    initial_params = jnp.array(0.5)
    minimize_structure(initial_params)

    def loss(test_params):
      opt_final = minimize_structure(test_params)
      return 1.0 - get_params(opt_final)[0]

    loss_opt_init, loss_opt_update, loss_get_params = optimizers.sgd(5e-2)

    J1 = jacrev(loss, argnums=(0,))(initial_params)
    J2 = jacfwd(loss, argnums=(0,))(initial_params)
    self.assertAllClose(J1, J2, rtol=1e-6)

  def testUnpackPackRoundTrip(self):
    opt_init, _, _ = optimizers.momentum(0.1, mass=0.9)
    params = [{'w': self.rng().randn(1, 2), 'bias': self.rng().randn(2)}]
    expected = opt_init(params)
    ans = optimizers.pack_optimizer_state(
        optimizers.unpack_optimizer_state(expected))
    self.assertEqual(ans, expected)

if __name__ == '__main__':
  absltest.main(testLoader=jtu.JaxTestLoader())
back to top