5a39341 | Sholto Douglas | 22 December 2022, 07:40:29 UTC | -Update streaming to work with an HTML window PiperOrigin-RevId: 497091885 | 28 December 2022, 23:44:34 UTC |
38521b8 | jax authors | 28 December 2022, 18:08:28 UTC | Merge pull request #13717 from jakevdp:sparse-grad PiperOrigin-RevId: 498204646 | 28 December 2022, 18:08:28 UTC |
27d7f77 | jax authors | 28 December 2022, 18:00:04 UTC | Merge pull request #13813 from jakevdp:fix-yash-thing PiperOrigin-RevId: 498203100 | 28 December 2022, 18:00:04 UTC |
2b82e7a | Jake VanderPlas | 28 December 2022, 17:55:11 UTC | [sparse] improve implementation of sparse.grad & sparse.value_and_grad | 28 December 2022, 17:55:11 UTC |
2cf0791 | Yuanzhong Xu | 28 December 2022, 17:51:56 UTC | Allow more cases in _TRANSPOSE_TRICKS by ignoring leading 1s in the mesh shape. PiperOrigin-RevId: 498201820 | 28 December 2022, 17:52:24 UTC |
ac88f30 | Jake VanderPlas | 28 December 2022, 17:38:56 UTC | fix signature of shard arg handler | 28 December 2022, 17:38:56 UTC |
a71ab80 | jax authors | 28 December 2022, 17:38:04 UTC | Merge pull request #13804 from jakevdp:masked-arr-error PiperOrigin-RevId: 498199498 | 28 December 2022, 17:38:04 UTC |
a8f796c | jax authors | 28 December 2022, 17:16:22 UTC | Merge pull request #13776 from jakevdp:sparse-batching-test PiperOrigin-RevId: 498195672 | 28 December 2022, 17:16:22 UTC |
bc1b22c | jax authors | 28 December 2022, 17:09:15 UTC | Merge pull request #13803 from gnecula:tf_dim_64 PiperOrigin-RevId: 498194615 | 28 December 2022, 17:09:15 UTC |
68ec9b5 | jax authors | 28 December 2022, 16:55:16 UTC | Merge pull request #13812 from gnecula:tf_while PiperOrigin-RevId: 498192217 | 28 December 2022, 16:55:16 UTC |
71ce600 | George Necula | 20 December 2022, 09:01:51 UTC | [jax2tf] Ensure that dim_as_value returns int64 in x64 mode and int32 otherwise Changes all the computations with dimensions to work in int64 if JAX_ENABLE_X64 and int32 otherwise. | 28 December 2022, 16:18:33 UTC |
9593741 | George Necula | 28 December 2022, 15:53:13 UTC | [jax2tf] Fix lowering for tf.while_loop Sometimes TF infers more specific shapes for the init_carry, and this has led to errors: "enters the loop with shape (1,), but has shape (None,) after one iteration" Unfortunately, I cannot construct a small test. | 28 December 2022, 15:53:13 UTC |
e85fd5c | jax authors | 28 December 2022, 15:19:39 UTC | Merge pull request #13811 from gnecula:dim_ops PiperOrigin-RevId: 498177537 | 28 December 2022, 15:19:39 UTC |
1e09602 | George Necula | 28 December 2022, 10:07:46 UTC | [jax2tf] Improve handling of dimension polynomials used with JAX arrays. See new description in README.md. | 28 December 2022, 10:37:42 UTC |
ec02966 | jax authors | 28 December 2022, 00:27:42 UTC | Merge pull request #13807 from jakevdp:fix-crossref PiperOrigin-RevId: 498052807 | 28 December 2022, 00:27:42 UTC |
fe4c958 | Jake VanderPlas | 27 December 2022, 23:59:32 UTC | doc: fix host callback module crossref | 27 December 2022, 23:59:32 UTC |
5367693 | Jake VanderPlas | 27 December 2022, 23:42:49 UTC | Error on numpy masked array inputs. | 27 December 2022, 23:42:49 UTC |
cca83ae | Jake VanderPlas | 23 December 2022, 00:15:43 UTC | [sparse] add _CheckBatchingSparse utility | 27 December 2022, 21:55:50 UTC |
8d2cd5c | jax authors | 27 December 2022, 21:15:19 UTC | Merge pull request #13805 from jakevdp:fix-rtd PiperOrigin-RevId: 498022216 | 27 December 2022, 21:15:19 UTC |
dc862a9 | Jake VanderPlas | 27 December 2022, 20:55:48 UTC | psum: fix docstring formatting | 27 December 2022, 20:55:48 UTC |
60dbc3b | jax authors | 27 December 2022, 18:43:36 UTC | Merge pull request #13775 from jakevdp:sparse-vmappable-metadata PiperOrigin-RevId: 497995723 | 27 December 2022, 18:43:36 UTC |
b043e76 | jax authors | 27 December 2022, 18:36:44 UTC | Merge pull request #13779 from eltociear:patch-3 PiperOrigin-RevId: 497995427 | 27 December 2022, 18:36:44 UTC |
1fc9197 | Yash Katariya | 27 December 2022, 18:16:08 UTC | Simplify Array's shard_arg_handler by merging pmap and pjit/xmap paths PiperOrigin-RevId: 497991966 | 27 December 2022, 18:16:44 UTC |
2cdce1d | George Necula | 27 December 2022, 17:45:17 UTC | [jax2tf] Simplify _eval_shape for the constant dimensions Before this change _eval_shape would return a tf.Const for the constant dimensions in the shape. Now it returns integer constants. This change reduces the size of the graph for the common cases when the dimensions are constant. It should not be a necessary change though but on some TPU flavors this results in TF2XLA bridge return error "XLA can't use dynamic begin values for slice" even when the dimension is actually a constant. PiperOrigin-RevId: 497986554 | 27 December 2022, 17:45:54 UTC |
909354c | jax authors | 27 December 2022, 17:27:30 UTC | Merge pull request #13802 from gnecula:clean_tests PiperOrigin-RevId: 497983945 | 27 December 2022, 17:27:30 UTC |
a1480c4 | Eugene Burmako | 27 December 2022, 16:52:39 UTC | Migrate JAX from producing MHLO to producing StableHLO As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically: 1) MLIR lowerings now produce StableHLO ops instead of MHLO ops. 2) Fallback lowerings now produce StableHLO ops as well. 3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs). From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO): a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing. b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath. c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo". d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues. PiperOrigin-RevId: 497978733 | 27 December 2022, 16:53:20 UTC |
0752655 | George Necula | 27 December 2022, 13:28:04 UTC | [jax2tf] Move pjit test into the right TestCase class This test uses tfxla.call_module directly, belongs to XlaCallModuleTest. | 27 December 2022, 13:28:04 UTC |
7d452ad | George Necula | 27 December 2022, 12:47:48 UTC | Add support for dynamic shapes to GPU threefry2x32 custom call. In presence of dynamic shapes the ThreeFry2x32Descriptor will contain the value n=-1, and the actual desired output length will be passed as an additional operand. If the shape is static then the length will be passed as part of the descriptor. PiperOrigin-RevId: 497945778 | 27 December 2022, 12:48:26 UTC |
0c8a4fb | George Necula | 26 December 2022, 04:26:26 UTC | [jax2tf] Simplify irrelevant part of call_tf_test.py PiperOrigin-RevId: 497727816 | 26 December 2022, 04:26:59 UTC |
398aaaa | jax authors | 24 December 2022, 06:46:21 UTC | Add support for Ellipsis as an index for stateful operations. PiperOrigin-RevId: 497466879 | 24 December 2022, 06:46:50 UTC |
eb875cd | jax authors | 24 December 2022, 00:05:18 UTC | Added a pattern-match optimisation for inplace-select. PiperOrigin-RevId: 497425937 | 24 December 2022, 00:05:56 UTC |
0c51ca2 | Qiao Zhang | 23 December 2022, 20:05:52 UTC | Minor docstring fix for custom_gradient. PiperOrigin-RevId: 497402055 | 23 December 2022, 20:12:48 UTC |
9fda20f | Qiao Zhang | 23 December 2022, 20:04:19 UTC | Add examples to lax.psum to illustrate axis_index_groups better. PiperOrigin-RevId: 497401892 | 23 December 2022, 20:04:59 UTC |
55d0a94 | Ikko Ashimine | 23 December 2022, 18:59:19 UTC | Fix typo in loops.py miscellanous -> miscellaneous | 23 December 2022, 18:59:19 UTC |
cee2779 | Jake VanderPlas | 22 December 2022, 22:46:40 UTC | [sparse] propagate metadata through vmappable | 22 December 2022, 22:46:40 UTC |
2f3d75a | Yash Katariya | 22 December 2022, 21:34:49 UTC | Remove dependency of maps from pjit to avoid circular imports when importing pjit in api.py. PiperOrigin-RevId: 497230514 | 22 December 2022, 21:35:23 UTC |
d2b87be | jax authors | 22 December 2022, 19:47:09 UTC | Merge pull request #13333 from harryjulian:bernoulli PiperOrigin-RevId: 497210080 | 22 December 2022, 19:47:09 UTC |
c0d4ae0 | harryjulian | 16 December 2022, 10:12:40 UTC | Added scipy.stats.bernoulli cdf and ppf. | 22 December 2022, 18:12:25 UTC |
d8bf334 | jax authors | 22 December 2022, 18:00:19 UTC | Merge pull request #13755 from jakevdp:sparse-dense-vmap PiperOrigin-RevId: 497187394 | 22 December 2022, 18:00:19 UTC |
e661c3f | Jake VanderPlas | 22 December 2022, 17:08:30 UTC | [sparse] Handle general batch dimensions in bcoo todense/fromdense | 22 December 2022, 17:08:30 UTC |
57840dd | Yash Katariya | 22 December 2022, 16:40:36 UTC | Move functions into `api_util.py` and `dispatch.py` to remove circular import error when pjit is imported in `api.py` for merging the `jit` and `pjit` frontend API. PiperOrigin-RevId: 497172760 | 22 December 2022, 16:42:05 UTC |
4d4eba4 | George Necula | 22 December 2022, 16:34:01 UTC | Fix scatter in CLIP mode with uint32 and uint64 indices Clipping uses np.iinfo(indices.dtype).max and those values are too large to be converted to Python or C constants. This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to failures when jax_array=False. Since that use case will go away soon we just enable the fix for when jax_array=True. PiperOrigin-RevId: 497171518 | 22 December 2022, 16:34:47 UTC |
2b716f2 | George Necula | 22 December 2022, 05:59:15 UTC | Fix scatter in CLIP mode with uint32 and uint64 indices Clipping uses np.iinfo(indices.dtype).max and those values are too large to be converted to Python or C constants. This is a second attempt, after https://github.com/google/jax/pull/13746 was rolled back due to failures when jax_array=False. Since that use case will go away soon we just enable the fix for when jax_array=True. PiperOrigin-RevId: 497079129 | 22 December 2022, 05:59:51 UTC |
40f5279 | jax authors | 22 December 2022, 03:44:55 UTC | Merge pull request #13761 from wookayin:fix-api-public PiperOrigin-RevId: 497060752 | 22 December 2022, 03:44:55 UTC |
c89e662 | Jake VanderPlas | 22 December 2022, 02:51:23 UTC | jax.linear_util: remove unnecessary exported names Follow-up to https://github.com/google/jax/pull/13735 PiperOrigin-RevId: 497052842 | 22 December 2022, 02:52:01 UTC |
479f336 | Jongwook Choi | 22 December 2022, 01:12:08 UTC | Make `jax.ensure_compile_time_eval` correctly exposed as a public API This function was added as a public API (#7987) but py.type static checkers do not recognize it as a public API because of the alias name. `jax.eval_context` exists only for backward compatibility, so the correct import would be to import `ensure_compile_time_eval` directly from `jax._src.core`. | 22 December 2022, 01:12:08 UTC |
704b9fc | Peter Hawkins | 21 December 2022, 21:29:52 UTC | Don't tuple high-arity computations on GPU. I can think of no reason we do this, and it seems strictly preferable to not have to form an on-device tuple since it requires a H2D copy. As best I can reconstruct it, the sequence of events that got us to the current state was: a) we had backend-independent logic to tuple computations with > 100 arguments. This logic was motivated only by TPU. b) we disabled tupling on CPU, because the XLA:CPU folks noticed it wasn't needed. c) we increased the argument limit on TPU to 2000, because the TPU compiler folks noticed it was too low. This left GPU with the old TPU behavior for no good reason. PiperOrigin-RevId: 496993803 | 21 December 2022, 21:30:34 UTC |
d5830e0 | jax authors | 21 December 2022, 20:58:28 UTC | Merge pull request #13754 from mattjj:dot-general-pp-rule PiperOrigin-RevId: 496987303 | 21 December 2022, 20:58:28 UTC |
c2d9b5c | Matthew Johnson | 21 December 2022, 18:16:18 UTC | tweak dot_general pretty-printing rule to suppress default params | 21 December 2022, 18:29:01 UTC |
e89b60e | Tianjian Lu | 21 December 2022, 17:54:38 UTC | [sparse] Propagate SparseInfo to BCSR todense() and tree_(un)flatten(). PiperOrigin-RevId: 496945167 | 21 December 2022, 17:55:21 UTC |
4f75ad6 | Tom Hennigan | 21 December 2022, 17:05:36 UTC | Revert #13747 since config values already have paired env vars. PiperOrigin-RevId: 496935378 | 21 December 2022, 17:06:08 UTC |
673838e | jax authors | 21 December 2022, 16:49:00 UTC | Merge pull request #13678 from jakevdp:bcoo-dot-general-sampled PiperOrigin-RevId: 496931958 | 21 December 2022, 16:49:00 UTC |
44e3e65 | George Necula | 21 December 2022, 16:23:46 UTC | [jax2tf] Disable a failing shape polymorphism test PiperOrigin-RevId: 496927119 | 21 December 2022, 16:24:27 UTC |
71f861a | Jake VanderPlas | 21 December 2022, 16:14:42 UTC | [sparse] improve tests for bcoo_dot_general_sampled | 21 December 2022, 16:14:42 UTC |
7095e07 | jax authors | 21 December 2022, 15:44:34 UTC | Merge pull request #13704 from gnecula:gpu_select_and_gather PiperOrigin-RevId: 496919137 | 21 December 2022, 15:44:34 UTC |
4075abc | jax authors | 21 December 2022, 14:08:12 UTC | Merge pull request #13719 from jakevdp:sparse-doc PiperOrigin-RevId: 496902058 | 21 December 2022, 14:08:12 UTC |
80998df | jax authors | 21 December 2022, 14:01:03 UTC | Merge pull request #13675 from jakevdp:bcoo-dot-general-vjp PiperOrigin-RevId: 496902013 | 21 December 2022, 14:01:03 UTC |
87aa3aa | jax authors | 21 December 2022, 13:15:38 UTC | Merge pull request #13735 from jakevdp:private-linear-util PiperOrigin-RevId: 496896110 | 21 December 2022, 13:15:38 UTC |
ce5320a | George Necula | 21 December 2022, 11:46:01 UTC | Copybara import of the project: -- a74c74c25572eec23c28e08dbe67781a23be19fb by George Necula <gcnecula@gmail.com>: Fix scatter in CLIP mode with uint32 and uint64 indices Clipping uses np.iinfo(indices.dtype).max and those values are too large to be converted to Python or C constants. PiperOrigin-RevId: 496883024 | 21 December 2022, 11:46:27 UTC |
76f92c4 | Tom Hennigan | 21 December 2022, 10:37:33 UTC | Set default value for jax_platforms based on JAX_PLATFORMS env var. PiperOrigin-RevId: 496872386 | 21 December 2022, 10:38:14 UTC |
8a409b9 | jax authors | 21 December 2022, 08:53:14 UTC | Merge pull request #13746 from gnecula:scatter_clip PiperOrigin-RevId: 496854361 | 21 December 2022, 08:53:14 UTC |
a74c74c | George Necula | 21 December 2022, 08:22:12 UTC | Fix scatter in CLIP mode with uint32 and uint64 indices Clipping uses np.iinfo(indices.dtype).max and those values are too large to be converted to Python or C constants. | 21 December 2022, 08:25:24 UTC |
45a2116 | George Necula | 21 December 2022, 08:22:35 UTC | [jax2tf] Disable known failures for shape polymorphism with native lowering PiperOrigin-RevId: 496849655 | 21 December 2022, 08:23:08 UTC |
8c3f19a | jax authors | 21 December 2022, 08:08:52 UTC | Merge pull request #13744 from gnecula:tf_autograph PiperOrigin-RevId: 496847423 | 21 December 2022, 08:08:52 UTC |
5812957 | George Necula | 21 December 2022, 07:14:46 UTC | [jax2tf] Add autograph=False everywhere we use tf.function | 21 December 2022, 07:14:46 UTC |
71230b6 | jax authors | 21 December 2022, 06:53:18 UTC | Merge pull request #13728 from gnecula:tf_opaque PiperOrigin-RevId: 496837297 | 21 December 2022, 06:53:18 UTC |
0a0ee78 | George Necula | 20 December 2022, 14:20:14 UTC | [jax2tf] Fixes for handling of opaque dtypes for slice/update_slice/gather | 21 December 2022, 06:48:26 UTC |
5790960 | jax authors | 21 December 2022, 05:18:58 UTC | Merge pull request #13743 from mattjj:tanh-tols PiperOrigin-RevId: 496824962 | 21 December 2022, 05:18:58 UTC |
6ba0ef6 | Matthew Johnson | 21 December 2022, 05:06:09 UTC | relax tanh test tols for upcoming xla change | 21 December 2022, 05:06:09 UTC |
39b14b1 | jax authors | 21 December 2022, 00:50:11 UTC | Merge pull request #13693 from jakevdp:typing-ad-util PiperOrigin-RevId: 496783324 | 21 December 2022, 00:50:11 UTC |
388ab7f | jax authors | 21 December 2022, 00:24:27 UTC | Merge pull request #13734 from mattjj:print-saved-residuals-tweaks PiperOrigin-RevId: 496778515 | 21 December 2022, 00:24:27 UTC |
4a6bbde | Jake VanderPlas | 20 December 2022, 22:49:27 UTC | Move jax.linear_util to jax._src.linear_util | 20 December 2022, 22:49:27 UTC |
d4fa1a4 | Peter Hawkins | 20 December 2022, 21:46:29 UTC | Remove code that existed to support the now-gone classic HLO lowering path. PiperOrigin-RevId: 496741725 | 20 December 2022, 21:47:06 UTC |
580fdb6 | Matthew Johnson | 20 December 2022, 20:00:46 UTC | tweak print_saved_residuals | 20 December 2022, 20:00:46 UTC |
357d044 | Peter Hawkins | 20 December 2022, 19:25:42 UTC | Change compilation_cache_test to compile MHLO instead of classic HLO. Support for classic HLO is being dropped from the .compile() API. In passing, also remove some obsolete version checks. The minimum xla_client API version is currently 109. PiperOrigin-RevId: 496708463 | 20 December 2022, 19:26:17 UTC |
843bc43 | Peter Hawkins | 20 December 2022, 18:16:31 UTC | [NumPy] Remove references to deprecated NumPy type aliases. This change replaces references to a number of deprecated NumPy type aliases (np.bool, np.int, np.float, np.complex, np.object, np.str) with their recommended replacement (bool, int, float, complex, object, str). NumPy 1.24 drops the deprecated aliases, so we must remove uses before updating NumPy. PiperOrigin-RevId: 496691463 | 20 December 2022, 18:17:12 UTC |
71b8968 | jax authors | 20 December 2022, 16:13:00 UTC | Merge pull request #13727 from gnecula:tf_squeeze PiperOrigin-RevId: 496664614 | 20 December 2022, 16:13:00 UTC |
a51b174 | George Necula | 20 December 2022, 13:42:09 UTC | [jax2tf] Improved error checking for jnp.squeeze with shape polymorphism | 20 December 2022, 15:58:32 UTC |
ae3ac50 | jax authors | 20 December 2022, 15:55:10 UTC | Merge pull request #13726 from gnecula:tf_shape_poly_test PiperOrigin-RevId: 496660786 | 20 December 2022, 15:55:10 UTC |
ecbd719 | Peter Hawkins | 20 December 2022, 15:23:29 UTC | Use stages API instead of calling .compile() in host_callback_test. Support for calling the xla_client compiler() API with classic HLO input is being removed, but we should just use the public JAX API anyway. PiperOrigin-RevId: 496655326 | 20 December 2022, 15:24:03 UTC |
61505d1 | George Necula | 20 December 2022, 13:29:51 UTC | [jax2tf] More improvements to shape_poly_test * Small cleanup inside PolyHarness.run_test to use contextlib.ExitContext, and to run the TF eager mode conversion before the tf.function mode, because the latter is unfriendly to breakpoints and makes debugging harder. * Cleanup of the code to exclude some tests for native serialization. | 20 December 2022, 13:32:04 UTC |
9c4e2fa | Chang Lan | 20 December 2022, 06:52:46 UTC | Make the device assignment of outfeed configurable PiperOrigin-RevId: 496574960 | 20 December 2022, 06:53:15 UTC |
dfb82dd | jax authors | 20 December 2022, 05:17:27 UTC | Merge pull request #13721 from jakevdp:bcoo-concatenate PiperOrigin-RevId: 496561475 | 20 December 2022, 05:17:27 UTC |
049000d | jax authors | 20 December 2022, 05:10:22 UTC | Merge pull request #13720 from jakevdp:doc-metadata PiperOrigin-RevId: 496561462 | 20 December 2022, 05:10:22 UTC |
54871cc | Jake VanderPlas | 19 December 2022, 23:26:15 UTC | [sparse] add gradient test for bcoo_concatenate | 19 December 2022, 23:26:15 UTC |
aa34ea7 | Jake VanderPlas | 19 December 2022, 23:09:54 UTC | JAX github actions: Update Python versions in test matrix for better coverage PiperOrigin-RevId: 496501739 | 19 December 2022, 23:10:33 UTC |
801213a | Jake VanderPlas | 19 December 2022, 23:06:28 UTC | DOC: specify jax sphinx extension as safe for parallel read/write | 19 December 2022, 23:06:28 UTC |
cff92a4 | jax authors | 19 December 2022, 22:31:53 UTC | Merge pull request #13716 from ROCmSoftwarePlatform:rocm_disable_tests PiperOrigin-RevId: 496492880 | 19 December 2022, 22:31:53 UTC |
cf4310b | Jake VanderPlas | 19 December 2022, 22:23:30 UTC | [sparse] update jax.experimental.sparse API listing | 19 December 2022, 22:23:30 UTC |
bc34af9 | Tianjian Lu | 19 December 2022, 22:16:50 UTC | [sparse] Add bcsr dot_general PiperOrigin-RevId: 496489368 | 19 December 2022, 22:17:44 UTC |
3391a5e | Rahul Batra | 19 December 2022, 21:05:17 UTC | [ROCm]: Disable some tests on ROCm platform | 19 December 2022, 21:33:13 UTC |
dbc3944 | Yash Katariya | 19 December 2022, 21:13:15 UTC | Remove more checks now that the minimum jaxlib version corresponds to xla_extension_version == 109. Also remove usage of `xc._version` and replace it with `xla_extension_version`. PiperOrigin-RevId: 496474494 | 19 December 2022, 21:15:07 UTC |
4301a85 | jax authors | 19 December 2022, 19:47:23 UTC | Merge pull request #13712 from google:nightly-311 PiperOrigin-RevId: 496454192 | 19 December 2022, 19:47:23 UTC |
9b5c827 | jax authors | 19 December 2022, 19:10:00 UTC | Merge pull request #13714 from jakevdp:sparse-test-fix PiperOrigin-RevId: 496444990 | 19 December 2022, 19:10:00 UTC |
ad40b08 | Jake VanderPlas | 19 December 2022, 17:03:54 UTC | CI: use Python 3.11 for upstream-nightly action | 19 December 2022, 18:47:09 UTC |
9a3f5b6 | jax authors | 19 December 2022, 18:46:19 UTC | Merge pull request #13713 from hawkinsp:minver PiperOrigin-RevId: 496438253 | 19 December 2022, 18:46:19 UTC |
435f338 | Jake VanderPlas | 19 December 2022, 18:34:39 UTC | [sparse] fix typo in test_coo_fromdense | 19 December 2022, 18:34:39 UTC |
2c6c30d | Peter Hawkins | 19 December 2022, 17:38:24 UTC | Bump the minimum jaxlib version to 0.4.1. Jaxlib 0.4.1 has XLA client version 109 and MLIR API version 39. | 19 December 2022, 17:49:24 UTC |
85f30c1 | Jake VanderPlas | 19 December 2022, 17:16:06 UTC | [sparse] implement more cases for vjp of bcoo_dot_general | 19 December 2022, 17:16:06 UTC |
63a9e71 | George Necula | 17 December 2022, 20:43:58 UTC | [jax2tf] Update the select_and_gather_add lowering for dynamic shapes This enables lowering of select_and_gather_add for GPU when we have dynamic shapes. The main change is to carry the abstract values and to use them for broadcast, using mlir.broadcast_in_dim, which can general BroadcastInDimOp or DynamicBroadcastInDimOp. This requires some passing around of abstract values. | 19 December 2022, 12:44:10 UTC |