8fb4e08 | Jevin Jiang | 28 March 2024, 05:11:38 UTC | [XLA:Mosaic] Add type-convert-insertion pass and support shapecast a vector whose dtype is packed and sublane dim size is 1. PiperOrigin-RevId: 619787143 | 28 March 2024, 05:31:51 UTC |
9e86aa5 | Yash Katariya | 28 March 2024, 00:06:56 UTC | Add custom call on output along with S(5) because XLA requires the custom call to show the transfer. Enable paramater streaming and weight offloading PiperOrigin-RevId: 619711649 | 28 March 2024, 00:07:36 UTC |
ec73c40 | Sergei Lebedev | 27 March 2024, 21:25:19 UTC | Do not deadlock the GPU if a pure_callback dispatches a GPU kernel PiperOrigin-RevId: 619656442 | 27 March 2024, 21:26:03 UTC |
7f30ab5 | Yunlong Liu | 27 March 2024, 19:05:46 UTC | Supports PjRt CPU array converting py array when the CPU PjRt arrays have non-default layouts. PiperOrigin-RevId: 619608909 | 27 March 2024, 19:14:39 UTC |
66877c9 | jax authors | 27 March 2024, 19:03:18 UTC | Allow allow_spmd_propagation_to_output to be generated for outputs annotated with pjit.AUTO PiperOrigin-RevId: 619608022 | 27 March 2024, 19:04:03 UTC |
c0c918a | George Necula | 27 March 2024, 18:05:25 UTC | [export] Increase minimum serialization version to 9. Stop supporting serializing older version. The current max serialization version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024. This change could break clients that set a specific JAX serialization version lower than 9. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions PiperOrigin-RevId: 619588685 | 27 March 2024, 18:06:23 UTC |
023930d | Michael Hudgins | 27 March 2024, 17:26:38 UTC | Fix some load orderings for buildifier PiperOrigin-RevId: 619575196 | 27 March 2024, 17:28:57 UTC |
b480fd5 | jax authors | 27 March 2024, 06:37:05 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/aa4d4447ce3ff449d7accd5b8a41d1575ed58b75. PiperOrigin-RevId: 619429908 | 27 March 2024, 06:37:47 UTC |
fa9f02b | Matthew Johnson | 27 March 2024, 05:25:16 UTC | Reverts 0dde8f7f9607d09841ece7125dfc0773c3613fab PiperOrigin-RevId: 619416732 | 27 March 2024, 05:26:41 UTC |
0dde8f7 | jax authors | 27 March 2024, 02:52:11 UTC | Merge pull request #20445 from mattjj:scan-dont-traverse-body-jaxpr-in-lowering-2 PiperOrigin-RevId: 619387957 | 27 March 2024, 02:52:11 UTC |
0b09762 | Parker Schuh | 27 March 2024, 01:35:49 UTC | Guard host transfers inside pure_callbacks from deadlocking the TPU. Also fix python/callback.cc to not swallow errors in numpy conversions. PiperOrigin-RevId: 619375128 | 27 March 2024, 01:36:39 UTC |
9474b46 | Matthew Johnson | 15 February 2024, 07:01:02 UTC | [scan] don't traverse body jaxpr in lowering This is an attempt to re-land #19819 aka cl/607570860 after a small number of performance regressions. As before, the main changes are: 1. simplify the scan impl that we trace through to get the lowering, and 2. ensure that when tracing it to a jaxpr, we don't rebuild the scan body jaxpr we already have in hand. The main motivation was (2), but (1) seems like a useful win too. The way we achieve (2) is with a new trick: in our scan_impl function, which is only ever traced to a jaxpr, instead of calling `core.jaxpr_as_fun(jaxpr)(*args)` we call a new primitive `eval_jaxpr_p.bind(*args, jaxpr=jaxpr)`. This new primitive only has a staging rule defined for it (i.e. all we can do with it is stage it into a jaxpr), and that rule just generates a call into the jaxpr of interest. Therefore we will not traverse into the jaxpr just to rebuild it inline (as before). The code in #19819 was simpler in that it avoided reshapes, concats, and un-concats. But it caused at least one apparent performance regression (an XLA bug?) and it was unrelated to the original goal of reducing tracing time. So here we just land the trace time improvement. | 27 March 2024, 00:17:58 UTC |
4a9c8d1 | Jieying Luo | 26 March 2024, 23:05:50 UTC | Removed obsolete call to libtpu_module.configure_library_path(). PiperOrigin-RevId: 619340619 | 26 March 2024, 23:06:41 UTC |
6e0c955 | Yash Katariya | 26 March 2024, 20:28:03 UTC | Remove the canonicalization to GSPMDSharding internally in jit. This is not required anymore since the caches are split into tracing, lowering and compilation. The canonicalization doesn't provide any value anymore and only makes the internals more complicated. The canonicalization can be done by lowering to HloSharding in places where required and there are utilities to help with that. PiperOrigin-RevId: 619292757 | 26 March 2024, 20:28:45 UTC |
f8d6669 | Anselm Levskaya | 26 March 2024, 18:23:59 UTC | Add commentary and clean the pallas bidirectional collective allgather matmul test. "test_pipeline_all_gather_matmul" is the best demo example of nested pallas pipelines, but it's hard to follow the logic in the existing test. A few changes were made there: - rename things to avoid confusion between outer and inner loop prologues / epilogues. - give clear names for the outer iteration space: (step, phase) to help clarify sequencing of compute and DMAs. - simplify and lift out all async copy definitions and add commentary on their function - remove some incorrect comments about the rDMA schedule, and generally add a ton of commentary about when things happen in the outer pipeline. - lift all the outer prologue work into an integrated prologue function - various other small things. PiperOrigin-RevId: 619254981 | 26 March 2024, 18:31:38 UTC |
76af8be | jax authors | 26 March 2024, 18:21:55 UTC | Merge pull request #20436 from pearu:pearu/expm1-2 PiperOrigin-RevId: 619253998 | 26 March 2024, 18:21:55 UTC |
c82f061 | Pearu Peterson | 26 March 2024, 10:35:54 UTC | Update complex function accurancy tests for expm1 | 26 March 2024, 17:10:38 UTC |
18c885d | Sergei Lebedev | 26 March 2024, 16:04:11 UTC | Removed double-printing of TTIR in Pallas GPU lowering PiperOrigin-RevId: 619208376 | 26 March 2024, 16:11:39 UTC |
c78054d | Yash Katariya | 26 March 2024, 16:01:25 UTC | Fix the pjit test failing on v5e PiperOrigin-RevId: 619207394 | 26 March 2024, 16:02:13 UTC |
6afa523 | jax authors | 26 March 2024, 15:50:12 UTC | Skip fused_attention_stablehlo_test.py. https://github.com/google/jax/issues/20438 PiperOrigin-RevId: 619204356 | 26 March 2024, 15:50:50 UTC |
75db481 | George Necula | 26 March 2024, 12:32:00 UTC | [callback] Fix io_callback for callbacks that return Python literals. The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal. Fixed the checks that the callback returns values of expected shape and dtype, and added tests. Reverts 19e6156ccec0df7a900471df7840bc421da2898b PiperOrigin-RevId: 619156176 | 26 March 2024, 12:32:41 UTC |
33cf53c | George Karpenkov | 26 March 2024, 08:34:47 UTC | [XLA:GPU] Add option to return FDO profile as textproto. PiperOrigin-RevId: 619105468 | 26 March 2024, 08:35:27 UTC |
86a7086 | jax authors | 26 March 2024, 05:52:14 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/25d813bb12c59d05d731011103f2e7d28483f5d2. PiperOrigin-RevId: 619073296 | 26 March 2024, 05:53:00 UTC |
f93c320 | Sharad Vikram | 26 March 2024, 03:04:42 UTC | Enable extra args with input output aliasing PiperOrigin-RevId: 619041158 | 26 March 2024, 03:05:33 UTC |
69980a2 | jax authors | 26 March 2024, 00:45:59 UTC | Use the information in allow_spmd_sharding_propagation_to_output and allow_spmd_sharding_propagation_to_parameters to determine what input and output tuple elements we are allowed to modfy the shardings of. PiperOrigin-RevId: 619013275 | 26 March 2024, 00:46:52 UTC |
c724eab | jax authors | 25 March 2024, 22:09:52 UTC | Merge pull request #20257 from Cjkkkk:sdpa_training PiperOrigin-RevId: 618972494 | 25 March 2024, 22:09:52 UTC |
f51d80e | Cjkkkk | 25 March 2024, 21:11:11 UTC | move checks to setup | 25 March 2024, 21:11:11 UTC |
e3bbd67 | jax authors | 25 March 2024, 19:54:06 UTC | Avoid jax_explain_cache_misses unpacking error. PiperOrigin-RevId: 618931412 | 25 March 2024, 19:55:00 UTC |
0be07e6 | jax authors | 25 March 2024, 18:44:40 UTC | Remove support for CUDA 11. Pin minimal required versions for CUDA to 12.1. Reverts 910a31d7b7510e3375718ab1ea0d38df7bd2c0d5 PiperOrigin-RevId: 618911489 | 25 March 2024, 18:46:39 UTC |
19e6156 | Qiao Zhang | 25 March 2024, 18:35:27 UTC | Reverts 2a4e1caac465bb4cb448e7370b23a822cce699e2 PiperOrigin-RevId: 618908505 | 25 March 2024, 18:36:13 UTC |
dacf009 | Cjkkkk | 25 March 2024, 17:19:48 UTC | add 4 gpus guard for inference | 25 March 2024, 17:19:48 UTC |
25d01e9 | Yash Katariya | 25 March 2024, 17:07:55 UTC | [Take 2] Expose .layout on jax.Array. Also add checks in the AOT path to make sure that the input Array's layout matches the layout given to jax.jit. Reverts cd79e71d85621a8d6dede9a710bdb2a29bb380fd PiperOrigin-RevId: 618878870 | 25 March 2024, 17:08:43 UTC |
b9e699f | Sergei Lebedev | 25 March 2024, 15:59:59 UTC | Pallas TPU now broadcasts operands when lowering bitwise ops I also used this as an opportunity to bootstrap shared GPU/TPU tests. Closes #20135 PiperOrigin-RevId: 618858137 | 25 March 2024, 16:00:39 UTC |
2a4e1ca | George Necula | 25 March 2024, 12:53:11 UTC | [callback] Fix io_callback for callbacks that return Python literals. The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal. Fixed the checks that the callback returns values of expected shape and dtype, and added tests. PiperOrigin-RevId: 618814787 | 25 March 2024, 12:53:52 UTC |
3d8ffd4 | jax authors | 25 March 2024, 06:11:53 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/807c6f8dd3891e458f4cc4cd9f3c06dfdfb52484. PiperOrigin-RevId: 618737249 | 25 March 2024, 06:12:38 UTC |
0e5c44a | jax authors | 24 March 2024, 06:14:29 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/ec483ea69f3e750cf96674f7f6fdab6604af7f8d. PiperOrigin-RevId: 618558052 | 24 March 2024, 06:15:12 UTC |
1f81477 | Mark Sandler | 23 March 2024, 20:32:18 UTC | Adds additional check for x is a Tracer, as looking up sharding attribute on a tracer is expensive. PiperOrigin-RevId: 618489738 | 23 March 2024, 20:33:02 UTC |
30674ae | jax authors | 23 March 2024, 06:12:52 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/8a74354fe12808dc04badcc38148d0e909defee1. PiperOrigin-RevId: 618379086 | 23 March 2024, 06:13:34 UTC |
b709925 | Changhui Lin | 23 March 2024, 03:09:31 UTC | Update the arg description. `slice_index` attribute has been added for GPU. PiperOrigin-RevId: 618354455 | 23 March 2024, 03:10:18 UTC |
6f0737b | Sharad Vikram | 22 March 2024, 23:58:45 UTC | [Pallas TPU] Add support for dynamic sized (tile aligned) DMAs This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example: ```python def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): size = size_smem_ref[0] pltpu.async_copy( x_hbm_ref.at[pl.ds(0, size)], o_hbm_ref.at[pl.ds(0, size)], sem).wait() ``` We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA. We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy. However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels. PiperOrigin-RevId: 618322737 | 22 March 2024, 23:59:32 UTC |
6ffd55c | Sandeep Dasgupta | 22 March 2024, 21:51:12 UTC | Fixing StableHLO python dependencies on stablehlo:reference_api PiperOrigin-RevId: 618294054 | 22 March 2024, 21:52:06 UTC |
910a31d | jax authors | 22 March 2024, 19:03:29 UTC | Reverts bed4f65438a62777ed100ecec2b0eb3f7cf87a0e PiperOrigin-RevId: 618249855 | 22 March 2024, 19:10:53 UTC |
a853539 | Peter Hawkins | 22 March 2024, 19:00:40 UTC | Remove Bazel workspace entries for nanobind and robin_map. The XLA repository has these, no need to duplicate them. PiperOrigin-RevId: 618248779 | 22 March 2024, 19:01:32 UTC |
0b46341 | Yash Katariya | 22 March 2024, 18:22:56 UTC | Don't report origin_msg if any execption is raised in `self._origin_msg` PiperOrigin-RevId: 618237231 | 22 March 2024, 18:23:46 UTC |
d7e5dde | Yash Katariya | 22 March 2024, 17:39:08 UTC | Remove _maybe_device_put because jax.device_put accepts `None` on the device parameter PiperOrigin-RevId: 618223250 | 22 March 2024, 17:39:57 UTC |
5f467b9 | Yash Katariya | 22 March 2024, 16:29:47 UTC | Propagate sharding of inputs to full_like that are capable of carrying sharding as an attribute. Fixes https://github.com/google/jax/issues/20390 PiperOrigin-RevId: 618202319 | 22 March 2024, 16:30:37 UTC |
bed4f65 | jax authors | 22 March 2024, 16:04:48 UTC | Remove support for CUDA 11. Pin minimal required versions for CUDA to 12.1. PiperOrigin-RevId: 618195554 | 22 March 2024, 16:05:39 UTC |
c82deb2 | jax authors | 22 March 2024, 15:39:53 UTC | Merge pull request #20373 from pearu:pearu/complex-plane-tests-fix PiperOrigin-RevId: 618188699 | 22 March 2024, 15:39:53 UTC |
6695a85 | jax authors | 22 March 2024, 15:16:20 UTC | Merge pull request #20370 from jakevdp:key-reuse-flag PiperOrigin-RevId: 618179711 | 22 March 2024, 15:16:20 UTC |
f6ed624 | jax authors | 22 March 2024, 15:06:43 UTC | Merge pull request #20342 from ROCm:rocm-export_test-add-rocm-platform PiperOrigin-RevId: 618179680 | 22 March 2024, 15:06:43 UTC |
44be575 | jax authors | 22 March 2024, 14:11:45 UTC | Compute axis_index without creating an entire grid of device IDs. For large meshes, this numpy array can exceed the size of SMEM. We can perform the same calculation using just the grid shape. PiperOrigin-RevId: 618167202 | 22 March 2024, 14:12:52 UTC |
2d65571 | Sergei Lebedev | 22 March 2024, 12:41:51 UTC | Really skip exp2 in Pallas GPU tests with older jaxlib PiperOrigin-RevId: 618149873 | 22 March 2024, 12:42:38 UTC |
8949a63 | Jake VanderPlas | 21 March 2024, 17:47:16 UTC | [key reuse] rename flag to jax_debug_key_reuse | 22 March 2024, 12:37:30 UTC |
cd79e71 | jax authors | 22 March 2024, 10:45:15 UTC | Reverts 0e092a77067dbbce33cfd6d54a46e743b779919b PiperOrigin-RevId: 618127324 | 22 March 2024, 10:46:09 UTC |
c2d9528 | jax authors | 22 March 2024, 06:06:03 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/c408b6320bbf8407b781487a2dbe329cb7488dac. PiperOrigin-RevId: 618071726 | 22 March 2024, 06:06:53 UTC |
0e092a7 | Yash Katariya | 22 March 2024, 04:01:50 UTC | Expose `.layout` on jax.Array. Also add checks in the AOT path to make sure that the input `Array`'s layout matches the layout given to `jax.jit`. PiperOrigin-RevId: 618050680 | 22 March 2024, 04:02:40 UTC |
73aadbf | jax authors | 22 March 2024, 01:57:16 UTC | Adding expectations to e2e test with expectations for mesh creation allowing splitting physical axes. PiperOrigin-RevId: 618028683 | 22 March 2024, 01:58:06 UTC |
9523547 | Peter Hawkins | 22 March 2024, 00:54:08 UTC | Add a fast path for Python scalars to shaped_abstractify. PiperOrigin-RevId: 618015741 | 22 March 2024, 01:05:10 UTC |
07e45c3 | jax authors | 22 March 2024, 00:55:57 UTC | Merge pull request #20236 from jakevdp:key-reuse-stack PiperOrigin-RevId: 618014760 | 22 March 2024, 00:55:57 UTC |
d57bb8c | Yash Katariya | 22 March 2024, 00:45:44 UTC | Raise a better error message when an invalid input is passed to jit call. Before: ``` TypeError: Argument 'ShapeDtypeStruct(shape=(4, 2), dtype=int32)' of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type. ``` After: ``` TypeError: Argument 'x['b']['c']' of shape int32[4,2] of type <class 'jax._src.api.ShapeDtypeStruct'> is not a valid JAX type. ``` The error is raised deep down the stack during `shard_arg`, so we raise an `InvalidInputException` and catch it in `_python_pjit_helper` where we have the `arg_names` information. PiperOrigin-RevId: 618014044 | 22 March 2024, 00:46:32 UTC |
7f7e0c0 | Tomás Longeri | 22 March 2024, 00:19:44 UTC | [Mosaic] Support left shifting relayouts PiperOrigin-RevId: 618008857 | 22 March 2024, 00:20:30 UTC |
8a2ba76 | jax authors | 21 March 2024, 23:20:37 UTC | Enable Creating Device Mesh with Physical Axes Splits. PiperOrigin-RevId: 617994892 | 21 March 2024, 23:21:25 UTC |
5532e55 | Peter Hawkins | 21 March 2024, 22:56:42 UTC | [XLA:Python] Add a C++ implementation of flatten_one_level. Also add a copy of the default registry that doesn't have None registered as a leaf, which is slightly faster than using an is_leaf function. This is mostly just doing an old TODO. PiperOrigin-RevId: 617988496 | 21 March 2024, 22:57:23 UTC |
05e61ed | jax authors | 21 March 2024, 22:06:17 UTC | Expose API to control whether to fuse input computation with Pallas kernel on per input basis. PiperOrigin-RevId: 617975104 | 21 March 2024, 22:07:18 UTC |
fdb5015 | Pearu Peterson | 21 March 2024, 21:35:29 UTC | Evaluate the correctness of JAX complex functions using mpmath as a reference | 21 March 2024, 21:35:29 UTC |
4c7351f | jax authors | 21 March 2024, 20:55:45 UTC | Relax regex matching for DebugString output PiperOrigin-RevId: 617954803 | 21 March 2024, 20:56:25 UTC |
7e60331 | Jake VanderPlas | 21 March 2024, 20:34:26 UTC | [key reuse] print information about key reuse location | 21 March 2024, 20:34:26 UTC |
2848cda | jax authors | 21 March 2024, 17:41:38 UTC | Merge pull request #20341 from ROCm:rocm_add_hipStreamWaitEvent PiperOrigin-RevId: 617893634 | 21 March 2024, 17:41:38 UTC |
291a5cd | Yue Sheng | 21 March 2024, 17:30:09 UTC | [PJRT][IFRT] Update PJRT, IFRT, and Py executable getters to return PjRtLayouts PiperOrigin-RevId: 617889924 | 21 March 2024, 17:30:57 UTC |
383ae41 | George Necula | 21 March 2024, 17:00:55 UTC | Attempt to eliminate flakiness for jax2tf test. PiperOrigin-RevId: 617878818 | 21 March 2024, 17:01:44 UTC |
c945766 | Sergei Lebedev | 21 March 2024, 16:45:17 UTC | Skip PallasOpsTest.test_elementwise_exp2 on older jaxlib PiperOrigin-RevId: 617873650 | 21 March 2024, 16:46:13 UTC |
0a0d65a | jax authors | 21 March 2024, 16:35:12 UTC | Merge pull request #20359 from eltociear:patch-8 PiperOrigin-RevId: 617870272 | 21 March 2024, 16:35:12 UTC |
54d8bde | Peter Hawkins | 21 March 2024, 15:59:28 UTC | Don't tree_flatten in_shardings and out_shardings each time a jit() is traced. Do it once when the jit is constructed. (In general we do a bit too much switching back and forth between flattened and unflattened representations, and we'd probably do well just to keep things flattened.) PiperOrigin-RevId: 617859205 | 21 March 2024, 16:00:16 UTC |
dd574cb | Yash Katariya | 21 March 2024, 15:09:37 UTC | Remove `_python_pjit` and make `_cpp_pjit` the only function wrapper. PiperOrigin-RevId: 617846352 | 21 March 2024, 15:18:42 UTC |
74f4846 | jax authors | 21 March 2024, 15:08:00 UTC | Allow disabling checking CUDA constraints at runtime. PiperOrigin-RevId: 617845847 | 21 March 2024, 15:08:51 UTC |
79b1894 | Peter Hawkins | 21 March 2024, 13:35:20 UTC | Only call inspect.signature once during the initial call to jit(). We call inspect.signature() once for debug information and once for argnum resolving. We can just call it once and reuse the result. PiperOrigin-RevId: 617824439 | 21 March 2024, 13:36:07 UTC |
8bdbaf7 | jax authors | 21 March 2024, 12:59:19 UTC | Merge pull request #20357 from gnecula:deprecate_host_callback PiperOrigin-RevId: 617816818 | 21 March 2024, 12:59:19 UTC |
d3e03ff | Peter Hawkins | 21 March 2024, 12:35:44 UTC | Refactorings to the jit implementation. Notably: * We can share more code between jit/pjit. There's no significant difference between the two, other than the handling of the resource environment, so we can share more of the code. * Rather than having an infer_params callback, we can just teach common_infer_params (now named _infer_params) to handle the resource environment, which is the only meaningful difference. common_infer_params already had to understand the two cases, so there's no reason we need to hoist part of that logic into a callback. * If we slightly alter the role of PjitInfo so it contains only the things we know about a jit() or can deduce from its arguments, we can construct it ahead of time. This does require that we split out a couple of things that we cannot deduce at that time, namely the resource environment and the two layout parameters into separate arguments, but the result reads more cleanly to me. No functional changes intended, this is just to improve readability. PiperOrigin-RevId: 617812557 | 21 March 2024, 12:37:32 UTC |
2bd579b | Chris Jones | 21 March 2024, 11:06:30 UTC | [pallas:triton] Implement lowering rule for broadcasted iota. PiperOrigin-RevId: 617795585 | 21 March 2024, 11:07:10 UTC |
590bec2 | Adam Paszke | 21 March 2024, 09:51:46 UTC | [Mosaic] Enable the lightweight stable serialization format This should let us perform changes in the Mosaic IR without breaking backwards compatibility with already serialized artifacts. It will also decrease the potential flakiness of Mosaic version mismatches in libtpu/jaxlib nightly builds. PiperOrigin-RevId: 617777794 | 21 March 2024, 09:52:40 UTC |
c5fa14b | Sergei Lebedev | 21 March 2024, 08:55:08 UTC | Complemented libdevice ops with the ones from the MLIR Math dialect This allows some ops, e.g. jnp.exp, to support half-precision inputs (#20239). PiperOrigin-RevId: 617766573 | 21 March 2024, 08:55:43 UTC |
da1a2ac | Ikko Eltociear Ashimine | 21 March 2024, 08:51:05 UTC | Update discharge.py minor fix | 21 March 2024, 08:51:05 UTC |
7d431ad | Adam Paszke | 21 March 2024, 07:55:58 UTC | Add support for slicing dynamically-shaped memrefs + DMAs between them This was a little difficult because our current dialect conversion setup assumes 1-1 type conversions. I think everything works out fine for as long as we never pass memrefs between basic blocks (i.e. for as long as we never have memrefs as loop carry or return them from conditionals). TODO: I still need to make sure that the changes to the TPU dialect are backwards-compatible. I am afraid that the signature change in MemRefSliceOp might not be. PiperOrigin-RevId: 617755035 | 21 March 2024, 07:56:51 UTC |
ca59971 | George Necula | 21 March 2024, 06:18:57 UTC | [host_callback] Deprecate the jax.experimental.host_callback module. | 21 March 2024, 07:11:17 UTC |
8920b54 | jax authors | 21 March 2024, 07:01:04 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/120f128be262aa8cd558581f4a3733f0c080da3c. PiperOrigin-RevId: 617744460 | 21 March 2024, 07:01:49 UTC |
d4948d8 | jax authors | 21 March 2024, 01:37:06 UTC | Merge pull request #20347 from jakevdp:doc-key-concepts PiperOrigin-RevId: 617684691 | 21 March 2024, 01:37:06 UTC |
d1235fa | Jake VanderPlas | 21 March 2024, 01:20:02 UTC | DOC: add key concepts doc This will replace the new content in thinking-in-jax within the new tutorial flow. | 21 March 2024, 01:30:55 UTC |
342887b | jax authors | 21 March 2024, 01:25:40 UTC | Merge pull request #20306 from jakevdp:jax-101 PiperOrigin-RevId: 617682492 | 21 March 2024, 01:25:40 UTC |
d6c07bd | Jake VanderPlas | 21 March 2024, 01:18:08 UTC | DOC: read-through and edit the new jax tutorials | 21 March 2024, 01:18:08 UTC |
d0819ae | jax authors | 20 March 2024, 23:14:22 UTC | remove unnecessary if statement PiperOrigin-RevId: 617653292 | 20 March 2024, 23:15:15 UTC |
d6f074b | Peter Hawkins | 20 March 2024, 21:32:25 UTC | Improve documentation and types for api_util.resolve_argnums. Prefix some private helpers with a _. No functional changes intended. PiperOrigin-RevId: 617627335 | 20 March 2024, 21:33:15 UTC |
3f13308 | Kanglan Tang | 20 March 2024, 20:36:27 UTC | Add an experimental build-only continuous cross compile build for MacOS x86 PiperOrigin-RevId: 617611779 | 20 March 2024, 20:43:45 UTC |
9862236 | Sergei Lebedev | 20 March 2024, 20:33:17 UTC | Use jnp.issubdtype instead of np.issubdtype PiperOrigin-RevId: 617610920 | 20 March 2024, 20:34:00 UTC |
089651f | Sergei Lebedev | 20 March 2024, 20:08:09 UTC | Added missing BUILD dependencies for Pallas GPU lowering PiperOrigin-RevId: 617603433 | 20 March 2024, 20:08:58 UTC |
efeaab6 | jax authors | 20 March 2024, 18:43:16 UTC | Merge pull request #20329 from rajasekharporeddy:special_matrices PiperOrigin-RevId: 617577306 | 20 March 2024, 18:43:16 UTC |
94027b5 | jax authors | 20 March 2024, 18:33:07 UTC | Merge pull request #20343 from Sohl-Dickstein:formatting-nit PiperOrigin-RevId: 617577068 | 20 March 2024, 18:33:07 UTC |
4f13f3a | Chris Jones | 20 March 2024, 18:04:33 UTC | [pallas] Fix method name in `_{load,swap}_jvp` rule. PiperOrigin-RevId: 617568433 | 20 March 2024, 18:05:22 UTC |
8da8d03 | jax authors | 20 March 2024, 17:36:35 UTC | Merge pull request #20251 from james77777778:fix-quantile-keepdims PiperOrigin-RevId: 617558805 | 20 March 2024, 17:36:35 UTC |
4d6a53f | rajasekharporeddy | 20 March 2024, 17:25:03 UTC | Add Hilbert matrix to jax.scipy.linalg | 20 March 2024, 17:25:03 UTC |
1eed31c | Jascha Sohl-Dickstein | 20 March 2024, 17:19:03 UTC | Fix formatting of documentation | 20 March 2024, 17:19:03 UTC |