https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
4cb5eee Merge pull request #17905 from skye:version PiperOrigin-RevId: 570481317 03 October 2023, 20:33:22 UTC
b9c602b Update XLA commit and versions for jax 0.4.17 release 03 October 2023, 20:20:25 UTC
3d84808 Merge pull request #17851 from mattjj:readme-tweak PiperOrigin-RevId: 570466125 03 October 2023, 19:40:53 UTC
dddbe43 Update README.md Co-authored-by: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> 03 October 2023, 19:33:41 UTC
24ad445 [Pallas] Add support for pytrees in scalar prefetch PiperOrigin-RevId: 570453699 03 October 2023, 18:52:34 UTC
ee8af09 Fix mock_gpu_test on OSS build. PiperOrigin-RevId: 570436380 03 October 2023, 17:55:20 UTC
c3e73c6 Merge pull request #17760 from superbobry:array-any PiperOrigin-RevId: 570400629 03 October 2023, 15:50:07 UTC
b2ac2de [XLA] Split --xla_detailed_logging_and_dumping debug flag into --xla_detailed_logging and --xla_enable_dumping. We want to suppress detailed logging (notably on TPU, which has pretty verbose detailed logging) separately from disabling HLO dumps. Even if we don't print detailed log information, it's quite surprising if an HLO module doesn't show up in the set of modules dumped by XLA. PiperOrigin-RevId: 570374492 03 October 2023, 13:59:47 UTC
72407f6 Update XLA dependency to use revision http://github.com/openxla/xla/commit/7a19856d74569fd1f765cd03bdee84e3b1fdc579. PiperOrigin-RevId: 570300307 03 October 2023, 07:37:15 UTC
17e259b fix typo: device_fun(c) -> device_fun(x) PiperOrigin-RevId: 570289287 03 October 2023, 06:38:48 UTC
1c796c0 [Pallas] Automatically turn mesh indices -> physical ids for remote DMAs PiperOrigin-RevId: 570221510 03 October 2023, 00:04:15 UTC
17d89ad Fix jax.device_put so it doesn't use tree_map for _check_sharding. This causes it to unnecessarily attempt to unflatten the None return values from _check_sharding into the original tree structure, which is a problem for custom datatypes registered with jax.tree_util that don't accept None values in place of jax arrays. PiperOrigin-RevId: 570189648 02 October 2023, 22:01:03 UTC
61d2016 Merge pull request #17889 from jakevdp:jnp-axis-types PiperOrigin-RevId: 570168959 02 October 2023, 20:54:36 UTC
c12929b Add more API set up for Mock GPU client. Also clean up previous mock GPU client API. PiperOrigin-RevId: 570153877 02 October 2023, 20:11:29 UTC
d45fa22 Add tests to cover `PyTreeDef.flatten_up_to` error scenarios. Also improve coverage of `PyTreeDef.flatten_up_to` success scenarios. PiperOrigin-RevId: 570152827 02 October 2023, 20:00:55 UTC
2902b32 [typing] allow Sequence inputs in several jax.numpy functions 02 October 2023, 18:48:36 UTC
c60b4ae Merge pull request #17888 from jakevdp:changelog PiperOrigin-RevId: 570127490 02 October 2023, 18:38:21 UTC
9247a62 Add CHANGELOG entry for the jnp annotation change 02 October 2023, 18:31:28 UTC
228e192 [XLA: Python] Update PJRT plugin configuration mapping with bool type PiperOrigin-RevId: 570100527 02 October 2023, 17:08:34 UTC
adba2f0 Add type stubs for jax.numpy. This allows mypy/pytype to obtain accurate types for the public jax.numpy APIs, which is helpful to downstream users of JAX, if not JAX itself. PiperOrigin-RevId: 570058363 02 October 2023, 14:20:20 UTC
0fe420e Merge pull request #17882 from jakevdp:fix-nightly PiperOrigin-RevId: 570049158 02 October 2023, 13:37:02 UTC
c9851ac [Mosaic] Allow vector.shape_cast to (un)fold the sublane dim, for as long as it remains a multiple of sublane tiling The old guards were overly restrictive, and we can actually treat a much larger class of reshapes as no-ops. PiperOrigin-RevId: 570049016 02 October 2023, 13:36:45 UTC
77d11e4 [Mosaic] Reimplement new relayout routines in C++ PiperOrigin-RevId: 570046852 02 October 2023, 13:24:56 UTC
ae65a7c Disable tests for int4 support on non-TPU platforms. An upcoming XLA change will reject programs containing int4 on CPU and GPU, because the XLA support is buggy and incomplete. When the XLA supports this we can reenable these tests. Issue https://github.com/google/jax/issues/17672 PiperOrigin-RevId: 570042917 02 October 2023, 13:04:22 UTC
568fa79 Update test to skip np.bitwise_count ufunc Fixes https://github.com/google/jax/issues/17878 02 October 2023, 12:50:07 UTC
4471abe Merge pull request #17423 from gnecula:export_multi2 PiperOrigin-RevId: 569993648 02 October 2023, 09:28:52 UTC
90286f2 Update XLA dependency to use revision http://github.com/openxla/xla/commit/2233d818422b067750cafa7c5ceba8f292633d40. PiperOrigin-RevId: 569975863 02 October 2023, 07:58:05 UTC
4b5ed34 [jax_export] Add the next part of multi-platform lowering support. We change the lowering rule selection code to work when `ModuleContext.lowering_parameters.platforms` contains multiple string, and emit conditional code to select the lowering based on the platform index argument. These changes will not affect the normal JAX lowering paths (when `ModuleContext.lowering_parameters.platforms` is `None`). It will also not affect the JAX native serialization paths for single platform lowering. These changes should work for most primitives, with the exception of the few ones that actually access `ModuleContext.platform` inside the lowering rules (most primitives just register different rules for different platforms, which is taken into account by these changes). Previous PR in this series: #17316. 02 October 2023, 07:57:35 UTC
5ab05e4 MAINT Clean up leftover `Array = Any` aliases in jax/_src/**.py I had to revert to using `Any` for `RaggedAxis.ragged_axes` because pytype found more latent type errors, which require the understanding of ragedness and dynamic shapes internals to fix properly. 01 October 2023, 11:19:21 UTC
004e67e Update XLA dependency to use revision http://github.com/openxla/xla/commit/9938bdbbf34a4098856ef387e3f1c53398d3f1e1. PiperOrigin-RevId: 569829166 01 October 2023, 08:30:44 UTC
0ae2a63 Add temporary flag for forcing arg tuplization of lowered functions. PiperOrigin-RevId: 569814851 01 October 2023, 06:14:38 UTC
bf46b74 Update XLA dependency to use revision http://github.com/openxla/xla/commit/f62cc8f27a3bf5f628eee6f9e784b20bdd18ad63. PiperOrigin-RevId: 569686162 30 September 2023, 08:01:55 UTC
5379d3d [JAX] Fix a regression in cost_analysis API access for an alternative JAX backend PiperOrigin-RevId: 569664445 30 September 2023, 04:42:33 UTC
095c367 Merge pull request #17864 from hawkinsp:buildwheel PiperOrigin-RevId: 569576169 29 September 2023, 20:34:42 UTC
fa81596 Clean up build_wheel.py and build_gpu_plugin_wheel.py. * Use pathlib.Path object-oriented paths. * Change copy_files() helper to copy many files in one call. * Make copy_files() also make the output directory, if needed. * Format file with pyink --pyink-indentation=2 29 September 2023, 20:08:42 UTC
4d933eb Merge pull request #17860 from google:test_fix PiperOrigin-RevId: 569549919 29 September 2023, 18:53:02 UTC
72b1eb3 Bump NumpyLinalgTest.testEighRankDeficient tolerance Otherwise it sometimes fails on Cloud TPU v5e. 29 September 2023, 18:43:33 UTC
ef6fd2e Bump test tolerance for sqrtm test. This test fails on ARM with a LAPACK built with gfortran 11. PiperOrigin-RevId: 569540626 29 September 2023, 18:15:48 UTC
59d86b2 Correct typo in dtype ValueError() call. PiperOrigin-RevId: 569527985 29 September 2023, 17:31:58 UTC
a32ed7e Bump shard_count for shard_map_test to fix the asan failures PiperOrigin-RevId: 569520202 29 September 2023, 17:02:38 UTC
e6a62fc [PJRT] Split the GpuId() platform constants into CudaId()/RocmId(). Similarly for the GpuName() constant. While most of the time we treat CUDA and ROCm GPUs identically, we sometimes want to distinguish between CUDA and ROCm (e.g., for DLPack exports) and it's helpful if this is encoded in the platform ID. PiperOrigin-RevId: 569513495 29 September 2023, 16:35:16 UTC
9f963d2 [Pallas:TPU] Use ExtUI to widen booleans to signed integer types. Otherwise `true` gets converted to `-1`, which is confusing. PiperOrigin-RevId: 569509184 29 September 2023, 16:16:12 UTC
3b0c789 tweak readme to call out broader numerical computing at the same time, it makes sense to highlight large-scale machine learning as a focus 29 September 2023, 15:59:27 UTC
d314cf0 Merge pull request #17841 from google:tpu_ci_disable_tcmalloc PiperOrigin-RevId: 569469475 29 September 2023, 12:55:48 UTC
88fac56 Update XLA dependency to use revision http://github.com/openxla/xla/commit/99a225c1b96999a27b7094241ec12b3ebf54e6de. PiperOrigin-RevId: 569418541 29 September 2023, 08:14:27 UTC
f94bbc1 Merge pull request #17827 from gnecula:lowering_params PiperOrigin-RevId: 569392664 29 September 2023, 06:08:28 UTC
552fef6 Introduce a LoweringParameters dataclass for easier plumbing There are currently two parameters that are used to configure lowering: lowering_platform (for cross-platform lowering), and override_lowering_rules. Each of them are passed as separate arguments through several layers of lowering internal functions. This is tedious, and error prone. In fact, override_lowering_rules was not plumbed in all places, and due to using default arguments in all places, this leads to silent errors. We foresee introducing other parameters for lowering: for multi-platform lowering, for controlling the lowering of effects. Here is pack all such parameters into a `mlir.LoweringParameters` dataclass and we plumb that through. 29 September 2023, 05:23:05 UTC
3247db7 add tests for host offloading (plus operations) under a custom VJP Co-authored-by: Yash Katariya <yashkatariya@google.com> PiperOrigin-RevId: 569333314 29 September 2023, 00:19:21 UTC
ef241e5 Cloud TPU CI: don't use tcmalloc (temporary workaround for tcmalloc deadlock) 28 September 2023, 23:55:11 UTC
e6f8477 Merge pull request #17826 from superbobry:fix-sort-overloads PiperOrigin-RevId: 569306390 28 September 2023, 22:26:57 UTC
2eca5b3 Add a compile-time version test that verifies CUDA is version 11.8 or newer. Issue https://github.com/google/jax/issues/17829 PiperOrigin-RevId: 569302585 28 September 2023, 22:14:04 UTC
528b035 Fix typing for kwargs. PiperOrigin-RevId: 569300602 28 September 2023, 22:03:20 UTC
8bfe3b9 Roll back https://github.com/tensorflow/tensorflow/commit/f92a70a41e76db5d0829120f46d5b001f89decdf Reverts bb4382f0bce074ab081e1e02871e32ba331d1d46 PiperOrigin-RevId: 569292433 28 September 2023, 21:32:23 UTC
4b107f8 [Mosaic] apply_vector_layout C++ rewrite (12): func.return PiperOrigin-RevId: 569268028 28 September 2023, 20:07:59 UTC
ac27d28 [Mosaic] Add `sqrt` lowering rule. PiperOrigin-RevId: 569260464 28 September 2023, 19:39:52 UTC
fc569b4 [Mosaic] apply_vector_layout C++ rewrite (11): vector.broadcast PiperOrigin-RevId: 569246375 28 September 2023, 18:47:30 UTC
c490a06 Merge pull request #17828 from hawkinsp:tpu PiperOrigin-RevId: 569210363 28 September 2023, 16:47:23 UTC
d0baa1d Fix incorrect backend allowlist in array_interoperability_test. We intended to only enable this test on CPU and GPU, but we were missing a critical "not". 28 September 2023, 14:30:22 UTC
173a270 [Mosaic] Add retiling swizzles required for int8 matmuls Ideally we would skip the swizzle entirely, but it is not always possible at the moment. PiperOrigin-RevId: 569149358 28 September 2023, 12:30:20 UTC
a8b8267 MAINT Reorder the overloads for `lax.sort` `Array` is structurally a `Sequence[Array]`, so the first overload always matches under pytype, which defines `collections.abc.Sequence` as a `Protocol`. See https://github.com/google/pytype/blob/b8f91a37e59535ce28e8bacb60c6453cd5c0ecfa/pytype/stubs/builtins/typing.pytd#L149. 28 September 2023, 11:51:36 UTC
79d0a83 Allow event listners to take extra keyword arguments. PiperOrigin-RevId: 569138957 28 September 2023, 11:38:43 UTC
2d068a1 Update XLA dependency to use revision http://github.com/openxla/xla/commit/0c71a63a8673302610bdc3d01640d91828a330cc. PiperOrigin-RevId: 569108876 28 September 2023, 09:12:36 UTC
a37c292 [Mosaic] apply_vector_layout C++ rewrite (10): vector.extract_strided_slice PiperOrigin-RevId: 569081032 28 September 2023, 06:48:26 UTC
fb90d3e [Mosaic] apply_vector_layout C++ rewrite (9): tpu.repeat PiperOrigin-RevId: 569078893 28 September 2023, 06:36:22 UTC
b1b81ec [Mosaic] apply_vector_layout C++ rewrite (8): tpu.gather, tpu.iota, tpu.trace PiperOrigin-RevId: 569069717 28 September 2023, 05:52:02 UTC
bb4382f Destruct objects owned by `WeakRefLRUCache::CacheEntry` out of band using `GlobalPyRefManager()` This assumes less about whether the thread that destructs `CacheEntry` has GIL or not, which is difficult to reason about due to the `xla::LRUCache`'s use of `std::shared_ptr<CacheEntry>`. The following changes have been made in JAX to accommodate the behavior differences from direct destruction to GC: * Since `PyLoadedExecutable`s cached in `WeakRefLRUCache` are now destructed out of band, `PyClient::LiveExecutables()` calls `GlobalPyRefManager()->CollectGarbage()` to make the returned information accurate and up to date. * `test_jit_reference_dropping` has been updated to call `gc.collect()` before verifying the live executable counts since the destruction of executables owned by weak ref maps is now done out of band as part of `GlobalPyRefManager`'s GC. PiperOrigin-RevId: 569062402 28 September 2023, 05:15:22 UTC
a9dc3c1 [shard_map] internal change to shard_map CI testing PiperOrigin-RevId: 569036873 28 September 2023, 03:06:24 UTC
951298d Relax cuDNN version compatibility test to ignore patch versions. PiperOrigin-RevId: 569020492 28 September 2023, 01:40:05 UTC
5936079 Merge pull request #17792 from jakevdp:mean-cast-f16 PiperOrigin-RevId: 569019549 28 September 2023, 01:30:06 UTC
b18ca05 jnp.mean: for f16 inputs, accumulate in f32 28 September 2023, 01:27:19 UTC
6be860b Clean up some device opt-in/opt-outs in test suite. Use allowlists rather than denylists in a few places. PiperOrigin-RevId: 568968749 27 September 2023, 21:56:00 UTC
5384561 Disable nanobind leak checker in cuda/versions module. The leak checker appears to be sensitive to the destruction order during Python shutdown. PiperOrigin-RevId: 568962933 27 September 2023, 21:43:20 UTC
07fa9dc Fix cupti-related build failure under CUDA 11. cuptiGetErrorMessage was added in CUDA 12.2. PiperOrigin-RevId: 568962562 27 September 2023, 21:33:30 UTC
841aaf4 Merge pull request #17813 from hawkinsp:notes PiperOrigin-RevId: 568958227 27 September 2023, 21:18:33 UTC
c62c6fc [Mosaic] Add `sin` and `clamp` lowering rules and support multiple branches in `cond`. Add a pallas_call test using scan/cond. Improve the error message for lowering exceptions and add a `LoweringException` type. PiperOrigin-RevId: 568945255 27 September 2023, 20:33:43 UTC
b7dfde8 Add notes about the new CUDA version restrictions to the changelog and installation instructions. 27 September 2023, 19:56:47 UTC
87af945 [Mosaic] Add min/max lowering rules for Mosaic. PiperOrigin-RevId: 568929392 27 September 2023, 19:35:00 UTC
1885c49 Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test(). This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design. Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches(). Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test(). PiperOrigin-RevId: 568923117 27 September 2023, 19:10:43 UTC
9404518 [CUDA] Add code to jax initialization that verifies that the CUDA libraries that are found are at least as new as the versions against which JAX was built. This is intended to flag cases where the wrong CUDA libraries are used, either because: * the user self-installed CUDA and that installation is too old, or * the user used the pip package installation, but due to LD_LIBRARY_PATH overrides or similar we didn't end up using the pip-installed version. PiperOrigin-RevId: 568910422 27 September 2023, 18:28:40 UTC
a6a90bd Merge pull request #17808 from google:tpu_ci PiperOrigin-RevId: 568887730 27 September 2023, 17:16:46 UTC
b6bd74f [Cloud TPU CI] Update job names to be more descriptive and add description comment to yaml file. 27 September 2023, 16:56:55 UTC
ae81ac9 Reverts 1a9c94e6265b40c8f1de1e6d920208d648d70fdd PiperOrigin-RevId: 568838127 27 September 2023, 13:58:35 UTC
1a9c94e [export] Set the default export serialization version to 8. This version has been supported by XlaCallModule since July 21, 2023 and we are now past the forward-compatibility window. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions PiperOrigin-RevId: 568777006 27 September 2023, 08:28:03 UTC
350b35c Update XLA dependency to use revision http://github.com/openxla/xla/commit/6b5136043c41a88c72386787141645dc9a8dd441. PiperOrigin-RevId: 568775379 27 September 2023, 08:18:00 UTC
d120dd6 Remove PjRt C API check in `PjRtArray::Reshard` since `PjRtCApiBuffer::CopyToMemorySpace` has been supported. PiperOrigin-RevId: 568754787 27 September 2023, 06:26:26 UTC
6e92c6e Merge pull request #17800 from hawkinsp:build PiperOrigin-RevId: 568717123 27 September 2023, 02:40:00 UTC
6e5409c Add missing raise to build.py 27 September 2023, 02:11:01 UTC
82d9b7f Merge pull request #17799 from hawkinsp:build PiperOrigin-RevId: 568710854 27 September 2023, 02:03:05 UTC
cf28e2c Small improvements to build/build.py Add a --verbose option that logs all shell() commands run by the script. Remove some Python 2 backward compatibility logic related to urllib and shutil. Enable debug logging on Windows wheel builds. Also include setuptools in the build requirements and test for its presence in build.py. 27 September 2023, 01:55:45 UTC
f441385 Merge pull request #17796 from hawkinsp:changelog PiperOrigin-RevId: 568689288 27 September 2023, 00:05:14 UTC
0e24b90 [PJRT C API] Register custom callback for `xla_python_gpu_callback` in plugin module. PiperOrigin-RevId: 568671822 26 September 2023, 22:54:10 UTC
870ba1d Merge pull request #17795 from hawkinsp:platforms PiperOrigin-RevId: 568664494 26 September 2023, 22:25:15 UTC
a2e1f1f Update changelog. Bump the minimum CUDA 12 pip package versions to the current releases. 26 September 2023, 22:21:51 UTC
83f12c5 Fix CI failures from https://github.com/google/jax/pull/17751 26 September 2023, 21:52:37 UTC
385cfc8 Use ModuleNotFoundError when importing cuda_plugin_extension module to be more specific. Therefore other ImportError will not be silenced. PiperOrigin-RevId: 568645824 26 September 2023, 21:14:48 UTC
1d2ccc7 Merge pull request #17791 from hawkinsp:platforms PiperOrigin-RevId: 568636577 26 September 2023, 20:44:16 UTC
56f0f5c Copybara import of the project: -- 4d6722e24bca0475f2a4f81a33e34a2369ab2969 by Peter Hawkins <phawkins@google.com>: Fix detection of whether a GPU-enabled jaxlib is installed. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17791 from hawkinsp:platforms 4d6722e24bca0475f2a4f81a33e34a2369ab2969 PiperOrigin-RevId: 568634101 26 September 2023, 20:33:26 UTC
a7a7066 Fix detection of whether a GPU-enabled jaxlib is installed. 26 September 2023, 20:20:40 UTC
3dbd364 Do not share the build command between building jaxlib and gpu plugin as their commands diverge. PiperOrigin-RevId: 568625959 26 September 2023, 20:03:00 UTC
cecd72e Merge pull request #17751 from hawkinsp:platforms PiperOrigin-RevId: 568622592 26 September 2023, 19:50:17 UTC
back to top