swh:1:snp:71afc25eb6e6e055a37a962e6b91010ec35e397f

sort by:
Revision Author Date Message Commit Date
e27af48 Merge pull request #17989 from skye:version PiperOrigin-RevId: 571418823 06 October 2023, 20:21:49 UTC
d4a1bb9 Update setup.py and CHANGELOG for jax 0.4.18 release 06 October 2023, 20:13:33 UTC
f0e4ea2 Merge pull request #17987 from jakevdp:lax-dep PiperOrigin-RevId: 571401660 06 October 2023, 19:23:20 UTC
6681e64 Merge pull request #17962 from jakevdp:fix-bitwise-count PiperOrigin-RevId: 571398043 06 October 2023, 19:13:11 UTC
fae53d9 Merge pull request #17959 from jakevdp:lax-abs PiperOrigin-RevId: 571397881 06 October 2023, 19:02:02 UTC
5ac2513 Merge pull request #17988 from froystig:rng-spec-jex PiperOrigin-RevId: 571393758 06 October 2023, 18:51:52 UTC
31b680b Merge pull request #17981 from hawkinsp:docs2 PiperOrigin-RevId: 571393427 06 October 2023, 18:41:41 UTC
8c4d020 Improve CUDA install documentation. Mention NCCL as a dependency, since it will be required by the next jaxlib release. Mention LD_LIBRARY_PATH and PATH as how one overrides the CUDA installation for local installs. Fixes #17831 06 October 2023, 18:36:29 UTC
ce6a0c4 jax.lax: deprecate inadvertent exports & internal utilities 06 October 2023, 18:26:03 UTC
7b831ba test custom PRNG impl construction round trip 06 October 2023, 18:18:03 UTC
bd50c61 Merge pull request #17979 from hawkinsp:regex PiperOrigin-RevId: 571379959 06 October 2023, 18:03:15 UTC
2052673 Merge pull request #17978 from hawkinsp:fft PiperOrigin-RevId: 571379922 06 October 2023, 17:52:48 UTC
6dee6f2 Merge pull request #17983 from froystig:rng-spec PiperOrigin-RevId: 571374210 06 October 2023, 17:33:43 UTC
5158e25 identify PRNG schemes on key arrays, and recognize them in key constructors Specifically: * Introduce `jax.random.key_impl`, which accepts a key array and returns a hashable identifier of its PRNG implementation. * Accept this identifier optionally as the `impl` argument to `jax.random.key` and `wrap_key_data`. This now works: ```python k1 = jax.random.key(72, impl='threefry2x32') impl = jax.random.key_impl(k1) k2 = jax.random.key(72, impl=impl) assert arrays_equal(k1, k2) assert k1.dtype == k2.dtype ``` This change also set up an internal PRNG registry and register built-in implementations, to simplify various places where we essentially reconstruct such a registry from scratch (such as in tests). Co-authored-by: Jake Vanderplas <jakevdp@google.com> 06 October 2023, 17:15:08 UTC
4c37c79 Merge pull request #17971 from hawkinsp:tanarm PiperOrigin-RevId: 571344901 06 October 2023, 15:31:04 UTC
19b900d Merge pull request #17968 from jakevdp:random-test-util PiperOrigin-RevId: 571335206 06 October 2023, 14:47:38 UTC
a703ec3 Add support for tiled layouts in arith.extf + implement (2,128)->(8,128) relayout for 32-bit types PiperOrigin-RevId: 571333868 06 October 2023, 14:37:22 UTC
327359c [Mosaic] Strengthen the checks in the tpu.iota apply_vector_layout rule PiperOrigin-RevId: 571328377 06 October 2023, 14:08:28 UTC
5a27121 Fix NumPy warning when casting negative float to unsigned in tan() test on aarch64 under Python 3.12 06 October 2023, 12:45:39 UTC
5311746 Fix tests that mistakenly used assertRaises(..., msg=...) to match a message. 06 October 2023, 12:07:49 UTC
4e1b8fc Check dtypes in fft_p's abstract eval rule. In particular, this catches a bad error when a bfloat16 is passed to rfft. 06 October 2023, 12:04:01 UTC
b580c5d [export] Fix dot_general multi-platform lowering The previous lowering rule for dot_general was using ctx.module_context.platform to customize the lowering per platform. Now we set different lowering rules for different platforms, thus enabling the multi-platform lowering to generate the proper code. This fixes the dot_general multi_platform_export_tests. PiperOrigin-RevId: 571304531 06 October 2023, 11:57:50 UTC
8f08ee9 Update XLA dependency to use revision http://github.com/openxla/xla/commit/fabb55217b915a93b9c008de5a12168880430b04. PiperOrigin-RevId: 571255729 06 October 2023, 07:52:17 UTC
866500c [export] Ensure that we run shape refinement for modules that use multi-platform lowering For multi-platform lowering we use a constant platform index argument threaded through all function calls, and we use conditionals for the lowering of primitives that have multiple lowerings. In many cases, but not all, these conditionals are removed by constant folding prior to conversion to HLO, and the XLA compiler will only see the code for the compilation platform. However, in some cases these conditionals are not constant-folded and the XLA compiler will either see code for other platforms that is does not expect (the TPU tests failing before), or will simply generate slightly different code (e.g., the conv_general_dilated tests on CPU, where we saw numerical differences before). To address this, we ensure that we run shape refinement for modules that use multi-platform lowering. The shape refinement pass already handles inter-procedural constant folding for dimension value arguments. At the moment, the platform index argument is modelled as a dimension value during lowering, so it makes some sense to use the same shape refinement pass to clean it up before compilation. But a cleaner solution would be to separate the shape refinement pass into an interprocedural constant folding, followed by proper shape refinement. Then we'd introduce a separate attribute `jax.needs_constant_folding` in addition to `jax.uses_shape_polymorphism`. This change fixes the remaining failures in the multi_platform_export_test for TPU, and the conv_general_dilated test for CPU. PiperOrigin-RevId: 571254037 06 October 2023, 07:42:03 UTC
28e907c Merge pull request #17932 from mattjj:run-state-pjit PiperOrigin-RevId: 571220044 06 October 2023, 04:39:42 UTC
5715db4 [run_state] add pjit run_state discharge rule and basic test 06 October 2023, 04:14:00 UTC
d1c5cdc Merge pull request #17928 from mattjj:scan-state-nested-test PiperOrigin-RevId: 571212422 06 October 2023, 03:51:50 UTC
f4bb1c0 [Mosaic] apply_vector_layout C++ rewrite (17): for.op (from cl/568376871) PiperOrigin-RevId: 571155269 05 October 2023, 23:02:11 UTC
3d503e0 random_test: remove unnecessary test utilities 05 October 2023, 22:33:14 UTC
8f911e1 random_test: Split into two so that each target is small enough to fit within a medium timeout. PiperOrigin-RevId: 571146766 05 October 2023, 22:28:51 UTC
1512650 [JAX] Keep CPU host callbacks alive via IFRT, rather than by attaching them to the Python object. We need to keep callback objects alive as long as any running executables are alive. It is possible to discard the Python data structures for an executable before the runtime has finished running that executable, which can lead to a use after free. Instead, make the runtime keep host callbacks alive. PiperOrigin-RevId: 571141106 05 October 2023, 22:07:03 UTC
7c353c4 [Mosaic] apply_vector_layout C++ rewrite: Have each rule use its own builder instead of passing a builder in the context Having each rule manage its own builder insertion point means we don't have to worry about updating the insertion point when calling into a rule (which we sometimes do from other rules). It also avoids predefining the insert point to before/after the op. Predefining it doesn't make sense for rules that want to do in-place modifications (i.e. keep the original op) since may need to place new ops before (for operands) the current op and/or after the current op. Another minor advantage is that it lets us use ImplicitLocOpBuilder when appropriate which is more concise. PiperOrigin-RevId: 571133260 05 October 2023, 21:38:20 UTC
70abf9a Merge pull request #17965 from tttttanya:Asynchronous-dispatch-jaxArray-update PiperOrigin-RevId: 571126113 05 October 2023, 21:11:54 UTC
cd18f8e jnp.bitwise_count: fix issue with unsigned ints 05 October 2023, 21:00:31 UTC
9e50b9d Asynchronous dispatch doc update regarding jax.Array migration 05 October 2023, 20:45:51 UTC
61bd34c [Mosaic] infer_memref_layout C++ rewrite PiperOrigin-RevId: 571111789 05 October 2023, 20:28:17 UTC
a2b70e3 Bump shard_count for shard_map_test to fix timeouts. PiperOrigin-RevId: 571109311 05 October 2023, 20:18:10 UTC
4fd6bf4 Re-add JIT to jax.numpy.bitwise_count PiperOrigin-RevId: 571107971 05 October 2023, 20:07:49 UTC
797f577 Expose mlir.ShapePolyLoweringState PiperOrigin-RevId: 571075542 05 October 2023, 18:21:09 UTC
ab4a8e3 [Mosaic] apply_vector_layout C++ rewrite: various bug fixes PiperOrigin-RevId: 571075082 05 October 2023, 18:10:49 UTC
295cecd Update XLA dependency to use revision http://github.com/openxla/xla/commit/e73af4223d463b33565640cff6ec617614fddcd3. PiperOrigin-RevId: 571067534 05 October 2023, 17:54:50 UTC
60029e7 lax.abs: better error for unsigned inputs 05 October 2023, 17:53:08 UTC
68c84a6 [Mosaic] apply_vector_layout: Use shape in generalizes check (in Python) - The addition to the check in the relayout loop in `apply_layout_op` should result in skipping some no-op relayouts - The assert in `disassemble` also needs to be updated because it won't hold now that relayout is skipped more (relayout guarantees the defining layout to be equal to the input layout) PiperOrigin-RevId: 571066259 05 October 2023, 17:44:22 UTC
633f68a [Mosaic] Fix a buggy vector.broadcast rule in apply_vector_layout The rule did not take tiling into account, assuming that it works with 32-bit data that has native tiling. Now, we should have appropriate checks in place, as well as some support for lane broadcasts of tiled values. PiperOrigin-RevId: 570956025 05 October 2023, 09:59:30 UTC
d8a81ba [Mosaic] Handle a larger class of broadcasts with 1-sized trailing dimensions PiperOrigin-RevId: 570947498 05 October 2023, 09:18:49 UTC
2e3a5d6 [Mosaic] Fix a bug when elementwise op can incorrectly propagate a replicated layout PiperOrigin-RevId: 570942842 05 October 2023, 08:54:27 UTC
dd2fcf5 [Mosaic] apply_vector_layout C++ rewrite (16): vector.multi_reduction PiperOrigin-RevId: 570907928 05 October 2023, 05:34:26 UTC
201001b Merge pull request #17921 from gnecula:harness_random PiperOrigin-RevId: 570900835 05 October 2023, 04:50:38 UTC
07e4077 [Mosaic] apply_vector_layout C++ rewrite (15): vector.shape_cast PiperOrigin-RevId: 570891989 05 October 2023, 03:47:12 UTC
98aae41 [XLA:Mosaic] Fix layout attribute string parse in python. PiperOrigin-RevId: 570883549 05 October 2023, 03:09:50 UTC
1c37f50 sparse_test: Split into two so that each target is small enough to fit within a medium timeout. PiperOrigin-RevId: 570882867 05 October 2023, 02:59:03 UTC
465eb21 [pallas] Fix `allow_tf32` value in Triton `dot_general` lowering. `precision` is canonicalized as a tuple or `None`. PiperOrigin-RevId: 570879987 05 October 2023, 02:34:19 UTC
862d676 [Mosaic] apply_vector_layout C++ rewrite (14): tpu.matmul and vector.contract PiperOrigin-RevId: 570877941 05 October 2023, 02:20:00 UTC
991e6ef [Mosaic] apply_vector_layout C++ rewrite (13): scf.if, scf.yield PiperOrigin-RevId: 570845376 04 October 2023, 23:30:53 UTC
9e3d64a Merge pull request #17929 from hawkinsp:torchloader PiperOrigin-RevId: 570839721 04 October 2023, 23:11:24 UTC
6065464 Merge pull request #17938 from jakevdp:bitwise-count PiperOrigin-RevId: 570837601 04 October 2023, 23:00:49 UTC
7be8df9 Merge pull request #17933 from cgarciae:condition-numpy-version PiperOrigin-RevId: 570827957 04 October 2023, 22:27:44 UTC
7df2957 jnp.bitwise_count: call into lax.population_count 04 October 2023, 22:22:33 UTC
7498ffe condition numpy version based on python version 04 October 2023, 21:06:01 UTC
be7f210 Allow duration event listeners to take extra keyword arguments. PiperOrigin-RevId: 570783105 04 October 2023, 20:15:49 UTC
305efe0 random_test: reduce num_generated_cases to avoid timeouts PiperOrigin-RevId: 570781641 04 October 2023, 20:04:44 UTC
933d353 [Mosaic] apply_vector_layout C++ rewrite: generalization for relayouts that reduce sublanes Corresponds to cl/563570338. PiperOrigin-RevId: 570769785 04 October 2023, 19:26:53 UTC
389334d Merge pull request #17925 from jakevdp:rand-loggamma PiperOrigin-RevId: 570765482 04 October 2023, 19:12:40 UTC
d8a0227 Simplify the torch data loader collate function using tree_map. Fixes https://github.com/google/jax/issues/1004 04 October 2023, 18:59:06 UTC
29af93b [run_state] add scan nested test, tweak rule name to mention 'state' 04 October 2023, 18:48:42 UTC
f739a88 jax.random: fix NaN corner-case in loggamma 04 October 2023, 18:40:32 UTC
59d4f44 [Mosaic] apply_vector_layout: arith.constant: erase old op after replacing uses (for non-splat values) PiperOrigin-RevId: 570755366 04 October 2023, 18:38:02 UTC
01372fe Clarify that mesh_utils.create_device_mesh's contiguous_submeshes arg isn't necessary with jax.Array PiperOrigin-RevId: 570751299 04 October 2023, 18:24:23 UTC
923498f `_StateContextManager` now preserves the type of the value it stores. This change is a follow-on to google/jax#16866, which added an ABSL-like API for flags defined with `DEFINE_...`. Here we add a similar typed API for flags defined with `define_..._state`. See https://github.com/abseil/abseil-py/blob/37dad4d356ca9e13f1c533ad6309631b397a2b6b/absl/flags/_flagvalues.py#L1333. PiperOrigin-RevId: 570721827 04 October 2023, 16:49:19 UTC
2fe00f8 [Mosaic] Run verifier after infer_memref_layout PiperOrigin-RevId: 570720278 04 October 2023, 16:38:54 UTC
3cf822d [Mosaic] Allow arbitrary LHS tensor axis size for non-packed matmuls PiperOrigin-RevId: 570717660 04 October 2023, 16:28:10 UTC
c63880b [export] Improve primitive harnesses to use jax.random.key Many tests involving randomness in multi_platform_export_test were failing because the primitive harnesses uses raw uint32 arrays as keys. Change them to use jax.random.keys. 04 October 2023, 13:26:11 UTC
efa987c [export] Add tests for multi-platform and cross-platform export Test for each JAX primitive harness that we can lower it for multiple platforms and then execute it on multiple platforms with the same results as the JAX native execution. This is a large test, covering between 5000-7000 harnesses, depending on the platform. Hundreds of harnesses fail this test. In future work we will address each of the failing harnesses in turn. PiperOrigin-RevId: 570661115 04 October 2023, 12:08:36 UTC
52500d9 [Pallas] Add pretty printing for dma_start DMAs now print as `dma_start a[...] -> b[...] c` PiperOrigin-RevId: 570587172 04 October 2023, 05:10:58 UTC
16a2283 Merge pull request #17911 from hawkinsp:py312 PiperOrigin-RevId: 570558531 04 October 2023, 02:01:14 UTC
58cb0d9 Use the released version of Python 3.12 in the Windows wheel builds. 04 October 2023, 01:51:37 UTC
6c09249 Merge pull request #17910 from hawkinsp:rocm PiperOrigin-RevId: 570555612 04 October 2023, 01:45:12 UTC
efc18e4 [JAX] Obtain NCCL via a stub, rather than linking it statically or dynamically. This shrinks the CUDA jaxlib wheel size by around 80MB. PiperOrigin-RevId: 570554454 04 October 2023, 01:33:58 UTC
578478b Remove references to the ROCm TensorFlow repository in AMD build instructions. 04 October 2023, 01:33:52 UTC
816ebf2 Merge pull request #17909 from skye:version PiperOrigin-RevId: 570551640 04 October 2023, 01:15:17 UTC
82b5838 Update versions and CHANGELOG after jax 0.4.17 release 04 October 2023, 00:54:35 UTC
f319a2b Merge pull request #17908 from ybaturina:update_test_instructions PiperOrigin-RevId: 570546044 04 October 2023, 00:53:35 UTC
2a46654 Merge pull request #17906 from google:tpu_ci_remove_workaround PiperOrigin-RevId: 570545829 04 October 2023, 00:43:33 UTC
6bc184e Merge pull request #17907 from jakevdp:typo PiperOrigin-RevId: 570502792 03 October 2023, 21:46:55 UTC
6cf4ce1 Merge pull request #17885 from jakevdp:bitwise-count PiperOrigin-RevId: 570491042 03 October 2023, 21:09:08 UTC
3fd204c fix typo in deprecation message 03 October 2023, 21:04:09 UTC
a142c59 [Pallas] Enable integer indexing for memrefs in `.at` PiperOrigin-RevId: 570488769 03 October 2023, 20:58:59 UTC
a09fdf6 Add jax.numpy.bitwise_count() 03 October 2023, 20:48:16 UTC
4cb5eee Merge pull request #17905 from skye:version PiperOrigin-RevId: 570481317 03 October 2023, 20:33:22 UTC
b9c602b Update XLA commit and versions for jax 0.4.17 release 03 October 2023, 20:20:25 UTC
3d84808 Merge pull request #17851 from mattjj:readme-tweak PiperOrigin-RevId: 570466125 03 October 2023, 19:40:53 UTC
dddbe43 Update README.md Co-authored-by: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> 03 October 2023, 19:33:41 UTC
24ad445 [Pallas] Add support for pytrees in scalar prefetch PiperOrigin-RevId: 570453699 03 October 2023, 18:52:34 UTC
ee8af09 Fix mock_gpu_test on OSS build. PiperOrigin-RevId: 570436380 03 October 2023, 17:55:20 UTC
c3e73c6 Merge pull request #17760 from superbobry:array-any PiperOrigin-RevId: 570400629 03 October 2023, 15:50:07 UTC
b2ac2de [XLA] Split --xla_detailed_logging_and_dumping debug flag into --xla_detailed_logging and --xla_enable_dumping. We want to suppress detailed logging (notably on TPU, which has pretty verbose detailed logging) separately from disabling HLO dumps. Even if we don't print detailed log information, it's quite surprising if an HLO module doesn't show up in the set of modules dumped by XLA. PiperOrigin-RevId: 570374492 03 October 2023, 13:59:47 UTC
72407f6 Update XLA dependency to use revision http://github.com/openxla/xla/commit/7a19856d74569fd1f765cd03bdee84e3b1fdc579. PiperOrigin-RevId: 570300307 03 October 2023, 07:37:15 UTC
17e259b fix typo: device_fun(c) -> device_fun(x) PiperOrigin-RevId: 570289287 03 October 2023, 06:38:48 UTC
1c796c0 [Pallas] Automatically turn mesh indices -> physical ids for remote DMAs PiperOrigin-RevId: 570221510 03 October 2023, 00:04:15 UTC
17d89ad Fix jax.device_put so it doesn't use tree_map for _check_sharding. This causes it to unnecessarily attempt to unflatten the None return values from _check_sharding into the original tree structure, which is a problem for custom datatypes registered with jax.tree_util that don't accept None values in place of jax arrays. PiperOrigin-RevId: 570189648 02 October 2023, 22:01:03 UTC
back to top