https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
5a39341 -Update streaming to work with an HTML window PiperOrigin-RevId: 497091885 28 December 2022, 23:44:34 UTC
38521b8 Merge pull request #13717 from jakevdp:sparse-grad PiperOrigin-RevId: 498204646 28 December 2022, 18:08:28 UTC
27d7f77 Merge pull request #13813 from jakevdp:fix-yash-thing PiperOrigin-RevId: 498203100 28 December 2022, 18:00:04 UTC
2b82e7a [sparse] improve implementation of sparse.grad & sparse.value_and_grad 28 December 2022, 17:55:11 UTC
2cf0791 Allow more cases in _TRANSPOSE_TRICKS by ignoring leading 1s in the mesh shape. PiperOrigin-RevId: 498201820 28 December 2022, 17:52:24 UTC
ac88f30 fix signature of shard arg handler 28 December 2022, 17:38:56 UTC
a71ab80 Merge pull request #13804 from jakevdp:masked-arr-error PiperOrigin-RevId: 498199498 28 December 2022, 17:38:04 UTC
a8f796c Merge pull request #13776 from jakevdp:sparse-batching-test PiperOrigin-RevId: 498195672 28 December 2022, 17:16:22 UTC
bc1b22c Merge pull request #13803 from gnecula:tf_dim_64 PiperOrigin-RevId: 498194615 28 December 2022, 17:09:15 UTC
68ec9b5 Merge pull request #13812 from gnecula:tf_while PiperOrigin-RevId: 498192217 28 December 2022, 16:55:16 UTC
71ce600 [jax2tf] Ensure that dim_as_value returns int64 in x64 mode and int32 otherwise Changes all the computations with dimensions to work in int64 if JAX_ENABLE_X64 and int32 otherwise. 28 December 2022, 16:18:33 UTC
9593741 [jax2tf] Fix lowering for tf.while_loop Sometimes TF infers more specific shapes for the init_carry, and this has led to errors: "enters the loop with shape (1,), but has shape (None,) after one iteration" Unfortunately, I cannot construct a small test. 28 December 2022, 15:53:13 UTC
e85fd5c Merge pull request #13811 from gnecula:dim_ops PiperOrigin-RevId: 498177537 28 December 2022, 15:19:39 UTC
1e09602 [jax2tf] Improve handling of dimension polynomials used with JAX arrays. See new description in README.md. 28 December 2022, 10:37:42 UTC
ec02966 Merge pull request #13807 from jakevdp:fix-crossref PiperOrigin-RevId: 498052807 28 December 2022, 00:27:42 UTC
fe4c958 doc: fix host callback module crossref 27 December 2022, 23:59:32 UTC
5367693 Error on numpy masked array inputs. 27 December 2022, 23:42:49 UTC
cca83ae [sparse] add _CheckBatchingSparse utility 27 December 2022, 21:55:50 UTC
8d2cd5c Merge pull request #13805 from jakevdp:fix-rtd PiperOrigin-RevId: 498022216 27 December 2022, 21:15:19 UTC
dc862a9 psum: fix docstring formatting 27 December 2022, 20:55:48 UTC
60dbc3b Merge pull request #13775 from jakevdp:sparse-vmappable-metadata PiperOrigin-RevId: 497995723 27 December 2022, 18:43:36 UTC
b043e76 Merge pull request #13779 from eltociear:patch-3 PiperOrigin-RevId: 497995427 27 December 2022, 18:36:44 UTC
1fc9197 Simplify Array's shard_arg_handler by merging pmap and pjit/xmap paths PiperOrigin-RevId: 497991966 27 December 2022, 18:16:44 UTC
2cdce1d [jax2tf] Simplify _eval_shape for the constant dimensions Before this change _eval_shape would return a tf.Const for the constant dimensions in the shape. Now it returns integer constants. This change reduces the size of the graph for the common cases when the dimensions are constant. It should not be a necessary change though but on some TPU flavors this results in TF2XLA bridge return error "XLA can't use dynamic begin values for slice" even when the dimension is actually a constant. PiperOrigin-RevId: 497986554 27 December 2022, 17:45:54 UTC
909354c Merge pull request #13802 from gnecula:clean_tests PiperOrigin-RevId: 497983945 27 December 2022, 17:27:30 UTC
a1480c4 Migrate JAX from producing MHLO to producing StableHLO As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically: 1) MLIR lowerings now produce StableHLO ops instead of MHLO ops. 2) Fallback lowerings now produce StableHLO ops as well. 3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs). From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO): a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing. b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath. c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo". d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues. PiperOrigin-RevId: 497978733 27 December 2022, 16:53:20 UTC
0752655 [jax2tf] Move pjit test into the right TestCase class This test uses tfxla.call_module directly, belongs to XlaCallModuleTest. 27 December 2022, 13:28:04 UTC
7d452ad Add support for dynamic shapes to GPU threefry2x32 custom call. In presence of dynamic shapes the ThreeFry2x32Descriptor will contain the value n=-1, and the actual desired output length will be passed as an additional operand. If the shape is static then the length will be passed as part of the descriptor. PiperOrigin-RevId: 497945778 27 December 2022, 12:48:26 UTC
0c8a4fb [jax2tf] Simplify irrelevant part of call_tf_test.py PiperOrigin-RevId: 497727816 26 December 2022, 04:26:59 UTC
398aaaa Add support for Ellipsis as an index for stateful operations. PiperOrigin-RevId: 497466879 24 December 2022, 06:46:50 UTC
eb875cd Added a pattern-match optimisation for inplace-select. PiperOrigin-RevId: 497425937 24 December 2022, 00:05:56 UTC
0c51ca2 Minor docstring fix for custom_gradient. PiperOrigin-RevId: 497402055 23 December 2022, 20:12:48 UTC
9fda20f Add examples to lax.psum to illustrate axis_index_groups better. PiperOrigin-RevId: 497401892 23 December 2022, 20:04:59 UTC
55d0a94 Fix typo in loops.py miscellanous -> miscellaneous 23 December 2022, 18:59:19 UTC
cee2779 [sparse] propagate metadata through vmappable 22 December 2022, 22:46:40 UTC
2f3d75a Remove dependency of maps from pjit to avoid circular imports when importing pjit in api.py. PiperOrigin-RevId: 497230514 22 December 2022, 21:35:23 UTC
d2b87be Merge pull request #13333 from harryjulian:bernoulli PiperOrigin-RevId: 497210080 22 December 2022, 19:47:09 UTC
c0d4ae0 Added scipy.stats.bernoulli cdf and ppf. 22 December 2022, 18:12:25 UTC
d8bf334 Merge pull request #13755 from jakevdp:sparse-dense-vmap PiperOrigin-RevId: 497187394 22 December 2022, 18:00:19 UTC
e661c3f [sparse] Handle general batch dimensions in bcoo todense/fromdense 22 December 2022, 17:08:30 UTC
57840dd Move functions into `api_util.py` and `dispatch.py` to remove circular import error when pjit is imported in `api.py` for merging the `jit` and `pjit` frontend API. PiperOrigin-RevId: 497172760 22 December 2022, 16:42:05 UTC
4d4eba4 Fix scatter in CLIP mode with uint32 and uint64 indices Clipping uses np.iinfo(indices.dtype).max and those values are too large to be converted to Python or C constants. This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to failures when jax_array=False. Since that use case will go away soon we just enable the fix for when jax_array=True. PiperOrigin-RevId: 497171518 22 December 2022, 16:34:47 UTC
2b716f2 Fix scatter in CLIP mode with uint32 and uint64 indices Clipping uses np.iinfo(indices.dtype).max and those values are too large to be converted to Python or C constants. This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to failures when jax_array=False. Since that use case will go away soon we just enable the fix for when jax_array=True. PiperOrigin-RevId: 497079129 22 December 2022, 05:59:51 UTC
40f5279 Merge pull request #13761 from wookayin:fix-api-public PiperOrigin-RevId: 497060752 22 December 2022, 03:44:55 UTC
c89e662 jax.linear_util: remove unnecessary exported names Follow-up to https://github.com/google/jax/pull/13735 PiperOrigin-RevId: 497052842 22 December 2022, 02:52:01 UTC
479f336 Make `jax.ensure_compile_time_eval` correctly exposed as a public API This function was added as a public API (#7987) but py.type static checkers do not recognize it as a public API because of the alias name. `jax.eval_context` exists only for backward compatibility, so the correct import would be to import `ensure_compile_time_eval` directly from `jax._src.core`. 22 December 2022, 01:12:08 UTC
704b9fc Don't tuple high-arity computations on GPU. I can think of no reason we do this, and it seems strictly preferable to not have to form an on-device tuple since it requires a H2D copy. As best I can reconstruct it, the sequence of events that got us to the current state was: a) we had backend-independent logic to tuple computations with > 100 arguments. This logic was motivated only by TPU. b) we disabled tupling on CPU, because the XLA:CPU folks noticed it wasn't needed. c) we increased the argument limit on TPU to 2000, because the TPU compiler folks noticed it was too low. This left GPU with the old TPU behavior for no good reason. PiperOrigin-RevId: 496993803 21 December 2022, 21:30:34 UTC
d5830e0 Merge pull request #13754 from mattjj:dot-general-pp-rule PiperOrigin-RevId: 496987303 21 December 2022, 20:58:28 UTC
c2d9b5c tweak dot_general pretty-printing rule to suppress default params 21 December 2022, 18:29:01 UTC
e89b60e [sparse] Propagate SparseInfo to BCSR todense() and tree_(un)flatten(). PiperOrigin-RevId: 496945167 21 December 2022, 17:55:21 UTC
4f75ad6 Revert #13747 since config values already have paired env vars. PiperOrigin-RevId: 496935378 21 December 2022, 17:06:08 UTC
673838e Merge pull request #13678 from jakevdp:bcoo-dot-general-sampled PiperOrigin-RevId: 496931958 21 December 2022, 16:49:00 UTC
44e3e65 [jax2tf] Disable a failing shape polymorphism test PiperOrigin-RevId: 496927119 21 December 2022, 16:24:27 UTC
71f861a [sparse] improve tests for bcoo_dot_general_sampled 21 December 2022, 16:14:42 UTC
7095e07 Merge pull request #13704 from gnecula:gpu_select_and_gather PiperOrigin-RevId: 496919137 21 December 2022, 15:44:34 UTC
4075abc Merge pull request #13719 from jakevdp:sparse-doc PiperOrigin-RevId: 496902058 21 December 2022, 14:08:12 UTC
80998df Merge pull request #13675 from jakevdp:bcoo-dot-general-vjp PiperOrigin-RevId: 496902013 21 December 2022, 14:01:03 UTC
87aa3aa Merge pull request #13735 from jakevdp:private-linear-util PiperOrigin-RevId: 496896110 21 December 2022, 13:15:38 UTC
ce5320a Copybara import of the project: -- a74c74c25572eec23c28e08dbe67781a23be19fb by George Necula <gcnecula@gmail.com>: Fix scatter in CLIP mode with uint32 and uint64 indices Clipping uses np.iinfo(indices.dtype).max and those values are too large to be converted to Python or C constants. PiperOrigin-RevId: 496883024 21 December 2022, 11:46:27 UTC
76f92c4 Set default value for jax_platforms based on JAX_PLATFORMS env var. PiperOrigin-RevId: 496872386 21 December 2022, 10:38:14 UTC
8a409b9 Merge pull request #13746 from gnecula:scatter_clip PiperOrigin-RevId: 496854361 21 December 2022, 08:53:14 UTC
a74c74c Fix scatter in CLIP mode with uint32 and uint64 indices Clipping uses np.iinfo(indices.dtype).max and those values are too large to be converted to Python or C constants. 21 December 2022, 08:25:24 UTC
45a2116 [jax2tf] Disable known failures for shape polymorphism with native lowering PiperOrigin-RevId: 496849655 21 December 2022, 08:23:08 UTC
8c3f19a Merge pull request #13744 from gnecula:tf_autograph PiperOrigin-RevId: 496847423 21 December 2022, 08:08:52 UTC
5812957 [jax2tf] Add autograph=False everywhere we use tf.function 21 December 2022, 07:14:46 UTC
71230b6 Merge pull request #13728 from gnecula:tf_opaque PiperOrigin-RevId: 496837297 21 December 2022, 06:53:18 UTC
0a0ee78 [jax2tf] Fixes for handling of opaque dtypes for slice/update_slice/gather 21 December 2022, 06:48:26 UTC
5790960 Merge pull request #13743 from mattjj:tanh-tols PiperOrigin-RevId: 496824962 21 December 2022, 05:18:58 UTC
6ba0ef6 relax tanh test tols for upcoming xla change 21 December 2022, 05:06:09 UTC
39b14b1 Merge pull request #13693 from jakevdp:typing-ad-util PiperOrigin-RevId: 496783324 21 December 2022, 00:50:11 UTC
388ab7f Merge pull request #13734 from mattjj:print-saved-residuals-tweaks PiperOrigin-RevId: 496778515 21 December 2022, 00:24:27 UTC
4a6bbde Move jax.linear_util to jax._src.linear_util 20 December 2022, 22:49:27 UTC
d4fa1a4 Remove code that existed to support the now-gone classic HLO lowering path. PiperOrigin-RevId: 496741725 20 December 2022, 21:47:06 UTC
580fdb6 tweak print_saved_residuals 20 December 2022, 20:00:46 UTC
357d044 Change compilation_cache_test to compile MHLO instead of classic HLO. Support for classic HLO is being dropped from the .compile() API. In passing, also remove some obsolete version checks. The minimum xla_client API version is currently 109. PiperOrigin-RevId: 496708463 20 December 2022, 19:26:17 UTC
843bc43 [NumPy] Remove references to deprecated NumPy type aliases. This change replaces references to a number of deprecated NumPy type aliases (np.bool, np.int, np.float, np.complex, np.object, np.str) with their recommended replacement (bool, int, float, complex, object, str). NumPy 1.24 drops the deprecated aliases, so we must remove uses before updating NumPy. PiperOrigin-RevId: 496691463 20 December 2022, 18:17:12 UTC
71b8968 Merge pull request #13727 from gnecula:tf_squeeze PiperOrigin-RevId: 496664614 20 December 2022, 16:13:00 UTC
a51b174 [jax2tf] Improved error checking for jnp.squeeze with shape polymorphism 20 December 2022, 15:58:32 UTC
ae3ac50 Merge pull request #13726 from gnecula:tf_shape_poly_test PiperOrigin-RevId: 496660786 20 December 2022, 15:55:10 UTC
ecbd719 Use stages API instead of calling .compile() in host_callback_test. Support for calling the xla_client compiler() API with classic HLO input is being removed, but we should just use the public JAX API anyway. PiperOrigin-RevId: 496655326 20 December 2022, 15:24:03 UTC
61505d1 [jax2tf] More improvements to shape_poly_test * Small cleanup inside PolyHarness.run_test to use contextlib.ExitContext, and to run the TF eager mode conversion before the tf.function mode, because the latter is unfriendly to breakpoints and makes debugging harder. * Cleanup of the code to exclude some tests for native serialization. 20 December 2022, 13:32:04 UTC
9c4e2fa Make the device assignment of outfeed configurable PiperOrigin-RevId: 496574960 20 December 2022, 06:53:15 UTC
dfb82dd Merge pull request #13721 from jakevdp:bcoo-concatenate PiperOrigin-RevId: 496561475 20 December 2022, 05:17:27 UTC
049000d Merge pull request #13720 from jakevdp:doc-metadata PiperOrigin-RevId: 496561462 20 December 2022, 05:10:22 UTC
54871cc [sparse] add gradient test for bcoo_concatenate 19 December 2022, 23:26:15 UTC
aa34ea7 JAX github actions: Update Python versions in test matrix for better coverage PiperOrigin-RevId: 496501739 19 December 2022, 23:10:33 UTC
801213a DOC: specify jax sphinx extension as safe for parallel read/write 19 December 2022, 23:06:28 UTC
cff92a4 Merge pull request #13716 from ROCmSoftwarePlatform:rocm_disable_tests PiperOrigin-RevId: 496492880 19 December 2022, 22:31:53 UTC
cf4310b [sparse] update jax.experimental.sparse API listing 19 December 2022, 22:23:30 UTC
bc34af9 [sparse] Add bcsr dot_general PiperOrigin-RevId: 496489368 19 December 2022, 22:17:44 UTC
3391a5e [ROCm]: Disable some tests on ROCm platform 19 December 2022, 21:33:13 UTC
dbc3944 Remove more checks now that the minimum jaxlib version corresponds to xla_extension_version == 109. Also remove usage of `xc._version` and replace it with `xla_extension_version`. PiperOrigin-RevId: 496474494 19 December 2022, 21:15:07 UTC
4301a85 Merge pull request #13712 from google:nightly-311 PiperOrigin-RevId: 496454192 19 December 2022, 19:47:23 UTC
9b5c827 Merge pull request #13714 from jakevdp:sparse-test-fix PiperOrigin-RevId: 496444990 19 December 2022, 19:10:00 UTC
ad40b08 CI: use Python 3.11 for upstream-nightly action 19 December 2022, 18:47:09 UTC
9a3f5b6 Merge pull request #13713 from hawkinsp:minver PiperOrigin-RevId: 496438253 19 December 2022, 18:46:19 UTC
435f338 [sparse] fix typo in test_coo_fromdense 19 December 2022, 18:34:39 UTC
2c6c30d Bump the minimum jaxlib version to 0.4.1. Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39. 19 December 2022, 17:49:24 UTC
85f30c1 [sparse] implement more cases for vjp of bcoo_dot_general 19 December 2022, 17:16:06 UTC
63a9e71 [jax2tf] Update the select_and_gather_add lowering for dynamic shapes This enables lowering of select_and_gather_add for GPU when we have dynamic shapes. The main change is to carry the abstract values and to use them for broadcast, using mlir.broadcast_in_dim, which can general BroadcastInDimOp or DynamicBroadcastInDimOp. This requires some passing around of abstract values. 19 December 2022, 12:44:10 UTC
back to top