https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
48ad9a6 Start jax and jaxlib 0.4.11 release PiperOrigin-RevId: 536860076 31 May 2023, 23:48:52 UTC
525ba49 Merge pull request #16204 from skye:importlib_metadata_version PiperOrigin-RevId: 536823622 31 May 2023, 21:27:53 UTC
9682370 Add importlib_metadata to project requirements. This is necessary to ensure we can correctly detect PJRT plugins via entry_points without compatibility errors. Prior to this change, there was conditional logic to handle if importlib_metadata wasn't installed at all. However, it doesn't handle the case where importlib_metadata is installed by not high enough version to support Python 3.10 compat. This change gets rid of that logic and just ensures the right version is installed. All of this logic can be removed if/when jax requires Python version >= 3.10 This also removes an unnecessary `requests` dep for the [tpu] install. 31 May 2023, 21:03:12 UTC
b35c20c Use xla_extension_version and remove some dead version check in xla_bridge_test.py. Min jaxlib requires xla_extension_version >= 144. PiperOrigin-RevId: 536810415 31 May 2023, 20:50:07 UTC
727c121 Merge pull request #16188 from nouiz:ci_jestimator PiperOrigin-RevId: 536810121 31 May 2023, 20:41:29 UTC
c587dac Merge pull request #16203 from skye:tpu_py_version2 PiperOrigin-RevId: 536776189 31 May 2023, 18:39:06 UTC
131d28b Use default Python version on Cloud TPU CI 31 May 2023, 18:04:39 UTC
6d6ba70 Disable the RunnTest.test_lstm1 test since it is fixed for cudnn >= 8.8 PiperOrigin-RevId: 536693061 31 May 2023, 13:21:01 UTC
758d68d Restore `call_tf_concrete_function_list` to previous state In the following case of nested call: ``` inputs = np.array(range(6), dtype=np.float32).reshape(3, 2) @jax.jit def forward(x): return x + 1 # JAX -> TF tf_fn = jax2tf.convert(forward, native_serialization=True) call_tf_fn = jax2tf.call_tf(tf_fn) tf_fn_too = jax2tf.convert(call_tf_fn, native_serialization=True) tf_fn_too(inputs) # FAIL ``` Without the fix, it fails with the following error: ``` jax/experimental/jax2tf/jax2tf.py", line 499, in _restore_context _thread_local_state.call_tf_concrete_function_list.clear() AttributeError: 'NoneType' object has no attribute 'clear' ``` because we call `_restore_context` twice when executing `jax2tf.convert`ed functions, the first time we call `_restore_context`, `call_tf_concrete_function_list` is set to `None` instead of restoring it to the previous state, so the second time we call `_restore_context`, `call_tf_concrete_function_list.clear()` throws the above error since `call_tf_concrete_function_list` is `None`. PiperOrigin-RevId: 536650377 31 May 2023, 09:23:14 UTC
f884b4d Fix the `test_sharding_on_output_with_vmap` failure in Pathways which was getting a cache miss in pjit_call_impl. There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function. In the test, check for re-compilation which is what that test was doing before cl/535630905 PiperOrigin-RevId: 536575987 31 May 2023, 02:51:48 UTC
3ad756f Merge pull request #16176 from gnecula:poly_constraints PiperOrigin-RevId: 536571493 31 May 2023, 02:16:52 UTC
9ad8c3b [shape_poly] Add static constraint checking to the computation of dim vars Previously we had one function `shape_poly.unify_avals_with_args` that was solving the dimension variables and was also used for generating the code to compute them. Now we separate the solving part, which is now using just symbolic expressions (`shape_poly.solve_dim_vars`), from the code generator for the dimension variables (`shape_poly.compute_dim_vars_from_arg_shapes`). We also add a notion of shape constraints, e.g., `dimexpr1 == dimexpr2` or `dimexpr1 >= dimexpr2`, under which the solution for the dimension variables is valid. For now we implement the static checking of the shape constraints, e.g., when the dimension expressions are constant or TF EagerTensor. We do not yet have compile-time checking of the constraints. This matches the previous behavior. However, the code is now ready for implementing compile-time checking of the constraints that cannot be checked statically. 31 May 2023, 01:48:44 UTC
cdced24 WAR the bug in t5x dependency. It currently need the dev version of jestimator. 30 May 2023, 19:55:21 UTC
acfeb9b Merge pull request #16169 from ZacCranko:data_parallel_example PiperOrigin-RevId: 536260245 30 May 2023, 01:39:44 UTC
a192b5e improve data parallel example fix example fix example fix example fix example fix example fix example 30 May 2023, 01:25:17 UTC
ae9160a Merge pull request #16159 from jakevdp:deprecations PiperOrigin-RevId: 536003451 28 May 2023, 14:27:45 UTC
7a87995 Deprecate jax.interpreters.xla.Buffer, device_put, xla_call_p 28 May 2023, 14:15:34 UTC
1279418 Link in CUDA runtime for triton in jaxlib PiperOrigin-RevId: 535708416 26 May 2023, 21:02:16 UTC
cb3b7ec [PJRT PLUGIN] Add num_processes to distributed.global_state. The number of processes is needed for multi-process GPU when plugin is used. PiperOrigin-RevId: 535696950 26 May 2023, 20:14:40 UTC
d62bc0f Fix the jax2tf failure in mypy: https://github.com/google/jax/actions/runs/5094063162/jobs/9157426652?pr=16155 PiperOrigin-RevId: 535692853 26 May 2023, 19:57:37 UTC
fe3fed3 Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy. PiperOrigin-RevId: 535687618 26 May 2023, 19:35:16 UTC
25a9a97 Merge pull request #16151 from hawkinsp:cudnn PiperOrigin-RevId: 535642800 26 May 2023, 16:48:40 UTC
4f07471 Make pjit_call_impl go via C++ dispatch. This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl. PiperOrigin-RevId: 535630905 26 May 2023, 15:57:30 UTC
9508f3a Merge pull request #16148 from gnecula:export_poly PiperOrigin-RevId: 535628086 26 May 2023, 15:44:41 UTC
46a258b [shape_poly] Add partial support for call_exported with polymorphic shapes Until now the jax_export.call_exported did not allow calling functions that were exported with polymorphic shapes. We now add that support, including resolving the dimension variables of the called function in terms of the shapes at the call site (which themselves may include dimension variables), and then computing the output shape of the called function. The support is partial in that we can export a JAX function that calls an exported polymorphic function, but we cannot invoke it. This is because we do not yet have access to the shape refinement machinery that XlaCallModule uses. For now, we use XlaCallModule for invoking exported that includes shape polymorphism. 26 May 2023, 15:27:44 UTC
2858df2 Start the process of removing OpSharding from JAX and replacing it with HloSharding. This will allow for future optimizations of HloSharding to work seamlessly with JAX. Currently, no function producing HloSharding is being used. I will do that in follow up CLs. PiperOrigin-RevId: 535622806 26 May 2023, 15:19:14 UTC
69cf67f Bump the minimum CUDNN version for CUDA 12 wheels to 8.9. 26 May 2023, 14:04:34 UTC
ea37043 Switch to `STATUS_RETURNING` callback API. PiperOrigin-RevId: 535568707 26 May 2023, 10:15:44 UTC
7833528 Merge pull request #16143 from jakevdp:fix-shape-poly PiperOrigin-RevId: 535427698 25 May 2023, 23:31:09 UTC
ed10293 Add new `called_index` to custom_call `tf.backend_config` DictAttr. Here, `called_index` indicates the tf concrete function index in the `function_list` of the parent XLACallModule. PiperOrigin-RevId: 535417558 25 May 2023, 22:58:50 UTC
bbae2ed jax2tf: correctly handle opaque dtype in jax2tf pure() In TF tracers, "val" is the physical TF representation, while "aval" is the abstract value used during tracing, which is where additional JAX-specific information such as opaque dtype, weak_type, etc. should be included. Before opaque dtypes, val and aval always had the same shape and dtype. With opaque dtypes, this is no longer the case, which revealed this bug in the logic of jax2tf pure(). PiperOrigin-RevId: 535408671 25 May 2023, 22:32:47 UTC
8534f0b Merge pull request #16142 from froystig:outline-random-functions PiperOrigin-RevId: 535406588 25 May 2023, 22:25:18 UTC
b853ce9 jax2tf: make shape_poly_test pass with custom PRNG 25 May 2023, 22:16:46 UTC
3238b62 outline jitted `jax.random` functions We may want to continue to inline these in Jaxpr somehow, but it's useful to outline them in HLO for visualization and debugging. 25 May 2023, 22:01:04 UTC
14089fb Merge pull request #16138 from hawkinsp:cudnn PiperOrigin-RevId: 535367462 25 May 2023, 20:40:34 UTC
2b77902 Bump minimum CUDNN version in pip installation to 8.8. There are known wrong output bugs observed in JAX for earlier versions, in particular related to RNNs. 25 May 2023, 18:46:39 UTC
16368bc [XLA:Python] Clean up handling of unsupported types in buffer protocol. Rather than enumerating a list of types that don't work in the buffer protocol, call the format descriptor function and fail if it fails. Simplify the format descriptor function to avoid allocating a format string; these can be compile-time constants. PiperOrigin-RevId: 535315975 25 May 2023, 18:10:19 UTC
2155b91 Switch to using JAX status macros in jax-triton kernel call lib. PiperOrigin-RevId: 535300412 25 May 2023, 17:26:06 UTC
bc547aa Adds a note that pjit is equivalent to jit. PiperOrigin-RevId: 535296532 25 May 2023, 17:17:25 UTC
32026ad Disable random_test_with_custom_prng on CPU under msan. This test flakily times out in CI. PiperOrigin-RevId: 535293997 25 May 2023, 17:10:01 UTC
24928a5 Merge pull request #16117 from jakevdp:matrix-transpose PiperOrigin-RevId: 535292507 25 May 2023, 17:02:26 UTC
222b951 Use new matrix_transpose in linalg code 25 May 2023, 16:32:14 UTC
333ff4a Add jnp.matrix_transpose() and jax.Array.mT This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/). It's lightweight enough that there's hardly any downside to supporting it in JAX. 25 May 2023, 16:02:05 UTC
e464dc8 Reland: [XLA:Python] Add buffer protocol support to jax.Array We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array. The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do. Fixes https://github.com/google/jax/issues/14713 PiperOrigin-RevId: 535248553 25 May 2023, 14:20:42 UTC
6b13d4e Add branch prediction to JAX status macros. PiperOrigin-RevId: 535233546 25 May 2023, 13:23:23 UTC
e25052c Use stablehlo.get_minimum_version in jax_export.py The currently used stablehlo.get_earliest_forward_compatible_version was intended to be a short-term workaround, and it has been recently replaced by the long-term stablehlo.get_minimum_version API. This CL migrates to the long-term API. PiperOrigin-RevId: 535091927 25 May 2023, 04:15:16 UTC
8e397f7 [XLA:Client] Change replicate_last_dim to subgroup_types in HloSharding.iota_tile to cover arbitrary subgroups, adding necessary accessors. PiperOrigin-RevId: 535079635 25 May 2023, 03:26:28 UTC
5e82d6b Fix jax2tf_test regression failure. PiperOrigin-RevId: 535002015 24 May 2023, 22:27:57 UTC
4fb834b Use jaxlib version guard for triton instead of xla_extension_version PiperOrigin-RevId: 534974834 24 May 2023, 21:06:45 UTC
6a54ebd Fix the lu.clear_all_cache function by adding the memoized_fun to the global weakref set rather than the function local `fun_caches` weakrefDict. PiperOrigin-RevId: 534971855 24 May 2023, 20:58:51 UTC
557ca52 Add `cuda_pip` extra for jaxlib PiperOrigin-RevId: 534957585 24 May 2023, 20:19:27 UTC
bf8ed6a Move triton_kernel_call_lib to jaxlib PiperOrigin-RevId: 534934592 24 May 2023, 19:11:21 UTC
7de1677 Add (optional) ordered effects for `jax2tf.call_tf` This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens. With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior: * With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry. * Without `call_tf_graph=True`, this raises a `NotImplementedError()`. For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed. Example StableHLO produced from the added test: ``` module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) { %0 = stablehlo.constant dense<> : tensor<0xi1> %1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>) return %1#1 : tensor<f32> } func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) { %0 = stablehlo.create_token : !stablehlo.token %1 = stablehlo.constant dense<0> : tensor<i32> %2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32> cond { %4 = stablehlo.constant dense<4> : tensor<i32> %5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1> stablehlo.return %5 : tensor<i1> } do { %4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token %5 = stablehlo.constant dense<1.000000e+00> : tensor<f32> %6 = stablehlo.add %iterArg_1, %5 : tensor<f32> %7 = stablehlo.constant dense<1> : tensor<i32> %8 = stablehlo.add %iterArg_0, %7 : tensor<i32> stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32> } %3 = stablehlo.constant dense<> : tensor<0xi1> return %3, %2#2 : tensor<0xi1>, tensor<f32> } } ``` PiperOrigin-RevId: 534926215 24 May 2023, 18:48:35 UTC
399e4ee Copybara import of the project: -- 8cf6a6acd151007935b0c3093df05ef036bb0244 by Jake VanderPlas <jakevdp@google.com>: Remove several deprecated APIs COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16110 from jakevdp:deprecations 8cf6a6acd151007935b0c3093df05ef036bb0244 PiperOrigin-RevId: 534897394 24 May 2023, 17:35:37 UTC
1831b3c Merge pull request #16105 from kmillikin:main PiperOrigin-RevId: 534854308 24 May 2023, 15:41:51 UTC
2f7cc7d Merge pull request #16109 from michaeldeistler:readme-fix PiperOrigin-RevId: 534844507 24 May 2023, 15:11:47 UTC
5f1952d fix typo 24 May 2023, 08:43:03 UTC
921fd22 Refer to the original `map`/`zip` classes via `builtins` Referring to them as simply `map` or `zip` will create recursive reimplementations (with no base case!) if the cell is reevaluated in the same runtime. 24 May 2023, 06:47:50 UTC
d9e7a2a Merge pull request #16102 from jakevdp:deprecate-lax-prod PiperOrigin-RevId: 534618632 24 May 2023, 00:40:18 UTC
4cfa96e deprecate jax.lax.prod 24 May 2023, 00:33:50 UTC
2d525b8 Merge pull request #16103 from jakevdp:deprecation-stacklevel PiperOrigin-RevId: 534616543 24 May 2023, 00:32:17 UTC
016eae4 Allow disabling the parsing of GSPMDSharding -> NamedSharding. Because this is best effort, users writing code to handle GPSMDSharding should be able to deal only with the GSPMDSharding type. PiperOrigin-RevId: 534612265 24 May 2023, 00:16:56 UTC
fecf193 Merge pull request #16070 from hawkinsp:docs2 PiperOrigin-RevId: 534606219 23 May 2023, 23:56:02 UTC
4f1f5e4 [XLA:Client] Expose HloSharding pybind factories for iota tile/partial tile, replicated and manual sharding, PiperOrigin-RevId: 534600886 23 May 2023, 23:37:42 UTC
7f7f995 Export jax.lax.sharding_constraint_p PiperOrigin-RevId: 534566582 23 May 2023, 21:50:46 UTC
2623473 Make deprecation warnings warn at appropriate stacklevel 23 May 2023, 21:43:38 UTC
db87167 Migrate `exec_tools` back to `tools`. PiperOrigin-RevId: 534549617 23 May 2023, 21:00:34 UTC
16410a5 Merge pull request #16096 from mattjj:softmax-custom-jvp-2 PiperOrigin-RevId: 534542849 23 May 2023, 20:41:23 UTC
d42350f disable custom_jvp for softmax by default Follow-up on #15677, basically undoing it. Some training runs experienced mysterious failures after many steps. We may leave this disabled until we diagnose the cause of the failures. 23 May 2023, 18:56:50 UTC
685ea5b Merge pull request #15069 from mattjj:issue15068 PiperOrigin-RevId: 534476334 23 May 2023, 17:39:38 UTC
62fb0cd explicitly convert jnp.var scalar normalizer to float (from int) This way we don't pass a potentially-large (Python builtin) int value to an int32 JAX computation parameter and get an error. Fixes #15068 Co-authored by: Matthew Johnson <mattjj@google.com> 23 May 2023, 16:44:08 UTC
a7b8129 Merge pull request #16073 from stellaraccident:extplugin PiperOrigin-RevId: 534237189 23 May 2023, 00:34:51 UTC
13f5090 Merge pull request #16018 from ZacCranko:tree_reduce_is_leaf PiperOrigin-RevId: 534165099 22 May 2023, 20:31:04 UTC
69d6c1b Merge pull request #16086 from froystig:upgraded-key-ctor PiperOrigin-RevId: 534152508 22 May 2023, 19:46:41 UTC
85fb48a Merge pull request #15930 from canyon289:jax201 PiperOrigin-RevId: 534149169 22 May 2023, 19:34:48 UTC
b7b90e6 add new random `key` constructor This constructor unconditionally returns a typed key array, regardless of the value of `jax.config.enable_custom_prng`. We can switch to referring to it in randomness docs and tutorials as we complete the typed key upgrade. 22 May 2023, 18:35:10 UTC
2f215e5 Merge pull request #16046 from mattjj:generalize-dot-general-dtypes PiperOrigin-RevId: 534123892 22 May 2023, 18:11:47 UTC
61b106e allow lax.dot_general to accept different input dtypes This change brings the dot_general primitive more in line with the HLO primitive, as it is described in XLA's shape_inference.cc (but not in the StableHLO spec). In particular we allow different input dtypes. The main motivation is to support transposition in the presence of preferred_element_type (which can set the output dtype to be different from the inputs), e.g. to fix #10818. However, because XLA platforms/backends can't seem to codegen all the cases that are accepted by shape_inference.cc, in our lowering rules we generate ConvertElementTypes on the inputs in a platform-dependent way. 22 May 2023, 17:33:42 UTC
473fa7d Add building on JAX 22 May 2023, 17:05:39 UTC
ee14ca2 Add option jax_include_full_tracebacks_in_locations. If enabled, includes full stack traces in MLIR emitted by JAX. These cannot be consumed by XLA at the moment. PiperOrigin-RevId: 534060827 22 May 2023, 14:41:29 UTC
fb99b9e Remove some dead code `tuple_args` is annotated as `bool`, it should not be `None`. PiperOrigin-RevId: 534039682 22 May 2023, 12:59:39 UTC
b71829f Allow pjit.AUTO to be used with jax.jit. This introduces an API change which requires a mesh to be provided to pjit.AUTO(mesh). `with mesh:` is no longer required with pjit to use the auto spmd pass of GSPMD. PiperOrigin-RevId: 533801596 21 May 2023, 06:00:35 UTC
e0b5003 Allow unconstrained dimensions when using NamedShardings. PiperOrigin-RevId: 533752415 20 May 2023, 23:28:12 UTC
56ca8af Make custom_partitioning support multiple return values. PiperOrigin-RevId: 533584581 19 May 2023, 23:58:54 UTC
f832710 Update comments 19 May 2023, 21:29:41 UTC
368e20e Appease flake8 19 May 2023, 21:18:56 UTC
221aa76 Extend plugin discovery to also include entry-points. This effectively implements a mix of option 2 and option 3 from https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ as a pragmatic way to cover all packaging cases. The namespace/path based iteration works for situations where code has not been packaged and is present on the PYTHONPATH, whereas the advertised entry-points work around setuptools/pkgutil issues that make it impossible to reliably iterate over installed modules in certain scenarios (noted for editable installs which use a custom finder that does not implement iter_modules()). A plugin entry-point can be advertised in setup.py (or equivalent pyproject.toml) with something like: ``` entry_points={ "jax_plugins": [ "openxla-cpu = jax_plugins.openxla_cpu", ], } ``` 19 May 2023, 21:10:08 UTC
ee4bc2a Add a note that Kepler GPUs are not supported to the README. 19 May 2023, 15:50:45 UTC
1d20d2f Increase sharding of host_callback_test on TPU to fix CI flakiness. PiperOrigin-RevId: 533451822 19 May 2023, 14:44:53 UTC
acc527d Merge pull request #16068 from hawkinsp:b16066 PiperOrigin-RevId: 533448906 19 May 2023, 14:29:40 UTC
26f2711 Fix typo in config.py. Fixes #16066 19 May 2023, 14:07:12 UTC
bb775c7 Merge pull request #15871 from nouiz:doc PiperOrigin-RevId: 533434343 19 May 2023, 13:08:01 UTC
9da52e8 [PJRT PLUGIN] Provide a register_plugin method that plugin can use to register their backend factory. The plugin is expected to calls jax._src.xla_bridge.register_plugin with its plugin_name, priority (default to be 400), path to .so file, and optional create options in their initialize() method. Logics to register a plugin from ENV is not deleted to facilitate development with ENV. PiperOrigin-RevId: 533280115 18 May 2023, 23:13:02 UTC
4a5c6f8 For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the tracing time. PiperOrigin-RevId: 533263584 18 May 2023, 22:09:54 UTC
6034e87 Convert tuple to DeviceAssignment on the replicated compilation path. PiperOrigin-RevId: 533258935 18 May 2023, 21:55:29 UTC
0816929 Simplify custom_partitioning to use jax.ShapeDtypeStruct instead of passing separate arguments for shape and sharding. PiperOrigin-RevId: 533257532 18 May 2023, 21:48:07 UTC
39097df Add some preliminary support for int4/uint4 types to JAX. PiperOrigin-RevId: 533251630 18 May 2023, 21:27:33 UTC
8ca40b2 Try another sphinx fix. 18 May 2023, 20:21:50 UTC
646bdb8 Merge pull request #15952 from hawkinsp:bazel PiperOrigin-RevId: 533146188 18 May 2023, 16:03:11 UTC
8696bef Integrate StableHLO at openxla/stablehlo@14691ce Manual changes: * stablehlo/integrations/python/mlir/dialects/stablehlo.py: to keep around get_earliest_forward_compatible_version while it's still used in JAX. PiperOrigin-RevId: 533140501 18 May 2023, 15:42:26 UTC
back to top