038dfee | Peter Hawkins | 08 May 2024, 18:55:53 UTC | Prepare 0.4.28 release. | 09 May 2024, 19:25:33 UTC |
f98e707 | jax authors | 09 May 2024, 19:21:39 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/d60579f54a0b6c37d1caf11dc3eb34488cf6922a. PiperOrigin-RevId: 632232639 | 09 May 2024, 19:22:25 UTC |
1a7a2aa | jax authors | 09 May 2024, 18:33:54 UTC | Merge pull request #21106 from jakevdp:linalg-precision PiperOrigin-RevId: 632217396 | 09 May 2024, 18:33:54 UTC |
0c4d81c | jax authors | 09 May 2024, 17:38:51 UTC | Merge pull request #21138 from jakevdp:einsum-doc PiperOrigin-RevId: 632198113 | 09 May 2024, 17:38:51 UTC |
2ddb7ff | Jake VanderPlas | 09 May 2024, 17:06:51 UTC | jnp.linalg: add precision & preferred_element_type to dot-like functions | 09 May 2024, 17:06:51 UTC |
eb0b1b0 | Justin Fu | 09 May 2024, 16:20:30 UTC | Merge pull request #21108 from justinjfu/skip_pallas_test_64 Skip float64 test_nextafter on TPU. | 09 May 2024, 16:20:30 UTC |
671fb12 | Yash Katariya | 09 May 2024, 15:37:43 UTC | Update the multi-process note in pjit's docstring PiperOrigin-RevId: 632160561 | 09 May 2024, 15:38:29 UTC |
2be3f6d | jax authors | 09 May 2024, 14:28:49 UTC | Merge pull request #21146 from jakevdp:fix-multidot PiperOrigin-RevId: 632142647 | 09 May 2024, 14:28:49 UTC |
89d25bb | Peter Hawkins | 09 May 2024, 14:08:48 UTC | Reenable examples_test in Bazel build. Fix bitrot. This test was disabled years ago because it was slow, but it isn't any more. PiperOrigin-RevId: 632138101 | 09 May 2024, 14:10:07 UTC |
5edfaa6 | Jake VanderPlas | 09 May 2024, 13:47:30 UTC | jnp.linalg.multi_dot: use optimize='auto' | 09 May 2024, 13:47:30 UTC |
1e88e2f | jax authors | 09 May 2024, 02:16:21 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/4872030c44e20556efdfea47829d942985e0ccf1. PiperOrigin-RevId: 631997979 | 09 May 2024, 02:17:08 UTC |
168f40e | Peter Hawkins | 09 May 2024, 01:04:40 UTC | [XLA:Python] Fix a memory corruption bug in the tp_name attribute of ArrayImpl and PjitFunction for Python 3.10 or earlier. This works around https://github.com/python/cpython/issues/89478, which was fixed in Python 3.11. PiperOrigin-RevId: 631984256 | 09 May 2024, 01:05:28 UTC |
96f888b | Yash Katariya | 09 May 2024, 00:22:24 UTC | Reverts 1956ff7d7b73794012fece2d8452e097196587fc PiperOrigin-RevId: 631974751 | 09 May 2024, 00:23:13 UTC |
f991dd8 | jax authors | 08 May 2024, 23:06:47 UTC | Merge pull request #21139 from jakevdp:fix-lpmn-test PiperOrigin-RevId: 631954696 | 08 May 2024, 23:06:47 UTC |
f556a17 | Jake VanderPlas | 08 May 2024, 22:55:20 UTC | TST: fix Lpmn test for new scipy | 08 May 2024, 22:55:20 UTC |
e870052 | Jake VanderPlas | 08 May 2024, 21:30:59 UTC | jnp.einsum: improve documentation | 08 May 2024, 21:30:59 UTC |
962f084 | jax authors | 08 May 2024, 21:20:10 UTC | Merge pull request #21137 from superbobry:pallas PiperOrigin-RevId: 631923082 | 08 May 2024, 21:20:10 UTC |
65d4c68 | jax authors | 08 May 2024, 20:59:50 UTC | Generic reduce window jvp The problem is that we want to generically jvp and tranpose over any reduction_fn. Jax already handles some of the hard parts for us, namely, ensuring that the user provided fn is jax capturable. All that is left then, is to write a jvp and tranpose fn that utilize the jax utils correctly. However, this is not so straightforward because in order to get the transpose of a reduction window, we need to be able to use both the tangents and primals. The current reduce_fn operates on (x, y) - but we actually need is, under jvp, to operate on `(x_primal, y_primal, x_tangent, y_tangent)`. In turn, this means we need to push down notions of a jvp-specific reduction_fn (captured via the usual machinery of as_fun `as_fun(jvp_fn(closed(user_reduction_jaxp)))`). For the jvp fn, we stack the primal operand and the tangent operand together, and we stack their respective initial values together - this means a good deal of changes to safety checks and assurances downstream (as well as unpacking) as the shape of the operand has changed from [K,...Kt] to [K, ...Kt, 2] where the last dim is the stacked primal and tangent values. In following CLs, we will add (1) re-entrant/recursive is_jvp and (2) transposition PiperOrigin-RevId: 631916764 | 08 May 2024, 21:00:39 UTC |
4b62425 | Sergei Lebedev | 08 May 2024, 20:38:05 UTC | Renamed is_device_gpu_at_least to is_cuda_compute_capability_at_least This makes it clear that the predicate is only supposed to be used for NVidia GPUs at the moment. | 08 May 2024, 20:41:50 UTC |
bfdc87d | jax authors | 08 May 2024, 20:34:40 UTC | Merge pull request #21136 from superbobry:pallas PiperOrigin-RevId: 631908723 | 08 May 2024, 20:34:40 UTC |
575ba94 | Sergei Lebedev | 08 May 2024, 19:29:18 UTC | Removed get_compute_capability from jax.experimental.pallas.gpu Compute capability is available as a `str` attribute on a GPU device since jaxlib 0.4.26. | 08 May 2024, 20:10:43 UTC |
a145109 | jax authors | 08 May 2024, 18:46:37 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/68b17a8571f197676ea07479637da1546b3b501d. PiperOrigin-RevId: 631874377 | 08 May 2024, 18:47:24 UTC |
f768cb7 | Anselm Levskaya | 08 May 2024, 18:21:48 UTC | Refactor pipeline emitter API. Building on enrique's work, this CL refactors the emit_pipeline abstraction: 1) factors out the VMEM double-buffering bookkeeping into a helper class. 2) concentrate the intricate copy/wait scheduling logic into one place inside a scheduler helper while allowing manual overrides, callbacks don't control scheduling anymore, rather we have explicit loop scheduling. 3) minimize callbacks and simplify the "defaults" for fusing pipelines together. Examples of fully overlapped versions of latency- and throughput- optimized AG-matmuls and matmul-RSs are included in new tests. PiperOrigin-RevId: 631865641 | 08 May 2024, 18:22:47 UTC |
e652d62 | Parker Schuh | 08 May 2024, 17:44:40 UTC | Cleanup second registration of custom_partitioning callbacks now that the jaxlib version has been bumped. PiperOrigin-RevId: 631852273 | 08 May 2024, 17:45:39 UTC |
8baa5d8 | jax authors | 08 May 2024, 17:07:42 UTC | Merge pull request #21128 from hawkinsp:loggingtest PiperOrigin-RevId: 631839432 | 08 May 2024, 17:07:42 UTC |
2967ec9 | jax authors | 08 May 2024, 16:25:49 UTC | Merge pull request #21129 from superbobry:pallas PiperOrigin-RevId: 631825682 | 08 May 2024, 16:25:49 UTC |
0feeaa5 | Sergei Lebedev | 08 May 2024, 14:57:04 UTC | Removed stale version guards and try/except blocks from Pallas GPU They are unnecessary now that the minimum jaxlib version is 0.4.27. | 08 May 2024, 16:05:45 UTC |
0e57c79 | jax authors | 08 May 2024, 15:46:52 UTC | Switch Windows jobs to use Clang. Remove the experimental/trial Clang job. PiperOrigin-RevId: 631814321 | 08 May 2024, 15:47:34 UTC |
11da3df | jax authors | 08 May 2024, 12:44:08 UTC | Merge pull request #21096 from gspschmid:gschmid/sourcemaps PiperOrigin-RevId: 631769572 | 08 May 2024, 12:44:08 UTC |
919832a | Peter Hawkins | 08 May 2024, 12:43:52 UTC | Enable logging_test on all CI platforms. Should catch issues like https://github.com/google/jax/issues/21121 | 08 May 2024, 12:43:52 UTC |
b0b322d | Georg Stefan Schmid | 06 May 2024, 14:14:11 UTC | Add sourcemap module to generate TC39-compliant source maps | 08 May 2024, 08:54:25 UTC |
335f27b | jax authors | 08 May 2024, 02:57:41 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/c6df436f9e3d2a706a3a9595e37224276fb6013b. PiperOrigin-RevId: 631629197 | 08 May 2024, 02:58:28 UTC |
e9e4e53 | jax authors | 08 May 2024, 01:13:05 UTC | [jax:mosaic-gpu] FragmentedArray can do tiled load. PiperOrigin-RevId: 631611060 | 08 May 2024, 01:13:55 UTC |
79f11d5 | Jevin Jiang | 07 May 2024, 23:51:46 UTC | [Pallas] Fix some typos. PiperOrigin-RevId: 631592201 | 07 May 2024, 23:52:38 UTC |
395d3cb | Yash Katariya | 07 May 2024, 23:06:48 UTC | Bump minimum jaxlib version to 0.4.27 xla_extension_version is 261 and mlir_api_version is 56 PiperOrigin-RevId: 631579739 | 07 May 2024, 23:07:59 UTC |
5ba56bb | Jieying Luo | 07 May 2024, 21:46:43 UTC | Recommend the plugin in the CUDA installation instructions. PiperOrigin-RevId: 631555876 | 07 May 2024, 21:47:39 UTC |
78e10ee | jax authors | 07 May 2024, 21:14:02 UTC | Merge pull request #21115 from jakevdp:multi-dot PiperOrigin-RevId: 631545161 | 07 May 2024, 21:14:02 UTC |
2b3251e | jax authors | 07 May 2024, 20:51:13 UTC | Merge pull request #21092 from jakevdp:dot-doc PiperOrigin-RevId: 631536980 | 07 May 2024, 20:51:13 UTC |
09810be | Jake VanderPlas | 07 May 2024, 20:40:25 UTC | Implement jnp.linalg.multi_dot using opt_einsum | 07 May 2024, 20:40:25 UTC |
5f70267 | jax authors | 07 May 2024, 20:11:43 UTC | Merge pull request #21103 from superbobry:mosaic-gpu-fix PiperOrigin-RevId: 631521771 | 07 May 2024, 20:11:43 UTC |
7153738 | jax authors | 07 May 2024, 20:06:33 UTC | Merge pull request #21104 from superbobry:triton-fixes PiperOrigin-RevId: 631521767 | 07 May 2024, 20:06:33 UTC |
174405c | jax authors | 07 May 2024, 19:57:55 UTC | The Bazel version used in JAX is bumped from 6.1.2 to 6.5.0. The update is needed for Windows/Clang builds and for the future hermetic CUDA implementation. PiperOrigin-RevId: 631519200 | 07 May 2024, 19:58:37 UTC |
b6fea37 | jax authors | 07 May 2024, 18:34:06 UTC | Merge pull request #21111 from jakevdp:fix-changelog PiperOrigin-RevId: 631493089 | 07 May 2024, 18:34:06 UTC |
9524188 | jax authors | 07 May 2024, 18:25:19 UTC | Merge pull request #21110 from jakevdp:upstream-nightly PiperOrigin-RevId: 631490005 | 07 May 2024, 18:25:19 UTC |
c18851b | Jake VanderPlas | 07 May 2024, 18:16:11 UTC | CHANGELOG: move change from 0.4.27 to 0.4.28 | 07 May 2024, 18:16:11 UTC |
496795e | Jake VanderPlas | 07 May 2024, 18:14:11 UTC | CI: fix typo in workflow | 07 May 2024, 18:14:11 UTC |
5031a1d | Yash Katariya | 07 May 2024, 18:13:21 UTC | Finish jax and jaxlib 0.4.27 release PiperOrigin-RevId: 631486157 | 07 May 2024, 18:14:09 UTC |
034d843 | Jake VanderPlas | 06 May 2024, 22:22:35 UTC | jax.numpy: better docs for matmul-like functions | 07 May 2024, 18:01:54 UTC |
3dd8af9 | jax authors | 07 May 2024, 17:59:10 UTC | Merge pull request #21090 from jakevdp:extract PiperOrigin-RevId: 631480474 | 07 May 2024, 17:59:10 UTC |
274e735 | jax authors | 07 May 2024, 17:38:06 UTC | Merge pull request #21105 from albanie:patch-3 PiperOrigin-RevId: 631472758 | 07 May 2024, 17:38:06 UTC |
daab7a0 | jax authors | 07 May 2024, 17:29:17 UTC | Handle ellipsis `...` in `_attempt_rewriting_take_via_slice`. Previously `model['some_array'][:,0,0,:]` would generate a `slice`, while `model['some_array'][...,0,0,:]` would generate a `gather`. Now both of these generate `slice` eqns. PiperOrigin-RevId: 631469837 | 07 May 2024, 17:30:08 UTC |
080b3f1 | Justin Fu | 07 May 2024, 17:06:08 UTC | Skip float64 test_nextafter on TPU | 07 May 2024, 17:06:08 UTC |
fdd8b61 | Samuel | 07 May 2024, 16:36:19 UTC | Minor doc fix to `jaxpr.rst` Update the written description of the example `jaxpr` to match the given snippet of code. | 07 May 2024, 16:36:19 UTC |
2ca3264 | Sergei Lebedev | 07 May 2024, 16:30:24 UTC | Updated the Pallas GPU lowering to work with older jaxlib versions Triton changed the signatures of LoadOp and DotOp upstream, and the lowering code was not ready to handle both old and new signatures. | 07 May 2024, 16:30:52 UTC |
8ccbeba | Sergei Lebedev | 07 May 2024, 16:02:58 UTC | Fixed Mosaic GPU build following #21029 | 07 May 2024, 16:08:00 UTC |
0c26a34 | Jake VanderPlas | 07 May 2024, 15:47:34 UTC | Add optional size argument to jnp.compress & jnp.extract. | 07 May 2024, 15:47:34 UTC |
9b79f65 | Jake VanderPlas | 07 May 2024, 15:18:05 UTC | Remove deprecated `kind` argument from `jnp.sort` and `jnp.argsort`. PiperOrigin-RevId: 631429900 | 07 May 2024, 15:18:59 UTC |
500da57 | jax authors | 07 May 2024, 14:07:04 UTC | Merge pull request #21077 from merrymercy:patch-1 PiperOrigin-RevId: 631409738 | 07 May 2024, 14:07:04 UTC |
70b4477 | Yash Katariya | 07 May 2024, 14:00:34 UTC | Start jax and jaxlib 0.4.27 release PiperOrigin-RevId: 631409685 | 07 May 2024, 14:01:24 UTC |
326adc0 | Adam Paszke | 07 May 2024, 13:35:43 UTC | [Mosaic GPU] Adjust memref.expand_shape construction to pass in the new args PiperOrigin-RevId: 631404097 | 07 May 2024, 13:36:36 UTC |
3e5a18f | jax authors | 07 May 2024, 03:37:05 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/873d09720f83cbbebf2a2a381c09be8fa0934b36. PiperOrigin-RevId: 631274530 | 07 May 2024, 03:37:43 UTC |
cb0c498 | jax authors | 07 May 2024, 00:33:12 UTC | Merge pull request #21081 from hawkinsp:sourcemap PiperOrigin-RevId: 631236806 | 07 May 2024, 00:33:12 UTC |
4de3464 | jax authors | 07 May 2024, 00:27:24 UTC | Fix that the insufficient output HBM buffer init would cause the <unk> token generated for quantized int8 model. PiperOrigin-RevId: 631235764 | 07 May 2024, 00:28:13 UTC |
eee2783 | jax authors | 06 May 2024, 23:22:15 UTC | Merge pull request #21070 from shuhand0:rel0.0.7 PiperOrigin-RevId: 631218770 | 06 May 2024, 23:22:15 UTC |
f6d8852 | jax authors | 06 May 2024, 21:30:06 UTC | Merge pull request #20327 from selamw1:add_examples PiperOrigin-RevId: 631186425 | 06 May 2024, 21:30:06 UTC |
aac3679 | Shuhan Ding | 06 May 2024, 20:51:22 UTC | fix jaxlib config name | 06 May 2024, 20:51:22 UTC |
9caf59d | Selam Waktola | 06 May 2024, 20:43:55 UTC | improve documentation for ix_ | 06 May 2024, 20:43:55 UTC |
3d3cb0b | jax authors | 06 May 2024, 20:39:56 UTC | Merge pull request #20842 from Micky774:array-api-default-promotion PiperOrigin-RevId: 631168892 | 06 May 2024, 20:39:56 UTC |
1013b1a | jax authors | 06 May 2024, 20:35:03 UTC | Merge pull request #21079 from jakevdp:tensorinv PiperOrigin-RevId: 631168828 | 06 May 2024, 20:35:03 UTC |
bb6aa12 | jax authors | 06 May 2024, 19:48:29 UTC | Merge pull request #21087 from jakevdp:upstream-print-version PiperOrigin-RevId: 631154080 | 06 May 2024, 19:48:29 UTC |
d014f5d | Peter Hawkins | 06 May 2024, 13:59:18 UTC | Compute source maps when pretty-printing jaxprs. This change is in preparation for adding support for emitting https://tc39.es/source-map/ source map information for jaxprs, so the relationship between a jaxpr and its Python code can be visualized using tooling for that purpose. This change adds a new `source_map()` pretty printer document, which causes the pretty-printer to populate a source_map side output during pretty printing. The change also teaches the core jaxpr pretty printer to populate source map information on each equation. | 06 May 2024, 19:45:25 UTC |
7d96e78 | Jake VanderPlas | 06 May 2024, 18:14:38 UTC | CI: print numpy/scipy version in upstream job | 06 May 2024, 18:14:38 UTC |
a265e42 | jax authors | 06 May 2024, 18:09:17 UTC | Add an experimental, Clang version of the Windows CI job. Once proven to work, this job will be deleted, and the MSVC job changed to use Clang. PiperOrigin-RevId: 631122967 | 06 May 2024, 18:10:09 UTC |
4a36315 | Jake VanderPlas | 06 May 2024, 14:04:11 UTC | jnp.linalg tensorinv & tensorsolve: improve implementation & docs | 06 May 2024, 18:08:36 UTC |
d26bd73 | jax authors | 06 May 2024, 17:49:34 UTC | Merge pull request #21084 from jakevdp:fix-upstream PiperOrigin-RevId: 631112318 | 06 May 2024, 17:49:34 UTC |
7e9ef1e | jax authors | 06 May 2024, 17:44:42 UTC | Merge pull request #21078 from jakevdp:numpy-linalg-doc PiperOrigin-RevId: 631112228 | 06 May 2024, 17:44:42 UTC |
fb65ba4 | jax authors | 06 May 2024, 17:38:35 UTC | Add a config for using Clang on Windows. PiperOrigin-RevId: 631112031 | 06 May 2024, 17:39:28 UTC |
8ba5c64 | jax authors | 06 May 2024, 17:04:27 UTC | Pass `bazel_options` directly to the Bazel command, instead of into .bazelrc. PiperOrigin-RevId: 631099970 | 06 May 2024, 17:05:19 UTC |
40b2d48 | Jake VanderPlas | 06 May 2024, 16:22:59 UTC | jnp.linalg: improve API documentation | 06 May 2024, 16:22:59 UTC |
6f7ebff | Jake VanderPlas | 06 May 2024, 16:20:07 UTC | random_lax_test: fix kstest for newer NumPy | 06 May 2024, 16:20:07 UTC |
34c5163 | Meekail Zain | 06 May 2024, 15:13:10 UTC | Refactored common upcast for integral-type accumulators | 06 May 2024, 15:13:10 UTC |
0eed28a | Lianmin Zheng | 06 May 2024, 11:59:23 UTC | Fix a typo in jax.jit docstring | 06 May 2024, 11:59:23 UTC |
7681493 | jax authors | 06 May 2024, 07:58:00 UTC | Don't create temp directory when module is getting imported. PiperOrigin-RevId: 630958402 | 06 May 2024, 07:58:45 UTC |
047ea21 | jax authors | 05 May 2024, 02:09:11 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/8833ecd8705f94c25502cb3580071ede8a2fe705. PiperOrigin-RevId: 630731298 | 05 May 2024, 02:09:55 UTC |
a1c8221 | jax authors | 04 May 2024, 02:18:52 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/25d66ce58c51869f1220bc3b4df96a6939d8fbf6. PiperOrigin-RevId: 630557277 | 04 May 2024, 02:19:41 UTC |
1b804a7 | jax authors | 04 May 2024, 02:06:46 UTC | Merge pull request #21056 from mattjj:vmap-grad-remat-shmap-bug PiperOrigin-RevId: 630555588 | 04 May 2024, 02:06:46 UTC |
7a87010 | Matthew Johnson | 03 May 2024, 16:12:24 UTC | [shard_map] better fix for spmd_axis_name issues with shmap residuals The fix in #21032 was not correct because it assumed that the set of all mesh axis names appearing in in_specs was an upper bound on the set of mesh axes over which residuals could be device-varying. But collectives can introduce device variance! So it's not an upper bound. We track device variance when check_rep=True, but often people set check_rep=False (e.g. when using pallas_call in a shard_map). So relying on our device variance tracking would be limiting. That may be a decent long term solution, if we can make it easy to annotate pallas_calls with device variance information. But it's not a great short term one to unblock things. So instead I temporrarily went with context sensitivity: instead of making residuals sharded over all mesh.axis_names (as we did before these patches), we make them sharded over all mesh axis names _excluding_ any spmd_axis_names in our dynamic context (by looking at the traces in our trace stack). It's illegal to mention any spmd_axis_names in collectives (indeed anywhere in the body of the function being vmapped), but I don't think we check it. TODO(mattjj): add more testing (maybe in follow-ups) | 04 May 2024, 01:31:15 UTC |
e95173a | Jake VanderPlas | 03 May 2024, 23:54:22 UTC | Require arraylike input for several jax.numpy functions PiperOrigin-RevId: 630532821 | 03 May 2024, 23:55:10 UTC |
53208ff | jax authors | 03 May 2024, 22:14:12 UTC | Merge pull request #21058 from jakevdp:jnp-delete-doc PiperOrigin-RevId: 630510340 | 03 May 2024, 22:14:12 UTC |
88318e6 | Jake VanderPlas | 03 May 2024, 17:48:58 UTC | jnp.delete: better docs | 03 May 2024, 21:41:06 UTC |
fc6a5d3 | jax authors | 03 May 2024, 21:35:03 UTC | Merge pull request #21059 from jakevdp:numpy-doc-tests PiperOrigin-RevId: 630500410 | 03 May 2024, 21:35:03 UTC |
2072224 | jax authors | 03 May 2024, 20:39:02 UTC | Merge pull request #21064 from jakevdp:doc-new-tutorials PiperOrigin-RevId: 630484720 | 03 May 2024, 20:39:02 UTC |
10ed827 | Jake VanderPlas | 18 April 2024, 20:11:25 UTC | DOC: replace old tutorials with new content | 03 May 2024, 19:20:06 UTC |
f2c2892 | Jake VanderPlas | 03 May 2024, 18:04:43 UTC | Refactor jax.numpy docstring tests | 03 May 2024, 18:04:43 UTC |
7e20e53 | jax authors | 03 May 2024, 17:49:57 UTC | Merge pull request #21057 from jakevdp:scipy-imports PiperOrigin-RevId: 630435054 | 03 May 2024, 17:49:57 UTC |
ff67e51 | Jake VanderPlas | 03 May 2024, 17:20:05 UTC | Remove last scipy imports | 03 May 2024, 17:20:05 UTC |
e70191b | jax authors | 03 May 2024, 17:18:49 UTC | Merge pull request #21055 from shoyer:relu-doc PiperOrigin-RevId: 630425740 | 03 May 2024, 17:18:49 UTC |
c77370c | Stephan Hoyer | 03 May 2024, 16:10:00 UTC | Recursively pull out __wrapped__ in linkcode_resolve() This should actually fix the source lookup for `jax.nn.relu`, which uses both `custom_jvp` and `jit` decorators. | 03 May 2024, 16:10:00 UTC |
c0cfc7a | jax authors | 03 May 2024, 13:47:22 UTC | Merge pull request #21032 from mattjj:vmap-grad-shmap-bug PiperOrigin-RevId: 630375950 | 03 May 2024, 13:47:22 UTC |
989ea61 | jax authors | 03 May 2024, 13:00:40 UTC | Merge pull request #21047 from shoyer:linkcode-robust PiperOrigin-RevId: 630367113 | 03 May 2024, 13:00:40 UTC |