https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
5a39102 [Mosaic GPU] Implement non-transposing matmul for bf16 x s8 matmul PiperOrigin-RevId: 632526631 13 May 2024, 00:23:43 UTC
b4f2145 Update XLA dependency to use revision http://github.com/openxla/xla/commit/d3e881ad668b7aa44283d47bb553a04b86b71315. PiperOrigin-RevId: 632860505 12 May 2024, 01:00:43 UTC
a527b71 [Mosaic GPU] Prepare for writing warp-specialized kernels PiperOrigin-RevId: 632854287 12 May 2024, 00:09:08 UTC
49bd4d6 Reverts 586568f4fe44cf9ad8b1bd022148a10c4b69f33a PiperOrigin-RevId: 632818524 11 May 2024, 19:24:06 UTC
3b03e54 Raise a runtime error when trying to convert the `jax.Array` wrapped by `jax.core.Token` to a numpy array, as it is an internal implementation detail and the buffer has XLA token shape. PiperOrigin-RevId: 632682906 11 May 2024, 04:08:06 UTC
20646eb Update XLA dependency to use revision http://github.com/openxla/xla/commit/0b3dc68410d57f6cd3d0a484f46946d6d12f03a8. PiperOrigin-RevId: 632655986 11 May 2024, 01:19:21 UTC
9ac1d38 Finish jax and jaxlib 0.4.28 release PiperOrigin-RevId: 632653310 11 May 2024, 01:06:52 UTC
979d9ca Merge pull request #21168 from 8bitmp3:upgrade-sharded--doc PiperOrigin-RevId: 632648408 11 May 2024, 00:44:15 UTC
a4693db Add a jaxpr interpreter for propagating memory kinds to output. It only triggers if we detect multiple memory kinds in the jaxpr. This hopefully should go away when XLA implements it's own memory space propagation pass or JAX adds memory_kind to the type system of jaxpr i.e. on avals. It's required to treat the following code blocks (1) and (2) as equivalent when lowering to stablehlo. In general shardings should also be treated the same way but we'll cross that bridge later. 1. `jit(f, out_shardings=s_host)` 2. ``` @jax.jit def f(x): return jax.device_put(x, s_host) ``` PiperOrigin-RevId: 632621025 10 May 2024, 22:34:57 UTC
27c932a Do not import from lowering in tests/pallas/pallas_test.py This ensures that the test is importable even with a non-GPU jaxlib, which does not have Triton dialect bindings. PiperOrigin-RevId: 632603225 10 May 2024, 21:25:17 UTC
9ea3fcb Upgrade JAX Parallelism Sharded Computation 101 doc 10 May 2024, 21:24:16 UTC
17444fc Merge pull request #21174 from hawkinsp:spmm PiperOrigin-RevId: 632589433 10 May 2024, 20:35:04 UTC
dda428e Disable tests that trigger warning if x64 mode isn't enabled. 10 May 2024, 19:58:22 UTC
c3cab2e Reverts 6c425338d20c0c9be3fc69d2f07ababf79c881d3 PiperOrigin-RevId: 632579101 10 May 2024, 19:56:10 UTC
c231cd5 Merge pull request #21173 from hawkinsp:precision PiperOrigin-RevId: 632577567 10 May 2024, 19:50:07 UTC
24b4731 Force float32 matmuls in examples_test. This test started failing when we changed our CI to use L4 GPUs. Using highest precision resolves the problem. 10 May 2024, 19:30:02 UTC
0a3e432 [PJRT C API] Enable PJRT C API runtime in jax2tf dlpack. GetDefaultLayout added a fallback for GPU backend so it is no longer blocked by the fact that PJRT C API does not support GetDefaultLayout yet. PiperOrigin-RevId: 632555239 10 May 2024, 18:30:37 UTC
6c42533 Reverts 0267ed0ba9584bbc137792361b53aa80e9c4d306 PiperOrigin-RevId: 632548226 10 May 2024, 18:06:38 UTC
586568f Simplify JAX lowering rules for cumulative sum Rely on XLA decomposition. # JAX GPU microbenchmarks 285us for cumsum over 1e8 elements 449us for cumsum over 1e8 elements. # JAX CPU microbenchmarks: 1.8s vs. 0.7s for 50 iterations over cumsum over 1e7 elements PiperOrigin-RevId: 632547166 10 May 2024, 18:03:28 UTC
13a1955 Merge pull request #21167 from jakevdp:einsum-path-func PiperOrigin-RevId: 632538144 10 May 2024, 17:35:32 UTC
bac3a6f Allow tokens being passed to `jit` and through dispatch and being returned from the jitted function. Fixes https://github.com/google/jax/issues/21160 PiperOrigin-RevId: 632531105 10 May 2024, 17:12:48 UTC
0267ed0 Replace xla_extension symlink with genrule that makes xla_extension module accessible from jax._src.lib. The runfiles of the original targets were lost when the symlinked files were used. This change is needed for future Hermetic CUDA implementation. Bazel will download CUDA distributives in cache, and CUDA executables and libraries will be added in the runfiles of the targets. When `xla_extension` is simlinked, the content of the runfiles is lost. With `genrule` the content of the runfiles is preserved. PiperOrigin-RevId: 632508121 10 May 2024, 15:48:12 UTC
d07951c jnp.einsum_path: improve docs & annotations 10 May 2024, 15:39:32 UTC
c2d78ab Merge pull request #21148 from jakevdp:einsum-path PiperOrigin-RevId: 632470123 10 May 2024, 12:58:30 UTC
2a01541 Merge pull request #21157 from djm3622:patch-1 PiperOrigin-RevId: 632469962 10 May 2024, 12:55:38 UTC
c3d3db9 jnp.einsum: support optimize=False, and improve docs for this keyword. 10 May 2024, 02:50:06 UTC
aa48a55 Update line127 error, debugging.md There is an error with "y, z = jnp.sin(x), jnp.cos(x)" where jnp.cos(x) was nested within jnp.sin(x) ==> jnp.sin(x, jnp.cos(x)). This caused an error to be thrown. This change fixes that. 10 May 2024, 02:44:27 UTC
f21e3e8 Update XLA dependency to use revision http://github.com/openxla/xla/commit/e8247c3ea1d4d7f31cf27def4c7ac6f2ce64ecd4. PiperOrigin-RevId: 632333365 10 May 2024, 01:37:25 UTC
6f79093 [XLA:TPU] Support output streaming and refactor TryOutputStreaming into a bottoms-up approach. Previously, output streaming took a top-down approach which indiscriminately checks if a MoveToHost custom call would trace down to an output marked with host memory space. This did not work when a dynamic-update-slice existed between the MTH call and the output. This CL fixes this problem by handling output streaming before other MTH calls, while also improving efficiency with the bottoms-up approach so we only trace a single path in the graph. PiperOrigin-RevId: 632318740 10 May 2024, 00:34:21 UTC
a9460f2 Merge pull request #21130 from Micky774:reshape PiperOrigin-RevId: 632277406 09 May 2024, 21:51:56 UTC
e8ac011 Allow nested shard_map. PiperOrigin-RevId: 632275515 09 May 2024, 21:45:40 UTC
79005c1 Deprecate newshape argument of jnp.reshape 09 May 2024, 21:02:07 UTC
1cb6971 Merge pull request #21107 from justinjfu/pallas_splash_test_fix Fix failing smem_hbm_dma test 09 May 2024, 20:27:56 UTC
7245714 Add zeros initialization to failing smem-hbm copy test. 09 May 2024, 20:19:21 UTC
9e62994 Merge pull request #21135 from hawkinsp:release PiperOrigin-RevId: 632235600 09 May 2024, 19:32:51 UTC
038dfee Prepare 0.4.28 release. 09 May 2024, 19:25:33 UTC
f98e707 Update XLA dependency to use revision http://github.com/openxla/xla/commit/d60579f54a0b6c37d1caf11dc3eb34488cf6922a. PiperOrigin-RevId: 632232639 09 May 2024, 19:22:25 UTC
1a7a2aa Merge pull request #21106 from jakevdp:linalg-precision PiperOrigin-RevId: 632217396 09 May 2024, 18:33:54 UTC
0c4d81c Merge pull request #21138 from jakevdp:einsum-doc PiperOrigin-RevId: 632198113 09 May 2024, 17:38:51 UTC
2ddb7ff jnp.linalg: add precision & preferred_element_type to dot-like functions 09 May 2024, 17:06:51 UTC
eb0b1b0 Merge pull request #21108 from justinjfu/skip_pallas_test_64 Skip float64 test_nextafter on TPU. 09 May 2024, 16:20:30 UTC
671fb12 Update the multi-process note in pjit's docstring PiperOrigin-RevId: 632160561 09 May 2024, 15:38:29 UTC
2be3f6d Merge pull request #21146 from jakevdp:fix-multidot PiperOrigin-RevId: 632142647 09 May 2024, 14:28:49 UTC
89d25bb Reenable examples_test in Bazel build. Fix bitrot. This test was disabled years ago because it was slow, but it isn't any more. PiperOrigin-RevId: 632138101 09 May 2024, 14:10:07 UTC
5edfaa6 jnp.linalg.multi_dot: use optimize='auto' 09 May 2024, 13:47:30 UTC
1e88e2f Update XLA dependency to use revision http://github.com/openxla/xla/commit/4872030c44e20556efdfea47829d942985e0ccf1. PiperOrigin-RevId: 631997979 09 May 2024, 02:17:08 UTC
168f40e [XLA:Python] Fix a memory corruption bug in the tp_name attribute of ArrayImpl and PjitFunction for Python 3.10 or earlier. This works around https://github.com/python/cpython/issues/89478, which was fixed in Python 3.11. PiperOrigin-RevId: 631984256 09 May 2024, 01:05:28 UTC
96f888b Reverts 1956ff7d7b73794012fece2d8452e097196587fc PiperOrigin-RevId: 631974751 09 May 2024, 00:23:13 UTC
f991dd8 Merge pull request #21139 from jakevdp:fix-lpmn-test PiperOrigin-RevId: 631954696 08 May 2024, 23:06:47 UTC
f556a17 TST: fix Lpmn test for new scipy 08 May 2024, 22:55:20 UTC
e870052 jnp.einsum: improve documentation 08 May 2024, 21:30:59 UTC
962f084 Merge pull request #21137 from superbobry:pallas PiperOrigin-RevId: 631923082 08 May 2024, 21:20:10 UTC
65d4c68 Generic reduce window jvp The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly. However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`). For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values. In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition PiperOrigin-RevId: 631916764 08 May 2024, 21:00:39 UTC
4b62425 Renamed is_device_gpu_at_least to is_cuda_compute_capability_at_least This makes it clear that the predicate is only supposed to be used for NVidia GPUs at the moment. 08 May 2024, 20:41:50 UTC
bfdc87d Merge pull request #21136 from superbobry:pallas PiperOrigin-RevId: 631908723 08 May 2024, 20:34:40 UTC
575ba94 Removed get_compute_capability from jax.experimental.pallas.gpu Compute capability is available as a `str` attribute on a GPU device since jaxlib 0.4.26. 08 May 2024, 20:10:43 UTC
a145109 Update XLA dependency to use revision http://github.com/openxla/xla/commit/68b17a8571f197676ea07479637da1546b3b501d. PiperOrigin-RevId: 631874377 08 May 2024, 18:47:24 UTC
f768cb7 Refactor pipeline emitter API. Building on enrique's work, this CL refactors the emit_pipeline abstraction: 1) factors out the VMEM double-buffering bookkeeping into a helper class. 2) concentrate the intricate copy/wait scheduling logic into one place inside a scheduler helper while allowing manual overrides, callbacks don't control scheduling anymore, rather we have explicit loop scheduling. 3) minimize callbacks and simplify the "defaults" for fusing pipelines together. Examples of fully overlapped versions of latency- and throughput- optimized AG-matmuls and matmul-RSs are included in new tests. PiperOrigin-RevId: 631865641 08 May 2024, 18:22:47 UTC
e652d62 Cleanup second registration of custom_partitioning callbacks now that the jaxlib version has been bumped. PiperOrigin-RevId: 631852273 08 May 2024, 17:45:39 UTC
8baa5d8 Merge pull request #21128 from hawkinsp:loggingtest PiperOrigin-RevId: 631839432 08 May 2024, 17:07:42 UTC
2967ec9 Merge pull request #21129 from superbobry:pallas PiperOrigin-RevId: 631825682 08 May 2024, 16:25:49 UTC
0feeaa5 Removed stale version guards and try/except blocks from Pallas GPU They are unnecessary now that the minimum jaxlib version is 0.4.27. 08 May 2024, 16:05:45 UTC
0e57c79 Switch Windows jobs to use Clang. Remove the experimental/trial Clang job. PiperOrigin-RevId: 631814321 08 May 2024, 15:47:34 UTC
11da3df Merge pull request #21096 from gspschmid:gschmid/sourcemaps PiperOrigin-RevId: 631769572 08 May 2024, 12:44:08 UTC
919832a Enable logging_test on all CI platforms. Should catch issues like https://github.com/google/jax/issues/21121 08 May 2024, 12:43:52 UTC
b0b322d Add sourcemap module to generate TC39-compliant source maps 08 May 2024, 08:54:25 UTC
335f27b Update XLA dependency to use revision http://github.com/openxla/xla/commit/c6df436f9e3d2a706a3a9595e37224276fb6013b. PiperOrigin-RevId: 631629197 08 May 2024, 02:58:28 UTC
e9e4e53 [jax:mosaic-gpu] FragmentedArray can do tiled load. PiperOrigin-RevId: 631611060 08 May 2024, 01:13:55 UTC
79f11d5 [Pallas] Fix some typos. PiperOrigin-RevId: 631592201 07 May 2024, 23:52:38 UTC
395d3cb Bump minimum jaxlib version to 0.4.27 xla_extension_version is 261 and mlir_api_version is 56 PiperOrigin-RevId: 631579739 07 May 2024, 23:07:59 UTC
5ba56bb Recommend the plugin in the CUDA installation instructions. PiperOrigin-RevId: 631555876 07 May 2024, 21:47:39 UTC
78e10ee Merge pull request #21115 from jakevdp:multi-dot PiperOrigin-RevId: 631545161 07 May 2024, 21:14:02 UTC
2b3251e Merge pull request #21092 from jakevdp:dot-doc PiperOrigin-RevId: 631536980 07 May 2024, 20:51:13 UTC
09810be Implement jnp.linalg.multi_dot using opt_einsum 07 May 2024, 20:40:25 UTC
5f70267 Merge pull request #21103 from superbobry:mosaic-gpu-fix PiperOrigin-RevId: 631521771 07 May 2024, 20:11:43 UTC
7153738 Merge pull request #21104 from superbobry:triton-fixes PiperOrigin-RevId: 631521767 07 May 2024, 20:06:33 UTC
174405c The Bazel version used in JAX is bumped from 6.1.2 to 6.5.0. The update is needed for Windows/Clang builds and for the future hermetic CUDA implementation. PiperOrigin-RevId: 631519200 07 May 2024, 19:58:37 UTC
b6fea37 Merge pull request #21111 from jakevdp:fix-changelog PiperOrigin-RevId: 631493089 07 May 2024, 18:34:06 UTC
9524188 Merge pull request #21110 from jakevdp:upstream-nightly PiperOrigin-RevId: 631490005 07 May 2024, 18:25:19 UTC
c18851b CHANGELOG: move change from 0.4.27 to 0.4.28 07 May 2024, 18:16:11 UTC
496795e CI: fix typo in workflow 07 May 2024, 18:14:11 UTC
5031a1d Finish jax and jaxlib 0.4.27 release PiperOrigin-RevId: 631486157 07 May 2024, 18:14:09 UTC
034d843 jax.numpy: better docs for matmul-like functions 07 May 2024, 18:01:54 UTC
3dd8af9 Merge pull request #21090 from jakevdp:extract PiperOrigin-RevId: 631480474 07 May 2024, 17:59:10 UTC
274e735 Merge pull request #21105 from albanie:patch-3 PiperOrigin-RevId: 631472758 07 May 2024, 17:38:06 UTC
daab7a0 Handle ellipsis `...` in `_attempt_rewriting_take_via_slice`. Previously `model['some_array'][:,0,0,:]` would generate a `slice`, while `model['some_array'][...,0,0,:]` would generate a `gather`. Now both of these generate `slice` eqns. PiperOrigin-RevId: 631469837 07 May 2024, 17:30:08 UTC
080b3f1 Skip float64 test_nextafter on TPU 07 May 2024, 17:06:08 UTC
fdd8b61 Minor doc fix to `jaxpr.rst` Update the written description of the example `jaxpr` to match the given snippet of code. 07 May 2024, 16:36:19 UTC
2ca3264 Updated the Pallas GPU lowering to work with older jaxlib versions Triton changed the signatures of LoadOp and DotOp upstream, and the lowering code was not ready to handle both old and new signatures. 07 May 2024, 16:30:52 UTC
8ccbeba Fixed Mosaic GPU build following #21029 07 May 2024, 16:08:00 UTC
0c26a34 Add optional size argument to jnp.compress & jnp.extract. 07 May 2024, 15:47:34 UTC
9b79f65 Remove deprecated `kind` argument from `jnp.sort` and `jnp.argsort`. PiperOrigin-RevId: 631429900 07 May 2024, 15:18:59 UTC
500da57 Merge pull request #21077 from merrymercy:patch-1 PiperOrigin-RevId: 631409738 07 May 2024, 14:07:04 UTC
70b4477 Start jax and jaxlib 0.4.27 release PiperOrigin-RevId: 631409685 07 May 2024, 14:01:24 UTC
326adc0 [Mosaic GPU] Adjust memref.expand_shape construction to pass in the new args PiperOrigin-RevId: 631404097 07 May 2024, 13:36:36 UTC
3e5a18f Update XLA dependency to use revision http://github.com/openxla/xla/commit/873d09720f83cbbebf2a2a381c09be8fa0934b36. PiperOrigin-RevId: 631274530 07 May 2024, 03:37:43 UTC
cb0c498 Merge pull request #21081 from hawkinsp:sourcemap PiperOrigin-RevId: 631236806 07 May 2024, 00:33:12 UTC
4de3464 Fix that the insufficient output HBM buffer init would cause the <unk> token generated for quantized int8 model. PiperOrigin-RevId: 631235764 07 May 2024, 00:28:13 UTC
eee2783 Merge pull request #21070 from shuhand0:rel0.0.7 PiperOrigin-RevId: 631218770 06 May 2024, 23:22:15 UTC
f6d8852 Merge pull request #20327 from selamw1:add_examples PiperOrigin-RevId: 631186425 06 May 2024, 21:30:06 UTC
back to top