https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
9267ac4 Clean up version switches from dense array migration PiperOrigin-RevId: 635965009 21 May 2024, 23:35:31 UTC
ccaf466 [XLA:Mosaic] Support retiling from (8, 128, -2) to (8, 128) for 32-bit data. This drops implicit second minor dim and fixes the infer vector layout for concatenate rule when concatenating constant vectors. PiperOrigin-RevId: 635917408 21 May 2024, 20:52:26 UTC
1181136 Merge pull request #21280 from gnecula:export_device_poly2 PiperOrigin-RevId: 635904723 21 May 2024, 20:12:33 UTC
e750a20 Merge pull request #21315 from jakevdp:dep-default-registry PiperOrigin-RevId: 635898094 21 May 2024, 19:51:08 UTC
6d5668d jax.tree_util: test serialize_using_proto 21 May 2024, 19:37:37 UTC
9948529 [export] Relax the check that exported modules are used with same number of devices as when exported Now we allow a module exported for 1 device and not using any sharding annotations to be called from a computation that uses multiple devices. Such exported modules can be parallelized trivially point-wise. 21 May 2024, 18:40:16 UTC
21aa723 Merge pull request #21316 from jakevdp:pytree-doc PiperOrigin-RevId: 635869738 21 May 2024, 18:19:29 UTC
92d892b Adds rewrite patterns for `arith` and `math` operations with `bf16` operands/results that are not supported by the underlying hardware. PiperOrigin-RevId: 635865752 21 May 2024, 18:08:19 UTC
6ff5532 Merge pull request #21324 from jakevdp:dep-warn-local PiperOrigin-RevId: 635863480 21 May 2024, 18:02:31 UTC
483f924 Bump shard count for experimental_rnn_test, which is timing out in CI when built under ASAN. PiperOrigin-RevId: 635850400 21 May 2024, 17:25:24 UTC
5350bc9 Merge pull request #21258 from olupton:skip-cusolver-test-with-cuda-12.4 PiperOrigin-RevId: 635845926 21 May 2024, 17:11:54 UTC
418b688 Automated Code Change PiperOrigin-RevId: 635818645 21 May 2024, 15:40:34 UTC
9ba77f8 Skip a test when run with cuSolver >= 11.6 This version is shipped with CUDA 12.4. The test assumes that a workspace size baked in with an older version of cuSolver can be used with a newer version of cuSolver. This is not safe, and leads to an error when upgrading from 11.5 to 11.6. 21 May 2024, 14:46:43 UTC
4394bdc Change the log message in `pxla.py` to be less confusing. PiperOrigin-RevId: 635789016 21 May 2024, 13:42:07 UTC
d5d2fb0 Ignore deprecation warnings locally rather than globally 21 May 2024, 03:28:25 UTC
1327143 Better documentation for jax.tree_util 21 May 2024, 02:56:47 UTC
3f1b059 Update XLA dependency to use revision http://github.com/openxla/xla/commit/c82597f5554394bf56847cbc75ec1b1e54f74d29. PiperOrigin-RevId: 635634941 21 May 2024, 01:47:32 UTC
d33a568 Refactor & test internal deprecation APIs The names and APIs were previously too similar and therefore somewhat confusing; this will be more clear I think. PiperOrigin-RevId: 635615163 21 May 2024, 00:16:31 UTC
2eff241 Merge pull request #21319 from gnecula:exp_fix_mesh PiperOrigin-RevId: 635611557 21 May 2024, 00:02:44 UTC
6deeee2 [export] Fix device assignment error for grad of exported. Currently, the export code uses a manufactured device assignment for exporting the VJP function. We should use instead the same device assigment that was used when exporting the primal function. This PR fixes that for the case when the export is done through the direct use of `jax.experimental.export`, and leaves as future work the case when the use is from `jax2tf`. We add a disabled tests for the latter case. Bug: #21314 20 May 2024, 23:11:01 UTC
b197ae5 [Mosaic] Also check bitwidth in apply-vector-layout's `layoutIsValidForValue`. PiperOrigin-RevId: 635595321 20 May 2024, 22:57:08 UTC
118ca21 Merge pull request #21318 from jakevdp:fix-mypy PiperOrigin-RevId: 635553370 20 May 2024, 20:31:33 UTC
329ab03 CI: fix mypy error 20 May 2024, 20:23:15 UTC
06d2e48 Copybara import of the project: -- f625317cc80639178882316df6f8775294adc6b7 by cjkkkk <ske@nvidia.com>: init COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21228 from Cjkkkk:sdpa_new_cudnn_frontend f625317cc80639178882316df6f8775294adc6b7 PiperOrigin-RevId: 635518631 20 May 2024, 18:31:15 UTC
e0a6453 Simplify JAX lowering rules for cumulative sum Upstream fix has landed => removing CPU workaround. PiperOrigin-RevId: 635505632 20 May 2024, 17:52:29 UTC
cc3a380 Add unit test to check if the backend serialization/deserialization result equal to the original executable. PiperOrigin-RevId: 635485374 20 May 2024, 16:52:38 UTC
61ff828 Add support for TPU delay in Mosaic PiperOrigin-RevId: 635473532 20 May 2024, 16:07:56 UTC
2f45830 [Mosaic GPU] Prepare matmul example so it can be exposed to other projects. PiperOrigin-RevId: 635442413 20 May 2024, 13:54:44 UTC
f600caa Use @register_lowering to register Pallas GPU lowering rules This leads to slightly more compact code, but should otherwise be identical. PiperOrigin-RevId: 635442002 20 May 2024, 13:51:33 UTC
c3d0b0d Merge pull request #21305 from jakevdp:scalar-bool PiperOrigin-RevId: 635436437 20 May 2024, 13:20:32 UTC
974c72b Merge pull request #21292 from ROCm:rv_stable_051624 PiperOrigin-RevId: 635430659 20 May 2024, 12:52:36 UTC
4bac10e Finalize deprecation of the config module. To configure JAX, use `import jax` and reference the config object via `jax.config`. PiperOrigin-RevId: 635430169 20 May 2024, 12:49:31 UTC
bb616ef Merge pull request #21231 from nouiz:doc_experimental_serialize_executable PiperOrigin-RevId: 635428472 20 May 2024, 12:39:21 UTC
5b28170 Support scalar boolean indices in arr.at[idx].set(vals) 20 May 2024, 12:33:36 UTC
53ec2cd Add notap tag to Mosaic tests PiperOrigin-RevId: 635379982 20 May 2024, 08:35:56 UTC
ffdb9bb Update XLA dependency to use revision http://github.com/openxla/xla/commit/4c566c945ad38e69bc71b5ee67b741551187e011. PiperOrigin-RevId: 635307171 20 May 2024, 01:27:13 UTC
45a7c22 fix: Update hermetic python dependencies to numpy=2.0.0rc2 and scipy=1.13.0 for all python version Also install built jaxlib in hermetic python to support //jax:build_jaxlib=false tests. PiperOrigin-RevId: 635169327 19 May 2024, 06:39:09 UTC
8caeaa2 Update XLA dependency to use revision http://github.com/openxla/xla/commit/ea6c3d5916b1c5125f1d8d3e440cbbd9f62c2343. PiperOrigin-RevId: 635132837 19 May 2024, 02:18:57 UTC
6577f47 Make `eqn.ctx` context manager thread safe by creating `eqn.ctx.manager`. PiperOrigin-RevId: 635057475 18 May 2024, 15:46:18 UTC
e3a7a87 Update XLA dependency to use revision http://github.com/openxla/xla/commit/e631101673a39975c5852fb43be6511e6c62fe22. PiperOrigin-RevId: 634954649 18 May 2024, 03:09:52 UTC
25aa13c Support remat + compute_on. If the rematted computation is annotated to run on host, the backward pass will also execute on host. Also enable no-op nested compute tests. PiperOrigin-RevId: 634943450 18 May 2024, 01:59:49 UTC
79fccf6 add cholesky changes in bazel 18 May 2024, 00:37:09 UTC
641d5c8 jax/pallas support ellipsis indexing PiperOrigin-RevId: 634922391 17 May 2024, 23:57:53 UTC
02c19e9 Make `jax.grad` and `compute_on` work correctly. If the forward pass has annotation to execute on CPU, then it's backward pass also executes on CPU. PiperOrigin-RevId: 634917402 17 May 2024, 23:38:35 UTC
1043e24 Add quantization support for PagedAttention TPU Pallas kernel. PiperOrigin-RevId: 634914369 17 May 2024, 23:17:33 UTC
2d6d408 Initial commit for `jax.experimental.compute_on` API. The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host. `cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation. PiperOrigin-RevId: 634909918 17 May 2024, 22:59:21 UTC
7a3fc71 Merge pull request #21289 from jakevdp:stacklevel PiperOrigin-RevId: 634895282 17 May 2024, 22:01:37 UTC
9ad6729 jax.experimental.export: fix stacklevel for warning 17 May 2024, 21:36:34 UTC
56301d4 Merge pull request #21283 from dfm:gh21279 PiperOrigin-RevId: 634819832 17 May 2024, 17:32:37 UTC
8152566 [pallas:Mosaic GPU] Configurable smem scratch and a small bug fix in Mosaic GPU PiperOrigin-RevId: 634813241 17 May 2024, 17:10:20 UTC
be06954 Add RegularGridInterpolator to generated API docs In responding to gh21279, I noticed that `RegularGridInterpolator` isn't currently listed in the API docs. I know that `scipy.interpolate` is out of scope (https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html#scipy-interpolate), but since we do currently provide this wrapper, it seems like it makes sense to include it in the docs! 17 May 2024, 15:24:35 UTC
e93f36a Merge pull request #21281 from jaro-sevcik:enable-host-offloading-scan-tests-gpu PiperOrigin-RevId: 634770894 17 May 2024, 14:39:24 UTC
7be7c1e Enable remat-scan offloading test 17 May 2024, 14:29:31 UTC
f87be35 [Mosaic GPU] reduce_sum does an intra-warp reduction before communicating with the other warps PiperOrigin-RevId: 634765339 17 May 2024, 14:24:35 UTC
210f8bb Use absl.flags.FlagHolder to defined --mosaic_gpu_debug PiperOrigin-RevId: 634713545 17 May 2024, 11:33:17 UTC
c455911 Internal BUILD file change PiperOrigin-RevId: 634713068 17 May 2024, 11:30:21 UTC
527aef3 Added a slow (but working!) implementation of layer norm in Pallas via Mosaic GPU PiperOrigin-RevId: 634710243 17 May 2024, 11:15:40 UTC
5e2710c Merge pull request #21261 from superbobry:mypy-ruff PiperOrigin-RevId: 634654578 17 May 2024, 07:10:27 UTC
1829a66 Merge pull request #21268 from jakevdp:register-dataclass PiperOrigin-RevId: 634624518 17 May 2024, 04:27:30 UTC
0e92433 [Mosaic GPU] Add a WGSplatLayout that trivially supports reshape and broadcast. PiperOrigin-RevId: 634610004 17 May 2024, 03:04:05 UTC
efa420b Update XLA dependency to use revision http://github.com/openxla/xla/commit/0821ce54085018643fba169e109f7b13bf2accb7. PiperOrigin-RevId: 634608759 17 May 2024, 03:01:07 UTC
8493fc9 Merge pull request #21181 from 8bitmp3:add-distributed-data-loading-doc PiperOrigin-RevId: 634595201 17 May 2024, 01:37:14 UTC
defb53f Document jax.tree_util.register_dataclass 17 May 2024, 01:34:54 UTC
15a3deb Add Distributed data loading doc 16 May 2024, 23:12:13 UTC
bc19f7f Merge pull request #21267 from jakevdp:quantile-warning PiperOrigin-RevId: 634556534 16 May 2024, 23:00:31 UTC
c3bc88d Bumped mypy to 1.10.0 and ruff to 0.4.4 16 May 2024, 22:16:32 UTC
bfbde5e jnp.quantile & friends: properly deprecate interpolation 16 May 2024, 22:13:25 UTC
9dd98dc Merge pull request #21244 from google:add-tpu-core-count PiperOrigin-RevId: 634510539 16 May 2024, 20:30:05 UTC
74ffed9 Switch the order of sharding and memory kind custom call PiperOrigin-RevId: 634505383 16 May 2024, 20:12:38 UTC
01194bd Clarified the type of the inputs to callback APIs The callback APIs were migrated to use jax.Arrays for both inputs and outputs in JAX 0.4.27. PiperOrigin-RevId: 634473890 16 May 2024, 18:29:09 UTC
380503b [Mosaic GPU] Enable pointwise operation with scalar values. PiperOrigin-RevId: 634441605 16 May 2024, 17:01:12 UTC
2ec3994 Add a sharding test for concatenate with an internal operand that requires paddings. PiperOrigin-RevId: 634389912 16 May 2024, 14:47:43 UTC
8f045ca Add jax.make_array_from_process_local_data to create a distributed tensor from host data and supporting scaffolding in sharding to be able to figure out dimensions of host data required. PiperOrigin-RevId: 634205261 16 May 2024, 05:06:45 UTC
cd41b4f Use `util.cache` instead of `lru_cache` for `create_mesh_pspec_sharding` Its return value depends on `jax.config.enable_memories` due to the memory kind canonicalization, so we should use `util.cache` that uses the trace_context as an additional key. PiperOrigin-RevId: 634192701 16 May 2024, 04:14:53 UTC
6fe313c Update XLA dependency to use revision http://github.com/openxla/xla/commit/a7ed898e9b9ca1f8d97ea8bc438ab690f23ef737. PiperOrigin-RevId: 634159073 16 May 2024, 02:07:03 UTC
517e299 Use hermetic Python in JAX, see "Managing hermetic Python" in developer.md for details PiperOrigin-RevId: 634146391 16 May 2024, 01:20:56 UTC
a820387 [Mosaic GPU] Adjust memref.expand_shape construction to pass in the new args PiperOrigin-RevId: 634084928 15 May 2024, 21:55:16 UTC
6a949ae Merge pull request #21251 from jakevdp:fix-pmap-warning PiperOrigin-RevId: 634061434 15 May 2024, 20:43:18 UTC
f5a59cc Fix warning filter in test_pmap_of_prng_key 15 May 2024, 20:20:05 UTC
1538197 Merge pull request #21247 from dfm:gh20481 PiperOrigin-RevId: 634045828 15 May 2024, 19:53:19 UTC
bb40075 [Mosaic GPU] Fix argument checking in case the input shape is not a tuple. PiperOrigin-RevId: 634044759 15 May 2024, 19:49:27 UTC
09a4b38 Add informative error for invalid unroll in scan As reported in #20481, setting `unroll=0` in `lax.scan` resulted in an uninformative `ZeroDivisionError`. This PR adds a check which raises a `ValueError` for `unroll<=0`. 15 May 2024, 19:40:27 UTC
76a7b19 Set V3 cores to 4 15 May 2024, 18:40:57 UTC
0232cb9 Run v3-8 tests with cores set at 8 15 May 2024, 18:37:27 UTC
181da12 Add missing bash symbol 15 May 2024, 18:33:01 UTC
3015699 Add core count to tpu nightly fix v5 job The current job assumes a 4 core TPU. Modify the matrix to enable defining the core count for each tpu 15 May 2024, 17:46:40 UTC
4ccac4c Merge pull request #21237 from gnecula:clean_tokens PiperOrigin-RevId: 634003364 15 May 2024, 17:40:58 UTC
11a4513 [Mosaic GPU] Explicitly check the smem size so that we get a good error rather than a cryptic cuda error PiperOrigin-RevId: 633993345 15 May 2024, 17:10:46 UTC
874a20a [Mosaic GPU] Fix np.random API mistake in tests PiperOrigin-RevId: 633970332 15 May 2024, 15:53:03 UTC
7653db8 Fix CI test failure in line_search_test. A recent XLA change means that certain matrix-vector products may now be lowered as matrix-matrix multiplications. This may mean that we use lower precisions where we previously did not. PiperOrigin-RevId: 633949879 15 May 2024, 14:38:20 UTC
41153b1 Cleanup token handling during lowering Version 0.4.27 of jaxlib is now the minimum version and it supports real stablehlo tokens as module inputs and outputs. Hence we can now clean up `mlir.lower_jaxpr_to_fun` to not use the kwargs `create_tokens` and `replace_tokens_with_dummy` (both of them are always False now). We also remove `num_output_tokens` that is not used. 15 May 2024, 12:54:46 UTC
66a92c4 Reverts 9e7830df2df9362edcf2e18e353d327fdecae678 PiperOrigin-RevId: 633816901 15 May 2024, 05:41:44 UTC
b8cec53 Update XLA dependency to use revision http://github.com/openxla/xla/commit/a0f5d76e3dac2ee3293de60a79c29a46668a53b0. PiperOrigin-RevId: 633771201 15 May 2024, 01:42:27 UTC
936ec98 Add doc for jax.experimental.serialize_executable module 14 May 2024, 23:11:37 UTC
9af0a31 Merge pull request #21222 from tempoxylophone:patch-1 PiperOrigin-RevId: 633734390 14 May 2024, 23:05:17 UTC
48b9018 Merge pull request #21221 from nouiz:doc PiperOrigin-RevId: 633734299 14 May 2024, 23:04:56 UTC
1dff5b3 Merge pull request #21226 from Micky774:linalg-depr PiperOrigin-RevId: 633734284 14 May 2024, 23:01:50 UTC
d4f36e2 Merge pull request #21227 from jakevdp:eye-offset PiperOrigin-RevId: 633730820 14 May 2024, 22:48:53 UTC
e8b06cc Cholesky rank-1 update kernel for JAX. PiperOrigin-RevId: 633722940 14 May 2024, 22:21:38 UTC
e2918ca Added a very rough sketch of Mosaic GPU lowering for Pallas Almost nothing is supported, including * PyTree inputs/outputs * indexers * non-trivial grids * block specs * any primitives beyond the ones added here * etc etc PiperOrigin-RevId: 633713366 14 May 2024, 21:48:09 UTC
back to top