https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
88a60b8 Merge pull request #16870 from skye:version PiperOrigin-RevId: 551636421 27 July 2023, 21:11:05 UTC
e132a0e Slightly downgrade xla version to avoid PJRT C API incompat 27 July 2023, 21:05:14 UTC
c75e85d Merge pull request #16869 from skye:version PiperOrigin-RevId: 551628701 27 July 2023, 20:47:46 UTC
a480aa8 Work around pytype error. An upcoming pytype release complains about unpacking a non-deterministic order iterable for this line of code. Work around pytype. PiperOrigin-RevId: 551627521 27 July 2023, 20:39:48 UTC
0b24b2b Update WORKSPACE and setup.py in preparation for jax/jaxlib 0.4.14 release 27 July 2023, 20:35:04 UTC
7c0ef86 Merge pull request #16433 from JGameCreation:patch-1 PiperOrigin-RevId: 551614086 27 July 2023, 19:52:28 UTC
76cda0a Update flags to use the ABSL typed flag API. Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary. For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API. Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`. This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR. PiperOrigin-RevId: 551604974 27 July 2023, 19:15:58 UTC
f35f226 Merge pull request #16865 from google:xla PiperOrigin-RevId: 551565217 27 July 2023, 17:07:24 UTC
0cd0253 Update XLA version to fix build failure. 27 July 2023, 12:38:51 UTC
a03d6e6 Move _tpu_ext.cc to jaxlib/mlir/_mlir_libs and set RPATH correctly _tpu_ext.so dynamically links in libjaxlib_mlir_capi.so (in jaxlib/mlir/_mlir_libs), so needs to include jaxlib/mlir/_mlir_libs in its RPATH or similar on other platforms. We achieve this by moving _tpu_ext.cc to jaxlib/mlir/_mlir_libs so it can use the same linkopts as other mlir targets that depend on libjaxlib_mlir_capi.so. In particular, we want this to work correctly across platforms, and it's not clear if Windows supports RPATH-like functionality beyond the current directory. PiperOrigin-RevId: 551372130 27 July 2023, 01:25:17 UTC
bcddc50 Merge pull request #16852 from hawkinsp:builddeps PiperOrigin-RevId: 551333094 26 July 2023, 22:33:01 UTC
5c39a0d Merge pull request #16844 from jakevdp:jnp-put PiperOrigin-RevId: 551331519 26 July 2023, 22:24:37 UTC
735637e Previously, using sparse.todense on a BCSR matrix with sparse.sparsify would raise `NotImplementedError: sparse rule for todense is not implemented`. By adding the sparse rule, it will resolve this issue. PiperOrigin-RevId: 551291543 26 July 2023, 20:01:02 UTC
3c4527b Check `build` and `wheel` are installed before building `jaxlib`. 26 July 2023, 18:46:11 UTC
416814d Merge pull request #16826 from mattjj:issue16805 PiperOrigin-RevId: 551263673 26 July 2023, 18:20:31 UTC
9e69277 Merge pull request #16849 from skye:mosaic_test_build_fix PiperOrigin-RevId: 551235536 26 July 2023, 16:50:14 UTC
88c42da Add implementation of jnp.put 26 July 2023, 15:54:54 UTC
d0b65f2 Make //jax:tpu_custom_call respect --//jax:build_jaxlib=false Otherwise jaxlib is partially built and doesn't work properly. 26 July 2023, 15:50:42 UTC
1054fe5 Merge pull request #16846 from gnecula:poly_dot PiperOrigin-RevId: 551174686 26 July 2023, 12:21:54 UTC
3baa6e7 Enable building jaxlib w/ Mosaic PiperOrigin-RevId: 551159246 26 July 2023, 10:59:30 UTC
c9f9f28 [shape_poly] Fix handling of dot_general with different lhs_dtype and rhs_dtype Add primitives tests for the case of dot_general with different lhs_dtype and rhs_dtype. Then fix the lowering to work with dynamic shapes. 26 July 2023, 10:29:12 UTC
f66d3cf Merge pull request #16842 from jakevdp:dynamic-slice-unsigned PiperOrigin-RevId: 550981737 25 July 2023, 20:37:31 UTC
0dbda84 lax.dynamic_slice: avoid negative index correction for unsigned indices 25 July 2023, 20:09:09 UTC
def6190 Merge pull request #16833 from gnecula:poly_v8_1 PiperOrigin-RevId: 550929048 25 July 2023, 17:36:35 UTC
b116616 Merge pull request #16839 from jakevdp:fix-core-deprecation PiperOrigin-RevId: 550920851 25 July 2023, 17:12:04 UTC
3b6b988 fix deprecations in core.py 25 July 2023, 16:47:04 UTC
b6ed056 [jax2tf] Add support for serialization version 8. In this version the serialized module contain a StableHLO module boolean attribute `jax.uses_shape_polymorphism` that specifies whether the module uses shape polymorphism. If it doesn't then we do not need to do shape refinement. Note that we are still keeping the default serialization version to 6, for forward compatibility. However, the serialization unit tests now run at version 8. Made Exported.mlir_module a method instead of a propery, to make it more obvious that it is a derived artifact. 25 July 2023, 05:13:34 UTC
95b410c Merge pull request #16820 from gnecula:export_ver PiperOrigin-RevId: 550769937 25 July 2023, 04:58:07 UTC
4081035 [jax2tf] Added error for attempting to use wrong jax_serialization_version Previously, the serialization would use the specified serialization version without checking if it supported by the serialzier. This could result in invalid serializations Also add some compatibility tests for all supported versions. 25 July 2023, 04:42:49 UTC
14a6089 Change `mhlo.is_same_data_across_replicas` from unit attr to bool attr Using bool attrs aligns better with StableHLO. Since [VHLO does not define unit attrs](https://github.com/openxla/stablehlo/blob/main/stablehlo/dialect/VhloAttrs.td), serializing StableHLO modules containing unit attrs fails. This becomes a problem when we want to serialize MHLO modules containing `mhlo.is_same_data_across_replicas` by converting them into StableHLO then VHLO. JAX emits `mhlo.is_same_data_across_replicas` as a bool attr only after a new jaxlib version since this requires the jaxlib to understand the new attr type. PiperOrigin-RevId: 550745955 25 July 2023, 02:50:33 UTC
7821516 Make _remake internal and add return type hints PiperOrigin-RevId: 550721261 25 July 2023, 00:36:36 UTC
727af17 Merge pull request #16829 from jakevdp:has-opaque PiperOrigin-RevId: 550686011 24 July 2023, 22:11:55 UTC
e1a1377 replace use of has_opaque_dtype 24 July 2023, 21:46:58 UTC
b4132b4 Copybara import of the project: -- b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b by Jake VanderPlas <jakevdp@google.com>: Rename opaque dtype to extended dtype. This includes three deprecations: - jax.core.is_opaque_dtype(dt) is deprecated in favor of jnp.issubdtype(dt, jax.dtypes.extended) - jax.core.has_opaque_dtype(x) is deprecated in favor of jnp.issubdtype(x.dtype, jax.dtypes.extended) - the allow_opaque_dtype argument to jax.core.canonicalize_dtype is now allow_extended_dtype Because jax.core is explicitly excluded from the API deprecation policy, these changes will not be subject to a standard 3-month deprecation period. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16824 from jakevdp:extended-dtype b243ea79ae7c9e2c2aa85e264b8dca8fc4c61b7b PiperOrigin-RevId: 550674205 24 July 2023, 21:38:20 UTC
c6fa3d9 Merge pull request #16822 from jakevdp:jupytext-version PiperOrigin-RevId: 550673985 24 July 2023, 21:29:37 UTC
9ddef5c make _dot_general_batch_rule handle python builtin numeric types 24 July 2023, 21:01:07 UTC
7bb8312 CI: update jupytext to v0.14.7 24 July 2023, 18:51:45 UTC
83d99bb Merge pull request #16813 from gnecula:poly_err_msg PiperOrigin-RevId: 550530238 24 July 2023, 12:52:01 UTC
deb8fdf [shape_poly] Improve error messages for shape assertions Starting with serialization version 7 we introduce shape assertions that are checked at runtime. In the process of rolling out version 7 we encoutered projects with failed shape assertions and it became clear that we need better error messages. See the changes here in tests and README.md for example of the updated assertions. To produce these assertions we now pass multiple operands to the shape assertion, and we introduce a CachedShapeEvaluator to reduce the amount of duplicate code generated. 24 July 2023, 11:57:06 UTC
32cbc36 Integrate LLVM at llvm/llvm-project@571c1292b693 Updates LLVM usage to match [571c1292b693](https://github.com/llvm/llvm-project/commit/571c1292b693) PiperOrigin-RevId: 550071080 21 July 2023, 22:56:28 UTC
8a1a5fa Merge pull request #16781 from jakevdp:prng-dtypes PiperOrigin-RevId: 550068690 21 July 2023, 22:45:29 UTC
7d7a536 custom prng: introduce mechanism to identify key arrays by dtype 21 July 2023, 19:27:32 UTC
1b33a4e Merge pull request #16815 from hawkinsp:py39 PiperOrigin-RevId: 550014612 21 July 2023, 19:12:47 UTC
319ab98 Apply pyupgrade --py39-plus. Notable changes: * use PEP 585 type names * use PEP 604 type union syntax where `from __future__ import annotations` is present. * use f-strings in more places. * remove redundant arguments to open(). 21 July 2023, 18:49:44 UTC
4114e6c Improve the default value of `output_shape_dtype`. PiperOrigin-RevId: 549988693 21 July 2023, 17:45:02 UTC
c8f4650 Merge pull request #16721 from jakevdp:dot-mixed-precision PiperOrigin-RevId: 549986744 21 July 2023, 17:37:11 UTC
561c953 Lower jax.numpy.dot to mixed-precision dot_general 21 July 2023, 17:10:30 UTC
884bcf4 Introduce version 8 of XlaCallModule. Previously, XlaCallModule was running the shape refinement pass for all compilations, even if the module did not use shape polymorphism. Currently shape refinement changes the structure of the module, through inlining and constant folding all integer operations. This complicates debugging because the HLO dump is very different than the one from JAX native executions. Starting with version 8, we run shape refinement only if the module contains a boolean module attribute jax.uses_shape_polymorphism=true. I think it makes sense to put this flag as a module attribute, rather than as a TF op attribute, because the same processing will be needed when the module is executed from JAX. This attribute is not yet populated by the JAX exporter. As part of this change we moved the error check for the number of invocation arguments from RefineDynamicShapes to LoadAndPreprocessModule. This required adding a couple more arguments to the loader constructor. PiperOrigin-RevId: 549973693 21 July 2023, 16:51:19 UTC
90840e4 Merge pull request #16795 from jakevdp:refactor-opaque PiperOrigin-RevId: 549962048 21 July 2023, 16:02:59 UTC
684228e Add back tracebacks to checkify's Error without leaking tracers. The trick is to save the traceback as an XLA traceback, then turn it into a python traceback only when throwing the error. No locals are leaked in the process. PiperOrigin-RevId: 549957746 21 July 2023, 15:44:26 UTC
8f9a34a Merge pull request #16802 from gnecula:ser_versions PiperOrigin-RevId: 549839838 21 July 2023, 05:25:08 UTC
71e2d28 [jax2tf] Document the JAX serialization version numbers. 21 July 2023, 05:11:44 UTC
2ffa9bd Refactor opaque dtype implementation. This makes it closer to numpy, with dtypes.OpaqueDtype analogous to np.dtype, and dtypes.opaque analogous to np.numeric. This will let us replace the dtypes.is_opaque_dtype function with jnp.issubdtype(dtype, dtypes.opaque). 21 July 2023, 02:51:52 UTC
3d556b7 Add Mosaic to Jaxlib and expose bindings in `jax.experimental.mosaic` PiperOrigin-RevId: 549801858 21 July 2023, 01:28:51 UTC
dcad04d Add support for int fields to compiler_options. PiperOrigin-RevId: 549790380 21 July 2023, 00:37:19 UTC
60bb3bc Merge pull request #16797 from google:tpu_ci_pip_update PiperOrigin-RevId: 549728792 20 July 2023, 20:56:13 UTC
aaff43c Merge pull request #16807 from jakevdp:tree-deprecation PiperOrigin-RevId: 549728774 20 July 2023, 20:48:04 UTC
2691d7e Use standard framework for jax.tree* deprecation. 20 July 2023, 19:58:17 UTC
fe30d3f Move _array_shard_arg helpers from pxla into array. Refactoring only which fixes a TODO. Add a canonicalize argument to pxla.shard_arg so we can call that API from array yet avoid double-canonicalization. PiperOrigin-RevId: 549658117 20 July 2023, 16:48:10 UTC
08366b2 Merge pull request #15679 from mattjj:issue15676 PiperOrigin-RevId: 549656910 20 July 2023, 16:40:03 UTC
65751bb make jvp(asarray, (1.,), (2.,)) produce Arrays fixes #15676 Co-authored-by: Matthew Johnson <mattjj@google.com> 20 July 2023, 16:21:55 UTC
fa5915b Delete pxla.make_sharded_device_array. This function is unused and not exported from JAX. PiperOrigin-RevId: 549606907 20 July 2023, 12:58:35 UTC
1ceddfc Merge pull request #16710 from gnecula:poly_max0 PiperOrigin-RevId: 549515427 20 July 2023, 04:40:17 UTC
e2a49ee Tweaks the utility function `_get_ppspec_from_executable` to get the shardings directly from the executable (instead of from its HLO modules). PiperOrigin-RevId: 549473458 20 July 2023, 00:38:59 UTC
c006e52 Merge pull request #16779 from jakevdp:random-gamma PiperOrigin-RevId: 549460066 19 July 2023, 23:39:05 UTC
54adc74 Merge pull request #16794 from lgeiger:py39-version-checks PiperOrigin-RevId: 549457819 19 July 2023, 23:29:18 UTC
6c90976 Cloud TPU CI: make sure we update test deps and upgrade protobuf version `profiler_test.py:ProfilerTest.test_remote_profiler` fails with the protobuf upgrade. However, I was seeing mysterious hangs without this, and in general I think we should be testing with up-to-date deps given that we don't pin. I'm gonna continue working on getting the Cloud TPU CI green. 19 July 2023, 23:17:47 UTC
7205160 Re-parameterize jax.random.gamma for better behavior at endpoints 19 July 2023, 23:15:03 UTC
cd951f4 Merge pull request #16793 from lgeiger:numpy-version-checks PiperOrigin-RevId: 549454320 19 July 2023, 23:13:55 UTC
6812d5c Remove unneeded Python 3.9+ version checks 19 July 2023, 22:37:30 UTC
de2c854 Remove obsolete numpy version checks 19 July 2023, 22:33:47 UTC
0c4c020 Include compile time along with executable in cache entry. In order to measure cache savings, we add compilation time to the cache entry along with the serialized executable. The compile time can then be retrieved on a cache hit. Testing: updated tests. PiperOrigin-RevId: 549439628 19 July 2023, 22:17:45 UTC
5ae3ac2 Add deprecation of `jax.stages.Compiled.compiler_ir` to the change log PiperOrigin-RevId: 549415191 19 July 2023, 20:48:55 UTC
cd2dc2f Error if memory_kind is not correct for the devices in Shardings during initialization. PiperOrigin-RevId: 549410478 19 July 2023, 20:32:39 UTC
7df3477 [JAX] Use MLIR argument locations instead of a bespoke jax.arg_info attribute. https://github.com/llvm/llvm-project/commit/514dddbeba643e32310c508a15d7b6ff42f2c461 allowed for specifying argument Locations in the MLIR Python bindings. We should use them, in the form of a Name location, rather than making up our own attribute. Example of new output: ``` In [1]: import jax In [2]: ir = jax.jit(lambda x, y: x + y).lower(7, 3).compiler_ir() In [3]: ir.operation.print(enable_debug_info=True) #loc1 = loc("x") #loc2 = loc("y") module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<i32> {mhlo.sharding = "{replicated}"} loc("x"), %arg1: tensor<i32> {mhlo.sharding = "{replicated}"} loc("y")) -> (tensor<i32> {jax.result_info = ""}) { %0 = stablehlo.add %arg0, %arg1 : tensor<i32> loc(#loc4) return %0 : tensor<i32> loc(#loc) } loc(#loc) } loc(#loc) #loc = loc(unknown) #loc3 = loc("<ipython-input-2-ef5a568a0c1c>":1:0) #loc4 = loc("jit(<lambda>)/jit(main)/add"(#loc3)) ``` Note debug information must be enabled. PiperOrigin-RevId: 549325621 19 July 2023, 15:39:16 UTC
f94104f Skip PgleTest.testPassingFDOProfile if xla_extension_version < 169 PiperOrigin-RevId: 549322105 19 July 2023, 15:22:32 UTC
cdb4813 [JAX] Add support for multiple pytree registries. We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another. One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially. PiperOrigin-RevId: 549301796 19 July 2023, 13:48:21 UTC
4fdc134 [shape_poly] Add support for max0 for symbolic dimensions. There are a few cases when JAX computes `max(v, 0)`, most notably when computing the sizes of strided access, dilated convolutions and padding, and for the size of jnp.arange. Until now these cases were supported for shape polymorphism only when we can tell statically that the size is >= 0. Here we add support to the symbolic expressions for a `non_negative` operator, which essentially implements `max(v, 0)` and with this we can now support the general case for `jnp.arange`, with simpler code. We could add a general `max` operator, and we may do so in the future, but for now `non_negative` suffices. Note that this fixes a couple of bugs * for core.dilated_dim we had the code "if d == 0 then 0 else ..." but this works only if we can tell statically that `d == 0`, and it produced wrong results when `d` was symbolic and could take the value 0. * for core.stride_dim we did not handle correctly the case when `d < window_size`. Handling the above fundamentally requires a `max(d, 0)` operation. 19 July 2023, 13:15:04 UTC
e643f98 [shape_poly] Reimplement the shape constraint checking using shape assertions. Most of the functionality is for the JAX native serialization case. This relies on newly added functionality to xla_extension.refine_polymorphic_shapes that handles custom calls @static_assertion. As a beneficial side-effect now we get shape constraint checking for jax2tf graph serialization when the resulting function is executed in graph mode. 19 July 2023, 06:56:33 UTC
f97dca7 Merge pull request #16752 from gnecula:bc_schur PiperOrigin-RevId: 549219190 19 July 2023, 06:48:32 UTC
f29544b [jax2tf] Add backwards compatibility for lax.linalg.schur on CPU 19 July 2023, 06:39:50 UTC
3f2bff5 Merge pull request #16751 from gnecula:bc_triangular PiperOrigin-RevId: 549216345 19 July 2023, 06:31:08 UTC
0dd45dd [jax2tf] Add backwards compatibility test for lax.triangular_solve on CPU 19 July 2023, 05:40:22 UTC
4b72163 Merge pull request #16775 from froystig:random-api-policy PiperOrigin-RevId: 549122781 18 July 2023, 22:06:56 UTC
b7686f4 Enable passing fdo_profile in compiler_options in pxla.py PiperOrigin-RevId: 549109629 18 July 2023, 21:18:28 UTC
9150b23 add `jax.prng` to uncovered modules list in API policy 18 July 2023, 21:13:25 UTC
9aa5307 API compatibility policy: expand on numerics and randomness 18 July 2023, 21:13:25 UTC
579808d Add memory_kind to NamedSharding, SingleDeviceSharding, PositionalSharding and GSPMDSharding. PiperOrigin-RevId: 548997870 18 July 2023, 14:39:36 UTC
59509dc Remove the jax_array config option, which does nothing. PiperOrigin-RevId: 548981491 18 July 2023, 13:16:06 UTC
8016fb3 Merge pull request #16769 from jakevdp:fix-jax-array PiperOrigin-RevId: 548835682 17 July 2023, 23:51:13 UTC
7415913 support np.array(x) where x is a custom pytree with __jax_array__ 17 July 2023, 20:33:17 UTC
68ea651 Merge pull request #16740 from jakevdp:spdot-general-args PiperOrigin-RevId: 548773744 17 July 2023, 19:52:33 UTC
909df91 Merge pull request #16754 from patrick-kidger:patch-5 PiperOrigin-RevId: 548759179 17 July 2023, 19:00:00 UTC
1b1e74f Merge pull request #16767 from jakevdp:mypy-fix PiperOrigin-RevId: 548752096 17 July 2023, 18:35:07 UTC
8bce54e Add type annotation to `jnp.tensordot` Just stopping pyright from complaining at me. 17 July 2023, 18:30:16 UTC
4bb54d3 mypy: suppress annotation-unchecked notes 17 July 2023, 18:18:48 UTC
f123ba9 Make _parsed_pspec a kw_only argument. This should be backwards compatible since you pass an arg as a kwarg too. PiperOrigin-RevId: 548741701 17 July 2023, 17:59:48 UTC
d8bc033 Merge pull request #16759 from patrick-kidger:patch-6 PiperOrigin-RevId: 548737104 17 July 2023, 17:45:13 UTC
3a78571 Disable tests triggering a known bug in cuda-12. PiperOrigin-RevId: 548727901 17 July 2023, 17:26:12 UTC
d49b67a Disable tests that trigger a known bug in cublasDtrsmBatched in cuda-12 on sm_60. PiperOrigin-RevId: 548727690 17 July 2023, 17:17:21 UTC
back to top