604dc24 | jax authors | 18 October 2023, 20:28:43 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/ecb73da4b7b2e3b54aa6d6b7f08a5c662bb19c6e. PiperOrigin-RevId: 574576317 | 18 October 2023, 20:39:03 UTC |
f10e333 | Yash Katariya | 18 October 2023, 20:25:12 UTC | Start release for jax 0.4.19 PiperOrigin-RevId: 574575158 | 18 October 2023, 20:28:52 UTC |
3778265 | jax authors | 18 October 2023, 20:18:20 UTC | Merge pull request #18126 from niqodea:wrapcauchy PiperOrigin-RevId: 574572631 | 18 October 2023, 20:18:20 UTC |
88fe0da | jax authors | 18 October 2023, 18:56:01 UTC | Merge pull request #18078 from ROCmSoftwarePlatform:rocm-jax-triton PiperOrigin-RevId: 574546618 | 18 October 2023, 18:56:01 UTC |
9435a0a | jax authors | 18 October 2023, 18:35:02 UTC | Merge pull request #18138 from mattjj:shmap-axis-env-fix PiperOrigin-RevId: 574540561 | 18 October 2023, 18:35:02 UTC |
d55085f | Jevin Jiang | 18 October 2023, 17:43:06 UTC | [Mosaic] Support tpu.concatenate along the tiling dims as long as the shapes are aligned to native tiling. PiperOrigin-RevId: 574523694 | 18 October 2023, 17:43:54 UTC |
cf65480 | George Necula | 18 October 2023, 12:53:36 UTC | [XlaCallModule] Drop support for dim_args_spec attribute. This attribute was used to support shape polymorphism in versions up to and including version 4. Starting on March 28th 2023 with JAX version 0.4.6 we stopped using this attribute. We are now beyond the 6 month backward compatibility version and we drop support for this attribute. We also increase the minimum supported serialization version to 5. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions PiperOrigin-RevId: 574450204 | 18 October 2023, 12:54:11 UTC |
890b762 | Nicola De Angeli | 16 October 2023, 01:45:24 UTC | feat: add wrapcauchy logpdf and pdf | 18 October 2023, 11:47:10 UTC |
e7dff2c | jax authors | 18 October 2023, 09:05:59 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/761b642893e16450e965d1c89d797d78a53b06fe. PiperOrigin-RevId: 574406486 | 18 October 2023, 09:06:38 UTC |
b4b97cd | Rahul Batra | 20 September 2023, 18:43:16 UTC | [ROCm]: Add jax-triton support for ROCm | 18 October 2023, 07:09:20 UTC |
5904747 | jax authors | 18 October 2023, 02:10:45 UTC | Merge pull request #18156 from jakevdp:seed_with_impl PiperOrigin-RevId: 574327521 | 18 October 2023, 02:10:45 UTC |
8353027 | jax authors | 17 October 2023, 23:41:01 UTC | Merge pull request #18070 from sagyakwa:patch-1 PiperOrigin-RevId: 574300558 | 17 October 2023, 23:41:01 UTC |
7498377 | jax authors | 17 October 2023, 22:46:13 UTC | Add `GetOpSharding` to XLA/PjRt utils. PiperOrigin-RevId: 574287268 | 17 October 2023, 22:46:52 UTC |
e9da4f5 | Samuel Agyakwa | 17 October 2023, 22:30:24 UTC | Merge branch 'google:main' into patch-1 | 17 October 2023, 22:30:24 UTC |
6da4750 | Jake VanderPlas | 17 October 2023, 20:18:08 UTC | [random] remove internal uses of deprecated prng.seed_with_impl() | 17 October 2023, 20:18:08 UTC |
d03bbc0 | jax authors | 17 October 2023, 19:58:39 UTC | random_lax_test: Bump shards number for CPU config. PiperOrigin-RevId: 574239793 | 17 October 2023, 19:59:12 UTC |
86023f5 | Sharad Vikram | 17 October 2023, 18:19:28 UTC | [Pallas TPU] Add DMA descriptor abstraction for constructing but not starting DMAs PiperOrigin-RevId: 574210634 | 17 October 2023, 18:20:05 UTC |
c16b893 | Chris Jones | 17 October 2023, 18:00:13 UTC | [pallas:gpu] Simplify `broadcast_to`, `min`, `max` lowering. PiperOrigin-RevId: 574204406 | 17 October 2023, 18:00:50 UTC |
2c9ea51 | Tomás Longeri | 17 October 2023, 17:46:35 UTC | [Mosaic] apply_vector_layout C++ rewrite: add tpu.concatenate PiperOrigin-RevId: 574199634 | 17 October 2023, 17:47:55 UTC |
b84ae98 | Adam Paszke | 17 October 2023, 16:05:07 UTC | Make sure we don't filter stack frames of packages that start with a `jax` prefix This ended up accidentally setting up filters for jax_triton. This change additionally adds an opt-in mechanism for paths, that overrides exclusions. We use this to avoid treating pallas ops implementations as JAX-internal. PiperOrigin-RevId: 574167963 | 17 October 2023, 16:06:30 UTC |
2be6019 | jax authors | 17 October 2023, 11:23:44 UTC | Rollback to fix internal breakage Reverts 7d203aebfa6206affde207c884b50172e203d177 PiperOrigin-RevId: 574101804 | 17 October 2023, 11:24:15 UTC |
2604d0c | jax authors | 17 October 2023, 08:36:18 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/147a47aa0dea586e2c21a565d58a370eca22ac7d. PiperOrigin-RevId: 574064031 | 17 October 2023, 08:36:56 UTC |
497f809 | Jieying Luo | 17 October 2023, 04:55:57 UTC | Re-enable the test as the GPU plugin profiler fixes are in. Reverts 4cd4f3f3b380755ee4738aa73fd2a834e6a61fd7 PiperOrigin-RevId: 574020446 | 17 October 2023, 04:56:46 UTC |
6540810 | jax authors | 17 October 2023, 02:33:59 UTC | Merge pull request #18117 from jakevdp:ad-deprecations PiperOrigin-RevId: 573996165 | 17 October 2023, 02:33:59 UTC |
7d203ae | jax authors | 17 October 2023, 02:22:41 UTC | Merge pull request #18105 from jakevdp:keyarray PiperOrigin-RevId: 573995089 | 17 October 2023, 02:22:41 UTC |
43fc423 | Jieying Luo | 16 October 2023, 23:12:26 UTC | Temporarily set TPU_LIBRARY_PATH in xla_bridge. This will be removed once tpu_tracer removes its dependency on TPU_LIBRARY_PATH (should be within next two weeks). PiperOrigin-RevId: 573958529 | 16 October 2023, 23:13:05 UTC |
89b5449 | Peter Hawkins | 16 October 2023, 23:01:34 UTC | [XLA:GPU] Fix bug in all-to-all for complex data types. The multiplier for complex data types wasn't being applied correctly; the chunk_bytes calculation double-applied the multiplier. Fixes https://github.com/google/jax/issues/18122 PiperOrigin-RevId: 573955671 | 16 October 2023, 23:02:22 UTC |
5919c1f | jax authors | 16 October 2023, 22:46:23 UTC | Merge pull request #18104 from gnecula:multi_jax2tf PiperOrigin-RevId: 573951693 | 16 October 2023, 22:46:23 UTC |
f285359 | jax authors | 16 October 2023, 22:03:35 UTC | Merge pull request #18120 from fehiepsi:slogdet PiperOrigin-RevId: 573939805 | 16 October 2023, 22:03:35 UTC |
93fbf62 | Tao Wang | 16 October 2023, 21:51:23 UTC | Fix testProfilerGetFDOProfile. PiperOrigin-RevId: 573936890 | 16 October 2023, 21:52:44 UTC |
3bfe1d2 | Matthew Johnson | 16 October 2023, 19:36:51 UTC | [shard_map] fix axis env extension bug Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com> | 16 October 2023, 19:36:51 UTC |
dcc92e3 | Chris Jones | 16 October 2023, 19:35:06 UTC | [pallas] `dot` fixes. - Check that operands are 2D. - Set `preferred_element_type`. - Fix dot output type on GPU. PiperOrigin-RevId: 573895904 | 16 October 2023, 19:35:43 UTC |
0eef3b5 | Jake VanderPlas | 16 October 2023, 18:22:37 UTC | jax.interpreters.ad: deprecate inadvertently exported symbols | 16 October 2023, 18:22:37 UTC |
675cb15 | Chris Jones | 16 October 2023, 18:16:34 UTC | [pallas:gpu] Directly use Triton built-in function lowering for `max_contiguous` and `multiple_of`. PiperOrigin-RevId: 573871324 | 16 October 2023, 18:17:10 UTC |
ac2c228 | Chris Jones | 16 October 2023, 17:59:29 UTC | [pallas:gpu] Simplify reshape lowering rule. PiperOrigin-RevId: 573865529 | 16 October 2023, 18:00:28 UTC |
0290150 | Jieying Luo | 16 October 2023, 16:58:30 UTC | Build jaxlib without PJRT GPU deps when plugin will be built. PiperOrigin-RevId: 573844805 | 16 October 2023, 16:59:07 UTC |
b65c1b2 | George Necula | 13 October 2023, 17:30:11 UTC | [jax2tf] First step to enable multi-platform native lowering Enable experiments with jax2tf native serialization for multiple platforms. This feature is not yet fully functional but we need this change to enable further testing. Cleanup some of the places that are specific to single-platform serialization, e.g., `lowering_platform`, and generalize them to multiple platforms (`lowering_platforms`). | 16 October 2023, 14:01:23 UTC |
110b8d7 | jax authors | 16 October 2023, 07:50:48 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/398b89072cf1e9e0b53b503306aed8fbf520eec7. PiperOrigin-RevId: 573731910 | 16 October 2023, 07:51:42 UTC |
8bf605a | jax authors | 15 October 2023, 07:41:00 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/a5187a333093f0815d82e0548b45f4d235a63e3d. PiperOrigin-RevId: 573575839 | 15 October 2023, 07:41:36 UTC |
f9e4ab7 | jax authors | 14 October 2023, 19:44:53 UTC | Merge pull request #18112 from jakevdp:broadcast-shape PiperOrigin-RevId: 573503405 | 14 October 2023, 19:44:53 UTC |
8d700df | jax authors | 14 October 2023, 14:19:37 UTC | Merge pull request #18111 from superbobry:fix-flags-underscore PiperOrigin-RevId: 573466372 | 14 October 2023, 14:19:37 UTC |
c97ecf4 | Du Phan | 14 October 2023, 12:22:33 UTC | Enhance the speed of slogdet with qr method | 14 October 2023, 12:22:33 UTC |
9a3ec22 | jax authors | 14 October 2023, 08:20:28 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/975cf53e0032484890f677a4242ff0565d1ebfbe. PiperOrigin-RevId: 573417519 | 14 October 2023, 08:21:04 UTC |
12c2baa | jax authors | 13 October 2023, 20:46:17 UTC | Merge pull request #18110 from jakevdp:jex-redirect PiperOrigin-RevId: 573305906 | 13 October 2023, 20:46:17 UTC |
1815bc7 | Jake VanderPlas | 13 October 2023, 20:37:20 UTC | [typing] allow scalar shape for jnp.broadcast_to | 13 October 2023, 20:37:20 UTC |
f9087ab | Sergei Lebedev | 13 October 2023, 20:27:14 UTC | MAINT Drop underscore from the name of externally-referenced state objects | 13 October 2023, 20:30:13 UTC |
4cd4f3f | Jieying Luo | 13 October 2023, 20:22:16 UTC | Disable pgle_test.py for GPU plugin. PiperOrigin-RevId: 573304221 | 13 October 2023, 20:25:11 UTC |
16061e6 | jax authors | 13 October 2023, 20:24:57 UTC | Merge pull request #17980 from gnecula:multi_collective PiperOrigin-RevId: 573302491 | 13 October 2023, 20:24:57 UTC |
ce4d0d2 | jax authors | 13 October 2023, 20:14:56 UTC | Merge pull request #18108 from google:libtpu_import_fix PiperOrigin-RevId: 573302310 | 13 October 2023, 20:14:56 UTC |
2edb66d | Jake VanderPlas | 13 October 2023, 19:49:05 UTC | jax.core: point deprecation to jax.extend | 13 October 2023, 19:49:05 UTC |
4ba7590 | Jake VanderPlas | 13 October 2023, 19:23:13 UTC | export jax.extend.source_info_util.current PiperOrigin-RevId: 573290435 | 13 October 2023, 19:31:11 UTC |
8fe4fcc | David Majnemer | 13 October 2023, 19:20:22 UTC | Use totalorder comparisons for sort PiperOrigin-RevId: 573289718 | 13 October 2023, 19:21:07 UTC |
4e34fe0 | Skye Wanderman-Milne | 13 October 2023, 18:20:46 UTC | Fix libtpu path on older jaxlibs. This is a follow-up to https://github.com/google/jax/commit/b81a3e1fd774ebdbc3015f1bc977bfacb5d4b745. We still need to set TPU_LIBRARY_PATH for jaxlibs that don't support the new mechanism for passing in the libtpu path. | 13 October 2023, 18:39:15 UTC |
c568110 | Tao Wang | 13 October 2023, 18:33:47 UTC | Set up an API to top trace and fdo profile in memory. PiperOrigin-RevId: 573276173 | 13 October 2023, 18:34:25 UTC |
a59ada0 | George Necula | 06 October 2023, 07:36:14 UTC | [export] Adapt several collective lowering rules for multi-platform lowering This fixes a few more places where the lowering rules used module_context.platform, which is not supported for multi-platform lowering. | 13 October 2023, 18:15:41 UTC |
a2623f2 | Jake VanderPlas | 13 October 2023, 18:10:05 UTC | [random] Avoid references to PRNGKeyArray type See https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html | 13 October 2023, 18:10:05 UTC |
e088a8e | David Majnemer | 13 October 2023, 17:42:54 UTC | Remove ancient XLA:TPU workaround We've supported S16 for a long time now PiperOrigin-RevId: 573260447 | 13 October 2023, 17:48:53 UTC |
39a3cfc | jax authors | 13 October 2023, 17:48:35 UTC | Merge pull request #18095 from parikshitadhikari:patch1 PiperOrigin-RevId: 573259065 | 13 October 2023, 17:48:35 UTC |
7a2837b | jax authors | 13 October 2023, 17:37:42 UTC | Merge pull request #18102 from jakevdp:jex-source-info-util PiperOrigin-RevId: 573258547 | 13 October 2023, 17:37:42 UTC |
699ae85 | jax authors | 13 October 2023, 17:26:13 UTC | Merge pull request #18096 from parikshitadhikari:patch2 PiperOrigin-RevId: 573255567 | 13 October 2023, 17:26:13 UTC |
7478fbc | Jieying Luo | 13 October 2023, 17:11:37 UTC | [PJRT C API] Add "cuda_plugin_extension" to "gpu_only_test_deps" to support bazel test for GPU plugin. PiperOrigin-RevId: 573251982 | 13 October 2023, 17:12:16 UTC |
1a5d6e8 | Samuel Agyakwa | 13 October 2023, 16:47:33 UTC | Merge remote-tracking branch 'origin/patch-1' into patch-1 | 13 October 2023, 16:47:33 UTC |
0cd6ef0 | Samuel Agyakwa | 12 October 2023, 23:43:03 UTC | Merge remote-tracking branch 'origin/patch-1' into patch-1 | 13 October 2023, 16:47:10 UTC |
4e463c0 | Jake VanderPlas | 13 October 2023, 16:36:00 UTC | JEX: add jax.extend.source_info_util | 13 October 2023, 16:36:00 UTC |
1ae2bbc | jax authors | 13 October 2023, 16:06:10 UTC | Merge pull request #17892 from jakevdp:typing-test PiperOrigin-RevId: 573236275 | 13 October 2023, 16:06:10 UTC |
2c64a0a | Jake VanderPlas | 13 October 2023, 15:30:08 UTC | typing: add some type assertions to typing_test | 13 October 2023, 15:30:08 UTC |
3c283db | parikshit adhikari | 13 October 2023, 15:20:43 UTC | fixed typo 'primtive' as 'primitive' in How_JAX_primitives_work.ipynb | 13 October 2023, 15:20:43 UTC |
44daa56 | jax authors | 13 October 2023, 09:16:15 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/cc2782f3ebfb3868e8edb8aa67b82cff0040d7b2. PiperOrigin-RevId: 573150536 | 13 October 2023, 09:16:59 UTC |
ae2d6e2 | Chris Jones | 13 October 2023, 08:59:00 UTC | [pallas:gpu] Implement `get` and `swap` using `load` and `masked_swap` lowering rules. PiperOrigin-RevId: 573146382 | 13 October 2023, 09:00:03 UTC |
2bc2e17 | Chris Jones | 13 October 2023, 08:33:41 UTC | [pallas:gpu] Fix `swap` Triton lowering. PiperOrigin-RevId: 573141426 | 13 October 2023, 08:48:20 UTC |
0da5828 | Chris Jones | 13 October 2023, 08:32:50 UTC | [pallas] Simplify `Slice.from_slice` code and add check for `Slice.size`. Slice behaviour for negative and out-of-range start/stop values now matches standard Python behaviour. PiperOrigin-RevId: 573141218 | 13 October 2023, 08:36:28 UTC |
8a2c5d6 | Jieying Luo | 13 October 2023, 05:59:27 UTC | [PJRT C API] Set the gpu plugin allocator related options. PiperOrigin-RevId: 573111513 | 13 October 2023, 06:00:15 UTC |
432506f | Jieying Luo | 13 October 2023, 04:24:48 UTC | [PJRT C API] Fixed pjrt_c_api_gpu and remove `noincompatible_remove_legacy_whole_archive` PiperOrigin-RevId: 573094387 | 13 October 2023, 04:25:25 UTC |
ba1af01 | parikshit adhikari | 13 October 2023, 03:09:51 UTC | fix: typo inside docs/notebooks/How_JAX_primitives_work.md | 13 October 2023, 03:09:51 UTC |
e21409f | parikshit adhikari | 13 October 2023, 03:06:01 UTC | fix: typo inside docs/jep/12049-type-annotations.md | 13 October 2023, 03:06:01 UTC |
f5a1439 | Yash Katariya | 13 October 2023, 00:51:35 UTC | Delete cuda 12.0.1 rbe configs since JAX doesn't support it anymore PiperOrigin-RevId: 573059967 | 13 October 2023, 00:52:11 UTC |
7bbd265 | Chris Jones | 13 October 2023, 00:22:54 UTC | [pallas:gpu] De-duplicate lowering code for JAX primitives that map to Triton builtins. Add several missing ops in the process. PiperOrigin-RevId: 573054532 | 13 October 2023, 00:41:20 UTC |
f9c6387 | Chris Jones | 13 October 2023, 00:21:01 UTC | [pallas:gpu] Minor fix to Triton lowering. PiperOrigin-RevId: 573054173 | 13 October 2023, 00:31:04 UTC |
ccc7113 | Chris Jones | 13 October 2023, 00:18:43 UTC | [pallas:gpu] Minor cleanup to `div` lowering. PiperOrigin-RevId: 573053669 | 13 October 2023, 00:19:30 UTC |
ef20526 | Yash Katariya | 12 October 2023, 23:54:34 UTC | Return PositionalSharding if input's rank is >= 3 or a NamedSharding if a mesh is available via the context from inspect_array_sharding. Never return GSPMDSharding from inspect_array_sharding. PiperOrigin-RevId: 573048344 | 12 October 2023, 23:55:12 UTC |
726422f | Samuel Agyakwa | 12 October 2023, 23:43:03 UTC | Merge remote-tracking branch 'origin/patch-1' into patch-1 | 12 October 2023, 23:43:03 UTC |
91f1745 | Samuel Agyakwa | 12 October 2023, 19:52:35 UTC | making ENABLE_PJRT_COMPATIBILITY conditionally true | 12 October 2023, 23:42:39 UTC |
489cd44 | jax authors | 12 October 2023, 23:36:35 UTC | Merge pull request #18022 from 8bitmp3:add-banner-xmap PiperOrigin-RevId: 573044352 | 12 October 2023, 23:36:35 UTC |
7e2e4a0 | 8bitmp3 | 12 October 2023, 21:41:51 UTC | Add jit/shard_map banner to xmap docs Add jit/shard_map banner to xmap docs | 12 October 2023, 22:06:15 UTC |
65cfe1a | jax authors | 12 October 2023, 21:51:18 UTC | Instrument metrics for the new JAX compilation cache key generation algorithm. Metrics: 1) '/jax/compilation_cache/cache_hits' to track the number of times the cached executable is successfully returned from a cache read using the new implementation. 2) '/jax/compilation_cache/compile_time_saved_sec' to record the time saved on cache hits using the new implementation. PiperOrigin-RevId: 573019115 | 12 October 2023, 21:56:02 UTC |
6fb776b | Jieying Luo | 12 October 2023, 21:44:35 UTC | Update is_device_cuda to support testing for GPU plugin. GPU plugin platform version is "PJRT C API\ncuda ...". PiperOrigin-RevId: 573017348 | 12 October 2023, 21:45:24 UTC |
ab161bb | George Necula | 12 October 2023, 20:32:47 UTC | Cleanup lowering rule for hlo_unshard, to remove platform dependence. PiperOrigin-RevId: 572997889 | 12 October 2023, 20:36:11 UTC |
a39d189 | Samuel Agyakwa | 12 October 2023, 19:52:35 UTC | Merge remote-tracking branch 'origin/patch-1' into patch-1 | 12 October 2023, 19:52:35 UTC |
d69e810 | Samuel Agyakwa | 12 October 2023, 19:29:49 UTC | Merge branch 'google:main' into patch-1 | 12 October 2023, 19:50:25 UTC |
3070748 | Samuel Agyakwa | 12 October 2023, 19:29:49 UTC | Merge branch 'google:main' into patch-1 | 12 October 2023, 19:29:49 UTC |
294fe80 | jax authors | 12 October 2023, 18:30:38 UTC | Merge pull request #18041 from hawkinsp:rpath PiperOrigin-RevId: 572965294 | 12 October 2023, 18:30:38 UTC |
4031c60 | jax authors | 12 October 2023, 17:00:23 UTC | Merge pull request #18079 from superbobry:state-objects PiperOrigin-RevId: 572937114 | 12 October 2023, 17:00:23 UTC |
d856ecc | Peter Hawkins | 10 October 2023, 14:57:56 UTC | Set RPATH, not RUNPATH in JAX CUDA builds. Fixes https://github.com/google/jax/issues/17497 | 12 October 2023, 16:38:10 UTC |
cbcaac2 | Sergei Lebedev | 12 October 2023, 12:15:22 UTC | MAINT Migrate remaining internal/test modules to use state objects The motivation here is to gradually replace all dynamic lookups on `jax.config` with statically-typed state objects, which are more type checker/IDE friendly. This is a follow up to #18008. | 12 October 2023, 16:32:15 UTC |
a736caa | jax authors | 12 October 2023, 15:40:15 UTC | Merge pull request #17927 from gnecula:multi_cumsum PiperOrigin-RevId: 572913695 | 12 October 2023, 15:40:15 UTC |
42ab110 | jax authors | 12 October 2023, 15:27:06 UTC | Merge pull request #18001 from gnecula:export_shard_map PiperOrigin-RevId: 572913449 | 12 October 2023, 15:27:06 UTC |
a06a5aa | Adam Paszke | 12 October 2023, 14:37:22 UTC | [Pallas] Use tpu.matmul instead of vector.contract in Mosaic lowering This will let us do mixed precision matmuls, which are rejected by the vector.contract verifier. PiperOrigin-RevId: 572901961 | 12 October 2023, 14:39:19 UTC |
bdb4eda | jax authors | 12 October 2023, 14:27:42 UTC | Merge pull request #18075 from jakevdp:callback-doc PiperOrigin-RevId: 572896857 | 12 October 2023, 14:27:42 UTC |
ac6e779 | jax authors | 12 October 2023, 14:15:13 UTC | Merge pull request #18030 from jakevdp:core-cleanup PiperOrigin-RevId: 572896527 | 12 October 2023, 14:15:13 UTC |
6f90f65 | George Necula | 04 October 2023, 18:39:18 UTC | [export] Adapt lowering caching to work with multi-platform lowering Previously the mlir.cache_lowering was assuming that a primitive has a unique lowering in a module for given input and output avals. But with multi-platform lowering we need to allow multiple lowerings. We fix this by adding the lowering function to the cache key. This fixes the multi-platform lowering tests for cumsum and cumprod. | 12 October 2023, 13:24:51 UTC |