849f837 | jax authors | 30 September 2022, 23:25:17 UTC | Merge pull request #12609 from skye:jax2tf_import PiperOrigin-RevId: 478104104 | 30 September 2022, 23:25:17 UTC |
0a69c9a | Skye Wanderman-Milne | 30 September 2022, 22:55:22 UTC | Fix jax2tf import so it works with both the latest tensorflow release (2.10.0) and tf-nightly | 30 September 2022, 22:55:22 UTC |
fb8558c | Yash Katariya | 30 September 2022, 21:20:57 UTC | Add jax_array coverage to debug_nans_test PiperOrigin-RevId: 478079509 | 30 September 2022, 21:21:32 UTC |
ec41de2 | jax authors | 30 September 2022, 20:33:03 UTC | Merge pull request #12603 from jbushago:patch-1 PiperOrigin-RevId: 478068347 | 30 September 2022, 20:33:03 UTC |
ea77c45 | jax authors | 30 September 2022, 20:17:47 UTC | Merge pull request #12602 from skye:colab PiperOrigin-RevId: 478063554 | 30 September 2022, 20:17:47 UTC |
8a1e0ed | jax authors | 30 September 2022, 20:11:34 UTC | Merge pull request #12594 from skye:cache_warnings PiperOrigin-RevId: 478063392 | 30 September 2022, 20:11:34 UTC |
15e5f38 | Skye Wanderman-Milne | 27 September 2022, 20:59:08 UTC | Make persistent compilation cache warn instead of raise an error on cache read/write failures Fixes #12582. Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning. Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled. | 30 September 2022, 18:38:22 UTC |
2038988 | jbushago | 30 September 2022, 18:14:05 UTC | Fix typo in faq.rst. Fixed a small typo in the FAQ: "inthe" -> "in the". | 30 September 2022, 18:14:05 UTC |
0cc4066 | Skye Wanderman-Milne | 30 September 2022, 17:43:58 UTC | Pin default jax.tools.colab_tpu.setup_tpu driver version. Prior to this change, we were defaulting to the TPU nightly driver version. We should instead pin to the version associated with the default jaxlib version that Colab uses. | 30 September 2022, 17:45:49 UTC |
aafc77d | Yash Katariya | 30 September 2022, 16:55:25 UTC | Improve the checks done in `Array` and apply them to all `Sharding`s rather than just `XLACompatibleSharding`. Also check the symmetric difference of sharding and `_arrays` devices. PiperOrigin-RevId: 478017409 | 30 September 2022, 16:56:16 UTC |
4e51d2d | Jake VanderPlas | 30 September 2022, 01:30:28 UTC | Roll back https://github.com/google/jax/pull/12588 because of test failures PiperOrigin-RevId: 477871341 | 30 September 2022, 01:30:58 UTC |
d498bd1 | jax authors | 29 September 2022, 23:52:42 UTC | Merge pull request #12588 from jakevdp:random-annotations PiperOrigin-RevId: 477855302 | 29 September 2022, 23:52:42 UTC |
9ff570e | Yash Katariya | 29 September 2022, 23:29:16 UTC | Make debug_nans_test.py pass with jax_array=1. Both with enabled and disabled jax_array flag and --pdb_post_mortem, we fall to the same place. PiperOrigin-RevId: 477850567 | 29 September 2022, 23:29:58 UTC |
3c7d927 | Yash Katariya | 29 September 2022, 22:09:23 UTC | Disable dynamic_api_test and custom_object_test.py with jax.Array. Enable it back when support for it is added. Also don't use xla_shape since its deprecated. PiperOrigin-RevId: 477833061 | 29 September 2022, 22:09:55 UTC |
aed46f3 | Jake VanderPlas | 29 September 2022, 21:29:02 UTC | [typing] use jax.Array annotations in random.py | 29 September 2022, 21:29:02 UTC |
eb0fa40 | Yash Katariya | 29 September 2022, 19:31:48 UTC | Fix `process_allgather` to work with `jax.Array`. PiperOrigin-RevId: 477793014 | 29 September 2022, 19:32:21 UTC |
4f90af9 | jax authors | 29 September 2022, 18:31:48 UTC | Remove unused jax_unique_mhlo_module_names flag. PiperOrigin-RevId: 477778135 | 29 September 2022, 18:32:22 UTC |
a770db0 | jax authors | 29 September 2022, 17:56:14 UTC | Merge pull request #12579 from jakevdp:gather-unique PiperOrigin-RevId: 477767679 | 29 September 2022, 17:56:14 UTC |
7b49a3f | Yash Katariya | 29 September 2022, 16:53:50 UTC | Run tests in multiprocess_gpu_test only if the backend is GPU. PiperOrigin-RevId: 477750739 | 29 September 2022, 16:54:32 UTC |
1bc161a | Jake VanderPlas | 29 September 2022, 16:34:03 UTC | random.permutation: use unique_indices=True for efficiency | 29 September 2022, 16:34:03 UTC |
d49c5c3 | Jake VanderPlas | 29 September 2022, 16:33:38 UTC | jnp.take: add optional arguments forwarded to lax.gather | 29 September 2022, 16:33:38 UTC |
137384d | Mehdi Amini | 29 September 2022, 12:40:22 UTC | Update xla_sharding import path to new location We are moving the TensorFlow APIs outside of XLA and will remove the old path soon. PiperOrigin-RevId: 477701988 | 29 September 2022, 12:40:56 UTC |
de5dd1a | jax authors | 29 September 2022, 11:18:18 UTC | Merge pull request #12444 from LenaMartens:checkify-switch PiperOrigin-RevId: 477688623 | 29 September 2022, 11:18:18 UTC |
0639ace | lenamartens | 16 September 2022, 18:52:18 UTC | Raise cond index into tracing context in case of effects. So even if the cond is not data dependent at all, it's included in the dynamic trace, and effects can be discharged. | 29 September 2022, 10:36:04 UTC |
48b8956 | jax authors | 29 September 2022, 04:23:38 UTC | Merge pull request #12566 from mattjj:djax-slice-sick PiperOrigin-RevId: 477626935 | 29 September 2022, 04:23:38 UTC |
163b7e2 | Yash Katariya | 29 September 2022, 04:16:58 UTC | Convert shardings in `jit` path to OpShardingSharding to avoid recompilation when semantically similar shardings are used in `jit`. PiperOrigin-RevId: 477626548 | 29 September 2022, 04:17:29 UTC |
500f8b7 | Yash Katariya | 28 September 2022, 23:59:43 UTC | Add HLOSharding's repr to OpShardingSharding since its more compact. PiperOrigin-RevId: 477587916 | 29 September 2022, 00:00:16 UTC |
84768d2 | Yash Katariya | 28 September 2022, 23:33:43 UTC | Replace `jax.xla.DeviceArray` private type with the new public type `jax.Array`. PiperOrigin-RevId: 477582562 | 28 September 2022, 23:34:10 UTC |
a8826e6 | Matthew Johnson | 26 September 2022, 23:31:18 UTC | [dynamic-shapes] Add basic slicing support If e.g. `x : f32[10, n]` then we want to handle Python expressions like `x[0]`. To do that, we can use a generalized version of `dynamic_slice` which allows dynamic slice sizes (where the result shape depends on those slice sizes). Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com> | 28 September 2022, 22:55:51 UTC |
33dbf0e | jax authors | 28 September 2022, 21:14:30 UTC | Merge pull request #12565 from hawkinsp:release PiperOrigin-RevId: 477549228 | 28 September 2022, 21:14:30 UTC |
b49e31a | Peter Hawkins | 28 September 2022, 18:49:22 UTC | Update version numbers after release. | 28 September 2022, 18:49:22 UTC |
c89cb5d | Yash Katariya | 28 September 2022, 15:57:07 UTC | Use `Array` in `__repr__` instead of the class name which is `ArrayImpl`. PiperOrigin-RevId: 477465432 | 28 September 2022, 15:57:53 UTC |
0282b4b | jax authors | 28 September 2022, 15:00:20 UTC | Merge pull request #12538 from jakevdp:bundle-pyi PiperOrigin-RevId: 477453094 | 28 September 2022, 15:00:20 UTC |
aafc70d | jax authors | 28 September 2022, 13:50:19 UTC | Merge pull request #12556 from hawkinsp:rocm PiperOrigin-RevId: 477440001 | 28 September 2022, 13:50:19 UTC |
5fe7a54 | jax authors | 28 September 2022, 13:50:06 UTC | Merge pull request #12555 from hawkinsp:release PiperOrigin-RevId: 477439236 | 28 September 2022, 13:50:06 UTC |
39eabe8 | jax authors | 28 September 2022, 13:43:45 UTC | Merge pull request #12552 from hawkinsp:nccl PiperOrigin-RevId: 477439228 | 28 September 2022, 13:43:45 UTC |
f7bafb3 | Peter Hawkins | 28 September 2022, 13:40:57 UTC | Disable multiprocess_gpu_test that fails on ROCm. | 28 September 2022, 13:40:57 UTC |
8d86436 | Peter Hawkins | 28 September 2022, 13:33:52 UTC | jax/jaxlib 0.3.20 release candidate. | 28 September 2022, 13:33:52 UTC |
eabb91e | Peter Hawkins | 28 September 2022, 13:06:04 UTC | Fix test failure in GPU CI if NCCL_DEBUG is enabled. If NCCL_DEBUG is enabled, NCCL prints extra status information. Make test accept this. | 28 September 2022, 13:06:04 UTC |
96abd9a | jax authors | 28 September 2022, 05:33:12 UTC | Merge pull request #12540 from sharadmv:cond-lowering-fix PiperOrigin-RevId: 477358889 | 28 September 2022, 05:33:12 UTC |
96a85bd | Yash Katariya | 28 September 2022, 05:26:51 UTC | Make addressable_shards a property like local_shards PiperOrigin-RevId: 477358276 | 28 September 2022, 05:27:19 UTC |
9489068 | jax authors | 28 September 2022, 05:16:02 UTC | Merge pull request #12546 from mattjj:issue12542 PiperOrigin-RevId: 477356925 | 28 September 2022, 05:16:02 UTC |
ddeaa8d | Sharad Vikram | 27 September 2022, 22:19:03 UTC | Fix lowering bug in effectful batched cond and add tests | 28 September 2022, 05:12:13 UTC |
b4e1d0a | Yash Katariya | 28 September 2022, 04:31:58 UTC | Propagate `name` through ExecuteReplicated for `dispatch.check_special` PiperOrigin-RevId: 477351323 | 28 September 2022, 04:32:32 UTC |
b175e11 | Matthew Johnson | 28 September 2022, 03:39:19 UTC | [c++ jit] only set use_fastpath in cache_miss if all args are DeviceArrays fixes #12542 Co-authored-by: Peter Hawkins <phawkins@google.com> Co-authored-by: Kuangyuan Chen <chky@google.com> | 28 September 2022, 03:51:07 UTC |
933b6a2 | Yash Katariya | 28 September 2022, 03:27:11 UTC | Fix the bug where XLA doesn't provide shardings for all the outputs if all the elements in the output tuple have the same sharding. XLA decides to run the `FusionTupleDeduplicator` to put the sharding on ROOT instead of the tuple. PiperOrigin-RevId: 477343328 | 28 September 2022, 03:27:39 UTC |
c8bff11 | Yash Katariya | 28 September 2022, 02:18:49 UTC | Add `addressable_` counterparts of `local_` to GDA to make it easier for users to move to Array as both will have the same API. PiperOrigin-RevId: 477332697 | 28 September 2022, 02:19:29 UTC |
e4f2bff | Yash Katariya | 27 September 2022, 23:01:46 UTC | Disintegrate `Array` into DeviceBuffers inside GDA. This is required for backwards compatibility changes as users can create GDAs and pass that to pjit even when Array is switched on. PiperOrigin-RevId: 477297406 | 27 September 2022, 23:02:23 UTC |
0919a67 | jax authors | 27 September 2022, 20:31:05 UTC | Merge pull request #12534 from google:update-pypi PiperOrigin-RevId: 477260550 | 27 September 2022, 20:31:05 UTC |
6e6fb10 | Jake VanderPlas | 27 September 2022, 19:55:20 UTC | setup: bundle *.pyi files with distribution | 27 September 2022, 19:55:42 UTC |
d028d93 | Skye Wanderman-Milne | 27 September 2022, 18:00:27 UTC | Update version and changelog for jax 0.3.19 release | 27 September 2022, 18:00:27 UTC |
9e4114f | Yash Katariya | 27 September 2022, 17:06:10 UTC | Move `array.py` and `sharding.py` from `experimental/` to `_src/`. PiperOrigin-RevId: 477201711 | 27 September 2022, 17:06:52 UTC |
0e11688 | jax authors | 27 September 2022, 15:38:46 UTC | Merge pull request #12382 from jakevdp:reduction-dtype PiperOrigin-RevId: 477179725 | 27 September 2022, 15:38:46 UTC |
1bcf8d6 | jax authors | 27 September 2022, 01:14:56 UTC | Merge pull request #12497 from mattjj:djax-dag-fix1 PiperOrigin-RevId: 477038279 | 27 September 2022, 01:14:56 UTC |
e42247b | jax authors | 27 September 2022, 01:08:45 UTC | Merge pull request #12524 from sharadmv:lax-import-fix PiperOrigin-RevId: 477038211 | 27 September 2022, 01:08:45 UTC |
389a2e5 | Yash Katariya | 27 September 2022, 00:49:48 UTC | Add a backwards compat path for `op_sharding.clone()` because it doesn't exist with the latest jaxlib on pypi PiperOrigin-RevId: 477034758 | 27 September 2022, 00:50:19 UTC |
1e7ca8f | Matthew Johnson | 23 September 2022, 21:21:18 UTC | fix bug in djax type signature inference logic Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com> | 27 September 2022, 00:48:25 UTC |
1d895b2 | Sharad Vikram | 27 September 2022, 00:29:08 UTC | Fix lax imports | 27 September 2022, 00:32:44 UTC |
cbf34cb | Yash Katariya | 26 September 2022, 23:17:26 UTC | Rename the concrete class `Array` to `ArrayImpl` PiperOrigin-RevId: 477017236 | 26 September 2022, 23:18:30 UTC |
71bcabe | Tianjian Lu | 26 September 2022, 23:01:47 UTC | [sparse] Add BCSR format template. PiperOrigin-RevId: 477013899 | 26 September 2022, 23:02:16 UTC |
82636b0 | jax authors | 26 September 2022, 22:23:02 UTC | Merge pull request #12523 from jakevdp:fix-build PiperOrigin-RevId: 477005157 | 26 September 2022, 22:23:02 UTC |
6cae54f | Jake VanderPlas | 26 September 2022, 22:13:12 UTC | Fix bazel build alias | 26 September 2022, 22:13:12 UTC |
d63a944 | Peter Hawkins | 26 September 2022, 21:38:06 UTC | Change jax_jit_test to be a jax_test() under Bazel that works across backends. Make it pass under TPU if x64 types are enabled. PiperOrigin-RevId: 476994286 | 26 September 2022, 21:38:35 UTC |
265b39d | Jake VanderPlas | 26 September 2022, 21:17:43 UTC | Add pytype_srcs to main jax BUILD rule PiperOrigin-RevId: 476989241 | 26 September 2022, 21:18:13 UTC |
ddd8581 | jax authors | 26 September 2022, 20:41:31 UTC | Merge pull request #12480 from google:bug-template-gpu-smi PiperOrigin-RevId: 476979981 | 26 September 2022, 20:41:31 UTC |
1860f6d | Jake VanderPlas | 26 September 2022, 20:31:43 UTC | [x64] add promote_integers argument to jnp.prod & jnp.sum | 26 September 2022, 20:31:43 UTC |
69d1a2c | jax authors | 26 September 2022, 20:00:31 UTC | Merge pull request #12517 from skye:update-pypi PiperOrigin-RevId: 476969287 | 26 September 2022, 20:00:31 UTC |
b2b60d9 | Yash Katariya | 26 September 2022, 19:43:13 UTC | Add `make_array_from_single_device_arrays` to prepare to rename of the concrete `Array` to `ArrayImpl`. PiperOrigin-RevId: 476965287 | 26 September 2022, 19:43:59 UTC |
3c0d280 | Skye Wanderman-Milne | 26 September 2022, 19:38:32 UTC | Update version and changelog for jax 0.3.18 release | 26 September 2022, 19:43:39 UTC |
2a7b319 | Roy Frostig | 26 September 2022, 18:06:29 UTC | add `nvidia-smi` question to bug template | 26 September 2022, 18:06:29 UTC |
e034432 | jax authors | 26 September 2022, 17:04:14 UTC | Merge pull request #12513 from inoryy:patch-4 PiperOrigin-RevId: 476923412 | 26 September 2022, 17:04:14 UTC |
7962b01 | jax authors | 26 September 2022, 16:53:40 UTC | Merge pull request #12485 from LenaMartens:checkify-lower PiperOrigin-RevId: 476922387 | 26 September 2022, 16:53:40 UTC |
27e3981 | lenamartens | 23 September 2022, 14:10:41 UTC | lowerable errors behind a config flag. | 26 September 2022, 16:34:27 UTC |
8bcf358 | Roman Ring | 26 September 2022, 16:14:09 UTC | Remove unused _remat_static_argnums import. | 26 September 2022, 16:14:09 UTC |
78ecc14 | lenamartens | 23 September 2022, 13:30:49 UTC | Lowerable checks!! | 26 September 2022, 15:54:18 UTC |
28672cc | jax authors | 26 September 2022, 15:50:13 UTC | Merge pull request #12496 from mattjj:improve-leak-checker-2 PiperOrigin-RevId: 476907407 | 26 September 2022, 15:50:13 UTC |
9c66569 | jax authors | 26 September 2022, 15:23:02 UTC | Merge pull request #12468 from LenaMartens:checkify-but-better PiperOrigin-RevId: 476901601 | 26 September 2022, 15:23:02 UTC |
2df61b1 | jax authors | 26 September 2022, 15:07:11 UTC | Merge pull request #12421 from jakevdp:jax-array PiperOrigin-RevId: 476898184 | 26 September 2022, 15:07:11 UTC |
0cb233e | Jake VanderPlas | 23 September 2022, 16:59:46 UTC | Add initial jax.Array base class for instance checks & annotation | 26 September 2022, 14:48:43 UTC |
2180710 | jax authors | 26 September 2022, 14:24:44 UTC | Merge pull request #12511 from hawkinsp:release PiperOrigin-RevId: 476889960 | 26 September 2022, 14:24:44 UTC |
bcd36d8 | Peter Hawkins | 26 September 2022, 14:10:57 UTC | Jax and jaxlib 0.3.18 release candidate. | 26 September 2022, 14:10:57 UTC |
53de057 | jax authors | 26 September 2022, 13:58:46 UTC | Merge pull request #12510 from hawkinsp:context PiperOrigin-RevId: 476884674 | 26 September 2022, 13:58:46 UTC |
f4bc663 | Peter Hawkins | 26 September 2022, 13:48:56 UTC | Wrap multiprocess test popen() uses in a context manager. Ensures that resources from popen() are cleaned up. | 26 September 2022, 13:48:56 UTC |
ec15e83 | jax authors | 26 September 2022, 03:53:45 UTC | - Wraps calls to lax.xeinsum and _einsum in a named call with their 'spec', the string specifying the computation. Makes xprof traces more interpretable. PiperOrigin-RevId: 476796185 | 26 September 2022, 03:54:17 UTC |
7c85ca3 | Yash Katariya | 24 September 2022, 00:31:00 UTC | Only look at hlo_modules for output sharding if there is more than 1 device because if there is only 1 device, the spmd partitioner won't run. PiperOrigin-RevId: 476497929 | 24 September 2022, 00:31:33 UTC |
8ee7129 | Peter Hawkins | 24 September 2022, 00:11:10 UTC | Fix jnp.unwrap() test failures on GPU. A recent XLA change allows XLA to use excess precision on GPU, which caused CompileAndCheck to report noticeable numerical changes for bfloat16. In passing, also enable the comparison against NumPy test for bfloat16 by using a wrapper function. PiperOrigin-RevId: 476494989 | 24 September 2022, 00:11:51 UTC |
d2fcfb6 | jax authors | 23 September 2022, 21:51:11 UTC | Merge pull request #12407 from hirwa-nshuti:docs-fix PiperOrigin-RevId: 476467728 | 23 September 2022, 21:51:11 UTC |
03abcc7 | Matthew Johnson | 23 September 2022, 19:42:15 UTC | fix typo in test | 23 September 2022, 21:43:24 UTC |
e76aa77 | jax authors | 23 September 2022, 20:38:59 UTC | Merge pull request #12437 from sudhakarsingh27:add_multi_host_pjit_tests PiperOrigin-RevId: 476451469 | 23 September 2022, 20:38:59 UTC |
1fa0dda | Yash Katariya | 23 September 2022, 20:29:47 UTC | Return single device Arrays from `.device_buffer` and `.device_buffers`. PiperOrigin-RevId: 476449591 | 23 September 2022, 20:30:26 UTC |
43bbce0 | jax authors | 23 September 2022, 20:09:26 UTC | Merge pull request #12486 from hawkinsp:debugging PiperOrigin-RevId: 476445041 | 23 September 2022, 20:09:26 UTC |
737327a | jax authors | 23 September 2022, 19:58:03 UTC | Merge pull request #12490 from mattjj:improve-leak-checker PiperOrigin-RevId: 476442352 | 23 September 2022, 19:58:03 UTC |
b6ef90f | Matthew Johnson | 23 September 2022, 18:24:13 UTC | fix leak checker internal error The issue was that partial_eval.py's _memoize, used in custom_jvp, was made into an identity function by enabling config.jax_check_tracer_leaks (from references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger the leak checker (which would see if any references to the main trace persisted after finishing tracing of the user function). But after #7345, the leak checker should only trigger when actual Tracers are leaked. So disabling the memoization when jax_check_tracer_leaks is no longer active shouldn't be necessary. (These PR numbers seem out of order! We're not sure why.) Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com> | 23 September 2022, 19:33:45 UTC |
4dd0d85 | Sudhakar | 23 September 2022, 19:11:56 UTC | add multihost pjit tests | 23 September 2022, 19:11:56 UTC |
a6b24b3 | Jake VanderPlas | 23 September 2022, 19:06:35 UTC | Add regression test for lax.rev simplification error PiperOrigin-RevId: 476430486 | 23 September 2022, 19:07:15 UTC |
ecb27a9 | Yash Katariya | 23 September 2022, 18:40:01 UTC | Update the `_check_special` code to not use xla_shape since its deprecated and does not work with Array. PiperOrigin-RevId: 476422732 | 23 September 2022, 18:40:32 UTC |
d078f3f | jax authors | 23 September 2022, 18:31:37 UTC | Merge pull request #12478 from sharadmv:sharding-docs PiperOrigin-RevId: 476420315 | 23 September 2022, 18:31:37 UTC |
e8865c8 | jax authors | 23 September 2022, 17:55:10 UTC | Merge pull request #12481 from kho:changelist/476272494 PiperOrigin-RevId: 476411483 | 23 September 2022, 17:55:10 UTC |
c823151 | Ke Wu | 23 September 2022, 04:33:07 UTC | Allow transpose axes to be negative to match (undocumented) NumPy behavior | 23 September 2022, 17:18:23 UTC |
0c08547 | Tres Popp | 23 September 2022, 16:34:12 UTC | Modify CorrCoef test to not rely on floating poing representation of 1/3 The operation computed an average while using the dimension of size 3. This is then changed to multiplying by 1/3 with compilers, but 1/3 cannot be represented perfectly. That made this test case rely on a very precise result from an unrepresentable calculation. PiperOrigin-RevId: 476391389 | 23 September 2022, 16:39:01 UTC |