https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
89c0f2f adjust test tolerance to account for changes in XLA:TPU PiperOrigin-RevId: 588037215 06 December 2023, 06:33:06 UTC
80e91fd Disabling running all_gather test on non-TPU platforms PiperOrigin-RevId: 588180749 05 December 2023, 21:57:19 UTC
dc469c3 Merge pull request #18676 from 8bitmp3:jax-docs-autodiff-101-201 PiperOrigin-RevId: 588172128 05 December 2023, 21:29:10 UTC
7090347 Merge pull request #18739 from jakevdp:ci-precommit PiperOrigin-RevId: 588161526 05 December 2023, 20:54:44 UTC
f25a51e Change bazel config name "rbe_cpu_linux_py39" to "rbe_cpu_linux_py3.9"to be consistent with cuda bazel build configs. Change other python versions bazel config similarly. PiperOrigin-RevId: 588100957 05 December 2023, 17:42:14 UTC
720ff42 [bazel] Add a macro if_building_jaxlib() to guard dependencies that should only be present if building jaxlib. Cleanup only, NFC intended. PiperOrigin-RevId: 588074047 05 December 2023, 16:05:17 UTC
9d35b90 [XLA:Mosaic] Support expanding lane dim in shapecast: (..., 128) -> (..., m * 128) and handle relayout from (1, 128) to (8, 128) for more general cases. PiperOrigin-RevId: 588024159 05 December 2023, 12:25:10 UTC
a31129a [Pallas/TPU] Add all gather kernel example PiperOrigin-RevId: 587963496 05 December 2023, 08:11:13 UTC
91ef37b Merge pull request #18818 from gnecula:fix_indexing_with_jax_array PiperOrigin-RevId: 587962024 05 December 2023, 08:03:09 UTC
ec46058 Fix indexing with slices when the slice elements are jax.Array. This fixes a bug introduced in #18679, for the case when some elements of the slice are `jax.Array`. We add a new test also. 05 December 2023, 07:02:50 UTC
7a3e214 Update XLA dependency to use revision http://github.com/openxla/xla/commit/a83901622bca41cb58594e81f752a1609f4417df. PiperOrigin-RevId: 587936503 05 December 2023, 05:48:24 UTC
7fa0f46 [bazel] Add a BUILD file for jax/extend, and add more granular targets for individual pieces of extend. In general we'd like to use more granular BUILD targets rather than larger monolithic targets. If nothing else, they interact better with pytype. This change is in preparation for adding the JAX MLIR bindings to jax.extend, since they are something that JAX users sometimes need especially for defining custom ops. PiperOrigin-RevId: 587893573 05 December 2023, 01:48:50 UTC
f29ec4e Change seed for a test We got unlucky and hit a seed which happens to fail the KS test. PiperOrigin-RevId: 587885112 05 December 2023, 01:06:17 UTC
46bad2b Upgrade JAX Autodiff 101 05 December 2023, 00:00:13 UTC
a9bfbd3 Finish jax and jaxlib 0.4.21 release PiperOrigin-RevId: 587866580 04 December 2023, 23:51:58 UTC
193eb12 Use sys.version_info to guard Python 3.11 only code. This makes pytype happier since it understands sys.version_info, but didn't understand the previous hasattr() test. PiperOrigin-RevId: 587846307 04 December 2023, 22:48:46 UTC
aa27048 Merge pull request #18809 from jakevdp:fix-numpy-warning PiperOrigin-RevId: 587845552 04 December 2023, 22:40:51 UTC
150ab68 Merge pull request #18580 from 8bitmp3:jax-docs-new-installation PiperOrigin-RevId: 587840299 04 December 2023, 22:24:34 UTC
8b74b93 Test: fix casting warning in betainc test 04 December 2023, 22:20:14 UTC
b1a8bc6 Upgrade JAX Installation doc 04 December 2023, 21:42:13 UTC
c4d2fc7 Replace device_buffers with addressable_shards in test because device_buffers is deprecated PiperOrigin-RevId: 587825636 04 December 2023, 21:36:12 UTC
baa7756 Use scoped disable_jit() in dynamic_api_test. This test was leaving jit disabled, affecting other tests. PiperOrigin-RevId: 587803847 04 December 2023, 20:20:36 UTC
b04fd31 Add option to pass in `unroll=True/False` into `scan` and `fori_loop`. PiperOrigin-RevId: 587795364 04 December 2023, 19:54:50 UTC
5b97960 Disable sanitizer builds of shape_poly_test. These take a very long time and sometimes timeout so it's probably not worth running them in CI. PiperOrigin-RevId: 587768399 04 December 2023, 18:42:56 UTC
54e3b76 Add support for unrolling to `lax.fori_loop` PiperOrigin-RevId: 587767613 04 December 2023, 18:34:53 UTC
5942e15 Prepare for 0.4.21 release PiperOrigin-RevId: 587767502 04 December 2023, 18:26:48 UTC
d91c13e Merge pull request #18795 from gnecula:test_export_grad PiperOrigin-RevId: 587730171 04 December 2023, 16:32:27 UTC
8a2d4a0 [export] Add and fix a test for exporting higher-order gradients with sharding There was a test for export with gradients, we changed the test to (a) export 2nd order gradient also, and (b) to export both with a mesh context and without a mesh context (using NamedSharding). This test currently fails, only in the case when we do NOT have a mesh context, as explained below: When exporting gradient functions, we first export the primal functions and we use the in/out-shardings to construct shardings of the gradient function. Since Exported shardings now contain only HloSharding objects, and to lower the gradient function we must use `pjit(vjp(f)).lower()`, we construct GSPMDSharding objects using the current devices and the HloSharding object from the Exported primal. However, these objects do not have the `_original_sharding` attribute. Later in `pjit._resource_typing_pjit` we attempt to `parse_flatten_op_sharding` using the mesh context (which is empty). This fails. This PR contains one workaround, to skip `parse_flatten_op_sharding` if the physical mesh of the `resource_env` is empty. Another, probably better solution, is to ensure that `resource_env` is `None` when then is no mesh context. That seemed reasonable, but currently the code returns an empty mesh from the resource_env if there is no mesh context. Changing this would have effects in more parts of the code, so I have not done it here, but it may be worth doing. 04 December 2023, 16:07:23 UTC
1d95e79 Disable export_harnesses_multi_platform_test under sanitizers. This test appears to hit some sort of LLVM bug on Sapphire Rapids CPUs. PiperOrigin-RevId: 587719850 04 December 2023, 15:54:35 UTC
5e0993c Merge pull request #18794 from olupton:qualname PiperOrigin-RevId: 587696743 04 December 2023, 14:23:09 UTC
3c0c6b7 Use qualified name if possible. 04 December 2023, 13:19:38 UTC
a137edc Update XLA dependency to use revision http://github.com/openxla/xla/commit/96c4f8749b6521ff3e2670d168c9095a5f6323c5. PiperOrigin-RevId: 587591208 04 December 2023, 05:18:37 UTC
bf08411 Update XLA dependency to use revision http://github.com/openxla/xla/commit/6f22d6a8d45ce00c2d6654fb5f99ca7f0f094b51. PiperOrigin-RevId: 587406410 03 December 2023, 05:40:59 UTC
61e79cd Merge pull request #18786 from gnecula:test_export_effects PiperOrigin-RevId: 587333913 02 December 2023, 18:49:08 UTC
b51b80e Merge pull request #18761 from gnecula:export_sharding PiperOrigin-RevId: 587330573 02 December 2023, 18:24:47 UTC
bd7c1aa [export] Improve the testing of exporting with effects We now test that when we call an Exported from a computation that already uses effects, the effects from the calling computation are identified with the events from the called Exported. 02 December 2023, 18:18:22 UTC
3eb3e2d [export] Simplify the handling of shardings in Exported. Previously, Exported contained tuples of `XlaCompatibleSharding` for the input and output shardings. These shardings contain references to JAX devices, which is too much for exporting purposes and in fact it gets in the way when we want to serialize the Exported. We change Exported to carry `xla_client.HloSharding` instead, which conveniently can be serialized to proto. We use the value `None` to denote an unspecified sharding. We also add `nr_devices` and then for exporting purposes we can construct actual `XlaCompatibleSharding` when we need to. 02 December 2023, 18:10:24 UTC
b822801 Merge pull request #18783 from gnecula:fix_indexing PiperOrigin-RevId: 587328492 02 December 2023, 18:09:29 UTC
32fb1b4 Remove the ml_program MLIR dialect from jaxlib. Jax isn't using this, and in fact our code to build this wasn't including the C++ parts, so it was broken anyway. Remove it until someone actually needs it for something. PiperOrigin-RevId: 587323808 02 December 2023, 17:29:39 UTC
d2f6261 Fix bug in indexing with slices that overflow, and add tests. This bug was introduced in #18679, and was not caught in unit tests because we were not testing cases when the slice needs to be clamped. 02 December 2023, 14:47:06 UTC
0418370 Put the input creating via `jnp` outside the pjit cpp cache miss counting code. PiperOrigin-RevId: 587256278 02 December 2023, 08:31:24 UTC
965ba05 Rewrite `tf.aliasing_output` when wrapping the main StableHLO function When the original main function has tokens and the wrapped main function does not, there are fewer outputs in the wrapped main than the original main. This is problematic for `tf.aliasing_output`, which is an argument attribute that stores result indexes to which the argument can alias. This CL makes the wrapper main creation rewrite `tf.aliasing_output` according to the new result indexes. The newly added test verifies that the aliasing indexes are correct across all supported serialization versions. Confirmed that the test fails without changes in export.py (versions 6, 7, and 8 fail and 9 passes). PiperOrigin-RevId: 587237842 02 December 2023, 06:42:37 UTC
1aab108 Update XLA dependency to use revision http://github.com/openxla/xla/commit/33c73ddc733b38245f1e2ca46f092119615fb386. PiperOrigin-RevId: 587237109 02 December 2023, 06:34:53 UTC
f0bc7e0 Reverts f0382a5838f4526d21631e804f6fe576bfc3f97e PiperOrigin-RevId: 587231484 02 December 2023, 06:06:33 UTC
86661c8 Re-enable passing tests on TPU These have been working for a while. PiperOrigin-RevId: 587199000 02 December 2023, 03:26:58 UTC
e61f7a8 Merge pull request #18757 from jakevdp:astype PiperOrigin-RevId: 587167285 02 December 2023, 01:09:02 UTC
f5f885d Merge pull request #18774 from hawkinsp:nightlyci PiperOrigin-RevId: 587150147 02 December 2023, 00:05:58 UTC
595117b Add a test to check if arr.delete() is idempotent. PiperOrigin-RevId: 587121346 01 December 2023, 22:28:51 UTC
7274960 Remove nightly NVIDIA GPU multiprocess CI. This CI seems to be dead. 01 December 2023, 22:13:24 UTC
a999120 Improve error message when cudnn is not found. We infer a missing cudnn if cudnnGetVersion() returns 0, since the stub implementation in TSL will do that if the library isn't found (https://github.com/openxla/xla/blob/10a378f49978aa4ee4564ceb105d33694fd48202/third_party/tsl/tsl/cuda/cudnn_stub.cc#L58). PiperOrigin-RevId: 587056454 01 December 2023, 18:52:48 UTC
50c7223 Fix Windows build failure. The TPU extension didn't build because the MLIR Python binding code requires pybind11 to be included first on Windows, per https://github.com/llvm/llvm-project/blob/9584f5834499e6093797d4a28fde209f927ea556/mlir/include/mlir-c/Bindings/Python/Interop.h#L24 PiperOrigin-RevId: 587049246 01 December 2023, 18:31:53 UTC
95bc2ba Inline sigmoid, isfinite, and isnan in jaxprs. In the common case (real values) these are all single-expression jaxprs themselves, so putting them out of line just makes things more verbose. There's no reason to include stuff like this in a jaxpr: ``` cxd:bool[8,16] = pjit[ jaxpr={ lambda ; cxe:f32[8,16]. let cxf:bool[8,16] = is_finite cxe in (cxf,) } name=isfinite ] cxc ``` PiperOrigin-RevId: 587047955 01 December 2023, 18:23:56 UTC
ada5fe5 Remove numpy-dispatch CI job & simplify build specification The numpy-dispatch approach has been superseded by the Python Array API (Tracked for JAX in https://github.com/google/jax/issues/18353). While we're here, we'll reduce the github CI to only two jobs: the oldest and newest supported Python versions. Other versions can be covered by Kokoro. PiperOrigin-RevId: 587041291 01 December 2023, 18:03:15 UTC
b3c579e Merge pull request #18762 from gnecula:poly_getitem_next PiperOrigin-RevId: 587018677 01 December 2023, 16:33:55 UTC
65fca0e [shape_poly] Add heuristics for deciding >= 0 The rules for deciding inequalities of symbolic expressions are incomplete. Here we add two heuristics that help decide the bounds checking of indices computed for indexing with slices: To decide whether an expression that contains `non_negative(e)` is >= 0, it is sufficient to show that the expression is >=0 if we replace the `non_negative(e)` with `0` and with `e`. To decide whether `floordiv(e, k)` is >= 0, when `k >= 0`, it is sufficient to show that `e` is >= 0. These are sufficient for the bounds checking that JAX is doing internally, but may not be for the cases when the user program does index computations using those operators. This enables us to re-enable the shape_poly indexing tests. 01 December 2023, 11:55:42 UTC
54fee48 Cast in/out shardings to tuple before passing to `Exported` ctor. PiperOrigin-RevId: 586951567 01 December 2023, 10:52:26 UTC
e60aa3b Merge pull request #18679 from gnecula:poly_getitem2 PiperOrigin-RevId: 586902301 01 December 2023, 07:07:31 UTC
7c76f16 Update XLA dependency to use revision http://github.com/openxla/xla/commit/6947d0481880b46f0deda592b1ad1fee365128a6. PiperOrigin-RevId: 586899328 01 December 2023, 06:58:46 UTC
2d1ce13 [shape_poly] Simplify the indexing with slice to make it compatible with shape polymorphism Currently, we do not support shape polymorphism when we index with a slice, e.g., `x[a:b:c]`, and insted we direct the user to use to `lax.dynamic_slice`. This is only because so far we have not tried to ensure that the index and bounds checking computations in gather are compatible with shape polymorphism. The problem was that there were a lot of conditionals, e.g., `if start >= stop` that cannot be handled in general in presence of symbolic shapes. Here we introduce a new helper function `_preprocess_slice` to contain all the computations for the start and the size of the slice. To test that this does not break the JAX index computations, I ran the tests with `JAX_NUM_GENERATED_CASES=1000`, especially the `lax_numpy_indexer_test.py`. 01 December 2023, 06:40:07 UTC
8ad774f Automate arguments for jax.distributed.initialize for cloud TPU environments. PiperOrigin-RevId: 586892544 01 December 2023, 06:25:00 UTC
d77cd9a Add jax.numpy.astype function 30 November 2023, 23:50:22 UTC
a07ed22 Merge pull request #18756 from jakevdp:skips PiperOrigin-RevId: 586795988 30 November 2023, 22:56:11 UTC
efb4924 Merge pull request #18754 from mattjj:fix-float0 PiperOrigin-RevId: 586781308 30 November 2023, 22:05:11 UTC
43ed74f rewrite test not to include float0 broadcast 30 November 2023, 21:53:13 UTC
d6154e5 [array-api] remove some test skips 30 November 2023, 21:28:08 UTC
5b3fc1b Merge pull request #18730 from jakevdp:dep-device PiperOrigin-RevId: 586761439 30 November 2023, 20:58:16 UTC
569f06c In python 3.11 async.run() always tries to convert repr of the result of a coroutine as integer while fetching sigint handler. This makes the test materialize the whole tensor in memory. This changes the test co-routine to return nothing to avoid triggering this bug. https://github.com/python/cpython/issues/112559 PiperOrigin-RevId: 586756112 30 November 2023, 20:37:12 UTC
97beb01 Deprecate the device() method of JAX arrays 30 November 2023, 19:43:02 UTC
4de07b3 Merge pull request #18753 from jakevdp:warnings-tests PiperOrigin-RevId: 586739682 30 November 2023, 19:37:24 UTC
53e66c1 Merge pull request #18752 from mattjj:shmap-remat-rule PiperOrigin-RevId: 586729063 30 November 2023, 19:02:20 UTC
dab6379 Merge pull request #18746 from olupton:fix-repeated-builds PiperOrigin-RevId: 586723814 30 November 2023, 18:47:45 UTC
bd46e5c Add `nb::arg` to nanobind definitions to generate better python annotations. PiperOrigin-RevId: 586721759 30 November 2023, 18:39:28 UTC
d2b4800 tests: improve warnings-related tests 30 November 2023, 18:35:24 UTC
5862852 [shard-map] add rewrite and replication checking rules for remat these rules enable shmap-of-remat with check_rep=True 30 November 2023, 18:15:48 UTC
11d7a2b Merge pull request #18741 from mattjj:shmap-test-fix PiperOrigin-RevId: 586710378 30 November 2023, 18:09:32 UTC
5c2635c [shard-map] fix test running broken by 0aec40a16fad02f084ef0cabd350db78b86b335e 30 November 2023, 17:56:34 UTC
e50c35d Fix repeatedly building JAX. Reproducer was essentially running `pip install .` twice in a row in the same source directory. Closes google/jax#18252. 30 November 2023, 12:31:17 UTC
fe237cd Update XLA dependency to use revision http://github.com/openxla/xla/commit/58e6b428e22e40c4100a7b66790fbe86dc9d7845. PiperOrigin-RevId: 586556206 30 November 2023, 06:44:13 UTC
d9a0cc0 Merge pull request #18731 from mattjj:shmap-custom-jvp-fix PiperOrigin-RevId: 586527134 30 November 2023, 04:24:54 UTC
b8f758e [shard-map] replace jaxpr interpreters with final-style-xform-of-eval-jaxpr 30 November 2023, 04:06:12 UTC
e624610 Replace apply_primitive internals with `jax.jit`. This allows deletion of a lot of code and leads to ~40% eager performance speedup. Benchmarks: ``` name old time/op new time/op delta eager_unary_dispatch 31.3µs ± 1% 19.4µs ± 6% -37.91% (p=0.016 n=4+5) eager_unary 32.1µs ± 0% 19.8µs ± 4% -38.26% (p=0.016 n=4+5) eager_binary_dispatch 35.9µs ± 1% 20.5µs ± 4% -42.93% (p=0.016 n=4+5) eager_binary 36.6µs ± 1% 21.1µs ± 4% -42.29% (p=0.016 n=4+5) jit_trivial_dispatch 3.87µs ± 2% 4.12µs ±25% ~ (p=1.000 n=5+5) jit_trivial 4.75µs ± 2% 4.82µs ±11% ~ (p=0.690 n=5+5) jit_simple_dispatch 2.95µs ± 2% 2.97µs ± 7% ~ (p=1.000 n=5+5) jit_simple 3.52µs ± 6% 3.51µs ± 5% ~ (p=0.841 n=5+5) jit_simple_dispatch_array 2.95µs ± 2% 2.96µs ± 6% ~ (p=1.000 n=5+5) jit_simple_array 3.46µs ± 2% 3.51µs ± 5% ~ (p=0.690 n=5+5) jit_small_matmul 3.01µs ± 1% 3.00µs ± 4% ~ (p=0.548 n=5+5) jit_big_matmul 34.0µs ±18% 35.5µs ±17% ~ (p=0.310 n=5+5) jit_simple_many_args_dispatch/num_args:10 6.93µs ± 6% 6.80µs ± 6% ~ (p=0.481 n=10+10) jit_simple_many_args_dispatch/num_args:100 47.7µs ± 7% 45.4µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args_dispatch/num_args:1000 545µs ± 8% 516µs ± 2% ~ (p=0.101 n=10+8) jit_simple_many_args_dispatch/num_args:2000 1.12ms ± 7% 1.07ms ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:10 7.42µs ± 5% 7.23µs ± 2% ~ (p=0.173 n=10+8) jit_simple_many_args/num_args:100 48.4µs ± 7% 45.6µs ± 2% ~ (p=0.237 n=10+8) jit_simple_many_args/num_args:1000 542µs ± 6% 524µs ± 8% ~ (p=0.089 n=10+10) jit_simple_many_args/num_args:2000 1.12ms ± 7% 1.08ms ± 1% ~ (p=0.068 n=10+8) jit_simple_pruned_args_dispatch_10 4.79µs ± 8% 4.98µs ±10% ~ (p=0.421 n=5+5) jit_simple_pruned_args_10 5.32µs ± 6% 5.30µs ± 4% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_100 24.7µs ± 6% 23.8µs ± 8% ~ (p=0.548 n=5+5) jit_simple_pruned_args_100 25.2µs ± 6% 24.4µs ± 8% ~ (p=0.690 n=5+5) jit_simple_pruned_args_dispatch_1000 238µs ± 7% 232µs ± 8% ~ (p=0.841 n=5+5) jit_simple_pruned_args_1000 240µs ± 7% 234µs ± 8% ~ (p=1.000 n=5+5) jit_simple_pruned_args_dispatch_2000 516µs ± 6% 497µs ± 1% ~ (p=0.413 n=5+4) jit_simple_pruned_args_2000 517µs ± 6% 505µs ± 7% ~ (p=0.690 n=5+5) jit_dispatch_without_transfer 719µs ± 9% 751µs ± 8% ~ (p=0.222 n=5+5) jit_dispatch_with_transfer 799µs ±14% 793µs ± 9% ~ (p=1.000 n=5+5) pmap_trivial_2_devices 49.9µs ±40% 48.2µs ±42% ~ (p=0.841 n=5+5) pmap_trivial_dispatch_8_devices 74.5µs ±24% 78.9µs ±29% ~ (p=0.421 n=5+5) pmap_trivial_8_devices 79.3µs ± 6% 82.7µs ±20% ~ (p=0.841 n=5+5) pmap_simple_2_devices 47.1µs ±17% 49.1µs ±20% ~ (p=0.548 n=5+5) pmap_simple_dispatch_8_devices 73.4µs ±16% 76.8µs ±21% ~ (p=0.690 n=5+5) pmap_simple_8_devices 76.0µs ±10% 80.6µs ±29% ~ (p=1.000 n=5+5) pmap_simple_dispatch_8_devices_100_args 1.12ms ±22% 1.08ms ±42% ~ (p=0.841 n=5+5) pmap_simple_8_devices_100_args 12.5ms ± 8% 12.8ms ±10% ~ (p=1.000 n=5+5) sda_index_1 413µs ± 1% 686µs ± 4% +66.08% (p=0.008 n=5+5) sda_index_2 850µs ± 1% 1378µs ± 4% +62.02% (p=0.008 n=5+5) sda_index_8 3.60ms ± 1% 5.69ms ± 4% +58.00% (p=0.008 n=5+5) bench_shaped_abstractify 300µs ± 1% 305µs ± 3% ~ (p=0.056 n=5+5) bench_xla_abstractify_scalar_int 6.45µs ± 1% 6.50µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_scalar_float 3.73µs ± 1% 3.73µs ± 3% ~ (p=0.690 n=5+5) bench_xla_abstractify_scalar_numpy_int32 4.97µs ± 1% 4.83µs ± 3% ~ (p=0.095 n=5+5) bench_xla_abstractify_scalar_numpy_uint32 4.91µs ± 1% 4.75µs ± 0% -3.30% (p=0.016 n=5+4) bench_xla_abstractify_numpy_random 4.34µs ± 2% 4.31µs ± 3% ~ (p=0.310 n=5+5) bench_xla_abstractify_numpy_arange_100_float32 3.94µs ± 1% 3.93µs ± 3% ~ (p=0.548 n=5+5) bench_xla_abstractify_enum 6.85µs ± 1% 7.06µs ± 7% +3.07% (p=0.032 n=5+5) bench_are_op_shardings_equal 26.9µs ± 2% 27.0µs ± 3% ~ (p=0.841 n=5+5) bench_pjit_check_aval_sharding 691µs ± 2% 711µs ±13% ~ (p=0.841 n=5+5) bench_addressable_shards_index 656ns ± 4% 688ns ± 9% ~ (p=0.095 n=5+5) bench_remat_eager_retracing_overheads 12.7ms ± 4% 10.7ms ± 1% -15.48% (p=0.016 n=5+4) bench_remat_eager_retracing_overheads_static_argnums 13.0ms ± 2% 11.3ms ± 6% -13.71% (p=0.008 n=5+5) bench_slicing_compilation 12.1ms ± 1% 12.3ms ± 4% ~ (p=0.690 n=5+5) bench_slicing_compilation2 11.3ms ± 0% 11.5ms ± 6% ~ (p=0.690 n=5+5) bench_repeated_static_indexing 62.5ms ± 2% 40.8ms ± 8% -34.77% (p=0.008 n=5+5) bench_repeated_static_slicing 46.7ms ± 1% 31.4ms ± 2% -32.76% (p=0.008 n=5+5) pjit_simple_1_device/num_args:1 2.72µs ± 2% 2.68µs ± 5% ~ (p=0.151 n=5+5) pjit_simple_1_device/num_args:10 12.6µs ± 7% 12.3µs ± 3% ~ (p=0.310 n=5+5) pjit_simple_1_device/num_args:100 109µs ± 3% 108µs ± 4% ~ (p=0.548 n=5+5) pjit_simple_4_device/num_args:1 38.0µs ±26% 36.8µs ±19% ~ (p=0.690 n=5+5) pjit_simple_4_device/num_args:10 93.3µs ±19% 96.6µs ±23% ~ (p=0.841 n=5+5) pjit_simple_4_device/num_args:100 730µs ±16% 698µs ±48% ~ (p=0.841 n=5+5) pjit_aot_1_device/num_args:1 3.29µs ± 2% 3.12µs ± 4% -5.24% (p=0.016 n=4+5) pjit_aot_1_device/num_args:10 13.0µs ± 1% 12.7µs ± 2% ~ (p=0.063 n=4+5) pjit_aot_1_device/num_args:100 111µs ± 5% 110µs ±11% ~ (p=0.421 n=5+5) pjit_aot_4_device/num_args:1 38.4µs ±19% 38.9µs ±24% ~ (p=1.000 n=5+5) pjit_aot_4_device/num_args:10 91.3µs ±15% 96.9µs ±29% ~ (p=0.548 n=5+5) pjit_aot_4_device/num_args:100 676µs ±20% 689µs ±41% ~ (p=0.841 n=5+5) host_local_array_to_global_array 196µs ± 6% 194µs ± 4% ~ (p=0.548 n=5+5) device_put 50.8µs ± 1% 50.7µs ± 4% ~ (p=0.413 n=4+5) device_put_sharded 176µs ± 0% 177µs ± 4% ~ (p=0.190 n=4+5) device_get_8_devices 3.96ms ± 4% 4.03ms ± 7% ~ (p=0.413 n=4+5) np_asarray_8_devices 3.34ms ±18% 3.30ms ±10% ~ (p=0.548 n=5+5) jax_array_arrays_8_devices 5.01ms ±10% 5.09ms ±21% ~ (p=0.421 n=5+5) batch_inplace_while_scatter 440µs ± 1% 439µs ± 1% ~ (p=0.421 n=5+5) batch_inplace_while_dynamic_update_slice 454µs ± 0% 457µs ± 1% ~ (p=0.905 n=4+5) serial_dot_products 4.51µs ± 3% 4.41µs ± 2% ~ (p=0.151 n=5+5) bench_make_array_from_callback_fully_replicated_sharding 26.6µs ± 1% 27.0µs ± 2% ~ (p=0.056 n=5+5) ``` PiperOrigin-RevId: 586505950 30 November 2023, 02:07:13 UTC
b6c73f8 Merge pull request #18740 from mattjj:shmap-conv-rules PiperOrigin-RevId: 586499305 30 November 2023, 01:33:24 UTC
6f20c0a [shard-map] add conv replication rules fixes #18737 30 November 2023, 00:58:54 UTC
57e19db Merge pull request #18736 from mattjj:device-put-fixes PiperOrigin-RevId: 586490689 30 November 2023, 00:51:15 UTC
6814db5 CI: add more pre-commit checks 30 November 2023, 00:33:55 UTC
c9ab0bf fix grad device_put src inference, and a small device_put bug Co-authored-by: Yash Katariya <yashkatariya@google.com> 30 November 2023, 00:24:24 UTC
f0382a5 Merge pull request #18728 from jakevdp:dep-device-buffer PiperOrigin-RevId: 586481166 30 November 2023, 00:10:56 UTC
0aec40a Deprecate arr.device_buffer and arr.device_buffers 29 November 2023, 23:31:01 UTC
842ca2c Use process_count(backend) in local_devices(). Due to what is arguably a bug, multiple TPU devices in the same job can have the same process index. When determining a process count for, say, CPU, make sure we use the same backend to compute the process_count. Otherwise we might see an apparently out-of-range process index from another backend. We should perhaps fix the TPU backend not to do this, but that's going to be a bigger change. PiperOrigin-RevId: 586453157 29 November 2023, 22:24:06 UTC
d6637da Disable test_memory_cosumption test PiperOrigin-RevId: 586426753 29 November 2023, 20:50:46 UTC
0fce77a Merge pull request #18708 from jakevdp:array-equal-dep PiperOrigin-RevId: 586357829 29 November 2023, 16:58:29 UTC
ef65ba8 Internal change PiperOrigin-RevId: 586312803 29 November 2023, 13:47:50 UTC
458a896 Always lower reduce_scatter_p as an HLO ReduceScatter. We don't need the fallback path for CPU: XLA:CPU already does its own lowering of ReduceScatter as AllReduce + DynamicSlice, and I plan to teach it a direct lowering in an upcoming change. PiperOrigin-RevId: 586311031 29 November 2023, 13:37:58 UTC
86d9398 Update XLA dependency to use revision http://github.com/openxla/xla/commit/422b8e7fb70e61abeb0d4eba66df4bbd5bc3a193. PiperOrigin-RevId: 586305722 29 November 2023, 13:14:13 UTC
528a90e Merge pull request #18724 from apaszke:downgrade-logging PiperOrigin-RevId: 586304793 29 November 2023, 13:06:18 UTC
1e961b8 Remove fallback path that lowers all_gather via psum. As far as I can tell this is no longer necessary on GPU, which handles arbitrary allgather dimensions (by making the dimension the major-most dimension in layout assignment), and on CPU, where at present XLA would do the same lowering JAX would. I'm planning to improve the XLA:CPU lowering in a subsequent change. PiperOrigin-RevId: 586291911 29 November 2023, 12:14:11 UTC
d80c15a Downgrade a bunch of logging to DEBUG The logs related to compilation cache ended up being quite chatty, which is quite unlike the other logs in JAX. This downgrades a bunch of them to debug, as they can always be enabled independently using JAX config. This should also fix the recent failures in logging_test.py. 29 November 2023, 12:10:53 UTC
8dfbf90 [Pallas/Mosaic] Add support for barrier semaphores PiperOrigin-RevId: 586289340 29 November 2023, 12:04:11 UTC
5bcf231 Update XLA dependency to use revision http://github.com/openxla/xla/commit/1317a629f339500ab50665cfb35e9c3ec00215fa. PiperOrigin-RevId: 586206098 29 November 2023, 06:23:33 UTC
896d4cf Disable task_using_cache_metric unit test while debugging. This test is failing in the OSS environment. Temporarily disabling the test while debugging. PiperOrigin-RevId: 586144501 29 November 2023, 01:04:23 UTC
back to top