89c0f2f | George Necula | 05 December 2023, 13:31:40 UTC | adjust test tolerance to account for changes in XLA:TPU PiperOrigin-RevId: 588037215 | 06 December 2023, 06:33:06 UTC |
80e91fd | Sharad Vikram | 05 December 2023, 21:56:51 UTC | Disabling running all_gather test on non-TPU platforms PiperOrigin-RevId: 588180749 | 05 December 2023, 21:57:19 UTC |
dc469c3 | jax authors | 05 December 2023, 21:29:10 UTC | Merge pull request #18676 from 8bitmp3:jax-docs-autodiff-101-201 PiperOrigin-RevId: 588172128 | 05 December 2023, 21:29:10 UTC |
7090347 | jax authors | 05 December 2023, 20:54:44 UTC | Merge pull request #18739 from jakevdp:ci-precommit PiperOrigin-RevId: 588161526 | 05 December 2023, 20:54:44 UTC |
f25a51e | Jieying Luo | 05 December 2023, 17:41:40 UTC | Change bazel config name "rbe_cpu_linux_py39" to "rbe_cpu_linux_py3.9"to be consistent with cuda bazel build configs. Change other python versions bazel config similarly. PiperOrigin-RevId: 588100957 | 05 December 2023, 17:42:14 UTC |
720ff42 | Peter Hawkins | 05 December 2023, 16:04:41 UTC | [bazel] Add a macro if_building_jaxlib() to guard dependencies that should only be present if building jaxlib. Cleanup only, NFC intended. PiperOrigin-RevId: 588074047 | 05 December 2023, 16:05:17 UTC |
9d35b90 | Jevin Jiang | 05 December 2023, 12:24:26 UTC | [XLA:Mosaic] Support expanding lane dim in shapecast: (..., 128) -> (..., m * 128) and handle relayout from (1, 128) to (8, 128) for more general cases. PiperOrigin-RevId: 588024159 | 05 December 2023, 12:25:10 UTC |
a31129a | Sharad Vikram | 05 December 2023, 08:09:34 UTC | [Pallas/TPU] Add all gather kernel example PiperOrigin-RevId: 587963496 | 05 December 2023, 08:11:13 UTC |
91ef37b | jax authors | 05 December 2023, 08:03:09 UTC | Merge pull request #18818 from gnecula:fix_indexing_with_jax_array PiperOrigin-RevId: 587962024 | 05 December 2023, 08:03:09 UTC |
ec46058 | George Necula | 05 December 2023, 06:47:08 UTC | Fix indexing with slices when the slice elements are jax.Array. This fixes a bug introduced in #18679, for the case when some elements of the slice are `jax.Array`. We add a new test also. | 05 December 2023, 07:02:50 UTC |
7a3e214 | jax authors | 05 December 2023, 05:47:53 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/a83901622bca41cb58594e81f752a1609f4417df. PiperOrigin-RevId: 587936503 | 05 December 2023, 05:48:24 UTC |
7fa0f46 | Peter Hawkins | 05 December 2023, 01:48:17 UTC | [bazel] Add a BUILD file for jax/extend, and add more granular targets for individual pieces of extend. In general we'd like to use more granular BUILD targets rather than larger monolithic targets. If nothing else, they interact better with pytype. This change is in preparation for adding the JAX MLIR bindings to jax.extend, since they are something that JAX users sometimes need especially for defining custom ops. PiperOrigin-RevId: 587893573 | 05 December 2023, 01:48:50 UTC |
f29ec4e | David Majnemer | 05 December 2023, 01:05:32 UTC | Change seed for a test We got unlucky and hit a seed which happens to fail the KS test. PiperOrigin-RevId: 587885112 | 05 December 2023, 01:06:17 UTC |
46bad2b | 8bitmp3 | 28 November 2023, 23:47:48 UTC | Upgrade JAX Autodiff 101 | 05 December 2023, 00:00:13 UTC |
a9bfbd3 | Yash Katariya | 04 December 2023, 23:51:21 UTC | Finish jax and jaxlib 0.4.21 release PiperOrigin-RevId: 587866580 | 04 December 2023, 23:51:58 UTC |
193eb12 | Peter Hawkins | 04 December 2023, 22:42:44 UTC | Use sys.version_info to guard Python 3.11 only code. This makes pytype happier since it understands sys.version_info, but didn't understand the previous hasattr() test. PiperOrigin-RevId: 587846307 | 04 December 2023, 22:48:46 UTC |
aa27048 | jax authors | 04 December 2023, 22:40:51 UTC | Merge pull request #18809 from jakevdp:fix-numpy-warning PiperOrigin-RevId: 587845552 | 04 December 2023, 22:40:51 UTC |
150ab68 | jax authors | 04 December 2023, 22:24:34 UTC | Merge pull request #18580 from 8bitmp3:jax-docs-new-installation PiperOrigin-RevId: 587840299 | 04 December 2023, 22:24:34 UTC |
8b74b93 | Jake VanderPlas | 04 December 2023, 22:20:14 UTC | Test: fix casting warning in betainc test | 04 December 2023, 22:20:14 UTC |
b1a8bc6 | 8bitmp3 | 04 December 2023, 21:42:13 UTC | Upgrade JAX Installation doc | 04 December 2023, 21:42:13 UTC |
c4d2fc7 | Yash Katariya | 04 December 2023, 21:35:35 UTC | Replace device_buffers with addressable_shards in test because device_buffers is deprecated PiperOrigin-RevId: 587825636 | 04 December 2023, 21:36:12 UTC |
baa7756 | Peter Hawkins | 04 December 2023, 20:20:05 UTC | Use scoped disable_jit() in dynamic_api_test. This test was leaving jit disabled, affecting other tests. PiperOrigin-RevId: 587803847 | 04 December 2023, 20:20:36 UTC |
b04fd31 | Sharad Vikram | 04 December 2023, 19:54:12 UTC | Add option to pass in `unroll=True/False` into `scan` and `fori_loop`. PiperOrigin-RevId: 587795364 | 04 December 2023, 19:54:50 UTC |
5b97960 | Peter Hawkins | 04 December 2023, 18:28:28 UTC | Disable sanitizer builds of shape_poly_test. These take a very long time and sometimes timeout so it's probably not worth running them in CI. PiperOrigin-RevId: 587768399 | 04 December 2023, 18:42:56 UTC |
54e3b76 | Sharad Vikram | 04 December 2023, 18:26:25 UTC | Add support for unrolling to `lax.fori_loop` PiperOrigin-RevId: 587767613 | 04 December 2023, 18:34:53 UTC |
5942e15 | Yash Katariya | 04 December 2023, 18:26:07 UTC | Prepare for 0.4.21 release PiperOrigin-RevId: 587767502 | 04 December 2023, 18:26:48 UTC |
d91c13e | jax authors | 04 December 2023, 16:32:27 UTC | Merge pull request #18795 from gnecula:test_export_grad PiperOrigin-RevId: 587730171 | 04 December 2023, 16:32:27 UTC |
8a2d4a0 | George Necula | 04 December 2023, 08:24:01 UTC | [export] Add and fix a test for exporting higher-order gradients with sharding There was a test for export with gradients, we changed the test to (a) export 2nd order gradient also, and (b) to export both with a mesh context and without a mesh context (using NamedSharding). This test currently fails, only in the case when we do NOT have a mesh context, as explained below: When exporting gradient functions, we first export the primal functions and we use the in/out-shardings to construct shardings of the gradient function. Since Exported shardings now contain only HloSharding objects, and to lower the gradient function we must use `pjit(vjp(f)).lower()`, we construct GSPMDSharding objects using the current devices and the HloSharding object from the Exported primal. However, these objects do not have the `_original_sharding` attribute. Later in `pjit._resource_typing_pjit` we attempt to `parse_flatten_op_sharding` using the mesh context (which is empty). This fails. This PR contains one workaround, to skip `parse_flatten_op_sharding` if the physical mesh of the `resource_env` is empty. Another, probably better solution, is to ensure that `resource_env` is `None` when then is no mesh context. That seemed reasonable, but currently the code returns an empty mesh from the resource_env if there is no mesh context. Changing this would have effects in more parts of the code, so I have not done it here, but it may be worth doing. | 04 December 2023, 16:07:23 UTC |
1d95e79 | Peter Hawkins | 04 December 2023, 15:53:54 UTC | Disable export_harnesses_multi_platform_test under sanitizers. This test appears to hit some sort of LLVM bug on Sapphire Rapids CPUs. PiperOrigin-RevId: 587719850 | 04 December 2023, 15:54:35 UTC |
5e0993c | jax authors | 04 December 2023, 14:23:09 UTC | Merge pull request #18794 from olupton:qualname PiperOrigin-RevId: 587696743 | 04 December 2023, 14:23:09 UTC |
3c0c6b7 | Olli Lupton | 04 December 2023, 13:19:38 UTC | Use qualified name if possible. | 04 December 2023, 13:19:38 UTC |
a137edc | jax authors | 04 December 2023, 05:18:06 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/96c4f8749b6521ff3e2670d168c9095a5f6323c5. PiperOrigin-RevId: 587591208 | 04 December 2023, 05:18:37 UTC |
bf08411 | jax authors | 03 December 2023, 05:40:19 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/6f22d6a8d45ce00c2d6654fb5f99ca7f0f094b51. PiperOrigin-RevId: 587406410 | 03 December 2023, 05:40:59 UTC |
61e79cd | jax authors | 02 December 2023, 18:49:08 UTC | Merge pull request #18786 from gnecula:test_export_effects PiperOrigin-RevId: 587333913 | 02 December 2023, 18:49:08 UTC |
b51b80e | jax authors | 02 December 2023, 18:24:47 UTC | Merge pull request #18761 from gnecula:export_sharding PiperOrigin-RevId: 587330573 | 02 December 2023, 18:24:47 UTC |
bd7c1aa | George Necula | 02 December 2023, 15:19:57 UTC | [export] Improve the testing of exporting with effects We now test that when we call an Exported from a computation that already uses effects, the effects from the calling computation are identified with the events from the called Exported. | 02 December 2023, 18:18:22 UTC |
3eb3e2d | George Necula | 01 December 2023, 10:24:21 UTC | [export] Simplify the handling of shardings in Exported. Previously, Exported contained tuples of `XlaCompatibleSharding` for the input and output shardings. These shardings contain references to JAX devices, which is too much for exporting purposes and in fact it gets in the way when we want to serialize the Exported. We change Exported to carry `xla_client.HloSharding` instead, which conveniently can be serialized to proto. We use the value `None` to denote an unspecified sharding. We also add `nr_devices` and then for exporting purposes we can construct actual `XlaCompatibleSharding` when we need to. | 02 December 2023, 18:10:24 UTC |
b822801 | jax authors | 02 December 2023, 18:09:29 UTC | Merge pull request #18783 from gnecula:fix_indexing PiperOrigin-RevId: 587328492 | 02 December 2023, 18:09:29 UTC |
32fb1b4 | Peter Hawkins | 02 December 2023, 17:29:06 UTC | Remove the ml_program MLIR dialect from jaxlib. Jax isn't using this, and in fact our code to build this wasn't including the C++ parts, so it was broken anyway. Remove it until someone actually needs it for something. PiperOrigin-RevId: 587323808 | 02 December 2023, 17:29:39 UTC |
d2f6261 | George Necula | 02 December 2023, 08:08:50 UTC | Fix bug in indexing with slices that overflow, and add tests. This bug was introduced in #18679, and was not caught in unit tests because we were not testing cases when the slice needs to be clamped. | 02 December 2023, 14:47:06 UTC |
0418370 | Yash Katariya | 02 December 2023, 08:30:45 UTC | Put the input creating via `jnp` outside the pjit cpp cache miss counting code. PiperOrigin-RevId: 587256278 | 02 December 2023, 08:31:24 UTC |
965ba05 | Junwhan Ahn | 02 December 2023, 06:38:23 UTC | Rewrite `tf.aliasing_output` when wrapping the main StableHLO function When the original main function has tokens and the wrapped main function does not, there are fewer outputs in the wrapped main than the original main. This is problematic for `tf.aliasing_output`, which is an argument attribute that stores result indexes to which the argument can alias. This CL makes the wrapper main creation rewrite `tf.aliasing_output` according to the new result indexes. The newly added test verifies that the aliasing indexes are correct across all supported serialization versions. Confirmed that the test fails without changes in export.py (versions 6, 7, and 8 fail and 9 passes). PiperOrigin-RevId: 587237842 | 02 December 2023, 06:42:37 UTC |
1aab108 | jax authors | 02 December 2023, 06:34:26 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/33c73ddc733b38245f1e2ca46f092119615fb386. PiperOrigin-RevId: 587237109 | 02 December 2023, 06:34:53 UTC |
f0bc7e0 | Yash Katariya | 02 December 2023, 06:05:59 UTC | Reverts f0382a5838f4526d21631e804f6fe576bfc3f97e PiperOrigin-RevId: 587231484 | 02 December 2023, 06:06:33 UTC |
86661c8 | David Majnemer | 02 December 2023, 03:25:56 UTC | Re-enable passing tests on TPU These have been working for a while. PiperOrigin-RevId: 587199000 | 02 December 2023, 03:26:58 UTC |
e61f7a8 | jax authors | 02 December 2023, 01:09:02 UTC | Merge pull request #18757 from jakevdp:astype PiperOrigin-RevId: 587167285 | 02 December 2023, 01:09:02 UTC |
f5f885d | jax authors | 02 December 2023, 00:05:58 UTC | Merge pull request #18774 from hawkinsp:nightlyci PiperOrigin-RevId: 587150147 | 02 December 2023, 00:05:58 UTC |
595117b | Yash Katariya | 01 December 2023, 22:28:21 UTC | Add a test to check if arr.delete() is idempotent. PiperOrigin-RevId: 587121346 | 01 December 2023, 22:28:51 UTC |
7274960 | Peter Hawkins | 01 December 2023, 22:01:07 UTC | Remove nightly NVIDIA GPU multiprocess CI. This CI seems to be dead. | 01 December 2023, 22:13:24 UTC |
a999120 | Peter Hawkins | 01 December 2023, 18:52:19 UTC | Improve error message when cudnn is not found. We infer a missing cudnn if cudnnGetVersion() returns 0, since the stub implementation in TSL will do that if the library isn't found (https://github.com/openxla/xla/blob/10a378f49978aa4ee4564ceb105d33694fd48202/third_party/tsl/tsl/cuda/cudnn_stub.cc#L58). PiperOrigin-RevId: 587056454 | 01 December 2023, 18:52:48 UTC |
50c7223 | Peter Hawkins | 01 December 2023, 18:27:17 UTC | Fix Windows build failure. The TPU extension didn't build because the MLIR Python binding code requires pybind11 to be included first on Windows, per https://github.com/llvm/llvm-project/blob/9584f5834499e6093797d4a28fde209f927ea556/mlir/include/mlir-c/Bindings/Python/Interop.h#L24 PiperOrigin-RevId: 587049246 | 01 December 2023, 18:31:53 UTC |
95bc2ba | Peter Hawkins | 01 December 2023, 18:23:25 UTC | Inline sigmoid, isfinite, and isnan in jaxprs. In the common case (real values) these are all single-expression jaxprs themselves, so putting them out of line just makes things more verbose. There's no reason to include stuff like this in a jaxpr: ``` cxd:bool[8,16] = pjit[ jaxpr={ lambda ; cxe:f32[8,16]. let cxf:bool[8,16] = is_finite cxe in (cxf,) } name=isfinite ] cxc ``` PiperOrigin-RevId: 587047955 | 01 December 2023, 18:23:56 UTC |
ada5fe5 | Jake VanderPlas | 01 December 2023, 18:02:34 UTC | Remove numpy-dispatch CI job & simplify build specification The numpy-dispatch approach has been superseded by the Python Array API (Tracked for JAX in https://github.com/google/jax/issues/18353). While we're here, we'll reduce the github CI to only two jobs: the oldest and newest supported Python versions. Other versions can be covered by Kokoro. PiperOrigin-RevId: 587041291 | 01 December 2023, 18:03:15 UTC |
b3c579e | jax authors | 01 December 2023, 16:33:55 UTC | Merge pull request #18762 from gnecula:poly_getitem_next PiperOrigin-RevId: 587018677 | 01 December 2023, 16:33:55 UTC |
65fca0e | George Necula | 29 November 2023, 08:24:37 UTC | [shape_poly] Add heuristics for deciding >= 0 The rules for deciding inequalities of symbolic expressions are incomplete. Here we add two heuristics that help decide the bounds checking of indices computed for indexing with slices: To decide whether an expression that contains `non_negative(e)` is >= 0, it is sufficient to show that the expression is >=0 if we replace the `non_negative(e)` with `0` and with `e`. To decide whether `floordiv(e, k)` is >= 0, when `k >= 0`, it is sufficient to show that `e` is >= 0. These are sufficient for the bounds checking that JAX is doing internally, but may not be for the cases when the user program does index computations using those operators. This enables us to re-enable the shape_poly indexing tests. | 01 December 2023, 11:55:42 UTC |
54fee48 | jax authors | 01 December 2023, 10:51:49 UTC | Cast in/out shardings to tuple before passing to `Exported` ctor. PiperOrigin-RevId: 586951567 | 01 December 2023, 10:52:26 UTC |
e60aa3b | jax authors | 01 December 2023, 07:07:31 UTC | Merge pull request #18679 from gnecula:poly_getitem2 PiperOrigin-RevId: 586902301 | 01 December 2023, 07:07:31 UTC |
7c76f16 | jax authors | 01 December 2023, 06:58:11 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/6947d0481880b46f0deda592b1ad1fee365128a6. PiperOrigin-RevId: 586899328 | 01 December 2023, 06:58:46 UTC |
2d1ce13 | George Necula | 27 November 2023, 08:13:57 UTC | [shape_poly] Simplify the indexing with slice to make it compatible with shape polymorphism Currently, we do not support shape polymorphism when we index with a slice, e.g., `x[a:b:c]`, and insted we direct the user to use to `lax.dynamic_slice`. This is only because so far we have not tried to ensure that the index and bounds checking computations in gather are compatible with shape polymorphism. The problem was that there were a lot of conditionals, e.g., `if start >= stop` that cannot be handled in general in presence of symbolic shapes. Here we introduce a new helper function `_preprocess_slice` to contain all the computations for the start and the size of the slice. To test that this does not break the JAX index computations, I ran the tests with `JAX_NUM_GENERATED_CASES=1000`, especially the `lax_numpy_indexer_test.py`. | 01 December 2023, 06:40:07 UTC |
8ad774f | jax authors | 01 December 2023, 06:24:22 UTC | Automate arguments for jax.distributed.initialize for cloud TPU environments. PiperOrigin-RevId: 586892544 | 01 December 2023, 06:25:00 UTC |
d77cd9a | Jake VanderPlas | 30 November 2023, 23:50:22 UTC | Add jax.numpy.astype function | 30 November 2023, 23:50:22 UTC |
a07ed22 | jax authors | 30 November 2023, 22:56:11 UTC | Merge pull request #18756 from jakevdp:skips PiperOrigin-RevId: 586795988 | 30 November 2023, 22:56:11 UTC |
efb4924 | jax authors | 30 November 2023, 22:05:11 UTC | Merge pull request #18754 from mattjj:fix-float0 PiperOrigin-RevId: 586781308 | 30 November 2023, 22:05:11 UTC |
43ed74f | Matthew Johnson | 30 November 2023, 21:53:13 UTC | rewrite test not to include float0 broadcast | 30 November 2023, 21:53:13 UTC |
d6154e5 | Jake VanderPlas | 30 November 2023, 21:28:08 UTC | [array-api] remove some test skips | 30 November 2023, 21:28:08 UTC |
5b3fc1b | jax authors | 30 November 2023, 20:58:16 UTC | Merge pull request #18730 from jakevdp:dep-device PiperOrigin-RevId: 586761439 | 30 November 2023, 20:58:16 UTC |
569f06c | Mark Sandler | 30 November 2023, 20:36:27 UTC | In python 3.11 async.run() always tries to convert repr of the result of a coroutine as integer while fetching sigint handler. This makes the test materialize the whole tensor in memory. This changes the test co-routine to return nothing to avoid triggering this bug. https://github.com/python/cpython/issues/112559 PiperOrigin-RevId: 586756112 | 30 November 2023, 20:37:12 UTC |
97beb01 | Jake VanderPlas | 30 November 2023, 00:52:09 UTC | Deprecate the device() method of JAX arrays | 30 November 2023, 19:43:02 UTC |
4de07b3 | jax authors | 30 November 2023, 19:37:24 UTC | Merge pull request #18753 from jakevdp:warnings-tests PiperOrigin-RevId: 586739682 | 30 November 2023, 19:37:24 UTC |
53e66c1 | jax authors | 30 November 2023, 19:02:20 UTC | Merge pull request #18752 from mattjj:shmap-remat-rule PiperOrigin-RevId: 586729063 | 30 November 2023, 19:02:20 UTC |
dab6379 | jax authors | 30 November 2023, 18:47:45 UTC | Merge pull request #18746 from olupton:fix-repeated-builds PiperOrigin-RevId: 586723814 | 30 November 2023, 18:47:45 UTC |
bd46e5c | Shashank Viswanadha | 30 November 2023, 18:38:38 UTC | Add `nb::arg` to nanobind definitions to generate better python annotations. PiperOrigin-RevId: 586721759 | 30 November 2023, 18:39:28 UTC |
d2b4800 | Jake VanderPlas | 30 November 2023, 18:35:24 UTC | tests: improve warnings-related tests | 30 November 2023, 18:35:24 UTC |
5862852 | Matthew Johnson | 30 November 2023, 18:00:00 UTC | [shard-map] add rewrite and replication checking rules for remat these rules enable shmap-of-remat with check_rep=True | 30 November 2023, 18:15:48 UTC |
11d7a2b | jax authors | 30 November 2023, 18:09:32 UTC | Merge pull request #18741 from mattjj:shmap-test-fix PiperOrigin-RevId: 586710378 | 30 November 2023, 18:09:32 UTC |
5c2635c | Matthew Johnson | 30 November 2023, 01:27:26 UTC | [shard-map] fix test running broken by 0aec40a16fad02f084ef0cabd350db78b86b335e | 30 November 2023, 17:56:34 UTC |
e50c35d | Olli Lupton | 30 November 2023, 12:31:12 UTC | Fix repeatedly building JAX. Reproducer was essentially running `pip install .` twice in a row in the same source directory. Closes google/jax#18252. | 30 November 2023, 12:31:17 UTC |
fe237cd | jax authors | 30 November 2023, 06:43:31 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/58e6b428e22e40c4100a7b66790fbe86dc9d7845. PiperOrigin-RevId: 586556206 | 30 November 2023, 06:44:13 UTC |
d9a0cc0 | jax authors | 30 November 2023, 04:24:54 UTC | Merge pull request #18731 from mattjj:shmap-custom-jvp-fix PiperOrigin-RevId: 586527134 | 30 November 2023, 04:24:54 UTC |
b8f758e | Matthew Johnson | 29 November 2023, 00:07:47 UTC | [shard-map] replace jaxpr interpreters with final-style-xform-of-eval-jaxpr | 30 November 2023, 04:06:12 UTC |
e624610 | Yash Katariya | 30 November 2023, 02:06:36 UTC | Replace apply_primitive internals with `jax.jit`. This allows deletion of a lot of code and leads to ~40% eager performance speedup. Benchmarks: ``` name old time/op new time/op delta eager_unary_dispatch 31.3µs ± 1% 19.4µs ± 6% -37.91% (p=0.016 n=4+5) eager_unary 32.1µs ± 0% 19.8µs ± 4% -38.26% (p=0.016 n=4+5) eager_binary_dispatch 35.9µs ± 1% 20.5µs ± 4% -42.93% (p=0.016 n=4+5) eager_binary 36.6µs ± 1% 21.1µs ± 4% -42.29% (p=0.016 n=4+5) jit_trivial_dispatch 3.87µs ± 2% 4.12µs ±25% ~ (p=1.000 n=5+5) jit_trivial 4.75µs ± 2% 4.82µs ±11% ~ (p=0.690 n=5+5) jit_simple_dispatch 2.95µs ± 2% 2.97µs ± 7% ~ (p=1.000 n=5+5) jit_simple 3.52µs ± 6% 3.51µs ± 5% ~ (p=0.841 n=5+5) jit_simple_dispatch_array 2.95µs ± 2% 2.96µs ± 6% ~ (p=1.000 n=5+5) jit_simple_array 3.46µs ± 2% 3.51µs ± 5% ~ (p=0.690 n=5+5) jit_small_matmul 3.01µs ± 1% 3.00µs ± 4% ~ (p=0.548 n=5+5) jit_big_matmul 34.0µs ±18% 35.5µs ±17% ~ (p=0.310 n=5+5) jit_simple_many_args_dispatch/num_args:10 6.93µs ± 6% 6.80µs ± 6% ~ (p=0.481 n=10+10) jit_simple_many_args_dispatch/num_args:100 47.7µs ± 7% 45.4µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args_dispatch/num_args:1000 545µs ± 8% 516µs ± 2% ~ (p=0.101 n=10+8) jit_simple_many_args_dispatch/num_args:2000 1.12ms ± 7% 1.07ms ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:10 7.42µs ± 5% 7.23µs ± 2% ~ (p=0.173 n=10+8) jit_simple_many_args/num_args:100 48.4µs ± 7% 45.6µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:1000 542µs ± 6% 524µs ± 8% ~ (p=0.089 n=10+10) jit_simple_many_args/num_args:2000 1.12ms ± 7% 1.08ms ± 1% ~ (p=0.068 n=10+8) jit_simple_pruned_args_dispatch_10 4.79µs ± 8% 4.98µs ±10% ~ (p=0.421 n=5+5) jit_simple_pruned_args_10 5.32µs ± 6% 5.30µs ± 4% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_100 24.7µs ± 6% 23.8µs ± 8% ~ (p=0.548 n=5+5) jit_simple_pruned_args_100 25.2µs ± 6% 24.4µs ± 8% ~ (p=0.690 n=5+5) jit_simple_pruned_args_dispatch_1000 238µs ± 7% 232µs ± 8% ~ (p=0.841 n=5+5) jit_simple_pruned_args_1000 240µs ± 7% 234µs ± 8% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_2000 516µs ± 6% 497µs ± 1% ~ (p=0.413 n=5+4) jit_simple_pruned_args_2000 517µs ± 6% 505µs ± 7% ~ (p=0.690 n=5+5) jit_dispatch_without_transfer 719µs ± 9% 751µs ± 8% ~ (p=0.222 n=5+5) jit_dispatch_with_transfer 799µs ±14% 793µs ± 9% ~ (p=1.000 n=5+5) pmap_trivial_2_devices 49.9µs ±40% 48.2µs ±42% ~ (p=0.841 n=5+5) pmap_trivial_dispatch_8_devices 74.5µs ±24% 78.9µs ±29% ~ (p=0.421 n=5+5) pmap_trivial_8_devices 79.3µs ± 6% 82.7µs ±20% ~ (p=0.841 n=5+5) pmap_simple_2_devices 47.1µs ±17% 49.1µs ±20% ~ (p=0.548 n=5+5) pmap_simple_dispatch_8_devices 73.4µs ±16% 76.8µs ±21% ~ (p=0.690 n=5+5) pmap_simple_8_devices 76.0µs ±10% 80.6µs ±29% ~ (p=1.000 n=5+5) pmap_simple_dispatch_8_devices_100_args 1.12ms ±22% 1.08ms ±42% ~ (p=0.841 n=5+5) pmap_simple_8_devices_100_args 12.5ms ± 8% 12.8ms ±10% ~ (p=1.000 n=5+5) sda_index_1 413µs ± 1% 686µs ± 4% +66.08% (p=0.008 n=5+5) sda_index_2 850µs ± 1% 1378µs ± 4% +62.02% (p=0.008 n=5+5) sda_index_8 3.60ms ± 1% 5.69ms ± 4% +58.00% (p=0.008 n=5+5) bench_shaped_abstractify 300µs ± 1% 305µs ± 3% ~ (p=0.056 n=5+5) bench_xla_abstractify_scalar_int 6.45µs ± 1% 6.50µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_scalar_float 3.73µs ± 1% 3.73µs ± 3% ~ (p=0.690 n=5+5) bench_xla_abstractify_scalar_numpy_int32 4.97µs ± 1% 4.83µs ± 3% ~ (p=0.095 n=5+5) bench_xla_abstractify_scalar_numpy_uint32 4.91µs ± 1% 4.75µs ± 0% -3.30% (p=0.016 n=5+4) bench_xla_abstractify_numpy_random 4.34µs ± 2% 4.31µs ± 3% ~ (p=0.310 n=5+5) bench_xla_abstractify_numpy_arange_100_float32 3.94µs ± 1% 3.93µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_enum 6.85µs ± 1% 7.06µs ± 7% +3.07% (p=0.032 n=5+5) bench_are_op_shardings_equal 26.9µs ± 2% 27.0µs ± 3% ~ (p=0.841 n=5+5) bench_pjit_check_aval_sharding 691µs ± 2% 711µs ±13% ~ (p=0.841 n=5+5) bench_addressable_shards_index 656ns ± 4% 688ns ± 9% ~ (p=0.095 n=5+5) bench_remat_eager_retracing_overheads 12.7ms ± 4% 10.7ms ± 1% -15.48% (p=0.016 n=5+4) bench_remat_eager_retracing_overheads_static_argnums 13.0ms ± 2% 11.3ms ± 6% -13.71% (p=0.008 n=5+5) bench_slicing_compilation 12.1ms ± 1% 12.3ms ± 4% ~ (p=0.690 n=5+5) bench_slicing_compilation2 11.3ms ± 0% 11.5ms ± 6% ~ (p=0.690 n=5+5) bench_repeated_static_indexing 62.5ms ± 2% 40.8ms ± 8% -34.77% (p=0.008 n=5+5) bench_repeated_static_slicing 46.7ms ± 1% 31.4ms ± 2% -32.76% (p=0.008 n=5+5) pjit_simple_1_device/num_args:1 2.72µs ± 2% 2.68µs ± 5% ~ (p=0.151 n=5+5) pjit_simple_1_device/num_args:10 12.6µs ± 7% 12.3µs ± 3% ~ (p=0.310 n=5+5) pjit_simple_1_device/num_args:100 109µs ± 3% 108µs ± 4% ~ (p=0.548 n=5+5) pjit_simple_4_device/num_args:1 38.0µs ±26% 36.8µs ±19% ~ (p=0.690 n=5+5) pjit_simple_4_device/num_args:10 93.3µs ±19% 96.6µs ±23% ~ (p=0.841 n=5+5) pjit_simple_4_device/num_args:100 730µs ±16% 698µs ±48% ~ (p=0.841 n=5+5) pjit_aot_1_device/num_args:1 3.29µs ± 2% 3.12µs ± 4% -5.24% (p=0.016 n=4+5) pjit_aot_1_device/num_args:10 13.0µs ± 1% 12.7µs ± 2% ~ (p=0.063 n=4+5) pjit_aot_1_device/num_args:100 111µs ± 5% 110µs ±11% ~ (p=0.421 n=5+5) pjit_aot_4_device/num_args:1 38.4µs ±19% 38.9µs ±24% ~ (p=1.000 n=5+5) pjit_aot_4_device/num_args:10 91.3µs ±15% 96.9µs ±29% ~ (p=0.548 n=5+5) pjit_aot_4_device/num_args:100 676µs ±20% 689µs ±41% ~ (p=0.841 n=5+5) host_local_array_to_global_array 196µs ± 6% 194µs ± 4% ~ (p=0.548 n=5+5) device_put 50.8µs ± 1% 50.7µs ± 4% ~ (p=0.413 n=4+5) device_put_sharded 176µs ± 0% 177µs ± 4% ~ (p=0.190 n=4+5) device_get_8_devices 3.96ms ± 4% 4.03ms ± 7% ~ (p=0.413 n=4+5) np_asarray_8_devices 3.34ms ±18% 3.30ms ±10% ~ (p=0.548 n=5+5) jax_array_arrays_8_devices 5.01ms ±10% 5.09ms ±21% ~ (p=0.421 n=5+5) batch_inplace_while_scatter 440µs ± 1% 439µs ± 1% ~ (p=0.421 n=5+5) batch_inplace_while_dynamic_update_slice 454µs ± 0% 457µs ± 1% ~ (p=0.905 n=4+5) serial_dot_products 4.51µs ± 3% 4.41µs ± 2% ~ (p=0.151 n=5+5) bench_make_array_from_callback_fully_replicated_sharding 26.6µs ± 1% 27.0µs ± 2% ~ (p=0.056 n=5+5) ``` PiperOrigin-RevId: 586505950 | 30 November 2023, 02:07:13 UTC |
b6c73f8 | jax authors | 30 November 2023, 01:33:24 UTC | Merge pull request #18740 from mattjj:shmap-conv-rules PiperOrigin-RevId: 586499305 | 30 November 2023, 01:33:24 UTC |
6f20c0a | Matthew Johnson | 30 November 2023, 00:58:17 UTC | [shard-map] add conv replication rules fixes #18737 | 30 November 2023, 00:58:54 UTC |
57e19db | jax authors | 30 November 2023, 00:51:15 UTC | Merge pull request #18736 from mattjj:device-put-fixes PiperOrigin-RevId: 586490689 | 30 November 2023, 00:51:15 UTC |
6814db5 | Jake VanderPlas | 30 November 2023, 00:33:55 UTC | CI: add more pre-commit checks | 30 November 2023, 00:33:55 UTC |
c9ab0bf | Matthew Johnson | 30 November 2023, 00:08:31 UTC | fix grad device_put src inference, and a small device_put bug Co-authored-by: Yash Katariya <yashkatariya@google.com> | 30 November 2023, 00:24:24 UTC |
f0382a5 | jax authors | 30 November 2023, 00:10:56 UTC | Merge pull request #18728 from jakevdp:dep-device-buffer PiperOrigin-RevId: 586481166 | 30 November 2023, 00:10:56 UTC |
0aec40a | Jake VanderPlas | 29 November 2023, 23:31:01 UTC | Deprecate arr.device_buffer and arr.device_buffers | 29 November 2023, 23:31:01 UTC |
842ca2c | Peter Hawkins | 29 November 2023, 22:23:25 UTC | Use process_count(backend) in local_devices(). Due to what is arguably a bug, multiple TPU devices in the same job can have the same process index. When determining a process count for, say, CPU, make sure we use the same backend to compute the process_count. Otherwise we might see an apparently out-of-range process index from another backend. We should perhaps fix the TPU backend not to do this, but that's going to be a bigger change. PiperOrigin-RevId: 586453157 | 29 November 2023, 22:24:06 UTC |
d6637da | Yash Katariya | 29 November 2023, 20:50:04 UTC | Disable test_memory_cosumption test PiperOrigin-RevId: 586426753 | 29 November 2023, 20:50:46 UTC |
0fce77a | jax authors | 29 November 2023, 16:58:29 UTC | Merge pull request #18708 from jakevdp:array-equal-dep PiperOrigin-RevId: 586357829 | 29 November 2023, 16:58:29 UTC |
ef65ba8 | Adam Paszke | 29 November 2023, 13:47:13 UTC | Internal change PiperOrigin-RevId: 586312803 | 29 November 2023, 13:47:50 UTC |
458a896 | Peter Hawkins | 29 November 2023, 13:37:19 UTC | Always lower reduce_scatter_p as an HLO ReduceScatter. We don't need the fallback path for CPU: XLA:CPU already does its own lowering of ReduceScatter as AllReduce + DynamicSlice, and I plan to teach it a direct lowering in an upcoming change. PiperOrigin-RevId: 586311031 | 29 November 2023, 13:37:58 UTC |
86d9398 | jax authors | 29 November 2023, 13:10:09 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/422b8e7fb70e61abeb0d4eba66df4bbd5bc3a193. PiperOrigin-RevId: 586305722 | 29 November 2023, 13:14:13 UTC |
528a90e | jax authors | 29 November 2023, 13:06:18 UTC | Merge pull request #18724 from apaszke:downgrade-logging PiperOrigin-RevId: 586304793 | 29 November 2023, 13:06:18 UTC |
1e961b8 | Peter Hawkins | 29 November 2023, 12:12:19 UTC | Remove fallback path that lowers all_gather via psum. As far as I can tell this is no longer necessary on GPU, which handles arbitrary allgather dimensions (by making the dimension the major-most dimension in layout assignment), and on CPU, where at present XLA would do the same lowering JAX would. I'm planning to improve the XLA:CPU lowering in a subsequent change. PiperOrigin-RevId: 586291911 | 29 November 2023, 12:14:11 UTC |
d80c15a | Adam Paszke | 29 November 2023, 12:10:53 UTC | Downgrade a bunch of logging to DEBUG The logs related to compilation cache ended up being quite chatty, which is quite unlike the other logs in JAX. This downgrades a bunch of them to debug, as they can always be enabled independently using JAX config. This should also fix the recent failures in logging_test.py. | 29 November 2023, 12:10:53 UTC |
8dfbf90 | Sharad Vikram | 29 November 2023, 12:02:30 UTC | [Pallas/Mosaic] Add support for barrier semaphores PiperOrigin-RevId: 586289340 | 29 November 2023, 12:04:11 UTC |
5bcf231 | jax authors | 29 November 2023, 06:22:49 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/1317a629f339500ab50665cfb35e9c3ec00215fa. PiperOrigin-RevId: 586206098 | 29 November 2023, 06:23:33 UTC |
896d4cf | jax authors | 29 November 2023, 01:00:09 UTC | Disable task_using_cache_metric unit test while debugging. This test is failing in the OSS environment. Temporarily disabling the test while debugging. PiperOrigin-RevId: 586144501 | 29 November 2023, 01:04:23 UTC |