https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
24c08b6 [XlaCallModule] Allow i64 platform index arguments. Previously, for multi-platform serialization the platform index argument was required to be an i32. Now we allow also i64, just like we do for dimension variables. This flexibility is useful for JAX when running in 64-bit mode. PiperOrigin-RevId: 573316975 13 October 2023, 21:27:27 UTC
12c2baa Merge pull request #18110 from jakevdp:jex-redirect PiperOrigin-RevId: 573305906 13 October 2023, 20:46:17 UTC
4cd4f3f Disable pgle_test.py for GPU plugin. PiperOrigin-RevId: 573304221 13 October 2023, 20:25:11 UTC
16061e6 Merge pull request #17980 from gnecula:multi_collective PiperOrigin-RevId: 573302491 13 October 2023, 20:24:57 UTC
ce4d0d2 Merge pull request #18108 from google:libtpu_import_fix PiperOrigin-RevId: 573302310 13 October 2023, 20:14:56 UTC
2edb66d jax.core: point deprecation to jax.extend 13 October 2023, 19:49:05 UTC
4ba7590 export jax.extend.source_info_util.current PiperOrigin-RevId: 573290435 13 October 2023, 19:31:11 UTC
8fe4fcc Use totalorder comparisons for sort PiperOrigin-RevId: 573289718 13 October 2023, 19:21:07 UTC
4e34fe0 Fix libtpu path on older jaxlibs. This is a follow-up to https://github.com/google/jax/commit/b81a3e1fd774ebdbc3015f1bc977bfacb5d4b745. We still need to set TPU_LIBRARY_PATH for jaxlibs that don't support the new mechanism for passing in the libtpu path. 13 October 2023, 18:39:15 UTC
c568110 Set up an API to top trace and fdo profile in memory. PiperOrigin-RevId: 573276173 13 October 2023, 18:34:25 UTC
a59ada0 [export] Adapt several collective lowering rules for multi-platform lowering This fixes a few more places where the lowering rules used module_context.platform, which is not supported for multi-platform lowering. 13 October 2023, 18:15:41 UTC
e088a8e Remove ancient XLA:TPU workaround We've supported S16 for a long time now PiperOrigin-RevId: 573260447 13 October 2023, 17:48:53 UTC
39a3cfc Merge pull request #18095 from parikshitadhikari:patch1 PiperOrigin-RevId: 573259065 13 October 2023, 17:48:35 UTC
7a2837b Merge pull request #18102 from jakevdp:jex-source-info-util PiperOrigin-RevId: 573258547 13 October 2023, 17:37:42 UTC
699ae85 Merge pull request #18096 from parikshitadhikari:patch2 PiperOrigin-RevId: 573255567 13 October 2023, 17:26:13 UTC
7478fbc [PJRT C API] Add "cuda_plugin_extension" to "gpu_only_test_deps" to support bazel test for GPU plugin. PiperOrigin-RevId: 573251982 13 October 2023, 17:12:16 UTC
4e463c0 JEX: add jax.extend.source_info_util 13 October 2023, 16:36:00 UTC
1ae2bbc Merge pull request #17892 from jakevdp:typing-test PiperOrigin-RevId: 573236275 13 October 2023, 16:06:10 UTC
2c64a0a typing: add some type assertions to typing_test 13 October 2023, 15:30:08 UTC
3c283db fixed typo 'primtive' as 'primitive' in How_JAX_primitives_work.ipynb 13 October 2023, 15:20:43 UTC
44daa56 Update XLA dependency to use revision http://github.com/openxla/xla/commit/cc2782f3ebfb3868e8edb8aa67b82cff0040d7b2. PiperOrigin-RevId: 573150536 13 October 2023, 09:16:59 UTC
ae2d6e2 [pallas:gpu] Implement `get` and `swap` using `load` and `masked_swap` lowering rules. PiperOrigin-RevId: 573146382 13 October 2023, 09:00:03 UTC
2bc2e17 [pallas:gpu] Fix `swap` Triton lowering. PiperOrigin-RevId: 573141426 13 October 2023, 08:48:20 UTC
0da5828 [pallas] Simplify `Slice.from_slice` code and add check for `Slice.size`. Slice behaviour for negative and out-of-range start/stop values now matches standard Python behaviour. PiperOrigin-RevId: 573141218 13 October 2023, 08:36:28 UTC
8a2c5d6 [PJRT C API] Set the gpu plugin allocator related options. PiperOrigin-RevId: 573111513 13 October 2023, 06:00:15 UTC
432506f [PJRT C API] Fixed pjrt_c_api_gpu and remove `noincompatible_remove_legacy_whole_archive` PiperOrigin-RevId: 573094387 13 October 2023, 04:25:25 UTC
ba1af01 fix: typo inside docs/notebooks/How_JAX_primitives_work.md 13 October 2023, 03:09:51 UTC
e21409f fix: typo inside docs/jep/12049-type-annotations.md 13 October 2023, 03:06:01 UTC
f5a1439 Delete cuda 12.0.1 rbe configs since JAX doesn't support it anymore PiperOrigin-RevId: 573059967 13 October 2023, 00:52:11 UTC
7bbd265 [pallas:gpu] De-duplicate lowering code for JAX primitives that map to Triton builtins. Add several missing ops in the process. PiperOrigin-RevId: 573054532 13 October 2023, 00:41:20 UTC
f9c6387 [pallas:gpu] Minor fix to Triton lowering. PiperOrigin-RevId: 573054173 13 October 2023, 00:31:04 UTC
ccc7113 [pallas:gpu] Minor cleanup to `div` lowering. PiperOrigin-RevId: 573053669 13 October 2023, 00:19:30 UTC
ef20526 Return PositionalSharding if input's rank is >= 3 or a NamedSharding if a mesh is available via the context from inspect_array_sharding. Never return GSPMDSharding from inspect_array_sharding. PiperOrigin-RevId: 573048344 12 October 2023, 23:55:12 UTC
489cd44 Merge pull request #18022 from 8bitmp3:add-banner-xmap PiperOrigin-RevId: 573044352 12 October 2023, 23:36:35 UTC
7e2e4a0 Add jit/shard_map banner to xmap docs Add jit/shard_map banner to xmap docs 12 October 2023, 22:06:15 UTC
65cfe1a Instrument metrics for the new JAX compilation cache key generation algorithm. Metrics: 1) '/jax/compilation_cache/cache_hits' to track the number of times the cached executable is successfully returned from a cache read using the new implementation. 2) '/jax/compilation_cache/compile_time_saved_sec' to record the time saved on cache hits using the new implementation. PiperOrigin-RevId: 573019115 12 October 2023, 21:56:02 UTC
6fb776b Update is_device_cuda to support testing for GPU plugin. GPU plugin platform version is "PJRT C API\ncuda ...". PiperOrigin-RevId: 573017348 12 October 2023, 21:45:24 UTC
ab161bb Cleanup lowering rule for hlo_unshard, to remove platform dependence. PiperOrigin-RevId: 572997889 12 October 2023, 20:36:11 UTC
294fe80 Merge pull request #18041 from hawkinsp:rpath PiperOrigin-RevId: 572965294 12 October 2023, 18:30:38 UTC
4031c60 Merge pull request #18079 from superbobry:state-objects PiperOrigin-RevId: 572937114 12 October 2023, 17:00:23 UTC
d856ecc Set RPATH, not RUNPATH in JAX CUDA builds. Fixes https://github.com/google/jax/issues/17497 12 October 2023, 16:38:10 UTC
cbcaac2 MAINT Migrate remaining internal/test modules to use state objects The motivation here is to gradually replace all dynamic lookups on `jax.config` with statically-typed state objects, which are more type checker/IDE friendly. This is a follow up to #18008. 12 October 2023, 16:32:15 UTC
a736caa Merge pull request #17927 from gnecula:multi_cumsum PiperOrigin-RevId: 572913695 12 October 2023, 15:40:15 UTC
42ab110 Merge pull request #18001 from gnecula:export_shard_map PiperOrigin-RevId: 572913449 12 October 2023, 15:27:06 UTC
a06a5aa [Pallas] Use tpu.matmul instead of vector.contract in Mosaic lowering This will let us do mixed precision matmuls, which are rejected by the vector.contract verifier. PiperOrigin-RevId: 572901961 12 October 2023, 14:39:19 UTC
bdb4eda Merge pull request #18075 from jakevdp:callback-doc PiperOrigin-RevId: 572896857 12 October 2023, 14:27:42 UTC
ac6e779 Merge pull request #18030 from jakevdp:core-cleanup PiperOrigin-RevId: 572896527 12 October 2023, 14:15:13 UTC
6f90f65 [export] Adapt lowering caching to work with multi-platform lowering Previously the mlir.cache_lowering was assuming that a primitive has a unique lowering in a module for given input and output avals. But with multi-platform lowering we need to allow multiple lowerings. We fix this by adding the lowering function to the cache key. This fixes the multi-platform lowering tests for cumsum and cumprod. 12 October 2023, 13:24:51 UTC
65e86e3 [export] Fix the call_exported in presence of shardings. Previously, when we call_exported of an Exported module with shardings, we invoke the right HLO but the enclosing JAX computation does not know about the shardings of the called module. This results in errors when invoking the calling module. We change call_exported lowering rules to add sharding constraints for the inputs and the outputs and we add a check that we call the exported module on the same number of devices as at export time. 12 October 2023, 13:19:20 UTC
b6f7441 [pallas] Add Triton lowering rule for `custom_jvp_call_p`. PiperOrigin-RevId: 572852113 12 October 2023, 10:51:32 UTC
eba73a7 Update XLA dependency to use revision http://github.com/openxla/xla/commit/500c965b04709f15008c26b46df2f8406279d730. PiperOrigin-RevId: 572817863 12 October 2023, 08:19:39 UTC
b1e9628 Testing triton integration 2023-09-14 PiperOrigin-RevId: 572808737 12 October 2023, 07:35:36 UTC
7dd1d81 DOC: mention ShapeDtypeStruct in callback docstrings 12 October 2023, 02:55:39 UTC
e0944c9 jax.core: deprecate some inadvertent exports 11 October 2023, 22:22:19 UTC
9ae5a43 Merge pull request #18071 from hawkinsp:jitlink PiperOrigin-RevId: 572702571 11 October 2023, 22:20:29 UTC
927a182 Fix typo in setup.py 11 October 2023, 22:15:18 UTC
a1c4dc0 Merge pull request #18066 from jakevdp:update-mypy PiperOrigin-RevId: 572681129 11 October 2023, 21:07:32 UTC
4d247a4 Merge pull request #18062 from hawkinsp:jitlink PiperOrigin-RevId: 572668295 11 October 2023, 20:23:30 UTC
a794beb CI: update mypy to v1.6.0 11 October 2023, 19:54:51 UTC
b8c0ca1 Add a version constraint on nvidia-nvjitlink-cu12. This works around a missing version constraint on NVIDIA's CUDA packages, for example nvidia-cusolver-cu12 should have a versioned dependency on nvidia-nvjitlink-cu12. Issue https://github.com/google/jax/issues/18027 11 October 2023, 16:52:20 UTC
2f70ae7 Migrate another subset of internal modules to use state objects The motivation here is to gradually replace all dynamic lookups on `jax.config` with statically-typed state objects, which are more type checker/IDE friendly. This is a follow up to #18008. PiperOrigin-RevId: 572587137 11 October 2023, 15:46:06 UTC
04422b8 Update XLA dependency to use revision http://github.com/openxla/xla/commit/50baaece6db8e61ada5db1f49f19a9d7dc1fa297. PiperOrigin-RevId: 572492083 11 October 2023, 07:53:34 UTC
00e16de Merge pull request #18052 from jakevdp:prngkey-abstractify PiperOrigin-RevId: 572433469 11 October 2023, 02:27:12 UTC
fd09b35 Optimize make_array_from_callback for fully replicated shardings by going via batched_device_put Before: ``` name cpu/op bench_make_array_from_callback_fully_replicated_sharding 467µs ± 3% name time/op bench_make_array_from_callback_fully_replicated_sharding 467µs ± 3% ``` After: ``` name cpu/op bench_make_array_from_callback_fully_replicated_sharding 28.1µs ± 2% name time/op bench_make_array_from_callback_fully_replicated_sharding 28.1µs ± 2% ``` PiperOrigin-RevId: 572429822 11 October 2023, 02:02:04 UTC
e5c2a2c [random] add shaped_abstractify handler for custom PRNG key 10 October 2023, 23:15:19 UTC
899cc30 Merge pull request #18048 from superbobry:all-any-list-comp PiperOrigin-RevId: 572384827 10 October 2023, 22:32:55 UTC
2c6fbc2 [pallas] Simplify several primitives bindings. PiperOrigin-RevId: 572370538 10 October 2023, 21:46:04 UTC
4017a7b [pallas] Minor cleanup of `Slice` code. PiperOrigin-RevId: 572369345 10 October 2023, 21:35:13 UTC
5d9c39f MAINT Use a generator expression with `all()` and `any()` There is no reason to allocate a list only for the purpose of iteration. 10 October 2023, 21:33:03 UTC
5ed6928 Merge pull request #18044 from hawkinsp:readme PiperOrigin-RevId: 572338172 10 October 2023, 19:46:36 UTC
b81a3e1 Remove calling configure_library_path during jax import and get libtpu path from libtpu_module.get_library_path(). PiperOrigin-RevId: 572306461 10 October 2023, 17:59:37 UTC
269d7ce Remove take_ownership support in DLPack. When take_ownership is true, the original buffer is marked as deleted and enforced that JAX won't attempt to read or write the buffer. This provides better error checking but at the cost of one more C++ API and two more C APIs. The same semantic can be achieved by not using take_ownership and being careful. Therefore we decided to remove take_ownership support in DLPack. PiperOrigin-RevId: 572278488 10 October 2023, 16:43:02 UTC
04bee99 Update references to aarch64 support in the README. 10 October 2023, 16:36:29 UTC
8f283bc Merge pull request #18042 from jakevdp:dtypelike PiperOrigin-RevId: 572268035 10 October 2023, 16:12:00 UTC
911f745 Make jax._src.typing.DTypeLike more strictly defined This is in preparation for exporting this to `jax.typing.DTypeLike`. Currently this is effectively just Any, and we want to make certain it's a meaningful type before exporting. PiperOrigin-RevId: 572260744 10 October 2023, 16:01:19 UTC
9edc803 Merge pull request #18035 from bb515:patch-1 PiperOrigin-RevId: 572260505 10 October 2023, 15:49:43 UTC
117f4bd Define jax.typing.DTypeLike 10 October 2023, 15:46:36 UTC
f919000 Merge pull request #18040 from hawkinsp:build2 PiperOrigin-RevId: 572259859 10 October 2023, 15:36:43 UTC
73db6ec Set -P when testing whether a package is installed during build.py. (Only on Python 3.11+) The test for the "build" package being installed always succeeded because of the subdirectory named "build". 10 October 2023, 14:30:37 UTC
00dd5b3 Merge pull request #18038 from apaszke:mypy-fix PiperOrigin-RevId: 572227791 10 October 2023, 13:26:52 UTC
3a7000e Fix mypy errors in flash_attention 10 October 2023, 13:05:48 UTC
b22a29f Update README.md so that CPU install instruction works Fix readme current `pip install -U jax[cpu]` has no matches in pip `no matches found: jax[cpu]`. Corrected install instruction to `pip install -U "jax[cpu]"` which successfully installs cpu version of JAX via pip. 10 October 2023, 11:03:16 UTC
acb698e Update XLA dependency to use revision http://github.com/openxla/xla/commit/dc3703f1a77706af0a5e1cc7f7d40281357892fb. PiperOrigin-RevId: 572172179 10 October 2023, 08:56:30 UTC
60b77bd Add segment_ids support to pallas flash attention on TPU. PiperOrigin-RevId: 572172125 10 October 2023, 08:44:35 UTC
f2cda73 [Mosaic] apply_vector_layout C++ rewrite: Fix check for defining layout in disassemble Though `relayout` would guarantee `equivalentTo` holding true, we skip relayout when the source layout generalizes the dest layout (because it's a no-op). PiperOrigin-RevId: 572095024 10 October 2023, 01:32:59 UTC
4c30638 Merge pull request #18023 from jakevdp:make-jaxpr-name PiperOrigin-RevId: 572094635 10 October 2023, 01:22:26 UTC
622472f [Mosaic] apply_vector_layout C++ rewrite: Handle elementwise ops by checking for the Elementwise trait and using the generic Operation interface, without templates PiperOrigin-RevId: 572065184 09 October 2023, 23:02:32 UTC
a86d4dd [Mosaic] apply_vector_layout C++ rewrite (18): vector.transpose PiperOrigin-RevId: 572061743 09 October 2023, 22:48:13 UTC
6ac063d Merge pull request #18017 from jakevdp:fix-frompyfunc PiperOrigin-RevId: 572053296 09 October 2023, 22:14:48 UTC
709b05f jax.make_jaxpr: fix __name__ & related attributes 09 October 2023, 22:12:28 UTC
a108547 [JAX] randint goes from [min, max) so [min - N, max + M] actually maps to [min, max] PiperOrigin-RevId: 572034040 09 October 2023, 21:05:39 UTC
4611d13 Only perform compilation cache writes from process 0. This avoids problems with contending writes on filesystems such as GCS. PiperOrigin-RevId: 572032482 09 October 2023, 20:55:07 UTC
cb51e37 [PJRT C API] Adding Profiler C APIs and related framework changes. C API changes: - Profiler C APIs are added in profiler_c_api.h. - Add a PJRT C API extension for the profiler C APIs in pjrt_c_api_profiler_extension.h. Framework changes: - Add a plugin_tracer that calls profiler C APIs. - Add a pybind method xla_client.profiler.register_plugin_profiler to register plugin_tracer with the plugin's PJRT_Api*. - Update xla_bridge.register_plugin to call register_plugin_profiler to register profiler for that plugin. PiperOrigin-RevId: 572027222 09 October 2023, 20:36:24 UTC
41a7d66 jnp.frompyfunc: fix .at() edge case 09 October 2023, 18:24:25 UTC
84b58ec Increase minimum scipy version to 1.9. Scipy 1.9 appears to fix some crashes on Mac ARM. PiperOrigin-RevId: 571977068 09 October 2023, 17:37:35 UTC
1ec51f4 Merge pull request #18010 from hawkinsp:macfailures PiperOrigin-RevId: 571975100 09 October 2023, 17:26:55 UTC
6a09b7c Merge pull request #18015 from andrinr:main PiperOrigin-RevId: 571960205 09 October 2023, 16:34:11 UTC
f043a80 fix typo in scatter_add doc 09 October 2023, 15:57:41 UTC
d61c340 Merge pull request #18012 from hawkinsp:patch PiperOrigin-RevId: 571945423 09 October 2023, 15:33:05 UTC
37d489b Ignore patch versions of cufft and cusolver in CUDA version checks. Fixes https://github.com/google/jax/issues/18009 09 October 2023, 15:06:17 UTC
back to top