https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
604dc24 Update XLA dependency to use revision http://github.com/openxla/xla/commit/ecb73da4b7b2e3b54aa6d6b7f08a5c662bb19c6e. PiperOrigin-RevId: 574576317 18 October 2023, 20:39:03 UTC
f10e333 Start release for jax 0.4.19 PiperOrigin-RevId: 574575158 18 October 2023, 20:28:52 UTC
3778265 Merge pull request #18126 from niqodea:wrapcauchy PiperOrigin-RevId: 574572631 18 October 2023, 20:18:20 UTC
88fe0da Merge pull request #18078 from ROCmSoftwarePlatform:rocm-jax-triton PiperOrigin-RevId: 574546618 18 October 2023, 18:56:01 UTC
9435a0a Merge pull request #18138 from mattjj:shmap-axis-env-fix PiperOrigin-RevId: 574540561 18 October 2023, 18:35:02 UTC
d55085f [Mosaic] Support tpu.concatenate along the tiling dims as long as the shapes are aligned to native tiling. PiperOrigin-RevId: 574523694 18 October 2023, 17:43:54 UTC
cf65480 [XlaCallModule] Drop support for dim_args_spec attribute. This attribute was used to support shape polymorphism in versions up to and including version 4. Starting on March 28th 2023 with JAX version 0.4.6 we stopped using this attribute. We are now beyond the 6 month backward compatibility version and we drop support for this attribute. We also increase the minimum supported serialization version to 5. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions PiperOrigin-RevId: 574450204 18 October 2023, 12:54:11 UTC
890b762 feat: add wrapcauchy logpdf and pdf 18 October 2023, 11:47:10 UTC
e7dff2c Update XLA dependency to use revision http://github.com/openxla/xla/commit/761b642893e16450e965d1c89d797d78a53b06fe. PiperOrigin-RevId: 574406486 18 October 2023, 09:06:38 UTC
b4b97cd [ROCm]: Add jax-triton support for ROCm 18 October 2023, 07:09:20 UTC
5904747 Merge pull request #18156 from jakevdp:seed_with_impl PiperOrigin-RevId: 574327521 18 October 2023, 02:10:45 UTC
8353027 Merge pull request #18070 from sagyakwa:patch-1 PiperOrigin-RevId: 574300558 17 October 2023, 23:41:01 UTC
7498377 Add `GetOpSharding` to XLA/PjRt utils. PiperOrigin-RevId: 574287268 17 October 2023, 22:46:52 UTC
e9da4f5 Merge branch 'google:main' into patch-1 17 October 2023, 22:30:24 UTC
6da4750 [random] remove internal uses of deprecated prng.seed_with_impl() 17 October 2023, 20:18:08 UTC
d03bbc0 random_lax_test: Bump shards number for CPU config. PiperOrigin-RevId: 574239793 17 October 2023, 19:59:12 UTC
86023f5 [Pallas TPU] Add DMA descriptor abstraction for constructing but not starting DMAs PiperOrigin-RevId: 574210634 17 October 2023, 18:20:05 UTC
c16b893 [pallas:gpu] Simplify `broadcast_to`, `min`, `max` lowering. PiperOrigin-RevId: 574204406 17 October 2023, 18:00:50 UTC
2c9ea51 [Mosaic] apply_vector_layout C++ rewrite: add tpu.concatenate PiperOrigin-RevId: 574199634 17 October 2023, 17:47:55 UTC
b84ae98 Make sure we don't filter stack frames of packages that start with a `jax` prefix This ended up accidentally setting up filters for jax_triton. This change additionally adds an opt-in mechanism for paths, that overrides exclusions. We use this to avoid treating pallas ops implementations as JAX-internal. PiperOrigin-RevId: 574167963 17 October 2023, 16:06:30 UTC
2be6019 Rollback to fix internal breakage Reverts 7d203aebfa6206affde207c884b50172e203d177 PiperOrigin-RevId: 574101804 17 October 2023, 11:24:15 UTC
2604d0c Update XLA dependency to use revision http://github.com/openxla/xla/commit/147a47aa0dea586e2c21a565d58a370eca22ac7d. PiperOrigin-RevId: 574064031 17 October 2023, 08:36:56 UTC
497f809 Re-enable the test as the GPU plugin profiler fixes are in. Reverts 4cd4f3f3b380755ee4738aa73fd2a834e6a61fd7 PiperOrigin-RevId: 574020446 17 October 2023, 04:56:46 UTC
6540810 Merge pull request #18117 from jakevdp:ad-deprecations PiperOrigin-RevId: 573996165 17 October 2023, 02:33:59 UTC
7d203ae Merge pull request #18105 from jakevdp:keyarray PiperOrigin-RevId: 573995089 17 October 2023, 02:22:41 UTC
43fc423 Temporarily set TPU_LIBRARY_PATH in xla_bridge. This will be removed once tpu_tracer removes its dependency on TPU_LIBRARY_PATH (should be within next two weeks). PiperOrigin-RevId: 573958529 16 October 2023, 23:13:05 UTC
89b5449 [XLA:GPU] Fix bug in all-to-all for complex data types. The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier. Fixes https://github.com/google/jax/issues/18122 PiperOrigin-RevId: 573955671 16 October 2023, 23:02:22 UTC
5919c1f Merge pull request #18104 from gnecula:multi_jax2tf PiperOrigin-RevId: 573951693 16 October 2023, 22:46:23 UTC
f285359 Merge pull request #18120 from fehiepsi:slogdet PiperOrigin-RevId: 573939805 16 October 2023, 22:03:35 UTC
93fbf62 Fix testProfilerGetFDOProfile. PiperOrigin-RevId: 573936890 16 October 2023, 21:52:44 UTC
3bfe1d2 [shard_map] fix axis env extension bug Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com> 16 October 2023, 19:36:51 UTC
dcc92e3 [pallas] `dot` fixes. - Check that operands are 2D. - Set `preferred_element_type`. - Fix dot output type on GPU. PiperOrigin-RevId: 573895904 16 October 2023, 19:35:43 UTC
0eef3b5 jax.interpreters.ad: deprecate inadvertently exported symbols 16 October 2023, 18:22:37 UTC
675cb15 [pallas:gpu] Directly use Triton built-in function lowering for `max_contiguous` and `multiple_of`. PiperOrigin-RevId: 573871324 16 October 2023, 18:17:10 UTC
ac2c228 [pallas:gpu] Simplify reshape lowering rule. PiperOrigin-RevId: 573865529 16 October 2023, 18:00:28 UTC
0290150 Build jaxlib without PJRT GPU deps when plugin will be built. PiperOrigin-RevId: 573844805 16 October 2023, 16:59:07 UTC
b65c1b2 [jax2tf] First step to enable multi-platform native lowering Enable experiments with jax2tf native serialization for multiple platforms. This feature is not yet fully functional but we need this change to enable further testing. Cleanup some of the places that are specific to single-platform serialization, e.g., `lowering_platform`, and generalize them to multiple platforms (`lowering_platforms`). 16 October 2023, 14:01:23 UTC
110b8d7 Update XLA dependency to use revision http://github.com/openxla/xla/commit/398b89072cf1e9e0b53b503306aed8fbf520eec7. PiperOrigin-RevId: 573731910 16 October 2023, 07:51:42 UTC
8bf605a Update XLA dependency to use revision http://github.com/openxla/xla/commit/a5187a333093f0815d82e0548b45f4d235a63e3d. PiperOrigin-RevId: 573575839 15 October 2023, 07:41:36 UTC
f9e4ab7 Merge pull request #18112 from jakevdp:broadcast-shape PiperOrigin-RevId: 573503405 14 October 2023, 19:44:53 UTC
8d700df Merge pull request #18111 from superbobry:fix-flags-underscore PiperOrigin-RevId: 573466372 14 October 2023, 14:19:37 UTC
c97ecf4 Enhance the speed of slogdet with qr method 14 October 2023, 12:22:33 UTC
9a3ec22 Update XLA dependency to use revision http://github.com/openxla/xla/commit/975cf53e0032484890f677a4242ff0565d1ebfbe. PiperOrigin-RevId: 573417519 14 October 2023, 08:21:04 UTC
12c2baa Merge pull request #18110 from jakevdp:jex-redirect PiperOrigin-RevId: 573305906 13 October 2023, 20:46:17 UTC
1815bc7 [typing] allow scalar shape for jnp.broadcast_to 13 October 2023, 20:37:20 UTC
f9087ab MAINT Drop underscore from the name of externally-referenced state objects 13 October 2023, 20:30:13 UTC
4cd4f3f Disable pgle_test.py for GPU plugin. PiperOrigin-RevId: 573304221 13 October 2023, 20:25:11 UTC
16061e6 Merge pull request #17980 from gnecula:multi_collective PiperOrigin-RevId: 573302491 13 October 2023, 20:24:57 UTC
ce4d0d2 Merge pull request #18108 from google:libtpu_import_fix PiperOrigin-RevId: 573302310 13 October 2023, 20:14:56 UTC
2edb66d jax.core: point deprecation to jax.extend 13 October 2023, 19:49:05 UTC
4ba7590 export jax.extend.source_info_util.current PiperOrigin-RevId: 573290435 13 October 2023, 19:31:11 UTC
8fe4fcc Use totalorder comparisons for sort PiperOrigin-RevId: 573289718 13 October 2023, 19:21:07 UTC
4e34fe0 Fix libtpu path on older jaxlibs. This is a follow-up to https://github.com/google/jax/commit/b81a3e1fd774ebdbc3015f1bc977bfacb5d4b745. We still need to set TPU_LIBRARY_PATH for jaxlibs that don't support the new mechanism for passing in the libtpu path. 13 October 2023, 18:39:15 UTC
c568110 Set up an API to top trace and fdo profile in memory. PiperOrigin-RevId: 573276173 13 October 2023, 18:34:25 UTC
a59ada0 [export] Adapt several collective lowering rules for multi-platform lowering This fixes a few more places where the lowering rules used module_context.platform, which is not supported for multi-platform lowering. 13 October 2023, 18:15:41 UTC
a2623f2 [random] Avoid references to PRNGKeyArray type See https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html 13 October 2023, 18:10:05 UTC
e088a8e Remove ancient XLA:TPU workaround We've supported S16 for a long time now PiperOrigin-RevId: 573260447 13 October 2023, 17:48:53 UTC
39a3cfc Merge pull request #18095 from parikshitadhikari:patch1 PiperOrigin-RevId: 573259065 13 October 2023, 17:48:35 UTC
7a2837b Merge pull request #18102 from jakevdp:jex-source-info-util PiperOrigin-RevId: 573258547 13 October 2023, 17:37:42 UTC
699ae85 Merge pull request #18096 from parikshitadhikari:patch2 PiperOrigin-RevId: 573255567 13 October 2023, 17:26:13 UTC
7478fbc [PJRT C API] Add "cuda_plugin_extension" to "gpu_only_test_deps" to support bazel test for GPU plugin. PiperOrigin-RevId: 573251982 13 October 2023, 17:12:16 UTC
1a5d6e8 Merge remote-tracking branch 'origin/patch-1' into patch-1 13 October 2023, 16:47:33 UTC
0cd6ef0 Merge remote-tracking branch 'origin/patch-1' into patch-1 13 October 2023, 16:47:10 UTC
4e463c0 JEX: add jax.extend.source_info_util 13 October 2023, 16:36:00 UTC
1ae2bbc Merge pull request #17892 from jakevdp:typing-test PiperOrigin-RevId: 573236275 13 October 2023, 16:06:10 UTC
2c64a0a typing: add some type assertions to typing_test 13 October 2023, 15:30:08 UTC
3c283db fixed typo 'primtive' as 'primitive' in How_JAX_primitives_work.ipynb 13 October 2023, 15:20:43 UTC
44daa56 Update XLA dependency to use revision http://github.com/openxla/xla/commit/cc2782f3ebfb3868e8edb8aa67b82cff0040d7b2. PiperOrigin-RevId: 573150536 13 October 2023, 09:16:59 UTC
ae2d6e2 [pallas:gpu] Implement `get` and `swap` using `load` and `masked_swap` lowering rules. PiperOrigin-RevId: 573146382 13 October 2023, 09:00:03 UTC
2bc2e17 [pallas:gpu] Fix `swap` Triton lowering. PiperOrigin-RevId: 573141426 13 October 2023, 08:48:20 UTC
0da5828 [pallas] Simplify `Slice.from_slice` code and add check for `Slice.size`. Slice behaviour for negative and out-of-range start/stop values now matches standard Python behaviour. PiperOrigin-RevId: 573141218 13 October 2023, 08:36:28 UTC
8a2c5d6 [PJRT C API] Set the gpu plugin allocator related options. PiperOrigin-RevId: 573111513 13 October 2023, 06:00:15 UTC
432506f [PJRT C API] Fixed pjrt_c_api_gpu and remove `noincompatible_remove_legacy_whole_archive` PiperOrigin-RevId: 573094387 13 October 2023, 04:25:25 UTC
ba1af01 fix: typo inside docs/notebooks/How_JAX_primitives_work.md 13 October 2023, 03:09:51 UTC
e21409f fix: typo inside docs/jep/12049-type-annotations.md 13 October 2023, 03:06:01 UTC
f5a1439 Delete cuda 12.0.1 rbe configs since JAX doesn't support it anymore PiperOrigin-RevId: 573059967 13 October 2023, 00:52:11 UTC
7bbd265 [pallas:gpu] De-duplicate lowering code for JAX primitives that map to Triton builtins. Add several missing ops in the process. PiperOrigin-RevId: 573054532 13 October 2023, 00:41:20 UTC
f9c6387 [pallas:gpu] Minor fix to Triton lowering. PiperOrigin-RevId: 573054173 13 October 2023, 00:31:04 UTC
ccc7113 [pallas:gpu] Minor cleanup to `div` lowering. PiperOrigin-RevId: 573053669 13 October 2023, 00:19:30 UTC
ef20526 Return PositionalSharding if input's rank is >= 3 or a NamedSharding if a mesh is available via the context from inspect_array_sharding. Never return GSPMDSharding from inspect_array_sharding. PiperOrigin-RevId: 573048344 12 October 2023, 23:55:12 UTC
726422f Merge remote-tracking branch 'origin/patch-1' into patch-1 12 October 2023, 23:43:03 UTC
91f1745 making ENABLE_PJRT_COMPATIBILITY conditionally true 12 October 2023, 23:42:39 UTC
489cd44 Merge pull request #18022 from 8bitmp3:add-banner-xmap PiperOrigin-RevId: 573044352 12 October 2023, 23:36:35 UTC
7e2e4a0 Add jit/shard_map banner to xmap docs Add jit/shard_map banner to xmap docs 12 October 2023, 22:06:15 UTC
65cfe1a Instrument metrics for the new JAX compilation cache key generation algorithm. Metrics: 1) '/jax/compilation_cache/cache_hits' to track the number of times the cached executable is successfully returned from a cache read using the new implementation. 2) '/jax/compilation_cache/compile_time_saved_sec' to record the time saved on cache hits using the new implementation. PiperOrigin-RevId: 573019115 12 October 2023, 21:56:02 UTC
6fb776b Update is_device_cuda to support testing for GPU plugin. GPU plugin platform version is "PJRT C API\ncuda ...". PiperOrigin-RevId: 573017348 12 October 2023, 21:45:24 UTC
ab161bb Cleanup lowering rule for hlo_unshard, to remove platform dependence. PiperOrigin-RevId: 572997889 12 October 2023, 20:36:11 UTC
a39d189 Merge remote-tracking branch 'origin/patch-1' into patch-1 12 October 2023, 19:52:35 UTC
d69e810 Merge branch 'google:main' into patch-1 12 October 2023, 19:50:25 UTC
3070748 Merge branch 'google:main' into patch-1 12 October 2023, 19:29:49 UTC
294fe80 Merge pull request #18041 from hawkinsp:rpath PiperOrigin-RevId: 572965294 12 October 2023, 18:30:38 UTC
4031c60 Merge pull request #18079 from superbobry:state-objects PiperOrigin-RevId: 572937114 12 October 2023, 17:00:23 UTC
d856ecc Set RPATH, not RUNPATH in JAX CUDA builds. Fixes https://github.com/google/jax/issues/17497 12 October 2023, 16:38:10 UTC
cbcaac2 MAINT Migrate remaining internal/test modules to use state objects The motivation here is to gradually replace all dynamic lookups on `jax.config` with statically-typed state objects, which are more type checker/IDE friendly. This is a follow up to #18008. 12 October 2023, 16:32:15 UTC
a736caa Merge pull request #17927 from gnecula:multi_cumsum PiperOrigin-RevId: 572913695 12 October 2023, 15:40:15 UTC
42ab110 Merge pull request #18001 from gnecula:export_shard_map PiperOrigin-RevId: 572913449 12 October 2023, 15:27:06 UTC
a06a5aa [Pallas] Use tpu.matmul instead of vector.contract in Mosaic lowering This will let us do mixed precision matmuls, which are rejected by the vector.contract verifier. PiperOrigin-RevId: 572901961 12 October 2023, 14:39:19 UTC
bdb4eda Merge pull request #18075 from jakevdp:callback-doc PiperOrigin-RevId: 572896857 12 October 2023, 14:27:42 UTC
ac6e779 Merge pull request #18030 from jakevdp:core-cleanup PiperOrigin-RevId: 572896527 12 October 2023, 14:15:13 UTC
6f90f65 [export] Adapt lowering caching to work with multi-platform lowering Previously the mlir.cache_lowering was assuming that a primitive has a unique lowering in a module for given input and output avals. But with multi-platform lowering we need to allow multiple lowerings. We fix this by adding the lowering function to the cache key. This fixes the multi-platform lowering tests for cumsum and cumprod. 12 October 2023, 13:24:51 UTC
back to top