https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
6ba9fb6 Upgrade the bazel version to 5.1.1 PiperOrigin-RevId: 441338363 13 April 2022, 00:48:09 UTC
c06eff8 Merge pull request #10245 from google:yashk2810-patch-7 PiperOrigin-RevId: 441265709 12 April 2022, 19:54:22 UTC
5fd78ea Bump the libtpu version to prepare for JAX release 12 April 2022, 18:41:07 UTC
9455254 [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one. This is a second attempt at this change. In this version, check for and report an error on jit(pjit(...)), which was the root cause of the failure that led to the previous version being reverted. PiperOrigin-RevId: 441214076 12 April 2022, 17:30:52 UTC
3136004 Fix the pytype error. PyType is looking for a __init__ method. This does not change the behavior of the class. ``` Function PartitionSpec.__init__ expects 1 arg(s), got 3 [wrong-arg-count] Expected: (self) Actually passed: (self, _, _) ``` PiperOrigin-RevId: 441211351 12 April 2022, 16:36:28 UTC
a2c2d9a [JAX] Adds the approx_top_k_p bridge. PiperOrigin-RevId: 441172779 12 April 2022, 13:52:47 UTC
7be37ab Merge pull request #10120 from mattjj:djax-latest PiperOrigin-RevId: 441088673 12 April 2022, 05:42:07 UTC
4354f35 prototyping dynamic shapes Co-authored-by: Dougal Maclaurin <dougalm@google.com> 12 April 2022, 05:10:47 UTC
fb6a143 Merge pull request #9723 from sharadmv:jaxpr-effects PiperOrigin-RevId: 441083960 12 April 2022, 05:05:10 UTC
b051be4 [MHLO] Switch sharded_jit dispatch path to use MHLO lowering. This change runs some chance of breaking sharded_jit users due to a lack of testing, but the plan is to delete sharded_jit very soon anyway. PiperOrigin-RevId: 440999535 11 April 2022, 21:41:51 UTC
0fa1edd Adds simple effect types to jaxprs 11 April 2022, 18:50:41 UTC
902fc0c Remove invertible_ad since it's not in use. PiperOrigin-RevId: 440890949 11 April 2022, 14:56:58 UTC
35b32ee Merge pull request #10215 from mattjj:dispatch-tweaks PiperOrigin-RevId: 440610850 09 April 2022, 21:25:01 UTC
9f1fab2 dispatch.py: type annotations, other minor tweaks 09 April 2022, 21:07:19 UTC
5fdad0e Roll forward manylinux2014 builds after fixes. PiperOrigin-RevId: 440589273 09 April 2022, 17:19:44 UTC
a11b41f [sparse] Use sorted indices instead of sorted rows only. PiperOrigin-RevId: 440579642 09 April 2022, 15:33:48 UTC
e9f95fa Make jaxlib builds manylinux2014 compliant. PiperOrigin-RevId: 440497401 09 April 2022, 01:51:46 UTC
506a85b Make jaxlib builds manylinux2014 compliant. PiperOrigin-RevId: 440476417 08 April 2022, 23:21:56 UTC
94307a0 Revert: [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one. PiperOrigin-RevId: 440452521 08 April 2022, 21:22:15 UTC
272ed95 remove experimental/djax PiperOrigin-RevId: 440445082 08 April 2022, 20:55:22 UTC
0bfb3ef [JAX] Fix batch logic for approx_min/max_k Previous logic was copied from lax.sort and was incorrect. Since approx_top_k can handle multi-rank tensors, the only mapping we need is to set the reduction_dim correctly. PiperOrigin-RevId: 440445041 08 April 2022, 20:50:36 UTC
6cb7526 Merge pull request #10176 from lgeiger:simplify-diagonal PiperOrigin-RevId: 440437605 08 April 2022, 20:19:09 UTC
0f15fa3 [MHLO] Add a direct MHLO lowering for pjit_p, which lacked one. PiperOrigin-RevId: 440433044 08 April 2022, 19:57:59 UTC
b8602d0 Merge pull request #10198 from jakevdp:bcoo-duplicates PiperOrigin-RevId: 440423326 08 April 2022, 19:12:12 UTC
654e5bd Roll forward again after the fix in the auto sharding pass. PiperOrigin-RevId: 440412218 08 April 2022, 18:25:07 UTC
8b9efe7 [sparse] fix autodiff bug in spdot_general 08 April 2022, 18:04:26 UTC
60c828a Simplify `jnp.trace` implementation 08 April 2022, 17:10:01 UTC
084adc7 Simplify `jnp.diagonal` implementation 08 April 2022, 17:07:32 UTC
8b6b736 Revert: Pin a newer @platforms in the Bazel workspace to fix Mac ARM builds. The @platforms repository has been updated in the @tf_runtime repository, which was pulling in the old version of @platforms. We no longer need to override @platforms in the JAX WORKSPACE. PiperOrigin-RevId: 440375016 08 April 2022, 15:49:00 UTC
648a512 [MHLO] Add direct MHLO lowerings for sparse primitives. PiperOrigin-RevId: 440374054 08 April 2022, 15:43:57 UTC
1cb4fcc Merge pull request #10187 from hawkinsp:jaxlib PiperOrigin-RevId: 440362035 08 April 2022, 14:41:33 UTC
0c02f79 Enable tests related to the Gamma distribution for non-default PRNG implementations only when jax_enable_custom_prng is enabled, for consistency with other tests. PiperOrigin-RevId: 440300882 08 April 2022, 08:08:55 UTC
4dc6903 Update version numbers after jax/jaxlib release. 07 April 2022, 20:40:19 UTC
58bdcb8 Merge pull request #10185 from hawkinsp:jaxlib PiperOrigin-RevId: 440184287 07 April 2022, 20:20:36 UTC
7f751c5 Update libtpu version for jax 0.3.5 release. 07 April 2022, 20:14:13 UTC
28cb44e Merge pull request #10184 from jakevdp:merge-bcoo-dot-general PiperOrigin-RevId: 440178509 07 April 2022, 19:57:19 UTC
96af4d5 Remove sharded_jit usage from jax2tf because sharded_jit is deprecated. PiperOrigin-RevId: 440169129 07 April 2022, 19:14:10 UTC
01e4fa8 [sparse] consolidate flavors of bcoo_dot_general 07 April 2022, 18:28:12 UTC
5522ed1 jax2tf: Support uint32 keys in rng_bit_generator. This follows the rng_bit_generator_translation rule in JAX, which allows for both uint32 and uint64 keys and casts between them. The default rbg prng implementation in JAX uses a (4,) uint32 key. PiperOrigin-RevId: 440124048 07 April 2022, 16:19:22 UTC
02fd875 Add __init__ to PolyShape. PiperOrigin-RevId: 440120323 07 April 2022, 16:06:37 UTC
b713d3c Minor change to lax to support jax2tf shape polymorphic concatenation. PiperOrigin-RevId: 440113799 07 April 2022, 15:34:27 UTC
cbdcdf7 [MHLO] Add MHLO lowerings for parallel collectives. PiperOrigin-RevId: 440106753 07 April 2022, 14:59:26 UTC
2884215 Merge pull request #10167 from ROCmSoftwarePlatform:rocm_solver_api_consolidation PiperOrigin-RevId: 439997492 07 April 2022, 03:00:20 UTC
8b3f039 Merge pull request #10039 from ajcr:add_scipy_linalg_rsf2csf PiperOrigin-RevId: 439997145 07 April 2022, 02:55:29 UTC
6c560b1 Consolidation of hipsolver/cusolver APIs. 07 April 2022, 01:46:43 UTC
832d9aa Merge pull request #10175 from jakevdp:bcoo-spdot PiperOrigin-RevId: 439980561 07 April 2022, 01:01:05 UTC
700d1e1 Merge pull request #10177 from hawkinsp:jaxlib PiperOrigin-RevId: 439975370 07 April 2022, 00:27:38 UTC
96ba290 Jax 0.3.5 and jaxlib 0.3.5 release. 06 April 2022, 23:56:41 UTC
7ee6adb Merge pull request #10173 from jakevdp:bcoo-add-batchdim PiperOrigin-RevId: 439955276 06 April 2022, 22:50:10 UTC
aa0da8e [sparse] make bcoo_spdot_general return a BCOO array, not raw buffers 06 April 2022, 22:40:26 UTC
6a7a346 Move PartitionSpec from sharded_jit.py to pxla.py. The public endpoint is via jax.experimental so that should be used (no changes to the public endpoint). This move is because sharded_jit is being deprecated. PiperOrigin-RevId: 439948391 06 April 2022, 22:19:19 UTC
93a24f3 [sparse] add bcoo_add_batchdim 06 April 2022, 21:44:29 UTC
bc658e7 [MHLO] Add direct MHLO lowerings for most linear algebra kernels. PiperOrigin-RevId: 439927594 06 April 2022, 20:59:09 UTC
4ed0660 Add deprecation warning for sharded_jit. PiperOrigin-RevId: 439926957 06 April 2022, 20:54:06 UTC
3bfa6af [MHLO] Add MHLO lowering for PRNG kernels. PiperOrigin-RevId: 439919104 06 April 2022, 20:23:01 UTC
869596f Add jax.scipy.linalg.rsf2csf 06 April 2022, 20:06:23 UTC
b9bb613 [MHLO] Prefer backend-specific HLO lowerings instead of non-backend-specific MHLO lowerings. This allows (in subsequent changes) to switch the generic case for translating a primitive to MHLO, even if we can't yet use an MHLO lowering for a backend-specific case yet. Add a handful of direct MLIR lowerings for primitives that lacked them. PiperOrigin-RevId: 439912093 06 April 2022, 19:53:56 UTC
4012267 Revert: implement `jnp.trace` in terms of `jnp.diagonal` This change appears to blow up compilation times for some models on TPU. PiperOrigin-RevId: 439880940 06 April 2022, 17:46:01 UTC
be64d8b Merge pull request #10164 from hawkinsp:macarm PiperOrigin-RevId: 439857716 06 April 2022, 16:21:53 UTC
21d81c0 Merge pull request #10160 from lgeiger:jnp-trace PiperOrigin-RevId: 439857693 06 April 2022, 16:16:53 UTC
d073a0f Pin a newer @platforms in the Bazel workspace to fix Mac ARM builds. 06 April 2022, 14:43:56 UTC
3e877f3 Implement `jnp.trace` in terms of `jnp.diagonal` 06 April 2022, 00:07:06 UTC
1df942f Merge pull request #10148 from sharadmv:version-info PiperOrigin-RevId: 439686412 05 April 2022, 22:06:47 UTC
d72a7b4 Add version int tuple `__version_info__` to JAX 05 April 2022, 20:26:05 UTC
fef3670 Merge pull request #10140 from jakevdp:jnp-diagonal PiperOrigin-RevId: 439596331 05 April 2022, 16:14:14 UTC
2857579 Merge pull request #10138 from lucasb-eyer:patch-5 PiperOrigin-RevId: 439596234 05 April 2022, 16:09:19 UTC
152c210 [MHLO] Implement return type inference for GetTupleElementOp and TupleOp. PiperOrigin-RevId: 439589720 05 April 2022, 15:38:38 UTC
b7344ed jnp.diagonal: implement in terms of gather rather than sum 05 April 2022, 00:02:11 UTC
97d834f Merge pull request #10139 from jakevdp:fix-gpu-test PiperOrigin-RevId: 439439327 04 April 2022, 23:29:02 UTC
5a96c0c Skip test outside x64 04 April 2022, 23:00:18 UTC
1246b6f Separate jax.test_util implementations into public and private sources. Eventually the private functionality will no longer be exported via the jax.test_util submodule. PiperOrigin-RevId: 439415485 04 April 2022, 21:43:39 UTC
71a5eb2 [GPU] Force an input buffer copy for double precision complex-to-real IRFFTs. Fixes https://github.com/google/jax/issues/9946 PiperOrigin-RevId: 439414091 04 April 2022, 21:38:52 UTC
f7b749c Explicit doc note about device_put* async 04 April 2022, 21:38:51 UTC
6825f65 * Disallow any other type other than GDA and ShapedArray for auto sharding. * Raise errors in the following 4 cases when GDAs sharding does not match the input sharding. **In all the 4 cases below, the check only runs once! There is no double checking going on. I have added tests for these cases. Please check them out.** * Auto sharding * f_pjitted(gda) -- `_pjit_call_impl` catches this mismatch. Only doing this check when `compiled._auto_spmd_lowering` is True. * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch * NO auto sharding * f_pjitted(gda) -- This is already covered and tested and happens in `infer_params` * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch PiperOrigin-RevId: 439413895 04 April 2022, 21:33:51 UTC
4949e78 Re-land changes from https://github.com/google/jax/pull/10069 PiperOrigin-RevId: 439381161 04 April 2022, 19:18:43 UTC
41b6e00 Enable use of `GlobalDeviceArray` (GDA) in T5X Checkpointer. Add a separate unit test, `gda_checkpoints_test`, to cover this use case. GDA is locked behind a `use_gda` bool in Checkpointer. The feature is currently not enabled anywhere. Our follow-up plan is to add code which would enable GDA use throughout T5X, and to fix any remaining issues with Checkpointer. PiperOrigin-RevId: 439358913 04 April 2022, 17:56:07 UTC
1b8be90 Remove the jax_enable_mlir flag. MLIR is now the only supported code path. This change does not yet remove all the XLA translation rule code since it may be used in various fallback paths. Only the top-level lowering function is removed. Further cleanup is left to subsequent changes. PiperOrigin-RevId: 439324450 04 April 2022, 15:40:09 UTC
e1bbbf5 Merge pull request #10130 from mattjj:no-string-annotations PiperOrigin-RevId: 439174012 03 April 2022, 19:11:03 UTC
c72d8f6 remove string annotations from core.py 03 April 2022, 18:19:07 UTC
359b614 Merge pull request #10122 from sharadmv:jax2tf-name-stack PiperOrigin-RevId: 439036794 02 April 2022, 16:02:21 UTC
e64a57d Merge pull request #10121 from hawkinsp:hcbcache PiperOrigin-RevId: 439036780 02 April 2022, 15:57:24 UTC
cdf4177 Merge pull request #10126 from jakevdp:tree-multimap PiperOrigin-RevId: 438956536 02 April 2022, 01:52:50 UTC
c61a18b DOC: switch from tree_multimap to tree_map in docs 01 April 2022, 21:52:16 UTC
df1ceae Deprecate jax.tree_util.tree_multimap 01 April 2022, 21:51:54 UTC
9693898 Merge pull request #10123 from fabianp:patch-4 PiperOrigin-RevId: 438906356 01 April 2022, 21:09:08 UTC
0eaeff6 Give auto sharder the mesh information specifically the mesh_shape and the devices ids of devices in the mesh. PiperOrigin-RevId: 438906211 01 April 2022, 21:04:23 UTC
4fd466b remove $ from command line commands A few commands in this file were prefixed with $ which results in an invalid command when copied with the sphinx "copy" button. 01 April 2022, 19:50:29 UTC
1c3edc8 Merge pull request #10110 from pschuh:weakref-bug PiperOrigin-RevId: 438887762 01 April 2022, 19:45:35 UTC
aac8ec8 Fixes `jax2tf`'s `test_name_scope` to use graph introspection instead of side-effect 01 April 2022, 19:39:56 UTC
7b7458b Give auto sharder the mesh information specifically the mesh_shape and the devices ids of devices in the mesh. PiperOrigin-RevId: 438882041 01 April 2022, 19:19:25 UTC
df1c478 Fix race condition for weakref destructor by catching rare exceptions. 01 April 2022, 19:04:36 UTC
8c3385c Expose AutoSharding's mesh_shape and mesh_ids options to JAX. PiperOrigin-RevId: 438874347 01 April 2022, 18:47:56 UTC
208e83c Avoid retracing when a host_callback.call is called multiple times with the same function. If we build a lambda in the host_callback.call() method, the identity of that lambda is different each time and will never lead to a primitive compilation cache hit. Instead, use a custom wrapper object with hash/equality. This issue was found in passing while debugging #9970. 01 April 2022, 18:41:14 UTC
a4a551a Merge pull request #10119 from jakevdp:pil-fix PiperOrigin-RevId: 438853940 01 April 2022, 17:23:00 UTC
1f300e7 CI: pin pillow<9.1 to prevent deprecation warnings 01 April 2022, 16:23:27 UTC
e766b96 Merge pull request #10058 from yotarok:istft PiperOrigin-RevId: 438832534 01 April 2022, 15:43:27 UTC
4decbcb Merge pull request #10103 from LenaMartens:changelist/438319917 PiperOrigin-RevId: 438821559 01 April 2022, 14:40:45 UTC
aa5d6b4 Fix the breakage by including --experimental_cc_shared_library as done by TF. PiperOrigin-RevId: 438746867 01 April 2022, 06:07:42 UTC
a7fd751 Add istft to jax.scipy.signal. 01 April 2022, 05:28:53 UTC
d42981b Merge pull request #10113 from mihaimaruseac:patch-1 PiperOrigin-RevId: 438709484 01 April 2022, 01:29:52 UTC
back to top