https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
700d1e1 Merge pull request #10177 from hawkinsp:jaxlib PiperOrigin-RevId: 439975370 07 April 2022, 00:27:38 UTC
96ba290 Jax 0.3.5 and jaxlib 0.3.5 release. 06 April 2022, 23:56:41 UTC
7ee6adb Merge pull request #10173 from jakevdp:bcoo-add-batchdim PiperOrigin-RevId: 439955276 06 April 2022, 22:50:10 UTC
6a7a346 Move PartitionSpec from sharded_jit.py to pxla.py. The public endpoint is via jax.experimental so that should be used (no changes to the public endpoint). This move is because sharded_jit is being deprecated. PiperOrigin-RevId: 439948391 06 April 2022, 22:19:19 UTC
93a24f3 [sparse] add bcoo_add_batchdim 06 April 2022, 21:44:29 UTC
bc658e7 [MHLO] Add direct MHLO lowerings for most linear algebra kernels. PiperOrigin-RevId: 439927594 06 April 2022, 20:59:09 UTC
4ed0660 Add deprecation warning for sharded_jit. PiperOrigin-RevId: 439926957 06 April 2022, 20:54:06 UTC
3bfa6af [MHLO] Add MHLO lowering for PRNG kernels. PiperOrigin-RevId: 439919104 06 April 2022, 20:23:01 UTC
b9bb613 [MHLO] Prefer backend-specific HLO lowerings instead of non-backend-specific MHLO lowerings. This allows (in subsequent changes) to switch the generic case for translating a primitive to MHLO, even if we can't yet use an MHLO lowering for a backend-specific case yet. Add a handful of direct MLIR lowerings for primitives that lacked them. PiperOrigin-RevId: 439912093 06 April 2022, 19:53:56 UTC
4012267 Revert: implement `jnp.trace` in terms of `jnp.diagonal` This change appears to blow up compilation times for some models on TPU. PiperOrigin-RevId: 439880940 06 April 2022, 17:46:01 UTC
be64d8b Merge pull request #10164 from hawkinsp:macarm PiperOrigin-RevId: 439857716 06 April 2022, 16:21:53 UTC
21d81c0 Merge pull request #10160 from lgeiger:jnp-trace PiperOrigin-RevId: 439857693 06 April 2022, 16:16:53 UTC
d073a0f Pin a newer @platforms in the Bazel workspace to fix Mac ARM builds. 06 April 2022, 14:43:56 UTC
3e877f3 Implement `jnp.trace` in terms of `jnp.diagonal` 06 April 2022, 00:07:06 UTC
1df942f Merge pull request #10148 from sharadmv:version-info PiperOrigin-RevId: 439686412 05 April 2022, 22:06:47 UTC
d72a7b4 Add version int tuple `__version_info__` to JAX 05 April 2022, 20:26:05 UTC
fef3670 Merge pull request #10140 from jakevdp:jnp-diagonal PiperOrigin-RevId: 439596331 05 April 2022, 16:14:14 UTC
2857579 Merge pull request #10138 from lucasb-eyer:patch-5 PiperOrigin-RevId: 439596234 05 April 2022, 16:09:19 UTC
152c210 [MHLO] Implement return type inference for GetTupleElementOp and TupleOp. PiperOrigin-RevId: 439589720 05 April 2022, 15:38:38 UTC
b7344ed jnp.diagonal: implement in terms of gather rather than sum 05 April 2022, 00:02:11 UTC
97d834f Merge pull request #10139 from jakevdp:fix-gpu-test PiperOrigin-RevId: 439439327 04 April 2022, 23:29:02 UTC
5a96c0c Skip test outside x64 04 April 2022, 23:00:18 UTC
1246b6f Separate jax.test_util implementations into public and private sources. Eventually the private functionality will no longer be exported via the jax.test_util submodule. PiperOrigin-RevId: 439415485 04 April 2022, 21:43:39 UTC
71a5eb2 [GPU] Force an input buffer copy for double precision complex-to-real IRFFTs. Fixes https://github.com/google/jax/issues/9946 PiperOrigin-RevId: 439414091 04 April 2022, 21:38:52 UTC
f7b749c Explicit doc note about device_put* async 04 April 2022, 21:38:51 UTC
6825f65 * Disallow any other type other than GDA and ShapedArray for auto sharding. * Raise errors in the following 4 cases when GDAs sharding does not match the input sharding. **In all the 4 cases below, the check only runs once! There is no double checking going on. I have added tests for these cases. Please check them out.** * Auto sharding * f_pjitted(gda) -- `_pjit_call_impl` catches this mismatch. Only doing this check when `compiled._auto_spmd_lowering` is True. * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch * NO auto sharding * f_pjitted(gda) -- This is already covered and tested and happens in `infer_params` * compiled(gda) -- `def call(*args)` in `MeshExecutable` catches this mismatch PiperOrigin-RevId: 439413895 04 April 2022, 21:33:51 UTC
4949e78 Re-land changes from https://github.com/google/jax/pull/10069 PiperOrigin-RevId: 439381161 04 April 2022, 19:18:43 UTC
41b6e00 Enable use of `GlobalDeviceArray` (GDA) in T5X Checkpointer. Add a separate unit test, `gda_checkpoints_test`, to cover this use case. GDA is locked behind a `use_gda` bool in Checkpointer. The feature is currently not enabled anywhere. Our follow-up plan is to add code which would enable GDA use throughout T5X, and to fix any remaining issues with Checkpointer. PiperOrigin-RevId: 439358913 04 April 2022, 17:56:07 UTC
1b8be90 Remove the jax_enable_mlir flag. MLIR is now the only supported code path. This change does not yet remove all the XLA translation rule code since it may be used in various fallback paths. Only the top-level lowering function is removed. Further cleanup is left to subsequent changes. PiperOrigin-RevId: 439324450 04 April 2022, 15:40:09 UTC
e1bbbf5 Merge pull request #10130 from mattjj:no-string-annotations PiperOrigin-RevId: 439174012 03 April 2022, 19:11:03 UTC
c72d8f6 remove string annotations from core.py 03 April 2022, 18:19:07 UTC
359b614 Merge pull request #10122 from sharadmv:jax2tf-name-stack PiperOrigin-RevId: 439036794 02 April 2022, 16:02:21 UTC
e64a57d Merge pull request #10121 from hawkinsp:hcbcache PiperOrigin-RevId: 439036780 02 April 2022, 15:57:24 UTC
cdf4177 Merge pull request #10126 from jakevdp:tree-multimap PiperOrigin-RevId: 438956536 02 April 2022, 01:52:50 UTC
c61a18b DOC: switch from tree_multimap to tree_map in docs 01 April 2022, 21:52:16 UTC
df1ceae Deprecate jax.tree_util.tree_multimap 01 April 2022, 21:51:54 UTC
9693898 Merge pull request #10123 from fabianp:patch-4 PiperOrigin-RevId: 438906356 01 April 2022, 21:09:08 UTC
0eaeff6 Give auto sharder the mesh information specifically the mesh_shape and the devices ids of devices in the mesh. PiperOrigin-RevId: 438906211 01 April 2022, 21:04:23 UTC
4fd466b remove $ from command line commands A few commands in this file were prefixed with $ which results in an invalid command when copied with the sphinx "copy" button. 01 April 2022, 19:50:29 UTC
1c3edc8 Merge pull request #10110 from pschuh:weakref-bug PiperOrigin-RevId: 438887762 01 April 2022, 19:45:35 UTC
aac8ec8 Fixes `jax2tf`'s `test_name_scope` to use graph introspection instead of side-effect 01 April 2022, 19:39:56 UTC
7b7458b Give auto sharder the mesh information specifically the mesh_shape and the devices ids of devices in the mesh. PiperOrigin-RevId: 438882041 01 April 2022, 19:19:25 UTC
df1c478 Fix race condition for weakref destructor by catching rare exceptions. 01 April 2022, 19:04:36 UTC
8c3385c Expose AutoSharding's mesh_shape and mesh_ids options to JAX. PiperOrigin-RevId: 438874347 01 April 2022, 18:47:56 UTC
208e83c Avoid retracing when a host_callback.call is called multiple times with the same function. If we build a lambda in the host_callback.call() method, the identity of that lambda is different each time and will never lead to a primitive compilation cache hit. Instead, use a custom wrapper object with hash/equality. This issue was found in passing while debugging #9970. 01 April 2022, 18:41:14 UTC
a4a551a Merge pull request #10119 from jakevdp:pil-fix PiperOrigin-RevId: 438853940 01 April 2022, 17:23:00 UTC
1f300e7 CI: pin pillow<9.1 to prevent deprecation warnings 01 April 2022, 16:23:27 UTC
e766b96 Merge pull request #10058 from yotarok:istft PiperOrigin-RevId: 438832534 01 April 2022, 15:43:27 UTC
4decbcb Merge pull request #10103 from LenaMartens:changelist/438319917 PiperOrigin-RevId: 438821559 01 April 2022, 14:40:45 UTC
aa5d6b4 Fix the breakage by including --experimental_cc_shared_library as done by TF. PiperOrigin-RevId: 438746867 01 April 2022, 06:07:42 UTC
a7fd751 Add istft to jax.scipy.signal. 01 April 2022, 05:28:53 UTC
d42981b Merge pull request #10113 from mihaimaruseac:patch-1 PiperOrigin-RevId: 438709484 01 April 2022, 01:29:52 UTC
8ca8f74 First attempt to enable auto-sharding. This CL adds support for GDA (no SDA support yet). An example of using auto sharding with GDA: ``` f = pjit(lambda x: x, in_axis_resources=pjit.AUTO, out_axis_resources=pjit.AUTO) sharding_info = pjit.get_sharding_from_xla(f, mesh, [(8, 2)], [np.int32]) inputs = [GlobalDeviceArray.from_callback(shape, mesh, ip, cb) for ip in sharding_info.in_pspec] # Use the compiled function (which was compiled in get_sharding_from_xla) out = sharding_info.compiled(*inputs) # Recommended way! # OR out = f(*inputs) ``` PiperOrigin-RevId: 438708483 01 April 2022, 01:22:02 UTC
9b68e1a Add missing `:` 01 April 2022, 01:17:51 UTC
113cfe7 Fix syntax error causing typo 01 April 2022, 00:48:59 UTC
f09aff1 Manually split the version string It seems `pkg_resources` also needs to be installed. TF does not provide a version tuple, so split manually. We only use the first 2 components as in `tf.nightly` the patch number contains additional letters, which would cause `int` to fail. 01 April 2022, 00:36:56 UTC
cdd703b Don't compare versions as strings. TF nightly is now 2.10. Comparing as strings will give wrong answer: ```python >>> "2.10.0" >= "2.8.0" False >>> from packaging import version >>> version.parse("2.10.0") >= version.parse("2.8.0") True >>> pkg_resources.parse_version("2.10.0") >= pkg_resources.parse_version("2.8.0") True ``` Use `pkg_resources.parse_version` as it should not require additional installs. 31 March 2022, 23:50:01 UTC
3184dd6 [sparse] Update docstrings for bcoo primitives. PiperOrigin-RevId: 438685829 31 March 2022, 23:17:05 UTC
ced2cbe Merge pull request #10097 from lgeiger:expand-dims PiperOrigin-RevId: 438649114 31 March 2022, 20:39:34 UTC
d9403f6 Merge pull request #10092 from jakevdp:sharp-bits-divergences PiperOrigin-RevId: 438644183 31 March 2022, 20:18:47 UTC
5181692 Merge pull request #10102 from gnecula:hcb_fix PiperOrigin-RevId: 438638041 31 March 2022, 19:52:40 UTC
72f14a5 Sharp bits: add miscellaneous divergences from numpy 31 March 2022, 19:15:06 UTC
f31b7c9 [jax2tf] Add links to Google-internal documentation. PiperOrigin-RevId: 438599124 31 March 2022, 17:17:47 UTC
137df31 Merge pull request #10089 from sharadmv:name-stack-fix PiperOrigin-RevId: 438591565 31 March 2022, 16:47:50 UTC
0bcd447 If `jax_parallel_functions_output_gda` flag is set to True, then all outputs are GDA. In `abstract_eval` rule of pjit, don't convert global avals to local avals if the `jax_parallel_functions_output_gda` flag is set. Fixes: https://github.com/google/jax/issues/10084 PiperOrigin-RevId: 438584351 31 March 2022, 16:17:20 UTC
6bff559 Merge pull request #10101 from ROCmSoftwarePlatform:upgrade_to_rocm_51 PiperOrigin-RevId: 438576411 31 March 2022, 15:43:04 UTC
dcd3006 Merge pull request #10027 from jakevdp:fix-vmap-weaktype PiperOrigin-RevId: 438565124 31 March 2022, 14:41:13 UTC
15d2cca Checkify: add axis and axis size to OOB error message. 31 March 2022, 14:16:35 UTC
84e7359 [host_callback] Fix tests to ensure we use the correct platform In host_callback_test, there are a few tests that inspect compiled HLO. In some cases, we're explicitly creating a CPU XLA computation, but we're handing it off the to the default backend. When we're on a TPU machine, we're asking a TPU backend to compile a CPU XLA computation. Fixes internal b/227521177. 31 March 2022, 12:25:21 UTC
190501b Upgrade ROCm build docker to ROCm version 5.1. 31 March 2022, 11:56:38 UTC
c5f9c42 Merge pull request #10096 from superbobry:mlir-types PiperOrigin-RevId: 438503354 31 March 2022, 08:21:27 UTC
19c6657 Suppressed a few more errors in jax/interpreters/mlir.py 31 March 2022, 08:03:14 UTC
50e8bc4 Replace `reshape` with `expand_dims` if possible 31 March 2022, 00:34:26 UTC
19e3592 [jax2tf] Updates `custom_assert` for jax2tf SVD (primitive) limitations. PiperOrigin-RevId: 438421090 30 March 2022, 23:10:14 UTC
0694dbd Merge pull request #10095 from hawkinsp:mlir PiperOrigin-RevId: 438400696 30 March 2022, 21:43:14 UTC
ade9f1a Share compare_mhlo function between lax.py and mlir.py. Use the .shape property on RankedTensorType. 30 March 2022, 21:02:19 UTC
ee1ca3f Merge pull request #10091 from rsuderman:tf_update PiperOrigin-RevId: 438374710 30 March 2022, 19:57:20 UTC
e9717c9 Merge pull request #10090 from sharadmv:docs-typo-fix PiperOrigin-RevId: 438364440 30 March 2022, 19:11:11 UTC
c8fbe75 Update tensorflow version to 0dfdb8 30 March 2022, 18:44:58 UTC
fb97717 Fix link to buffer-donation docs 30 March 2022, 18:23:23 UTC
c233a97 Remove redundant name-stack setting in `DynamicJaxprTrace` 30 March 2022, 18:18:56 UTC
34f116c vmap: preserve weak_type in batching tracer 30 March 2022, 18:06:56 UTC
a04b777 [mhlo] Clean up ops that can use InferTensorTypeWithReify This means we can get rid of custom builders and return type inference. This all goes through inferReturnTypeComponents now, so fix obvious bugs in those implementations. There should be no behaviorial change. However, python bindings no longer generate a result type builder for mhlo.CompareOp, which is unfortunate. PiperOrigin-RevId: 438341237 30 March 2022, 17:44:16 UTC
b1a50fd Merge pull request #10087 from superbobry:patch-1 PiperOrigin-RevId: 438323358 30 March 2022, 16:35:12 UTC
b315212 Fixed a typo in the return type of _array_ir_types There could be other typing issues in that module, but I will address them separately. 30 March 2022, 15:31:59 UTC
1555ba1 Copybara import of the project: -- de9a948d1ce407056de545b5717c3441298e2f36 by Jake VanderPlas <jakevdp@google.com>: make device_array.copy() return a device array PiperOrigin-RevId: 438308145 30 March 2022, 15:30:18 UTC
3dee82d Merge pull request #10073 from jakevdp:wrap-links PiperOrigin-RevId: 438307525 30 March 2022, 15:24:36 UTC
c46ad2f Merge pull request #10085 from jakevdp:changelog-10073 PiperOrigin-RevId: 438305623 30 March 2022, 15:15:06 UTC
b359b8a Add CHANGELOG entry for #10069 30 March 2022, 15:05:34 UTC
ef2efec Merge pull request #10069 from jakevdp:devicearray-copy PiperOrigin-RevId: 438292130 30 March 2022, 14:01:19 UTC
17fc5bd Merge pull request #9290 from fehiepsi:named PiperOrigin-RevId: 438290209 30 March 2022, 13:54:10 UTC
4ac8eb0 Merge pull request #9645 from yalaudah:main PiperOrigin-RevId: 438290146 30 March 2022, 13:49:16 UTC
02c2eb5 Merge pull request #9682 from jblespiau:changelist/430664138 PiperOrigin-RevId: 438285304 30 March 2022, 13:22:20 UTC
f8cddf0 Merge pull request #9689 from fehiepsi:image PiperOrigin-RevId: 438282642 30 March 2022, 13:05:30 UTC
9da5f4e Merge pull request #10074 from jakevdp:jit-partial-doc PiperOrigin-RevId: 438193490 30 March 2022, 03:02:17 UTC
8884ce5 Migrate 'jaxlib' CPU custom-calls to the status-returning API PiperOrigin-RevId: 438165260 30 March 2022, 00:14:14 UTC
f4b64f4 doc: add examples of using partial with jit 29 March 2022, 22:43:58 UTC
b31cf89 Merge pull request #10072 from jakevdp:fromiter PiperOrigin-RevId: 438141629 29 March 2022, 22:28:36 UTC
23c783a Merge pull request #10068 from hawkinsp:docs PiperOrigin-RevId: 438131830 29 March 2022, 21:49:47 UTC
4f6ea7b docs: use intersphinx links for wrapped functions 29 March 2022, 21:43:59 UTC
back to top