Raw File
onnx2xla.py
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An ONNX to XLA compiler by JAX-tracing a Numpy-backed ONNX interpreter."""

from io import BytesIO
import hashlib
import urllib.request
import sys

import numpy as np
import onnx
from onnx import numpy_helper

import jax.numpy as jnp
from jax import jit, grad
from jax import lax


def _asarray(proto):
  return numpy_helper.to_array(proto).reshape(tuple(proto.dims))


attr_types = dict(onnx.AttributeProto.AttributeType.items())
attribute_handlers = {
    attr_types['FLOAT']: lambda a: a.f,
    attr_types['INT']: lambda a: a.i,
    attr_types['STRING']: lambda a: a.s,
    attr_types['TENSOR']: lambda a: _asarray(a.t),
    attr_types['FLOATS']: lambda a: a.floats,
    attr_types['INTS']: lambda a: a.ints,
    attr_types['STRINGS']: lambda a: a.strings,
    attr_types['TENSORS']: lambda a: [_asarray(x) for x in a.tensors],
}


def onnx_maxpool(x, kernel_shape, pads=None, strides=None):
  """Numpy-backed implementation of ONNX MaxPool op."""
  prefix = (1,) * (x.ndim - len(kernel_shape))
  dims = prefix + tuple(kernel_shape)
  pads = tuple(pads) if pads else [0] * len(kernel_shape)
  strides = (prefix + tuple(strides)) if strides else [1] * len(kernel_shape)
  return [lax.reduce_window(x, -jnp.inf, lax.max, dims, strides, 'VALID')]


def onnx_conv(x, w, b=0, group=1, kernel_shape=None, pads=None, strides=None,
              dilations=None, auto_pad=None):
  """Numpy-backed implementation of ONNX Conv op."""
  assert group == 1
  kernel_shape = kernel_shape or w.shape
  strides = strides or [1] * (w.ndim - 2)
  if auto_pad:
    auto_pad = 'SAME' if auto_pad.startswith(b'SAME') else 'VALID'
    pads = lax.padtype_to_pads(x.shape[2:], w.shape[2:], strides, auto_pad)
  else:
    pads = pads or [0] * (w.ndim - 2)
  lhs_dilation = [1] * (w.ndim - 2)
  rhs_dilation = dilations or [1] * (w.ndim - 2)
  return [lax.conv_with_general_padding(x, w, strides, pads,
                                        lhs_dilation, rhs_dilation) + b]


def onnx_add(a, b, axis=None, broadcast=True):
  """Numpy-backed implementation of ONNX Add op."""
  if broadcast:
    axis = (a.dim - b.ndim) if axis is None else axis % a.ndim
    assert a.shape[axis:][:b.ndim] == b.shape
    b_shape = np.ones(a.ndim, dtype='int64')
    b_shape[axis:axis + b.ndim] = b.shape
    b = jnp.reshape(b, b_shape)
  return [a + b]


onnx_ops = {
    'Add': onnx_add,
    'Constant': lambda value: [value],
    'Conv': onnx_conv,
    'MatMul': lambda x, y: [jnp.matmul(x, y)],
    'MaxPool': onnx_maxpool,
    'Relu': lambda x: [jnp.maximum(x, 0)],
    'Reshape': lambda x, shape: [jnp.reshape(x, shape)],
}


def interpret_onnx(graph, *args):
  vals = dict({n.name: a for n, a in zip(graph.input, args)},
              **{n.name: _asarray(n) for n in graph.initializer})
  for node in graph.node:
    args = (vals[name] for name in node.input)
    attrs = {a.name: attribute_handlers[a.type](a) for a in node.attribute}
    outputs = onnx_ops[node.op_type](*args, **attrs)
    for name, output in zip(node.output, outputs):
      vals[name] = output
  return [vals[n.name] for n in graph.output]


if __name__ == "__main__":
  # It seems that there are several ONNX proto versions (you had one job!) but
  # this implementation works with at least this one mnist example file.
  url = ('https://github.com/onnx/models/blob/'
         '81c4779096d1205edd0b809e191a924c58c38fef/'
         'mnist/model.onnx?raw=true')
  download = urllib.request.urlopen(url).read()
  if hashlib.md5(download).hexdigest() != 'bc8ad9bd19c5a058055dc18d0f089dad':
    print("onnx file checksum mismatch")
    sys.exit(1)
  model = onnx.load(BytesIO(download))

  predict = lambda inputs: interpret_onnx(model.graph, inputs)[0]

  # Run inference in Numpy-backed interpreter
  print("interpreted:")
  print(predict(jnp.ones((1, 1, 28, 28))))

  # JIT compile to XLA device, run inference on device
  compiled_predict = jit(predict)
  print("compiled:")
  print(compiled_predict(jnp.ones((1, 1, 28, 28))))

  # The interpreter is differentiable too! Even the compiled one:
  fun = lambda inputs: jnp.sum(compiled_predict(inputs))
  print("a derivative with respect to inputs:")
  print(grad(fun)(jnp.ones((1, 1, 28, 28)))[..., :3, :3])
back to top