https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
7f8a4c8 Remove PositionalSharding from distributed array doc 08 August 2024, 04:25:24 UTC
be53ee1 Set `jax_enable_memories` flag to `True` by default PiperOrigin-RevId: 660579462 07 August 2024, 23:25:25 UTC
7efca04 Merge pull request #22920 from jakevdp:fix-lint PiperOrigin-RevId: 660570457 07 August 2024, 23:01:09 UTC
a57d659 Update XLA dependency to use revision http://github.com/openxla/xla/commit/3bf7e1ae488174aa6b29cc3f2c216785dd161af8. PiperOrigin-RevId: 660570144 07 August 2024, 22:57:41 UTC
53af0d4 CI: fix mypy errors 07 August 2024, 22:15:45 UTC
de02988 Merge pull request #22909 from ROCm:ci_fix_solver_paths PiperOrigin-RevId: 660515208 07 August 2024, 20:26:17 UTC
cce7250 Merge pull request #22830 from kaixih:support_vmap PiperOrigin-RevId: 660509938 07 August 2024, 20:12:59 UTC
d3b6066 Merge pull request #22820 from Rifur13:mha-faster PiperOrigin-RevId: 660461104 07 August 2024, 18:11:15 UTC
32131d0 Merge pull request #22897 from jakevdp:bool-indexing PiperOrigin-RevId: 660444193 07 August 2024, 17:30:41 UTC
6fc57c0 Rolling forward #22836 This version, proposed by @dfm, does not have a custom JVP for the whole logsumexp and instead fixes #22398 directly. Reverts e416c6675acfd82866a6e83e8c221640c4d02f29 PiperOrigin-RevId: 660438802 07 August 2024, 17:17:55 UTC
893ae6e Merge pull request #22869 from dfm:custom-batching-polish PiperOrigin-RevId: 660421503 07 August 2024, 16:40:46 UTC
5cb9510 Merge pull request #22908 from gnecula:pallas_warn PiperOrigin-RevId: 660421476 07 August 2024, 16:37:15 UTC
930c8ca Merge pull request #22914 from rajasekharporeddy:testbranch1 PiperOrigin-RevId: 660421322 07 August 2024, 16:33:43 UTC
a2d7993 [ROCM] Fix BUILD.bazel library source paths 07 August 2024, 14:18:20 UTC
3a1567f Do not run nn_test under asan -- it times out PiperOrigin-RevId: 660377176 07 August 2024, 14:14:27 UTC
3095c57 Better docs for jnp.fft.rfft2 and jnp.fft.irfft2 07 August 2024, 12:29:53 UTC
3e5e947 Move some backwards compatibility tests from jax_triton to jax/pallas. While doing this I moved `matmul.py` to `jax/experimental/pallas/ops/tpu` PiperOrigin-RevId: 660341331 07 August 2024, 12:00:29 UTC
28ca734 Added another boxDim check to mosaic_gpu_init_tma_desc PiperOrigin-RevId: 660314586 07 August 2024, 10:16:54 UTC
64eb8e9 [pallas] Add a warning message about experimental and incomplete status 07 August 2024, 05:38:56 UTC
803453e [Pallas TPU] Close over consts in while_loop lowering to avoid passing refs in/out of loop PiperOrigin-RevId: 660238073 07 August 2024, 05:33:15 UTC
dd958ad Add `mesh_shape` to the lowering context. This is to allow custom partitioning to not depend on the mesh context manager to return NamedShardings even if the arguments have NamedShardings on them. Since `shardy`, sharding in types work, world 2 dagger is going in a direction of making Mesh and PartitionSpec a first class sharding type, let's pull the trigger right now to start fixing these bad user interactions. Some things that will break due to this change: Before passing NamedSharding and an equivalent PositionalSharding to the same jitted function one after another would lead to a lowering cache hit. But now we will cache miss. In other words: `f(ns); f(ps) # cache hit before` In followup CLs, we will make the tracing cache aware of the mesh shape too to fix some other issues related to tracing and lowering cache misses PiperOrigin-RevId: 660177423 07 August 2024, 01:35:44 UTC
7f44edc Change log level of clearing JAX backend caches from info to debug. PiperOrigin-RevId: 660141868 06 August 2024, 23:27:56 UTC
798297a Update XLA dependency to use revision http://github.com/openxla/xla/commit/08b8d938eb56928970e65639b126794c01b75c3d. PiperOrigin-RevId: 660133285 06 August 2024, 23:05:14 UTC
53ab5eb Merge pull request #22900 from jakevdp:dep-bfloat16 PiperOrigin-RevId: 660102762 06 August 2024, 21:42:43 UTC
9074e85 Add test for zero-sized host memory parameter PiperOrigin-RevId: 660097039 06 August 2024, 21:31:41 UTC
aec6efb Merge pull request #22649 from ROCm:ci_jax_export_harness PiperOrigin-RevId: 660096296 06 August 2024, 21:27:13 UTC
cc96657 Merge pull request #22901 from ROCm:ci_test_harness_vmap PiperOrigin-RevId: 660089572 06 August 2024, 21:04:57 UTC
abe7982 Remove enable_gpu and xla_python_enable_gpu from jax .bazelrc. The plugin is released and the flag is no longer needed. Also set default value of enable_gpu to False. enable_gpu will be removed in the next change. PiperOrigin-RevId: 660059432 06 August 2024, 19:39:45 UTC
ae54120 Skip flaky test_weight_offload_with_dp_on_output test on GPU backend. PiperOrigin-RevId: 660057950 06 August 2024, 19:35:53 UTC
707cdd4 [ROCM] Fix hipsolverSsyevd tests due to align with the rocm behavior. 06 August 2024, 19:10:09 UTC
a009e1c deprecate jax.lib.xla_client.bfloat16 06 August 2024, 18:22:27 UTC
f67f73c Merge pull request #22834 from rajasekharporeddy:testbranch1 PiperOrigin-RevId: 660014909 06 August 2024, 17:46:56 UTC
799de71 Merge pull request #22896 from jakevdp:pin-sphinx PiperOrigin-RevId: 660012176 06 August 2024, 17:39:40 UTC
b45f0fe Support empty boolean indexing 06 August 2024, 16:56:03 UTC
4f8c5a3 CI: pin sphinx to avoid build errors on 8.0 06 August 2024, 16:16:41 UTC
35c70fd [ROCM] Fix export harness tests 06 August 2024, 15:12:31 UTC
8b9ceb5 Handle bool comparisons. PiperOrigin-RevId: 659919931 06 August 2024, 12:37:35 UTC
209f6cd [Mosaic GPU] Profiler improvements 1. Each process now corresponds to an SM, showing how many blocks are executing concurrently. 2. The timeline now accounts for the start offset of each block, instead of aligning them together. This makes a lot more sense in the SM view. 3. We now use inline PTX to emit profiler events. This sometimes slightly pessimizes code generation, but allows us to predicate out write on all threads other than the leader of each warpgroup, improving the trace quality. 4. We make sure each trace is monotonic. I can't explain why but the clocks can behave very weirdly, potentially due to rescheduling on the SASS level. We now fix up all backward movements and emit a warning if big shifts have been detected. PiperOrigin-RevId: 659911268 06 August 2024, 12:02:59 UTC
23da11b Re-land FFI port of GPU LU decomposition after fixing XLA FFI memory leak. PiperOrigin-RevId: 659867028 06 August 2024, 09:13:21 UTC
f255fb7 Async dispatch expensive computations on the JAX CPU backend. By setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`, one could opt out of the change and recover the old behavior. PiperOrigin-RevId: 659741822 06 August 2024, 00:48:17 UTC
0ab4d68 Merge pull request #22885 from jakevdp:dep-xla PiperOrigin-RevId: 659724044 05 August 2024, 23:43:46 UTC
06f29bb Deprecate jax.lib.xla_client._xla This is an alias for jax.lib.xla_extension. Why the deprecation warning for this when #22844 removed other APIs without any warning? This one is relatively commonly used (I found a few dozen downstream references) so I feld that a deprecation warning might be helpful. 05 August 2024, 23:19:59 UTC
489fbc0 Add a test for streaming in closed over constants from host to device PiperOrigin-RevId: 659711557 05 August 2024, 23:00:45 UTC
a497fbc Update XLA dependency to use revision http://github.com/openxla/xla/commit/b33429c33d7037e0ad2678b1e0923c35d560a386. PiperOrigin-RevId: 659695240 05 August 2024, 22:08:18 UTC
3d857b0 export jax.lib.xla_extension.HloModule Followup to #22844, because the symbol is used downstream. PiperOrigin-RevId: 659678623 05 August 2024, 21:16:40 UTC
e416c66 Reverts 0f103d33849ca017e6a199d0f79fa0d83b373995 PiperOrigin-RevId: 659670593 05 August 2024, 20:52:04 UTC
c2c04e0 Merge pull request #22608 from kaixih:fix_cuda_version_check PiperOrigin-RevId: 659664879 05 August 2024, 20:34:43 UTC
09b8843 Fix CUDA version checks 05 August 2024, 20:09:17 UTC
0f103d3 Merge pull request #22836 from superbobry:maint-2 PiperOrigin-RevId: 659644462 05 August 2024, 19:30:51 UTC
af1a69e Merge pull request #22870 from google:dependabot/github_actions/actions/upload-artifact-4.3.5 PiperOrigin-RevId: 659604731 05 August 2024, 17:47:06 UTC
6d7cf3f Bump actions/upload-artifact from 4.3.3 to 4.3.5 Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.3.3 to 4.3.5. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/65462800fd760344b1a7b4382951275a0abb4808...89ef406dd8d7e03cfd12d9e0a4a378f454709029) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] <support@github.com> 05 August 2024, 17:09:11 UTC
5f99c55 Merge pull request #22844 from jakevdp:xla-extension PiperOrigin-RevId: 659586676 05 August 2024, 16:55:24 UTC
1d425b2 Small tweaks to custom_vmap UI. I'm working on some extensions to `custom_vmap` and came across these small UI improvements (I think!). This includes the two changes: 1. A weakening of the `kwargs` check to be consistent with the one in `custom_vjp`/`custom_jvp`, and 2. An improved error message when `def_vmap` isn't called. 05 August 2024, 16:51:47 UTC
1acff9c Better docs for jnp.fft.hfft and jnp.fft.ihfft 05 August 2024, 16:23:29 UTC
0a48aca Added a custom JVP rule for jax.nn.logsumexp Fixes #22398 where the Jacobian of jax.nn.logsumexp was wrong if b= contained exact zeros. 05 August 2024, 16:05:03 UTC
9762ac5 Move CostEstimate from pltu to pl * Move CostEstimate from TPU-specific `compiler_params` to a platform-independent argument of `pallas_call`. Passing a CostEstimate in `compiler_params` is now deprecated and will be removed in 3 months time. * Update the CostEstimate when batching a kernel by scaling it by the size of the batch axis. PiperOrigin-RevId: 659560330 05 August 2024, 15:18:01 UTC
ecf9f64 Nit. Fix missing backticks documentation for jnp.where. PiperOrigin-RevId: 659549362 05 August 2024, 14:37:41 UTC
44a8c98 Merge pull request #22141 from dfm:update-cuda-call-example-to-ffi-call PiperOrigin-RevId: 659542133 05 August 2024, 14:09:12 UTC
5892747 Merge pull request #22849 from jakevdp:dlpack-doc PiperOrigin-RevId: 659541072 05 August 2024, 14:04:42 UTC
252032a [pallas] Improve error and debugging messages with source locations Document the `name` argument to `pallas_call` and supplement it with source location information for the kernel function. Pass all this as the `name_and_src_info` parameter to the `pallas_call_p` primitive. Added some more information to the `if debug` prints. Set the MLIR module names so that the debug dumps are named properly. I changed `import pallas.core as pl_core` to `... as pallas_core` for consistency, in a couple of modules. PiperOrigin-RevId: 659506675 05 August 2024, 11:23:55 UTC
b2a469b Port Eigenvalue Decompositions to XLA's FFI This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks. PiperOrigin-RevId: 659492696 05 August 2024, 10:18:13 UTC
9b35b76 [pallas] Enable check for GPU lowering that tensor sizes are power of 2 Triton has a restriction that all operations have arguments and results that are tensor whose size is a power of 2. Added a lowering check for this. Without this, when we violate the condition we get an unfriendly crash. PiperOrigin-RevId: 659483450 05 August 2024, 09:34:21 UTC
0b87bf4 Update XLA dependency to use revision http://github.com/openxla/xla/commit/b94c84fa5417ffac64731ea76f64b8a69d216f20. PiperOrigin-RevId: 659354942 04 August 2024, 21:53:18 UTC
56ff247 Reverts 80560663d3fab4c0c3f87d7c8e52fb9931526dbb PiperOrigin-RevId: 659334027 04 August 2024, 19:11:30 UTC
83b5c7a Merge pull request #22857 from mattjj:improve-while-loop-error PiperOrigin-RevId: 659160966 03 August 2024, 22:32:31 UTC
06c6a73 Update XLA dependency to use revision http://github.com/openxla/xla/commit/83ef35ce9ecc60a9a796b4e9d34475968cad4562. PiperOrigin-RevId: 659159129 03 August 2024, 22:20:24 UTC
bdcd358 improve while_loop carry pytree/type mismatch errors Now we call into the same error utility as we use in scan. 03 August 2024, 21:57:29 UTC
4d637c8 Improve documentation for jnp.from_dlpack 03 August 2024, 12:48:48 UTC
521c94c Tighten the public API for jax.lib.xla_client & xla_extension 03 August 2024, 12:26:22 UTC
09beb33 Don't call `api.clean_up` when there is no default backend. PiperOrigin-RevId: 658936536 02 August 2024, 23:14:29 UTC
eb571c9 Fix lint in run_single_gpu.py PiperOrigin-RevId: 658933291 02 August 2024, 23:03:03 UTC
51abbf9 Update XLA dependency to use revision http://github.com/openxla/xla/commit/8a54f481f091f0d5a89f1836adcbe8d0189f089f. PiperOrigin-RevId: 658923035 02 August 2024, 22:25:51 UTC
261cf52 Merge pull request #22828 from ROCm:ci_add_packages_dockerfile PiperOrigin-RevId: 658901700 02 August 2024, 21:16:05 UTC
e7fd424 Merge pull request #22784 from pearu:pearu/accuracy-tests-update PiperOrigin-RevId: 658901408 02 August 2024, 21:12:05 UTC
780b10b Update complex functions accuracy tests 02 August 2024, 20:31:51 UTC
9f9e3e6 Address comments 02 August 2024, 19:55:28 UTC
88c8bac Add `util.clear_all_caches` to `api.clear_backends` and let `api.clear_backends` be called before process terminates on JAX CPU. This could make the PjRt CPU client object to be successfully destroyed during Python garbage collection. PiperOrigin-RevId: 658843789 02 August 2024, 18:08:48 UTC
958234a Thread the mesh context manager to the place where we recover out_shardings back from GSPMDShardings. Before if you had a program like this: ``` with mesh: out = pjit(lambda: 1)() ``` The sharding of `out` was a `GSPMDSharding` which is not ideal. This change fixes that and returns a `NamedSharding` instead. This is also required for `Shardy` integration. PiperOrigin-RevId: 658842350 02 August 2024, 18:04:48 UTC
ac52890 [jax] Shard pallas_vmap_test PiperOrigin-RevId: 658834942 02 August 2024, 17:41:22 UTC
6c79c10 Merge pull request #22840 from jakevdp:fix-backends PiperOrigin-RevId: 658817878 02 August 2024, 16:47:25 UTC
86c9903 [Pallas TPU] Make sure that the bug repros actually fail One of them was fixed in the meantime but we didn't realize it. PiperOrigin-RevId: 658799901 02 August 2024, 15:37:22 UTC
e6851e6 Fix the AOT check for sharding consistency which skipped checking the devices of the sharding. So before for TPU compiled computation, a user could have passed in a committed array on CPU and JAX wouldn't have errored which is wrong. This change fixes that. Also `is_equivalent_to` should check for devices, HloSharding and memory_kind (so removing the redundant `memory_kind` check too). PiperOrigin-RevId: 658794885 02 August 2024, 15:15:32 UTC
f85b8e6 [Mosaic TPU] Add support for bf16 reductions PiperOrigin-RevId: 658787017 02 August 2024, 14:42:27 UTC
5474e0e Update CUDA custom call example code to use `ffi_call`. Following up on #21925, we can update the example code in `docs/cuda_custom_call` to use `ffi_call` instead of manually registering `core.Primitive`s. This removes quite a bit of boilerplate and doesn't require direct use of MLIR. 02 August 2024, 14:15:10 UTC
3fa86a9 remove jax.extend.backend.default_backend in favor of jax.backend I added this two days ago before realizing there is already a canonical API for this in the top-level namespace, so it should be safe to remove. 02 August 2024, 14:07:29 UTC
e88887e [Mosaic TPU] Add a missing reshape in relayout The fact that src generalizes dst does not mean that they have the same implicit tile shape (if one has an implicit dim and the other one doesn't, then they will differ by a singleton dimension). PiperOrigin-RevId: 658775019 02 August 2024, 13:44:31 UTC
02d836d Updated Mosaic GPU lowering registration in Pallas The lowering rule for mosaic_gpu_p now expects a serialized module. PiperOrigin-RevId: 658772330 02 August 2024, 13:30:04 UTC
959657a [Mosaic TPU] Remove special handling of implicit dim in relayout Now all changes happen inside the dedicated functions. PiperOrigin-RevId: 658763465 02 August 2024, 12:46:26 UTC
8056066 Enable FFI implementation of GPU Getrf FFI handler. PiperOrigin-RevId: 658755392 02 August 2024, 12:07:02 UTC
99625ff [Mosaic TPU] Break out implicit dim changes from relayout PiperOrigin-RevId: 658752228 02 August 2024, 11:50:40 UTC
efba5f6 Merge pull request #22812 from superbobry:maint PiperOrigin-RevId: 658751187 02 August 2024, 11:43:33 UTC
6b0b222 Activate LU Decomposition to XLA's FFI PiperOrigin-RevId: 658721697 02 August 2024, 09:22:53 UTC
20e9c15 [pallas] Small cleanup in the Mosaic lowering Uses the helper functions for the calling convention from #22552 and #22593. PiperOrigin-RevId: 658692284 02 August 2024, 07:16:35 UTC
d57447a Double buffer pipeline semaphores so we can hide DMA latency under compute and not just BW. Also enable disabling automatic accumulation across pipelines. PiperOrigin-RevId: 658585671 01 August 2024, 23:58:19 UTC
28b8660 [pallas:mosaic_gpu] Make the linter happy. PiperOrigin-RevId: 658580241 01 August 2024, 23:37:12 UTC
7d6fa3c [ROCm]: Add support to continue on fail, fix script paths and update Dockerfile to add necessary packages 01 August 2024, 22:55:15 UTC
2241dad Merge pull request #22814 from superbobry:maint-2 PiperOrigin-RevId: 658560253 01 August 2024, 22:31:40 UTC
bc0229a Rollback as it broke some tests. Reverts ff17b76e3eec3e573788f64fafe23fabcfc09ce2 PiperOrigin-RevId: 658557091 01 August 2024, 22:21:42 UTC
16c868a Merge pull request #22825 from jakevdp:fix-old-array-api PiperOrigin-RevId: 658552229 01 August 2024, 22:07:57 UTC
8df0c3a Port Getrf GPU kernel from custom call to FFI. PiperOrigin-RevId: 658550170 01 August 2024, 22:02:25 UTC
back to top