https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
8fb4e08 [XLA:Mosaic] Add type-convert-insertion pass and support shapecast a vector whose dtype is packed and sublane dim size is 1. PiperOrigin-RevId: 619787143 28 March 2024, 05:31:51 UTC
9e86aa5 Add custom call on output along with S(5) because XLA requires the custom call to show the transfer. Enable paramater streaming and weight offloading PiperOrigin-RevId: 619711649 28 March 2024, 00:07:36 UTC
ec73c40 Do not deadlock the GPU if a pure_callback dispatches a GPU kernel PiperOrigin-RevId: 619656442 27 March 2024, 21:26:03 UTC
7f30ab5 Supports PjRt CPU array converting py array when the CPU PjRt arrays have non-default layouts. PiperOrigin-RevId: 619608909 27 March 2024, 19:14:39 UTC
66877c9 Allow allow_spmd_propagation_to_output to be generated for outputs annotated with pjit.AUTO PiperOrigin-RevId: 619608022 27 March 2024, 19:04:03 UTC
c0c918a [export] Increase minimum serialization version to 9. Stop supporting serializing older version. The current max serialization version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024. This change could break clients that set a specific JAX serialization version lower than 9. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions PiperOrigin-RevId: 619588685 27 March 2024, 18:06:23 UTC
023930d Fix some load orderings for buildifier PiperOrigin-RevId: 619575196 27 March 2024, 17:28:57 UTC
b480fd5 Update XLA dependency to use revision http://github.com/openxla/xla/commit/aa4d4447ce3ff449d7accd5b8a41d1575ed58b75. PiperOrigin-RevId: 619429908 27 March 2024, 06:37:47 UTC
fa9f02b Reverts 0dde8f7f9607d09841ece7125dfc0773c3613fab PiperOrigin-RevId: 619416732 27 March 2024, 05:26:41 UTC
0dde8f7 Merge pull request #20445 from mattjj:scan-dont-traverse-body-jaxpr-in-lowering-2 PiperOrigin-RevId: 619387957 27 March 2024, 02:52:11 UTC
0b09762 Guard host transfers inside pure_callbacks from deadlocking the TPU. Also fix python/callback.cc to not swallow errors in numpy conversions. PiperOrigin-RevId: 619375128 27 March 2024, 01:36:39 UTC
9474b46 [scan] don't traverse body jaxpr in lowering This is an attempt to re-land #19819 aka cl/607570860 after a small number of performance regressions. As before, the main changes are: 1. simplify the scan impl that we trace through to get the lowering, and 2. ensure that when tracing it to a jaxpr, we don't rebuild the scan body jaxpr we already have in hand. The main motivation was (2), but (1) seems like a useful win too. The way we achieve (2) is with a new trick: in our scan_impl function, which is only ever traced to a jaxpr, instead of calling `core.jaxpr_as_fun(jaxpr)(*args)` we call a new primitive `eval_jaxpr_p.bind(*args, jaxpr=jaxpr)`. This new primitive only has a staging rule defined for it (i.e. all we can do with it is stage it into a jaxpr), and that rule just generates a call into the jaxpr of interest. Therefore we will not traverse into the jaxpr just to rebuild it inline (as before). The code in #19819 was simpler in that it avoided reshapes, concats, and un-concats. But it caused at least one apparent performance regression (an XLA bug?) and it was unrelated to the original goal of reducing tracing time. So here we just land the trace time improvement. 27 March 2024, 00:17:58 UTC
4a9c8d1 Removed obsolete call to libtpu_module.configure_library_path(). PiperOrigin-RevId: 619340619 26 March 2024, 23:06:41 UTC
6e0c955 Remove the canonicalization to GSPMDSharding internally in jit. This is not required anymore since the caches are split into tracing, lowering and compilation. The canonicalization doesn't provide any value anymore and only makes the internals more complicated. The canonicalization can be done by lowering to HloSharding in places where required and there are utilities to help with that. PiperOrigin-RevId: 619292757 26 March 2024, 20:28:45 UTC
f8d6669 Add commentary and clean the pallas bidirectional collective allgather matmul test. "test_pipeline_all_gather_matmul" is the best demo example of nested pallas pipelines, but it's hard to follow the logic in the existing test. A few changes were made there: - rename things to avoid confusion between outer and inner loop prologues / epilogues. - give clear names for the outer iteration space: (step, phase) to help clarify sequencing of compute and DMAs. - simplify and lift out all async copy definitions and add commentary on their function - remove some incorrect comments about the rDMA schedule, and generally add a ton of commentary about when things happen in the outer pipeline. - lift all the outer prologue work into an integrated prologue function - various other small things. PiperOrigin-RevId: 619254981 26 March 2024, 18:31:38 UTC
76af8be Merge pull request #20436 from pearu:pearu/expm1-2 PiperOrigin-RevId: 619253998 26 March 2024, 18:21:55 UTC
c82f061 Update complex function accurancy tests for expm1 26 March 2024, 17:10:38 UTC
18c885d Removed double-printing of TTIR in Pallas GPU lowering PiperOrigin-RevId: 619208376 26 March 2024, 16:11:39 UTC
c78054d Fix the pjit test failing on v5e PiperOrigin-RevId: 619207394 26 March 2024, 16:02:13 UTC
6afa523 Skip fused_attention_stablehlo_test.py. https://github.com/google/jax/issues/20438 PiperOrigin-RevId: 619204356 26 March 2024, 15:50:50 UTC
75db481 [callback] Fix io_callback for callbacks that return Python literals. The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal. Fixed the checks that the callback returns values of expected shape and dtype, and added tests. Reverts 19e6156ccec0df7a900471df7840bc421da2898b PiperOrigin-RevId: 619156176 26 March 2024, 12:32:41 UTC
33cf53c [XLA:GPU] Add option to return FDO profile as textproto. PiperOrigin-RevId: 619105468 26 March 2024, 08:35:27 UTC
86a7086 Update XLA dependency to use revision http://github.com/openxla/xla/commit/25d813bb12c59d05d731011103f2e7d28483f5d2. PiperOrigin-RevId: 619073296 26 March 2024, 05:53:00 UTC
f93c320 Enable extra args with input output aliasing PiperOrigin-RevId: 619041158 26 March 2024, 03:05:33 UTC
69980a2 Use the information in allow_spmd_sharding_propagation_to_output and allow_spmd_sharding_propagation_to_parameters to determine what input and output tuple elements we are allowed to modfy the shardings of. PiperOrigin-RevId: 619013275 26 March 2024, 00:46:52 UTC
c724eab Merge pull request #20257 from Cjkkkk:sdpa_training PiperOrigin-RevId: 618972494 25 March 2024, 22:09:52 UTC
f51d80e move checks to setup 25 March 2024, 21:11:11 UTC
e3bbd67 Avoid jax_explain_cache_misses unpacking error. PiperOrigin-RevId: 618931412 25 March 2024, 19:55:00 UTC
0be07e6 Remove support for CUDA 11. Pin minimal required versions for CUDA to 12.1. Reverts 910a31d7b7510e3375718ab1ea0d38df7bd2c0d5 PiperOrigin-RevId: 618911489 25 March 2024, 18:46:39 UTC
19e6156 Reverts 2a4e1caac465bb4cb448e7370b23a822cce699e2 PiperOrigin-RevId: 618908505 25 March 2024, 18:36:13 UTC
dacf009 add 4 gpus guard for inference 25 March 2024, 17:19:48 UTC
25d01e9 [Take 2] Expose .layout on jax.Array. Also add checks in the AOT path to make sure that the input Array's layout matches the layout given to jax.jit. Reverts cd79e71d85621a8d6dede9a710bdb2a29bb380fd PiperOrigin-RevId: 618878870 25 March 2024, 17:08:43 UTC
b9e699f Pallas TPU now broadcasts operands when lowering bitwise ops I also used this as an opportunity to bootstrap shared GPU/TPU tests. Closes #20135 PiperOrigin-RevId: 618858137 25 March 2024, 16:00:39 UTC
2a4e1ca [callback] Fix io_callback for callbacks that return Python literals. The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal. Fixed the checks that the callback returns values of expected shape and dtype, and added tests. PiperOrigin-RevId: 618814787 25 March 2024, 12:53:52 UTC
3d8ffd4 Update XLA dependency to use revision http://github.com/openxla/xla/commit/807c6f8dd3891e458f4cc4cd9f3c06dfdfb52484. PiperOrigin-RevId: 618737249 25 March 2024, 06:12:38 UTC
0e5c44a Update XLA dependency to use revision http://github.com/openxla/xla/commit/ec483ea69f3e750cf96674f7f6fdab6604af7f8d. PiperOrigin-RevId: 618558052 24 March 2024, 06:15:12 UTC
1f81477 Adds additional check for x is a Tracer, as looking up sharding attribute on a tracer is expensive. PiperOrigin-RevId: 618489738 23 March 2024, 20:33:02 UTC
30674ae Update XLA dependency to use revision http://github.com/openxla/xla/commit/8a74354fe12808dc04badcc38148d0e909defee1. PiperOrigin-RevId: 618379086 23 March 2024, 06:13:34 UTC
b709925 Update the arg description. `slice_index` attribute has been added for GPU. PiperOrigin-RevId: 618354455 23 March 2024, 03:10:18 UTC
6f0737b [Pallas TPU] Add support for dynamic sized (tile aligned) DMAs This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example: ```python def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): size = size_smem_ref[0] pltpu.async_copy( x_hbm_ref.at[pl.ds(0, size)], o_hbm_ref.at[pl.ds(0, size)], sem).wait() ``` We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA. We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy. However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels. PiperOrigin-RevId: 618322737 22 March 2024, 23:59:32 UTC
6ffd55c Fixing StableHLO python dependencies on stablehlo:reference_api PiperOrigin-RevId: 618294054 22 March 2024, 21:52:06 UTC
910a31d Reverts bed4f65438a62777ed100ecec2b0eb3f7cf87a0e PiperOrigin-RevId: 618249855 22 March 2024, 19:10:53 UTC
a853539 Remove Bazel workspace entries for nanobind and robin_map. The XLA repository has these, no need to duplicate them. PiperOrigin-RevId: 618248779 22 March 2024, 19:01:32 UTC
0b46341 Don't report origin_msg if any execption is raised in `self._origin_msg` PiperOrigin-RevId: 618237231 22 March 2024, 18:23:46 UTC
d7e5dde Remove _maybe_device_put because jax.device_put accepts `None` on the device parameter PiperOrigin-RevId: 618223250 22 March 2024, 17:39:57 UTC
5f467b9 Propagate sharding of inputs to full_like that are capable of carrying sharding as an attribute. Fixes https://github.com/google/jax/issues/20390 PiperOrigin-RevId: 618202319 22 March 2024, 16:30:37 UTC
bed4f65 Remove support for CUDA 11. Pin minimal required versions for CUDA to 12.1. PiperOrigin-RevId: 618195554 22 March 2024, 16:05:39 UTC
c82deb2 Merge pull request #20373 from pearu:pearu/complex-plane-tests-fix PiperOrigin-RevId: 618188699 22 March 2024, 15:39:53 UTC
6695a85 Merge pull request #20370 from jakevdp:key-reuse-flag PiperOrigin-RevId: 618179711 22 March 2024, 15:16:20 UTC
f6ed624 Merge pull request #20342 from ROCm:rocm-export_test-add-rocm-platform PiperOrigin-RevId: 618179680 22 March 2024, 15:06:43 UTC
44be575 Compute axis_index without creating an entire grid of device IDs. For large meshes, this numpy array can exceed the size of SMEM. We can perform the same calculation using just the grid shape. PiperOrigin-RevId: 618167202 22 March 2024, 14:12:52 UTC
2d65571 Really skip exp2 in Pallas GPU tests with older jaxlib PiperOrigin-RevId: 618149873 22 March 2024, 12:42:38 UTC
8949a63 [key reuse] rename flag to jax_debug_key_reuse 22 March 2024, 12:37:30 UTC
cd79e71 Reverts 0e092a77067dbbce33cfd6d54a46e743b779919b PiperOrigin-RevId: 618127324 22 March 2024, 10:46:09 UTC
c2d9528 Update XLA dependency to use revision http://github.com/openxla/xla/commit/c408b6320bbf8407b781487a2dbe329cb7488dac. PiperOrigin-RevId: 618071726 22 March 2024, 06:06:53 UTC
0e092a7 Expose `.layout` on jax.Array. Also add checks in the AOT path to make sure that the input `Array`'s layout matches the layout given to `jax.jit`. PiperOrigin-RevId: 618050680 22 March 2024, 04:02:40 UTC
73aadbf Adding expectations to e2e test with expectations for mesh creation allowing splitting physical axes. PiperOrigin-RevId: 618028683 22 March 2024, 01:58:06 UTC
9523547 Add a fast path for Python scalars to shaped_abstractify. PiperOrigin-RevId: 618015741 22 March 2024, 01:05:10 UTC
07e45c3 Merge pull request #20236 from jakevdp:key-reuse-stack PiperOrigin-RevId: 618014760 22 March 2024, 00:55:57 UTC
d57bb8c Raise a better error message when an invalid input is passed to jit call. Before: ``` TypeError: Argument 'ShapeDtypeStruct(shape=(4, 2), dtype=int32)' of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type. ``` After: ``` TypeError: Argument 'x['b']['c']' of shape int32[4,2] of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type. ``` The error is raised deep down the stack during `shard_arg`, so we raise an `InvalidInputException` and catch it in `_python_pjit_helper` where we have the `arg_names` information. PiperOrigin-RevId: 618014044 22 March 2024, 00:46:32 UTC
7f7e0c0 [Mosaic] Support left shifting relayouts PiperOrigin-RevId: 618008857 22 March 2024, 00:20:30 UTC
8a2ba76 Enable Creating Device Mesh with Physical Axes Splits. PiperOrigin-RevId: 617994892 21 March 2024, 23:21:25 UTC
5532e55 [XLA:Python] Add a C++ implementation of flatten_one_level. Also add a copy of the default registry that doesn't have None registered as a leaf, which is slightly faster than using an is_leaf function. This is mostly just doing an old TODO. PiperOrigin-RevId: 617988496 21 March 2024, 22:57:23 UTC
05e61ed Expose API to control whether to fuse input computation with Pallas kernel on per input basis. PiperOrigin-RevId: 617975104 21 March 2024, 22:07:18 UTC
fdb5015 Evaluate the correctness of JAX complex functions using mpmath as a reference 21 March 2024, 21:35:29 UTC
4c7351f Relax regex matching for DebugString output PiperOrigin-RevId: 617954803 21 March 2024, 20:56:25 UTC
7e60331 [key reuse] print information about key reuse location 21 March 2024, 20:34:26 UTC
2848cda Merge pull request #20341 from ROCm:rocm_add_hipStreamWaitEvent PiperOrigin-RevId: 617893634 21 March 2024, 17:41:38 UTC
291a5cd [PJRT][IFRT] Update PJRT, IFRT, and Py executable getters to return PjRtLayouts PiperOrigin-RevId: 617889924 21 March 2024, 17:30:57 UTC
383ae41 Attempt to eliminate flakiness for jax2tf test. PiperOrigin-RevId: 617878818 21 March 2024, 17:01:44 UTC
c945766 Skip PallasOpsTest.test_elementwise_exp2 on older jaxlib PiperOrigin-RevId: 617873650 21 March 2024, 16:46:13 UTC
0a0d65a Merge pull request #20359 from eltociear:patch-8 PiperOrigin-RevId: 617870272 21 March 2024, 16:35:12 UTC
54d8bde Don't tree_flatten in_shardings and out_shardings each time a jit() is traced. Do it once when the jit is constructed. (In general we do a bit too much switching back and forth between flattened and unflattened representations, and we'd probably do well just to keep things flattened.) PiperOrigin-RevId: 617859205 21 March 2024, 16:00:16 UTC
dd574cb Remove `_python_pjit` and make `_cpp_pjit` the only function wrapper. PiperOrigin-RevId: 617846352 21 March 2024, 15:18:42 UTC
74f4846 Allow disabling checking CUDA constraints at runtime. PiperOrigin-RevId: 617845847 21 March 2024, 15:08:51 UTC
79b1894 Only call inspect.signature once during the initial call to jit(). We call inspect.signature() once for debug information and once for argnum resolving. We can just call it once and reuse the result. PiperOrigin-RevId: 617824439 21 March 2024, 13:36:07 UTC
8bdbaf7 Merge pull request #20357 from gnecula:deprecate_host_callback PiperOrigin-RevId: 617816818 21 March 2024, 12:59:19 UTC
d3e03ff Refactorings to the jit implementation. Notably: * We can share more code between jit/pjit. There's no significant difference between the two, other than the handling of the resource environment, so we can share more of the code. * Rather than having an infer_params callback, we can just teach common_infer_params (now named _infer_params) to handle the resource environment, which is the only meaningful difference. common_infer_params already had to understand the two cases, so there's no reason we need to hoist part of that logic into a callback. * If we slightly alter the role of PjitInfo so it contains only the things we know about a jit() or can deduce from its arguments, we can construct it ahead of time. This does require that we split out a couple of things that we cannot deduce at that time, namely the resource environment and the two layout parameters into separate arguments, but the result reads more cleanly to me. No functional changes intended, this is just to improve readability. PiperOrigin-RevId: 617812557 21 March 2024, 12:37:32 UTC
2bd579b [pallas:triton] Implement lowering rule for broadcasted iota. PiperOrigin-RevId: 617795585 21 March 2024, 11:07:10 UTC
590bec2 [Mosaic] Enable the lightweight stable serialization format This should let us perform changes in the Mosaic IR without breaking backwards compatibility with already serialized artifacts. It will also decrease the potential flakiness of Mosaic version mismatches in libtpu/jaxlib nightly builds. PiperOrigin-RevId: 617777794 21 March 2024, 09:52:40 UTC
c5fa14b Complemented libdevice ops with the ones from the MLIR Math dialect This allows some ops, e.g. jnp.exp, to support half-precision inputs (#20239). PiperOrigin-RevId: 617766573 21 March 2024, 08:55:43 UTC
da1a2ac Update discharge.py minor fix 21 March 2024, 08:51:05 UTC
7d431ad Add support for slicing dynamically-shaped memrefs + DMAs between them This was a little difficult because our current dialect conversion setup assumes 1-1 type conversions. I think everything works out fine for as long as we never pass memrefs between basic blocks (i.e. for as long as we never have memrefs as loop carry or return them from conditionals). TODO: I still need to make sure that the changes to the TPU dialect are backwards-compatible. I am afraid that the signature change in MemRefSliceOp might not be. PiperOrigin-RevId: 617755035 21 March 2024, 07:56:51 UTC
ca59971 [host_callback] Deprecate the jax.experimental.host_callback module. 21 March 2024, 07:11:17 UTC
8920b54 Update XLA dependency to use revision http://github.com/openxla/xla/commit/120f128be262aa8cd558581f4a3733f0c080da3c. PiperOrigin-RevId: 617744460 21 March 2024, 07:01:49 UTC
d4948d8 Merge pull request #20347 from jakevdp:doc-key-concepts PiperOrigin-RevId: 617684691 21 March 2024, 01:37:06 UTC
d1235fa DOC: add key concepts doc This will replace the new content in thinking-in-jax within the new tutorial flow. 21 March 2024, 01:30:55 UTC
342887b Merge pull request #20306 from jakevdp:jax-101 PiperOrigin-RevId: 617682492 21 March 2024, 01:25:40 UTC
d6c07bd DOC: read-through and edit the new jax tutorials 21 March 2024, 01:18:08 UTC
d0819ae remove unnecessary if statement PiperOrigin-RevId: 617653292 20 March 2024, 23:15:15 UTC
d6f074b Improve documentation and types for api_util.resolve_argnums. Prefix some private helpers with a _. No functional changes intended. PiperOrigin-RevId: 617627335 20 March 2024, 21:33:15 UTC
3f13308 Add an experimental build-only continuous cross compile build for MacOS x86 PiperOrigin-RevId: 617611779 20 March 2024, 20:43:45 UTC
9862236 Use jnp.issubdtype instead of np.issubdtype PiperOrigin-RevId: 617610920 20 March 2024, 20:34:00 UTC
089651f Added missing BUILD dependencies for Pallas GPU lowering PiperOrigin-RevId: 617603433 20 March 2024, 20:08:58 UTC
efeaab6 Merge pull request #20329 from rajasekharporeddy:special_matrices PiperOrigin-RevId: 617577306 20 March 2024, 18:43:16 UTC
94027b5 Merge pull request #20343 from Sohl-Dickstein:formatting-nit PiperOrigin-RevId: 617577068 20 March 2024, 18:33:07 UTC
4f13f3a [pallas] Fix method name in `_{load,swap}_jvp` rule. PiperOrigin-RevId: 617568433 20 March 2024, 18:05:22 UTC
8da8d03 Merge pull request #20251 from james77777778:fix-quantile-keepdims PiperOrigin-RevId: 617558805 20 March 2024, 17:36:35 UTC
4d6a53f Add Hilbert matrix to jax.scipy.linalg 20 March 2024, 17:25:03 UTC
1eed31c Fix formatting of documentation 20 March 2024, 17:19:03 UTC
back to top