swh:1:snp:71afc25eb6e6e055a37a962e6b91010ec35e397f

sort by:
Revision Author Date Message Commit Date
9e62994 Merge pull request #21135 from hawkinsp:release PiperOrigin-RevId: 632235600 09 May 2024, 19:32:51 UTC
038dfee Prepare 0.4.28 release. 09 May 2024, 19:25:33 UTC
f98e707 Update XLA dependency to use revision http://github.com/openxla/xla/commit/d60579f54a0b6c37d1caf11dc3eb34488cf6922a. PiperOrigin-RevId: 632232639 09 May 2024, 19:22:25 UTC
1a7a2aa Merge pull request #21106 from jakevdp:linalg-precision PiperOrigin-RevId: 632217396 09 May 2024, 18:33:54 UTC
0c4d81c Merge pull request #21138 from jakevdp:einsum-doc PiperOrigin-RevId: 632198113 09 May 2024, 17:38:51 UTC
2ddb7ff jnp.linalg: add precision & preferred_element_type to dot-like functions 09 May 2024, 17:06:51 UTC
eb0b1b0 Merge pull request #21108 from justinjfu/skip_pallas_test_64 Skip float64 test_nextafter on TPU. 09 May 2024, 16:20:30 UTC
671fb12 Update the multi-process note in pjit's docstring PiperOrigin-RevId: 632160561 09 May 2024, 15:38:29 UTC
2be3f6d Merge pull request #21146 from jakevdp:fix-multidot PiperOrigin-RevId: 632142647 09 May 2024, 14:28:49 UTC
89d25bb Reenable examples_test in Bazel build. Fix bitrot. This test was disabled years ago because it was slow, but it isn't any more. PiperOrigin-RevId: 632138101 09 May 2024, 14:10:07 UTC
5edfaa6 jnp.linalg.multi_dot: use optimize='auto' 09 May 2024, 13:47:30 UTC
1e88e2f Update XLA dependency to use revision http://github.com/openxla/xla/commit/4872030c44e20556efdfea47829d942985e0ccf1. PiperOrigin-RevId: 631997979 09 May 2024, 02:17:08 UTC
168f40e [XLA:Python] Fix a memory corruption bug in the tp_name attribute of ArrayImpl and PjitFunction for Python 3.10 or earlier. This works around https://github.com/python/cpython/issues/89478, which was fixed in Python 3.11. PiperOrigin-RevId: 631984256 09 May 2024, 01:05:28 UTC
96f888b Reverts 1956ff7d7b73794012fece2d8452e097196587fc PiperOrigin-RevId: 631974751 09 May 2024, 00:23:13 UTC
f991dd8 Merge pull request #21139 from jakevdp:fix-lpmn-test PiperOrigin-RevId: 631954696 08 May 2024, 23:06:47 UTC
f556a17 TST: fix Lpmn test for new scipy 08 May 2024, 22:55:20 UTC
e870052 jnp.einsum: improve documentation 08 May 2024, 21:30:59 UTC
962f084 Merge pull request #21137 from superbobry:pallas PiperOrigin-RevId: 631923082 08 May 2024, 21:20:10 UTC
65d4c68 Generic reduce window jvp The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly. However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`). For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values. In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition PiperOrigin-RevId: 631916764 08 May 2024, 21:00:39 UTC
4b62425 Renamed is_device_gpu_at_least to is_cuda_compute_capability_at_least This makes it clear that the predicate is only supposed to be used for NVidia GPUs at the moment. 08 May 2024, 20:41:50 UTC
bfdc87d Merge pull request #21136 from superbobry:pallas PiperOrigin-RevId: 631908723 08 May 2024, 20:34:40 UTC
575ba94 Removed get_compute_capability from jax.experimental.pallas.gpu Compute capability is available as a `str` attribute on a GPU device since jaxlib 0.4.26. 08 May 2024, 20:10:43 UTC
a145109 Update XLA dependency to use revision http://github.com/openxla/xla/commit/68b17a8571f197676ea07479637da1546b3b501d. PiperOrigin-RevId: 631874377 08 May 2024, 18:47:24 UTC
f768cb7 Refactor pipeline emitter API. Building on enrique's work, this CL refactors the emit_pipeline abstraction: 1) factors out the VMEM double-buffering bookkeeping into a helper class. 2) concentrate the intricate copy/wait scheduling logic into one place inside a scheduler helper while allowing manual overrides, callbacks don't control scheduling anymore, rather we have explicit loop scheduling. 3) minimize callbacks and simplify the "defaults" for fusing pipelines together. Examples of fully overlapped versions of latency- and throughput- optimized AG-matmuls and matmul-RSs are included in new tests. PiperOrigin-RevId: 631865641 08 May 2024, 18:22:47 UTC
e652d62 Cleanup second registration of custom_partitioning callbacks now that the jaxlib version has been bumped. PiperOrigin-RevId: 631852273 08 May 2024, 17:45:39 UTC
8baa5d8 Merge pull request #21128 from hawkinsp:loggingtest PiperOrigin-RevId: 631839432 08 May 2024, 17:07:42 UTC
2967ec9 Merge pull request #21129 from superbobry:pallas PiperOrigin-RevId: 631825682 08 May 2024, 16:25:49 UTC
0feeaa5 Removed stale version guards and try/except blocks from Pallas GPU They are unnecessary now that the minimum jaxlib version is 0.4.27. 08 May 2024, 16:05:45 UTC
0e57c79 Switch Windows jobs to use Clang. Remove the experimental/trial Clang job. PiperOrigin-RevId: 631814321 08 May 2024, 15:47:34 UTC
11da3df Merge pull request #21096 from gspschmid:gschmid/sourcemaps PiperOrigin-RevId: 631769572 08 May 2024, 12:44:08 UTC
919832a Enable logging_test on all CI platforms. Should catch issues like https://github.com/google/jax/issues/21121 08 May 2024, 12:43:52 UTC
b0b322d Add sourcemap module to generate TC39-compliant source maps 08 May 2024, 08:54:25 UTC
335f27b Update XLA dependency to use revision http://github.com/openxla/xla/commit/c6df436f9e3d2a706a3a9595e37224276fb6013b. PiperOrigin-RevId: 631629197 08 May 2024, 02:58:28 UTC
e9e4e53 [jax:mosaic-gpu] FragmentedArray can do tiled load. PiperOrigin-RevId: 631611060 08 May 2024, 01:13:55 UTC
79f11d5 [Pallas] Fix some typos. PiperOrigin-RevId: 631592201 07 May 2024, 23:52:38 UTC
395d3cb Bump minimum jaxlib version to 0.4.27 xla_extension_version is 261 and mlir_api_version is 56 PiperOrigin-RevId: 631579739 07 May 2024, 23:07:59 UTC
5ba56bb Recommend the plugin in the CUDA installation instructions. PiperOrigin-RevId: 631555876 07 May 2024, 21:47:39 UTC
78e10ee Merge pull request #21115 from jakevdp:multi-dot PiperOrigin-RevId: 631545161 07 May 2024, 21:14:02 UTC
2b3251e Merge pull request #21092 from jakevdp:dot-doc PiperOrigin-RevId: 631536980 07 May 2024, 20:51:13 UTC
09810be Implement jnp.linalg.multi_dot using opt_einsum 07 May 2024, 20:40:25 UTC
5f70267 Merge pull request #21103 from superbobry:mosaic-gpu-fix PiperOrigin-RevId: 631521771 07 May 2024, 20:11:43 UTC
7153738 Merge pull request #21104 from superbobry:triton-fixes PiperOrigin-RevId: 631521767 07 May 2024, 20:06:33 UTC
174405c The Bazel version used in JAX is bumped from 6.1.2 to 6.5.0. The update is needed for Windows/Clang builds and for the future hermetic CUDA implementation. PiperOrigin-RevId: 631519200 07 May 2024, 19:58:37 UTC
b6fea37 Merge pull request #21111 from jakevdp:fix-changelog PiperOrigin-RevId: 631493089 07 May 2024, 18:34:06 UTC
9524188 Merge pull request #21110 from jakevdp:upstream-nightly PiperOrigin-RevId: 631490005 07 May 2024, 18:25:19 UTC
c18851b CHANGELOG: move change from 0.4.27 to 0.4.28 07 May 2024, 18:16:11 UTC
496795e CI: fix typo in workflow 07 May 2024, 18:14:11 UTC
5031a1d Finish jax and jaxlib 0.4.27 release PiperOrigin-RevId: 631486157 07 May 2024, 18:14:09 UTC
034d843 jax.numpy: better docs for matmul-like functions 07 May 2024, 18:01:54 UTC
3dd8af9 Merge pull request #21090 from jakevdp:extract PiperOrigin-RevId: 631480474 07 May 2024, 17:59:10 UTC
274e735 Merge pull request #21105 from albanie:patch-3 PiperOrigin-RevId: 631472758 07 May 2024, 17:38:06 UTC
daab7a0 Handle ellipsis `...` in `_attempt_rewriting_take_via_slice`. Previously `model['some_array'][:,0,0,:]` would generate a `slice`, while `model['some_array'][...,0,0,:]` would generate a `gather`. Now both of these generate `slice` eqns. PiperOrigin-RevId: 631469837 07 May 2024, 17:30:08 UTC
080b3f1 Skip float64 test_nextafter on TPU 07 May 2024, 17:06:08 UTC
fdd8b61 Minor doc fix to `jaxpr.rst` Update the written description of the example `jaxpr` to match the given snippet of code. 07 May 2024, 16:36:19 UTC
2ca3264 Updated the Pallas GPU lowering to work with older jaxlib versions Triton changed the signatures of LoadOp and DotOp upstream, and the lowering code was not ready to handle both old and new signatures. 07 May 2024, 16:30:52 UTC
8ccbeba Fixed Mosaic GPU build following #21029 07 May 2024, 16:08:00 UTC
0c26a34 Add optional size argument to jnp.compress & jnp.extract. 07 May 2024, 15:47:34 UTC
9b79f65 Remove deprecated `kind` argument from `jnp.sort` and `jnp.argsort`. PiperOrigin-RevId: 631429900 07 May 2024, 15:18:59 UTC
500da57 Merge pull request #21077 from merrymercy:patch-1 PiperOrigin-RevId: 631409738 07 May 2024, 14:07:04 UTC
70b4477 Start jax and jaxlib 0.4.27 release PiperOrigin-RevId: 631409685 07 May 2024, 14:01:24 UTC
326adc0 [Mosaic GPU] Adjust memref.expand_shape construction to pass in the new args PiperOrigin-RevId: 631404097 07 May 2024, 13:36:36 UTC
3e5a18f Update XLA dependency to use revision http://github.com/openxla/xla/commit/873d09720f83cbbebf2a2a381c09be8fa0934b36. PiperOrigin-RevId: 631274530 07 May 2024, 03:37:43 UTC
cb0c498 Merge pull request #21081 from hawkinsp:sourcemap PiperOrigin-RevId: 631236806 07 May 2024, 00:33:12 UTC
4de3464 Fix that the insufficient output HBM buffer init would cause the <unk> token generated for quantized int8 model. PiperOrigin-RevId: 631235764 07 May 2024, 00:28:13 UTC
eee2783 Merge pull request #21070 from shuhand0:rel0.0.7 PiperOrigin-RevId: 631218770 06 May 2024, 23:22:15 UTC
f6d8852 Merge pull request #20327 from selamw1:add_examples PiperOrigin-RevId: 631186425 06 May 2024, 21:30:06 UTC
aac3679 fix jaxlib config name 06 May 2024, 20:51:22 UTC
9caf59d improve documentation for ix_ 06 May 2024, 20:43:55 UTC
3d3cb0b Merge pull request #20842 from Micky774:array-api-default-promotion PiperOrigin-RevId: 631168892 06 May 2024, 20:39:56 UTC
1013b1a Merge pull request #21079 from jakevdp:tensorinv PiperOrigin-RevId: 631168828 06 May 2024, 20:35:03 UTC
bb6aa12 Merge pull request #21087 from jakevdp:upstream-print-version PiperOrigin-RevId: 631154080 06 May 2024, 19:48:29 UTC
d014f5d Compute source maps when pretty-printing jaxprs. This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose. This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing. The change also teaches the core jaxpr pretty printer to populate source map information on each equation. 06 May 2024, 19:45:25 UTC
7d96e78 CI: print numpy/scipy version in upstream job 06 May 2024, 18:14:38 UTC
a265e42 Add an experimental, Clang version of the Windows CI job. Once proven to work, this job will be deleted, and the MSVC job changed to use Clang. PiperOrigin-RevId: 631122967 06 May 2024, 18:10:09 UTC
4a36315 jnp.linalg tensorinv & tensorsolve: improve implementation & docs 06 May 2024, 18:08:36 UTC
d26bd73 Merge pull request #21084 from jakevdp:fix-upstream PiperOrigin-RevId: 631112318 06 May 2024, 17:49:34 UTC
7e9ef1e Merge pull request #21078 from jakevdp:numpy-linalg-doc PiperOrigin-RevId: 631112228 06 May 2024, 17:44:42 UTC
fb65ba4 Add a config for using Clang on Windows. PiperOrigin-RevId: 631112031 06 May 2024, 17:39:28 UTC
8ba5c64 Pass `bazel_options` directly to the Bazel command, instead of into .bazelrc. PiperOrigin-RevId: 631099970 06 May 2024, 17:05:19 UTC
40b2d48 jnp.linalg: improve API documentation 06 May 2024, 16:22:59 UTC
6f7ebff random_lax_test: fix kstest for newer NumPy 06 May 2024, 16:20:07 UTC
34c5163 Refactored common upcast for integral-type accumulators 06 May 2024, 15:13:10 UTC
0eed28a Fix a typo in jax.jit docstring 06 May 2024, 11:59:23 UTC
7681493 Don't create temp directory when module is getting imported. PiperOrigin-RevId: 630958402 06 May 2024, 07:58:45 UTC
047ea21 Update XLA dependency to use revision http://github.com/openxla/xla/commit/8833ecd8705f94c25502cb3580071ede8a2fe705. PiperOrigin-RevId: 630731298 05 May 2024, 02:09:55 UTC
a1c8221 Update XLA dependency to use revision http://github.com/openxla/xla/commit/25d66ce58c51869f1220bc3b4df96a6939d8fbf6. PiperOrigin-RevId: 630557277 04 May 2024, 02:19:41 UTC
1b804a7 Merge pull request #21056 from mattjj:vmap-grad-remat-shmap-bug PiperOrigin-RevId: 630555588 04 May 2024, 02:06:46 UTC
7a87010 [shard_map] better fix for spmd_axis_name issues with shmap residuals The fix in #21032 was not correct because it assumed that the set of all mesh axis names appearing in in_specs was an upper bound on the set of mesh axes over which residuals could be device-varying. But collectives can introduce device variance! So it's not an upper bound. We track device variance when check_rep=True, but often people set check_rep=False (e.g. when using pallas_call in a shard_map). So relying on our device variance tracking would be limiting. That may be a decent long term solution, if we can make it easy to annotate pallas_calls with device variance information. But it's not a great short term one to unblock things. So instead I temporrarily went with context sensitivity: instead of making residuals sharded over all mesh.axis_names (as we did before these patches), we make them sharded over all mesh axis names _excluding_ any spmd_axis_names in our dynamic context (by looking at the traces in our trace stack). It's illegal to mention any spmd_axis_names in collectives (indeed anywhere in the body of the function being vmapped), but I don't think we check it. TODO(mattjj): add more testing (maybe in follow-ups) 04 May 2024, 01:31:15 UTC
e95173a Require arraylike input for several jax.numpy functions PiperOrigin-RevId: 630532821 03 May 2024, 23:55:10 UTC
53208ff Merge pull request #21058 from jakevdp:jnp-delete-doc PiperOrigin-RevId: 630510340 03 May 2024, 22:14:12 UTC
88318e6 jnp.delete: better docs 03 May 2024, 21:41:06 UTC
fc6a5d3 Merge pull request #21059 from jakevdp:numpy-doc-tests PiperOrigin-RevId: 630500410 03 May 2024, 21:35:03 UTC
2072224 Merge pull request #21064 from jakevdp:doc-new-tutorials PiperOrigin-RevId: 630484720 03 May 2024, 20:39:02 UTC
10ed827 DOC: replace old tutorials with new content 03 May 2024, 19:20:06 UTC
f2c2892 Refactor jax.numpy docstring tests 03 May 2024, 18:04:43 UTC
7e20e53 Merge pull request #21057 from jakevdp:scipy-imports PiperOrigin-RevId: 630435054 03 May 2024, 17:49:57 UTC
ff67e51 Remove last scipy imports 03 May 2024, 17:20:05 UTC
e70191b Merge pull request #21055 from shoyer:relu-doc PiperOrigin-RevId: 630425740 03 May 2024, 17:18:49 UTC
c77370c Recursively pull out __wrapped__ in linkcode_resolve() This should actually fix the source lookup for `jax.nn.relu`, which uses both `custom_jvp` and `jit` decorators. 03 May 2024, 16:10:00 UTC
c0cfc7a Merge pull request #21032 from mattjj:vmap-grad-shmap-bug PiperOrigin-RevId: 630375950 03 May 2024, 13:47:22 UTC
back to top