https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
b0046de Support automating single slice GKE TPU clusters PiperOrigin-RevId: 613354286 08 March 2024, 04:46:39 UTC
5c99473 Merge pull request #20127 from jakevdp:fix-script PiperOrigin-RevId: 613774212 08 March 2024, 02:53:09 UTC
1ed5883 Merge pull request #20108 from selamw1:modify-nn-doc PiperOrigin-RevId: 613770878 08 March 2024, 02:37:24 UTC
2eff1f0 Guard script against execution on import 08 March 2024, 02:32:18 UTC
b22c549 Merge pull request #20126 from jakevdp:fix-lint PiperOrigin-RevId: 613768945 08 March 2024, 02:28:19 UTC
6d67aa2 Fix mypy errors 08 March 2024, 02:18:04 UTC
0e79f95 lint: fix unused import 08 March 2024, 02:12:42 UTC
8ac2913 minor modification for silu and swish func description Update 'aka' only inside functions.py modify SiLU (a.k.a. swish) activation function. to SiLU (aka swish) activation function. 07 March 2024, 23:40:39 UTC
c986cbc Disable the input sharding propagation temporarily PiperOrigin-RevId: 613694719 07 March 2024, 21:59:26 UTC
00c231c Reuse the utility `_gspmd_to_named_sharding_via_mesh` in other places PiperOrigin-RevId: 613686995 07 March 2024, 21:38:42 UTC
fa17dac Merge pull request #20042 from abhinavgoel95:patch-1 PiperOrigin-RevId: 613685424 07 March 2024, 21:28:57 UTC
a7a9f85 Added license information 07 March 2024, 19:55:55 UTC
f0afc1b Add an experimental build-only continuous cross compile build for Linux Aarch64 PiperOrigin-RevId: 613624879 07 March 2024, 18:17:43 UTC
c94ec17 Added a missing cast in _compute_pointers_from_indices This should fix DecodeAttentionTest on H100. PiperOrigin-RevId: 613540340 07 March 2024, 12:37:36 UTC
d5ffc7f Update XLA dependency to use revision http://github.com/openxla/xla/commit/421b738e400a15c053b02924712a0e915b73cf7b. PiperOrigin-RevId: 613486135 07 March 2024, 08:33:07 UTC
ea7c1c6 [Mosaic] Add debug-assert-insertion back into pipeline (as an option) This was accidentally removed when moving the pipeline into C++ (into custom_call_emitter.cc) in cl/596464480 and cl/597332393. Also fix the dump prefix from post-hlo-vectorization to post-hlo-conversion, which was also an unintended change from these CLs. PiperOrigin-RevId: 613400543 07 March 2024, 01:53:07 UTC
75f2f75 [XLA:Mosaic] Support input offset (replicated, 0) in shapecast. PiperOrigin-RevId: 613340933 06 March 2024, 22:26:00 UTC
e7eb207 Change xla_bridge_test to expect a bytes FDO profile instead of a string in future jaxlib versions. Change in preparation for nanobind migration. PiperOrigin-RevId: 613329489 06 March 2024, 21:50:36 UTC
ef8f4fc Merge pull request #20105 from jakevdp:dead-code PiperOrigin-RevId: 613293830 06 March 2024, 19:58:30 UTC
1cb8d31 Convert in_shardings to physical shardings in cpp dispatch path because the same happens with prng arrays. Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch. PiperOrigin-RevId: 613289252 06 March 2024, 19:42:40 UTC
b349328 Remove some dead code 06 March 2024, 19:30:48 UTC
fc8dc83 Re-enable disabled pjit tests due to MSAN failure. PiperOrigin-RevId: 613266308 06 March 2024, 18:36:04 UTC
026b2d2 Merge pull request #20103 from froystig:random-dce PiperOrigin-RevId: 613259747 06 March 2024, 18:16:40 UTC
612d4d5 Merge pull request #20101 from andportnoy:aportnoy/include-cstdint PiperOrigin-RevId: 613250941 06 March 2024, 17:54:25 UTC
6240dc6 remove dead code in `random` 06 March 2024, 17:44:40 UTC
30973a9 [Pallas] Pass in compiler params via explicit compiler_params argument instead of passing via **kwargs This is a change that makes the API a bit more intuitive and avoids footguns like accidentally passing in `in_spec` instead of `in_specs` because previously kwargs that weren't used by any downstream lowering would be ignored and users would get weird errors as a result. This change doesn't deprecate the old way of passing in compiler params but it will be deprecated soon after this. PiperOrigin-RevId: 613239439 06 March 2024, 17:16:22 UTC
dcb58bb Include <cstdint> in files where it is used 06 March 2024, 16:58:15 UTC
d0e0ca1 Ensured that Pallas GPU tests only run in x32 mode We do not yet properly handle x64. PiperOrigin-RevId: 613190060 06 March 2024, 14:15:13 UTC
8aefe5e Update XLA dependency to use revision http://github.com/openxla/xla/commit/08b6a3deeb1a90da06c2ca23ac795d8ae188f474. PiperOrigin-RevId: 613101323 06 March 2024, 07:43:13 UTC
1a193ea Fix segfault in cuda_plugin_extension. The nanobind switch for the GPU callback code means that we are now using the NumPy APIs rather than pybind11's clone of them. It is important to initialize the NumPy APIs before using them in each module. PiperOrigin-RevId: 613036056 06 March 2024, 02:31:50 UTC
c37fd1b Move Device_put tests out of a disabled test class and execute them (so they are tested) PiperOrigin-RevId: 613035562 06 March 2024, 02:22:18 UTC
ca3f3f0 Make sure that if gspmd_sharding1 == gspmd_sharding2, then their hash also is equal. PiperOrigin-RevId: 613009976 06 March 2024, 00:36:49 UTC
15da713 [XLA:SPMD] Do not propagate sharding to parameter/output if it does not evenly partition the parameter/output. PiperOrigin-RevId: 612998062 05 March 2024, 23:57:05 UTC
0c9bbe4 Merge pull request #20087 from jakevdp:fix-scipy-doc PiperOrigin-RevId: 612994633 05 March 2024, 23:46:01 UTC
f35538b DOC: fix two minor doc issues 05 March 2024, 23:22:42 UTC
20090dd Merge pull request #20083 from mattjj:attrs-fix-tracer-lifetime PiperOrigin-RevId: 612944372 05 March 2024, 21:10:25 UTC
c44dda8 [attrs] fix tracer lifetime bug, fixes #20082 05 March 2024, 20:08:44 UTC
67e3542 Merge pull request #20080 from jakevdp:key-reuse-srcinfo PiperOrigin-RevId: 612920951 05 March 2024, 19:57:58 UTC
baed4eb Merge pull request #20081 from jakevdp:setup-tpu-err PiperOrigin-RevId: 612914123 05 March 2024, 19:40:05 UTC
288aca1 Merge pull request #20078 from mattjj:remat-saving-collectives-fix PiperOrigin-RevId: 612912239 05 March 2024, 19:31:42 UTC
5005890 Enable more tests on H100 https://github.com/llvm/llvm-project/commit/20895965b2ed1bd037c64430dba98245ffa1232b fixed these PiperOrigin-RevId: 612907679 05 March 2024, 19:22:56 UTC
d8a4ea4 Use shorter error message for jax.tools.colab_tpu.setup_tpu() 05 March 2024, 19:20:38 UTC
05f54b6 [XLA:Mosaic] Use different MXU shape based on the target PiperOrigin-RevId: 612906617 05 March 2024, 19:14:24 UTC
3d32262 ignore NamedAxisEffect for remat and dce purposes 05 March 2024, 19:04:23 UTC
735ec63 [key reuse] improve error message using source_info_util 05 March 2024, 19:02:39 UTC
32fec82 Merge pull request #20077 from jakevdp:fix-dunder-array PiperOrigin-RevId: 612895096 05 March 2024, 18:53:45 UTC
430c7ed Merge pull request #20070 from jakevdp:key-reuse-errs PiperOrigin-RevId: 612895094 05 March 2024, 18:44:08 UTC
851b82b Add copy argument to Array.__array__ 05 March 2024, 17:31:16 UTC
bb91bf2 [key reuse] improve some key reuse errors. 05 March 2024, 16:14:20 UTC
28fa886 Update XLA dependency to use revision http://github.com/openxla/xla/commit/9e8b8b45b5289fe2e486737e9154908bf781aa5c. PiperOrigin-RevId: 612715965 05 March 2024, 07:05:31 UTC
843aa21 Merge pull request #20071 from jakevdp:key-reuse-docs PiperOrigin-RevId: 612654740 05 March 2024, 02:12:40 UTC
9996b1f [jax_triton] Add parameter allowing user to compile for specific compute capability. PiperOrigin-RevId: 612647104 05 March 2024, 01:37:04 UTC
bc3f123 Merge pull request #20069 from jakevdp:key-reuse-equality PiperOrigin-RevId: 612642622 05 March 2024, 01:17:49 UTC
9a4b0fc [key reuse] improve module docs 05 March 2024, 01:16:55 UTC
6207977 Disable some tests that fail on H100 in CI. PiperOrigin-RevId: 612637375 05 March 2024, 00:59:52 UTC
3fe65e2 Pipe `tiled` through `all_to_all` primitive The `_all_to_all_transpose_rule` calls `all_to_all` which can accept a `tiled` argument. Thus, for the transpose to know the right value of `tiled` to pass, we need to plumb the `tiled` argument through the primitive and various interpreters, even though it's a no-op because the `tiled` argument is handled outside the primitive. It would be cleaner to handle `tiled` inside the primitive, but I will leave that for followup work. Fixes #15982. PiperOrigin-RevId: 612628600 05 March 2024, 00:33:56 UTC
40038d6 Rename test PiperOrigin-RevId: 612609237 04 March 2024, 23:35:02 UTC
feda85d Replace references to xla/python/status_casters.h with xla/pjrt/status_casters.h, which its current home. PiperOrigin-RevId: 612578488 04 March 2024, 22:11:01 UTC
6f7be3c Define lax.Precision directly in Python, rather than inheriting from a C++ type in jaxlib. Historically, we defined Precision to be an enum exported from jaxlib using pybind11, since that was the type the old XLA ComputationBuilder classes expected as input. But we build IR using StableHLO MLIR builders these days, and there's no reason for the JAX-level Precision type to match the XLA-internal one. In a future change I plan to change the definition of Precision in jaxlib to be defined using nanobind instead of pybind11. Nanobind defines its enum classes to be final by default, which precludes this inheritance, and that's probably a good design decision by nanobind. But as discussed above, there's no good reason to inherit in the first place. PiperOrigin-RevId: 612575404 04 March 2024, 22:01:31 UTC
84d11d7 [key reuse] don't consume on equality check 04 March 2024, 21:32:35 UTC
67b0eb3 Improve pytree mismatch error in AOT PiperOrigin-RevId: 612560820 04 March 2024, 21:15:32 UTC
2480ca3 respond to reviewer's comments 04 March 2024, 19:42:01 UTC
a745b8e Merge pull request #20067 from jakevdp:copy-failure PiperOrigin-RevId: 612513014 04 March 2024, 19:06:25 UTC
ee963a7 Merge pull request #20065 from google:dependabot/github_actions/actions/cache-4.0.1 PiperOrigin-RevId: 612503993 04 March 2024, 18:42:29 UTC
32da56f jnp.array: fix failure under numpy 2.0 copy semantics 04 March 2024, 18:39:38 UTC
8a62918 Bump actions/cache from 4.0.0 to 4.0.1 Bumps [actions/cache](https://github.com/actions/cache) from 4.0.0 to 4.0.1. - [Release notes](https://github.com/actions/cache/releases) - [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md) - [Commits](https://github.com/actions/cache/compare/13aacd865c20de90d75de3b17ebe84f7a17d57d2...ab5e6d0c87105b4c9c2047343972218f562e4319) --- updated-dependencies: - dependency-name: actions/cache dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] <support@github.com> 04 March 2024, 17:06:05 UTC
46913d7 Reverts 18344e8647f3459ae6a559a1e0a322120ac50782 PiperOrigin-RevId: 612466742 04 March 2024, 16:51:54 UTC
81363ce Merge pull request #19808 from Micky774:cc_check PiperOrigin-RevId: 612463272 04 March 2024, 16:40:13 UTC
63cc46d Treat both aarch and arm as possible ARM prefixes On macOS `platform.machine()` returns `arm64` instead of `aarch64`. PiperOrigin-RevId: 612458486 04 March 2024, 16:26:04 UTC
0549f18 Refine regions_with_inaccuracies to account for ARM numerics differences cc @pearu PiperOrigin-RevId: 612424597 04 March 2024, 14:17:05 UTC
7514d5c [triton] Add clustering support and test PiperOrigin-RevId: 612417957 04 March 2024, 13:51:10 UTC
18344e8 Reverts 5c9c57fd6ff747ea37a2b74ff327a48fb72b3e69 PiperOrigin-RevId: 612417903 04 March 2024, 13:50:55 UTC
5283d4b Axis names are now tracked via an effect This allows propagating the names bottom up -- from equations to the jaxpr, instead of "discovering" them top-down by traversing (and rebuilding) the jaxpr via core.subst_axis_names. PiperOrigin-RevId: 612416803 04 March 2024, 13:42:03 UTC
2dd5e9e Update XLA dependency to use revision http://github.com/openxla/xla/commit/e2af8488f6d3115045129c00215a80c2ab674224. PiperOrigin-RevId: 612318379 04 March 2024, 06:24:30 UTC
9fff9ae Update 03 March 2024, 19:57:26 UTC
0b70244 Thread out_avals to MeshExecutable PiperOrigin-RevId: 612037684 02 March 2024, 21:35:31 UTC
8569b89 Update XLA dependency to use revision http://github.com/openxla/xla/commit/bbf2f8bcfae6f08a248e5e415a4f6d8be2c14f33. PiperOrigin-RevId: 611934527 02 March 2024, 07:44:58 UTC
ab83469 [PALLAS] add test for large indexing. PiperOrigin-RevId: 611925093 02 March 2024, 06:40:24 UTC
51a31e5 [Pallas] Use scratch_shapes for scratch operands in flash attention kernel. PiperOrigin-RevId: 611884935 02 March 2024, 01:34:07 UTC
28f84eb Merge pull request #20044 from mattjj:mutable-arrays PiperOrigin-RevId: 611866507 01 March 2024, 23:23:29 UTC
3a403f2 [mutable-arrays] move MutableArray, add eager, improve tests, fix bug 1. move MutableArray to core.py, and some handlers to their respective files 2. fix a bug in aliasing setup (it was just broken before, now better test coverage) 3. add eager support by enabling get_p, swap_p, and addupdate_p impls 4. improve tests slightly 01 March 2024, 23:03:23 UTC
04f6bfa Prevent accidental upcasting in jax.nn.initializers. Currently distribution parameters such as stddev and scale are expected to be weakly typed scalars. When they're passed as float32 they can cause an upcast of the initialized arrays even when the dtype is specified as e.g. bfloat16. Some users were surprised by this. PiperOrigin-RevId: 611858446 01 March 2024, 22:24:26 UTC
5c9c57f Allow partially specified shardings in `in_shardings` and `out_shardings` parameters of `jax.jit`. PiperOrigin-RevId: 611848778 01 March 2024, 21:15:26 UTC
3a89557 Update `CompiledMemoryStats` in xla/python to include host memory stats and add a few tests to memories_test.py PiperOrigin-RevId: 611834035 01 March 2024, 19:39:25 UTC
1ae2022 [jax-triton] Do not capture jax-triton calls that require autotuning PiperOrigin-RevId: 611823473 01 March 2024, 18:28:47 UTC
8e2a8b7 Merge pull request #20043 from mattjj:select-transpose-extra-code-flag PiperOrigin-RevId: 611820937 01 March 2024, 18:12:33 UTC
dd0ce6e add upgrade (aka to-be-removed) flag for new select rule 01 March 2024, 17:50:01 UTC
2761f26 Set out_mut to `None` as default on `from_hlo` instead of in `__init__` of `MeshComputation` and correct the types too. PiperOrigin-RevId: 611814102 01 March 2024, 17:28:42 UTC
5a732ad Adding script to convert NVIDIA nsys profiles to pbtxt 01 March 2024, 17:14:11 UTC
cfeb113 Partially rollback propagating sharding to inputs because SPMD chooses wrong shardings when shape is not divisble by shard_shape. PiperOrigin-RevId: 611806938 01 March 2024, 16:40:28 UTC
1615e7a Update XLA dependency to use revision http://github.com/openxla/xla/commit/1fc43bb444ad46b9ebf93915d27c02743d3a8503. PiperOrigin-RevId: 611728098 01 March 2024, 08:16:01 UTC
47f5375 Correctly set the version number PiperOrigin-RevId: 611713955 01 March 2024, 06:57:50 UTC
2e83fed Merge pull request #20026 from mattjj:mutable-arrays PiperOrigin-RevId: 611707543 01 March 2024, 06:18:05 UTC
ab0f706 [mutable-arrays] allow state effects in jit by building in run_state with help from @sharadmv, @yashkatariya, @dougalm, and others The basic strategy is to apply discharge_state when lowering a jaxpr with state effects to HLO, and update the dispatch path accordingly. Specifically: 1. in tests only for now, introduce a MutableArray data type; 2. teach jit to abstract it to a Ref(ShapedArray) type, register an input handler, etc; 3. call discharge_state in `lower_sharding_computation` to lower a jaxpr with refs to a jaxpr (and then to an HLO) with extra outputs, and set up aliasing; 4. teach the output side of the dispatch path to drop those outputs. As an alternative to (3), we could potentially lower away the effects at a higher level, like in _pjit_lower_cached. They are similar because _pjit_lower_cached is the only (non-xmap) caller of lower_sharding_computation. I decided to do it in lower_sharding_computation mainly because that's closer to where we set up aliases, and I wanted to make mutable arrays correspond to aliased inputs/outputs on the XLA computation. 01 March 2024, 05:50:19 UTC
48e6e0d Raise an error if the names intersect in `save_and_offload_only_these_names` policy PiperOrigin-RevId: 611666221 01 March 2024, 02:45:43 UTC
32bb3b0 Use `$(RULEDIR)` to avoid an implicit dependency on `output_to_genfiles`. PiperOrigin-RevId: 611652089 01 March 2024, 01:40:18 UTC
30d3bb4 Merge pull request #19795 from jakevdp:key-reuse-eager PiperOrigin-RevId: 611626544 01 March 2024, 00:08:38 UTC
d08e9a0 [key reuse] add eager checks 29 February 2024, 23:30:19 UTC
087f99a Support mocking number of GPUs in CUDA plugin. Also move reading jax config value to be right before the client is created. Previously they were read before calling register_plugin, which happens during import and before any call of jax.config.update. The decorator in mock_gpu_test was used wrongly. jtu.run_on_devices will create the client before jax.config.update is called, which is not desired. Remove the decorator will not fail CPU/TPU tests because the mesh will check the num_shard and the number of devices in the client and skip it if it does not match. generate_pjrt_gpu_plugin_options is only used in places that do not require compatibility so do not need to update xla_client version. PiperOrigin-RevId: 611610915 29 February 2024, 23:15:06 UTC
b2058d7 Merge pull request #20025 from jakevdp:fix-key-reuse-test PiperOrigin-RevId: 611557123 29 February 2024, 20:22:54 UTC
back to top