9267ac4 | Michael Levesque-Dion | 21 May 2024, 23:29:43 UTC | Clean up version switches from dense array migration PiperOrigin-RevId: 635965009 | 21 May 2024, 23:35:31 UTC |
ccaf466 | Jevin Jiang | 21 May 2024, 20:51:35 UTC | [XLA:Mosaic] Support retiling from (8, 128, -2) to (8, 128) for 32-bit data. This drops implicit second minor dim and fixes the infer vector layout for concatenate rule when concatenating constant vectors. PiperOrigin-RevId: 635917408 | 21 May 2024, 20:52:26 UTC |
1181136 | jax authors | 21 May 2024, 20:12:33 UTC | Merge pull request #21280 from gnecula:export_device_poly2 PiperOrigin-RevId: 635904723 | 21 May 2024, 20:12:33 UTC |
e750a20 | jax authors | 21 May 2024, 19:51:08 UTC | Merge pull request #21315 from jakevdp:dep-default-registry PiperOrigin-RevId: 635898094 | 21 May 2024, 19:51:08 UTC |
6d5668d | Jake VanderPlas | 21 May 2024, 19:37:37 UTC | jax.tree_util: test serialize_using_proto | 21 May 2024, 19:37:37 UTC |
9948529 | George Necula | 13 May 2024, 11:32:24 UTC | [export] Relax the check that exported modules are used with same number of devices as when exported Now we allow a module exported for 1 device and not using any sharding annotations to be called from a computation that uses multiple devices. Such exported modules can be parallelized trivially point-wise. | 21 May 2024, 18:40:16 UTC |
21aa723 | jax authors | 21 May 2024, 18:19:29 UTC | Merge pull request #21316 from jakevdp:pytree-doc PiperOrigin-RevId: 635869738 | 21 May 2024, 18:19:29 UTC |
92d892b | jax authors | 21 May 2024, 18:07:15 UTC | Adds rewrite patterns for `arith` and `math` operations with `bf16` operands/results that are not supported by the underlying hardware. PiperOrigin-RevId: 635865752 | 21 May 2024, 18:08:19 UTC |
6ff5532 | jax authors | 21 May 2024, 18:02:31 UTC | Merge pull request #21324 from jakevdp:dep-warn-local PiperOrigin-RevId: 635863480 | 21 May 2024, 18:02:31 UTC |
483f924 | Peter Hawkins | 21 May 2024, 17:24:29 UTC | Bump shard count for experimental_rnn_test, which is timing out in CI when built under ASAN. PiperOrigin-RevId: 635850400 | 21 May 2024, 17:25:24 UTC |
5350bc9 | jax authors | 21 May 2024, 17:11:54 UTC | Merge pull request #21258 from olupton:skip-cusolver-test-with-cuda-12.4 PiperOrigin-RevId: 635845926 | 21 May 2024, 17:11:54 UTC |
418b688 | Kyle Lucke | 21 May 2024, 15:39:41 UTC | Automated Code Change PiperOrigin-RevId: 635818645 | 21 May 2024, 15:40:34 UTC |
9ba77f8 | Olli Lupton | 16 May 2024, 09:01:35 UTC | Skip a test when run with cuSolver >= 11.6 This version is shipped with CUDA 12.4. The test assumes that a workspace size baked in with an older version of cuSolver can be used with a newer version of cuSolver. This is not safe, and leads to an error when upgrading from 11.5 to 11.6. | 21 May 2024, 14:46:43 UTC |
4394bdc | Dan Suh | 21 May 2024, 13:41:20 UTC | Change the log message in `pxla.py` to be less confusing. PiperOrigin-RevId: 635789016 | 21 May 2024, 13:42:07 UTC |
d5d2fb0 | Jake VanderPlas | 21 May 2024, 03:28:25 UTC | Ignore deprecation warnings locally rather than globally | 21 May 2024, 03:28:25 UTC |
1327143 | Jake VanderPlas | 21 May 2024, 02:56:47 UTC | Better documentation for jax.tree_util | 21 May 2024, 02:56:47 UTC |
3f1b059 | jax authors | 21 May 2024, 01:46:30 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/c82597f5554394bf56847cbc75ec1b1e54f74d29. PiperOrigin-RevId: 635634941 | 21 May 2024, 01:47:32 UTC |
d33a568 | Jake VanderPlas | 21 May 2024, 00:15:40 UTC | Refactor & test internal deprecation APIs The names and APIs were previously too similar and therefore somewhat confusing; this will be more clear I think. PiperOrigin-RevId: 635615163 | 21 May 2024, 00:16:31 UTC |
2eff241 | jax authors | 21 May 2024, 00:02:44 UTC | Merge pull request #21319 from gnecula:exp_fix_mesh PiperOrigin-RevId: 635611557 | 21 May 2024, 00:02:44 UTC |
6deeee2 | George Necula | 20 May 2024, 13:27:05 UTC | [export] Fix device assignment error for grad of exported. Currently, the export code uses a manufactured device assignment for exporting the VJP function. We should use instead the same device assigment that was used when exporting the primal function. This PR fixes that for the case when the export is done through the direct use of `jax.experimental.export`, and leaves as future work the case when the use is from `jax2tf`. We add a disabled tests for the latter case. Bug: #21314 | 20 May 2024, 23:11:01 UTC |
b197ae5 | Tomás Longeri | 20 May 2024, 22:56:20 UTC | [Mosaic] Also check bitwidth in apply-vector-layout's `layoutIsValidForValue`. PiperOrigin-RevId: 635595321 | 20 May 2024, 22:57:08 UTC |
118ca21 | jax authors | 20 May 2024, 20:31:33 UTC | Merge pull request #21318 from jakevdp:fix-mypy PiperOrigin-RevId: 635553370 | 20 May 2024, 20:31:33 UTC |
329ab03 | Jake VanderPlas | 20 May 2024, 20:23:15 UTC | CI: fix mypy error | 20 May 2024, 20:23:15 UTC |
06d2e48 | Shanbin Ke | 20 May 2024, 18:29:26 UTC | Copybara import of the project: -- f625317cc80639178882316df6f8775294adc6b7 by cjkkkk <ske@nvidia.com>: init COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21228 from Cjkkkk:sdpa_new_cudnn_frontend f625317cc80639178882316df6f8775294adc6b7 PiperOrigin-RevId: 635518631 | 20 May 2024, 18:31:15 UTC |
e0a6453 | George Karpenkov | 20 May 2024, 17:51:27 UTC | Simplify JAX lowering rules for cumulative sum Upstream fix has landed => removing CPU workaround. PiperOrigin-RevId: 635505632 | 20 May 2024, 17:52:29 UTC |
cc3a380 | jax authors | 20 May 2024, 16:51:46 UTC | Add unit test to check if the backend serialization/deserialization result equal to the original executable. PiperOrigin-RevId: 635485374 | 20 May 2024, 16:52:38 UTC |
61ff828 | jax authors | 20 May 2024, 16:07:03 UTC | Add support for TPU delay in Mosaic PiperOrigin-RevId: 635473532 | 20 May 2024, 16:07:56 UTC |
2f45830 | jax authors | 20 May 2024, 13:53:28 UTC | [Mosaic GPU] Prepare matmul example so it can be exposed to other projects. PiperOrigin-RevId: 635442413 | 20 May 2024, 13:54:44 UTC |
f600caa | Sergei Lebedev | 20 May 2024, 13:50:56 UTC | Use @register_lowering to register Pallas GPU lowering rules This leads to slightly more compact code, but should otherwise be identical. PiperOrigin-RevId: 635442002 | 20 May 2024, 13:51:33 UTC |
c3d0b0d | jax authors | 20 May 2024, 13:20:32 UTC | Merge pull request #21305 from jakevdp:scalar-bool PiperOrigin-RevId: 635436437 | 20 May 2024, 13:20:32 UTC |
974c72b | jax authors | 20 May 2024, 12:52:36 UTC | Merge pull request #21292 from ROCm:rv_stable_051624 PiperOrigin-RevId: 635430659 | 20 May 2024, 12:52:36 UTC |
4bac10e | Jake VanderPlas | 20 May 2024, 12:48:49 UTC | Finalize deprecation of the config module. To configure JAX, use `import jax` and reference the config object via `jax.config`. PiperOrigin-RevId: 635430169 | 20 May 2024, 12:49:31 UTC |
bb616ef | jax authors | 20 May 2024, 12:39:21 UTC | Merge pull request #21231 from nouiz:doc_experimental_serialize_executable PiperOrigin-RevId: 635428472 | 20 May 2024, 12:39:21 UTC |
5b28170 | Jake VanderPlas | 20 May 2024, 12:33:36 UTC | Support scalar boolean indices in arr.at[idx].set(vals) | 20 May 2024, 12:33:36 UTC |
53ec2cd | Adam Paszke | 20 May 2024, 08:35:05 UTC | Add notap tag to Mosaic tests PiperOrigin-RevId: 635379982 | 20 May 2024, 08:35:56 UTC |
ffdb9bb | jax authors | 20 May 2024, 01:26:20 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/4c566c945ad38e69bc71b5ee67b741551187e011. PiperOrigin-RevId: 635307171 | 20 May 2024, 01:27:13 UTC |
45a7c22 | Vadym Matsishevskyi | 19 May 2024, 06:38:22 UTC | fix: Update hermetic python dependencies to numpy=2.0.0rc2 and scipy=1.13.0 for all python version Also install built jaxlib in hermetic python to support //jax:build_jaxlib=false tests. PiperOrigin-RevId: 635169327 | 19 May 2024, 06:39:09 UTC |
8caeaa2 | jax authors | 19 May 2024, 02:18:11 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/ea6c3d5916b1c5125f1d8d3e440cbbd9f62c2343. PiperOrigin-RevId: 635132837 | 19 May 2024, 02:18:57 UTC |
6577f47 | Yash Katariya | 18 May 2024, 15:45:31 UTC | Make `eqn.ctx` context manager thread safe by creating `eqn.ctx.manager`. PiperOrigin-RevId: 635057475 | 18 May 2024, 15:46:18 UTC |
e3a7a87 | jax authors | 18 May 2024, 03:09:00 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/e631101673a39975c5852fb43be6511e6c62fe22. PiperOrigin-RevId: 634954649 | 18 May 2024, 03:09:52 UTC |
25aa13c | Yash Katariya | 18 May 2024, 01:58:48 UTC | Support remat + compute_on. If the rematted computation is annotated to run on host, the backward pass will also execute on host. Also enable no-op nested compute tests. PiperOrigin-RevId: 634943450 | 18 May 2024, 01:59:49 UTC |
79fccf6 | Ruturaj4 | 16 May 2024, 16:25:34 UTC | add cholesky changes in bazel | 18 May 2024, 00:37:09 UTC |
641d5c8 | jax authors | 17 May 2024, 23:57:05 UTC | jax/pallas support ellipsis indexing PiperOrigin-RevId: 634922391 | 17 May 2024, 23:57:53 UTC |
02c19e9 | Yash Katariya | 17 May 2024, 23:31:23 UTC | Make `jax.grad` and `compute_on` work correctly. If the forward pass has annotation to execute on CPU, then it's backward pass also executes on CPU. PiperOrigin-RevId: 634917402 | 17 May 2024, 23:38:35 UTC |
1043e24 | Ashish Shenoy | 17 May 2024, 23:16:28 UTC | Add quantization support for PagedAttention TPU Pallas kernel. PiperOrigin-RevId: 634914369 | 17 May 2024, 23:17:33 UTC |
2d6d408 | Yash Katariya | 17 May 2024, 22:58:25 UTC | Initial commit for `jax.experimental.compute_on` API. The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host. `cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation. PiperOrigin-RevId: 634909918 | 17 May 2024, 22:59:21 UTC |
7a3fc71 | jax authors | 17 May 2024, 22:01:37 UTC | Merge pull request #21289 from jakevdp:stacklevel PiperOrigin-RevId: 634895282 | 17 May 2024, 22:01:37 UTC |
9ad6729 | Jake VanderPlas | 17 May 2024, 21:36:34 UTC | jax.experimental.export: fix stacklevel for warning | 17 May 2024, 21:36:34 UTC |
56301d4 | jax authors | 17 May 2024, 17:32:37 UTC | Merge pull request #21283 from dfm:gh21279 PiperOrigin-RevId: 634819832 | 17 May 2024, 17:32:37 UTC |
8152566 | jax authors | 17 May 2024, 17:09:32 UTC | [pallas:Mosaic GPU] Configurable smem scratch and a small bug fix in Mosaic GPU PiperOrigin-RevId: 634813241 | 17 May 2024, 17:10:20 UTC |
be06954 | Dan Foreman-Mackey | 17 May 2024, 15:16:34 UTC | Add RegularGridInterpolator to generated API docs In responding to gh21279, I noticed that `RegularGridInterpolator` isn't currently listed in the API docs. I know that `scipy.interpolate` is out of scope (https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html#scipy-interpolate), but since we do currently provide this wrapper, it seems like it makes sense to include it in the docs! | 17 May 2024, 15:24:35 UTC |
e93f36a | jax authors | 17 May 2024, 14:39:24 UTC | Merge pull request #21281 from jaro-sevcik:enable-host-offloading-scan-tests-gpu PiperOrigin-RevId: 634770894 | 17 May 2024, 14:39:24 UTC |
7be7c1e | Jaroslav Sevcik | 17 May 2024, 14:04:43 UTC | Enable remat-scan offloading test | 17 May 2024, 14:29:31 UTC |
f87be35 | jax authors | 17 May 2024, 14:23:45 UTC | [Mosaic GPU] reduce_sum does an intra-warp reduction before communicating with the other warps PiperOrigin-RevId: 634765339 | 17 May 2024, 14:24:35 UTC |
210f8bb | Sergei Lebedev | 17 May 2024, 11:32:07 UTC | Use absl.flags.FlagHolder to defined --mosaic_gpu_debug PiperOrigin-RevId: 634713545 | 17 May 2024, 11:33:17 UTC |
c455911 | jax authors | 17 May 2024, 11:29:43 UTC | Internal BUILD file change PiperOrigin-RevId: 634713068 | 17 May 2024, 11:30:21 UTC |
527aef3 | Sergei Lebedev | 17 May 2024, 11:14:50 UTC | Added a slow (but working!) implementation of layer norm in Pallas via Mosaic GPU PiperOrigin-RevId: 634710243 | 17 May 2024, 11:15:40 UTC |
5e2710c | jax authors | 17 May 2024, 07:10:27 UTC | Merge pull request #21261 from superbobry:mypy-ruff PiperOrigin-RevId: 634654578 | 17 May 2024, 07:10:27 UTC |
1829a66 | jax authors | 17 May 2024, 04:27:30 UTC | Merge pull request #21268 from jakevdp:register-dataclass PiperOrigin-RevId: 634624518 | 17 May 2024, 04:27:30 UTC |
0e92433 | jax authors | 17 May 2024, 03:03:03 UTC | [Mosaic GPU] Add a WGSplatLayout that trivially supports reshape and broadcast. PiperOrigin-RevId: 634610004 | 17 May 2024, 03:04:05 UTC |
efa420b | jax authors | 17 May 2024, 03:00:06 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/0821ce54085018643fba169e109f7b13bf2accb7. PiperOrigin-RevId: 634608759 | 17 May 2024, 03:01:07 UTC |
8493fc9 | jax authors | 17 May 2024, 01:37:14 UTC | Merge pull request #21181 from 8bitmp3:add-distributed-data-loading-doc PiperOrigin-RevId: 634595201 | 17 May 2024, 01:37:14 UTC |
defb53f | Jake VanderPlas | 17 May 2024, 01:34:54 UTC | Document jax.tree_util.register_dataclass | 17 May 2024, 01:34:54 UTC |
15a3deb | 8bitmp3 | 10 May 2024, 20:36:17 UTC | Add Distributed data loading doc | 16 May 2024, 23:12:13 UTC |
bc19f7f | jax authors | 16 May 2024, 23:00:31 UTC | Merge pull request #21267 from jakevdp:quantile-warning PiperOrigin-RevId: 634556534 | 16 May 2024, 23:00:31 UTC |
c3bc88d | Sergei Lebedev | 16 May 2024, 14:10:01 UTC | Bumped mypy to 1.10.0 and ruff to 0.4.4 | 16 May 2024, 22:16:32 UTC |
bfbde5e | Jake VanderPlas | 16 May 2024, 22:13:25 UTC | jnp.quantile & friends: properly deprecate interpolation | 16 May 2024, 22:13:25 UTC |
9dd98dc | jax authors | 16 May 2024, 20:30:05 UTC | Merge pull request #21244 from google:add-tpu-core-count PiperOrigin-RevId: 634510539 | 16 May 2024, 20:30:05 UTC |
74ffed9 | Yash Katariya | 16 May 2024, 20:11:50 UTC | Switch the order of sharding and memory kind custom call PiperOrigin-RevId: 634505383 | 16 May 2024, 20:12:38 UTC |
01194bd | Sergei Lebedev | 16 May 2024, 18:27:56 UTC | Clarified the type of the inputs to callback APIs The callback APIs were migrated to use jax.Arrays for both inputs and outputs in JAX 0.4.27. PiperOrigin-RevId: 634473890 | 16 May 2024, 18:29:09 UTC |
380503b | jax authors | 16 May 2024, 17:00:25 UTC | [Mosaic GPU] Enable pointwise operation with scalar values. PiperOrigin-RevId: 634441605 | 16 May 2024, 17:01:12 UTC |
2ec3994 | Bixia Zheng | 16 May 2024, 14:47:02 UTC | Add a sharding test for concatenate with an internal operand that requires paddings. PiperOrigin-RevId: 634389912 | 16 May 2024, 14:47:43 UTC |
8f045ca | Mark Sandler | 16 May 2024, 05:06:11 UTC | Add jax.make_array_from_process_local_data to create a distributed tensor from host data and supporting scaffolding in sharding to be able to figure out dimensions of host data required. PiperOrigin-RevId: 634205261 | 16 May 2024, 05:06:45 UTC |
cd41b4f | Junwhan Ahn | 16 May 2024, 04:13:59 UTC | Use `util.cache` instead of `lru_cache` for `create_mesh_pspec_sharding` Its return value depends on `jax.config.enable_memories` due to the memory kind canonicalization, so we should use `util.cache` that uses the trace_context as an additional key. PiperOrigin-RevId: 634192701 | 16 May 2024, 04:14:53 UTC |
6fe313c | jax authors | 16 May 2024, 02:06:11 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/a7ed898e9b9ca1f8d97ea8bc438ab690f23ef737. PiperOrigin-RevId: 634159073 | 16 May 2024, 02:07:03 UTC |
517e299 | Vadym Matsishevskyi | 16 May 2024, 01:20:14 UTC | Use hermetic Python in JAX, see "Managing hermetic Python" in developer.md for details PiperOrigin-RevId: 634146391 | 16 May 2024, 01:20:56 UTC |
a820387 | jax authors | 15 May 2024, 21:54:27 UTC | [Mosaic GPU] Adjust memref.expand_shape construction to pass in the new args PiperOrigin-RevId: 634084928 | 15 May 2024, 21:55:16 UTC |
6a949ae | jax authors | 15 May 2024, 20:43:18 UTC | Merge pull request #21251 from jakevdp:fix-pmap-warning PiperOrigin-RevId: 634061434 | 15 May 2024, 20:43:18 UTC |
f5a59cc | Jake VanderPlas | 15 May 2024, 20:20:05 UTC | Fix warning filter in test_pmap_of_prng_key | 15 May 2024, 20:20:05 UTC |
1538197 | jax authors | 15 May 2024, 19:53:19 UTC | Merge pull request #21247 from dfm:gh20481 PiperOrigin-RevId: 634045828 | 15 May 2024, 19:53:19 UTC |
bb40075 | jax authors | 15 May 2024, 19:48:37 UTC | [Mosaic GPU] Fix argument checking in case the input shape is not a tuple. PiperOrigin-RevId: 634044759 | 15 May 2024, 19:49:27 UTC |
09a4b38 | Dan Foreman-Mackey | 15 May 2024, 19:01:01 UTC | Add informative error for invalid unroll in scan As reported in #20481, setting `unroll=0` in `lax.scan` resulted in an uninformative `ZeroDivisionError`. This PR adds a check which raises a `ValueError` for `unroll<=0`. | 15 May 2024, 19:40:27 UTC |
76a7b19 | Michael Hudgins | 15 May 2024, 18:40:57 UTC | Set V3 cores to 4 | 15 May 2024, 18:40:57 UTC |
0232cb9 | Michael Hudgins | 15 May 2024, 18:37:27 UTC | Run v3-8 tests with cores set at 8 | 15 May 2024, 18:37:27 UTC |
181da12 | Michael Hudgins | 15 May 2024, 18:33:01 UTC | Add missing bash symbol | 15 May 2024, 18:33:01 UTC |
3015699 | Michael Hudgins | 15 May 2024, 17:46:40 UTC | Add core count to tpu nightly fix v5 job The current job assumes a 4 core TPU. Modify the matrix to enable defining the core count for each tpu | 15 May 2024, 17:46:40 UTC |
4ccac4c | jax authors | 15 May 2024, 17:40:58 UTC | Merge pull request #21237 from gnecula:clean_tokens PiperOrigin-RevId: 634003364 | 15 May 2024, 17:40:58 UTC |
11a4513 | jax authors | 15 May 2024, 17:09:47 UTC | [Mosaic GPU] Explicitly check the smem size so that we get a good error rather than a cryptic cuda error PiperOrigin-RevId: 633993345 | 15 May 2024, 17:10:46 UTC |
874a20a | jax authors | 15 May 2024, 15:52:16 UTC | [Mosaic GPU] Fix np.random API mistake in tests PiperOrigin-RevId: 633970332 | 15 May 2024, 15:53:03 UTC |
7653db8 | Peter Hawkins | 15 May 2024, 14:37:24 UTC | Fix CI test failure in line_search_test. A recent XLA change means that certain matrix-vector products may now be lowered as matrix-matrix multiplications. This may mean that we use lower precisions where we previously did not. PiperOrigin-RevId: 633949879 | 15 May 2024, 14:38:20 UTC |
41153b1 | George Necula | 14 May 2024, 10:20:58 UTC | Cleanup token handling during lowering Version 0.4.27 of jaxlib is now the minimum version and it supports real stablehlo tokens as module inputs and outputs. Hence we can now clean up `mlir.lower_jaxpr_to_fun` to not use the kwargs `create_tokens` and `replace_tokens_with_dummy` (both of them are always False now). We also remove `num_output_tokens` that is not used. | 15 May 2024, 12:54:46 UTC |
66a92c4 | Yue Sheng | 15 May 2024, 05:40:53 UTC | Reverts 9e7830df2df9362edcf2e18e353d327fdecae678 PiperOrigin-RevId: 633816901 | 15 May 2024, 05:41:44 UTC |
b8cec53 | jax authors | 15 May 2024, 01:41:41 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/a0f5d76e3dac2ee3293de60a79c29a46668a53b0. PiperOrigin-RevId: 633771201 | 15 May 2024, 01:42:27 UTC |
936ec98 | Frederic Bastien | 14 May 2024, 20:36:35 UTC | Add doc for jax.experimental.serialize_executable module | 14 May 2024, 23:11:37 UTC |
9af0a31 | jax authors | 14 May 2024, 23:05:17 UTC | Merge pull request #21222 from tempoxylophone:patch-1 PiperOrigin-RevId: 633734390 | 14 May 2024, 23:05:17 UTC |
48b9018 | jax authors | 14 May 2024, 23:04:56 UTC | Merge pull request #21221 from nouiz:doc PiperOrigin-RevId: 633734299 | 14 May 2024, 23:04:56 UTC |
1dff5b3 | jax authors | 14 May 2024, 23:01:50 UTC | Merge pull request #21226 from Micky774:linalg-depr PiperOrigin-RevId: 633734284 | 14 May 2024, 23:01:50 UTC |
d4f36e2 | jax authors | 14 May 2024, 22:48:53 UTC | Merge pull request #21227 from jakevdp:eye-offset PiperOrigin-RevId: 633730820 | 14 May 2024, 22:48:53 UTC |
e8b06cc | jax authors | 14 May 2024, 22:20:37 UTC | Cholesky rank-1 update kernel for JAX. PiperOrigin-RevId: 633722940 | 14 May 2024, 22:21:38 UTC |
e2918ca | Sergei Lebedev | 14 May 2024, 21:47:24 UTC | Added a very rough sketch of Mosaic GPU lowering for Pallas Almost nothing is supported, including * PyTree inputs/outputs * indexers * non-trivial grids * block specs * any primitives beyond the ones added here * etc etc PiperOrigin-RevId: 633713366 | 14 May 2024, 21:48:09 UTC |