48ad9a6 | Yash Katariya | 31 May 2023, 23:48:13 UTC | Start jax and jaxlib 0.4.11 release PiperOrigin-RevId: 536860076 | 31 May 2023, 23:48:52 UTC |
525ba49 | jax authors | 31 May 2023, 21:27:53 UTC | Merge pull request #16204 from skye:importlib_metadata_version PiperOrigin-RevId: 536823622 | 31 May 2023, 21:27:53 UTC |
9682370 | Skye Wanderman-Milne | 31 May 2023, 18:28:09 UTC | Add importlib_metadata to project requirements. This is necessary to ensure we can correctly detect PJRT plugins via entry_points without compatibility errors. Prior to this change, there was conditional logic to handle if importlib_metadata wasn't installed at all. However, it doesn't handle the case where importlib_metadata is installed by not high enough version to support Python 3.10 compat. This change gets rid of that logic and just ensures the right version is installed. All of this logic can be removed if/when jax requires Python version >= 3.10 This also removes an unnecessary `requests` dep for the [tpu] install. | 31 May 2023, 21:03:12 UTC |
b35c20c | Jieying Luo | 31 May 2023, 20:41:53 UTC | Use xla_extension_version and remove some dead version check in xla_bridge_test.py. Min jaxlib requires xla_extension_version >= 144. PiperOrigin-RevId: 536810415 | 31 May 2023, 20:50:07 UTC |
727c121 | jax authors | 31 May 2023, 20:41:29 UTC | Merge pull request #16188 from nouiz:ci_jestimator PiperOrigin-RevId: 536810121 | 31 May 2023, 20:41:29 UTC |
c587dac | jax authors | 31 May 2023, 18:39:06 UTC | Merge pull request #16203 from skye:tpu_py_version2 PiperOrigin-RevId: 536776189 | 31 May 2023, 18:39:06 UTC |
131d28b | Skye Wanderman-Milne | 02 May 2023, 22:45:27 UTC | Use default Python version on Cloud TPU CI | 31 May 2023, 18:04:39 UTC |
6d6ba70 | Yash Katariya | 31 May 2023, 13:20:26 UTC | Disable the RunnTest.test_lstm1 test since it is fixed for cudnn >= 8.8 PiperOrigin-RevId: 536693061 | 31 May 2023, 13:21:01 UTC |
758d68d | jax authors | 31 May 2023, 09:22:40 UTC | Restore `call_tf_concrete_function_list` to previous state In the following case of nested call: ``` inputs = np.array(range(6), dtype=np.float32).reshape(3, 2) @jax.jit def forward(x): return x + 1 # JAX -> TF tf_fn = jax2tf.convert(forward, native_serialization=True) call_tf_fn = jax2tf.call_tf(tf_fn) tf_fn_too = jax2tf.convert(call_tf_fn, native_serialization=True) tf_fn_too(inputs) # FAIL ``` Without the fix, it fails with the following error: ``` jax/experimental/jax2tf/jax2tf.py", line 499, in _restore_context _thread_local_state.call_tf_concrete_function_list.clear() AttributeError: 'NoneType' object has no attribute 'clear' ``` because we call `_restore_context` twice when executing `jax2tf.convert`ed functions, the first time we call `_restore_context`, `call_tf_concrete_function_list` is set to `None` instead of restoring it to the previous state, so the second time we call `_restore_context`, `call_tf_concrete_function_list.clear()` throws the above error since `call_tf_concrete_function_list` is `None`. PiperOrigin-RevId: 536650377 | 31 May 2023, 09:23:14 UTC |
f884b4d | Yash Katariya | 31 May 2023, 02:51:06 UTC | Fix the `test_sharding_on_output_with_vmap` failure in Pathways which was getting a cache miss in pjit_call_impl. There was an inconsistency between how the global cache was used at the top level and in pjit_call_impl so standardize it via a helper function. In the test, check for re-compilation which is what that test was doing before cl/535630905 PiperOrigin-RevId: 536575987 | 31 May 2023, 02:51:48 UTC |
3ad756f | jax authors | 31 May 2023, 02:16:52 UTC | Merge pull request #16176 from gnecula:poly_constraints PiperOrigin-RevId: 536571493 | 31 May 2023, 02:16:52 UTC |
9ad8c3b | George Necula | 13 May 2023, 14:57:27 UTC | [shape_poly] Add static constraint checking to the computation of dim vars Previously we had one function `shape_poly.unify_avals_with_args` that was solving the dimension variables and was also used for generating the code to compute them. Now we separate the solving part, which is now using just symbolic expressions (`shape_poly.solve_dim_vars`), from the code generator for the dimension variables (`shape_poly.compute_dim_vars_from_arg_shapes`). We also add a notion of shape constraints, e.g., `dimexpr1 == dimexpr2` or `dimexpr1 >= dimexpr2`, under which the solution for the dimension variables is valid. For now we implement the static checking of the shape constraints, e.g., when the dimension expressions are constant or TF EagerTensor. We do not yet have compile-time checking of the constraints. This matches the previous behavior. However, the code is now ready for implementing compile-time checking of the constraints that cannot be checked statically. | 31 May 2023, 01:48:44 UTC |
cdced24 | Frederic Bastien | 30 May 2023, 19:55:21 UTC | WAR the bug in t5x dependency. It currently need the dev version of jestimator. | 30 May 2023, 19:55:21 UTC |
acfeb9b | jax authors | 30 May 2023, 01:39:44 UTC | Merge pull request #16169 from ZacCranko:data_parallel_example PiperOrigin-RevId: 536260245 | 30 May 2023, 01:39:44 UTC |
a192b5e | Zac Cranko | 29 May 2023, 06:46:15 UTC | improve data parallel example fix example fix example fix example fix example fix example fix example | 30 May 2023, 01:25:17 UTC |
ae9160a | jax authors | 28 May 2023, 14:27:45 UTC | Merge pull request #16159 from jakevdp:deprecations PiperOrigin-RevId: 536003451 | 28 May 2023, 14:27:45 UTC |
7a87995 | Jake VanderPlas | 28 May 2023, 14:15:34 UTC | Deprecate jax.interpreters.xla.Buffer, device_put, xla_call_p | 28 May 2023, 14:15:34 UTC |
1279418 | Sharad Vikram | 26 May 2023, 21:01:42 UTC | Link in CUDA runtime for triton in jaxlib PiperOrigin-RevId: 535708416 | 26 May 2023, 21:02:16 UTC |
cb3b7ec | Jieying Luo | 26 May 2023, 20:14:02 UTC | [PJRT PLUGIN] Add num_processes to distributed.global_state. The number of processes is needed for multi-process GPU when plugin is used. PiperOrigin-RevId: 535696950 | 26 May 2023, 20:14:40 UTC |
d62bc0f | Yash Katariya | 26 May 2023, 19:57:01 UTC | Fix the jax2tf failure in mypy: https://github.com/google/jax/actions/runs/5094063162/jobs/9157426652?pr=16155 PiperOrigin-RevId: 535692853 | 26 May 2023, 19:57:37 UTC |
fe3fed3 | Yash Katariya | 26 May 2023, 19:34:32 UTC | Remove axis_resources from with_sharding_constraint since it has been 3 months since the deprecation as per the API deprecation policy. PiperOrigin-RevId: 535687618 | 26 May 2023, 19:35:16 UTC |
25a9a97 | jax authors | 26 May 2023, 16:48:40 UTC | Merge pull request #16151 from hawkinsp:cudnn PiperOrigin-RevId: 535642800 | 26 May 2023, 16:48:40 UTC |
4f07471 | Yash Katariya | 26 May 2023, 15:56:56 UTC | Make pjit_call_impl go via C++ dispatch. This is required for APIs like `eval_jaxpr` and `jaxpr_as_fun` that don't call the top level pjit/jit function but rather go via pjit_p.bind directly which calls into _pjit_call_impl. PiperOrigin-RevId: 535630905 | 26 May 2023, 15:57:30 UTC |
9508f3a | jax authors | 26 May 2023, 15:44:41 UTC | Merge pull request #16148 from gnecula:export_poly PiperOrigin-RevId: 535628086 | 26 May 2023, 15:44:41 UTC |
46a258b | George Necula | 26 April 2023, 07:11:04 UTC | [shape_poly] Add partial support for call_exported with polymorphic shapes Until now the jax_export.call_exported did not allow calling functions that were exported with polymorphic shapes. We now add that support, including resolving the dimension variables of the called function in terms of the shapes at the call site (which themselves may include dimension variables), and then computing the output shape of the called function. The support is partial in that we can export a JAX function that calls an exported polymorphic function, but we cannot invoke it. This is because we do not yet have access to the shape refinement machinery that XlaCallModule uses. For now, we use XlaCallModule for invoking exported that includes shape polymorphism. | 26 May 2023, 15:27:44 UTC |
2858df2 | Yash Katariya | 26 May 2023, 15:18:39 UTC | Start the process of removing OpSharding from JAX and replacing it with HloSharding. This will allow for future optimizations of HloSharding to work seamlessly with JAX. Currently, no function producing HloSharding is being used. I will do that in follow up CLs. PiperOrigin-RevId: 535622806 | 26 May 2023, 15:19:14 UTC |
69cf67f | Peter Hawkins | 26 May 2023, 14:04:34 UTC | Bump the minimum CUDNN version for CUDA 12 wheels to 8.9. | 26 May 2023, 14:04:34 UTC |
ea37043 | Chris Jones | 26 May 2023, 10:15:02 UTC | Switch to `STATUS_RETURNING` callback API. PiperOrigin-RevId: 535568707 | 26 May 2023, 10:15:44 UTC |
7833528 | jax authors | 25 May 2023, 23:31:09 UTC | Merge pull request #16143 from jakevdp:fix-shape-poly PiperOrigin-RevId: 535427698 | 25 May 2023, 23:31:09 UTC |
ed10293 | John QiangZhang | 25 May 2023, 22:58:16 UTC | Add new `called_index` to custom_call `tf.backend_config` DictAttr. Here, `called_index` indicates the tf concrete function index in the `function_list` of the parent XLACallModule. PiperOrigin-RevId: 535417558 | 25 May 2023, 22:58:50 UTC |
bbae2ed | Jake VanderPlas | 25 May 2023, 22:31:16 UTC | jax2tf: correctly handle opaque dtype in jax2tf pure() In TF tracers, "val" is the physical TF representation, while "aval" is the abstract value used during tracing, which is where additional JAX-specific information such as opaque dtype, weak_type, etc. should be included. Before opaque dtypes, val and aval always had the same shape and dtype. With opaque dtypes, this is no longer the case, which revealed this bug in the logic of jax2tf pure(). PiperOrigin-RevId: 535408671 | 25 May 2023, 22:32:47 UTC |
8534f0b | jax authors | 25 May 2023, 22:25:18 UTC | Merge pull request #16142 from froystig:outline-random-functions PiperOrigin-RevId: 535406588 | 25 May 2023, 22:25:18 UTC |
b853ce9 | Jake VanderPlas | 25 May 2023, 22:16:46 UTC | jax2tf: make shape_poly_test pass with custom PRNG | 25 May 2023, 22:16:46 UTC |
3238b62 | Roy Frostig | 25 May 2023, 01:00:54 UTC | outline jitted `jax.random` functions We may want to continue to inline these in Jaxpr somehow, but it's useful to outline them in HLO for visualization and debugging. | 25 May 2023, 22:01:04 UTC |
14089fb | jax authors | 25 May 2023, 20:40:34 UTC | Merge pull request #16138 from hawkinsp:cudnn PiperOrigin-RevId: 535367462 | 25 May 2023, 20:40:34 UTC |
2b77902 | Peter Hawkins | 25 May 2023, 18:46:39 UTC | Bump minimum CUDNN version in pip installation to 8.8. There are known wrong output bugs observed in JAX for earlier versions, in particular related to RNNs. | 25 May 2023, 18:46:39 UTC |
16368bc | Peter Hawkins | 25 May 2023, 18:09:41 UTC | [XLA:Python] Clean up handling of unsupported types in buffer protocol. Rather than enumerating a list of types that don't work in the buffer protocol, call the format descriptor function and fail if it fails. Simplify the format descriptor function to avoid allocating a format string; these can be compile-time constants. PiperOrigin-RevId: 535315975 | 25 May 2023, 18:10:19 UTC |
2155b91 | Chris Jones | 25 May 2023, 17:25:11 UTC | Switch to using JAX status macros in jax-triton kernel call lib. PiperOrigin-RevId: 535300412 | 25 May 2023, 17:26:06 UTC |
bc547aa | Mark Sandler | 25 May 2023, 17:13:50 UTC | Adds a note that pjit is equivalent to jit. PiperOrigin-RevId: 535296532 | 25 May 2023, 17:17:25 UTC |
32026ad | Peter Hawkins | 25 May 2023, 17:05:37 UTC | Disable random_test_with_custom_prng on CPU under msan. This test flakily times out in CI. PiperOrigin-RevId: 535293997 | 25 May 2023, 17:10:01 UTC |
24928a5 | jax authors | 25 May 2023, 17:02:26 UTC | Merge pull request #16117 from jakevdp:matrix-transpose PiperOrigin-RevId: 535292507 | 25 May 2023, 17:02:26 UTC |
222b951 | Jake VanderPlas | 25 May 2023, 16:32:14 UTC | Use new matrix_transpose in linalg code | 25 May 2023, 16:32:14 UTC |
333ff4a | Jake VanderPlas | 25 May 2023, 16:02:05 UTC | Add jnp.matrix_transpose() and jax.Array.mT This is an API proposed by the Python Array API Standard (https://data-apis.org/array-api/2022.12/). It's lightweight enough that there's hardly any downside to supporting it in JAX. | 25 May 2023, 16:02:05 UTC |
e464dc8 | Peter Hawkins | 25 May 2023, 14:19:56 UTC | Reland: [XLA:Python] Add buffer protocol support to jax.Array We supported the buffer protocol on the older DeviceArray class; port that support to jax.Array. The previous attempt was reverted because it led to a C++ CHECK failure if the buffer was deleted while an external Python reference was held. Change the CPU PJRT client to keep the underlying buffer alive as long as there are external references, which is what the contract of Delete() says it will do. Fixes https://github.com/google/jax/issues/14713 PiperOrigin-RevId: 535248553 | 25 May 2023, 14:20:42 UTC |
6b13d4e | Chris Jones | 25 May 2023, 13:22:49 UTC | Add branch prediction to JAX status macros. PiperOrigin-RevId: 535233546 | 25 May 2023, 13:23:23 UTC |
e25052c | Eugene Burmako | 25 May 2023, 04:14:40 UTC | Use stablehlo.get_minimum_version in jax_export.py The currently used stablehlo.get_earliest_forward_compatible_version was intended to be a short-term workaround, and it has been recently replaced by the long-term stablehlo.get_minimum_version API. This CL migrates to the long-term API. PiperOrigin-RevId: 535091927 | 25 May 2023, 04:15:16 UTC |
8e397f7 | Ce Zheng | 25 May 2023, 03:25:48 UTC | [XLA:Client] Change replicate_last_dim to subgroup_types in HloSharding.iota_tile to cover arbitrary subgroups, adding necessary accessors. PiperOrigin-RevId: 535079635 | 25 May 2023, 03:26:28 UTC |
5e82d6b | John QiangZhang | 24 May 2023, 22:27:17 UTC | Fix jax2tf_test regression failure. PiperOrigin-RevId: 535002015 | 24 May 2023, 22:27:57 UTC |
4fb834b | Sharad Vikram | 24 May 2023, 21:06:10 UTC | Use jaxlib version guard for triton instead of xla_extension_version PiperOrigin-RevId: 534974834 | 24 May 2023, 21:06:45 UTC |
6a54ebd | Yash Katariya | 24 May 2023, 20:58:17 UTC | Fix the lu.clear_all_cache function by adding the memoized_fun to the global weakref set rather than the function local `fun_caches` weakrefDict. PiperOrigin-RevId: 534971855 | 24 May 2023, 20:58:51 UTC |
557ca52 | Sharad Vikram | 24 May 2023, 20:18:38 UTC | Add `cuda_pip` extra for jaxlib PiperOrigin-RevId: 534957585 | 24 May 2023, 20:19:27 UTC |
bf8ed6a | Sharad Vikram | 24 May 2023, 19:09:57 UTC | Move triton_kernel_call_lib to jaxlib PiperOrigin-RevId: 534934592 | 24 May 2023, 19:11:21 UTC |
7de1677 | jax authors | 24 May 2023, 18:47:58 UTC | Add (optional) ordered effects for `jax2tf.call_tf` This allows users to express nested TensorFlow computation that must be ordered during execution. It leverages the existing JAX effects system to model such side effects and lower them to use XLA tokens. With this change, `jax2tf.call_tf(ordered=True)` can be used to generate ordered TF calls. This has the following behavior: * With `call_tf_graph=True`, this generates a custom call op with the following differences: (1) a `!stablehlo.token` argument/result is prepended to each custom call's argument/result list and (2) `tf.backend_config` has an additional `has_token_input_output = true` entry. * Without `call_tf_graph=True`, this raises a `NotImplementedError()`. For this, `jax_export.py` makes sure that dummy arguments/results added for ordered effects are not exposed to the public interface by passing constant values in a wrapper function. Because of this, adding ordered effects to jax2tf-ed computation no longer causes calling convention changes and can be safely allowed. Example StableHLO produced from the added test: ``` module @jit_f_jax attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) { %0 = stablehlo.constant dense<> : tensor<0xi1> %1:2 = call @_wrapped_jax_export_main(%0, %arg0) : (tensor<0xi1>, tensor<f32>) -> (tensor<0xi1>, tensor<f32>) return %1#1 : tensor<f32> } func.func private @_wrapped_jax_export_main(%arg0: tensor<0xi1> {jax.token = true}, %arg1: tensor<f32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"}) -> (tensor<0xi1> {jax.token = true}, tensor<f32> {jax.result_info = ""}) { %0 = stablehlo.create_token : !stablehlo.token %1 = stablehlo.constant dense<0> : tensor<i32> %2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %1, %iterArg_1 = %arg1) : !stablehlo.token, tensor<i32>, tensor<f32> cond { %4 = stablehlo.constant dense<4> : tensor<i32> %5 = stablehlo.compare LT, %iterArg_0, %4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1> stablehlo.return %5 : tensor<i1> } do { %4 = stablehlo.custom_call @tf.call_tf_function(%iterArg, %iterArg_1) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {caller_name = "__inference_callable_flat_tf_10", has_token_input_output = true}} : (!stablehlo.token, tensor<f32>) -> !stablehlo.token %5 = stablehlo.constant dense<1.000000e+00> : tensor<f32> %6 = stablehlo.add %iterArg_1, %5 : tensor<f32> %7 = stablehlo.constant dense<1> : tensor<i32> %8 = stablehlo.add %iterArg_0, %7 : tensor<i32> stablehlo.return %4, %8, %6 : !stablehlo.token, tensor<i32>, tensor<f32> } %3 = stablehlo.constant dense<> : tensor<0xi1> return %3, %2#2 : tensor<0xi1>, tensor<f32> } } ``` PiperOrigin-RevId: 534926215 | 24 May 2023, 18:48:35 UTC |
399e4ee | Jake Vanderplas | 24 May 2023, 17:35:01 UTC | Copybara import of the project: -- 8cf6a6acd151007935b0c3093df05ef036bb0244 by Jake VanderPlas <jakevdp@google.com>: Remove several deprecated APIs COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16110 from jakevdp:deprecations 8cf6a6acd151007935b0c3093df05ef036bb0244 PiperOrigin-RevId: 534897394 | 24 May 2023, 17:35:37 UTC |
1831b3c | jax authors | 24 May 2023, 15:41:51 UTC | Merge pull request #16105 from kmillikin:main PiperOrigin-RevId: 534854308 | 24 May 2023, 15:41:51 UTC |
2f7cc7d | jax authors | 24 May 2023, 15:11:47 UTC | Merge pull request #16109 from michaeldeistler:readme-fix PiperOrigin-RevId: 534844507 | 24 May 2023, 15:11:47 UTC |
5f1952d | Michael Deistler | 24 May 2023, 08:43:03 UTC | fix typo | 24 May 2023, 08:43:03 UTC |
921fd22 | Kevin Millikin | 24 May 2023, 06:47:50 UTC | Refer to the original `map`/`zip` classes via `builtins` Referring to them as simply `map` or `zip` will create recursive reimplementations (with no base case!) if the cell is reevaluated in the same runtime. | 24 May 2023, 06:47:50 UTC |
d9e7a2a | jax authors | 24 May 2023, 00:40:18 UTC | Merge pull request #16102 from jakevdp:deprecate-lax-prod PiperOrigin-RevId: 534618632 | 24 May 2023, 00:40:18 UTC |
4cfa96e | Jake VanderPlas | 23 May 2023, 21:39:38 UTC | deprecate jax.lax.prod | 24 May 2023, 00:33:50 UTC |
2d525b8 | jax authors | 24 May 2023, 00:32:17 UTC | Merge pull request #16103 from jakevdp:deprecation-stacklevel PiperOrigin-RevId: 534616543 | 24 May 2023, 00:32:17 UTC |
016eae4 | Parker Schuh | 24 May 2023, 00:16:17 UTC | Allow disabling the parsing of GSPMDSharding -> NamedSharding. Because this is best effort, users writing code to handle GPSMDSharding should be able to deal only with the GSPMDSharding type. PiperOrigin-RevId: 534612265 | 24 May 2023, 00:16:56 UTC |
fecf193 | jax authors | 23 May 2023, 23:56:02 UTC | Merge pull request #16070 from hawkinsp:docs2 PiperOrigin-RevId: 534606219 | 23 May 2023, 23:56:02 UTC |
4f1f5e4 | Ce Zheng | 23 May 2023, 23:37:02 UTC | [XLA:Client] Expose HloSharding pybind factories for iota tile/partial tile, replicated and manual sharding, PiperOrigin-RevId: 534600886 | 23 May 2023, 23:37:42 UTC |
7f7f995 | Jake VanderPlas | 23 May 2023, 21:50:06 UTC | Export jax.lax.sharding_constraint_p PiperOrigin-RevId: 534566582 | 23 May 2023, 21:50:46 UTC |
2623473 | Jake VanderPlas | 23 May 2023, 21:43:38 UTC | Make deprecation warnings warn at appropriate stacklevel | 23 May 2023, 21:43:38 UTC |
db87167 | John Cater | 23 May 2023, 20:59:53 UTC | Migrate `exec_tools` back to `tools`. PiperOrigin-RevId: 534549617 | 23 May 2023, 21:00:34 UTC |
16410a5 | jax authors | 23 May 2023, 20:41:23 UTC | Merge pull request #16096 from mattjj:softmax-custom-jvp-2 PiperOrigin-RevId: 534542849 | 23 May 2023, 20:41:23 UTC |
d42350f | Matthew Johnson | 23 May 2023, 18:56:50 UTC | disable custom_jvp for softmax by default Follow-up on #15677, basically undoing it. Some training runs experienced mysterious failures after many steps. We may leave this disabled until we diagnose the cause of the failures. | 23 May 2023, 18:56:50 UTC |
685ea5b | jax authors | 23 May 2023, 17:39:38 UTC | Merge pull request #15069 from mattjj:issue15068 PiperOrigin-RevId: 534476334 | 23 May 2023, 17:39:38 UTC |
62fb0cd | Jake VanderPlas | 23 May 2023, 16:44:08 UTC | explicitly convert jnp.var scalar normalizer to float (from int) This way we don't pass a potentially-large (Python builtin) int value to an int32 JAX computation parameter and get an error. Fixes #15068 Co-authored by: Matthew Johnson <mattjj@google.com> | 23 May 2023, 16:44:08 UTC |
a7b8129 | jax authors | 23 May 2023, 00:34:51 UTC | Merge pull request #16073 from stellaraccident:extplugin PiperOrigin-RevId: 534237189 | 23 May 2023, 00:34:51 UTC |
13f5090 | jax authors | 22 May 2023, 20:31:04 UTC | Merge pull request #16018 from ZacCranko:tree_reduce_is_leaf PiperOrigin-RevId: 534165099 | 22 May 2023, 20:31:04 UTC |
69d6c1b | jax authors | 22 May 2023, 19:46:41 UTC | Merge pull request #16086 from froystig:upgraded-key-ctor PiperOrigin-RevId: 534152508 | 22 May 2023, 19:46:41 UTC |
85fb48a | jax authors | 22 May 2023, 19:34:48 UTC | Merge pull request #15930 from canyon289:jax201 PiperOrigin-RevId: 534149169 | 22 May 2023, 19:34:48 UTC |
b7b90e6 | Roy Frostig | 22 May 2023, 18:35:06 UTC | add new random `key` constructor This constructor unconditionally returns a typed key array, regardless of the value of `jax.config.enable_custom_prng`. We can switch to referring to it in randomness docs and tutorials as we complete the typed key upgrade. | 22 May 2023, 18:35:10 UTC |
2f215e5 | jax authors | 22 May 2023, 18:11:47 UTC | Merge pull request #16046 from mattjj:generalize-dot-general-dtypes PiperOrigin-RevId: 534123892 | 22 May 2023, 18:11:47 UTC |
61b106e | Matthew Johnson | 13 May 2023, 02:56:59 UTC | allow lax.dot_general to accept different input dtypes This change brings the dot_general primitive more in line with the HLO primitive, as it is described in XLA's shape_inference.cc (but not in the StableHLO spec). In particular we allow different input dtypes. The main motivation is to support transposition in the presence of preferred_element_type (which can set the output dtype to be different from the inputs), e.g. to fix #10818. However, because XLA platforms/backends can't seem to codegen all the cases that are accepted by shape_inference.cc, in our lowering rules we generate ConvertElementTypes on the inputs in a platform-dependent way. | 22 May 2023, 17:33:42 UTC |
473fa7d | Ravin Kumar | 22 May 2023, 17:03:12 UTC | Add building on JAX | 22 May 2023, 17:05:39 UTC |
ee14ca2 | Peter Hawkins | 22 May 2023, 14:40:52 UTC | Add option jax_include_full_tracebacks_in_locations. If enabled, includes full stack traces in MLIR emitted by JAX. These cannot be consumed by XLA at the moment. PiperOrigin-RevId: 534060827 | 22 May 2023, 14:41:29 UTC |
fb99b9e | kmillikin | 22 May 2023, 12:59:06 UTC | Remove some dead code `tuple_args` is annotated as `bool`, it should not be `None`. PiperOrigin-RevId: 534039682 | 22 May 2023, 12:59:39 UTC |
b71829f | Yash Katariya | 21 May 2023, 05:59:52 UTC | Allow pjit.AUTO to be used with jax.jit. This introduces an API change which requires a mesh to be provided to pjit.AUTO(mesh). `with mesh:` is no longer required with pjit to use the auto spmd pass of GSPMD. PiperOrigin-RevId: 533801596 | 21 May 2023, 06:00:35 UTC |
e0b5003 | Sholto Douglas | 20 May 2023, 23:27:28 UTC | Allow unconstrained dimensions when using NamedShardings. PiperOrigin-RevId: 533752415 | 20 May 2023, 23:28:12 UTC |
56ca8af | Parker Schuh | 19 May 2023, 23:58:21 UTC | Make custom_partitioning support multiple return values. PiperOrigin-RevId: 533584581 | 19 May 2023, 23:58:54 UTC |
f832710 | Stella Laurenzo | 19 May 2023, 21:29:41 UTC | Update comments | 19 May 2023, 21:29:41 UTC |
368e20e | Stella Laurenzo | 19 May 2023, 21:18:56 UTC | Appease flake8 | 19 May 2023, 21:18:56 UTC |
221aa76 | Stella Laurenzo | 19 May 2023, 21:10:08 UTC | Extend plugin discovery to also include entry-points. This effectively implements a mix of option 2 and option 3 from https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ as a pragmatic way to cover all packaging cases. The namespace/path based iteration works for situations where code has not been packaged and is present on the PYTHONPATH, whereas the advertised entry-points work around setuptools/pkgutil issues that make it impossible to reliably iterate over installed modules in certain scenarios (noted for editable installs which use a custom finder that does not implement iter_modules()). A plugin entry-point can be advertised in setup.py (or equivalent pyproject.toml) with something like: ``` entry_points={ "jax_plugins": [ "openxla-cpu = jax_plugins.openxla_cpu", ], } ``` | 19 May 2023, 21:10:08 UTC |
ee4bc2a | Peter Hawkins | 19 May 2023, 15:50:45 UTC | Add a note that Kepler GPUs are not supported to the README. | 19 May 2023, 15:50:45 UTC |
1d20d2f | Peter Hawkins | 19 May 2023, 14:44:18 UTC | Increase sharding of host_callback_test on TPU to fix CI flakiness. PiperOrigin-RevId: 533451822 | 19 May 2023, 14:44:53 UTC |
acc527d | jax authors | 19 May 2023, 14:29:40 UTC | Merge pull request #16068 from hawkinsp:b16066 PiperOrigin-RevId: 533448906 | 19 May 2023, 14:29:40 UTC |
26f2711 | Peter Hawkins | 19 May 2023, 14:07:12 UTC | Fix typo in config.py. Fixes #16066 | 19 May 2023, 14:07:12 UTC |
bb775c7 | jax authors | 19 May 2023, 13:08:01 UTC | Merge pull request #15871 from nouiz:doc PiperOrigin-RevId: 533434343 | 19 May 2023, 13:08:01 UTC |
9da52e8 | Jieying Luo | 18 May 2023, 23:12:23 UTC | [PJRT PLUGIN] Provide a register_plugin method that plugin can use to register their backend factory. The plugin is expected to calls jax._src.xla_bridge.register_plugin with its plugin_name, priority (default to be 400), path to .so file, and optional create options in their initialize() method. Logics to register a plugin from ENV is not deleted to facilitate development with ENV. PiperOrigin-RevId: 533280115 | 18 May 2023, 23:13:02 UTC |
4a5c6f8 | Yash Katariya | 18 May 2023, 22:09:00 UTC | For nested pjit's cache the generation of StableHLO if it satifies the key. This should help in improving the tracing time. PiperOrigin-RevId: 533263584 | 18 May 2023, 22:09:54 UTC |
6034e87 | jax authors | 18 May 2023, 21:52:37 UTC | Convert tuple to DeviceAssignment on the replicated compilation path. PiperOrigin-RevId: 533258935 | 18 May 2023, 21:55:29 UTC |
0816929 | Parker Schuh | 18 May 2023, 21:47:34 UTC | Simplify custom_partitioning to use jax.ShapeDtypeStruct instead of passing separate arguments for shape and sharding. PiperOrigin-RevId: 533257532 | 18 May 2023, 21:48:07 UTC |
39097df | Peter Hawkins | 18 May 2023, 21:26:42 UTC | Add some preliminary support for int4/uint4 types to JAX. PiperOrigin-RevId: 533251630 | 18 May 2023, 21:27:33 UTC |
8ca40b2 | Frederic Bastien | 18 May 2023, 20:21:50 UTC | Try another sphinx fix. | 18 May 2023, 20:21:50 UTC |
646bdb8 | jax authors | 18 May 2023, 16:03:11 UTC | Merge pull request #15952 from hawkinsp:bazel PiperOrigin-RevId: 533146188 | 18 May 2023, 16:03:11 UTC |
8696bef | Eugene Burmako | 18 May 2023, 15:41:42 UTC | Integrate StableHLO at openxla/stablehlo@14691ce Manual changes: * stablehlo/integrations/python/mlir/dialects/stablehlo.py: to keep around get_earliest_forward_compatible_version while it's still used in JAX. PiperOrigin-RevId: 533140501 | 18 May 2023, 15:42:26 UTC |