b0046de | jax authors | 06 March 2024, 23:06:15 UTC | Support automating single slice GKE TPU clusters PiperOrigin-RevId: 613354286 | 08 March 2024, 04:46:39 UTC |
5c99473 | jax authors | 08 March 2024, 02:53:09 UTC | Merge pull request #20127 from jakevdp:fix-script PiperOrigin-RevId: 613774212 | 08 March 2024, 02:53:09 UTC |
1ed5883 | jax authors | 08 March 2024, 02:37:24 UTC | Merge pull request #20108 from selamw1:modify-nn-doc PiperOrigin-RevId: 613770878 | 08 March 2024, 02:37:24 UTC |
2eff1f0 | Jake VanderPlas | 08 March 2024, 02:32:18 UTC | Guard script against execution on import | 08 March 2024, 02:32:18 UTC |
b22c549 | jax authors | 08 March 2024, 02:28:19 UTC | Merge pull request #20126 from jakevdp:fix-lint PiperOrigin-RevId: 613768945 | 08 March 2024, 02:28:19 UTC |
6d67aa2 | Jake VanderPlas | 08 March 2024, 02:18:04 UTC | Fix mypy errors | 08 March 2024, 02:18:04 UTC |
0e79f95 | Jake VanderPlas | 08 March 2024, 02:12:42 UTC | lint: fix unused import | 08 March 2024, 02:12:42 UTC |
8ac2913 | Selam Waktola | 06 March 2024, 23:43:42 UTC | minor modification for silu and swish func description Update 'aka' only inside functions.py modify SiLU (a.k.a. swish) activation function. to SiLU (aka swish) activation function. | 07 March 2024, 23:40:39 UTC |
c986cbc | Yash Katariya | 07 March 2024, 21:58:34 UTC | Disable the input sharding propagation temporarily PiperOrigin-RevId: 613694719 | 07 March 2024, 21:59:26 UTC |
00c231c | Yash Katariya | 07 March 2024, 21:33:13 UTC | Reuse the utility `_gspmd_to_named_sharding_via_mesh` in other places PiperOrigin-RevId: 613686995 | 07 March 2024, 21:38:42 UTC |
fa17dac | jax authors | 07 March 2024, 21:28:57 UTC | Merge pull request #20042 from abhinavgoel95:patch-1 PiperOrigin-RevId: 613685424 | 07 March 2024, 21:28:57 UTC |
a7a9f85 | Abhinav Goel | 07 March 2024, 19:55:55 UTC | Added license information | 07 March 2024, 19:55:55 UTC |
f0afc1b | Kanglan Tang | 07 March 2024, 18:16:51 UTC | Add an experimental build-only continuous cross compile build for Linux Aarch64 PiperOrigin-RevId: 613624879 | 07 March 2024, 18:17:43 UTC |
c94ec17 | Sergei Lebedev | 07 March 2024, 12:36:45 UTC | Added a missing cast in _compute_pointers_from_indices This should fix DecodeAttentionTest on H100. PiperOrigin-RevId: 613540340 | 07 March 2024, 12:37:36 UTC |
d5ffc7f | jax authors | 07 March 2024, 08:32:14 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/421b738e400a15c053b02924712a0e915b73cf7b. PiperOrigin-RevId: 613486135 | 07 March 2024, 08:33:07 UTC |
ea7c1c6 | Tomás Longeri | 07 March 2024, 01:52:22 UTC | [Mosaic] Add debug-assert-insertion back into pipeline (as an option) This was accidentally removed when moving the pipeline into C++ (into custom_call_emitter.cc) in cl/596464480 and cl/597332393. Also fix the dump prefix from post-hlo-vectorization to post-hlo-conversion, which was also an unintended change from these CLs. PiperOrigin-RevId: 613400543 | 07 March 2024, 01:53:07 UTC |
75f2f75 | Jevin Jiang | 06 March 2024, 22:25:17 UTC | [XLA:Mosaic] Support input offset (replicated, 0) in shapecast. PiperOrigin-RevId: 613340933 | 06 March 2024, 22:26:00 UTC |
e7eb207 | Peter Hawkins | 06 March 2024, 21:49:54 UTC | Change xla_bridge_test to expect a bytes FDO profile instead of a string in future jaxlib versions. Change in preparation for nanobind migration. PiperOrigin-RevId: 613329489 | 06 March 2024, 21:50:36 UTC |
ef8f4fc | jax authors | 06 March 2024, 19:58:30 UTC | Merge pull request #20105 from jakevdp:dead-code PiperOrigin-RevId: 613293830 | 06 March 2024, 19:58:30 UTC |
1cb8d31 | Yash Katariya | 06 March 2024, 19:41:34 UTC | Convert in_shardings to physical shardings in cpp dispatch path because the same happens with prng arrays. Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch. PiperOrigin-RevId: 613289252 | 06 March 2024, 19:42:40 UTC |
b349328 | Jake VanderPlas | 06 March 2024, 19:30:48 UTC | Remove some dead code | 06 March 2024, 19:30:48 UTC |
fc8dc83 | Tongfei Guo | 06 March 2024, 18:33:46 UTC | Re-enable disabled pjit tests due to MSAN failure. PiperOrigin-RevId: 613266308 | 06 March 2024, 18:36:04 UTC |
026b2d2 | jax authors | 06 March 2024, 18:16:40 UTC | Merge pull request #20103 from froystig:random-dce PiperOrigin-RevId: 613259747 | 06 March 2024, 18:16:40 UTC |
612d4d5 | jax authors | 06 March 2024, 17:54:25 UTC | Merge pull request #20101 from andportnoy:aportnoy/include-cstdint PiperOrigin-RevId: 613250941 | 06 March 2024, 17:54:25 UTC |
6240dc6 | Roy Frostig | 06 March 2024, 17:44:40 UTC | remove dead code in `random` | 06 March 2024, 17:44:40 UTC |
30973a9 | Sharad Vikram | 06 March 2024, 17:15:36 UTC | [Pallas] Pass in compiler params via explicit compiler_params argument instead of passing via **kwargs This is a change that makes the API a bit more intuitive and avoids footguns like accidentally passing in `in_spec` instead of `in_specs` because previously kwargs that weren't used by any downstream lowering would be ignored and users would get weird errors as a result. This change doesn't deprecate the old way of passing in compiler params but it will be deprecated soon after this. PiperOrigin-RevId: 613239439 | 06 March 2024, 17:16:22 UTC |
dcb58bb | Andrey Portnoy | 06 March 2024, 16:58:15 UTC | Include <cstdint> in files where it is used | 06 March 2024, 16:58:15 UTC |
d0e0ca1 | Sergei Lebedev | 06 March 2024, 14:14:30 UTC | Ensured that Pallas GPU tests only run in x32 mode We do not yet properly handle x64. PiperOrigin-RevId: 613190060 | 06 March 2024, 14:15:13 UTC |
8aefe5e | jax authors | 06 March 2024, 07:42:26 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/08b6a3deeb1a90da06c2ca23ac795d8ae188f474. PiperOrigin-RevId: 613101323 | 06 March 2024, 07:43:13 UTC |
1a193ea | Peter Hawkins | 06 March 2024, 02:23:49 UTC | Fix segfault in cuda_plugin_extension. The nanobind switch for the GPU callback code means that we are now using the NumPy APIs rather than pybind11's clone of them. It is important to initialize the NumPy APIs before using them in each module. PiperOrigin-RevId: 613036056 | 06 March 2024, 02:31:50 UTC |
c37fd1b | Yash Katariya | 06 March 2024, 02:21:31 UTC | Move Device_put tests out of a disabled test class and execute them (so they are tested) PiperOrigin-RevId: 613035562 | 06 March 2024, 02:22:18 UTC |
ca3f3f0 | Yash Katariya | 06 March 2024, 00:35:57 UTC | Make sure that if gspmd_sharding1 == gspmd_sharding2, then their hash also is equal. PiperOrigin-RevId: 613009976 | 06 March 2024, 00:36:49 UTC |
15da713 | jax authors | 05 March 2024, 23:56:16 UTC | [XLA:SPMD] Do not propagate sharding to parameter/output if it does not evenly partition the parameter/output. PiperOrigin-RevId: 612998062 | 05 March 2024, 23:57:05 UTC |
0c9bbe4 | jax authors | 05 March 2024, 23:46:01 UTC | Merge pull request #20087 from jakevdp:fix-scipy-doc PiperOrigin-RevId: 612994633 | 05 March 2024, 23:46:01 UTC |
f35538b | Jake VanderPlas | 05 March 2024, 23:22:42 UTC | DOC: fix two minor doc issues | 05 March 2024, 23:22:42 UTC |
20090dd | jax authors | 05 March 2024, 21:10:25 UTC | Merge pull request #20083 from mattjj:attrs-fix-tracer-lifetime PiperOrigin-RevId: 612944372 | 05 March 2024, 21:10:25 UTC |
c44dda8 | Matthew Johnson | 05 March 2024, 20:08:44 UTC | [attrs] fix tracer lifetime bug, fixes #20082 | 05 March 2024, 20:08:44 UTC |
67e3542 | jax authors | 05 March 2024, 19:57:58 UTC | Merge pull request #20080 from jakevdp:key-reuse-srcinfo PiperOrigin-RevId: 612920951 | 05 March 2024, 19:57:58 UTC |
baed4eb | jax authors | 05 March 2024, 19:40:05 UTC | Merge pull request #20081 from jakevdp:setup-tpu-err PiperOrigin-RevId: 612914123 | 05 March 2024, 19:40:05 UTC |
288aca1 | jax authors | 05 March 2024, 19:31:42 UTC | Merge pull request #20078 from mattjj:remat-saving-collectives-fix PiperOrigin-RevId: 612912239 | 05 March 2024, 19:31:42 UTC |
5005890 | Benjamin Kramer | 05 March 2024, 19:17:05 UTC | Enable more tests on H100 https://github.com/llvm/llvm-project/commit/20895965b2ed1bd037c64430dba98245ffa1232b fixed these PiperOrigin-RevId: 612907679 | 05 March 2024, 19:22:56 UTC |
d8a4ea4 | Jake VanderPlas | 05 March 2024, 19:20:38 UTC | Use shorter error message for jax.tools.colab_tpu.setup_tpu() | 05 March 2024, 19:20:38 UTC |
05f54b6 | Jevin Jiang | 05 March 2024, 19:13:42 UTC | [XLA:Mosaic] Use different MXU shape based on the target PiperOrigin-RevId: 612906617 | 05 March 2024, 19:14:24 UTC |
3d32262 | Matthew Johnson | 05 March 2024, 18:39:25 UTC | ignore NamedAxisEffect for remat and dce purposes | 05 March 2024, 19:04:23 UTC |
735ec63 | Jake VanderPlas | 05 March 2024, 19:02:39 UTC | [key reuse] improve error message using source_info_util | 05 March 2024, 19:02:39 UTC |
32fec82 | jax authors | 05 March 2024, 18:53:45 UTC | Merge pull request #20077 from jakevdp:fix-dunder-array PiperOrigin-RevId: 612895096 | 05 March 2024, 18:53:45 UTC |
430c7ed | jax authors | 05 March 2024, 18:44:08 UTC | Merge pull request #20070 from jakevdp:key-reuse-errs PiperOrigin-RevId: 612895094 | 05 March 2024, 18:44:08 UTC |
851b82b | Jake VanderPlas | 05 March 2024, 17:31:16 UTC | Add copy argument to Array.__array__ | 05 March 2024, 17:31:16 UTC |
bb91bf2 | Jake VanderPlas | 04 March 2024, 22:50:39 UTC | [key reuse] improve some key reuse errors. | 05 March 2024, 16:14:20 UTC |
28fa886 | jax authors | 05 March 2024, 07:04:30 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/9e8b8b45b5289fe2e486737e9154908bf781aa5c. PiperOrigin-RevId: 612715965 | 05 March 2024, 07:05:31 UTC |
843aa21 | jax authors | 05 March 2024, 02:12:40 UTC | Merge pull request #20071 from jakevdp:key-reuse-docs PiperOrigin-RevId: 612654740 | 05 March 2024, 02:12:40 UTC |
9996b1f | Chris Jones | 05 March 2024, 01:36:23 UTC | [jax_triton] Add parameter allowing user to compile for specific compute capability. PiperOrigin-RevId: 612647104 | 05 March 2024, 01:37:04 UTC |
bc3f123 | jax authors | 05 March 2024, 01:17:49 UTC | Merge pull request #20069 from jakevdp:key-reuse-equality PiperOrigin-RevId: 612642622 | 05 March 2024, 01:17:49 UTC |
9a4b0fc | Jake VanderPlas | 05 March 2024, 01:16:55 UTC | [key reuse] improve module docs | 05 March 2024, 01:16:55 UTC |
6207977 | Peter Hawkins | 05 March 2024, 00:59:05 UTC | Disable some tests that fail on H100 in CI. PiperOrigin-RevId: 612637375 | 05 March 2024, 00:59:52 UTC |
3fe65e2 | Philip Pham | 05 March 2024, 00:33:13 UTC | Pipe `tiled` through `all_to_all` primitive The `_all_to_all_transpose_rule` calls `all_to_all` which can accept a `tiled` argument. Thus, for the transpose to know the right value of `tiled` to pass, we need to plumb the `tiled` argument through the primitive and various interpreters, even though it's a no-op because the `tiled` argument is handled outside the primitive. It would be cleaner to handle `tiled` inside the primitive, but I will leave that for followup work. Fixes #15982. PiperOrigin-RevId: 612628600 | 05 March 2024, 00:33:56 UTC |
40038d6 | Yash Katariya | 04 March 2024, 23:34:22 UTC | Rename test PiperOrigin-RevId: 612609237 | 04 March 2024, 23:35:02 UTC |
feda85d | Peter Hawkins | 04 March 2024, 22:07:56 UTC | Replace references to xla/python/status_casters.h with xla/pjrt/status_casters.h, which its current home. PiperOrigin-RevId: 612578488 | 04 March 2024, 22:11:01 UTC |
6f7be3c | Peter Hawkins | 04 March 2024, 22:00:36 UTC | Define lax.Precision directly in Python, rather than inheriting from a C++ type in jaxlib. Historically, we defined Precision to be an enum exported from jaxlib using pybind11, since that was the type the old XLA ComputationBuilder classes expected as input. But we build IR using StableHLO MLIR builders these days, and there's no reason for the JAX-level Precision type to match the XLA-internal one. In a future change I plan to change the definition of Precision in jaxlib to be defined using nanobind instead of pybind11. Nanobind defines its enum classes to be final by default, which precludes this inheritance, and that's probably a good design decision by nanobind. But as discussed above, there's no good reason to inherit in the first place. PiperOrigin-RevId: 612575404 | 04 March 2024, 22:01:31 UTC |
84d11d7 | Jake VanderPlas | 04 March 2024, 21:32:35 UTC | [key reuse] don't consume on equality check | 04 March 2024, 21:32:35 UTC |
67b0eb3 | Yash Katariya | 04 March 2024, 21:14:47 UTC | Improve pytree mismatch error in AOT PiperOrigin-RevId: 612560820 | 04 March 2024, 21:15:32 UTC |
2480ca3 | Abhinav Goel | 04 March 2024, 19:42:01 UTC | respond to reviewer's comments | 04 March 2024, 19:42:01 UTC |
a745b8e | jax authors | 04 March 2024, 19:06:25 UTC | Merge pull request #20067 from jakevdp:copy-failure PiperOrigin-RevId: 612513014 | 04 March 2024, 19:06:25 UTC |
ee963a7 | jax authors | 04 March 2024, 18:42:29 UTC | Merge pull request #20065 from google:dependabot/github_actions/actions/cache-4.0.1 PiperOrigin-RevId: 612503993 | 04 March 2024, 18:42:29 UTC |
32da56f | Jake VanderPlas | 04 March 2024, 18:39:38 UTC | jnp.array: fix failure under numpy 2.0 copy semantics | 04 March 2024, 18:39:38 UTC |
8a62918 | dependabot[bot] | 04 March 2024, 17:06:05 UTC | Bump actions/cache from 4.0.0 to 4.0.1 Bumps [actions/cache](https://github.com/actions/cache) from 4.0.0 to 4.0.1. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/13aacd865c20de90d75de3b17ebe84f7a17d57d2...ab5e6d0c87105b4c9c2047343972218f562e4319) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] <support@github.com> | 04 March 2024, 17:06:05 UTC |
46913d7 | Yash Katariya | 04 March 2024, 16:50:58 UTC | Reverts 18344e8647f3459ae6a559a1e0a322120ac50782 PiperOrigin-RevId: 612466742 | 04 March 2024, 16:51:54 UTC |
81363ce | jax authors | 04 March 2024, 16:40:13 UTC | Merge pull request #19808 from Micky774:cc_check PiperOrigin-RevId: 612463272 | 04 March 2024, 16:40:13 UTC |
63cc46d | Adam Paszke | 04 March 2024, 16:25:17 UTC | Treat both aarch and arm as possible ARM prefixes On macOS `platform.machine()` returns `arm64` instead of `aarch64`. PiperOrigin-RevId: 612458486 | 04 March 2024, 16:26:04 UTC |
0549f18 | Adam Paszke | 04 March 2024, 14:16:24 UTC | Refine regions_with_inaccuracies to account for ARM numerics differences cc @pearu PiperOrigin-RevId: 612424597 | 04 March 2024, 14:17:05 UTC |
7514d5c | jax authors | 04 March 2024, 13:47:41 UTC | [triton] Add clustering support and test PiperOrigin-RevId: 612417957 | 04 March 2024, 13:51:10 UTC |
18344e8 | Adam Paszke | 04 March 2024, 13:47:26 UTC | Reverts 5c9c57fd6ff747ea37a2b74ff327a48fb72b3e69 PiperOrigin-RevId: 612417903 | 04 March 2024, 13:50:55 UTC |
5283d4b | Sergei Lebedev | 04 March 2024, 13:41:29 UTC | Axis names are now tracked via an effect This allows propagating the names bottom up -- from equations to the jaxpr, instead of "discovering" them top-down by traversing (and rebuilding) the jaxpr via core.subst_axis_names. PiperOrigin-RevId: 612416803 | 04 March 2024, 13:42:03 UTC |
2dd5e9e | jax authors | 04 March 2024, 06:23:46 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/e2af8488f6d3115045129c00215a80c2ab674224. PiperOrigin-RevId: 612318379 | 04 March 2024, 06:24:30 UTC |
9fff9ae | Meekail Zain | 03 March 2024, 19:57:26 UTC | Update | 03 March 2024, 19:57:26 UTC |
0b70244 | Yash Katariya | 02 March 2024, 21:34:46 UTC | Thread out_avals to MeshExecutable PiperOrigin-RevId: 612037684 | 02 March 2024, 21:35:31 UTC |
8569b89 | jax authors | 02 March 2024, 07:44:18 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/bbf2f8bcfae6f08a248e5e415a4f6d8be2c14f33. PiperOrigin-RevId: 611934527 | 02 March 2024, 07:44:58 UTC |
ab83469 | Blake Hechtman | 02 March 2024, 06:39:42 UTC | [PALLAS] add test for large indexing. PiperOrigin-RevId: 611925093 | 02 March 2024, 06:40:24 UTC |
51a31e5 | Jevin Jiang | 02 March 2024, 01:33:24 UTC | [Pallas] Use scratch_shapes for scratch operands in flash attention kernel. PiperOrigin-RevId: 611884935 | 02 March 2024, 01:34:07 UTC |
28f84eb | jax authors | 01 March 2024, 23:23:29 UTC | Merge pull request #20044 from mattjj:mutable-arrays PiperOrigin-RevId: 611866507 | 01 March 2024, 23:23:29 UTC |
3a403f2 | Matthew Johnson | 01 March 2024, 19:07:45 UTC | [mutable-arrays] move MutableArray, add eager, improve tests, fix bug 1. move MutableArray to core.py, and some handlers to their respective files 2. fix a bug in aliasing setup (it was just broken before, now better test coverage) 3. add eager support by enabling get_p, swap_p, and addupdate_p impls 4. improve tests slightly | 01 March 2024, 23:03:23 UTC |
04f6bfa | Anselm Levskaya | 01 March 2024, 22:23:47 UTC | Prevent accidental upcasting in jax.nn.initializers. Currently distribution parameters such as stddev and scale are expected to be weakly typed scalars. When they're passed as float32 they can cause an upcast of the initialized arrays even when the dtype is specified as e.g. bfloat16. Some users were surprised by this. PiperOrigin-RevId: 611858446 | 01 March 2024, 22:24:26 UTC |
5c9c57f | Yash Katariya | 01 March 2024, 21:14:50 UTC | Allow partially specified shardings in `in_shardings` and `out_shardings` parameters of `jax.jit`. PiperOrigin-RevId: 611848778 | 01 March 2024, 21:15:26 UTC |
3a89557 | Yue Sheng | 01 March 2024, 19:38:37 UTC | Update `CompiledMemoryStats` in xla/python to include host memory stats and add a few tests to memories_test.py PiperOrigin-RevId: 611834035 | 01 March 2024, 19:39:25 UTC |
1ae2022 | Eugene Zhulenev | 01 March 2024, 18:28:11 UTC | [jax-triton] Do not capture jax-triton calls that require autotuning PiperOrigin-RevId: 611823473 | 01 March 2024, 18:28:47 UTC |
8e2a8b7 | jax authors | 01 March 2024, 18:12:33 UTC | Merge pull request #20043 from mattjj:select-transpose-extra-code-flag PiperOrigin-RevId: 611820937 | 01 March 2024, 18:12:33 UTC |
dd0ce6e | Matthew Johnson | 01 March 2024, 17:48:48 UTC | add upgrade (aka to-be-removed) flag for new select rule | 01 March 2024, 17:50:01 UTC |
2761f26 | Yash Katariya | 01 March 2024, 17:27:57 UTC | Set out_mut to `None` as default on `from_hlo` instead of in `__init__` of `MeshComputation` and correct the types too. PiperOrigin-RevId: 611814102 | 01 March 2024, 17:28:42 UTC |
5a732ad | Abhinav Goel | 01 March 2024, 17:14:11 UTC | Adding script to convert NVIDIA nsys profiles to pbtxt | 01 March 2024, 17:14:11 UTC |
cfeb113 | Yash Katariya | 01 March 2024, 16:39:49 UTC | Partially rollback propagating sharding to inputs because SPMD chooses wrong shardings when shape is not divisble by shard_shape. PiperOrigin-RevId: 611806938 | 01 March 2024, 16:40:28 UTC |
1615e7a | jax authors | 01 March 2024, 08:15:12 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/1fc43bb444ad46b9ebf93915d27c02743d3a8503. PiperOrigin-RevId: 611728098 | 01 March 2024, 08:16:01 UTC |
47f5375 | Yash Katariya | 01 March 2024, 06:57:01 UTC | Correctly set the version number PiperOrigin-RevId: 611713955 | 01 March 2024, 06:57:50 UTC |
2e83fed | jax authors | 01 March 2024, 06:18:05 UTC | Merge pull request #20026 from mattjj:mutable-arrays PiperOrigin-RevId: 611707543 | 01 March 2024, 06:18:05 UTC |
ab0f706 | Matthew Johnson | 26 February 2024, 22:46:05 UTC | [mutable-arrays] allow state effects in jit by building in run_state with help from @sharadmv, @yashkatariya, @dougalm, and others The basic strategy is to apply discharge_state when lowering a jaxpr with state effects to HLO, and update the dispatch path accordingly. Specifically: 1. in tests only for now, introduce a MutableArray data type; 2. teach jit to abstract it to a Ref(ShapedArray) type, register an input handler, etc; 3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing; 4. teach the output side of the dispatch path to drop those outputs. As an alternative to (3), we could potentially lower away the effects at a higher level, like in _pjit_lower_cached. They are similar because _pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation. I decided to do it in lower_sharding_computation mainly because that's closer to where we set up aliases, and I wanted to make mutable arrays correspond to aliased inputs/outputs on the XLA computation. | 01 March 2024, 05:50:19 UTC |
48e6e0d | Yash Katariya | 01 March 2024, 02:45:00 UTC | Raise an error if the names intersect in `save_and_offload_only_these_names` policy PiperOrigin-RevId: 611666221 | 01 March 2024, 02:45:43 UTC |
32bb3b0 | jax authors | 01 March 2024, 01:39:28 UTC | Use `$(RULEDIR)` to avoid an implicit dependency on `output_to_genfiles`. PiperOrigin-RevId: 611652089 | 01 March 2024, 01:40:18 UTC |
30d3bb4 | jax authors | 01 March 2024, 00:08:38 UTC | Merge pull request #19795 from jakevdp:key-reuse-eager PiperOrigin-RevId: 611626544 | 01 March 2024, 00:08:38 UTC |
d08e9a0 | Jake VanderPlas | 29 February 2024, 23:30:19 UTC | [key reuse] add eager checks | 29 February 2024, 23:30:19 UTC |
087f99a | Jieying Luo | 29 February 2024, 23:14:15 UTC | Support mocking number of GPUs in CUDA plugin. Also move reading jax config value to be right before the client is created. Previously they were read before calling register_plugin, which happens during import and before any call of jax.config.update. The decorator in mock_gpu_test was used wrongly. jtu.run_on_devices will create the client before jax.config.update is called, which is not desired. Remove the decorator will not fail CPU/TPU tests because the mesh will check the num_shard and the number of devices in the client and skip it if it does not match. generate_pjrt_gpu_plugin_options is only used in places that do not require compatibility so do not need to update xla_client version. PiperOrigin-RevId: 611610915 | 29 February 2024, 23:15:06 UTC |
b2058d7 | jax authors | 29 February 2024, 20:22:54 UTC | Merge pull request #20025 from jakevdp:fix-key-reuse-test PiperOrigin-RevId: 611557123 | 29 February 2024, 20:22:54 UTC |