7f3078b | Matthew Johnson | 08 September 2020, 15:54:13 UTC | updtate version and changelog for pypi (#4224) | 08 September 2020, 15:54:13 UTC |
ed0d8c0 | Matthew Johnson | 08 September 2020, 15:27:41 UTC | tweak lax.py shape broadcasting logic (#4217) This new implementation is faster, and works for polymorphic shapes without weird tricks. (This new implementation is faster even if we remove the weird tricks for polymorphism.) | 08 September 2020, 15:27:41 UTC |
798a264 | Benjamin Chetioui | 08 September 2020, 08:32:53 UTC | [jax2tf] Fix bug in population count and move expect_tf_exception (#4214) into correctness stats. The code was using `tf.bitcast` instead of `tf.cast`, but using `expect_tf_exception` in every case was hiding the errors. | 08 September 2020, 08:32:53 UTC |
e1340f3 | Benjamin Chetioui | 07 September 2020, 15:12:35 UTC | [jax2tf] Fix missing complex64 TPU corner case of scatter_{add,mul} (#4213) | 07 September 2020, 15:12:35 UTC |
0aed1f4 | Adam Paszke | 07 September 2020, 12:31:58 UTC | Add more context to the axis_frame error message. Some of the vmap and gmap collective tests have been failing on master and I can't seem to be able to reproduce them locally. Hopefully, if this happens again, this extra bit of information will be useful in debugging the problem. | 07 September 2020, 14:25:30 UTC |
4413bb8 | George Necula | 07 September 2020, 14:13:11 UTC | [jax2tf] Do not use jax.random.PRNGKey before in primitive harness (#4211) We cannot execute JAX functions before the program is initialized | 07 September 2020, 14:13:11 UTC |
be8ea14 | Benjamin Chetioui | 07 September 2020, 13:47:18 UTC | [jax2tf] Expand coverage of primitives by categorize. (#4209) * [jax2tf] Expand coverage of primitives by categorize. This commit adds handling logic for the limitations of: - qr - svd - select_and_gather_add - reduce_window/reduce_window_{min,max,sum} - add - mul - scatter/scatter_{min,max,mul,add} Also fixes a bug in a call to _infer_shape_jax, which wasn't compatible with boolean operands and went undetected due to the high-level handling of TF exceptions in higher-order primitives. | 07 September 2020, 13:47:18 UTC |
1e84cbe | George Necula | 07 September 2020, 11:41:50 UTC | [jax2tf] Fix random.split when jax_exable_x64 (#4208) Since we do the threefry with signed integers when converting to TF, we run into the type promotion 'uint32 - int32 = int64', which then results in lax.shift_right_logical(uint32, int64), which fails. | 07 September 2020, 11:41:50 UTC |
6c62935 | Benjamin Chetioui | 07 September 2020, 09:03:00 UTC | [jax2tf] Cleanup the correctness stats layout. (#4201) * [jax2tf] Cleanup the correctness stats layout. * Added Google license at the top of the file. * Cleanup: fix docstring for 80 char boundary. * Monkey patch/cleanup outside of the loop. * Removed tensorflow dependency. * Fixed the name of attributes of Limitation. | 07 September 2020, 09:03:00 UTC |
c6e6ee2 | George Necula | 07 September 2020, 08:26:52 UTC | [jax2tf] Use the JAX impl rule for threefry instead of writing our own (#4204) * performance is the same | 07 September 2020, 08:26:52 UTC |
96278e6 | AdrienCorenflos | 04 September 2020, 16:21:43 UTC | Add reverse flag in associative scan (#4181) Add optional 'reverse' argument in associative scan | 04 September 2020, 16:21:43 UTC |
bcf9777 | Benjamin Chetioui | 03 September 2020, 13:56:22 UTC | [jax2tf] Generator for the documentation of operations with limited support (WIP) (#4193) * [jax2tf] Draft of a generator for the documentation of operations with limited support. | 03 September 2020, 13:56:22 UTC |
abdd138 | George Necula | 03 September 2020, 11:24:04 UTC | [jax2tf] Flip the with_gradient=True; was flipped back by mistake (#4200) | 03 September 2020, 11:24:04 UTC |
5eac477 | George Necula | 03 September 2020, 11:18:35 UTC | [jax2tf] Implementation of random_gamma (#4192) * [jax2tf] implementation of random_gamma The simplest implementation is by converting the JAX own impl_rule, which rewrites gamma into other JAX primitives. On TPU with use_vmap=True the performance is the same for JAX and TF, provided we use tf.function(compile=True). | 03 September 2020, 11:18:35 UTC |
708d07d | Alex Riley | 02 September 2020, 23:13:17 UTC | Add jax.numpy.array_split (#4197) | 02 September 2020, 23:13:17 UTC |
04f9a7e | Matthew Johnson | 02 September 2020, 01:16:20 UTC | better jax.numpy.tile implementation (#4190) Use reshape, broadcast_to, reshape. | 02 September 2020, 01:16:20 UTC |
421550a | Jake Vanderplas | 01 September 2020, 22:48:40 UTC | copysign: promote to inexact to match numpy & support unsigned inputs (#4188) | 01 September 2020, 22:48:40 UTC |
0cdb1f7 | Benjamin Chetioui | 01 September 2020, 07:35:25 UTC | [jax2tf] Indicate the version of TF used in tests in README. (#4185) | 01 September 2020, 07:35:25 UTC |
bdd6545 | Jean-Baptiste Lespiau | 01 September 2020, 07:34:47 UTC | Add more features to the C++ jax.jit. (#4169) This mainly follows https://github.com/google/jax/pull/4089 by adding: - support for disable_jit from C++ - support for jax._cpp_jit on methods. - supporting applying @jax.jit on top-level functions, by delaying the retrieval of the device and backend. - concurrency support. I am not aware of any feature missing (but I suspect there are still some differences due to the differences between xla_computation and _xla_callable.) See: - https://i.ibb.co/ZMvZ4nK/benchmark.png for the benchmarking comparison (see cr/328899906 + benchmarks for how numbers were generated) - The results of the Jax tests when enabling this: http://sponge2/4a67d132-209f-45c5-ab7b-83716d329ec2 (110 fails, 92 passes, but many common cause of failure). | 01 September 2020, 07:34:47 UTC |
36368a2 | Jake Vanderplas | 31 August 2020, 21:11:49 UTC | jnp.abs(): support boolean inputs (#4186) | 31 August 2020, 21:11:49 UTC |
44bcf7e | Hamza Merzić | 31 August 2020, 14:00:34 UTC | Fix axis checking and remove extra print statement (#4184) A series of PRs renaming the frame entries have been submitted, one of them introducing a bug when using omnistaging. This PR fixes that and removes a print comment (assuming added for debugging purposes). | 31 August 2020, 14:00:34 UTC |
b6b1f5e | George Necula | 31 August 2020, 07:26:32 UTC | [jax2tf] Turn on with_gradient by default (#4180) As I was writing the demo I realized that it makes more sense for with_gradient to be set to True by default. I have also fixed a bug with tie_in in omnistaging. | 31 August 2020, 07:26:32 UTC |
634c625 | George Necula | 30 August 2020, 09:38:14 UTC | More renaming of master to main in JAX internals (#4179) | 30 August 2020, 09:38:14 UTC |
ffbfadd | Jake Vanderplas | 30 August 2020, 08:36:47 UTC | lax.associative_scan: fix docstring examples (#4172) * lax.associative_scan: fix docstring examples * add verbiage from #3583 | 30 August 2020, 08:36:47 UTC |
6b6789a | Matthew Johnson | 30 August 2020, 08:16:51 UTC | applied simple find+sed for 'master' -> 'main' (#4174) * applied simple find+sed for 'master' -> 'main' * Rename master->main in JAX API and internals (#4178) * Started with #4174 * Renamed Trace.master to Trace.main * Renamed core.new_master and core.new_base_master Co-authored-by: George Necula <gcnecula@gmail.com> | 30 August 2020, 08:16:51 UTC |
1a87fd3 | Benjamin Chetioui | 29 August 2020, 08:24:03 UTC | Implement a proper shape checking rule for gather. (#4166) * Implement a proper shape checking rule for gather. The implementation is based on the corresponding shape inference code in `tensorflow/compiler/xla/service/shape_inference.cc`. The tests added in `tests/lax_test.py` are similarly mirroring the corresponding tests in tensorflow, with slight adaptations for the particular setting of JAX. Fixes google/jax#2826, and in principle fixes google/jax#4154 and google/jax#3905. * Extracted common functions for gather/scatter shape checking rules. | 29 August 2020, 08:24:03 UTC |
a33f4dd | Adam Paszke | 28 August 2020, 18:03:39 UTC | Add support for axis_index inside vmap (#4168) Also, reorganize the code to put all `axis_index` related functions in `lax_parallel.py`, next to all other parallel collectives. | 28 August 2020, 18:03:39 UTC |
1dab791 | Jake Vanderplas | 28 August 2020, 16:07:30 UTC | Avoid calling jnp.sum() on list (#4163) | 28 August 2020, 16:07:30 UTC |
04f9ff7 | Benjamin Chetioui | 28 August 2020, 14:27:32 UTC | Addition of one more conclusive polynomial comparison case. (#4167) * Addition of one more conclusive polynomial comparison case. In the case when the difference between two polynomials is a constant, it is possible to conclusively compare them. This commit adds such a case to masking.Poly.__ge__. * Added a few relevant tests in tests.masking_test.test_Poly_compare. | 28 August 2020, 14:27:32 UTC |
7210d6f | Adam Paszke | 18 August 2020, 09:14:38 UTC | Add support for binding axis_name in gmap This allows executing collectives over the gmapped axes. This requires some extra manipulation of the gmapped jaxpr, since gmap exposes a single logical axis name, but evaluates the program using multiple "physical" axes. This also fixes some bugs around handling `multiple_returns` in vmap collective implementation. | 28 August 2020, 12:42:01 UTC |
e95d570 | Jean-Baptiste Lespiau | 27 August 2020, 14:02:13 UTC | Add benchmarks for specifically the dispatch time. (#4128) The goal is to distinguish the time it takes for `jitted_f` to return, and the time it takes to return and wait for the result. We also add one to distinguish the time it takes to call the function with the argument transfer or without it. e.g. name time/op jit_trivial_dispatch 28.9µs ± 2% jit_trivial 31.5µs ± 5% jit_simple_dispatch 60.7µs ± 4% jit_simple 129µs ±24% jit_simple_many_args_disptch 390µs ±19% jit_simple_many_args 388µs ±16% jit_dispatch_without_transfer 379µs ± 6% jit_dispatch_with_transfer 450µs ± 5% | 27 August 2020, 14:02:13 UTC |
36846e0 | George Necula | 27 August 2020, 09:45:48 UTC | Revert "Delete batching.last. (#4148)" (#4160) This reverts commit 4bf3d6e9cccc5de3834e37affae2012e6e3d3180. This commit fails internal tests. | 27 August 2020, 09:45:48 UTC |
a7faf09 | Benjamin Chetioui | 27 August 2020, 09:24:13 UTC | [jax2tf] Added conversion for scatter*_p primitives. (#4091) * [jax2tf] Added conversion for scatter*_p primitives. Limitations: the conversion works as well as the conversion of the underlying reduction functions (e.g. lax.scatter_max is not properly converted for the int8 dtype, because tf.math.maximum is not defined for int8 tensors); the conversion can not take advantage of the unique_indices parameter. This does not affect correctness, but may affect performance on certain platforms (as stated in the documentation of lax.scatter). * Put tf.function experimental compile wrapper back on scatter. * Removed unique_indices=True test cases * Remove non-deterministic test cases from the scatter harness. This commit also documents the reasons for ignoring these test cases and potential pitfalls, in case someone needs to perform these tests at a later time. | 27 August 2020, 09:24:13 UTC |
4d7396a | Benjamin Chetioui | 27 August 2020, 09:04:32 UTC | Implement a proper shape checking rule for scatter. (#4144) The implementation is based on the corresponding shape inference code in `tensorflow/compiler/xla/service/shape_inference.cc`. The tests added in `tests/lax_test.py` are similarly mirroring the corresponding tests in tensorflow, with slight adaptations for the particular setting of JAX. | 27 August 2020, 09:04:32 UTC |
80114e5 | Benjamin Chetioui | 27 August 2020, 07:47:19 UTC | Add a boolean to _check_shapelike to accept or reject shapes (#4108) * Add a boolean to _check_shapelike to accept or reject shapes corresponding to arrays of 0 elements. (Fixes google/jax#3972). * Added test for failures referenced in issue 3972. | 27 August 2020, 07:47:19 UTC |
1dc71b2 | Benjamin Chetioui | 27 August 2020, 07:46:32 UTC | [jax2tf] Add testing for add/mul/min/max conversion. (#4142) * [jax2tf] Add testing for add/mul/min/max conversion. Only certain types are supported for each of the operations above. This commit adds previously missing tests to make this explicit. | 27 August 2020, 07:46:32 UTC |
c76b84f | George Necula | 27 August 2020, 07:34:53 UTC | Revert "Increase tolerance for CPU test LaxBackedNumpyTests::testCorrCoef (#4080)" (#4151) This reverts commit 22b92c5122ab5af6f5e4560f9be08f5649ae7653. We revert this change because the LLVM bug that made us relax the test tolerance is now fixed. | 27 August 2020, 07:34:53 UTC |
57f49b6 | George Necula | 27 August 2020, 07:05:24 UTC | Fix bug in omnistaging_enabler (#4159) This code was failing with "KeyError: psum" for the tests "//third_party/py/flax/...". I suspect that the error is due to the ordering of the omnistaging enablers, changed in #4152. I am not sure of this fix, but this seemed to be enough for all the presubmit tests to pass and allow the copybara import. | 27 August 2020, 07:05:24 UTC |
417c9ff | George Necula | 27 August 2020, 06:41:16 UTC | Fix pytype error (#4158) | 27 August 2020, 06:41:16 UTC |
29073be | Jake Vanderplas | 27 August 2020, 04:13:33 UTC | cleanup: remove duplicate line (#4156) | 27 August 2020, 04:13:33 UTC |
f0fb7d0 | Tom Hennigan | 26 August 2020, 21:06:27 UTC | Use omnistaging env var even when not using absl flags for config. (#4152) | 26 August 2020, 21:06:27 UTC |
1d93991 | Matthew Johnson | 26 August 2020, 17:21:56 UTC | allow random.choice to accept ndarray input (#4145) * allow random.choice to accept ndarray `a` follow-up to #4137 to allow ndarray inputs to be passed * add jax.random.choice tests to cover ndarray input * don't use callables in test params it can mess with pytest-xdist because of hashing by id | 26 August 2020, 17:21:56 UTC |
01319fb | Peter Hawkins | 25 August 2020, 17:05:06 UTC | Speed up and clean up geomspace test. (#4149) * Speed up and clean up geomspace test. | 25 August 2020, 17:05:06 UTC |
4bf3d6e | Peter Hawkins | 25 August 2020, 16:53:18 UTC | Delete batching.last. (#4148) A -1 axis works just as well at head. | 25 August 2020, 16:53:18 UTC |
8c8060e | Peter Hawkins | 25 August 2020, 16:53:02 UTC | Remove workaround for illegal vmap out_axes. (#4147) | 25 August 2020, 16:53:02 UTC |
6d54eb5 | Jake Vanderplas | 25 August 2020, 12:47:43 UTC | Do not call asarray() on inputs of jax.random.choice (#4137) | 25 August 2020, 12:47:43 UTC |
f959219 | Jean-Baptiste Lespiau | 25 August 2020, 12:39:45 UTC | Rename collectives into "collective operations" for the pmap function. (#4136) It is just because it serves as the entry point, and this term leads to good Google results, such as https://en.wikipedia.org/wiki/Collective_operation, while the current "collectives" do not. | 25 August 2020, 12:39:45 UTC |
f4b05bc | Matthew Johnson | 25 August 2020, 12:38:41 UTC | make pe.abstract_eval_fun use omnistaging (#4139) | 25 August 2020, 12:38:41 UTC |
04173b3 | Matthew Johnson | 25 August 2020, 12:38:20 UTC | Merge pull request #4140 from sharadmv/patch-2 Remove frame check assertion in `extend_axis_env`. | 25 August 2020, 12:38:20 UTC |
774b5f6 | Sharad Vikram | 25 August 2020, 04:08:23 UTC | Remove frame check assertion in `extend_axis_env`. | 25 August 2020, 04:13:30 UTC |
e06a6ab | Peter Hawkins | 25 August 2020, 00:21:19 UTC | Add support for negative axes to vmap. (#4111) * Add support for negative axes to vmap. * Add workaround for out-of-range vmap axes. | 25 August 2020, 00:21:19 UTC |
603f0c1 | Philipp Thölke | 24 August 2020, 20:50:44 UTC | Fix scan carry types in gradient of complex ODE (#4130) * Cast t_bar from potential complex to float in ode.py * Add test case for complex odeint (currently failing) * Wrap odeint into complex-to-real function in test case * fixup Co-authored-by: Stephan Hoyer <shoyer@google.com> | 24 August 2020, 20:50:44 UTC |
0cc3802 | Benjamin Chetioui | 24 August 2020, 19:29:22 UTC | Fix documentation of scatter_* operations. (#4138) * Fix documentation of scatter_* operations. This commit changes the documentation of the `unique_indices` parameter to scatter to better capture its intended meaning in XLA. | 24 August 2020, 19:29:22 UTC |
e5c4ccb | Matthew Johnson | 22 August 2020, 03:36:02 UTC | Merge pull request #4125 from google/issue4124 make random.choice error when shape isn't sequence | 22 August 2020, 03:36:02 UTC |
56b3688 | Matthew Johnson | 22 August 2020, 02:58:06 UTC | make random.choice error when shape isn't sequence fixes #4124 | 22 August 2020, 02:58:06 UTC |
6bed4ee | Jean-Baptiste Lespiau | 22 August 2020, 01:44:52 UTC | Temporarily disable jax_jit tests. (#4118) | 22 August 2020, 01:44:52 UTC |
7082105 | Matthew Johnson | 22 August 2020, 01:35:30 UTC | Merge pull request #4123 from google/only-one-axis-index-primitive only construct one axis_index_p primitive | 22 August 2020, 01:35:30 UTC |
a62580c | Matthew Johnson | 22 August 2020, 00:56:59 UTC | deflake | 22 August 2020, 00:56:59 UTC |
66a02b6 | Matthew Johnson | 22 August 2020, 00:43:15 UTC | only construct one axis_index_p primitive Before this change, there were two versions, one used with omnistaging and one without. But that made bookkeeping hard and buggy. This change defines the axis_index_p primitive in core.py. Some of its rules are still changed when omnistaging is enabled. | 22 August 2020, 00:43:15 UTC |
7e77af4 | Matthew Johnson | 21 August 2020, 19:34:31 UTC | don't force backend creation in xla_computation (#4121) | 21 August 2020, 19:34:31 UTC |
519d57c | Matthew Johnson | 21 August 2020, 19:14:45 UTC | fix bugs | 21 August 2020, 19:14:45 UTC |
9d733dd | Wojciech Rzadkowski | 21 August 2020, 18:19:51 UTC | Doc: change suggested way of starting the profiler (#4120) | 21 August 2020, 18:19:51 UTC |
b2a239c | Matthew Johnson | 21 August 2020, 18:10:53 UTC | don't force backend creation in xla_computation | 21 August 2020, 18:10:53 UTC |
3063145 | Matthew Johnson | 20 August 2020, 21:44:26 UTC | use xla.backend_compile function in pxla.py (#4113) * use xla.backend_compile function in pxla.py Not only is this useful for profiling, it also helps us do google-internal logging for the XLA team. | 20 August 2020, 21:44:26 UTC |
1e6b809 | Benjamin Chetioui | 20 August 2020, 18:45:15 UTC | Fixes padding generation for padding == 'SAME' in reduce_window to (#4110) * Fixes padding generation for padding == 'SAME' in reduce_window to take window_dilation into account. (Fixes google/jax#3973). This commit applies the fix suggested by James on the issue, which is backed by the meaning of padding described on https://www.tensorflow.org/xla/operation_semantics#reducewindow. * Added shape tests for reduce_window when stride is 1 in each direction and padding is 'SAME'. | 20 August 2020, 18:45:15 UTC |
1e8ac24 | Mihaela Rosca | 20 August 2020, 14:46:55 UTC | Add rademacher, maxwell, double_sided_maxwell and weibull_min to jax.random. (#4104) | 20 August 2020, 14:46:55 UTC |
d978808 | Stephan Hoyer | 20 August 2020, 07:36:35 UTC | Document the required form of tap_func for host_callback.id_tap (#4100) | 20 August 2020, 07:36:35 UTC |
22b593b | Benjamin Chetioui | 20 August 2020, 07:05:30 UTC | [jax2tf] General conversion of reduce_window. (#4093) * [jax2tf] General conversion of reduce_window. Much like scatter_p, the conversion works as well as the underlying reduction function (e.g. reduce_window_max is not properly converted for the int8 dtype, because tf.math.maximum is not defined for int8 tensors). | 20 August 2020, 07:05:30 UTC |
9ba282f | Matthew Johnson | 19 August 2020, 22:51:40 UTC | add axis_index to supported multi-host collectives (#4107) also make the error message less confusing | 19 August 2020, 22:51:40 UTC |
b4efb31 | George Thomas | 19 August 2020, 18:36:28 UTC | Docs: Fix broken link in quickstart (#4102) | 19 August 2020, 18:36:28 UTC |
9ca1020 | Jean-Baptiste Lespiau | 19 August 2020, 16:39:25 UTC | Add a fast C++ jit codepath. (#4089) This starts a C++ jit codepath to speed up dispatch time. Tracing is not supported yet. Supported features: - scalar, numpy array and DeviceArray argument support: - integer, floats, boolean, and complex scalars arguments are supported. - The jax_enable_x64 flag will be used at object-creation type to cast scalars and numpy arrays. - The Jax `weak_type` attribute for arguments is supported (DeviceArray and scalars). - The donate_argnums argument. - Use an XLA tuple for more than 100 arguments Unsupported features: - jax._cpp_jit on methods e.g @functools.partial(jax.jit, static_argnums=0) def _compute_log_data(self, ...) ... This is currently not supported by the C++ codepath, because "self" won't be automatically added. - disable_jit. | 19 August 2020, 16:39:25 UTC |
5135fd1 | Roy Frostig | 19 August 2020, 14:48:25 UTC | fix jaxpr util test under enable_x64 | 19 August 2020, 15:28:56 UTC |
b892236 | Matthew Johnson | 19 August 2020, 04:04:14 UTC | remove check for TypedJaxpr literals arent tracers (#4096) In the original usage of TypedJaxpr, literals could not be tracers because they were only produced by initial-style transformations of jaxprs. But now TypedJaxpr is used in several other ways, e.g. in make_jaxpr, and moreover its avals are redundant. It should probably be renamed ClosedJaxpr since it mainly serves to package a jaxpr together with its constant arrays. This check was limiting the utility of TypedJaxpr, and it was only added relatively recently anyway. | 19 August 2020, 04:04:14 UTC |
8e166ad | Peter Hawkins | 19 August 2020, 01:24:41 UTC | Unbreak jaxlib build. (#4098) | 19 August 2020, 01:24:41 UTC |
8cc9579 | Roy Frostig | 19 August 2020, 00:03:43 UTC | check path prefixes using os.path instead of string comparisons | 19 August 2020, 01:08:08 UTC |
d778a6d | Roy Frostig | 21 July 2020, 02:04:43 UTC | move experimental.jaxpr_stats to jaxpr_util | 19 August 2020, 01:07:38 UTC |
908d54a | Roy Frostig | 22 June 2020, 23:01:03 UTC | utilities to collect summary statistics of jaxprs | 19 August 2020, 01:07:38 UTC |
d70976c | Jake Vanderplas | 18 August 2020, 23:31:54 UTC | Cleanup: reduce redundant code (#4095) | 18 August 2020, 23:31:54 UTC |
afeefa6 | Alex Alemi | 18 August 2020, 22:49:29 UTC | Add typing and namedtuple to `optimizers.py`, improve documentation. (#3570) | 18 August 2020, 22:49:29 UTC |
29f7fa7 | gaurav pathak | 18 August 2020, 20:40:45 UTC | Add implementation of jax.numpy.trim_zeros (#4027) | 18 August 2020, 20:40:45 UTC |
29aa9bf | Jake Vanderplas | 18 August 2020, 17:17:38 UTC | Cleanup: avoid jnp.prod & np.prod on array shapes (#4086) | 18 August 2020, 17:17:38 UTC |
decd760 | Stephan Hoyer | 18 August 2020, 16:40:57 UTC | Add experimental __array_module__ method (#4076) * Add experimental __array_module__ method xref https://github.com/google/jax/issues/1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence](https://github.com/numpy/numpy/pull/16935#issuecomment-673951287) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time. * don't explicitly list cython * remove UnshpaedArray from _JAX_ARRAY_TYPES * Remove incorrect note about metaclasses * remove unnecessary numpy_dispatch.ensure_dispatching() | 18 August 2020, 16:40:57 UTC |
36f3a36 | Adam Paszke | 18 August 2020, 10:02:28 UTC | Separate axis splitting from collective handling (#4082) This makes the vmap collective handling a bit more flexible and allowed me to add ppermute support. | 18 August 2020, 10:02:28 UTC |
ace23fa | Benjamin Chetioui | 18 August 2020, 09:01:13 UTC | [jax2tf] Added tests for reduce_window translation (#4062) * [jax2tf] Added tests for reduce_window translation (WIP) * Added other non-floating types to the tests. | 18 August 2020, 09:01:13 UTC |
8c2ee37 | Jean-Baptiste Lespiau | 18 August 2020, 08:43:52 UTC | Prior refactoring before the C++ jax.jit. (#4045) | 18 August 2020, 08:43:52 UTC |
2ab6b42 | Jean-Baptiste Lespiau | 18 August 2020, 05:58:43 UTC | Use pytree defined in tensorflow. (#4087) It also adds some tests on the scalar C++ conversion. | 18 August 2020, 05:58:43 UTC |
fe69d3c | Roy Frostig | 15 August 2020, 01:52:55 UTC | always deref all locals that indirectly reach stack frames in the exception-reraise handler | 18 August 2020, 01:13:58 UTC |
dbca9e6 | Roy Frostig | 14 August 2020, 20:22:20 UTC | unrevert #3674 (revert #3791) | 18 August 2020, 01:13:58 UTC |
1ba4e06 | Adam Paszke | 17 August 2020, 18:11:43 UTC | Initial version of gmap (#4006) Co-autored-by: Matthew Johnson <mattjj@google.com> | 17 August 2020, 18:11:43 UTC |
4c22e01 | Benjamin Chetioui | 17 August 2020, 14:32:34 UTC | [jax2tf] Explictly raise an error when attempting to convert _select_and_scatter_add_p. (#4084) | 17 August 2020, 14:32:34 UTC |
ec90c35 | Benjamin Chetioui | 17 August 2020, 10:57:41 UTC | [jax2tf] Fix bfloat16 bug in select_and_gather_add conversion. (#4058) * [jax2tf] Fix bfloat16 bug in select_and_gather_add conversion. This fix makes it possible to run bfloat16 tests for the jax2tf conversion of select_and_gather_add. | 17 August 2020, 10:57:41 UTC |
c7aff1d | George Necula | 17 August 2020, 09:53:18 UTC | Revert "Use pytree from xla_client. (#4063)" (#4081) This reverts commit d8de6b61411179dcd2f63d7639bbcd69b30ac15f. Tryting to revert because it seems that this produces test failures in Google. | 17 August 2020, 09:53:18 UTC |
22b92c5 | George Necula | 17 August 2020, 06:50:47 UTC | Increase tolerance for CPU test LaxBackedNumpyTests::testCorrCoef (#4080) | 17 August 2020, 06:50:47 UTC |
9120701 | Matthew Johnson | 17 August 2020, 03:00:40 UTC | allow xla_computation to psum a constant (#4078) * allow xla_computation to psum a constant * allow axis_env to be None | 17 August 2020, 03:00:40 UTC |
8232f2d | Matthew Johnson | 16 August 2020, 05:55:18 UTC | adapt _TempAxisName for unhashable objs (#4077) adapt _TempAxisName for unhashable objs | 16 August 2020, 05:55:18 UTC |
16ab9cb | James Bradbury | 15 August 2020, 22:07:08 UTC | support multi-host pmap with omnistaging (#4075) | 15 August 2020, 22:07:08 UTC |
1316562 | Philipp Thölke | 15 August 2020, 15:47:28 UTC | Canonicalize result dtype to fix double precision problem in ldexp (#4069) | 15 August 2020, 15:47:28 UTC |
1dbdaac | George Necula | 15 August 2020, 05:55:02 UTC | [jax2tf] avoid import errors when omnistaging is enabled (#4072) * [jax2tf] avoid import errors when omnistaging is enabled | 15 August 2020, 05:55:02 UTC |
9ab07d8 | James Bradbury | 15 August 2020, 05:54:36 UTC | support axis_index_groups in psum(const) (#4070) * support axis_index_groups in psum(const) * add test for psum(constant, axis_index_groups) * rm trailing whitespace * Update lax_parallel.py | 15 August 2020, 05:54:36 UTC |
394a33c | Ryan Sepassi | 14 August 2020, 20:05:58 UTC | Add `in_parts` and `out_parts` optional arguments `jax.xla_computation`. (#4055) PR #3771 redux (reverted in #3780) Co-authored-by: Skye Wanderman-Milne <skyewm@google.com> | 14 August 2020, 20:05:58 UTC |