6ba9fb6 | Yash Katariya | 13 April 2022, 00:47:28 UTC | Upgrade the bazel version to 5.1.1 PiperOrigin-RevId: 441338363 | 13 April 2022, 00:48:09 UTC |
c06eff8 | jax authors | 12 April 2022, 19:54:22 UTC | Merge pull request #10245 from google:yashk2810-patch-7 PiperOrigin-RevId: 441265709 | 12 April 2022, 19:54:22 UTC |
5fd78ea | Yash Katariya | 12 April 2022, 18:41:07 UTC | Bump the libtpu version to prepare for JAX release | 12 April 2022, 18:41:07 UTC |
9455254 | Peter Hawkins | 12 April 2022, 16:45:18 UTC | [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one. This is a second attempt at this change. In this version, check for and report an error on jit(pjit(...)), which was the root cause of the failure that led to the previous version being reverted. PiperOrigin-RevId: 441214076 | 12 April 2022, 17:30:52 UTC |
3136004 | Yash Katariya | 12 April 2022, 16:34:58 UTC | Fix the pytype error. PyType is looking for a __init__ method. This does not change the behavior of the class. ``` Function PartitionSpec.__init__ expects 1 arg(s), got 3 [wrong-arg-count] Expected: (self) Actually passed: (self, _, _) ``` PiperOrigin-RevId: 441211351 | 12 April 2022, 16:36:28 UTC |
a2c2d9a | jax authors | 12 April 2022, 13:51:29 UTC | [JAX] Adds the approx_top_k_p bridge. PiperOrigin-RevId: 441172779 | 12 April 2022, 13:52:47 UTC |
7be37ab | jax authors | 12 April 2022, 05:42:07 UTC | Merge pull request #10120 from mattjj:djax-latest PiperOrigin-RevId: 441088673 | 12 April 2022, 05:42:07 UTC |
4354f35 | Matthew Johnson | 31 March 2022, 00:52:55 UTC | prototyping dynamic shapes Co-authored-by: Dougal Maclaurin <dougalm@google.com> | 12 April 2022, 05:10:47 UTC |
fb6a143 | jax authors | 12 April 2022, 05:05:10 UTC | Merge pull request #9723 from sharadmv:jaxpr-effects PiperOrigin-RevId: 441083960 | 12 April 2022, 05:05:10 UTC |
b051be4 | Peter Hawkins | 11 April 2022, 21:36:45 UTC | [MHLO] Switch sharded_jit dispatch path to use MHLO lowering. This change runs some chance of breaking sharded_jit users due to a lack of testing, but the plan is to delete sharded_jit very soon anyway. PiperOrigin-RevId: 440999535 | 11 April 2022, 21:41:51 UTC |
0fa1edd | Sharad Vikram | 28 February 2022, 21:36:39 UTC | Adds simple effect types to jaxprs | 11 April 2022, 18:50:41 UTC |
902fc0c | Matthew Johnson | 11 April 2022, 14:56:18 UTC | Remove invertible_ad since it's not in use. PiperOrigin-RevId: 440890949 | 11 April 2022, 14:56:58 UTC |
35b32ee | jax authors | 09 April 2022, 21:25:01 UTC | Merge pull request #10215 from mattjj:dispatch-tweaks PiperOrigin-RevId: 440610850 | 09 April 2022, 21:25:01 UTC |
9f1fab2 | Matthew Johnson | 09 April 2022, 17:56:14 UTC | dispatch.py: type annotations, other minor tweaks | 09 April 2022, 21:07:19 UTC |
5fdad0e | Yash Katariya | 09 April 2022, 17:19:16 UTC | Roll forward manylinux2014 builds after fixes. PiperOrigin-RevId: 440589273 | 09 April 2022, 17:19:44 UTC |
a11b41f | Tianjian Lu | 09 April 2022, 15:33:20 UTC | [sparse] Use sorted indices instead of sorted rows only. PiperOrigin-RevId: 440579642 | 09 April 2022, 15:33:48 UTC |
e9f95fa | Yash Katariya | 09 April 2022, 01:51:14 UTC | Make jaxlib builds manylinux2014 compliant. PiperOrigin-RevId: 440497401 | 09 April 2022, 01:51:46 UTC |
506a85b | Yash Katariya | 08 April 2022, 23:21:23 UTC | Make jaxlib builds manylinux2014 compliant. PiperOrigin-RevId: 440476417 | 08 April 2022, 23:21:56 UTC |
94307a0 | Peter Hawkins | 08 April 2022, 21:21:43 UTC | Revert: [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one. PiperOrigin-RevId: 440452521 | 08 April 2022, 21:22:15 UTC |
272ed95 | Matthew Johnson | 08 April 2022, 20:50:15 UTC | remove experimental/djax PiperOrigin-RevId: 440445082 | 08 April 2022, 20:55:22 UTC |
0bfb3ef | jax authors | 08 April 2022, 20:50:04 UTC | [JAX] Fix batch logic for approx_min/max_k Previous logic was copied from lax.sort and was incorrect. Since approx_top_k can handle multi-rank tensors, the only mapping we need is to set the reduction_dim correctly. PiperOrigin-RevId: 440445041 | 08 April 2022, 20:50:36 UTC |
6cb7526 | jax authors | 08 April 2022, 20:19:09 UTC | Merge pull request #10176 from lgeiger:simplify-diagonal PiperOrigin-RevId: 440437605 | 08 April 2022, 20:19:09 UTC |
0f15fa3 | Peter Hawkins | 08 April 2022, 19:57:30 UTC | [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one. PiperOrigin-RevId: 440433044 | 08 April 2022, 19:57:59 UTC |
b8602d0 | jax authors | 08 April 2022, 19:12:12 UTC | Merge pull request #10198 from jakevdp:bcoo-duplicates PiperOrigin-RevId: 440423326 | 08 April 2022, 19:12:12 UTC |
654e5bd | Yash Katariya | 08 April 2022, 18:24:22 UTC | Roll forward again after the fix in the auto sharding pass. PiperOrigin-RevId: 440412218 | 08 April 2022, 18:25:07 UTC |
8b9efe7 | Jake VanderPlas | 08 April 2022, 18:04:26 UTC | [sparse] fix autodiff bug in spdot_general | 08 April 2022, 18:04:26 UTC |
60c828a | Lukas Geiger | 08 April 2022, 17:10:01 UTC | Simplify `jnp.trace` implementation | 08 April 2022, 17:10:01 UTC |
084adc7 | Lukas Geiger | 08 April 2022, 17:07:32 UTC | Simplify `jnp.diagonal` implementation | 08 April 2022, 17:07:32 UTC |
8b6b736 | Peter Hawkins | 08 April 2022, 15:48:26 UTC | Revert: Pin a newer @platforms in the Bazel workspace to fix Mac ARM builds. The @platforms repository has been updated in the @tf_runtime repository, which was pulling in the old version of @platforms. We no longer need to override @platforms in the JAX WORKSPACE. PiperOrigin-RevId: 440375016 | 08 April 2022, 15:49:00 UTC |
648a512 | Peter Hawkins | 08 April 2022, 15:43:23 UTC | [MHLO] Add direct MHLO lowerings for sparse primitives. PiperOrigin-RevId: 440374054 | 08 April 2022, 15:43:57 UTC |
1cb4fcc | jax authors | 08 April 2022, 14:41:33 UTC | Merge pull request #10187 from hawkinsp:jaxlib PiperOrigin-RevId: 440362035 | 08 April 2022, 14:41:33 UTC |
0c02f79 | Joan Puigcerver | 08 April 2022, 08:08:24 UTC | Enable tests related to the Gamma distribution for non-default PRNG implementations only when jax_enable_custom_prng is enabled, for consistency with other tests. PiperOrigin-RevId: 440300882 | 08 April 2022, 08:08:55 UTC |
4dc6903 | Peter Hawkins | 07 April 2022, 20:38:09 UTC | Update version numbers after jax/jaxlib release. | 07 April 2022, 20:40:19 UTC |
58bdcb8 | jax authors | 07 April 2022, 20:20:36 UTC | Merge pull request #10185 from hawkinsp:jaxlib PiperOrigin-RevId: 440184287 | 07 April 2022, 20:20:36 UTC |
7f751c5 | Peter Hawkins | 07 April 2022, 20:14:13 UTC | Update libtpu version for jax 0.3.5 release. | 07 April 2022, 20:14:13 UTC |
28cb44e | jax authors | 07 April 2022, 19:57:19 UTC | Merge pull request #10184 from jakevdp:merge-bcoo-dot-general PiperOrigin-RevId: 440178509 | 07 April 2022, 19:57:19 UTC |
96af4d5 | Yash Katariya | 07 April 2022, 19:13:36 UTC | Remove sharded_jit usage from jax2tf because sharded_jit is deprecated. PiperOrigin-RevId: 440169129 | 07 April 2022, 19:14:10 UTC |
01e4fa8 | Jake VanderPlas | 07 April 2022, 18:28:12 UTC | [sparse] consolidate flavors of bcoo_dot_general | 07 April 2022, 18:28:12 UTC |
5522ed1 | Lena Martens | 07 April 2022, 16:18:48 UTC | jax2tf: Support uint32 keys in rng_bit_generator. This follows the rng_bit_generator_translation rule in JAX, which allows for both uint32 and uint64 keys and casts between them. The default rbg prng implementation in JAX uses a (4,) uint32 key. PiperOrigin-RevId: 440124048 | 07 April 2022, 16:19:22 UTC |
02fd875 | jax authors | 07 April 2022, 16:03:05 UTC | Add __init__ to PolyShape. PiperOrigin-RevId: 440120323 | 07 April 2022, 16:06:37 UTC |
b713d3c | jax authors | 07 April 2022, 15:33:10 UTC | Minor change to lax to support jax2tf shape polymorphic concatenation. PiperOrigin-RevId: 440113799 | 07 April 2022, 15:34:27 UTC |
cbdcdf7 | Peter Hawkins | 07 April 2022, 14:58:59 UTC | [MHLO] Add MHLO lowerings for parallel collectives. PiperOrigin-RevId: 440106753 | 07 April 2022, 14:59:26 UTC |
2884215 | jax authors | 07 April 2022, 03:00:20 UTC | Merge pull request #10167 from ROCmSoftwarePlatform:rocm_solver_api_consolidation PiperOrigin-RevId: 439997492 | 07 April 2022, 03:00:20 UTC |
8b3f039 | jax authors | 07 April 2022, 02:55:29 UTC | Merge pull request #10039 from ajcr:add_scipy_linalg_rsf2csf PiperOrigin-RevId: 439997145 | 07 April 2022, 02:55:29 UTC |
6c560b1 | Rohit Santhanam | 06 April 2022, 14:45:47 UTC | Consolidation of hipsolver/cusolver APIs. | 07 April 2022, 01:46:43 UTC |
832d9aa | jax authors | 07 April 2022, 01:01:05 UTC | Merge pull request #10175 from jakevdp:bcoo-spdot PiperOrigin-RevId: 439980561 | 07 April 2022, 01:01:05 UTC |
700d1e1 | jax authors | 07 April 2022, 00:27:38 UTC | Merge pull request #10177 from hawkinsp:jaxlib PiperOrigin-RevId: 439975370 | 07 April 2022, 00:27:38 UTC |
96ba290 | Peter Hawkins | 06 April 2022, 23:56:41 UTC | Jax 0.3.5 and jaxlib 0.3.5 release. | 06 April 2022, 23:56:41 UTC |
7ee6adb | jax authors | 06 April 2022, 22:50:10 UTC | Merge pull request #10173 from jakevdp:bcoo-add-batchdim PiperOrigin-RevId: 439955276 | 06 April 2022, 22:50:10 UTC |
aa0da8e | Jake VanderPlas | 06 April 2022, 22:40:26 UTC | [sparse] make bcoo_spdot_general return a BCOO array, not raw buffers | 06 April 2022, 22:40:26 UTC |
6a7a346 | Yash Katariya | 06 April 2022, 22:18:39 UTC | Move PartitionSpec from sharded_jit.py to pxla.py. The public endpoint is via jax.experimental so that should be used (no changes to the public endpoint). This move is because sharded_jit is being deprecated. PiperOrigin-RevId: 439948391 | 06 April 2022, 22:19:19 UTC |
93a24f3 | Jake VanderPlas | 06 April 2022, 21:44:29 UTC | [sparse] add bcoo_add_batchdim | 06 April 2022, 21:44:29 UTC |
bc658e7 | Peter Hawkins | 06 April 2022, 20:56:01 UTC | [MHLO] Add direct MHLO lowerings for most linear algebra kernels. PiperOrigin-RevId: 439927594 | 06 April 2022, 20:59:09 UTC |
4ed0660 | Yash Katariya | 06 April 2022, 20:53:34 UTC | Add deprecation warning for sharded_jit. PiperOrigin-RevId: 439926957 | 06 April 2022, 20:54:06 UTC |
3bfa6af | Peter Hawkins | 06 April 2022, 20:22:25 UTC | [MHLO] Add MHLO lowering for PRNG kernels. PiperOrigin-RevId: 439919104 | 06 April 2022, 20:23:01 UTC |
869596f | Alex Riley | 27 March 2022, 11:31:12 UTC | Add jax.scipy.linalg.rsf2csf | 06 April 2022, 20:06:23 UTC |
b9bb613 | Peter Hawkins | 06 April 2022, 19:53:19 UTC | [MHLO] Prefer backend-specific HLO lowerings instead of non-backend-specific MHLO lowerings. This allows (in subsequent changes) to switch the generic case for translating a primitive to MHLO, even if we can't yet use an MHLO lowering for a backend-specific case yet. Add a handful of direct MLIR lowerings for primitives that lacked them. PiperOrigin-RevId: 439912093 | 06 April 2022, 19:53:56 UTC |
4012267 | Peter Hawkins | 06 April 2022, 17:45:19 UTC | Revert: implement `jnp.trace` in terms of `jnp.diagonal` This change appears to blow up compilation times for some models on TPU. PiperOrigin-RevId: 439880940 | 06 April 2022, 17:46:01 UTC |
be64d8b | jax authors | 06 April 2022, 16:21:53 UTC | Merge pull request #10164 from hawkinsp:macarm PiperOrigin-RevId: 439857716 | 06 April 2022, 16:21:53 UTC |
21d81c0 | jax authors | 06 April 2022, 16:16:53 UTC | Merge pull request #10160 from lgeiger:jnp-trace PiperOrigin-RevId: 439857693 | 06 April 2022, 16:16:53 UTC |
d073a0f | Peter Hawkins | 06 April 2022, 14:43:56 UTC | Pin a newer @platforms in the Bazel workspace to fix Mac ARM builds. | 06 April 2022, 14:43:56 UTC |
3e877f3 | Lukas Geiger | 06 April 2022, 00:07:06 UTC | Implement `jnp.trace` in terms of `jnp.diagonal` | 06 April 2022, 00:07:06 UTC |
1df942f | jax authors | 05 April 2022, 22:06:47 UTC | Merge pull request #10148 from sharadmv:version-info PiperOrigin-RevId: 439686412 | 05 April 2022, 22:06:47 UTC |
d72a7b4 | Sharad Vikram | 05 April 2022, 19:42:43 UTC | Add version int tuple `__version_info__` to JAX | 05 April 2022, 20:26:05 UTC |
fef3670 | jax authors | 05 April 2022, 16:14:14 UTC | Merge pull request #10140 from jakevdp:jnp-diagonal PiperOrigin-RevId: 439596331 | 05 April 2022, 16:14:14 UTC |
2857579 | jax authors | 05 April 2022, 16:09:19 UTC | Merge pull request #10138 from lucasb-eyer:patch-5 PiperOrigin-RevId: 439596234 | 05 April 2022, 16:09:19 UTC |
152c210 | Peter Hawkins | 05 April 2022, 15:38:07 UTC | [MHLO] Implement return type inference for GetTupleElementOp and TupleOp. PiperOrigin-RevId: 439589720 | 05 April 2022, 15:38:38 UTC |
b7344ed | Jake VanderPlas | 05 April 2022, 00:02:11 UTC | jnp.diagonal: implement in terms of gather rather than sum | 05 April 2022, 00:02:11 UTC |
97d834f | jax authors | 04 April 2022, 23:29:02 UTC | Merge pull request #10139 from jakevdp:fix-gpu-test PiperOrigin-RevId: 439439327 | 04 April 2022, 23:29:02 UTC |
5a96c0c | Jake VanderPlas | 04 April 2022, 23:00:18 UTC | Skip test outside x64 | 04 April 2022, 23:00:18 UTC |
1246b6f | Jake VanderPlas | 04 April 2022, 21:39:43 UTC | Separate jax.test_util implementations into public and private sources. Eventually the private functionality will no longer be exported via the jax.test_util submodule. PiperOrigin-RevId: 439415485 | 04 April 2022, 21:43:39 UTC |
71a5eb2 | Peter Hawkins | 04 April 2022, 21:34:00 UTC | [GPU] Force an input buffer copy for double precision complex-to-real IRFFTs. Fixes https://github.com/google/jax/issues/9946 PiperOrigin-RevId: 439414091 | 04 April 2022, 21:38:52 UTC |
f7b749c | Lucas Beyer | 04 April 2022, 21:38:51 UTC | Explicit doc note about device_put* async | 04 April 2022, 21:38:51 UTC |
6825f65 | Yash Katariya | 04 April 2022, 21:33:17 UTC | * Disallow any other type other than GDA and ShapedArray for auto sharding. * Raise errors in the following 4 cases when GDAs sharding does not match the input sharding. **In all the 4 cases below, the check only runs once! There is no double checking going on. I have added tests for these cases. Please check them out.** * Auto sharding * f_pjitted(gda) -- `_pjit_call_impl` catches this mismatch. Only doing this check when `compiled._auto_spmd_lowering` is True. * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch * NO auto sharding * f_pjitted(gda) -- This is already covered and tested and happens in `infer_params` * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch PiperOrigin-RevId: 439413895 | 04 April 2022, 21:33:51 UTC |
4949e78 | Jake VanderPlas | 04 April 2022, 19:18:11 UTC | Re-land changes from https://github.com/google/jax/pull/10069 PiperOrigin-RevId: 439381161 | 04 April 2022, 19:18:43 UTC |
41b6e00 | Colin Gaffney | 04 April 2022, 17:55:31 UTC | Enable use of `GlobalDeviceArray` (GDA) in T5X Checkpointer. Add a separate unit test, `gda_checkpoints_test`, to cover this use case. GDA is locked behind a `use_gda` bool in Checkpointer. The feature is currently not enabled anywhere. Our follow-up plan is to add code which would enable GDA use throughout T5X, and to fix any remaining issues with Checkpointer. PiperOrigin-RevId: 439358913 | 04 April 2022, 17:56:07 UTC |
1b8be90 | Peter Hawkins | 04 April 2022, 15:39:32 UTC | Remove the jax_enable_mlir flag. MLIR is now the only supported code path. This change does not yet remove all the XLA translation rule code since it may be used in various fallback paths. Only the top-level lowering function is removed. Further cleanup is left to subsequent changes. PiperOrigin-RevId: 439324450 | 04 April 2022, 15:40:09 UTC |
e1bbbf5 | jax authors | 03 April 2022, 19:11:03 UTC | Merge pull request #10130 from mattjj:no-string-annotations PiperOrigin-RevId: 439174012 | 03 April 2022, 19:11:03 UTC |
c72d8f6 | Matthew Johnson | 03 April 2022, 18:17:57 UTC | remove string annotations from core.py | 03 April 2022, 18:19:07 UTC |
359b614 | jax authors | 02 April 2022, 16:02:21 UTC | Merge pull request #10122 from sharadmv:jax2tf-name-stack PiperOrigin-RevId: 439036794 | 02 April 2022, 16:02:21 UTC |
e64a57d | jax authors | 02 April 2022, 15:57:24 UTC | Merge pull request #10121 from hawkinsp:hcbcache PiperOrigin-RevId: 439036780 | 02 April 2022, 15:57:24 UTC |
cdf4177 | jax authors | 02 April 2022, 01:52:50 UTC | Merge pull request #10126 from jakevdp:tree-multimap PiperOrigin-RevId: 438956536 | 02 April 2022, 01:52:50 UTC |
c61a18b | Jake VanderPlas | 01 April 2022, 21:52:16 UTC | DOC: switch from tree_multimap to tree_map in docs | 01 April 2022, 21:52:16 UTC |
df1ceae | Jake VanderPlas | 01 April 2022, 21:51:54 UTC | Deprecate jax.tree_util.tree_multimap | 01 April 2022, 21:51:54 UTC |
9693898 | jax authors | 01 April 2022, 21:09:08 UTC | Merge pull request #10123 from fabianp:patch-4 PiperOrigin-RevId: 438906356 | 01 April 2022, 21:09:08 UTC |
0eaeff6 | Yash Katariya | 01 April 2022, 21:03:52 UTC | Give auto sharder the mesh information specifically the mesh_shape and the devices ids of devices in the mesh. PiperOrigin-RevId: 438906211 | 01 April 2022, 21:04:23 UTC |
4fd466b | Fabian Pedregosa | 01 April 2022, 19:50:29 UTC | remove $ from command line commands A few commands in this file were prefixed with $ which results in an invalid command when copied with the sphinx "copy" button. | 01 April 2022, 19:50:29 UTC |
1c3edc8 | jax authors | 01 April 2022, 19:45:35 UTC | Merge pull request #10110 from pschuh:weakref-bug PiperOrigin-RevId: 438887762 | 01 April 2022, 19:45:35 UTC |
aac8ec8 | Sharad Vikram | 01 April 2022, 19:39:56 UTC | Fixes `jax2tf`'s `test_name_scope` to use graph introspection instead of side-effect | 01 April 2022, 19:39:56 UTC |
7b7458b | Yash Katariya | 01 April 2022, 19:18:56 UTC | Give auto sharder the mesh information specifically the mesh_shape and the devices ids of devices in the mesh. PiperOrigin-RevId: 438882041 | 01 April 2022, 19:19:25 UTC |
df1c478 | Parker Schuh | 31 March 2022, 20:54:06 UTC | Fix race condition for weakref destructor by catching rare exceptions. | 01 April 2022, 19:04:36 UTC |
8c3385c | jax authors | 01 April 2022, 18:47:26 UTC | Expose AutoSharding's mesh_shape and mesh_ids options to JAX. PiperOrigin-RevId: 438874347 | 01 April 2022, 18:47:56 UTC |
208e83c | Peter Hawkins | 01 April 2022, 18:03:14 UTC | Avoid retracing when a host_callback.call is called multiple times with the same function. If we build a lambda in the host_callback.call() method, the identity of that lambda is different each time and will never lead to a primitive compilation cache hit. Instead, use a custom wrapper object with hash/equality. This issue was found in passing while debugging #9970. | 01 April 2022, 18:41:14 UTC |
a4a551a | jax authors | 01 April 2022, 17:23:00 UTC | Merge pull request #10119 from jakevdp:pil-fix PiperOrigin-RevId: 438853940 | 01 April 2022, 17:23:00 UTC |
1f300e7 | Jake VanderPlas | 01 April 2022, 16:23:27 UTC | CI: pin pillow<9.1 to prevent deprecation warnings | 01 April 2022, 16:23:27 UTC |
e766b96 | jax authors | 01 April 2022, 15:43:27 UTC | Merge pull request #10058 from yotarok:istft PiperOrigin-RevId: 438832534 | 01 April 2022, 15:43:27 UTC |
4decbcb | jax authors | 01 April 2022, 14:40:45 UTC | Merge pull request #10103 from LenaMartens:changelist/438319917 PiperOrigin-RevId: 438821559 | 01 April 2022, 14:40:45 UTC |
aa5d6b4 | Yash Katariya | 01 April 2022, 06:07:01 UTC | Fix the breakage by including --experimental_cc_shared_library as done by TF. PiperOrigin-RevId: 438746867 | 01 April 2022, 06:07:42 UTC |
a7fd751 | Yotaro Kubo | 28 March 2022, 12:33:40 UTC | Add istft to jax.scipy.signal. | 01 April 2022, 05:28:53 UTC |
d42981b | jax authors | 01 April 2022, 01:29:52 UTC | Merge pull request #10113 from mihaimaruseac:patch-1 PiperOrigin-RevId: 438709484 | 01 April 2022, 01:29:52 UTC |