https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
79021a4 Type compiler options correctly. PiperOrigin-RevId: 565394366 14 September 2023, 16:47:15 UTC
a2720ee Deprecate `jax.experimental.pjit.with_sharding_constraint`. Replacement is jax.lax.with_sharding_constraint which has been available since 1 year. PiperOrigin-RevId: 565389746 14 September 2023, 16:23:03 UTC
ed955ea Fully unroll the `scan` in `jnp.searchsorted`, when method 'scan_unrolled' is specified. On GPU, XLA's 'scan' (fori_loop) implementation launches multiple calls to the body_fun GPU kernel, whereas a fully unrolled scan can be fused into a single kernel launch. Since we only require log-many steps, this is often quite practical, and can be a nice speedup. (from 4.5ms down to 1.5ms in my scenario.) PiperOrigin-RevId: 565371859 14 September 2023, 15:10:49 UTC
bbfba9a Remove code that disabled tests on "stream_executor" backends. These tests work on both GPU and the current (non-stream_executor) TPU runtime, so the conditions aren't needed any more. Tag a couple of tests as "multiaccelerator" since they appear to benefit from multiple devices. PiperOrigin-RevId: 565367453 14 September 2023, 14:52:43 UTC
8acf597 [Mosaic] Add support for specifying estimated costs for Mosaic kernels PiperOrigin-RevId: 565310871 14 September 2023, 09:50:44 UTC
cab68db Update XLA dependency to use revision http://github.com/openxla/xla/commit/f33da8b05c619f9134b85dba8c0e19aa4e5422d7. PiperOrigin-RevId: 565291862 14 September 2023, 08:08:59 UTC
838f59e [Mosaic] apply_vector_layout C++ rewrite (4) Elementwise ops PiperOrigin-RevId: 565255860 14 September 2023, 05:19:10 UTC
6869000 Modify the name for JAX compilation cache adoption metric. Rename the adoption metric from '/jax/compilation_cache/tasks_using_original_cache' to '/jax/compilation_cache/tasks_using_cache' to record the number of tasks using compilation cache for both original and new implementations. We realized that one adoption metric is enough to monitor and compare the original and new cache adoption rates. To avoid confusion of the word 'original' in the metric name, we decide to change the metric name to better describe the purpose of this adoption metric. Testing: revised unit test. PiperOrigin-RevId: 565197027 13 September 2023, 23:44:10 UTC
1f8cc44 deprecate `PRNGKeyArray.unsafe_raw_array` in favor of `jax.random.key_data` The latter function is also better in that its behavior is invariant to `jit`, whereas the `unsafe_raw_array` method only works in eager mode. PiperOrigin-RevId: 565195381 13 September 2023, 23:33:56 UTC
cd2d419 [Pallas] Add support for remote DMAs on TPU PiperOrigin-RevId: 565190637 13 September 2023, 23:14:10 UTC
91fbf9d [PJRT C API] Set up jax xla cuda package. Add a build wheel, pyproject.toml and setup.py. The directory structure in jax repo is: jax/ └── plugins/ └── cuda/ ├── __init__.py ├── pyproject.toml └── setup.py Installed package structure is: jax_plugins/ └── xla_cuda_cu12/ ├── __init__.py └── xla_cuda_plugin.so The major cuda version will be part of the package name. The plugin wheel can be built with command: python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla" PiperOrigin-RevId: 565187954 13 September 2023, 23:03:53 UTC
a38a152 Disable lobpcg_test on GPU. This test is failing in CI, disable it while we debug. PiperOrigin-RevId: 565180431 13 September 2023, 22:44:02 UTC
ebc24c7 Pass sharded inputs to remat offloading tests. When we execute, these inputs will be interesting to validate against the correctness of the compiler passes. PiperOrigin-RevId: 565180089 13 September 2023, 22:43:40 UTC
44bd916 [Pallas] Add support for local DMAs on TPU backend PiperOrigin-RevId: 565179693 13 September 2023, 22:32:44 UTC
11c2f16 Merge pull request #17594 from jakevdp:dep-prngkey PiperOrigin-RevId: 565163390 13 September 2023, 21:33:56 UTC
306c60d Remove references to deprecated "tpu_se" build configuration. PiperOrigin-RevId: 565156675 13 September 2023, 21:10:30 UTC
4e6c1b6 Deprecate random.KeyArray and random.PRNGKeyArray 13 September 2023, 21:05:42 UTC
270cc60 Update internal callers to avoid PRNGKeyArray 13 September 2023, 21:05:42 UTC
729752b Disable XLA detailed logging and dumping for small computations. This significantly reduces the amount of logging from XLA on TPU. PiperOrigin-RevId: 565148809 13 September 2023, 20:45:00 UTC
eeb32a7 Finish deprecation cycle for abstract_arrays.ShapedArray & abstract_arrays.raise_to_shaped PiperOrigin-RevId: 565142019 13 September 2023, 20:21:46 UTC
22ff7bd Finish the deprecation cycle for jnp.alltrue, jnp.sometrue, jnp.product, jnp.cumproduct These have been deprecated in JAX following similar deprecations in numpy v1.25.0 PiperOrigin-RevId: 565122288 13 September 2023, 19:07:36 UTC
8340149 Check if the input which is donated is actually deleted along with the AOT check. PiperOrigin-RevId: 565098239 13 September 2023, 17:50:16 UTC
6abefa1 fast dispatch for functions over typed PRNG key arrays Before this change, JAX could dispatch compiled functions over new-style (typed) RNG key arrays, but it would always do so off of the fast (C++-based) dispatch path. In other words, switching from old-style `uint32` RNG keys to new-style keys would regress dispatch times. With this change, dispatch happens on the fast path again and performance regressions ought to be minimal. We currently maintain only one pytree registry, for all registered pytree node types. We want RNG key arrays to also be treated as pytree leaves everywhere *except* during dispatch. In other words: we want operations on (typed) RNG key arrays to appear in Jaxpr, but we want to unravel those arrays into their underlying `uint32` arrays only during dispatch. To do this, we add a new internal pytree registry that dispatch respects uniquely. This registry includes all items in the default registry, but also the RNG key array type. Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 565077758 13 September 2023, 16:43:58 UTC
8495128 Merge pull request #17584 from cabbagepatchman:cabbagepatchman-fix-sp PiperOrigin-RevId: 565066492 13 September 2023, 16:03:13 UTC
5a15ba9 Update XLA dependency to use revision http://github.com/openxla/xla/commit/551a620dafee38b6015028244a3b906d729e5a44. PiperOrigin-RevId: 564978874 13 September 2023, 09:10:36 UTC
4ded017 [Mosaic] Fix an incorrectly implemented heuristic for selecting tile sizes It did not properly account for the minimum tile size requirements on TPUv2 and v3. PiperOrigin-RevId: 564977666 13 September 2023, 09:00:40 UTC
d7a6eed Update jax2tf/tests/sharding_test.py for TPU runtime changes. Removes support for an older runtime (StreamExecutor) on TPU. PiperOrigin-RevId: 564927177 13 September 2023, 05:08:26 UTC
c41d271 Add memories support to remat. This PR adds basic support to remat to allow transferring intermediates (activations) to destination memory in the forward pass. Currently JAX only support host memory kind but the API allows to transfer to other memories too. Remat will automatically load the residuals back to the source memory in the backward pass. Introduce two singletons called `Recompute`, `Saveable` and a NamedTuple (`Offloadable`) that each policy can return. Currently policies return a bool which if True means saveable else recompute on backward pass. This is a backwards compatible change i.e. policies can still return a bool. A very basic offloadable policy can look like this: ``` def policy(prim, *avals, **params): return ad_checkpoint.Offloadable(src='tpu_hbm', dst='unpinned_host') ``` Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 564914301 13 September 2023, 03:50:05 UTC
125343b Update README.md 13 September 2023, 03:39:06 UTC
7c0abb1 Merge pull request #17564 from andportnoy:aportnoy/scipy-spatial-test-increase-tolerance PiperOrigin-RevId: 564891104 13 September 2023, 01:21:31 UTC
d4b564a [Mosaic] Support relayout from (1,128) to (8,128) when dst.offset is (0, 0). PiperOrigin-RevId: 564882618 13 September 2023, 00:35:09 UTC
d3950b9 Merge pull request #17568 from zhenying-liu:testdotgeneral PiperOrigin-RevId: 564871912 12 September 2023, 23:46:39 UTC
90da4f1 Fix an A100 nightly unit test failure on testDotGeneral() by replacing TF32 with float32 12 September 2023, 23:31:22 UTC
56791eb lax_test: adjust TPU tolerance for igamma & friends PiperOrigin-RevId: 564859109 12 September 2023, 22:59:41 UTC
a26125c Integrate LLVM at llvm/llvm-project@c1796be93fe5 Updates LLVM usage to match [c1796be93fe5](https://github.com/llvm/llvm-project/commit/c1796be93fe5) PiperOrigin-RevId: 564842806 12 September 2023, 22:00:09 UTC
801cbef [Mosaic] Use strided load to load one entire row more efficiently. PiperOrigin-RevId: 564831610 12 September 2023, 21:19:35 UTC
c617bcb Merge pull request #17566 from jakevdp:prngarray-type PiperOrigin-RevId: 564828726 12 September 2023, 21:09:02 UTC
9d86421 [Mosaic] Use strided store to store one row. PiperOrigin-RevId: 564821813 12 September 2023, 20:56:58 UTC
ea5f126 [custom prng] make PRNGKeyArray a subclass of jax.Array 12 September 2023, 20:48:12 UTC
b20b93e Merge pull request #17565 from jakevdp:fix-array-type PiperOrigin-RevId: 564821273 12 September 2023, 20:46:16 UTC
d44b038 [typing] fix a few array type declarations 12 September 2023, 20:21:48 UTC
34ea2b2 Increase comparison tolerance in SciPy spatial RotationMean subtest Previous value leads to failures on A100 runners in github.com/NVIDIA/JAX-Toolbox CI: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6144692887/job/16670611913#step:8:1014 The suspected reason is the use of TF32 math for matmuls: decorating the function with @jax.default_matmul_precision("float32") allows the test to pass. We thought it's better to loosen the tolerance but preserve the original execution mode. The fully qualified test case is tests/scipy_spatial_test.py::LaxBackedScipySpatialTransformTests::testRotationMean0 12 September 2023, 20:12:14 UTC
d0df18a Merge pull request #17562 from jakevdp:version-import PiperOrigin-RevId: 564807097 12 September 2023, 19:57:29 UTC
1800015 Import jax.version first 12 September 2023, 19:27:20 UTC
63e51fe Merge pull request #17524 from jakevdp:pref-eltype-test PiperOrigin-RevId: 564792174 12 September 2023, 19:01:21 UTC
dbb0e8f [Mosaic] Add a pass for instantiating memory spaces PiperOrigin-RevId: 564723473 12 September 2023, 15:05:26 UTC
7dddb50 [XLA:Python] Remove the use_tfrt flag from make_cpu_client(). use_tfrt=True has been the default for over a year, and the flag currently does nothing. PiperOrigin-RevId: 564712316 12 September 2023, 14:15:04 UTC
462c0bd Update XLA dependency to use revision http://github.com/openxla/xla/commit/cd7379d9af2d1a6f0b18edf47720ec9198b39fe3. PiperOrigin-RevId: 564660543 12 September 2023, 09:59:25 UTC
2a7b8e6 Add `gpu_common_utils` to build_wheel to fix the gpu wheels build PiperOrigin-RevId: 564562958 12 September 2023, 01:40:55 UTC
76a5dc3 Move memories_test.py to JAX PiperOrigin-RevId: 564551723 12 September 2023, 00:41:55 UTC
3e06dc8 Update `jax_spmd_mode` flag docstring and remove unused `allow_pjit` option. PiperOrigin-RevId: 564543943 12 September 2023, 00:08:35 UTC
d4adf00 Add default jvp and transpose rule for jax.lax.reduce_precision. PiperOrigin-RevId: 564536160 11 September 2023, 23:35:44 UTC
997b35e Improve the gpu lowering error message if users forget link the gpu library. PiperOrigin-RevId: 564530960 11 September 2023, 23:14:18 UTC
6c3b42d Add flags to exclude from cache-key generation. Some flags do not affect the compilation output. These should not be part of the cache key, otherwise changing them will change the key causing an unnecessary cache miss. Synchronize the exclusions between the command-line flags and DebugOptions. Add if-this-then-that lint checks to keep them in sync. PiperOrigin-RevId: 564474189 11 September 2023, 19:57:45 UTC
05d2432 Merge pull request #17527 from ROCmSoftwarePlatform:rocm_build_updates_1 PiperOrigin-RevId: 564472739 11 September 2023, 19:48:10 UTC
23e4f0b Hash serialized topology description for new cache key generation. The original cache key generation hashes devices and backend. This is not future proof: it does not work for accelerators other than TPUs. Change this to use the serialized version of PjRtTopologyDescription which is supported for all accelerators. Note: . CPU and PjRt C API not supported as yet. . Stream Executor will not be supported. Testing: revised unit test. PiperOrigin-RevId: 564461564 11 September 2023, 19:08:26 UTC
a36598b Set the jax_enable_memories flag to True. If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes. If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted. PiperOrigin-RevId: 564457393 11 September 2023, 18:55:09 UTC
bfc12bd Update XLA dependency to use revision http://github.com/openxla/xla/commit/96beae40a196d9d990695d430ca68b3015bda9bb. PiperOrigin-RevId: 564327668 11 September 2023, 10:36:30 UTC
292deef Update XLA dependency to use revision http://github.com/openxla/xla/commit/c227585959ec96a4527b1b9f9023f8d6bbe976b3. PiperOrigin-RevId: 563982388 09 September 2023, 11:01:14 UTC
4091ac6 [ROCm]: Fix duplicate deps include 08 September 2023, 22:56:59 UTC
ef79c19 [ROCm]: Dockerfile and build script updates Add hipblaslt in Dockerfile Update docker file to default to ROCm5.6 CI scripts update to handle multiple ROCm versions 08 September 2023, 22:56:59 UTC
9289f32 Add missing preferred_element_type tests Followup to https://github.com/google/jax/pull/17506 08 September 2023, 20:07:37 UTC
1c99dd5 Merge pull request #17502 from hawkinsp:fb PiperOrigin-RevId: 563832070 08 September 2023, 20:07:36 UTC
dd05048 Merge pull request #17521 from jakevdp:more-version-tests PiperOrigin-RevId: 563813319 08 September 2023, 18:58:50 UTC
34ba4f5 Merge pull request #17519 from jakevdp:git-stderr PiperOrigin-RevId: 563811588 08 September 2023, 18:48:40 UTC
bc91f2d Add more extensive tests for version strings 08 September 2023, 18:33:49 UTC
bc1ca0b Merge pull request #17515 from jakevdp:fix-can-cast PiperOrigin-RevId: 563793237 08 September 2023, 17:44:23 UTC
db6104a Build: catch stderr when attempting git commands 08 September 2023, 17:15:55 UTC
4adbc1f Merge pull request #17511 from hawkinsp:dlpack PiperOrigin-RevId: 563783931 08 September 2023, 17:12:09 UTC
c351523 Add -lm to Linux flatbuffer builds. 08 September 2023, 16:57:46 UTC
09a5970 [Mosaic] Pass hardware generation to C++ apply vector layout pass PiperOrigin-RevId: 563777135 08 September 2023, 16:47:19 UTC
3ce5cb6 Remove use of get_default_device_assignment(). This is the only caller of this API in JAX, and it can be simplified. Change in preparation for removing get_default_device_assignment() from the Python bindings. PiperOrigin-RevId: 563770199 08 September 2023, 16:18:32 UTC
9e9ea5a jax.dtypes: fix dtypes.safe_to_cast() when output dtype is weak 08 September 2023, 16:04:05 UTC
832820e CI changes PiperOrigin-RevId: 563754696 08 September 2023, 15:08:05 UTC
592eb44 Fix Pallas tests that broke after recent changes PiperOrigin-RevId: 563750775 08 September 2023, 14:50:20 UTC
3a4b60b Fix dlpack type signatures to match Array API spec. Fixes https://github.com/google/jax/issues/17510 08 September 2023, 14:12:32 UTC
601d67a Update XLA dependency to use revision http://github.com/openxla/xla/commit/cb2f349863479be68e73052b0e03500c0470f40d. PiperOrigin-RevId: 563710408 08 September 2023, 11:14:45 UTC
366a16f [Pallas] Add support for allocating TPU semaphores and basic semaphore ops PiperOrigin-RevId: 563632709 08 September 2023, 04:33:53 UTC
b8eccb1 Remove the date check from jaxlib and jax version checks since it causes problem when jaxlib runs ahead of jax in CI (depending on timezones). PiperOrigin-RevId: 563614108 08 September 2023, 02:52:35 UTC
47bf243 Merge pull request #17506 from jakevdp:pref-el-type PiperOrigin-RevId: 563613949 08 September 2023, 02:43:01 UTC
cb114f2 [Pallas] Refactor memory space handling PiperOrigin-RevId: 563586933 08 September 2023, 00:08:57 UTC
d0c4c9b [Pallas] Add support for scoped allocations to Pallas TPU PiperOrigin-RevId: 563580548 07 September 2023, 23:41:01 UTC
311dc9c Add truncated normal initializer to jax.nn PiperOrigin-RevId: 563576354 07 September 2023, 23:23:42 UTC
bfd79b8 Support relayout of tiles in register when the layout tiling changes. PiperOrigin-RevId: 563570338 07 September 2023, 23:00:44 UTC
2451f34 jax.numpy: add preferred_element_type argument to matmul functions 07 September 2023, 22:16:22 UTC
a6eed40 [MOSAIC] apply_vector_layout C++ rewrite (3) applyLayoutOp and relayout PiperOrigin-RevId: 563556815 07 September 2023, 22:08:30 UTC
52ee404 Merge pull request #17505 from hawkinsp:windows PiperOrigin-RevId: 563550978 07 September 2023, 21:47:37 UTC
bda9292 Propagate ad.Zeros to the scan body function for jax.lax.scan for the output 'ys'. Example of what this fixes: ``` def grad_fn(x): def scan_body(x, params): return x, x.sum() pred, state = jax.lax.scan(scan_body, x, None, length=2) return pred.sum(), state x = np.zeros((5, 10), dtype=np.float32) loss_grad_fn = jax.value_and_grad(grad_fn, has_aux=True) print(jax.make_jaxpr(loss_grad_fn)(x)) ``` PiperOrigin-RevId: 563544684 07 September 2023, 21:36:53 UTC
de5aee8 Fix Bazel build failures from long command lines on Windows. Fixes https://github.com/google/jax/issues/14950 07 September 2023, 21:36:04 UTC
dbf1325 Copybara import of the project: -- 3905d6123bdc22f505934242363fda426c99c4cf by Peter Hawkins <phawkins@google.com>: Update flatbuffers. Use upstream flatbuffer bazel scripts, with a couple of small patches to fix: * https://github.com/google/flatbuffers/issues/8087 (remove npm references) * https://github.com/google/flatbuffers/pull/8088 (fix flatc build failure due to main() removal by linker) COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17502 from hawkinsp:fb 3905d6123bdc22f505934242363fda426c99c4cf PiperOrigin-RevId: 563543954 07 September 2023, 21:27:25 UTC
9566112 Merge pull request #17501 from jakevdp:safe-to-cast PiperOrigin-RevId: 563543026 07 September 2023, 21:17:03 UTC
8b700fa [Mosaic] Support relayout from (1, 128) to (8, 128). PiperOrigin-RevId: 563534657 07 September 2023, 20:49:44 UTC
8412781 Internal: add dtypes.safe_to_cast utility & use to generate indexing warning 07 September 2023, 19:18:14 UTC
a9410e5 Merge pull request #17489 from patrick-kidger:patch-5 PiperOrigin-RevId: 563472425 07 September 2023, 17:25:45 UTC
4e4f8b1 Merge pull request #17491 from hawkinsp:relnotes PiperOrigin-RevId: 563472424 07 September 2023, 17:25:22 UTC
f64235a Merge pull request #17453 from jakevdp:fix-version-string PiperOrigin-RevId: 563466394 07 September 2023, 17:06:32 UTC
6f3f0d5 build: write appropriate version strings to build artifacts 07 September 2023, 15:45:48 UTC
9b447aa Relax test tolerance to fix BCSR sparse matmul test failure on P100 GPU. PiperOrigin-RevId: 563441383 07 September 2023, 15:37:31 UTC
429422d Reverts 5fcd9265b1e20c41d684659af3d52c41f25ae2f3 PiperOrigin-RevId: 563426073 07 September 2023, 14:35:44 UTC
408c657 Add a release note about a fixed Windows crash. 07 September 2023, 13:35:25 UTC
back to top