https://github.com/google/jax
Raw File
Tip revision: 3acbd44952b86f54de6c937d9ca0874e47b382f9 authored by Yash Katariya on 02 February 2022, 00:38:12 UTC
Remove isinstance checks
Tip revision: 3acbd44
api_util.py
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
from functools import partial
from typing import Any, Dict, Iterable, Tuple, Union, Optional

import numpy as np

from jax import core
from jax._src import dtypes
from jax._src.tree_util import (
    PyTreeDef, tree_flatten, tree_unflatten, tree_multimap, tree_structure,
    treedef_children, treedef_is_leaf)
from jax._src.tree_util import _replace_nones
from jax import linear_util as lu
from jax._src.util import safe_map, WrapKwArgs, Hashable, Unhashable
from jax.core import unit

from jax._src import traceback_util
traceback_util.register_exclusion(__file__)

map = safe_map

def _ensure_index(x: Any) -> Union[int, Tuple[int, ...]]:
  """Ensure x is either an index or a tuple of indices."""
  try:
    return operator.index(x)
  except TypeError:
    return tuple(map(operator.index, x))

def _ensure_index_tuple(x: Any) -> Tuple[int, ...]:
  """Convert x to a tuple of indices."""
  try:
    return (operator.index(x),)
  except TypeError:
    return tuple(map(operator.index, x))

def _ensure_str(x: str) -> str:
  if not isinstance(x, str):
    raise TypeError(f"argument is not a string: {x}")
  return x

def _ensure_str_tuple(x: Union[str, Iterable[str]]) -> Tuple[str, ...]:
  """Convert x to a tuple of strings."""
  if isinstance(x, str):
    return (x,)
  else:
    return tuple(map(_ensure_str, x))

@lu.transformation_with_aux
def flatten_fun(in_tree, *args_flat):
  py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
  ans = yield py_args, py_kwargs
  yield tree_flatten(ans)

def apply_flat_fun(fun, io_tree, *py_args):
  in_tree_expected, out_tree = io_tree
  args, in_tree = tree_flatten((py_args, {}))
  if in_tree != in_tree_expected:
    raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
  ans = fun(*args)
  return tree_unflatten(out_tree, ans)

@lu.transformation_with_aux
def flatten_fun_nokwargs(in_tree, *args_flat):
  py_args = tree_unflatten(in_tree, args_flat)
  ans = yield py_args, {}
  yield tree_flatten(ans)

def apply_flat_fun_nokwargs(fun, io_tree, py_args):
  in_tree_expected, out_tree = io_tree
  args, in_tree = tree_flatten(py_args)
  if in_tree != in_tree_expected:
    raise TypeError("Expected {}, got {}".format(in_tree_expected, in_tree))
  ans = fun(*args)
  return tree_unflatten(out_tree, ans)

def flattened_fun_in_tree(fn: lu.WrappedFun) -> Optional[Tuple[PyTreeDef, bool]]:
  # This implementation relies on internal details of linear_util.py's
  # WrappedFun, but it's for the worthy cause of better user error messages.
  # It can fail (i.e. return None) if its WrappedFun argument is not transformed
  # with flatten_fun or flatten_fun_nokwargs, which could happen e.g. when
  # core.eval_jaxpr encounters a call primitive (though at that point we're just
  # round-tripping jaxprs and the user errors in question are impossible).
  assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
  assert (isinstance(flatten_fun_nokwargs, partial) and
          len(flatten_fun_nokwargs.args) == 1)
  flat_xforms = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
  try:
    (in_tree, has_kwargs), = ((args[0], f is flatten_fun.args[0])
                              for f, args in fn.transforms if f in flat_xforms)
  except ValueError:
    return None
  else:
    return in_tree, has_kwargs

@lu.transformation_with_aux
def flatten_fun_nokwargs2(in_tree, *args_flat):
  py_args = tree_unflatten(in_tree, args_flat)
  pair = yield py_args, {}
  if not isinstance(pair, (list, tuple)) or len(pair) != 2:
    raise TypeError("expected function with aux output to return a two-element "
                    f"tuple, but got type {type(pair)} with value {repr(pair)}")
  ans, aux = pair
  ans_flat, ans_tree = tree_flatten(ans)
  aux_flat, aux_tree = tree_flatten(aux)
  yield (ans_flat, aux_flat), (ans_tree, aux_tree)

class _HashableWithStrictTypeEquality:
  """Box object used when comparing static arguments as a jit key.

  Requires exact type equality using `is` and value equality."""
  __slots__ = ["val"]

  def __init__(self, val):
    self.val = val

  def __hash__(self):
    return hash(self.val)

  def __eq__(self, other):
    return type(self.val) is type(other.val) and self.val == other.val


def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True):
  dyn_argnums = _ensure_index_tuple(dyn_argnums)
  fixed_args = [unit] * len(args)
  for i, arg in enumerate(args):
    if i in dyn_argnums: continue
    if require_static_args_hashable:
      if not is_hashable(arg):
        raise ValueError(
            "Non-hashable static arguments are not supported, as this can lead "
            f"to unexpected cache-misses. Static argument (index {i}) of type "
            f"{type(arg)} for function {f.__name__} is non-hashable.")
      fixed_args[i] = _HashableWithStrictTypeEquality(arg)
    else:
      fixed_args[i] = Unhashable(arg)

  dyn_args = tuple(args[i] for i in dyn_argnums)
  return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args

def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
                           args: Tuple[Any], *, allow_invalid: bool):
  """Version of ``argnums_partial`` that checks hashability of static_argnums."""
  if not static_argnums:
    return f, args
  dyn_argnums = tuple(i for i in range(len(args)) if i not in static_argnums)
  dyn_args = tuple(args[i] for i in dyn_argnums)

  fixed_args = [unit] * len(args)  # type: ignore
  for i in static_argnums:
    # TODO(shoyer): set allow_invalid=True permanently after enabling
    # static_argnames.
    if allow_invalid and i >= len(args):
      continue
    static_arg = args[i]
    try:
      hash(static_arg)
    except TypeError:
      raise ValueError(
          "Non-hashable static arguments are not supported, as this can lead "
          f"to unexpected cache-misses. Static argument (index {i}) of type "
          f"{type(static_arg)} for function {f.__name__} is non-hashable.")
    else:
      fixed_args[i] = _HashableWithStrictTypeEquality(static_arg)  # type: ignore

  return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args


@lu.transformation
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
  args = [None if arg is unit else arg.val for arg in fixed_args]
  for i, arg in zip(dyn_argnums, dyn_args):
    args[i] = arg
  ans = yield args, kwargs
  yield ans


def argnames_partial_except(f: lu.WrappedFun, static_argnames: Tuple[str, ...],
                            kwargs: Dict[str, Any]):
  if not static_argnames:
    return f, kwargs
  dyn_kwargs = {k: v for k, v in kwargs.items() if k not in static_argnames}

  fixed_kwargs: Dict[str, Any] = {}
  for k, arg in kwargs.items():
    if k in dyn_kwargs:
      fixed_kwargs[k] = unit
    else:
      try:
        hash(arg)
      except TypeError:
        raise ValueError(
            "Non-hashable static arguments are not supported, as this can lead "
            f"to unexpected cache-misses. Static argument (name {k}) of type "
            f"{type(arg)} for function {f.__name__} is non-hashable.")
      else:
        fixed_kwargs[k] = Hashable(arg)  # type: ignore

  return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs


@lu.transformation
def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
  kwargs = {k: None if arg is unit else arg.val
            for k, arg in fixed_kwargs.val.items()}
  kwargs.update(dyn_kwargs)
  ans = yield args, kwargs
  yield ans


def donation_vector(donate_argnums, args, kwargs) -> Tuple[bool, ...]:
  """Returns a tuple with a boolean value for each leaf in args."""
  res = []
  for i, arg in enumerate(args):
    donate = bool(i in donate_argnums)
    res.extend((donate,) * tree_structure(arg).num_leaves)
  res.extend((False,) * tree_structure(kwargs).num_leaves)
  return tuple(res)

def rebase_donate_argnums(donate_argnums, static_argnums) -> Tuple[int, ...]:
  """Shifts donate to account for static.

  >>> rebase_donate_argnums((3, 4), (0, 1))
  (1, 2)

  Args:
    donate_argnums: An iterable of ints.
    static_argnums: An iterable of ints.

  Returns:
    A tuple of unique, sorted integer values based on donate_argnums with each
    element offset to account for static_argnums.
  """
  if not (static_argnums or donate_argnums):
    return tuple(sorted(donate_argnums))

  static_argnums = sorted(set(static_argnums))
  donate_argnums = sorted(set(donate_argnums))
  i = j = o = 0
  out = []
  while j < len(donate_argnums):
    if i < len(static_argnums) and static_argnums[i] == donate_argnums[j]:
      raise ValueError(f"`static_argnums` {static_argnums} and "
                       f"`donate_argnums` {donate_argnums} cannot intersect.")

    if i < len(static_argnums) and static_argnums[i] < donate_argnums[j]:
      o += 1
      i += 1
    else:
      out.append(donate_argnums[j] - o)
      j += 1
  return tuple(out)


def is_hashable(arg):
  try:
    hash(arg)
    return True
  except TypeError:
    return False


def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False):
  # given an axis spec tree axis_tree (a pytree with integers and Nones at the
  # leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of
  # the given treedef, build a complete axis spec tree with the same structure
  # and return the flattened result
  # TODO(mattjj,phawkins): improve this implementation

  proxy = object()
  dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)
  axes = []
  add_leaves = lambda i, x: axes.extend([i] * len(tree_flatten(x)[0]))
  try:
    tree_multimap(add_leaves, _replace_nones(proxy, axis_tree), dummy)
  except ValueError:
    if kws:
      # if keyword arguments are included in the tree, we make adapt the error
      # message only to be about the positional arguments
      treedef, leaf = treedef_children(treedef)
      assert treedef_is_leaf(leaf)
      axis_tree, _ = axis_tree
    hint = ""
    if tupled_args:
      hint += (f" Note that {name} that are non-trivial pytrees should always be "
               f"wrapped in a tuple representing the argument list.")
      if len(treedef.children()) == 1:
        try:
          flatten_axes(name, treedef, (axis_tree,))
        except ValueError:
          pass  # That's not the issue.
        else:
          hint += (f" In particular, you're passing in a single argument which "
                   f"means that {name} might need to be wrapped in "
                   f"a singleton tuple.")
    raise ValueError(f"{name} specification must be a tree prefix of the "
                     f"corresponding value, got specification {axis_tree} "
                     f"for value tree {treedef}.{hint}") from None
  axes = [None if a is proxy else a for a in axes]
  assert len(axes) == treedef.num_leaves
  return axes

def _dtype(x):
  try:
    return dtypes.result_type(x)
  except ValueError:
    return dtypes.result_type(getattr(x, 'dtype'))

def shaped_abstractify(x):
  try:
    return core.raise_to_shaped(core.get_aval(x))
  except TypeError:
    pass

  weak_type = getattr(x, 'weak_type', False)
  named_shape = getattr(x, 'named_shape', {})
  return core.ShapedArray(np.shape(x), _dtype(x), weak_type=weak_type,
                          named_shape=named_shape)

# This decorator exists to make it easier to monkey-patch APIs in JAX.
# By default it does nothing, but it can be monkey-patched to do other things.
def api_hook(fun, tag: str):
  return fun
back to top