Raw File
{
 "cells": [
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "---\n",
    "Copyright 2021 Google LLC\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at\n",
    "\n",
    "    https://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open in\n",
    "Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Autodidax: JAX core from scratch\n",
    "\n",
    "Ever want to learn how JAX works, but the implementation seemed impenetrable?\n",
    "Well, you're in luck! By reading this tutorial, you'll learn every big idea in\n",
    "JAX's core system. You'll even get clued into our weird jargon!\n",
    "\n",
    "**This is a work-in-progress draft.** There are some important ingredients\n",
    "missing, still to come in parts 5 and 6 (and more?). There are also some\n",
    "simplifications here that we haven't yet applied to the main system, but we\n",
    "will."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 1: Transformations as interpreters: standard evaluation, `jvp`, and `vmap`\n",
    "\n",
    "We want to transform functions that look like this:\n",
    "\n",
    "```python\n",
    "def f(x):\n",
    "  y = sin(x) * 2.\n",
    "  z = - y + x\n",
    "  return z\n",
    "```\n",
    "\n",
    "Think of functions like `sin` and the arithmetic operations underlying the\n",
    "infix operators (`mul`, `add`, and `neg`) as primitive operations, meaning\n",
    "atomic units of processing rather than compositions.\n",
    "\n",
    "\"Transform\" means \"interpret differently.\" Instead of standard interpretation\n",
    "where we apply primitive operations to numerical inputs to produce numerical\n",
    "outputs, we want to override primitive application and let different values\n",
    "flow through our program. For example, we might want to replace the\n",
    "application of every primitive with an application of [its JVP\n",
    "rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),\n",
    "and let primal-tangent pairs flow through our program. Moreover, we want to be\n",
    "able to compose multiple transformations, leading to stacks of interpreters."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### JAX core machinery\n",
    "\n",
    "We can implement stacks of interpreters and even have them all discharge on\n",
    "the fly as we execute the Python function to be transformed. To start, let's\n",
    "define these primitives so that we can intercept their application:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import NamedTuple\n",
    "\n",
    "class Primitive(NamedTuple):\n",
    "  name: str\n",
    "\n",
    "add_p = Primitive('add')\n",
    "mul_p = Primitive('mul')\n",
    "neg_p = Primitive(\"neg\")\n",
    "sin_p = Primitive(\"sin\")\n",
    "cos_p = Primitive(\"cos\")\n",
    "reduce_sum_p = Primitive(\"reduce_sum\")\n",
    "greater_p = Primitive(\"greater\")\n",
    "less_p = Primitive(\"less\")\n",
    "transpose_p = Primitive(\"transpose\")\n",
    "broadcast_p = Primitive(\"broadcast\")\n",
    "\n",
    "def add(x, y): return bind1(add_p, x, y)\n",
    "def mul(x, y): return bind1(mul_p, x, y)\n",
    "def neg(x): return bind1(neg_p, x)\n",
    "def sin(x): return bind1(sin_p, x)\n",
    "def cos(x): return bind1(cos_p, x)\n",
    "def greater(x, y): return bind1(greater_p, x, y)\n",
    "def less(x, y): return bind1(less_p, x, y)\n",
    "def transpose(x, perm): return bind1(transpose_p, x, perm=perm)\n",
    "def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)\n",
    "def reduce_sum(x, axis=None):\n",
    "  if axis is None:\n",
    "    axis = tuple(range(np.ndim(x)))\n",
    "  if type(axis) is int:\n",
    "    axis = (axis,)\n",
    "  return bind1(reduce_sum_p, x, axis=axis)\n",
    "\n",
    "def bind1(prim, *args, **params):\n",
    "  out, = bind(prim, *args, **params)\n",
    "  return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll set up array data types and infix operator methods in a moment.\n",
    "\n",
    "A `Primitive` is just an object with a name, to which we attach our\n",
    "interpretation rules (one for each transformation). The `bind` function is our\n",
    "interception point: it'll figure out which transformation rule to apply, based\n",
    "on how the arguments are boxed in tracers and what interpreters are active.\n",
    "\n",
    "The functions that user code calls, like `add` and `sin`, are just wrappers\n",
    "around calls to `bind`. These wrappers let us control how arguments are passed\n",
    "to `bind`, and in particular we follow a handy internal convention: when we\n",
    "call `bind`, we pass values representing array data as positional arguments,\n",
    "and we pass metadata like the `axis` argument to `sum_p` via keyword. This\n",
    "calling convention simplifies some core logic (since e.g. instances of the\n",
    "`Tracer` class to be defined below can only occur in positional arguments to\n",
    "`bind`). The wrappers can also provide docstrings!\n",
    "\n",
    "We represent active interpreters as a stack. The stack is just a simple\n",
    "`list`, and each element is a container with an integer level (corresponding\n",
    "to the element's height in the stack), an interpreter type (which we'll call a\n",
    "`trace_type`), and an optional field for any global data the interpreter\n",
    "needs. We call each element a `MainTrace`, though maybe \"Interpreter\" would be\n",
    "more descriptive."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from contextlib import contextmanager\n",
    "from typing import Type, List, Tuple, Sequence, Optional, Any\n",
    "\n",
    "class MainTrace(NamedTuple):\n",
    "  level: int\n",
    "  trace_type: Type['Trace']\n",
    "  global_data: Optional[Any]\n",
    "\n",
    "trace_stack: List[MainTrace] = []\n",
    "dynamic_trace: Optional[MainTrace] = None  # to be employed in Part 3\n",
    "\n",
    "@contextmanager\n",
    "def new_main(trace_type: Type['Trace'], global_data=None):\n",
    "  level = len(trace_stack)\n",
    "  main = MainTrace(level, trace_type, global_data)\n",
    "  trace_stack.append(main)\n",
    "\n",
    "  try:\n",
    "    yield main\n",
    "  finally:\n",
    "    trace_stack.pop()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When we're about to apply a transformation, we'll push another interpreter\n",
    "onto the stack using `new_main`. Then, as we apply primitives in the function,\n",
    "we can think of the `bind` first being interpreted by the trace at the top of\n",
    "the stack (i.e. with the highest level). If that first interpreter itself\n",
    "binds other primitives in its interpretation rule for the primitive, like how\n",
    "the JVP rule of `sin_p` might bind `cos_p` and `mul_p`, then those `bind`\n",
    "calls will be handled by the interpreter at the next level down.\n",
    "\n",
    "What goes at the bottom of the interpreter stack? At the bottom, we know all\n",
    "the transformation interpreters are finished, and we just want to do standard\n",
    "evaluation. So at the bottom we'll put an evaluation interpreter.\n",
    "\n",
    "Let's sketch out the interface for interpreters, which is based on the `Trace`\n",
    "and `Tracer` base classes. A `Tracer` represents a boxed-up value, perhaps\n",
    "carrying some extra context data used by the interpreter. A `Trace` handles\n",
    "boxing up values into `Tracers` and also handles primitive application."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Trace:\n",
    "  main: MainTrace\n",
    "\n",
    "  def __init__(self, main: MainTrace) -> None:\n",
    "    self.main = main\n",
    "\n",
    "  def pure(self, val): assert False  # must override\n",
    "  def lift(self, val): assert False  # must override\n",
    "\n",
    "  def process_primitive(self, primitive, tracers, params):\n",
    "    assert False  # must override"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first two methods are about boxing up values in `Tracer`s, which are the\n",
    "objects that flow through the Python programs we transform. The last method is\n",
    "the callback we'll use to interpret primitive application.\n",
    "\n",
    "The `Trace` itself doesn't contain any data, other than a reference to its\n",
    "corresponding `MainTrace` instance. In fact, multiple instances of a `Trace`\n",
    "might be created and discarded during an application of a transformation,\n",
    "whereas only a single `MainTrace` instance is created per application of a\n",
    "transformation.\n",
    "\n",
    "As for `Tracer`s themselves, each one carries an abstract value (and forwards\n",
    "infix operators to it), and the rest is up to the transformation. (The\n",
    "relationship between `Tracer`s and `AbstractValue`s is that there's one\n",
    "`Tracer` per transformation, and at least one `AbstractValue` per base type,\n",
    "like arrays.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "class Tracer:\n",
    "  _trace: Trace\n",
    "\n",
    "  __array_priority__ = 1000\n",
    "\n",
    "  @property\n",
    "  def aval(self):\n",
    "    assert False  # must override\n",
    "\n",
    "  def full_lower(self):\n",
    "    return self  # default implementation\n",
    "\n",
    "  def __neg__(self): return self.aval._neg(self)\n",
    "  def __add__(self, other): return self.aval._add(self, other)\n",
    "  def __radd__(self, other): return self.aval._radd(self, other)\n",
    "  def __mul__(self, other): return self.aval._mul(self, other)\n",
    "  def __rmul__(self, other): return self.aval._rmul(self, other)\n",
    "  def __gt__(self, other): return self.aval._gt(self, other)\n",
    "  def __lt__(self, other): return self.aval._lt(self, other)\n",
    "  def __bool__(self): return self.aval._bool(self)\n",
    "  def __nonzero__(self): return self.aval._nonzero(self)\n",
    "\n",
    "  def __getattr__(self, name):\n",
    "    try:\n",
    "      return getattr(self.aval, name)\n",
    "    except AttributeError:\n",
    "      raise AttributeError(f\"{self.__class__.__name__} has no attribute {name}\")\n",
    "\n",
    "def swap(f): return lambda x, y: f(y, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ShapedArray:\n",
    "  array_abstraction_level = 1\n",
    "  shape: Tuple[int]\n",
    "  dtype: np.dtype\n",
    "\n",
    "  def __init__(self, shape, dtype):\n",
    "    self.shape = shape\n",
    "    self.dtype = dtype\n",
    "\n",
    "  @property\n",
    "  def ndim(self):\n",
    "    return len(self.shape)\n",
    "\n",
    "  _neg = staticmethod(neg)\n",
    "  _add = staticmethod(add)\n",
    "  _radd = staticmethod(swap(add))\n",
    "  _mul = staticmethod(mul)\n",
    "  _rmul = staticmethod(swap(mul))\n",
    "  _gt = staticmethod(greater)\n",
    "  _lt = staticmethod(less)\n",
    "\n",
    "  @staticmethod\n",
    "  def _bool(tracer):\n",
    "    raise Exception(\"ShapedArray can't be unambiguously converted to bool\")\n",
    "\n",
    "  @staticmethod\n",
    "  def _nonzero(tracer):\n",
    "    raise Exception(\"ShapedArray can't be unambiguously converted to bool\")\n",
    "\n",
    "  def str_short(self):\n",
    "    return f'{self.dtype.name}[{\",\".join(str(d) for d in self.shape)}]'\n",
    "\n",
    "  def __hash__(self):\n",
    "    return hash((self.shape, self.dtype))\n",
    "\n",
    "  def __eq__(self, other):\n",
    "    return (type(self) is type(other) and\n",
    "            self.shape == other.shape and self.dtype == other.dtype)\n",
    "\n",
    "  def __repr__(self):\n",
    "    return f\"ShapedArray(shape={self.shape}, dtype={self.dtype})\"\n",
    "\n",
    "class ConcreteArray(ShapedArray):\n",
    "  array_abstraction_level = 2\n",
    "  val: np.ndarray\n",
    "\n",
    "  def __init__(self, val):\n",
    "    self.val = val\n",
    "    self.shape = val.shape\n",
    "    self.dtype = val.dtype\n",
    "\n",
    "  @staticmethod\n",
    "  def _bool(tracer):\n",
    "    return bool(tracer.aval.val)\n",
    "\n",
    "  @staticmethod\n",
    "  def _nonzero(tracer):\n",
    "    return bool(tracer.aval.val)\n",
    "\n",
    "def get_aval(x):\n",
    "  if isinstance(x, Tracer):\n",
    "    return x.aval\n",
    "  elif type(x) in jax_types:\n",
    "    return ConcreteArray(np.asarray(x))\n",
    "  else:\n",
    "    raise TypeError(x)\n",
    "\n",
    "jax_types = {bool, int, float,\n",
    "             np.bool_, np.int32, np.int64, np.float32, np.float64, np.ndarray}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that we actually have two `AbstractValue`s for arrays, representing\n",
    "different levels of abstraction. A `ShapedArray` represents the set of all\n",
    "possible arrays with a given shape and dtype. A `ConcreteArray` represents a\n",
    "singleton set consisting of a single array value.\n",
    "\n",
    "Now that we've set up the interpreter stack, the Trace/Tracer API for\n",
    "interpreters, and abstract values, we can come back to implement `bind`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bind(prim, *args, **params):\n",
    "  top_trace = find_top_trace(args)\n",
    "  tracers = [full_raise(top_trace, arg) for arg in args]\n",
    "  outs = top_trace.process_primitive(prim, tracers, params)\n",
    "  return [full_lower(out) for out in outs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The main action is that we call `find_top_trace` to figure out which\n",
    "interpreter should handle this primitive application. We then call that top\n",
    "trace's `process_primitive` so that the trace can apply its interpretation\n",
    "rule. The calls to `full_raise` just ensure that the inputs are boxed in the\n",
    "top trace's `Tracer` instances, and the call to `full_lower` is an optional\n",
    "optimization so that we unbox values out of `Tracer`s as much as possible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import operator as op\n",
    "\n",
    "def find_top_trace(xs) -> Trace:\n",
    "  top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),\n",
    "                 default=trace_stack[0], key=op.attrgetter('level'))\n",
    "  if dynamic_trace and dynamic_trace.level > top_main.level:\n",
    "    top_main = dynamic_trace\n",
    "  return top_main.trace_type(top_main)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In words, ignoring the `dynamic_trace` step until Part 3, `find_top_trace`\n",
    "returns the highest-level interpreter associated with the `Tracer`s on its\n",
    "inputs, and otherwise returns the interpreter at the bottom of the stack\n",
    "(which is always an evaluation trace, at least for now). This is a deviation\n",
    "from the description above, where we always start by running the interpreter\n",
    "at the top of the stack and then work our way down, applying every interpreter\n",
    "in the stack. Instead, we're only applying an interpreter when the input\n",
    "arguments to a primitive bind are boxed in a `Tracer` corresponding to that\n",
    "interpreter. This optimization lets us skip irrelevant transformations, but\n",
    "bakes in an assumption that transformations mostly follow data dependence\n",
    "(except for the special bottom-of-the-stack interpreter, which interprets\n",
    "everything).\n",
    "\n",
    "An alternative would be to have every interpreter in the stack interpret every\n",
    "operation. That's worth exploring! JAX is designed around data dependence in\n",
    "large part because that's so natural for automatic differentiation, and JAX's\n",
    "roots are in autodiff. But it may be over-fit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def full_lower(val: Any):\n",
    "  if isinstance(val, Tracer):\n",
    "    return val.full_lower()\n",
    "  else:\n",
    "    return val\n",
    "\n",
    "def full_raise(trace: Trace, val: Any) -> Tracer:\n",
    "  if not isinstance(val, Tracer):\n",
    "    assert type(val) in jax_types\n",
    "    return trace.pure(val)\n",
    "  level = trace.main.level\n",
    "  if val._trace.main is trace.main:\n",
    "    return val\n",
    "  elif val._trace.main.level < level:\n",
    "    return trace.lift(val)\n",
    "  elif val._trace.main.level > level:\n",
    "    raise Exception(f\"Can't lift level {val._trace.main.level} to {level}.\")\n",
    "  else:  # val._trace.level == level\n",
    "    raise Exception(f\"Different traces at same level: {val._trace}, {trace}.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The logic in `full_raise` serves to box values into `Tracer`s for a particular\n",
    "`Trace`, calling different methods on the `Trace` based on context:\n",
    "`Trace.pure` is called on non-`Tracer` constants, and `Trace.lift` is called\n",
    "for values that are already `Tracer`s from a lower-level interpreter. These\n",
    "two methods could share the same implementation, but by distinguishing them in\n",
    "the core logic we can provide more information to the `Trace` subclass.\n",
    "\n",
    "That's it for the JAX core! Now we can start adding interpreters."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation interpreter\n",
    "\n",
    "We'll start with the simplest interpreter: the evaluation interpreter that\n",
    "will sit at the bottom of the interpreter stack."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EvalTrace(Trace):\n",
    "  pure = lift = lambda self, x: x  # no boxing in Tracers needed\n",
    "\n",
    "  def process_primitive(self, primitive, tracers, params):\n",
    "    return impl_rules[primitive](*tracers, **params)\n",
    "\n",
    "trace_stack.append(MainTrace(0, EvalTrace, None))  # special bottom of the stack\n",
    "\n",
    "# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance\n",
    "impl_rules = {}\n",
    "\n",
    "impl_rules[add_p] = lambda x, y: [np.add(x, y)]\n",
    "impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]\n",
    "impl_rules[neg_p] = lambda x: [np.negative(x)]\n",
    "impl_rules[sin_p] = lambda x: [np.sin(x)]\n",
    "impl_rules[cos_p] = lambda x: [np.cos(x)]\n",
    "impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]\n",
    "impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]\n",
    "impl_rules[less_p] = lambda x, y: [np.less(x, y)]\n",
    "impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]\n",
    "\n",
    "def broadcast_impl(x, *, shape, axes):\n",
    "  for axis in sorted(axes):\n",
    "    x = np.expand_dims(x, axis)\n",
    "  return [np.broadcast_to(x, shape)]\n",
    "impl_rules[broadcast_p] = broadcast_impl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With this interpreter, we can evaluate user functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x):\n",
    "  y = sin(x) * 2.\n",
    "  z = - y + x\n",
    "  return z\n",
    "\n",
    "print(f(3.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Woo! Like going around in a big circle. But the point of this indirection is\n",
    "that now we can add some real transformations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Forward-mode autodiff with `jvp`\n",
    "\n",
    "First, a few helper functions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def zeros_like(val):\n",
    "  aval = get_aval(val)\n",
    "  return np.zeros(aval.shape, aval.dtype)\n",
    "\n",
    "def unzip2(pairs):\n",
    "  lst1, lst2 = [], []\n",
    "  for x1, x2 in pairs:\n",
    "    lst1.append(x1)\n",
    "    lst2.append(x2)\n",
    "  return lst1, lst2\n",
    "\n",
    "map_ = map\n",
    "def map(f, *xs):\n",
    "  return list(map_(f, *xs))\n",
    "\n",
    "zip_ = zip\n",
    "def zip(*args):\n",
    "  fst, *rest = args = map(list, args)\n",
    "  n = len(fst)\n",
    "  for arg in rest:\n",
    "    assert len(arg) == n\n",
    "  return list(zip_(*args))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The\n",
    "`Trace` applies JVP rules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class JVPTracer(Tracer):\n",
    "  def __init__(self, trace, primal, tangent):\n",
    "    self._trace = trace\n",
    "    self.primal = primal\n",
    "    self.tangent = tangent\n",
    "\n",
    "  @property\n",
    "  def aval(self):\n",
    "    return get_aval(self.primal)\n",
    "\n",
    "class JVPTrace(Trace):\n",
    "  pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))\n",
    "\n",
    "  def process_primitive(self, primitive, tracers, params):\n",
    "    primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)\n",
    "    jvp_rule = jvp_rules[primitive]\n",
    "    primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)\n",
    "    return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]\n",
    "\n",
    "jvp_rules = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice both `pure` and `lift` package a value into a `JVPTracer` with the\n",
    "minimal amount of context, which is a zero tangent value.\n",
    "\n",
    "Let's add some JVP rules for primitives:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_jvp(primals, tangents):\n",
    "  (x, y), (x_dot, y_dot) = primals, tangents\n",
    "  return [x + y], [x_dot + y_dot]\n",
    "jvp_rules[add_p] = add_jvp\n",
    "\n",
    "def mul_jvp(primals, tangents):\n",
    "  (x, y), (x_dot, y_dot) = primals, tangents\n",
    "  return [x * y], [x_dot * y + x * y_dot]\n",
    "jvp_rules[mul_p] = mul_jvp\n",
    "\n",
    "def sin_jvp(primals, tangents):\n",
    "  (x,), (x_dot,) = primals, tangents\n",
    "  return [sin(x)], [cos(x) * x_dot]\n",
    "jvp_rules[sin_p] = sin_jvp\n",
    "\n",
    "def cos_jvp(primals, tangents):\n",
    "  (x,), (x_dot,) = primals, tangents\n",
    "  return [cos(x)], [-sin(x) * x_dot]\n",
    "jvp_rules[cos_p] = cos_jvp\n",
    "\n",
    "def neg_jvp(primals, tangents):\n",
    "  (x,), (x_dot,) = primals, tangents\n",
    "  return [neg(x)], [neg(x_dot)]\n",
    "jvp_rules[neg_p] = neg_jvp\n",
    "\n",
    "def reduce_sum_jvp(primals, tangents, *, axis):\n",
    "  (x,), (x_dot,) = primals, tangents\n",
    "  return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]\n",
    "jvp_rules[reduce_sum_p] = reduce_sum_jvp\n",
    "\n",
    "def greater_jvp(primals, tangents):\n",
    "  (x, y), _ = primals, tangents\n",
    "  out_primal = greater(x, y)\n",
    "  return [out_primal], [zeros_like(out_primal)]\n",
    "jvp_rules[greater_p] = greater_jvp\n",
    "\n",
    "def less_jvp(primals, tangents):\n",
    "  (x, y), _ = primals, tangents\n",
    "  out_primal = less(x, y)\n",
    "  return [out_primal], [zeros_like(out_primal)]\n",
    "jvp_rules[less_p] = less_jvp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we add a transformation API to kick off the trace:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def jvp_v1(f, primals, tangents):\n",
    "  with new_main(JVPTrace) as main:\n",
    "    trace = JVPTrace(main)\n",
    "    tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]\n",
    "    out = f(*tracers_in)\n",
    "    tracer_out = full_raise(trace, out)\n",
    "    primal_out, tangent_out = tracer_out.primal, tracer_out.tangent\n",
    "  return primal_out, tangent_out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And with that, we can differentiate!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = 3.0\n",
    "y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))\n",
    "print(sin_deriv_at_3)\n",
    "print(cos(3.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x):\n",
    "  y = sin(x) * 2.\n",
    "  z = - y + x\n",
    "  return z\n",
    "\n",
    "x, xdot = 3., 1.\n",
    "y, ydot = jvp_v1(f, (x,), (xdot,))\n",
    "print(y)\n",
    "print(ydot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def deriv(f):\n",
    "  return lambda x: jvp_v1(f, (x,), (1.,))[1]\n",
    "\n",
    "print(deriv(sin)(3.))\n",
    "print(deriv(deriv(sin))(3.))\n",
    "print(deriv(deriv(deriv(sin)))(3.))\n",
    "print(deriv(deriv(deriv(deriv(sin))))(3.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x):\n",
    "  if x > 0.:  # Python control flow\n",
    "    return 2. * x\n",
    "  else:\n",
    "    return x\n",
    "\n",
    "print(deriv(f)(3.))\n",
    "print(deriv(f)(-3.))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pytrees and flattening user functions' inputs and outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A limitation with `jvp_v1` is that it assumes the user function accepts arrays\n",
    "as positional arguments and produces a single array as output. What if it\n",
    "produced a list as output? Or accepted nested containers as inputs? It would\n",
    "be a pain to deal with all the possible containers in inputs and outputs at\n",
    "every layer of the stack. Instead, we can wrap the user function so that the\n",
    "wrapped version accepts arrays as inputs and returns a flat list of arrays as\n",
    "output. The wrapper just needs to unflatten its input, call the user function,\n",
    "and flatten the output.\n",
    "\n",
    "Here's how we'd like to write `jvp`, assuming the user always gives us\n",
    "functions that take arrays as inputs and produces a flat list of arrays as\n",
    "outputs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def jvp_flat(f, primals, tangents):\n",
    "  with new_main(JVPTrace) as main:\n",
    "    trace = JVPTrace(main)\n",
    "    tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]\n",
    "    outs = f(*tracers_in)\n",
    "    tracers_out = [full_raise(trace, out) for out in outs]\n",
    "    primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)\n",
    "  return primals_out, tangents_out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To support user functions that have arbitrary containers in the inputs and\n",
    "outputs, here's how we'd write the user-facing `jvp` wrapper:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def jvp(f, primals, tangents):\n",
    "  primals_flat, in_tree = tree_flatten(primals)\n",
    "  tangents_flat, in_tree2 = tree_flatten(tangents)\n",
    "  if in_tree != in_tree2: raise TypeError\n",
    "  f, out_tree = flatten_fun(f, in_tree)\n",
    "  primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)\n",
    "  primals_out = tree_unflatten(out_tree(), primals_out_flat)\n",
    "  tangents_out = tree_unflatten(out_tree(), tangents_out_flat)\n",
    "  return primals_out, tangents_out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that we had to plumb the tree structure of the user function output\n",
    "back to the caller of `flatten_fun`. That information isn't available until we\n",
    "actually run the user function, so `flatten_fun` just returns a reference to a\n",
    "mutable cell, represented as a thunk. These side-effects are safe because we\n",
    "always run the user function exactly once. (This safe regime is the reason for\n",
    "the \"linear\" name in `linear_util.py`, in the sense of [linear\n",
    "types](https://en.wikipedia.org/wiki/Substructural_type_system).)\n",
    "\n",
    "All that remains is to write `tree_flatten`, `tree_unflatten`, and\n",
    "`flatten_fun`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "def flatten_fun(f, in_tree):\n",
    "  store = Store()\n",
    "\n",
    "  def flat_fun(*args_flat):\n",
    "    pytree_args = tree_unflatten(in_tree, args_flat)\n",
    "    out = f(*pytree_args)\n",
    "    out_flat, out_tree = tree_flatten(out)\n",
    "    store.set_value(out_tree)\n",
    "    return out_flat\n",
    "\n",
    "  return flat_fun, store\n",
    "\n",
    "class Empty: pass\n",
    "empty = Empty()\n",
    "\n",
    "class Store:\n",
    "  val = empty\n",
    "\n",
    "  def set_value(self, val):\n",
    "    assert self.val is empty\n",
    "    self.val = val\n",
    "\n",
    "  def __call__(self):\n",
    "    return self.val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import itertools as it\n",
    "from typing import Callable, Type, Hashable, Dict, Iterable, Iterator\n",
    "\n",
    "class NodeType(NamedTuple):\n",
    "  name: str\n",
    "  to_iterable: Callable\n",
    "  from_iterable: Callable\n",
    "\n",
    "def register_pytree_node(ty: Type, to_iter: Callable, from_iter: Callable\n",
    "                         ) -> None:\n",
    "  node_types[ty] = NodeType(str(ty), to_iter, from_iter)\n",
    "\n",
    "node_types: Dict[Type, NodeType] = {}\n",
    "register_pytree_node(tuple, lambda t: (None, t), lambda _, xs: tuple(xs))\n",
    "register_pytree_node(list,  lambda l: (None, l), lambda _, xs:  list(xs))\n",
    "register_pytree_node(dict,\n",
    "                     lambda d: map(tuple, unzip2(sorted(d.items()))),\n",
    "                     lambda keys, vals: dict(zip(keys, vals)))\n",
    "\n",
    "class PyTreeDef(NamedTuple):\n",
    "  node_type: NodeType\n",
    "  node_metadata: Hashable\n",
    "  child_treedefs: Tuple['PyTreeDef']\n",
    "\n",
    "class Leaf: pass\n",
    "leaf = Leaf()\n",
    "\n",
    "def tree_flatten(x: Any) -> Tuple[List[Any], PyTreeDef]:\n",
    "  children_iter, treedef = _tree_flatten(x)\n",
    "  return list(children_iter), treedef\n",
    "\n",
    "def _tree_flatten(x: Any) -> Tuple[Iterable, PyTreeDef]:\n",
    "  node_type = node_types.get(type(x))\n",
    "  if node_type:\n",
    "    node_metadata, children = node_type.to_iterable(x)\n",
    "    children_flat, child_trees = unzip2(map(_tree_flatten, children))\n",
    "    flattened = it.chain.from_iterable(children_flat)\n",
    "    return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))\n",
    "  else:\n",
    "    return [x], leaf\n",
    "\n",
    "def tree_unflatten(treedef: PyTreeDef, xs: List[Any]) -> Any:\n",
    "  return _tree_unflatten(treedef, iter(xs))\n",
    "\n",
    "def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any:\n",
    "  if treedef is leaf:\n",
    "    return next(xs)\n",
    "  else:\n",
    "    children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)\n",
    "    return treedef.node_type.from_iterable(treedef.node_metadata, children)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With this pytree-handling `jvp` implementation, we can now handle arbitrary\n",
    "input and output containers. That'll come in handy with future transformations\n",
    "too!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x):\n",
    "  y = sin(x) * 2.\n",
    "  z = - y + x\n",
    "  return {'hi': z, 'there': [x, y]}\n",
    "\n",
    "x, xdot = 3., 1.\n",
    "y, ydot = jvp(f, (x,), (xdot,))\n",
    "print(y)\n",
    "print(ydot)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Vectorized batching with `vmap`\n",
    "\n",
    "First, a couple helper functions, one for producing mapped abstract values\n",
    "from unmapped ones (by removing an axis), and one for moving batch dimensions\n",
    "around:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mapped_aval(batch_dim, aval):\n",
    "  shape = list(aval.shape)\n",
    "  del shape[batch_dim]\n",
    "  return ShapedArray(tuple(shape), aval.dtype)\n",
    "\n",
    "def move_batch_axis(axis_size, src, dst, x):\n",
    "  if src is not_mapped:\n",
    "    target_shape = list(np.shape(x))\n",
    "    target_shape.insert(dst, axis_size)\n",
    "    return broadcast(x, target_shape, [dst])\n",
    "  elif src == dst:\n",
    "    return x\n",
    "  else:\n",
    "    return moveaxis(x, src, dst)\n",
    "\n",
    "def moveaxis(x, src: int, dst: int):\n",
    "  perm = [i for i in range(np.ndim(x)) if i != src]\n",
    "  perm.insert(dst, src)\n",
    "  return transpose(x, perm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `Tracer` for vectorized batching carries a batched value and an optional\n",
    "integer indicating which axis (if any) is the batch axis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Union\n",
    "\n",
    "class NotMapped: pass\n",
    "not_mapped = NotMapped()\n",
    "\n",
    "BatchAxis = Union[NotMapped, int]\n",
    "\n",
    "class BatchTracer(Tracer):\n",
    "  def __init__(self, trace, val, batch_dim: BatchAxis):\n",
    "    self._trace = trace\n",
    "    self.val = val\n",
    "    self.batch_dim = batch_dim\n",
    "\n",
    "  @property\n",
    "  def aval(self):\n",
    "    if self.batch_dim is not_mapped:\n",
    "      return get_aval(self.val)\n",
    "    else:\n",
    "      return mapped_aval(self.batch_dim, get_aval(self.val))\n",
    "\n",
    "  def full_lower(self):\n",
    "    if self.batch_dim is not_mapped:\n",
    "      return full_lower(self.val)\n",
    "    else:\n",
    "      return self\n",
    "\n",
    "class BatchTrace(Trace):\n",
    "  pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)\n",
    "\n",
    "  def process_primitive(self, primitive, tracers, params):\n",
    "    vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)\n",
    "    vmap_rule = vmap_rules[primitive]\n",
    "    val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)\n",
    "    return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]\n",
    "\n",
    "  @property\n",
    "  def axis_size(self):\n",
    "    return self.main.global_data\n",
    "\n",
    "vmap_rules = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we've implemented the optional `Tracer.full_lower` method, which lets us\n",
    "peel off a batching tracer if it's not needed because it doesn't represent a\n",
    "batched value.\n",
    "\n",
    "For `BatchTrace`, analogous to `JVPTrace`, the methods `pure` and `lift` just\n",
    "box a value in a `BatchTracer` with the minimal amount of context, which in\n",
    "this case is a `batch_dim` taking the sentinel value `not_mapped`. Notice we\n",
    "use the `MainTrace`'s interpreter-global data field to store the batch axis\n",
    "size.\n",
    "\n",
    "Next we can define batching interpreter rules for each primitive:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "def binop_batching_rule(op, axis_size, vals_in, dims_in):\n",
    "  (x, y), (x_bdim, y_bdim) = vals_in, dims_in\n",
    "  if x_bdim != y_bdim:\n",
    "    if x_bdim is not_mapped:\n",
    "      x = move_batch_axis(axis_size, x_bdim, y_bdim, x)\n",
    "      x_bdim = y_bdim\n",
    "    else:\n",
    "      y = move_batch_axis(axis_size, y_bdim, x_bdim, y)\n",
    "  return [op(x, y)], [x_bdim]\n",
    "vmap_rules[add_p] = partial(binop_batching_rule, add)\n",
    "vmap_rules[mul_p] = partial(binop_batching_rule, mul)\n",
    "\n",
    "def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):\n",
    "  (x,), (x_bdim,) = vals_in, dims_in\n",
    "  return [op(x)], [x_bdim]\n",
    "vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)\n",
    "vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)\n",
    "vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)\n",
    "\n",
    "def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):\n",
    "  (x,), (x_bdim,) = vals_in, dims_in\n",
    "  new_axis = tuple(ax + (x_bdim <= ax) for ax in axis)\n",
    "  out_bdim = x_bdim - sum(ax < x_bdim for ax in axis)\n",
    "  return [reduce_sum(x, new_axis)], [out_bdim]\n",
    "vmap_rules[reduce_sum_p] = reduce_sum_batching_rule"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we add a transformation API to kick off the trace:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vmap_flat(f, in_axes, *args):\n",
    "  axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)\n",
    "                if ax is not not_mapped}\n",
    "  with new_main(BatchTrace, axis_size) as main:\n",
    "    trace = BatchTrace(main)\n",
    "    tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x\n",
    "                  for x, ax in zip(args, in_axes)]\n",
    "    outs = f(*tracers_in)\n",
    "    tracers_out = [full_raise(trace, out) for out in outs]\n",
    "    vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)\n",
    "  outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out)\n",
    "                     for val_out, bdim in zip(vals_out, bdims_out)]\n",
    "  return outs_transposed\n",
    "\n",
    "def vmap(f, in_axes):\n",
    "  def batched_f(*args):\n",
    "    args_flat, in_tree = tree_flatten(args)\n",
    "    in_axes_flat, in_tree2 = tree_flatten(in_axes)\n",
    "    if in_tree != in_tree2: raise TypeError\n",
    "    f_flat, out_tree = flatten_fun(f, in_tree)\n",
    "    outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)\n",
    "    return tree_unflatten(out_tree(), outs_flat)\n",
    "  return batched_f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_one_to_a_scalar(scalar):\n",
    "  assert np.ndim(scalar) == 0\n",
    "  return 1 + scalar\n",
    "\n",
    "vector_in = np.arange(3.)\n",
    "vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in)\n",
    "\n",
    "print(vector_in)\n",
    "print(vector_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def jacfwd(f, x):\n",
    "  pushfwd = lambda v: jvp(f, (x,), (v,))[1]\n",
    "  vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)\n",
    "  return vmap(pushfwd, (0,))(vecs_in)\n",
    "\n",
    "def f(x):\n",
    "  return sin(x)\n",
    "\n",
    "jacfwd(f, np.arange(3.))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That's it for `jvp` and `vmap`!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 2: Jaxprs\n",
    "\n",
    "The next transformations on the horizon are `jit` for just-in-time\n",
    "compilation and `vjp` for reverse-mode autodiff.  (`grad` is just a small\n",
    "wrapper around `vjp`.) Whereas `jvp` and `vmap` only needed each `Tracer` to\n",
    "carry a little bit of extra context, for both `jit` and `vjp` we need much\n",
    "richer context: we need to represent _programs_. That is, we need jaxprs!\n",
    "\n",
    "Jaxprs are JAX's internal intermediate representation of programs. They are\n",
    "explicitly typed, functional, first-order, and in ANF form. We need a\n",
    "program representation for `jit` because the purpose of `jit` is to stage\n",
    "computation out of Python. For any computation we want to stage out, we need\n",
    "to be able to represent it as data, and build it up as we trace a Python\n",
    "function. Similarly, `vjp` needs a way to represent the computation for the\n",
    "backward pass of reverse-mode autodiff. We use the same jaxpr program\n",
    "representation for both needs.\n",
    "\n",
    "(Building a program representation is the most\n",
    "[free](https://en.wikipedia.org/wiki/Free_object) kind of\n",
    "trace-transformation, and so except for issues around handling native Python\n",
    "control flow, any transformation could be implemented by first tracing to a\n",
    "jaxpr and then interpreting the jaxpr.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Jaxpr data structures\n",
    "\n",
    "The jaxpr term syntax is roughly:\n",
    "\n",
    "```\n",
    "jaxpr ::=\n",
    "  { lambda <binder> , ... .\n",
    "    let <eqn>\n",
    "        ...\n",
    "    in ( <atom> , ... ) }\n",
    "\n",
    "binder ::= <var>:<array_type>\n",
    "var ::= a | b | c | ...\n",
    "atom ::= <var> | <literal>\n",
    "literal ::= <int32> | <int64> | <float32> | <float64>\n",
    "\n",
    "eqn ::= <binder> , ... = <primitive> [ <params> ] <atom> , ...\n",
    "```\n",
    "\n",
    "The syntax of types is:\n",
    "\n",
    "```\n",
    "jaxpr_type ::= [ <array_type> , ... ] -> [ <array_type> , ... ]\n",
    "array_type ::= <dtype>[<shape>]\n",
    "dtype ::= f32 | f64 | i32 | i64\n",
    "shape ::= <int> , ...\n",
    "```\n",
    "\n",
    "How do we represent these as Python data structures? We reuse ShapedArrays to\n",
    "represent types, and we can represent the term syntax with a few Python\n",
    "structs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Set\n",
    "\n",
    "class Var:\n",
    "  aval: ShapedArray\n",
    "  def __init__(self, aval): self.aval = aval\n",
    "\n",
    "class Lit:\n",
    "  val: Any\n",
    "  aval: ShapedArray\n",
    "\n",
    "  def __init__(self, val):\n",
    "    self.aval = aval = raise_to_shaped(get_aval(val))\n",
    "    self.val = np.array(val, aval.dtype)\n",
    "\n",
    "Atom = Union[Var, Lit]\n",
    "\n",
    "class JaxprEqn(NamedTuple):\n",
    "  primitive: Primitive\n",
    "  inputs: List[Atom]\n",
    "  params: Dict[str, Any]\n",
    "  out_binders: List[Var]\n",
    "\n",
    "class Jaxpr(NamedTuple):\n",
    "  in_binders: List[Var]\n",
    "  eqns: List[JaxprEqn]\n",
    "  outs: List[Atom]\n",
    "\n",
    "  def __hash__(self): return id(self)\n",
    "  __eq__ = op.is_\n",
    "\n",
    "def raise_to_shaped(aval):\n",
    "  return ShapedArray(aval.shape, aval.dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Type-checking a jaxpr involves checking that there are no unbound variables,\n",
    "that variables are only bound once, and that for each equation the type of\n",
    "the primitive application matches the type of the output binders."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class JaxprType(NamedTuple):\n",
    "  in_types:  List[ShapedArray]\n",
    "  out_types: List[ShapedArray]\n",
    "\n",
    "  def __repr__(self):\n",
    "    in_types = ', '.join(aval.str_short() for aval in self.in_types)\n",
    "    out_types = ', '.join(aval.str_short() for aval in self.out_types)\n",
    "    return f'({in_types}) -> ({out_types})'\n",
    "\n",
    "def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:\n",
    "  env: Set[Var] = set()\n",
    "\n",
    "  for v in jaxpr.in_binders:\n",
    "    if v in env: raise TypeError\n",
    "    env.add(v)\n",
    "\n",
    "  for eqn in jaxpr.eqns:\n",
    "    in_types = [typecheck_atom(env, x) for x in eqn.inputs]\n",
    "    out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)\n",
    "    for out_binder, out_type in zip(eqn.out_binders, out_types):\n",
    "      if not out_type == out_binder.aval: raise TypeError\n",
    "    for out_binder in eqn.out_binders:\n",
    "      if out_binder in env: raise TypeError\n",
    "      env.add(out_binder)\n",
    "\n",
    "  in_types = [v.aval for v in jaxpr.in_binders]\n",
    "  out_types = [typecheck_atom(env, x) for x in jaxpr.outs]\n",
    "  return JaxprType(in_types, out_types)\n",
    "\n",
    "def typecheck_atom(env: Set[Var], x: Atom) -> ShapedArray:\n",
    "  if isinstance(x, Var):\n",
    "    if x not in env: raise TypeError(\"unbound variable\")\n",
    "    return x.aval\n",
    "  elif isinstance(x, Lit):\n",
    "    return raise_to_shaped(get_aval(x.val))\n",
    "  else:\n",
    "    assert False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can apply the function represented by a jaxpr to arguments with a simple\n",
    "interpreter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_jaxpr(jaxpr: Jaxpr, args: List[Any]) -> List[Any]:\n",
    "  env: Dict[Var, Any] = {}\n",
    "\n",
    "  def read(x: Atom) -> Any:\n",
    "    return env[x] if type(x) is Var else x.val\n",
    "\n",
    "  def write(v: Var, val: Any) -> None:\n",
    "    assert v not in env  # single-assignment\n",
    "    env[v] = val\n",
    "\n",
    "  map(write, jaxpr.in_binders, args)\n",
    "  for eqn in jaxpr.eqns:\n",
    "    in_vals = map(read, eqn.inputs)\n",
    "    outs = bind(eqn.primitive, *in_vals, **eqn.params)\n",
    "    map(write, eqn.out_binders, outs)\n",
    "  return map(read, jaxpr.outs)\n",
    "\n",
    "def jaxpr_as_fun(jaxpr: Jaxpr):\n",
    "  return lambda *args: eval_jaxpr(jaxpr, args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By using `bind` in the interpreter, this interpreter itself is traceable."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Building jaxprs with tracing\n",
    "\n",
    "Now that we have jaxprs as a data structure, we need ways to produce these\n",
    "from tracing Python code. In general there are two variants of how we trace to\n",
    "a jaxpr; `jit` uses one and `vjp` uses the other. We'll start with the one\n",
    "used by `jit`, which is also used by control flow primitives like `lax.cond`,\n",
    "`lax.while_loop`, and `lax.scan`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_list(lst: List[Any], n: int) -> Tuple[List[Any], List[Any]]:\n",
    "  assert 0 <= n <= len(lst)\n",
    "  return lst[:n], lst[n:]\n",
    "\n",
    "def partition_list(bs: List[bool], l: List[Any]) -> Tuple[List[Any], List[Any]]:\n",
    "  assert len(bs) == len(l)\n",
    "  lists = lst1, lst2 = [], []\n",
    "  for b, x in zip(bs, l):\n",
    "    lists[b].append(x)\n",
    "  return lst1, lst2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NB: the analogous class in JAX is called 'DynamicJaxprTracer'\n",
    "class JaxprTracer(Tracer):\n",
    "  __slots__ = ['aval']\n",
    "  aval: ShapedArray\n",
    "\n",
    "  def __init__(self, trace, aval):\n",
    "    self._trace = trace\n",
    "    self.aval = aval\n",
    "\n",
    "# NB: the analogous class in JAX is called 'DynamicJaxprTrace'\n",
    "class JaxprTrace(Trace):\n",
    "  def new_arg(self, aval: ShapedArray) -> JaxprTracer:\n",
    "    aval = raise_to_shaped(aval)\n",
    "    tracer = self.builder.new_tracer(self, aval)\n",
    "    self.builder.tracer_to_var[id(tracer)] = Var(aval)\n",
    "    return tracer\n",
    "\n",
    "  def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:\n",
    "    tracer = self.builder.const_tracers.get(id(val))\n",
    "    if tracer is None:\n",
    "      tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val)))\n",
    "      self.builder.add_const(tracer, val)\n",
    "    return tracer\n",
    "  pure = lift = get_or_make_const_tracer\n",
    "\n",
    "  def process_primitive(self, primitive, tracers, params):\n",
    "    avals_in = [t.aval for t in tracers]\n",
    "    avals_out = abstract_eval_rules[primitive](*avals_in, **params)\n",
    "    out_tracers = [self.builder.new_tracer(self, a) for a in avals_out]\n",
    "    inputs = [self.builder.getvar(t) for t in tracers]\n",
    "    outvars = [self.builder.add_var(t) for t in out_tracers]\n",
    "    self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars))\n",
    "    return out_tracers\n",
    "\n",
    "  @property\n",
    "  def builder(self):\n",
    "    return self.main.global_data\n",
    "\n",
    "# NB: in JAX, we instead attach abstract eval rules to Primitive instances\n",
    "abstract_eval_rules = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that we keep as interpreter-global data a builder object, which keeps\n",
    "track of variables, constants, and eqns as we build up the jaxpr."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class JaxprBuilder:\n",
    "  eqns: List[JaxprEqn]\n",
    "  tracer_to_var: Dict[int, Var]\n",
    "  const_tracers: Dict[int, JaxprTracer]\n",
    "  constvals: Dict[Var, Any]\n",
    "  tracers: List[JaxprTracer]\n",
    "\n",
    "  def __init__(self):\n",
    "    self.eqns = []\n",
    "    self.tracer_to_var = {}\n",
    "    self.const_tracers = {}\n",
    "    self.constvals = {}\n",
    "    self.tracers = []\n",
    "\n",
    "  def new_tracer(self, trace: JaxprTrace, aval: ShapedArray) -> JaxprTracer:\n",
    "    tracer = JaxprTracer(trace, aval)\n",
    "    self.tracers.append(tracer)\n",
    "    return tracer\n",
    "\n",
    "  def add_eqn(self, eqn: JaxprEqn) -> None:\n",
    "    self.eqns.append(eqn)\n",
    "\n",
    "  def add_var(self, tracer: JaxprTracer) -> Var:\n",
    "    assert id(tracer) not in self.tracer_to_var\n",
    "    var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)\n",
    "    return var\n",
    "\n",
    "  def getvar(self, tracer: JaxprTracer) -> Var:\n",
    "    var = self.tracer_to_var.get(id(tracer))\n",
    "    assert var is not None\n",
    "    return var\n",
    "\n",
    "  def add_const(self, tracer: JaxprTracer, val: Any) -> Var:\n",
    "    var = self.add_var(tracer)\n",
    "    self.const_tracers[id(val)] = tracer\n",
    "    self.constvals[var] = val\n",
    "    return var\n",
    "\n",
    "  def build(self, in_tracers: List[JaxprTracer], out_tracers: List[JaxprTracer]\n",
    "            ) -> Tuple[Jaxpr, List[Any]]:\n",
    "    constvars, constvals = unzip2(self.constvals.items())\n",
    "    t2v = lambda t: self.tracer_to_var[id(t)]\n",
    "    in_binders = constvars + [t2v(t) for t in in_tracers]\n",
    "    out_vars = [t2v(t) for t in out_tracers]\n",
    "    jaxpr = Jaxpr(in_binders, self.eqns, out_vars)\n",
    "    typecheck_jaxpr(jaxpr)\n",
    "    jaxpr, constvals = _inline_literals(jaxpr, constvals)\n",
    "    return jaxpr, constvals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _inline_literals(jaxpr: Jaxpr, consts: List[Any]) -> Tuple[Jaxpr, List[Any]]:\n",
    "  const_binders, other_binders = split_list(jaxpr.in_binders, len(consts))\n",
    "  scalars = [type(x) in jax_types and not get_aval(x).shape for x in consts]\n",
    "  new_const_binders, lit_binders = partition_list(scalars, const_binders)\n",
    "  new_consts, lit_vals = partition_list(scalars, consts)\n",
    "  literals = dict(zip(lit_binders, map(Lit, lit_vals)))\n",
    "  new_eqns = [JaxprEqn(eqn.primitive, [literals.get(x, x) for x in eqn.inputs],\n",
    "                       eqn.params, eqn.out_binders) for eqn in jaxpr.eqns]\n",
    "  new_outs = [literals.get(x, x) for x in jaxpr.outs]\n",
    "  new_jaxpr = Jaxpr(new_const_binders + other_binders, new_eqns, new_outs)\n",
    "  typecheck_jaxpr(new_jaxpr)\n",
    "  return new_jaxpr, new_consts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The rules we need for `JaxprTrace.process_primitive` are essentially typing\n",
    "rules for primitive applications: given the primitive, its parameters, and\n",
    "types for the inputs, the rule must produce a type for the output, which is\n",
    "then packaged with the output `JaxprTracer`. We can use abstract evaluation\n",
    "rules for this same purpose, even though they can be more general (since\n",
    "abstract evaluation rules must accept ConcreteArray inputs, and since they\n",
    "need only return an upper bound on the set of possible outputs, they can\n",
    "produce ConcreteArray outputs as well). We'll reuse these abstract evaluation\n",
    "rules for the other jaxpr-producing trace machinery, where the potential extra\n",
    "generality is useful."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def binop_abstract_eval(x: ShapedArray, y: ShapedArray) -> List[ShapedArray]:\n",
    "  if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):\n",
    "    raise TypeError\n",
    "  if raise_to_shaped(x) != raise_to_shaped(y): raise TypeError\n",
    "  return [ShapedArray(x.shape, x.dtype)]\n",
    "\n",
    "abstract_eval_rules[add_p] = binop_abstract_eval\n",
    "abstract_eval_rules[mul_p] = binop_abstract_eval\n",
    "\n",
    "def compare_abstract_eval(x: ShapedArray, y: ShapedArray) -> List[ShapedArray]:\n",
    "  if not isinstance(x, ShapedArray) or not isinstance(y, ShapedArray):\n",
    "    raise TypeError\n",
    "  if x.shape != y.shape: raise TypeError\n",
    "  return [ShapedArray(x.shape, np.dtype('bool'))]\n",
    "abstract_eval_rules[greater_p] = compare_abstract_eval\n",
    "abstract_eval_rules[less_p] = compare_abstract_eval\n",
    "\n",
    "def vectorized_unop_abstract_eval(x: ShapedArray) -> List[ShapedArray]:\n",
    "  return [ShapedArray(x.shape, x.dtype)]\n",
    "\n",
    "abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval\n",
    "abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval\n",
    "abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval\n",
    "\n",
    "def reduce_sum_abstract_eval(x: ShapedArray, *, axis: Tuple[int, ...]\n",
    "                             ) -> List[ShapedArray]:\n",
    "  axis_ = set(axis)\n",
    "  new_shape = [d for i, d in enumerate(x.shape) if i not in axis_]\n",
    "  return [ShapedArray(tuple(new_shape), x.dtype)]\n",
    "abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval\n",
    "\n",
    "def broadcast_abstract_eval(x: ShapedArray, *, shape: Sequence[int],\n",
    "                            axes: Sequence[int]) -> List[ShapedArray]:\n",
    "  return [ShapedArray(tuple(shape), x.dtype)]\n",
    "abstract_eval_rules[broadcast_p] = broadcast_abstract_eval"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To check our implementation of jaxprs, we can add a `make_jaxpr`\n",
    "transformation and a pretty-printer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import lru_cache\n",
    "\n",
    "@lru_cache()  # ShapedArrays are hashable\n",
    "def make_jaxpr_v1(f, *avals_in):\n",
    "  avals_in, in_tree = tree_flatten(avals_in)\n",
    "  f, out_tree = flatten_fun(f, in_tree)\n",
    "\n",
    "  builder = JaxprBuilder()\n",
    "  with new_main(JaxprTrace, builder) as main:\n",
    "    trace = JaxprTrace(main)\n",
    "    tracers_in = [trace.new_arg(aval) for aval in avals_in]\n",
    "    outs = f(*tracers_in)\n",
    "    tracers_out = [full_raise(trace, out) for out in outs]\n",
    "    jaxpr, consts = builder.build(tracers_in, tracers_out)\n",
    "  return jaxpr, consts, out_tree()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "from typing import DefaultDict\n",
    "from collections import defaultdict\n",
    "import string\n",
    "\n",
    "class PPrint:\n",
    "  lines: List[Tuple[int, str]]\n",
    "\n",
    "  def __init__(self, lines):\n",
    "    self.lines = lines\n",
    "\n",
    "  def indent(self, indent: int) -> 'PPrint':\n",
    "    return PPrint([(indent + orig_indent, s) for orig_indent, s in self.lines])\n",
    "\n",
    "  def __add__(self, rhs: 'PPrint') -> 'PPrint':\n",
    "    return PPrint(self.lines + rhs.lines)\n",
    "\n",
    "  def __rshift__(self, rhs: 'PPrint') -> 'PPrint':\n",
    "    if not rhs.lines: return self\n",
    "    if not self.lines: return rhs\n",
    "    indent, s = self.lines[-1]\n",
    "    indented_block = rhs.indent(indent + len(s))\n",
    "    common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]\n",
    "    return PPrint(self.lines[:-1]\n",
    "                  + [(indent, common_line)]\n",
    "                  + indented_block.lines[1:])\n",
    "\n",
    "  def __str__(self) -> str:\n",
    "    return '\\n'.join(' ' * indent + s for indent, s in self.lines)\n",
    "\n",
    "def pp(s: Any) -> PPrint:\n",
    "  return PPrint([(0, line) for line in str(s).splitlines()])\n",
    "\n",
    "def vcat(ps: List[PPrint]) -> PPrint:\n",
    "  return sum(ps, pp(''))\n",
    "\n",
    "def pp_jaxpr(jaxpr: Jaxpr) -> PPrint:\n",
    "  namegen = (''.join(s) for r in it.count(1)\n",
    "             for s in it.permutations(string.ascii_lowercase, r))\n",
    "  names = defaultdict(lambda: next(namegen))\n",
    "  in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders)\n",
    "  eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns])\n",
    "  outs = ', '.join(names[v] if isinstance(v, Var) else str(v.val)\n",
    "                   for v in jaxpr.outs)\n",
    "  return (pp(f'{{ lambda {in_binders} .') +\n",
    "          ((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))\n",
    "\n",
    "def var_str(names: DefaultDict[Var, str], v: Var) -> str:\n",
    "  return f'{names[v]}:{v.aval.str_short()}'\n",
    "\n",
    "def pp_eqn(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
    "  rule = pp_rules.get(eqn.primitive)\n",
    "  if rule:\n",
    "    return rule(names, eqn)\n",
    "  else:\n",
    "    lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
    "    rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>\n",
    "           pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
    "                       for x in eqn.inputs)))\n",
    "    return lhs >> pp(' = ') >> rhs\n",
    "\n",
    "def pp_params(params: Dict[str, Any]) -> PPrint:\n",
    "  items = sorted(params.items())\n",
    "  if items:\n",
    "    return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')\n",
    "  else:\n",
    "    return pp(' ')\n",
    "\n",
    "Jaxpr.__repr__ = lambda self: str(pp_jaxpr(self))\n",
    "pp_rules: Dict[Primitive, Callable[..., PPrint]] = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))\n",
    "print(jaxpr)\n",
    "print(typecheck_jaxpr(jaxpr))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But there's a limitation here: because of how `find_top_trace` operates by\n",
    "data dependence, `make_jaxpr_v1` can't stage out all the primitive operations\n",
    "performed by the Python callable it's given. For example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.))\n",
    "print(jaxpr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is precisely the issue that\n",
    "[omnistaging](https://github.com/google/jax/pull/3370) fixed.\n",
    "We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always\n",
    "applied, regardless of whether any inputs to `bind` are boxed in corresponding\n",
    "`JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`\n",
    "global defined in Part 1:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@contextmanager\n",
    "def new_dynamic(main: MainTrace):\n",
    "  global dynamic_trace\n",
    "  prev_dynamic_trace, dynamic_trace = dynamic_trace, main\n",
    "  try:\n",
    "    yield\n",
    "  finally:\n",
    "    dynamic_trace = prev_dynamic_trace\n",
    "\n",
    "@lru_cache()\n",
    "def make_jaxpr(f: Callable, *avals_in: ShapedArray,\n",
    "               ) -> Tuple[Jaxpr, List[Any], PyTreeDef]:\n",
    "  avals_in, in_tree = tree_flatten(avals_in)\n",
    "  f, out_tree = flatten_fun(f, in_tree)\n",
    "\n",
    "  builder = JaxprBuilder()\n",
    "  with new_main(JaxprTrace, builder) as main:\n",
    "    with new_dynamic(main):\n",
    "      trace = JaxprTrace(main)\n",
    "      tracers_in = [trace.new_arg(aval) for aval in avals_in]\n",
    "      outs = f(*tracers_in)\n",
    "      tracers_out = [full_raise(trace, out) for out in outs]\n",
    "      jaxpr, consts = builder.build(tracers_in, tracers_out)\n",
    "  return jaxpr, consts, out_tree()\n",
    "\n",
    "jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.))\n",
    "print(jaxpr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using `dynamic_trace` this way is conceptually the same as stashing the\n",
    "current interpreter stack and starting a new one with the `JaxprTrace` at the\n",
    "bottom. That is, no interpreters lower in the stack than the `dynamic_trace`\n",
    "are applied (since `JaxprTrace.process_primitive` doesn't call `bind`), though\n",
    "if the Python callable being traced to a jaxpr itself uses transformations\n",
    "then those can be pushed onto the interpreter stack above the `JaxprTrace`.\n",
    "But temporarily stashing the interpreter stack would break up the system\n",
    "state. The `dynamic_trace` tag achieves the same goals while keeping the\n",
    "system state simpler."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That's it for jaxprs! With jaxprs in hand, we can implement the remaining\n",
    "major JAX features."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 3: `jit`, simplified\n",
    "\n",
    "While `jit` has a transformation-like API in that it accepts a Python callable\n",
    "as an argument, under the hood it's really a higher-order primitive rather\n",
    "than a transformation. A primitive is _higher-order_ when it's parameterized\n",
    "by a function."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### On-the-fly (\"final style\") and staged (\"initial style\") processing\n",
    "\n",
    "There are two options for how to handle higher-order primitives. Each requires\n",
    "a different approach to tracing and engenders different tradeoffs:\n",
    "1. **On-the-fly processing, where `bind` takes a Python callable as an\n",
    "   argument.** We defer forming a jaxpr until as late as possible, namely\n",
    "   until we're running the final interpreter at the bottom of the interpreter\n",
    "   stack. That way we can swap a `JaxprTrace` in at the bottom of the\n",
    "   interpreter stack and thus stage out rather than execute all primitive\n",
    "   operations. With this approach, transformations in the stack get applied as\n",
    "   we execute the Python callable as usual. This approach can be very tricky\n",
    "   to implement, but it's as general as possible because it allows\n",
    "   higher-order primitives not to raise the abstraction level of their\n",
    "   arguments and thus allows data-dependent Python control flow. We refer to\n",
    "   this approach as using a \"final-style higher-order primitive\" employing the\n",
    "   discharge-at-tracing-time \"final-style transformations\" we've used so far.\n",
    "2. **Staged processing, where `bind` takes a jaxpr as an argument.** Before we\n",
    "   call `bind`, in the primitive wrapper we can just use `make_jaxpr` to form\n",
    "   a jaxpr up-front and be done with the Python callable entirely. In this\n",
    "   case, `make_jaxpr` puts its `JaxprTrace` at the top of the interpreter\n",
    "   stack, and no transformations lower in the stack, which might enter via\n",
    "   closed-over Tracers, are applied to the Python callable as we trace it.\n",
    "   (Transformations applied within the Python callable are applied as usual,\n",
    "   being added to the stack above the JaxprTrace.) Instead, the\n",
    "   transformations lower in the stack are later applied to the call primitive,\n",
    "   and the call primitive's rules must then transform the jaxpr itself.\n",
    "   Because we trace to a jaxpr up-front, this approach can't support\n",
    "   data-dependent Python control flow, but it is more straightforward to\n",
    "   implement. We refer to this kind of higher-order primitive as an\n",
    "   \"initial-style higher-order primitive\", and say that its jaxpr-processing\n",
    "   transformation rules are \"initial-style transformation rules.\"\n",
    "\n",
    "The latter approach fits for `jit` because we don't need to support\n",
    "data-dependent Python control flow in the user-provided Python callable, as\n",
    "the whole purpose of `jit` is to stage computation out of Python to be\n",
    "executed by XLA. (In contrast, `custom_jvp` is a higher-order primitive in\n",
    "which we want to support data-dependent Python control flow.)\n",
    "\n",
    "Historically, we started using the \"initial-style\" and \"final-style\"\n",
    "terminology after reading the [typed tagless final\n",
    "interpreters](http://okmij.org/ftp/tagless-final/index.html) paper, and\n",
    "jokingly referring to JAX as an implementation of \"untyped tagful final\n",
    "interpreters.\" We don't claim to carry over (or understand) any deep meaning\n",
    "behind these terms; we loosely use \"initial style\" to mean \"build an AST and\n",
    "then transform it\", and we use \"final style\" to mean \"transform as we trace.\"\n",
    "But it's just imprecise yet sticky jargon."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the initial-style approach, here's the user-facing `jit` wrapper:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def jit(f):\n",
    "  def f_jitted(*args):\n",
    "    avals_in = [raise_to_shaped(get_aval(x)) for x in args]\n",
    "    jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)\n",
    "    outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts))\n",
    "    return tree_unflatten(out_tree, outs)\n",
    "  return f_jitted\n",
    "\n",
    "xla_call_p = Primitive('xla_call')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With any new primitive, we need to give it transformation rules, starting with\n",
    "its evaluation rule. When we evaluate an application of the `xla_call`\n",
    "primitive, we want to stage out out the computation to XLA. That involves\n",
    "translating the jaxpr to an XLA HLO program, transferring the argument values\n",
    "to the XLA device, executing the XLA program, and transferring back the\n",
    "results. We'll cache the XLA HLO compilation so that for each `jit`ted\n",
    "function it only needs to be performed once per argument shape and dtype\n",
    "signature.\n",
    "\n",
    "First, some utilities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class IDHashable:\n",
    "  val: Any\n",
    "\n",
    "  def __init__(self, val):\n",
    "    self.val = val\n",
    "\n",
    "  def __hash__(self) -> int:\n",
    "    return id(self.val)\n",
    "\n",
    "  def __eq__(self, other):\n",
    "    return type(other) is IDHashable and id(self.val) == id(other.val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll define the evaluation rule for `xla_call`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax._src.lib import xla_bridge as xb\n",
    "from jax._src.lib import xla_client as xc\n",
    "xe = xc._xla\n",
    "xops = xc._xla.ops\n",
    "\n",
    "def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):\n",
    "  consts, args = args[:num_consts], args[num_consts:]\n",
    "  hashable_consts = tuple(map(IDHashable, consts))\n",
    "  execute = xla_callable(IDHashable(jaxpr), hashable_consts)\n",
    "  return execute(*args)\n",
    "impl_rules[xla_call_p] = xla_call_impl\n",
    "\n",
    "@lru_cache()\n",
    "def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: Tuple[IDHashable]):\n",
    "  jaxpr: Jaxpr = hashable_jaxpr.val\n",
    "  typecheck_jaxpr(jaxpr)\n",
    "  consts = [x.val for x in hashable_consts]\n",
    "  in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]\n",
    "  c = xc.XlaBuilder('xla_call')\n",
    "  xla_consts = _xla_consts(c, consts)\n",
    "  xla_params = _xla_params(c, in_avals)\n",
    "  outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)\n",
    "  out = xops.Tuple(c, outs)\n",
    "  compiled = xb.get_backend(None).compile(c.build(out))\n",
    "  return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])\n",
    "\n",
    "def _xla_consts(c: xe.XlaBuilder, consts: List[Any]) -> List[xe.XlaOp]:\n",
    "  unique_consts = {id(cnst): cnst for cnst in consts}\n",
    "  xla_consts = {\n",
    "      id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()}\n",
    "  return [xla_consts[id(cnst)] for cnst in consts]\n",
    "\n",
    "def _xla_params(c: xe.XlaBuilder, avals_in: List[ShapedArray]) -> List[xe.XlaOp]:\n",
    "  return [xops.Parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]\n",
    "\n",
    "def _xla_shape(aval: ShapedArray) -> xe.Shape:\n",
    "  return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The main action is in `xla_callable`, which compiles a jaxpr into an XLA HLO\n",
    "program using `jaxpr_subcomp`, then returns a callable which executes the\n",
    "compiled program:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]\n",
    "                  ) -> xe.XlaOp:\n",
    "  env: Dict[Var, xe.XlaOp] = {}\n",
    "\n",
    "  def read(x: Atom) -> xe.XlaOp:\n",
    "    return env[x] if type(x) is Var else xops.Constant(c, np.asarray(x.val))\n",
    "\n",
    "  def write(v: Var, val: xe.XlaOp) -> None:\n",
    "    env[v] = val\n",
    "\n",
    "  map(write, jaxpr.in_binders, args)\n",
    "  for eqn in jaxpr.eqns:\n",
    "    in_avals = [x.aval for x in eqn.inputs]\n",
    "    in_vals = map(read, eqn.inputs)\n",
    "    rule = xla_translations[eqn.primitive]\n",
    "    out_vals = rule(c, in_avals, in_vals, **eqn.params)\n",
    "    map(write, eqn.out_binders, out_vals)\n",
    "  return map(read, jaxpr.outs)\n",
    "\n",
    "def execute_compiled(compiled, out_avals, *args):\n",
    "  input_bufs = [input_handlers[type(x)](x) for x in args]\n",
    "  out_bufs = compiled.execute(input_bufs)\n",
    "  return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)]\n",
    "\n",
    "default_input_handler = xb.get_backend(None).buffer_from_pyval\n",
    "input_handlers = {ty: default_input_handler for ty in\n",
    "                  [bool, int, float, np.ndarray, np.float64, np.float32]}\n",
    "\n",
    "def handle_result(aval: ShapedArray, buf):\n",
    "  del aval  # Unused for now\n",
    "  return buf.to_py()\n",
    "\n",
    "xla_translations = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that `jaxpr_subcomp` has the structure of a simple interpreter. That's\n",
    "a common pattern: the way we process jaxprs is usually with an interpreter.\n",
    "And as with any interpreter, we need an interpretation rule for each\n",
    "primitive:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def direct_translation(op, c, in_avals, in_vals):\n",
    "  del c, in_avals\n",
    "  return [op(*in_vals)]\n",
    "\n",
    "xla_translations[add_p] = partial(direct_translation, xops.Add)\n",
    "xla_translations[mul_p] = partial(direct_translation, xops.Mul)\n",
    "xla_translations[neg_p] = partial(direct_translation, xops.Neg)\n",
    "xla_translations[sin_p] = partial(direct_translation, xops.Sin)\n",
    "xla_translations[cos_p] = partial(direct_translation, xops.Cos)\n",
    "xla_translations[greater_p] = partial(direct_translation, xops.Gt)\n",
    "xla_translations[less_p] = partial(direct_translation, xops.Lt)\n",
    "\n",
    "def reduce_sum_translation(c, in_avals, in_vals, *, axis):\n",
    "  (x_aval,), (x,) = in_avals, in_vals\n",
    "  zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))\n",
    "  subc = xc.XlaBuilder('add')\n",
    "  shape = _xla_shape(ShapedArray((), x_aval.dtype))\n",
    "  xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))\n",
    "  return [xops.Reduce(c, [x], [zero], subc.build(), axis)]\n",
    "xla_translations[reduce_sum_p] = reduce_sum_translation\n",
    "\n",
    "def broadcast_translation(c, in_avals, in_vals, *, shape, axes):\n",
    "  x, = in_vals\n",
    "  dims_complement = [i for i in range(len(shape)) if i not in axes]\n",
    "  return [xops.BroadcastInDim(x, shape, dims_complement)]\n",
    "xla_translations[broadcast_p] = broadcast_translation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With that, we can now use `jit` to stage out, compile, and execute programs\n",
    "with XLA!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit\n",
    "def f(x, y):\n",
    "  print('tracing!')\n",
    "  return sin(x) * cos(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = f(3., 4.)  # 'tracing!' prints the first time\n",
    "print(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = f(4., 5.)  # 'tracing!' doesn't print, compilation cache hit!\n",
    "print(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit\n",
    "def f(x):\n",
    "  return reduce_sum(x, axis=0)\n",
    "\n",
    "print(f(np.array([1., 2., 3.])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x):\n",
    "  y = sin(x) * 2.\n",
    "  z = - y + x\n",
    "  return z\n",
    "\n",
    "def deriv(f):\n",
    "  return lambda x: jvp(f, (x,), (1.,))[1]\n",
    "\n",
    "print(    deriv(deriv(f))(3.))\n",
    "print(jit(deriv(deriv(f)))(3.))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instead of implementing `jit` to first trace to a jaxpr and then to lower the\n",
    "jaxpr to XLA HLO, it might appear that we could have skipped the jaxpr step\n",
    "and just lowered to HLO while tracing. That is, perhaps we could have instead\n",
    "implemented `jit` with a `Trace` and `Tracer` that appended to the XLA HLO\n",
    "graph incrementally on each primitive bind. That's correct for now, but won't\n",
    "be possible when we introduce compiled SPMD computations because there we must\n",
    "know the number of replicas needed before compiling the program."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We haven't yet defined any transformation rules for `xla_call_p` other than\n",
    "its evaluation rule. That is, we can't yet do `vmap`-of-`jit` or\n",
    "`jvp`-of-`jit` or even `jit`-of`-jit`. Instead `jit` has to be at the \"top\n",
    "level.\" Let's fix that!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):\n",
    "  del num_consts  # Unused\n",
    "  new_jaxpr, new_consts = jvp_jaxpr(jaxpr)\n",
    "  outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,\n",
    "              num_consts=len(new_consts))\n",
    "  n = len(outs) // 2\n",
    "  primals_out, tangents_out = outs[:n], outs[n:]\n",
    "  return primals_out, tangents_out\n",
    "jvp_rules[xla_call_p] = xla_call_jvp_rule\n",
    "\n",
    "@lru_cache()\n",
    "def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:\n",
    "  def jvp_traceable(*primals_and_tangents):\n",
    "    n = len(primals_and_tangents) // 2\n",
    "    primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]\n",
    "    return jvp(jaxpr_as_fun(jaxpr), primals, tangents)\n",
    "\n",
    "  in_avals = [v.aval for v in jaxpr.in_binders]\n",
    "  new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals)\n",
    "  return new_jaxpr, new_consts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):\n",
    "  del num_consts  # Unused\n",
    "  new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, tuple(dims_in))\n",
    "  outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,\n",
    "              num_consts=len(new_consts))\n",
    "  return outs, [0] * len(outs)\n",
    "vmap_rules[xla_call_p] = xla_call_vmap_rule\n",
    "\n",
    "@lru_cache()\n",
    "def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: Tuple[BatchAxis, ...]\n",
    "               ) -> Tuple[Jaxpr, List[Any]]:\n",
    "  vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))\n",
    "  in_avals = [unmapped_aval(axis_size, d, v.aval)\n",
    "              for v, d in zip(jaxpr.in_binders, bdims_in)]\n",
    "  new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals)\n",
    "  return new_jaxpr, new_consts\n",
    "\n",
    "def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray\n",
    "                  ) -> ShapedArray:\n",
    "  if batch_dim is not_mapped:\n",
    "    return aval\n",
    "  else:\n",
    "    shape = list(aval.shape)\n",
    "    shape.insert(batch_dim, axis_size)\n",
    "    return ShapedArray(tuple(shape), aval.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def xla_call_abstract_eval_rule(*in_types, jaxpr, num_consts):\n",
    "  del num_consts  # Unused\n",
    "  jaxpr_type = typecheck_jaxpr(jaxpr)\n",
    "  if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):\n",
    "    raise TypeError\n",
    "  return jaxpr_type.out_types\n",
    "abstract_eval_rules[xla_call_p] = xla_call_abstract_eval_rule\n",
    "\n",
    "def xla_call_translation(c, in_avals, in_vals, *, jaxpr, num_consts):\n",
    "  del num_consts  # Only used at top-level.\n",
    "  # Calling jaxpr_subcomp directly would inline. We generate a Call HLO instead.\n",
    "  subc = xc.XlaBuilder('inner xla_call')\n",
    "  xla_params = _xla_params(subc, in_avals)\n",
    "  outs = jaxpr_subcomp(subc, jaxpr, xla_params)\n",
    "  subc = subc.build(xops.Tuple(subc, outs))\n",
    "  return destructure_tuple(c, xops.Call(c, subc, in_vals))\n",
    "xla_translations[xla_call_p] = xla_call_translation\n",
    "\n",
    "def destructure_tuple(c, tup):\n",
    "  num_elements = len(c.get_shape(tup).tuple_shapes())\n",
    "  return [xops.GetTupleElement(tup, i) for i in range(num_elements)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit\n",
    "def f(x):\n",
    "  print('tracing!')\n",
    "  y = sin(x) * 2.\n",
    "  z = - y + x\n",
    "  return z\n",
    "\n",
    "x, xdot = 3., 1.\n",
    "y, ydot = jvp(f, (x,), (xdot,))\n",
    "print(y)\n",
    "print(ydot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y, ydot = jvp(f, (x,), (xdot,))  # 'tracing!' not printed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ys = vmap(f, (0,))(np.arange(3.))\n",
    "print(ys)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One piece missing is device memory persistence for arrays. That is, we've\n",
    "defined `handle_result` to transfer results back to CPU memory as NumPy\n",
    "arrays, but it's often preferable to avoid transferring results just to\n",
    "transfer them back for the next operation. We can do that by introducing a\n",
    "`DeviceArray` class, which can wrap XLA buffers and otherwise duck-type\n",
    "`numpy.ndarray`s:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def handle_result(aval: ShapedArray, buf):  # noqa: F811\n",
    "  return DeviceArray(aval, buf)\n",
    "\n",
    "class DeviceArray:\n",
    "  buf: Any\n",
    "  aval: ShapedArray\n",
    "\n",
    "  def __init__(self, aval, buf):\n",
    "    self.aval = aval\n",
    "    self.buf = buf\n",
    "\n",
    "  dtype = property(lambda self: self.aval.dtype)\n",
    "  shape = property(lambda self: self.aval.shape)\n",
    "  ndim  = property(lambda self: self.aval.ndim)\n",
    "\n",
    "  def __array__(self): return self.buf.to_py()\n",
    "  def __repr__(self):  return repr(self.buf.to_py())\n",
    "  def __str__(self):   return str(self.buf.to_py())\n",
    "\n",
    "  _neg = staticmethod(neg)\n",
    "  _add = staticmethod(add)\n",
    "  _radd = staticmethod(add)\n",
    "  _mul = staticmethod(mul)\n",
    "  _rmul = staticmethod(mul)\n",
    "  _gt = staticmethod(greater)\n",
    "  _lt = staticmethod(less)\n",
    "input_handlers[DeviceArray] = lambda x: x.buf\n",
    "\n",
    "jax_types.add(DeviceArray)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit\n",
    "def f(x):\n",
    "  y = sin(x) * 2.\n",
    "  z = - y + x\n",
    "  return z\n",
    "\n",
    "x, xdot = 3., 1.\n",
    "y, ydot = jvp(f, (x,), (xdot,))\n",
    "print(y)\n",
    "print(ydot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "def pprint_xla_call(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
    "  lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
    "  params_without_jaxpr = {k:v for k, v in eqn.params.items() if k != 'jaxpr'}\n",
    "  rhs = (pp(eqn.primitive.name) >> pp_params(params_without_jaxpr) >>\n",
    "         pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
    "                     for x in eqn.inputs)))\n",
    "  return vcat([lhs >> pp(' = ') >> rhs,\n",
    "               pp_jaxpr(eqn.params['jaxpr']).indent(2)])\n",
    "pp_rules[xla_call_p] = pprint_xla_call"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 4: `linearize` and `vjp` (and `grad`!)\n",
    "\n",
    "The `linearize` and `vjp` autodiff functions are built on `jvp`, but involve\n",
    "jaxprs as well. That's because both involve staging out, or delaying,\n",
    "computation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `linearize`\n",
    "\n",
    "In the case of `linearize`, we want to stage out the linear part of a `jvp`\n",
    "computation. That is, in terms of\n",
    "[Haskell-like type signatures](https://wiki.haskell.org/Type_signature),\n",
    "if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,\n",
    "then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, using `T a` to\n",
    "mean \"the tangent type of `a`\" and using the \"lollipop\" `-o` rather than the\n",
    "arrow `->` to indicate a _linear_ function. We define the semantics of\n",
    "`linearize` in terms of `jvp` too:\n",
    "```python\n",
    "y, f_lin = linearize(f, x)\n",
    "y_dot = f_lin(x_dot)\n",
    "```\n",
    "gives the same result for `(y, y_dot)` as\n",
    "```\n",
    "y, y_dot = jvp(f, (x,), (x_dot,))\n",
    "```\n",
    "where the application of `f_lin` does not redo any of the linearization work.\n",
    "We'll represent the delayed linear part `f_lin : T a -o T b` as a jaxpr.\n",
    "\n",
    "Tangentially, now that we have linear arrows `-o`, we can provide a slightly\n",
    "more informative type for `jvp`:\n",
    "```\n",
    "jvp : (a -> b) -> (UnrestrictedUse a, T a) -o (UnrestrictedUse b, T b)\n",
    "```\n",
    "Here we're writing `UnrestrictedUse` just to indicate that we have a special\n",
    "pair where the first element can be used in an unrestricted (nonlinear) way.\n",
    "In conjunction with the linear arrow, this notation is just meant to express\n",
    "that the function `jvp f` uses its first input in a nonlinear way but its\n",
    "second input in a linear way, producing a corresponding nonlinear output\n",
    "(which can be used in a nonlinear way) paired with a linear output. This more\n",
    "refined type signature encodes the data dependencies in `jvp f`, which are\n",
    "useful for partial evaluation.\n",
    "\n",
    "To build the `f_lin` jaxpr from a JVP, we need to perform partial evaluation:\n",
    "we evaluate all the primal values as we trace, but stage the tangent\n",
    "computations into a jaxpr. This is our second way to build jaxprs. But where\n",
    "`make_jaxpr` and its underlying `JaxprTrace`/`JaxprTracer` interpreters aim\n",
    "to stage out every primitive bind, this second approach stages out only those\n",
    "primitive binds with a data dependence on tangent inputs.\n",
    "\n",
    "First, some utilities:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_half(lst: List[Any]) -> Tuple[List[Any], List[Any]]:\n",
    "  assert not len(lst) % 2\n",
    "  return split_list(lst, len(lst) // 2)\n",
    "\n",
    "def merge_lists(which: List[bool], l1: List[Any], l2: List[Any]) -> List[Any]:\n",
    "  l1, l2 = iter(l1), iter(l2)\n",
    "  out = [next(l2) if b else next(l1) for b in which]\n",
    "  assert next(l1, None) is next(l2, None) is None\n",
    "  return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll write `linearize` by combining `jvp` together with a general\n",
    "partial evaluation transformation, to be added next:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linearize_flat(f, *primals_in):\n",
    "  pvals_in = ([PartialVal.known(x) for x in primals_in] +\n",
    "              [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])\n",
    "  def f_jvp(*primals_tangents_in):\n",
    "    primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))\n",
    "    return [*primals_out, *tangents_out]\n",
    "  jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)\n",
    "  primal_pvals, _ = split_half(pvals_out)\n",
    "  assert all(pval.is_known for pval in primal_pvals)\n",
    "  primals_out = [pval.const for pval in primal_pvals]\n",
    "  f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])\n",
    "  return primals_out, f_lin\n",
    "\n",
    "def linearize(f, *primals_in):\n",
    "  primals_in_flat, in_tree = tree_flatten(primals_in)\n",
    "  f, out_tree = flatten_fun(f, in_tree)\n",
    "  primals_out_flat, f_lin_flat = linearize_flat(f, *primals_in_flat)\n",
    "  primals_out = tree_unflatten(out_tree(), primals_out_flat)\n",
    "\n",
    "  def f_lin(*tangents_in):\n",
    "    tangents_in_flat, in_tree2 = tree_flatten(tangents_in)\n",
    "    if in_tree != in_tree2: raise TypeError\n",
    "    tangents_out_flat = f_lin_flat(*tangents_in_flat)\n",
    "    return tree_unflatten(out_tree(), tangents_out_flat)\n",
    "\n",
    "  return primals_out, f_lin\n",
    "\n",
    "def vspace(aval: ShapedArray) -> ShapedArray:\n",
    "  return raise_to_shaped(aval)  # TODO handle integers?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we turn to the general partial evaluation transformation. The goal is to\n",
    "accept a Python callable and a list of inputs, some known and some unknown,\n",
    "and to produce (1) all the outputs which can be computed from the known\n",
    "inputs, together with (2) a jaxpr representing the part of the Python\n",
    "callable's computation which can only be performed after the remaining inputs\n",
    "are known.\n",
    "\n",
    "This transformation is tricky to summarize in a type signature. If we\n",
    "assume the input function's type signature is `(a1, a2) -> (b1, b2)`, where\n",
    "`a1` and `a2` represent the known and unknown inputs, respectively, and where\n",
    "`b1` only has a data dependency on `a1` while `b2` has some data dependency on\n",
    "`a2`, then we might write\n",
    "\n",
    "```\n",
    "partial_eval : ((a1, a2) -> (b1, b2)) -> a1 -> exists r. (b1, r, (r, a2) -> b2)\n",
    "```\n",
    "\n",
    "In words, given values for the inputs of type `a1`, `partial_eval` produces\n",
    "the outputs of type `b1` along with \"residual\" values of\n",
    "existentially-quantified type `r` representing the intermediates required to\n",
    "complete the computation in the second stage. It also produces a function of\n",
    "type `(r, a2) -> b2` which accepts the residual values as well as the\n",
    "remaining inputs and produces the remaining outputs.\n",
    "\n",
    "We like to think of partial evaluation as \"unzipping\" one computation into\n",
    "two. For example, consider this jaxpr:\n",
    "```\n",
    "{ lambda a:float64[] .\n",
    "  let b:float64[] = sin a\n",
    "      c:float64[] = neg b\n",
    "  in ( c ) }\n",
    "```\n",
    "A jaxpr for the JVP would look like:\n",
    "```\n",
    "{ lambda a:float64[] b:float64[] .\n",
    "  let c:float64[] = sin a\n",
    "      d:float64[] = cos a\n",
    "      e:float64[] = mul d b\n",
    "      f:float64[] = neg c\n",
    "      g:float64[] = neg e\n",
    "  in ( f, g ) }\n",
    "```\n",
    "If we imagine applying partial evaluation to this jaxpr with the first input\n",
    "known and the second unknown, we end up 'unzipping' the JVP jaxpr into primal\n",
    "and tangent jaxprs:\n",
    "```\n",
    "{ lambda a:float64[] .\n",
    "  let c:float64[] = sin a\n",
    "      d:float64[] = cos a\n",
    "      f:float64[] = neg c\n",
    "  in ( f, d ) }\n",
    "```\n",
    "```\n",
    "{ lambda d:float64[] b:float64[] .\n",
    "  let e:float64[] = mul d b\n",
    "      g:float64[] = neg e\n",
    "  in ( g ) }\n",
    "```\n",
    "This second jaxpr represents the linear computation that we want from\n",
    "`linearize`.\n",
    "\n",
    "However, unlike in this jaxpr example, we want the computation on known values\n",
    "to occur while evaluating the input Python callable. That is, rather than\n",
    "forming a jaxpr for the entire function `(a1, a2) -> (b1, b2)`, staging all\n",
    "operations out of Python first before sorting out what can be evaluated now\n",
    "and what must be delayed, we want only to form a jaxpr for those operations\n",
    "that _must_ be delayed due to a dependence on unknown inputs. In the context\n",
    "of automatic differentiation, this is the feature that ultimately enables us\n",
    "to handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python\n",
    "control flow works because partial evaluation keeps the primal computation in\n",
    "Python. As a consequence, our `Trace` and `Tracer` subclasses must on the fly\n",
    "sort out what can be evaluated and what must be staged out into a jaxpr.\n",
    "\n",
    "First, we start with a `PartialVal` class, which represents a value that can\n",
    "be either known or unknown:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PartialVal(NamedTuple):\n",
    "  aval: ShapedArray\n",
    "  const: Optional[Any]\n",
    "\n",
    "  @classmethod\n",
    "  def known(cls, val: Any):\n",
    "    return PartialVal(get_aval(val), val)\n",
    "\n",
    "  @classmethod\n",
    "  def unknown(cls, aval: ShapedArray):\n",
    "    return PartialVal(aval, None)\n",
    "\n",
    "  is_known   = property(lambda self: self.const is not None)\n",
    "  is_unknown = property(lambda self: self.const is     None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Partial evaluation will take a list of `PartialVal`s representing inputs, and\n",
    "return a list of `PartialVal` outputs along with a jaxpr representing the\n",
    "delayed computation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def partial_eval_flat(f: Callable, pvals_in: List[PartialVal]\n",
    "                      ) -> Tuple[Jaxpr, List[PartialVal], List[Any]]:\n",
    "  with new_main(PartialEvalTrace) as main:\n",
    "    trace = PartialEvalTrace(main)\n",
    "    tracers_in = [trace.new_arg(pval) for pval in pvals_in]\n",
    "    outs = f(*tracers_in)\n",
    "    tracers_out = [full_raise(trace, out) for out in outs]\n",
    "    pvals_out = [t.pval for t in tracers_out]\n",
    "    unk_tracers_in  = [t for t in tracers_in  if t.pval.is_unknown]\n",
    "    unk_tracers_out = [t for t in tracers_out if t.pval.is_unknown]\n",
    "    jaxpr, consts = tracers_to_jaxpr(unk_tracers_in, unk_tracers_out)\n",
    "  return jaxpr, pvals_out, consts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we need to implement `PartialEvalTrace` and its `PartialEvalTracer`. This\n",
    "interpreter will build a jaxpr on the fly while tracking data dependencies. To\n",
    "do so, it builds a bipartite directed acyclic graph (DAG) between\n",
    "`PartialEvalTracer` nodes, representing staged-out values, and `JaxprRecipe`\n",
    "nodes, representing formulas for how to compute some values from others. One\n",
    "kind of recipe is a `JaxprEqnRecipe`, corresponding to a `JaxprEqn`'s\n",
    "primitive application, but we also have recipe types for constants and lambda\n",
    "binders:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from weakref import ref, ReferenceType\n",
    "\n",
    "class LambdaBindingRecipe(NamedTuple):\n",
    "  pass\n",
    "\n",
    "class ConstRecipe(NamedTuple):\n",
    "  val: Any\n",
    "\n",
    "class JaxprEqnRecipe(NamedTuple):\n",
    "  prim: Primitive\n",
    "  tracers_in: List['PartialEvalTracer']\n",
    "  params: Dict[str, Any]\n",
    "  avals_out: List[ShapedArray]\n",
    "  tracer_refs_out: List['ReferenceType[PartialEvalTracer]']\n",
    "\n",
    "JaxprRecipe = Union[LambdaBindingRecipe, ConstRecipe, JaxprEqnRecipe]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PartialEvalTracer(Tracer):\n",
    "  pval: PartialVal\n",
    "  recipe: Optional[JaxprRecipe]\n",
    "\n",
    "  def __init__(self, trace, pval, recipe):\n",
    "    self._trace = trace\n",
    "    self.pval = pval\n",
    "    self.recipe = recipe\n",
    "\n",
    "  aval = property(lambda self: self.pval.aval)\n",
    "\n",
    "  def full_lower(self):\n",
    "    if self.pval.is_known:\n",
    "      return full_lower(self.pval.const)\n",
    "    return self"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `PartialEvalTrace` contains the logic for constructing the graph of\n",
    "`JaxprRecipe`s and `PartialEvalTracer`s. Each argument corresponds to a\n",
    "`LambdaBindingRecipe` leaf node, and each constant is a `ConstRecipe` leaf\n",
    "node holding a reference to the constant. All other tracers and recipes come\n",
    "from `process_primitive`, which forms tracers with `JaxprEqnRecipe`s.\n",
    "\n",
    "For most primitives, the `process_primitive` logic is straightforward: if all\n",
    "inputs are known then we can bind the primitive on the known values\n",
    "(evaluating it in Python) and avoid forming tracers corresponding to the\n",
    "output. If instead any input is unknown then we instead stage out into a\n",
    "`JaxprEqnRecipe` representing the primitive application. To build the tracers\n",
    "representing unknown outputs, we need avals, which we get from the abstract\n",
    "eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and\n",
    "`JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using\n",
    "weakrefs.)\n",
    "\n",
    "That `process_primitive` logic applies to most primitives, but `xla_call_p`\n",
    "requires recursive treatment. So we special-case its rule in a\n",
    "`partial_eval_rules` dict."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PartialEvalTrace(Trace):\n",
    "  def new_arg(self, pval: PartialVal) -> Any:\n",
    "    return PartialEvalTracer(self, pval, LambdaBindingRecipe())\n",
    "\n",
    "  def lift(self, val: Any) -> PartialEvalTracer:\n",
    "    return PartialEvalTracer(self, PartialVal.known(val), None)\n",
    "  pure = lift\n",
    "\n",
    "  def instantiate_const(self, tracer: PartialEvalTracer) -> PartialEvalTracer:\n",
    "    if tracer.pval.is_unknown:\n",
    "      return tracer\n",
    "    else:\n",
    "      pval = PartialVal.unknown(raise_to_shaped(tracer.aval))\n",
    "      return PartialEvalTracer(self, pval, ConstRecipe(tracer.pval.const))\n",
    "\n",
    "  def process_primitive(self, primitive, tracers, params):\n",
    "    if all(t.pval.is_known for t in tracers):\n",
    "      return bind(primitive, *map(full_lower, tracers), **params)\n",
    "    rule = partial_eval_rules.get(primitive)\n",
    "    if rule: return rule(self, tracers, **params)\n",
    "    tracers_in = [self.instantiate_const(t) for t in tracers]\n",
    "    avals_in = [t.aval for t in tracers_in]\n",
    "    avals_out = abstract_eval_rules[primitive](*avals_in, **params)\n",
    "    tracers_out = [PartialEvalTracer(self, PartialVal.unknown(aval), None)\n",
    "                   for aval in avals_out]\n",
    "    eqn = JaxprEqnRecipe(primitive, tracers_in, params, avals_out,\n",
    "                         map(ref, tracers_out))\n",
    "    for t in tracers_out: t.recipe = eqn\n",
    "    return tracers_out\n",
    "\n",
    "partial_eval_rules = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we can build graph representations of jaxprs with `PartialEvalTrace`,\n",
    "we need a mechanism to convert the graph representation to a standard jaxpr.\n",
    "The jaxpr corresponds to a topological sort of the graph."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tracers_to_jaxpr(tracers_in: List[PartialEvalTracer],\n",
    "                     tracers_out: List[PartialEvalTracer]):\n",
    "  tracer_to_var: Dict[int, Var] = {id(t): Var(raise_to_shaped(t.aval))\n",
    "                                   for t in tracers_in}\n",
    "  constvar_to_val: Dict[int, Any] = {}\n",
    "  constid_to_var: Dict[int, Var] = {}\n",
    "  processed_eqns: Set[int] = set()\n",
    "  eqns: List[JaxprEqn] = []\n",
    "  for t in toposort(tracers_out, tracer_parents):\n",
    "    if isinstance(t.recipe, LambdaBindingRecipe):\n",
    "      assert id(t) in set(map(id, tracers_in))\n",
    "    elif isinstance(t.recipe, ConstRecipe):\n",
    "      val = t.recipe.val\n",
    "      var = constid_to_var.get(id(val))\n",
    "      if var is None:\n",
    "        aval = raise_to_shaped(get_aval(val))\n",
    "        var = constid_to_var[id(val)] = Var(aval)\n",
    "        constvar_to_val[var] = val\n",
    "      tracer_to_var[id(t)] = var\n",
    "    elif isinstance(t.recipe, JaxprEqnRecipe):\n",
    "      if id(t.recipe) not in processed_eqns:\n",
    "        eqns.append(recipe_to_eqn(tracer_to_var, t.recipe))\n",
    "        processed_eqns.add(id(t.recipe))\n",
    "    else:\n",
    "      raise TypeError(t.recipe)\n",
    "\n",
    "  constvars, constvals = unzip2(constvar_to_val.items())\n",
    "  in_binders = constvars + [tracer_to_var[id(t)] for t in tracers_in]\n",
    "  out_vars = [tracer_to_var[id(t)] for t in tracers_out]\n",
    "  jaxpr = Jaxpr(in_binders, eqns, out_vars)\n",
    "  typecheck_jaxpr(jaxpr)\n",
    "  return jaxpr, constvals\n",
    "\n",
    "def recipe_to_eqn(tracer_to_var: Dict[int, Var], recipe: JaxprEqnRecipe\n",
    "                  ) -> JaxprEqn:\n",
    "  inputs = [tracer_to_var[id(t)] for t in recipe.tracers_in]\n",
    "  out_binders = [Var(aval) for aval in recipe.avals_out]\n",
    "  for t_ref, var in zip(recipe.tracer_refs_out, out_binders):\n",
    "    if t_ref() is not None: tracer_to_var[id(t_ref())] = var\n",
    "  return JaxprEqn(recipe.prim, inputs, recipe.params, out_binders)\n",
    "\n",
    "def tracer_parents(t: PartialEvalTracer) -> List[PartialEvalTracer]:\n",
    "  return t.recipe.tracers_in if isinstance(t.recipe, JaxprEqnRecipe) else []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "def toposort(out_nodes: List[Any], parents: Callable[[Any], List[Any]]):\n",
    "  if not out_nodes: return []\n",
    "  out_nodes = remove_duplicates(out_nodes)\n",
    "\n",
    "  child_counts = {}\n",
    "  stack = list(out_nodes)\n",
    "  while stack:\n",
    "    node = stack.pop()\n",
    "    if id(node) in child_counts:\n",
    "      child_counts[id(node)] += 1\n",
    "    else:\n",
    "      child_counts[id(node)] = 1\n",
    "      stack.extend(parents(node))\n",
    "  for node in out_nodes:\n",
    "    child_counts[id(node)] -= 1\n",
    "\n",
    "  sorted_nodes = []\n",
    "  childless_nodes = [node for node in out_nodes if not child_counts[id(node)]]\n",
    "  while childless_nodes:\n",
    "    node = childless_nodes.pop()\n",
    "    sorted_nodes.append(node)\n",
    "    for parent in parents(node):\n",
    "      if child_counts[id(parent)] == 1:\n",
    "        childless_nodes.append(parent)\n",
    "      else:\n",
    "        child_counts[id(parent)] -= 1\n",
    "\n",
    "  sorted_nodes = sorted_nodes[::-1]\n",
    "  check_toposort(sorted_nodes, parents)\n",
    "  return sorted_nodes\n",
    "\n",
    "def remove_duplicates(lst):\n",
    "  seen = set()\n",
    "  return [x for x in lst if id(x) not in seen and not seen.add(id(x))]\n",
    "\n",
    "def check_toposort(nodes: List[Any], parents: Callable[[Any], List[Any]]):\n",
    "  seen = set()\n",
    "  for node in nodes:\n",
    "    assert all(id(parent) in seen for parent in parents(node))\n",
    "    seen.add(id(node))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can linearize!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y, sin_lin = linearize(sin, 3.)\n",
    "print(y, sin(3.))\n",
    "print(sin_lin(1.), cos(3.))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To handle `linearize`-of-`jit`, we still need to write a partial evaluation\n",
    "rule for `xla_call_p`. Other than tracer bookkeeping, the main task is to\n",
    "perform partial evaluation of a jaxpr, 'unzipping' it into two jaxprs.\n",
    "\n",
    "There are actually two rules to write: one for trace-time partial evaluation,\n",
    "which we'll call `xla_call_partial_eval`, and one for partial evaluation of\n",
    "jaxprs, which we'll call `xla_call_peval_eqn`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def xla_call_partial_eval(trace, tracers, *, jaxpr, num_consts):\n",
    "  del num_consts  # Unused\n",
    "  in_unknowns = [not t.pval.is_known for t in tracers]\n",
    "  jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns)\n",
    "  known_tracers, unknown_tracers = partition_list(in_unknowns, tracers)\n",
    "  known_vals = [t.pval.const for t in known_tracers]\n",
    "  outs1_res = bind(xla_call_p, *known_vals, jaxpr=jaxpr1, num_consts=0)\n",
    "  outs1, res = split_list(outs1_res, len(jaxpr1.outs) - num_res)\n",
    "  res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]\n",
    "  outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)\n",
    "           for v in jaxpr2.outs]\n",
    "  eqn = JaxprEqnRecipe(xla_call_p, res_tracers + unknown_tracers,\n",
    "                       dict(jaxpr=jaxpr2, num_consts=0),\n",
    "                       [v.aval for v in jaxpr2.outs], map(ref, outs2))\n",
    "  for t in outs2: t.recipe = eqn\n",
    "  return merge_lists(out_unknowns, outs1, outs2)\n",
    "partial_eval_rules[xla_call_p] = xla_call_partial_eval\n",
    "\n",
    "def partial_eval_jaxpr(jaxpr: Jaxpr, in_unknowns: List[bool],\n",
    "                       instantiate: Optional[List[bool]] = None,\n",
    "                       ) -> Tuple[Jaxpr, Jaxpr, List[bool], int]:\n",
    "  env: Dict[Var, bool] = {}\n",
    "  residuals: Set[Var] = set()\n",
    "\n",
    "  def read(x: Atom) -> bool:\n",
    "    return type(x) is Var and env[x]\n",
    "\n",
    "  def write(unk: bool, v: Var) -> None:\n",
    "    env[v] = unk\n",
    "\n",
    "  def new_res(x: Atom) -> Atom:\n",
    "    if type(x) is Var: residuals.add(x)\n",
    "    return x\n",
    "\n",
    "  eqns1, eqns2 = [], []\n",
    "  map(write, in_unknowns, jaxpr.in_binders)\n",
    "  for eqn in jaxpr.eqns:\n",
    "    unks_in = map(read, eqn.inputs)\n",
    "    rule = partial_eval_jaxpr_rules.get(eqn.primitive)\n",
    "    if rule:\n",
    "      eqn1, eqn2, unks_out, res = rule(unks_in, eqn)\n",
    "      eqns1.append(eqn1); eqns2.append(eqn2); residuals.update(res)\n",
    "      map(write, unks_out, eqn.out_binders)\n",
    "    elif any(unks_in):\n",
    "      inputs = [v if unk else new_res(v) for unk, v in zip(unks_in, eqn.inputs)]\n",
    "      eqns2.append(JaxprEqn(eqn.primitive, inputs, eqn.params, eqn.out_binders))\n",
    "      map(partial(write, True), eqn.out_binders)\n",
    "    else:\n",
    "      eqns1.append(eqn)\n",
    "      map(partial(write, False), eqn.out_binders)\n",
    "  out_unknowns = map(read, jaxpr.outs)\n",
    "  if instantiate is not None:\n",
    "    for v, uk, inst in zip(jaxpr.outs, out_unknowns, instantiate):\n",
    "      if inst and not uk: new_res(v)\n",
    "    out_unknowns = map(op.or_, out_unknowns, instantiate)\n",
    "\n",
    "  residuals, num_res = list(residuals), len(residuals)\n",
    "  assert all(type(v) is Var for v in residuals), residuals\n",
    "\n",
    "  ins1, ins2 = partition_list(in_unknowns, jaxpr.in_binders)\n",
    "  outs1, outs2 = partition_list(out_unknowns, jaxpr.outs)\n",
    "\n",
    "  jaxpr1 = Jaxpr(ins1, eqns1, outs1 + residuals)\n",
    "  jaxpr2 = Jaxpr(residuals + ins2, eqns2, outs2)\n",
    "  typecheck_partial_eval_jaxpr(jaxpr, in_unknowns, out_unknowns, jaxpr1, jaxpr2)\n",
    "\n",
    "  return jaxpr1, jaxpr2, out_unknowns, num_res\n",
    "\n",
    "def typecheck_partial_eval_jaxpr(jaxpr, unks_in, unks_out, jaxpr1, jaxpr2):\n",
    "  jaxprty = typecheck_jaxpr(jaxpr)    # (a1,  a2) -> (b1, b2 )\n",
    "  jaxpr1ty = typecheck_jaxpr(jaxpr1)  #  a1       -> (b1, res)\n",
    "  jaxpr2ty = typecheck_jaxpr(jaxpr2)  # (res, a2) -> b2\n",
    "\n",
    "  a1, a2 = partition_list(unks_in, jaxprty.in_types)\n",
    "  b1, b2 = partition_list(unks_out, jaxprty.out_types)\n",
    "  b1_, res = split_list(jaxpr1ty.out_types, len(b1))\n",
    "  res_, a2_ = split_list(jaxpr2ty.in_types, len(res))\n",
    "  b2_ = jaxpr2ty.out_types\n",
    "\n",
    "  if jaxpr1ty.in_types != a1: raise TypeError\n",
    "  if jaxpr2ty.out_types != b2: raise TypeError\n",
    "  if b1 != b1_: raise TypeError\n",
    "  if res != res_: raise TypeError\n",
    "  if a2 != a2_: raise TypeError\n",
    "  if b2 != b2_: raise TypeError\n",
    "\n",
    "partial_eval_jaxpr_rules = {}\n",
    "\n",
    "def xla_call_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,\n",
    "                       ) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Var]]:\n",
    "  jaxpr = eqn.params['jaxpr']\n",
    "  jaxpr1, jaxpr2, unks_out, num_res = partial_eval_jaxpr(jaxpr, unks_in)\n",
    "  ins1, ins2 = partition_list(unks_in, eqn.inputs)\n",
    "  out_binders1, out_binders2 = partition_list(unks_out, eqn.out_binders)\n",
    "  residuals = [Var(v.aval) for v in jaxpr2.in_binders[:num_res]]\n",
    "  eqn1 = JaxprEqn(xla_call_p, ins1, dict(jaxpr=jaxpr1, num_consts=0),\n",
    "                  out_binders1 + residuals)\n",
    "  eqn2 = JaxprEqn(xla_call_p, residuals + ins2,\n",
    "                  dict(jaxpr=jaxpr2, num_consts=0), out_binders2)\n",
    "  return eqn1, eqn2, unks_out, residuals\n",
    "partial_eval_jaxpr_rules[xla_call_p] = xla_call_peval_eqn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With that, we can compose `linearize` and `jit` however we like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit\n",
    "def f(x):\n",
    "  y = sin(x) * 2.\n",
    "  z = - y + x\n",
    "  return z\n",
    "\n",
    "y, f_lin = linearize(f, 3.)\n",
    "y_dot = f_lin(1.)\n",
    "print(y, y_dot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit\n",
    "def f(x):\n",
    "  y = sin(x) * 2.\n",
    "  z = g(x, y)\n",
    "  return z\n",
    "\n",
    "@jit\n",
    "def g(x, y):\n",
    "  return cos(x) + y\n",
    "\n",
    "y, f_lin = linearize(f, 3.)\n",
    "y_dot = f_lin(1.)\n",
    "print(y, y_dot)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `vjp` and `grad`\n",
    "\n",
    "The `vjp` transformation works a lot like linearize. Its type signature is\n",
    "analogous:\n",
    "\n",
    "```\n",
    "linearize : (a -> b) -> a -> (b, T a -o T b)\n",
    "vjp       : (a -> b) -> a -> (b, T b -o T a)\n",
    "```\n",
    "\n",
    "The only difference is that we transpose the linear part of the computation\n",
    "before returning it, so that it goes from type `T a -o T b` to type `T b -o T\n",
    "a`. That is, we'll implement `vjp` as, essentially,\n",
    "\n",
    "```\n",
    "def vjp(f, x):\n",
    "  y, f_lin = linearize(f, x)\n",
    "  f_vjp = lambda y_bar: transpose(f_lin)(y_bar)\n",
    "  return y, f_vjp\n",
    "```\n",
    "\n",
    "Since we have the linear computation as a jaxpr, not just a Python callable,\n",
    "we can implement the transpose transformation as a jaxpr interpreter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vjp_flat(f, *primals_in):\n",
    "  pvals_in = ([PartialVal.known(x) for x in primals_in] +\n",
    "              [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])\n",
    "  primal_pvals_in, tangent_pvals_in = split_half(pvals_in)\n",
    "  def f_jvp(*primals_tangents_in):\n",
    "    primals_out, tangents_out = jvp(f, *split_half(primals_tangents_in))\n",
    "    return [*primals_out, *tangents_out]\n",
    "  jaxpr, pvals_out, consts = partial_eval_flat(f_jvp, pvals_in)  # linearize\n",
    "  primal_pvals, _ = split_half(pvals_out)\n",
    "  assert all(pval.is_known for pval in primal_pvals)\n",
    "  primals_out = [pval.const for pval in primal_pvals]\n",
    "  transpose_inputs = consts + [UndefPrimal(p.aval) for p in tangent_pvals_in]\n",
    "  f_vjp = lambda *cts: eval_jaxpr_transposed(jaxpr, transpose_inputs, cts)\n",
    "  return primals_out, f_vjp\n",
    "\n",
    "def vjp(f, *primals_in):\n",
    "  primals_in_flat, in_tree = tree_flatten(primals_in)\n",
    "  f, out_tree = flatten_fun(f, in_tree)\n",
    "  primals_out_flat, f_vjp_flat = vjp_flat(f, *primals_in_flat)\n",
    "  primals_out = tree_unflatten(out_tree(), primals_out_flat)\n",
    "\n",
    "  def f_vjp(*cotangents_out):\n",
    "    cotangents_out_flat, _ = tree_flatten(cotangents_out)\n",
    "    cotangents_in_flat = f_vjp_flat(*cotangents_out_flat)\n",
    "    return tree_unflatten(in_tree, cotangents_in_flat)\n",
    "\n",
    "  return primals_out, f_vjp\n",
    "\n",
    "class UndefPrimal(NamedTuple):\n",
    "  aval: ShapedArray\n",
    "\n",
    "register_pytree_node(UndefPrimal,\n",
    "                     lambda u: (u.aval, ()),\n",
    "                     lambda aval, _: UndefPrimal(aval))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use `UndefPrimal` instances to indicate which arguments with respect to\n",
    "which we want to transpose. These arise because in general, being explicit\n",
    "about closed-over values, we want to transpose functions of type\n",
    "`a -> b -o c` to functions of type `a -> c -o b`. Even more generally, the\n",
    "inputs with respect to which the function is linear could be scattered through\n",
    "the argument list. So we indicate the linear positions using `UndefPrimal`.\n",
    "We register `UndefPrimal` as a pytree node because the pytree mechanism gives\n",
    "a handy way to prune these placeholders out of argument lists.\n",
    "\n",
    "Next, we can write `eval_jaxpr_transposed`, along with transpose rules for\n",
    "all primitives which can be linear in at least one argument:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NB: the analogous function in JAX is called 'backward_pass'\n",
    "def eval_jaxpr_transposed(jaxpr: Jaxpr, args: List[Any], cotangents: List[Any]\n",
    "                          ) -> List[Any]:\n",
    "  primal_env: Dict[Var, Any] = {}\n",
    "  ct_env: Dict[Var, Any] = {}\n",
    "\n",
    "  def read_primal(x: Atom) -> Any:\n",
    "    return primal_env.get(x, UndefPrimal(x.aval)) if type(x) is Var else x.val\n",
    "\n",
    "  def write_primal(v: Var, val: Any) -> None:\n",
    "    if type(val) is not UndefPrimal:\n",
    "      primal_env[v] = val\n",
    "\n",
    "  def read_cotangent(v: Var) -> Any:\n",
    "    return ct_env.pop(v, np.zeros(v.aval.shape, v.aval.dtype))\n",
    "\n",
    "  def write_cotangent(x: Atom, val: Any):\n",
    "    if type(x) is Var and val is not None:\n",
    "      ct_env[x] = add(ct_env[x], val) if x in ct_env else val\n",
    "\n",
    "  map(write_primal, jaxpr.in_binders, args)\n",
    "  map(write_cotangent, jaxpr.outs, cotangents)\n",
    "  for eqn in jaxpr.eqns[::-1]:\n",
    "    primals_in = map(read_primal, eqn.inputs)\n",
    "    cts_in = map(read_cotangent, eqn.out_binders)\n",
    "    rule = transpose_rules[eqn.primitive]\n",
    "    cts_out = rule(cts_in, *primals_in, **eqn.params)\n",
    "    map(write_cotangent, eqn.inputs, cts_out)\n",
    "\n",
    "  return [read_cotangent(v) for v, x in zip(jaxpr.in_binders, args)\n",
    "          if type(x) is UndefPrimal]\n",
    "\n",
    "transpose_rules = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mul_transpose_rule(cts, x, y):\n",
    "  z_bar, = cts\n",
    "  assert (type(x) is UndefPrimal) ^ (type(y) is UndefPrimal)\n",
    "  return [mul(z_bar, y), None] if type(x) is UndefPrimal else [None, mul(x, z_bar)]\n",
    "transpose_rules[mul_p] = mul_transpose_rule\n",
    "\n",
    "def neg_transpose_rule(cts, x):\n",
    "  ybar, = cts\n",
    "  assert type(x) is UndefPrimal\n",
    "  return [neg(ybar)]\n",
    "transpose_rules[neg_p] = neg_transpose_rule\n",
    "\n",
    "def add_transpose_rule(cts, x, y):\n",
    "  z_bar, = cts\n",
    "  return [z_bar, z_bar]\n",
    "transpose_rules[add_p] = add_transpose_rule\n",
    "\n",
    "def reduce_sum_transpose_rule(cts, x, *, axis):\n",
    "  y_bar, = cts\n",
    "  return [broadcast(y_bar, x.aval.shape, axis)]\n",
    "transpose_rules[reduce_sum_p] = reduce_sum_transpose_rule\n",
    "\n",
    "def xla_call_transpose_rule(cts, *invals, jaxpr, num_consts):\n",
    "  del num_consts  # Unused\n",
    "  undef_primals = [type(x) is UndefPrimal for x in invals]\n",
    "  transposed_jaxpr, new_consts = transpose_jaxpr(jaxpr, tuple(undef_primals))\n",
    "  residuals, _ = partition_list(undef_primals, invals)\n",
    "  outs = bind(xla_call_p, *new_consts, *residuals, *cts,\n",
    "              jaxpr=transposed_jaxpr, num_consts=len(new_consts))\n",
    "  outs = iter(outs)\n",
    "  return [next(outs) if undef else None for undef in undef_primals]\n",
    "transpose_rules[xla_call_p] = xla_call_transpose_rule\n",
    "\n",
    "@lru_cache()\n",
    "def transpose_jaxpr(jaxpr: Jaxpr, undef_primals: Tuple[bool, ...]\n",
    "                    ) -> Tuple[Jaxpr, List[Any]]:\n",
    "  avals_in, avals_out = typecheck_jaxpr(jaxpr)\n",
    "  traceable = partial(eval_jaxpr_transposed, jaxpr)\n",
    "  args = [UndefPrimal(a) if u else a for a, u in zip(avals_in, undef_primals)]\n",
    "  trans_jaxpr, consts, _ = make_jaxpr(traceable, tuple(args), tuple(avals_out))\n",
    "  typecheck_jaxpr(trans_jaxpr)\n",
    "  return trans_jaxpr, consts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we can linearize and transpose, we can finally write `grad`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def grad(f):\n",
    "  def gradfun(x, *xs):\n",
    "    y, f_vjp = vjp(f, x, *xs)\n",
    "    if np.shape(y) != (): raise TypeError\n",
    "    x_bar, *_ = f_vjp(np.ones(np.shape(y), np.result_type(y)))\n",
    "    return x_bar\n",
    "  return gradfun"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y, f_vjp = vjp(sin, 3.)\n",
    "print(f_vjp(1.), cos(3.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x):\n",
    "  y = sin(x) * 2.\n",
    "  z = - y + x\n",
    "  return z\n",
    "\n",
    "print(grad(f)(3.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit\n",
    "def f(x):\n",
    "  y = x * 2.\n",
    "  z = g(y)\n",
    "  return z\n",
    "\n",
    "@jit\n",
    "def g(x):\n",
    "  return cos(x) * 2.\n",
    "\n",
    "print(grad(f)(3.))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's something of a compositionality stress test:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from core_test.py fun_with_nested_calls_2\n",
    "def foo(x):\n",
    "  @jit\n",
    "  def bar(y):\n",
    "    def baz(w):\n",
    "      q = jit(lambda x: y)(x)\n",
    "      q = q + jit(lambda: y)()\n",
    "      q = q + jit(lambda y: w + y)(y)\n",
    "      q = jit(lambda w: jit(sin)(x) * y)(1.0) + q\n",
    "      return q\n",
    "    p, t = jvp(baz, (x + 1.0,), (y,))\n",
    "    return t + (x * p)\n",
    "  return bar(x)\n",
    "\n",
    "def assert_allclose(*vals):\n",
    "  for v1, v2 in zip(vals[:-1], vals[1:]):\n",
    "    np.testing.assert_allclose(v1, v2)\n",
    "\n",
    "ans1 = f(3.)\n",
    "ans2 = jit(f)(3.)\n",
    "ans3, _ = jvp(f, (3.,), (5.,))\n",
    "ans4, _ = jvp(jit(f), (3.,), (5.,))\n",
    "assert_allclose(ans1, ans2, ans3, ans4)\n",
    "\n",
    "deriv1 = grad(f)(3.)\n",
    "deriv2 = grad(jit(f))(3.)\n",
    "deriv3 = jit(grad(jit(f)))(3.)\n",
    "_, deriv4 = jvp(f, (3.,), (1.,))\n",
    "_, deriv5 = jvp(jit(f), (3.,), (1.,))\n",
    "assert_allclose(deriv1, deriv2, deriv3, deriv4, deriv5)\n",
    "\n",
    "hess1 = grad(grad(f))(3.)\n",
    "hess2 = grad(grad(jit(f)))(3.)\n",
    "hess3 = grad(jit(grad(f)))(3.)\n",
    "hess4 = jit(grad(grad(f)))(3.)\n",
    "_, hess5 = jvp(grad(f), (3.,), (1.,))\n",
    "_, hess6 = jvp(jit(grad(f)), (3.,), (1.,))\n",
    "_, hess7 = jvp(jit(grad(f)), (3.,), (1.,))\n",
    "assert_allclose(hess1, hess2, hess3, hess4, hess5, hess6, hess7)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 5: the control flow primitives `cond`\n",
    "\n",
    "Next we'll add higher-order primitives for staged-out control flow. These\n",
    "resemble `jit` from Part 3, another higher-order primitive, but differ in that\n",
    "they are parameterized by multiple callables rather than just one."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adding `cond`\n",
    "\n",
    "We introduce a `cond` primitive to represent conditional application of one\n",
    "function or another inside a jaxpr. We write the type of `cond` as\n",
    "`Bool -> (a -> b) -> (a -> b) -> a -> b`. In words, `cond` takes a boolean\n",
    "representing the predicate and two functions of equal types. Depending on the\n",
    "value of the predicate, it applies one function or the other to its final\n",
    "argument.\n",
    "\n",
    "In Python, we represent it as a function which itself takes two functions as\n",
    "arguments. As with `jit`, the first step is to call `make_jaxpr` on its\n",
    "callable arguments to turn them into jaxprs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cond(pred, true_fn, false_fn, *operands):\n",
    "  avals_in = [raise_to_shaped(get_aval(x)) for x in operands]\n",
    "  true_jaxpr, true_consts, out_tree = make_jaxpr(true_fn, *avals_in)\n",
    "  false_jaxpr, false_consts, out_tree_ = make_jaxpr(false_fn, *avals_in)\n",
    "  if out_tree != out_tree_: raise TypeError\n",
    "  true_jaxpr, false_jaxpr = _join_jaxpr_consts(\n",
    "      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))\n",
    "  if typecheck_jaxpr(true_jaxpr) != typecheck_jaxpr(false_jaxpr):\n",
    "    raise TypeError\n",
    "  outs = bind_cond(pred, *true_consts, *false_consts, *operands,\n",
    "                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)\n",
    "  return tree_unflatten(out_tree, outs)\n",
    "cond_p = Primitive('cond')\n",
    "\n",
    "def _join_jaxpr_consts(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int\n",
    "                       ) -> Tuple[Jaxpr, Jaxpr]:\n",
    "  jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)\n",
    "  assert jaxpr1_type.in_types[n1:] == jaxpr2_type.in_types[n2:]\n",
    "  consts1, rest1 = split_list(jaxpr1.in_binders, n1)\n",
    "  consts2, rest2 = split_list(jaxpr2.in_binders, n2)\n",
    "  new_jaxpr1 = Jaxpr(consts1 + consts2 + rest1, jaxpr1.eqns, jaxpr1.outs)\n",
    "  new_jaxpr2 = Jaxpr(consts1 + consts2 + rest2, jaxpr2.eqns, jaxpr2.outs)\n",
    "  return new_jaxpr1, new_jaxpr2\n",
    "\n",
    "def bind_cond(pred, *args, true_jaxpr, false_jaxpr):\n",
    "  assert len(args) == len(true_jaxpr.in_binders) == len(false_jaxpr.in_binders)\n",
    "  return bind(cond_p, pred, *args, true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We require `true_jaxpr` and `false_jaxpr` to have the same type, but because\n",
    "they might close over different constants (and because jaxprs can only\n",
    "represent closed terms, i.e. can't have free variables and are instead\n",
    "closure-converted) we need to use the helper `_join_jaxpr_consts` to make\n",
    "consistent the input binder lists of the two jaxprs. (To be more economical we\n",
    "could try to identify pairs of constants with the same shapes, but instead we\n",
    "just concatenate the lists of constants.)\n",
    "\n",
    "Next we can turn to adding interpreter rules for `cond`. Its evaluation rule\n",
    "is simple:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cond_impl(pred, *operands, true_jaxpr, false_jaxpr):\n",
    "  if pred:\n",
    "    return eval_jaxpr(true_jaxpr, operands)\n",
    "  else:\n",
    "    return eval_jaxpr(false_jaxpr, operands)\n",
    "impl_rules[cond_p] = cond_impl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = cond(True, lambda: 3, lambda: 4)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For its JVP and vmap rules, we only need to call the same `jvp_jaxpr` and\n",
    "`vmap_jaxpr` utilities we created for `jit`, followed by another pass of\n",
    "`_join_jaxpr_consts`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cond_jvp_rule(primals, tangents, *, true_jaxpr, false_jaxpr):\n",
    "  pred, *primals = primals\n",
    "  _   , *tangents = tangents\n",
    "  true_jaxpr , true_consts  = jvp_jaxpr(true_jaxpr)\n",
    "  false_jaxpr, false_consts = jvp_jaxpr(false_jaxpr)\n",
    "  true_jaxpr, false_jaxpr = _join_jaxpr_consts(\n",
    "      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))\n",
    "  assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)\n",
    "  outs = bind_cond(pred, *true_consts, *false_consts, *primals, *tangents,\n",
    "                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)\n",
    "  primals_out, tangents_out = split_half(outs)\n",
    "  return primals_out, tangents_out\n",
    "jvp_rules[cond_p] = cond_jvp_rule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out, out_tan = jvp(lambda x: cond(True, lambda: x * x, lambda: 0.), (1.,), (1.,))\n",
    "print(out_tan)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cond_vmap_rule(axis_size, vals_in, dims_in, *, true_jaxpr, false_jaxpr):\n",
    "  pred    , *vals_in = vals_in\n",
    "  pred_dim, *dims_in = dims_in\n",
    "  if pred_dim is not not_mapped: raise NotImplementedError  # TODO\n",
    "  true_jaxpr, true_consts = vmap_jaxpr(true_jaxpr, axis_size, tuple(dims_in))\n",
    "  false_jaxpr, false_consts = vmap_jaxpr(false_jaxpr, axis_size, tuple(dims_in))\n",
    "  true_jaxpr, false_jaxpr = _join_jaxpr_consts(\n",
    "      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))\n",
    "  assert typecheck_jaxpr(true_jaxpr) == typecheck_jaxpr(false_jaxpr)\n",
    "  outs = bind_cond(pred, *true_consts, *false_consts, *vals_in,\n",
    "                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)\n",
    "  return outs, [0] * len(outs)\n",
    "vmap_rules[cond_p] = cond_vmap_rule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = np.array([1., 2., 3])\n",
    "out = vmap(lambda x: cond(True, lambda: x + 1., lambda: 0.), (0,))(xs)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice that we're not currently supporting the case where the predicate value\n",
    "itself is batched. In mainline JAX, we handle this case by transforming the\n",
    "conditional to a [select primitive](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.select.html).\n",
    "That transformation is semantically correct so long as `true_fun` and\n",
    "`false_fun` do not involve any side-effecting primitives.\n",
    "\n",
    "Another thing not represented here, but present in the mainline JAX, is that\n",
    "applying transformations to two jaxprs of equal type might result in jaxprs of\n",
    "different types. For example, applying the mainline JAX version of\n",
    "`vmap_jaxpr` to the identity-function jaxpr\n",
    "\n",
    "```\n",
    "{ lambda a:float32[] .\n",
    "  let\n",
    "  in ( a ) }\n",
    "```\n",
    "\n",
    "would result in a jaxpr with a batched output, of type\n",
    "`[float32[10]] -> [float32[10]]` if the batch size were 10, while applying it\n",
    "to the zero-function jaxpr\n",
    "\n",
    "```\n",
    "{ lambda a:float32[] .\n",
    "  let\n",
    "  in ( 0. ) }\n",
    "```\n",
    "\n",
    "would result in a jaxpr with an unbatched output, of type\n",
    "`[float32[10]] -> [float32[]]`. This is an optimization, aimed at not batching\n",
    "values unnecessarily. But it means that in `cond` we'd need an extra step of\n",
    "joining the two transformed jaxprs to have consistent output types. We don't\n",
    "need this step here because we chose `vmap_jaxpr` always to batch all outputs\n",
    "over the leading axis."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we can turn to abstract evaluation and XLA lowering rules:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cond_abstract_eval(pred_type, *in_types, true_jaxpr, false_jaxpr):\n",
    "  if pred_type != ShapedArray((), np.dtype('bool')): raise TypeError\n",
    "  jaxpr_type = typecheck_jaxpr(true_jaxpr)\n",
    "  if jaxpr_type != typecheck_jaxpr(false_jaxpr):\n",
    "    raise TypeError\n",
    "  if not all(t1 == t2 for t1, t2 in zip(jaxpr_type.in_types, in_types)):\n",
    "    raise TypeError\n",
    "  return jaxpr_type.out_types\n",
    "abstract_eval_rules[cond_p] = cond_abstract_eval\n",
    "\n",
    "def cond_translation(c, in_avals, in_vals, *, true_jaxpr, false_jaxpr):\n",
    "  del in_avals  # Unused\n",
    "  pred, *in_vals = in_vals\n",
    "  flat_vals, in_tree = tree_flatten(in_vals)\n",
    "  operand = xops.Tuple(c, flat_vals)\n",
    "  operand_shape = c.get_shape(operand)\n",
    "\n",
    "  def make_comp(name: str, jaxpr: Jaxpr) -> xe.XlaComputation:\n",
    "    c = xc.XlaBuilder(name)\n",
    "    operand = xops.Parameter(c, 0, operand_shape)\n",
    "    operands = tree_unflatten(in_tree, destructure_tuple(c, operand))\n",
    "    outs = jaxpr_subcomp(c, jaxpr, operands)\n",
    "    return c.build(xops.Tuple(c, outs))\n",
    "\n",
    "  true_comp = make_comp('true_fn', true_jaxpr)\n",
    "  false_comp = make_comp('false_fn', false_jaxpr)\n",
    "\n",
    "  int_etype = xc.dtype_to_etype(np.dtype('int32'))\n",
    "  out = xops.Conditional(xops.ConvertElementType(pred, int_etype),\n",
    "                         [false_comp, true_comp], [operand] * 2)\n",
    "  return destructure_tuple(c, out)\n",
    "xla_translations[cond_p] = cond_translation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = jit(lambda: cond(False, lambda: 1, lambda: 2))()\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, to support reverse-mode automatic differentiation, we need partial\n",
    "evaluation and transposition rules. For partial evaluation, we need to\n",
    "introduce another jaxpr-munging utility, `_join_jaxpr_res`, to handle the fact\n",
    "that applying partial evaluation to `true_fun` and `false_fun` will in general\n",
    "result in distinct residuals. We use `_join_jaxpr_res` to make the output\n",
    "types of the transformed jaxprs consistent (while `_join_jaxpr_consts` dealt\n",
    "with input types)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cond_partial_eval(trace, tracers, *, true_jaxpr, false_jaxpr):\n",
    "  pred_tracer, *tracers = tracers\n",
    "  assert pred_tracer.pval.is_known\n",
    "  pred = pred_tracer.pval.const\n",
    "  in_uks = [not t.pval.is_known for t in tracers]\n",
    "\n",
    "  *jaxprs, out_uks, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, in_uks)\n",
    "  t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs\n",
    "\n",
    "  known_tracers, unknown_tracers = partition_list(in_uks, tracers)\n",
    "  known_vals = [t.pval.const for t in known_tracers]\n",
    "  outs1_res = bind_cond(pred, *known_vals,\n",
    "                        true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1)\n",
    "  outs1, res = split_list(outs1_res, len(outs1_res) - num_res)\n",
    "  pred_tracer_ = trace.instantiate_const(full_raise(trace, pred_tracer))\n",
    "  res_tracers = [trace.instantiate_const(full_raise(trace, x)) for x in res]\n",
    "  outs2 = [PartialEvalTracer(trace, PartialVal.unknown(v.aval), None)\n",
    "           for v in t_jaxpr2.outs]\n",
    "  eqn = JaxprEqnRecipe(cond_p, [pred_tracer_, *res_tracers, *unknown_tracers],\n",
    "                       dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),\n",
    "                       [v.aval for v in t_jaxpr2.outs], map(ref, outs2))\n",
    "  for t in outs2: t.recipe = eqn\n",
    "  return merge_lists(out_uks, outs1, outs2)\n",
    "partial_eval_rules[cond_p] = cond_partial_eval\n",
    "\n",
    "def _cond_partial_eval(true_jaxpr: Jaxpr, false_jaxpr: Jaxpr, in_uks: List[bool]\n",
    "                       ) -> Tuple[Jaxpr, Jaxpr, Jaxpr, Jaxpr, List[bool], int]:\n",
    "  _, _, t_out_uks, _ = partial_eval_jaxpr(true_jaxpr , in_uks)\n",
    "  _, _, f_out_uks, _ = partial_eval_jaxpr(false_jaxpr, in_uks)\n",
    "  out_uks = map(op.or_, t_out_uks, f_out_uks)\n",
    "\n",
    "  t_jaxpr1, t_jaxpr2, _, t_nres = partial_eval_jaxpr(true_jaxpr , in_uks, out_uks)\n",
    "  f_jaxpr1, f_jaxpr2, _, f_nres = partial_eval_jaxpr(false_jaxpr, in_uks, out_uks)\n",
    "\n",
    "  t_jaxpr1, f_jaxpr1 = _join_jaxpr_res(t_jaxpr1, f_jaxpr1, t_nres, f_nres)\n",
    "  t_jaxpr2, f_jaxpr2 = _join_jaxpr_consts(t_jaxpr2, f_jaxpr2, t_nres, f_nres)\n",
    "  assert typecheck_jaxpr(t_jaxpr1) == typecheck_jaxpr(f_jaxpr1)\n",
    "  assert typecheck_jaxpr(t_jaxpr2) == typecheck_jaxpr(f_jaxpr2)\n",
    "  num_res = t_nres + f_nres\n",
    "\n",
    "  return t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2, out_uks, num_res\n",
    "\n",
    "def _join_jaxpr_res(jaxpr1: Jaxpr, jaxpr2: Jaxpr, n1: int, n2: int\n",
    "                    ) -> Tuple[Jaxpr, Jaxpr]:\n",
    "  jaxpr1_type, jaxpr2_type = typecheck_jaxpr(jaxpr1), typecheck_jaxpr(jaxpr2)\n",
    "  out_types1, _ = split_list(jaxpr1_type.out_types, len(jaxpr1.outs) - n1)\n",
    "  out_types2, _ = split_list(jaxpr2_type.out_types, len(jaxpr2.outs) - n2)\n",
    "  assert out_types1 == out_types2\n",
    "  outs1, res1 = split_list(jaxpr1.outs, len(jaxpr1.outs) - n1)\n",
    "  outs2, res2 = split_list(jaxpr2.outs, len(jaxpr2.outs) - n2)\n",
    "  zeros_like1 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res1]\n",
    "  zeros_like2 = [Lit(np.zeros(v.aval.shape, v.aval.dtype)) for v in res2]\n",
    "  new_jaxpr1 = Jaxpr(jaxpr1.in_binders, jaxpr1.eqns, outs1 + res1 + zeros_like2)\n",
    "  new_jaxpr2 = Jaxpr(jaxpr2.in_binders, jaxpr2.eqns, outs2 + zeros_like1 + res2)\n",
    "  return new_jaxpr1, new_jaxpr2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, f_lin = linearize(lambda x: cond(True, lambda: x, lambda: 0.), 1.)\n",
    "out = f_lin(3.14)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cond_peval_eqn(unks_in: List[bool], eqn: JaxprEqn,\n",
    "                   ) -> Tuple[JaxprEqn, JaxprEqn, List[bool], List[Atom]]:\n",
    "  pred_unk, *unks_in = unks_in\n",
    "  assert not pred_unk\n",
    "  true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']\n",
    "  *jaxprs, unks_out, num_res = _cond_partial_eval(true_jaxpr, false_jaxpr, unks_in)\n",
    "  t_jaxpr1, f_jaxpr1, t_jaxpr2, f_jaxpr2 = jaxprs\n",
    "  ins1, ins2 = partition_list(unks_in, eqn.inputs[1:])\n",
    "  outs1, outs2 = partition_list(unks_out, eqn.out_binders)\n",
    "  residuals, _ = split_list(t_jaxpr2.in_binders, num_res)\n",
    "  eqn1 = JaxprEqn(cond_p, [eqn.inputs[0], *ins1],\n",
    "                  dict(true_jaxpr=t_jaxpr1, false_jaxpr=f_jaxpr1),\n",
    "                  outs1 + residuals)\n",
    "  eqn2 = JaxprEqn(cond_p, [eqn.inputs[0], *residuals, *ins2],\n",
    "                  dict(true_jaxpr=t_jaxpr2, false_jaxpr=f_jaxpr2),\n",
    "                  outs2)\n",
    "  res = [eqn.inputs[0], *residuals] if type(eqn.inputs[0]) is Var else residuals\n",
    "  return eqn1, eqn2, unks_out, res\n",
    "partial_eval_jaxpr_rules[cond_p] = cond_peval_eqn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, f_lin = linearize(jit(lambda x: cond(True, lambda: x, lambda: 0.)), 1.)\n",
    "out = f_lin(3.14)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Transposition is a fairly straightforward application of `transpose_jaxpr`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cond_transpose_rule(cts, pred, *invals, true_jaxpr, false_jaxpr):\n",
    "  undef_primals = tuple([type(x) is UndefPrimal for x in invals])\n",
    "  true_jaxpr, true_consts = transpose_jaxpr(true_jaxpr, undef_primals)\n",
    "  false_jaxpr, false_consts = transpose_jaxpr(false_jaxpr, undef_primals)\n",
    "  true_jaxpr, false_jaxpr = _join_jaxpr_consts(\n",
    "      true_jaxpr, false_jaxpr, len(true_consts), len(false_consts))\n",
    "  res = [x for x in invals if type(x) is not UndefPrimal]\n",
    "  outs = bind_cond(pred, *true_consts, *false_consts, *res, *cts,\n",
    "                   true_jaxpr=true_jaxpr, false_jaxpr=false_jaxpr)\n",
    "  outs = iter(outs)\n",
    "  return [None] + [next(outs) if type(x) is UndefPrimal else None for x in invals]\n",
    "transpose_rules[cond_p] = cond_transpose_rule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = grad(lambda x: cond(True, lambda: x * x, lambda: 0.))(1.)\n",
    "print(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "def pprint_cond(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:\n",
    "  true_jaxpr, false_jaxpr = eqn.params['true_jaxpr'], eqn.params['false_jaxpr']\n",
    "  new_params = {k:v for k, v in eqn.params.items() if not k.endswith('jaxpr')}\n",
    "  lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n",
    "  rhs = (pp(eqn.primitive.name) >> pp_params(new_params) >>\n",
    "         pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n",
    "                     for x in eqn.inputs)))\n",
    "  return vcat([lhs >> pp(' = ') >> rhs,\n",
    "               pp_jaxpr(true_jaxpr).indent(2),\n",
    "               pp_jaxpr(false_jaxpr).indent(2)])\n",
    "pp_rules[cond_p] = pprint_cond"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,md:myst,py",
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
back to top