https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
7f3078b updtate version and changelog for pypi (#4224) 08 September 2020, 15:54:13 UTC
ed0d8c0 tweak lax.py shape broadcasting logic (#4217) This new implementation is faster, and works for polymorphic shapes without weird tricks. (This new implementation is faster even if we remove the weird tricks for polymorphism.) 08 September 2020, 15:27:41 UTC
798a264 [jax2tf] Fix bug in population count and move expect_tf_exception (#4214) into correctness stats. The code was using `tf.bitcast` instead of `tf.cast`, but using `expect_tf_exception` in every case was hiding the errors. 08 September 2020, 08:32:53 UTC
e1340f3 [jax2tf] Fix missing complex64 TPU corner case of scatter_{add,mul} (#4213) 07 September 2020, 15:12:35 UTC
0aed1f4 Add more context to the axis_frame error message. Some of the vmap and gmap collective tests have been failing on master and I can't seem to be able to reproduce them locally. Hopefully, if this happens again, this extra bit of information will be useful in debugging the problem. 07 September 2020, 14:25:30 UTC
4413bb8 [jax2tf] Do not use jax.random.PRNGKey before in primitive harness (#4211) We cannot execute JAX functions before the program is initialized 07 September 2020, 14:13:11 UTC
be8ea14 [jax2tf] Expand coverage of primitives by categorize. (#4209) * [jax2tf] Expand coverage of primitives by categorize. This commit adds handling logic for the limitations of: - qr - svd - select_and_gather_add - reduce_window/reduce_window_{min,max,sum} - add - mul - scatter/scatter_{min,max,mul,add} Also fixes a bug in a call to _infer_shape_jax, which wasn't compatible with boolean operands and went undetected due to the high-level handling of TF exceptions in higher-order primitives. 07 September 2020, 13:47:18 UTC
1e84cbe [jax2tf] Fix random.split when jax_exable_x64 (#4208) Since we do the threefry with signed integers when converting to TF, we run into the type promotion 'uint32 - int32 = int64', which then results in lax.shift_right_logical(uint32, int64), which fails. 07 September 2020, 11:41:50 UTC
6c62935 [jax2tf] Cleanup the correctness stats layout. (#4201) * [jax2tf] Cleanup the correctness stats layout. * Added Google license at the top of the file. * Cleanup: fix docstring for 80 char boundary. * Monkey patch/cleanup outside of the loop. * Removed tensorflow dependency. * Fixed the name of attributes of Limitation. 07 September 2020, 09:03:00 UTC
c6e6ee2 [jax2tf] Use the JAX impl rule for threefry instead of writing our own (#4204) * performance is the same 07 September 2020, 08:26:52 UTC
96278e6 Add reverse flag in associative scan (#4181) Add optional 'reverse' argument in associative scan 04 September 2020, 16:21:43 UTC
bcf9777 [jax2tf] Generator for the documentation of operations with limited support (WIP) (#4193) * [jax2tf] Draft of a generator for the documentation of operations with limited support. 03 September 2020, 13:56:22 UTC
abdd138 [jax2tf] Flip the with_gradient=True; was flipped back by mistake (#4200) 03 September 2020, 11:24:04 UTC
5eac477 [jax2tf] Implementation of random_gamma (#4192) * [jax2tf] implementation of random_gamma The simplest implementation is by converting the JAX own impl_rule, which rewrites gamma into other JAX primitives. On TPU with use_vmap=True the performance is the same for JAX and TF, provided we use tf.function(compile=True). 03 September 2020, 11:18:35 UTC
708d07d Add jax.numpy.array_split (#4197) 02 September 2020, 23:13:17 UTC
04f9a7e better jax.numpy.tile implementation (#4190) Use reshape, broadcast_to, reshape. 02 September 2020, 01:16:20 UTC
421550a copysign: promote to inexact to match numpy & support unsigned inputs (#4188) 01 September 2020, 22:48:40 UTC
0cdb1f7 [jax2tf] Indicate the version of TF used in tests in README. (#4185) 01 September 2020, 07:35:25 UTC
bdd6545 Add more features to the C++ jax.jit. (#4169) This mainly follows https://github.com/google/jax/pull/4089 by adding: - support for disable_jit from C++ - support for jax._cpp_jit on methods. - supporting applying @jax.jit on top-level functions, by delaying the retrieval of the device and backend. - concurrency support. I am not aware of any feature missing (but I suspect there are still some differences due to the differences between xla_computation and _xla_callable.) See: - https://i.ibb.co/ZMvZ4nK/benchmark.png for the benchmarking comparison (see cr/328899906 + benchmarks for how numbers were generated) - The results of the Jax tests when enabling this: http://sponge2/4a67d132-209f-45c5-ab7b-83716d329ec2 (110 fails, 92 passes, but many common cause of failure). 01 September 2020, 07:34:47 UTC
36368a2 jnp.abs(): support boolean inputs (#4186) 31 August 2020, 21:11:49 UTC
44bcf7e Fix axis checking and remove extra print statement (#4184) A series of PRs renaming the frame entries have been submitted, one of them introducing a bug when using omnistaging. This PR fixes that and removes a print comment (assuming added for debugging purposes). 31 August 2020, 14:00:34 UTC
b6b1f5e [jax2tf] Turn on with_gradient by default (#4180) As I was writing the demo I realized that it makes more sense for with_gradient to be set to True by default. I have also fixed a bug with tie_in in omnistaging. 31 August 2020, 07:26:32 UTC
634c625 More renaming of master to main in JAX internals (#4179) 30 August 2020, 09:38:14 UTC
ffbfadd lax.associative_scan: fix docstring examples (#4172) * lax.associative_scan: fix docstring examples * add verbiage from #3583 30 August 2020, 08:36:47 UTC
6b6789a applied simple find+sed for 'master' -> 'main' (#4174) * applied simple find+sed for 'master' -> 'main' * Rename master->main in JAX API and internals (#4178) * Started with #4174 * Renamed Trace.master to Trace.main * Renamed core.new_master and core.new_base_master Co-authored-by: George Necula <gcnecula@gmail.com> 30 August 2020, 08:16:51 UTC
1a87fd3 Implement a proper shape checking rule for gather. (#4166) * Implement a proper shape checking rule for gather. The implementation is based on the corresponding shape inference code in `tensorflow/compiler/xla/service/shape_inference.cc`. The tests added in `tests/lax_test.py` are similarly mirroring the corresponding tests in tensorflow, with slight adaptations for the particular setting of JAX. Fixes google/jax#2826, and in principle fixes google/jax#4154 and google/jax#3905. * Extracted common functions for gather/scatter shape checking rules. 29 August 2020, 08:24:03 UTC
a33f4dd Add support for axis_index inside vmap (#4168) Also, reorganize the code to put all `axis_index` related functions in `lax_parallel.py`, next to all other parallel collectives. 28 August 2020, 18:03:39 UTC
1dab791 Avoid calling jnp.sum() on list (#4163) 28 August 2020, 16:07:30 UTC
04f9ff7 Addition of one more conclusive polynomial comparison case. (#4167) * Addition of one more conclusive polynomial comparison case. In the case when the difference between two polynomials is a constant, it is possible to conclusively compare them. This commit adds such a case to masking.Poly.__ge__. * Added a few relevant tests in tests.masking_test.test_Poly_compare. 28 August 2020, 14:27:32 UTC
7210d6f Add support for binding axis_name in gmap This allows executing collectives over the gmapped axes. This requires some extra manipulation of the gmapped jaxpr, since gmap exposes a single logical axis name, but evaluates the program using multiple "physical" axes. This also fixes some bugs around handling `multiple_returns` in vmap collective implementation. 28 August 2020, 12:42:01 UTC
e95d570 Add benchmarks for specifically the dispatch time. (#4128) The goal is to distinguish the time it takes for `jitted_f` to return, and the time it takes to return and wait for the result. We also add one to distinguish the time it takes to call the function with the argument transfer or without it. e.g. name time/op jit_trivial_dispatch 28.9µs ± 2% jit_trivial 31.5µs ± 5% jit_simple_dispatch 60.7µs ± 4% jit_simple 129µs ±24% jit_simple_many_args_disptch 390µs ±19% jit_simple_many_args 388µs ±16% jit_dispatch_without_transfer 379µs ± 6% jit_dispatch_with_transfer 450µs ± 5% 27 August 2020, 14:02:13 UTC
36846e0 Revert "Delete batching.last. (#4148)" (#4160) This reverts commit 4bf3d6e9cccc5de3834e37affae2012e6e3d3180. This commit fails internal tests. 27 August 2020, 09:45:48 UTC
a7faf09 [jax2tf] Added conversion for scatter*_p primitives. (#4091) * [jax2tf] Added conversion for scatter*_p primitives. Limitations: the conversion works as well as the conversion of the underlying reduction functions (e.g. lax.scatter_max is not properly converted for the int8 dtype, because tf.math.maximum is not defined for int8 tensors); the conversion can not take advantage of the unique_indices parameter. This does not affect correctness, but may affect performance on certain platforms (as stated in the documentation of lax.scatter). * Put tf.function experimental compile wrapper back on scatter. * Removed unique_indices=True test cases * Remove non-deterministic test cases from the scatter harness. This commit also documents the reasons for ignoring these test cases and potential pitfalls, in case someone needs to perform these tests at a later time. 27 August 2020, 09:24:13 UTC
4d7396a Implement a proper shape checking rule for scatter. (#4144) The implementation is based on the corresponding shape inference code in `tensorflow/compiler/xla/service/shape_inference.cc`. The tests added in `tests/lax_test.py` are similarly mirroring the corresponding tests in tensorflow, with slight adaptations for the particular setting of JAX. 27 August 2020, 09:04:32 UTC
80114e5 Add a boolean to _check_shapelike to accept or reject shapes (#4108) * Add a boolean to _check_shapelike to accept or reject shapes corresponding to arrays of 0 elements. (Fixes google/jax#3972). * Added test for failures referenced in issue 3972. 27 August 2020, 07:47:19 UTC
1dc71b2 [jax2tf] Add testing for add/mul/min/max conversion. (#4142) * [jax2tf] Add testing for add/mul/min/max conversion. Only certain types are supported for each of the operations above. This commit adds previously missing tests to make this explicit. 27 August 2020, 07:46:32 UTC
c76b84f Revert "Increase tolerance for CPU test LaxBackedNumpyTests::testCorrCoef (#4080)" (#4151) This reverts commit 22b92c5122ab5af6f5e4560f9be08f5649ae7653. We revert this change because the LLVM bug that made us relax the test tolerance is now fixed. 27 August 2020, 07:34:53 UTC
57f49b6 Fix bug in omnistaging_enabler (#4159) This code was failing with "KeyError: psum" for the tests "//third_party/py/flax/...". I suspect that the error is due to the ordering of the omnistaging enablers, changed in #4152. I am not sure of this fix, but this seemed to be enough for all the presubmit tests to pass and allow the copybara import. 27 August 2020, 07:05:24 UTC
417c9ff Fix pytype error (#4158) 27 August 2020, 06:41:16 UTC
29073be cleanup: remove duplicate line (#4156) 27 August 2020, 04:13:33 UTC
f0fb7d0 Use omnistaging env var even when not using absl flags for config. (#4152) 26 August 2020, 21:06:27 UTC
1d93991 allow random.choice to accept ndarray input (#4145) * allow random.choice to accept ndarray `a` follow-up to #4137 to allow ndarray inputs to be passed * add jax.random.choice tests to cover ndarray input * don't use callables in test params it can mess with pytest-xdist because of hashing by id 26 August 2020, 17:21:56 UTC
01319fb Speed up and clean up geomspace test. (#4149) * Speed up and clean up geomspace test. 25 August 2020, 17:05:06 UTC
4bf3d6e Delete batching.last. (#4148) A -1 axis works just as well at head. 25 August 2020, 16:53:18 UTC
8c8060e Remove workaround for illegal vmap out_axes. (#4147) 25 August 2020, 16:53:02 UTC
6d54eb5 Do not call asarray() on inputs of jax.random.choice (#4137) 25 August 2020, 12:47:43 UTC
f959219 Rename collectives into "collective operations" for the pmap function. (#4136) It is just because it serves as the entry point, and this term leads to good Google results, such as https://en.wikipedia.org/wiki/Collective_operation, while the current "collectives" do not. 25 August 2020, 12:39:45 UTC
f4b05bc make pe.abstract_eval_fun use omnistaging (#4139) 25 August 2020, 12:38:41 UTC
04173b3 Merge pull request #4140 from sharadmv/patch-2 Remove frame check assertion in `extend_axis_env`. 25 August 2020, 12:38:20 UTC
774b5f6 Remove frame check assertion in `extend_axis_env`. 25 August 2020, 04:13:30 UTC
e06a6ab Add support for negative axes to vmap. (#4111) * Add support for negative axes to vmap. * Add workaround for out-of-range vmap axes. 25 August 2020, 00:21:19 UTC
603f0c1 Fix scan carry types in gradient of complex ODE (#4130) * Cast t_bar from potential complex to float in ode.py * Add test case for complex odeint (currently failing) * Wrap odeint into complex-to-real function in test case * fixup Co-authored-by: Stephan Hoyer <shoyer@google.com> 24 August 2020, 20:50:44 UTC
0cc3802 Fix documentation of scatter_* operations. (#4138) * Fix documentation of scatter_* operations. This commit changes the documentation of the `unique_indices` parameter to scatter to better capture its intended meaning in XLA. 24 August 2020, 19:29:22 UTC
e5c4ccb Merge pull request #4125 from google/issue4124 make random.choice error when shape isn't sequence 22 August 2020, 03:36:02 UTC
56b3688 make random.choice error when shape isn't sequence fixes #4124 22 August 2020, 02:58:06 UTC
6bed4ee Temporarily disable jax_jit tests. (#4118) 22 August 2020, 01:44:52 UTC
7082105 Merge pull request #4123 from google/only-one-axis-index-primitive only construct one axis_index_p primitive 22 August 2020, 01:35:30 UTC
a62580c deflake 22 August 2020, 00:56:59 UTC
66a02b6 only construct one axis_index_p primitive Before this change, there were two versions, one used with omnistaging and one without. But that made bookkeeping hard and buggy. This change defines the axis_index_p primitive in core.py. Some of its rules are still changed when omnistaging is enabled. 22 August 2020, 00:43:15 UTC
7e77af4 don't force backend creation in xla_computation (#4121) 21 August 2020, 19:34:31 UTC
519d57c fix bugs 21 August 2020, 19:14:45 UTC
9d733dd Doc: change suggested way of starting the profiler (#4120) 21 August 2020, 18:19:51 UTC
b2a239c don't force backend creation in xla_computation 21 August 2020, 18:10:53 UTC
3063145 use xla.backend_compile function in pxla.py (#4113) * use xla.backend_compile function in pxla.py Not only is this useful for profiling, it also helps us do google-internal logging for the XLA team. 20 August 2020, 21:44:26 UTC
1e6b809 Fixes padding generation for padding == 'SAME' in reduce_window to (#4110) * Fixes padding generation for padding == 'SAME' in reduce_window to take window_dilation into account. (Fixes google/jax#3973). This commit applies the fix suggested by James on the issue, which is backed by the meaning of padding described on https://www.tensorflow.org/xla/operation_semantics#reducewindow. * Added shape tests for reduce_window when stride is 1 in each direction and padding is 'SAME'. 20 August 2020, 18:45:15 UTC
1e8ac24 Add rademacher, maxwell, double_sided_maxwell and weibull_min to jax.random. (#4104) 20 August 2020, 14:46:55 UTC
d978808 Document the required form of tap_func for host_callback.id_tap (#4100) 20 August 2020, 07:36:35 UTC
22b593b [jax2tf] General conversion of reduce_window. (#4093) * [jax2tf] General conversion of reduce_window. Much like scatter_p, the conversion works as well as the underlying reduction function (e.g. reduce_window_max is not properly converted for the int8 dtype, because tf.math.maximum is not defined for int8 tensors). 20 August 2020, 07:05:30 UTC
9ba282f add axis_index to supported multi-host collectives (#4107) also make the error message less confusing 19 August 2020, 22:51:40 UTC
b4efb31 Docs: Fix broken link in quickstart (#4102) 19 August 2020, 18:36:28 UTC
9ca1020 Add a fast C++ jit codepath. (#4089) This starts a C++ jit codepath to speed up dispatch time. Tracing is not supported yet. Supported features: - scalar, numpy array and DeviceArray argument support: - integer, floats, boolean, and complex scalars arguments are supported. - The jax_enable_x64 flag will be used at object-creation type to cast scalars and numpy arrays. - The Jax `weak_type` attribute for arguments is supported (DeviceArray and scalars). - The donate_argnums argument. - Use an XLA tuple for more than 100 arguments Unsupported features: - jax._cpp_jit on methods e.g @functools.partial(jax.jit, static_argnums=0) def _compute_log_data(self, ...) ... This is currently not supported by the C++ codepath, because "self" won't be automatically added. - disable_jit. 19 August 2020, 16:39:25 UTC
5135fd1 fix jaxpr util test under enable_x64 19 August 2020, 15:28:56 UTC
b892236 remove check for TypedJaxpr literals arent tracers (#4096) In the original usage of TypedJaxpr, literals could not be tracers because they were only produced by initial-style transformations of jaxprs. But now TypedJaxpr is used in several other ways, e.g. in make_jaxpr, and moreover its avals are redundant. It should probably be renamed ClosedJaxpr since it mainly serves to package a jaxpr together with its constant arrays. This check was limiting the utility of TypedJaxpr, and it was only added relatively recently anyway. 19 August 2020, 04:04:14 UTC
8e166ad Unbreak jaxlib build. (#4098) 19 August 2020, 01:24:41 UTC
8cc9579 check path prefixes using os.path instead of string comparisons 19 August 2020, 01:08:08 UTC
d778a6d move experimental.jaxpr_stats to jaxpr_util 19 August 2020, 01:07:38 UTC
908d54a utilities to collect summary statistics of jaxprs 19 August 2020, 01:07:38 UTC
d70976c Cleanup: reduce redundant code (#4095) 18 August 2020, 23:31:54 UTC
afeefa6 Add typing and namedtuple to `optimizers.py`, improve documentation. (#3570) 18 August 2020, 22:49:29 UTC
29f7fa7 Add implementation of jax.numpy.trim_zeros (#4027) 18 August 2020, 20:40:45 UTC
29aa9bf Cleanup: avoid jnp.prod & np.prod on array shapes (#4086) 18 August 2020, 17:17:38 UTC
decd760 Add experimental __array_module__ method (#4076) * Add experimental __array_module__ method xref https://github.com/google/jax/issues/1565 `__array_module__` (see [NEP 37](https://numpy.org/neps/nep-0037-array-module.html)) is an experimental alternative to `__array_function__` and `__array_ufunc__` for "duck array" compatibility with NumPy that promises to be much less invasive. Example usage: ```python import numpy as np def duckarray_stack(arrays): """This "stack" function should work with any array library, including JAX.""" npx = np.get_array_module(*arrays) arrays = [npx.asarray(arr) for arr in arrays] shapes = {arr.shape for arr in arrays} if len(shapes) != 1: raise ValueError('all input arrays must have the same shape') expanded_arrays = [arr[npx.newaxis, ...] for arr in arrays] return npx.concatenate(expanded_arrays, axis=0) ``` Support for this protocol has *not* yet been implemented in NumPy, but it can be tested with https://github.com/seberg/numpy-dispatch. My reasoning for merging it into JAX (on an experimental basis with no guarantees, of course) is that: 1. It's not invasive -- the implementation is small and self-contained. 2. No backwards compatibility issues. Unlike `__array_function__` and `__array_ufunc__`, `__array_module__` will always require an explicit opt-in by libraries that use it by calling `get_array_module()`. 2. Other NumPy developers [want evidence](https://github.com/numpy/numpy/pull/16935#issuecomment-673951287) that this is actually feasible. 3. Scikit-Learn developers like @thomasjpfan are interested in exploring supporting scikit-learn on top of NumPy-like libraries like JAX, and experimental support for this protocol will make that easier. Note: this PR does add `numpy-dispatch` as a optional testing requirement in order to verify that this works. If desired, we could remove this from CI, but installing numpy-dispatch (and its build requirement Cython) appears to only add a few seconds of build time. * don't explicitly list cython * remove UnshpaedArray from _JAX_ARRAY_TYPES * Remove incorrect note about metaclasses * remove unnecessary numpy_dispatch.ensure_dispatching() 18 August 2020, 16:40:57 UTC
36f3a36 Separate axis splitting from collective handling (#4082) This makes the vmap collective handling a bit more flexible and allowed me to add ppermute support. 18 August 2020, 10:02:28 UTC
ace23fa [jax2tf] Added tests for reduce_window translation (#4062) * [jax2tf] Added tests for reduce_window translation (WIP) * Added other non-floating types to the tests. 18 August 2020, 09:01:13 UTC
8c2ee37 Prior refactoring before the C++ jax.jit. (#4045) 18 August 2020, 08:43:52 UTC
2ab6b42 Use pytree defined in tensorflow. (#4087) It also adds some tests on the scalar C++ conversion. 18 August 2020, 05:58:43 UTC
fe69d3c always deref all locals that indirectly reach stack frames in the exception-reraise handler 18 August 2020, 01:13:58 UTC
dbca9e6 unrevert #3674 (revert #3791) 18 August 2020, 01:13:58 UTC
1ba4e06 Initial version of gmap (#4006) Co-autored-by: Matthew Johnson <mattjj@google.com> 17 August 2020, 18:11:43 UTC
4c22e01 [jax2tf] Explictly raise an error when attempting to convert _select_and_scatter_add_p. (#4084) 17 August 2020, 14:32:34 UTC
ec90c35 [jax2tf] Fix bfloat16 bug in select_and_gather_add conversion. (#4058) * [jax2tf] Fix bfloat16 bug in select_and_gather_add conversion. This fix makes it possible to run bfloat16 tests for the jax2tf conversion of select_and_gather_add. 17 August 2020, 10:57:41 UTC
c7aff1d Revert "Use pytree from xla_client. (#4063)" (#4081) This reverts commit d8de6b61411179dcd2f63d7639bbcd69b30ac15f. Tryting to revert because it seems that this produces test failures in Google. 17 August 2020, 09:53:18 UTC
22b92c5 Increase tolerance for CPU test LaxBackedNumpyTests::testCorrCoef (#4080) 17 August 2020, 06:50:47 UTC
9120701 allow xla_computation to psum a constant (#4078) * allow xla_computation to psum a constant * allow axis_env to be None 17 August 2020, 03:00:40 UTC
8232f2d adapt _TempAxisName for unhashable objs (#4077) adapt _TempAxisName for unhashable objs 16 August 2020, 05:55:18 UTC
16ab9cb support multi-host pmap with omnistaging (#4075) 15 August 2020, 22:07:08 UTC
1316562 Canonicalize result dtype to fix double precision problem in ldexp (#4069) 15 August 2020, 15:47:28 UTC
1dbdaac [jax2tf] avoid import errors when omnistaging is enabled (#4072) * [jax2tf] avoid import errors when omnistaging is enabled 15 August 2020, 05:55:02 UTC
9ab07d8 support axis_index_groups in psum(const) (#4070) * support axis_index_groups in psum(const) * add test for psum(constant, axis_index_groups) * rm trailing whitespace * Update lax_parallel.py 15 August 2020, 05:54:36 UTC
394a33c Add `in_parts` and `out_parts` optional arguments `jax.xla_computation`. (#4055) PR #3771 redux (reverted in #3780) Co-authored-by: Skye Wanderman-Milne <skyewm@google.com> 14 August 2020, 20:05:58 UTC
back to top