Revision f20a1c968b2e0d2ec88319a2a17116793ee6fe4c authored by Peter Hawkins on 07 April 2020, 15:21:39 UTC, committed by GitHub on 07 April 2020, 15:21:39 UTC
1 parent fa383b4
Raw File
control_test.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from functools import partial
from unittest import SkipTest

from absl.testing import absltest
import numpy as onp

from jax import lax
from jax import test_util as jtu
import jax.numpy as np

from examples import control

from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS


def one_step_lqr(dim, T):
  Q = np.stack(T * (np.eye(dim),))
  q = np.zeros((T, dim))
  R = np.zeros((T, dim, dim))
  r = np.zeros((T, dim))
  M = np.zeros((T, dim, dim))
  A = np.stack(T * (np.eye(dim),))
  B = np.stack(T * (np.eye(dim),))
  return control.LqrSpec(Q, q, R, r, M, A, B)


def control_from_lqr(lqr):
  T, dim, _ = lqr.Q.shape
  dot = np.dot

  def cost(t, x, u):
    return (
        dot(dot(lqr.Q[t], x), x) + dot(lqr.q[t], x) +
        dot(dot(lqr.R[t], u), u) + dot(lqr.r[t], u) +
        dot(dot(lqr.M[t], u), x))

  def dynamics(t, x, u):
    return dot(lqr.A[t], x) + dot(lqr.B[t], u)

  return control.ControlSpec(cost, dynamics, T, dim, dim)


def one_step_control(dim, T):

  def cost(t, x, u):
    return np.dot(x, x)

  def dynamics(t, x, u):
    return x + u

  return control.ControlSpec(cost, dynamics, T, dim, dim)


class ControlExampleTest(jtu.JaxTestCase):

  def testTrajectoryCyclicIntegerCounter(self):
    num_states = 3

    def dynamics(t, x, u):
      return (x + u) % num_states

    T = 10

    U = np.ones((T, 1))
    X = control.trajectory(dynamics, U, np.zeros(1))
    expected = np.arange(T + 1) % num_states
    expected = np.reshape(expected, (T + 1, 1))
    self.assertAllClose(X, expected, check_dtypes=False)

    U = 2 * np.ones((T, 1))
    X = control.trajectory(dynamics, U, np.zeros(1))
    expected = np.cumsum(2 * np.ones(T)) % num_states
    expected = np.concatenate((np.zeros(1), expected))
    expected = np.reshape(expected, (T + 1, 1))
    self.assertAllClose(X, expected, check_dtypes=False)

  def testTrajectoryTimeVarying(self):
    T = 6

    def clip(x, lo, hi):
      return np.minimum(hi, np.maximum(lo, x))

    def dynamics(t, x, u):
      # computes `(x + u) if t > T else 0`
      return (x + u) * clip(t - T, 0, 1)

    U = np.ones((2 * T, 1))
    X = control.trajectory(dynamics, U, np.zeros(1))
    expected = np.concatenate((np.zeros(T + 1), np.arange(T)))
    expected = np.reshape(expected, (2 * T + 1, 1))
    self.assertAllClose(X, expected, check_dtypes=True)


  def testTrajectoryCyclicIndicator(self):
    num_states = 3

    def position(x):
      '''finds the index of a standard basis vector, e.g. [0, 1, 0] -> 1'''
      x = np.cumsum(x)
      x = 1 - x
      return np.sum(x, dtype=np.int32)

    def dynamics(t, x, u):
      '''moves  the next standard basis vector'''
      idx = (position(x) + u[0]) % num_states
      return lax.dynamic_slice_in_dim(np.eye(num_states), idx, 1)[0]

    T = 8

    U = np.ones((T, 1), dtype=np.int32)
    X = control.trajectory(dynamics, U, np.eye(num_states, dtype=np.int32)[0])
    expected = np.vstack((np.eye(num_states),) * 3)
    self.assertAllClose(X, expected, check_dtypes=True)


  def testLqrSolve(self):
    dim, T = 2, 10
    p = one_step_lqr(dim, T)
    K, k = control.lqr_solve(p)
    K_ = -np.stack(T * (np.eye(dim),))
    self.assertAllClose(K, K_, check_dtypes=True, atol=1e-6, rtol=1e-6)
    self.assertAllClose(k, np.zeros((T, dim)), check_dtypes=True)


  def testLqrPredict(self):
    randn = onp.random.RandomState(0).randn
    dim, T = 2, 10
    p = one_step_lqr(dim, T)
    x0 = randn(dim)
    X, U = control.lqr_predict(p, x0)
    self.assertAllClose(X[0], x0, check_dtypes=True)
    self.assertAllClose(U[0], -x0, check_dtypes=True,
                        atol=1e-6, rtol=1e-6)
    self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True,
                        atol=1e-6, rtol=1e-6)
    self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True,
                        atol=1e-6, rtol=1e-6)


  def testIlqrWithLqrProblem(self):
    randn = onp.random.RandomState(0).randn
    dim, T, num_iters = 2, 10, 3
    lqr = one_step_lqr(dim, T)
    p = control_from_lqr(lqr)
    x0 = randn(dim)
    X, U = control.ilqr(num_iters, p, x0, np.zeros((T, dim)))
    self.assertAllClose(X[0], x0, check_dtypes=True)
    self.assertAllClose(U[0], -x0, check_dtypes=True)
    self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True)
    self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True)


  def testIlqrWithLqrProblemSpecifiedGenerally(self):
    randn = onp.random.RandomState(0).randn
    dim, T, num_iters = 2, 10, 3
    p = one_step_control(dim, T)
    x0 = randn(dim)
    X, U = control.ilqr(num_iters, p, x0, np.zeros((T, dim)))
    self.assertAllClose(X[0], x0, check_dtypes=True)
    self.assertAllClose(U[0], -x0, check_dtypes=True)
    self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True)
    self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True)


  def testIlqrWithNonlinearProblem(self):
    def cost(t, x, u):
      return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1.)

    def dynamics(t, x, u):
      return (x ** 2. - u ** 2.) / (t + 1.)

    T, num_iters, d = 10, 7, 1
    p = control.ControlSpec(cost, dynamics, T, d, d)

    x0 = np.array([0.2])
    X, U = control.ilqr(num_iters, p, x0, 1e-5 * np.ones((T, d)))
    assert_close = partial(self.assertAllClose, atol=1e-2, check_dtypes=True)
    assert_close(X[0], x0)
    assert_close(U[0] ** 2., x0 ** 2.)
    assert_close(X[1:], np.zeros((T, d)))
    assert_close(U[1:], np.zeros((T - 1, d)))


  def testMpcWithLqrProblem(self):
    randn = onp.random.RandomState(0).randn
    dim, T, num_iters = 2, 10, 3
    lqr = one_step_lqr(dim, T)
    p = control_from_lqr(lqr)
    x0 = randn(dim)
    solver = partial(control.ilqr, num_iters)
    X, U = control.mpc_predict(solver, p, x0, np.zeros((T, dim)))
    self.assertAllClose(X[0], x0, check_dtypes=True)
    self.assertAllClose(U[0], -x0, check_dtypes=True)
    self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True)
    self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True)


  def testMpcWithLqrProblemSpecifiedGenerally(self):
    randn = onp.random.RandomState(0).randn
    dim, T, num_iters = 2, 10, 3
    p = one_step_control(dim, T)
    x0 = randn(dim)
    solver = partial(control.ilqr, num_iters)
    X, U = control.mpc_predict(solver, p, x0, np.zeros((T, dim)))
    self.assertAllClose(X[0], x0, check_dtypes=True)
    self.assertAllClose(U[0], -x0, check_dtypes=True)
    self.assertAllClose(X[1:], np.zeros((T, 2)), check_dtypes=True)
    self.assertAllClose(U[1:], np.zeros((T - 1, 2)), check_dtypes=True)


  def testMpcWithNonlinearProblem(self):
    def cost(t, x, u):
      return (x[0] ** 2. + 1e-3 * u[0] ** 2.) / (t + 1.)

    def dynamics(t, x, u):
      return (x ** 2. - u ** 2.) / (t + 1.)

    T, num_iters, d = 10, 7, 1
    p = control.ControlSpec(cost, dynamics, T, d, d)

    x0 = np.array([0.2])
    solver = partial(control.ilqr, num_iters)
    X, U = control.mpc_predict(solver, p, x0, 1e-5 * np.ones((T, d)))
    assert_close = partial(self.assertAllClose, atol=1e-2, check_dtypes=True)
    assert_close(X[0], x0)
    assert_close(U[0] ** 2., x0 ** 2.)
    assert_close(X[1:], np.zeros((T, d)))
    assert_close(U[1:], np.zeros((T - 1, d)))


if __name__ == '__main__':
  absltest.main()
back to top