https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
88408e1 Remove stale references to //jaxlib:setup.cfg in Bazel build. Fixes broken jaxlib wheel build. 03 September 2023, 19:18:25 UTC
01c068e [callback] Some test cleanup. Removes callback testing function and uses io_callback and pure_callback instead. This allows us to remove some tests from the PureCallbackTest class. Renames IoPythonCallbackTest -> IoCallbackTest and PurePythonCallbackTest -> PureCallbackTest. PiperOrigin-RevId: 562285255 03 September 2023, 04:51:07 UTC
da87d78 Update XLA dependency to use revision http://github.com/openxla/xla/commit/7bd5f555f0e86efb00005dbc9e38991d5fb9765b. PiperOrigin-RevId: 562181905 02 September 2023, 11:51:09 UTC
e9638da Fix compatibility problem with the build 1.0 release. It no longer appears possible to pass platform tags to bdist_wheel on the build command line without triggering a deprecation warning from bdist_wheel. However, we can just pass the platform tag via a generated setup.cfg instead. It does not appear necessary to pass a Python tag, since the default behavior is correct. PiperOrigin-RevId: 562073096 01 September 2023, 23:18:33 UTC
0ddbe76 Update XLA dependency to use revision http://github.com/openxla/xla/commit/7b9aa07fe9f83bb7af898092fdde3d4310b7906c. PiperOrigin-RevId: 561928858 01 September 2023, 12:35:23 UTC
73275aa [Mosaic] Add support for inserting a new lane dimension This is often useful when a kernel uses statistics tensors that are constant across the minormost dimensions. Right now the only way to use them is to force XLA to insert the extra dimension before the kernel, but that turns out to be very inefficient. PiperOrigin-RevId: 561903222 01 September 2023, 10:02:02 UTC
efaea8e [callback] Enable device_index support in terms of callback sharding support. This is part of deprecating host_callback and moving to io_callback. PiperOrigin-RevId: 561856023 01 September 2023, 05:31:35 UTC
e0a6230 [host_callback] Delete unused code paths. This is part of deprecating host_callback and moving to io_callback. PiperOrigin-RevId: 561851494 01 September 2023, 05:08:23 UTC
70b58bb rolling forward shard_map transpose fixes The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests! The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those! The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule. Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa PiperOrigin-RevId: 561805840 01 September 2023, 00:31:21 UTC
c38f670 Hash serialized CompileOptions for new cache key generation. The original cache key generation hashes individual fields of CompileOptions, ExecutableBuildOptions, and DebugOptions. This is not future proof: when a field is added to any of these structures, the corresponding hash needs to be added to the cache key generation. The new cache key generation algorithm hashes the serialized representation of CompileOptions. Some DebugOptions do not affect the compilation result; exclude them from the computation. If additional fields are identified, they can be added; such additions will reduce unnecessary cache misses. Testing: revised unit test. PiperOrigin-RevId: 561803875 01 September 2023, 00:21:57 UTC
da097e7 Don't hold a reference to _lapack.initialize(). Nanobind prints a warning if a reference to a nanobind-bound function is held at process atexit() time. But there's no particularly good reason we need to hold a function reference that long anyway in this case. PiperOrigin-RevId: 561795469 31 August 2023, 23:46:04 UTC
9f7a19a [jax_triton] Improve error message when shared memory exceeds that available on the GPU. PiperOrigin-RevId: 561792406 31 August 2023, 23:32:18 UTC
ccb8814 Make apply_primitive go via C++ fast dispatch. This leads to a ~30% faster dispatch time. Ideally, we should replace this with jit, but that has it's own set of problems that I will look into later. ``` eager_unary_dispatch 40.3µs ± 2% 29.2µs ± 9% -27.51% (p=0.008 n=5+5) eager_unary 40.6µs ± 0% 31.1µs ±11% -23.41% (p=0.016 n=4+5) eager_binary_dispatch 49.6µs ± 0% 34.5µs ± 8% -30.58% (p=0.016 n=4+5) eager_binary 50.2µs ± 1% 35.4µs ± 9% -29.38% (p=0.016 n=4+5) bench_remat_eager_retracing_overheads 13.0ms ± 1% 11.3ms ± 8% -13.26% (p=0.008 n=5+5) bench_remat_eager_retracing_overheads_static_argnums 13.3ms ± 0% 12.3ms ± 6% -7.34% (p=0.016 n=4+5) bench_repeated_static_indexing 112ms ± 2% 82ms ± 5% -26.46% (p=0.008 n=5+5) bench_repeated_static_slicing 90.5ms ± 1% 68.3ms ± 5% -24.54% (p=0.008 n=5+5) ``` PiperOrigin-RevId: 561774696 31 August 2023, 22:25:11 UTC
6c1b4b9 [pallas] Remove kernel_name and kernel_regeneration_metadata from backend config This information is now attached as attributes to the MLIR custom_call Op. When serializing the these additional attributes are attached to the metadata field of the HLO custom_call op. This is done in order not to interfere with the backward guarantees on the serialization format of custom_call backend_config. PiperOrigin-RevId: 561771663 31 August 2023, 22:15:25 UTC
80f6151 Instrument metrics to track cache hit rate of original JAX compilation cache. Metrics: 1) '/jax/compilation_cache/compile_requests_use_cache' to track the number of the number of times `compile_or_get_cached` is called and `use_compilation_cache` is true. 2) '/jax/compilation_cache/cache_hits_original' to track the number of times the cached executable is successfully returned from a cache read using the original implementation. 3) '/jax/compilation_cache/cache_misses' to track the number of times cache is missed and the compiled executable is written to cache repository. Created a context manager to register/unregister event listeners. PiperOrigin-RevId: 561771262 31 August 2023, 22:05:23 UTC
8a04dfd rolling back shard_map transposition change to fix a bug Reverts 437d7be73534403f39fbee9d6391be1c532933a1 PiperOrigin-RevId: 561730581 31 August 2023, 19:39:56 UTC
ccf8d89 Merge pull request #17389 from jakevdp:rand-dtypes PiperOrigin-RevId: 561712142 31 August 2023, 18:30:23 UTC
f0309b4 jax.random: warn on unsupported dtypes 31 August 2023, 17:56:05 UTC
24c3a9d [MOSAIC] apply_vector_layout C++ rewrite (2) No-op pass and flag to use it instead of Python PiperOrigin-RevId: 561697585 31 August 2023, 17:42:03 UTC
faa7a68 Merge pull request #17388 from jakevdp:warning-type PiperOrigin-RevId: 561692672 31 August 2023, 17:25:26 UTC
9cda6be [Mosaic] Add debug flag to print module after each path. PiperOrigin-RevId: 561683148 31 August 2023, 16:53:11 UTC
ce92442 Fix TPU custom call device kind checking. PiperOrigin-RevId: 561677171 31 August 2023, 16:30:44 UTC
4e9b643 typing: annotate NumpyComplexWarning 31 August 2023, 16:05:26 UTC
9f2fd4e Merge pull request #17316 from gnecula:export_multi_1 PiperOrigin-RevId: 561640360 31 August 2023, 13:47:03 UTC
6b4fbe9 Update XLA dependency to use revision http://github.com/openxla/xla/commit/449d45dbcc923bed79746ea11975aab4adc0316c. PiperOrigin-RevId: 561630195 31 August 2023, 12:54:20 UTC
440ef03 [export] Implement the calling convention for exporting with multi-platform lowering This is a first step towards supporting multi-platform exported JAX modules. Such modules are usable on more than one platform, and take an additional first argument that encodes the actual compilation platform as an index into the sequence of platforms for which the module was lowered. More details about the calling convention are in the docstring for jax_export.Exported in this PR. The value of the platform index is set by `jax_export.call_exported` when calling from JAX, and in the tf.XlaCallModule prior to compilation, when called from TensorFlow. This is already implemented in tf.XlaCallModule. This PR has some incomplete pieces: * Currently we actually lower only for the first platform specified, and the platform argument is not used. There are a couple of implementation strategies for actual multi-platform lowering, both using the same calling convention. We could lower separately for each platform and put the results together with one top-level conditional. Alternatively, we can take advantage of the fact that few primitives have per-platform lowering; we could lower those using a conditional. * we implement multi-platform lowering only for jax_export, not for regular JAX jit or AOT lowering. This ensure that this change is narrowly scoped and safe for most JAX usage. * we abuse the `_experimental_lowering_platform` kwarg to `lower()` to pass a tuple of platforms when we want multi-platform lowering. We ought to rename it to `_experimental_lowering_platforms`, but that requires more plumbing. * we take advantage of the fact that the lowering for the platform index is identical to that for dimension variables: add a new argument to inner functions and pass the values to callees. We implement platform index as a dimension variable. * we do not yet have the connection with jax2tf.convert. 31 August 2023, 12:43:08 UTC
65a2ae9 Remove always_lower now that trivial computations don't exist PiperOrigin-RevId: 561516209 31 August 2023, 02:31:30 UTC
c21b0db Merge pull request #17350 from jakevdp:dep-linear-util PiperOrigin-RevId: 561515645 31 August 2023, 02:21:21 UTC
ca39457 JEX: move jax.linear_util to jax.extend.linear_util 31 August 2023, 01:32:12 UTC
437d7be Merge pull request #17368 from mattjj:shmap-transpose PiperOrigin-RevId: 561476082 30 August 2023, 23:14:13 UTC
51fa4d2 Merge pull request #17353 from jakevdp:dep-prng PiperOrigin-RevId: 561476055 30 August 2023, 23:04:27 UTC
03c2377 Merge pull request #17355 from jakevdp:dep-tree-util PiperOrigin-RevId: 561464587 30 August 2023, 22:19:23 UTC
4b89d03 Deprecate the contents of jax.prng 30 August 2023, 22:13:32 UTC
fdd252f [shard-map] add rewrite for efficient transposition 30 August 2023, 22:08:11 UTC
1761f79 Merge pull request #17369 from skye:version PiperOrigin-RevId: 561461008 30 August 2023, 22:05:31 UTC
f71ba0a Update versions and changelog post 0.4.15 release 30 August 2023, 20:20:25 UTC
603c879 Run `_check_sharding` checks during `api.device_put` instead of in the impl rule so that we don't have to repeat these checks in each rule of device_put. The same is done for jit and with_sharding_constraint. PiperOrigin-RevId: 561380348 30 August 2023, 17:27:37 UTC
4268ed7 Merge pull request #17362 from skye:version PiperOrigin-RevId: 561323664 30 August 2023, 13:32:37 UTC
2c5e857 Update versions for jax/jaxlib 0.4.15 release 30 August 2023, 13:10:07 UTC
c1680aa Update XLA dependency to use revision http://github.com/openxla/xla/commit/61e12edca8c1cfd865550c261c634861eccf506a. PiperOrigin-RevId: 561315877 30 August 2023, 12:52:12 UTC
d02b59e [MOSAIC] apply_vector_layout C++ rewrite (1) VectorLayout functions PiperOrigin-RevId: 561237760 30 August 2023, 05:57:13 UTC
6b57470 Abstract the array_serialization error message to a global variable so that it can be overridden. PiperOrigin-RevId: 561224461 30 August 2023, 04:46:26 UTC
e785f89 Raise a good error message when mesh is not provided to jax.jit when using spmd_axis_name parameter of jax.vmap PiperOrigin-RevId: 561217612 30 August 2023, 03:58:57 UTC
9be96c1 Deprecate a number of exports from jax.interpreters.xla. Custom HLO lowering rules for primitives should be updated to use MLIR StableHLO lowering rules via jax.interpreter.mlir. PiperOrigin-RevId: 561215967 30 August 2023, 03:47:59 UTC
ecee8f9 [JAX] Implement importing external dlpack-aware Python arrays. See https://dmlc.github.io/dlpack/latest/python_spec.html. This is the import path. The export path was implemented in https://github.com/openxla/xla/commit/0b3cbfe4bc7cd68dc20924bd878cdfb2faa1a169. This allows for creating jax.Arrays from external GPU arrays asynchronously. PiperOrigin-RevId: 561172624 29 August 2023, 23:39:31 UTC
e369445 Remove tests for jax.numpy.in1d, which is deprecated. PiperOrigin-RevId: 561161024 29 August 2023, 22:52:34 UTC
7c9cca8 jax.tree_util: use standard deprecation framework for deprecated items 29 August 2023, 22:14:16 UTC
9390024 Remove jax.interpreters.xla.register_collective_primitive. We aren't consuming this data any more. It existed only to compare against the set of multiprocess-allowed collectives, but we removed that list also. So this registry is completely pointless. PiperOrigin-RevId: 561150259 29 August 2023, 22:10:05 UTC
289ccad Merge pull request #17351 from jakevdp:mypy-fix PiperOrigin-RevId: 561136741 29 August 2023, 21:20:07 UTC
5a578cb Add segment_ids support to pallas flash attention on GPU. PiperOrigin-RevId: 561130379 29 August 2023, 20:59:18 UTC
f1fc2ad Fix mypy error 29 August 2023, 20:25:12 UTC
6072d59 Any devices passed to jax.sharding.Mesh are required to be hashable. This is true for mock devices or user specific devices and jax.devices() too. Fix the tests so that the mock devices are hashable. PiperOrigin-RevId: 561103167 29 August 2023, 19:20:54 UTC
ff5b480 Merge pull request #17337 from jakevdp:jax-extend-doc PiperOrigin-RevId: 561101464 29 August 2023, 19:11:23 UTC
f935c00 DOC: add missing docs for jax.random functions 29 August 2023, 18:20:25 UTC
d0a6813 Make mlir.custom_call() more general and expose it as jax.interpreters.mlir.custom_call(). This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities. Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules. This function has two benefits over just building the stablehlo directly: a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults). Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper. PiperOrigin-RevId: 561042402 29 August 2023, 15:50:07 UTC
ba8c4c3 Update XLA dependency to use revision http://github.com/openxla/xla/commit/db4cc73e3fc233afdd30f0dd94192f513feb5c6a. PiperOrigin-RevId: 561015160 29 August 2023, 13:46:00 UTC
46ac9e2 Use the default CSR matmul algorithm. Previously we requested CUSPARSE_SPMM_CSR_ALG3 in an attempt to get deterministic results from cusparse SpMM CSR matmuls. In the past, Cusparse silently ignored this algorithm choice and used a different algorithm in cases where ALG3 was not supported, but cusparse 12.2.1 removed the silent fallback behavior. Since we're not actually getting deterministic behavior anyway in all cases, use the default algorithm always. PiperOrigin-RevId: 560867049 29 August 2023, 00:49:01 UTC
046bcc0 [Mosaic] Add missing headers in linalg vectorization. PiperOrigin-RevId: 560865251 29 August 2023, 00:39:47 UTC
92128d4 Remove backward compatibility code related to pytree registries. We always have a default_registry now, so we don't need to protect code that uses it with conditionals. A number of type suppressions are also stale. PiperOrigin-RevId: 560849610 28 August 2023, 23:26:25 UTC
5bdbda5 Merge pull request #17334 from google:cache_log PiperOrigin-RevId: 560829247 28 August 2023, 22:14:05 UTC
a37e215 Don't drop out of C++ fast path if mesh pointers are not equal. This is done by returning the same object when constructing mesh if devices.shape, axis_names and flat device list matches. PiperOrigin-RevId: 560828993 28 August 2023, 22:04:05 UTC
0c7ce72 Merge pull request #17228 from jakevdp:doc-nb-exec PiperOrigin-RevId: 560811268 28 August 2023, 21:09:35 UTC
fbadec2 Merge pull request #17331 from jakevdp:print-docs PiperOrigin-RevId: 560811165 28 August 2023, 20:59:21 UTC
34010a9 Align dummy pointers passed to cusparse to 16 bytes Fixes alignment errors from Cusparse 12.2. PiperOrigin-RevId: 560793586 28 August 2023, 19:56:27 UTC
c523e3b Turn down jax_persistent_cache_min_compile_time_secs logging from info to debug. It's very noisy otherwise, since jax usually produces many small computations that aren't cached. 28 August 2023, 19:49:25 UTC
b09bef7 Merge pull request #17300 from jakevdp:prng-test PiperOrigin-RevId: 560776573 28 August 2023, 18:52:17 UTC
2f62111 DOC: mention static values in debug.print docs 28 August 2023, 17:59:40 UTC
2f878a7 Tests: set jax_legacy_prng_key='error' 28 August 2023, 17:56:09 UTC
c3e624a [Mosaic] Fix assert PiperOrigin-RevId: 560756391 28 August 2023, 17:50:23 UTC
70206ee Give jax.numpy.array the type `Callable`. This is to prevent users from using as the type of arrays in type annotations. PiperOrigin-RevId: 560754568 28 August 2023, 17:41:07 UTC
3ea141d Merge pull request #17113 from jakevdp:faster-ufuncs PiperOrigin-RevId: 560737512 28 August 2023, 16:46:54 UTC
cb7c7ad jnp.ufunc: add fast paths for add/prod reductions 28 August 2023, 15:30:23 UTC
f407298 Merge pull request #17322 from abdulasiraj:patch-1 PiperOrigin-RevId: 560695446 28 August 2023, 13:46:12 UTC
7b802dc Update XLA dependency to use revision http://github.com/openxla/xla/commit/4f675c5715d23a49933e481d09adcf5065edfca5. PiperOrigin-RevId: 560691873 28 August 2023, 13:30:55 UTC
599b35e Update __init__.py 28 August 2023, 05:13:16 UTC
d09c55a Update __init__.py to include dlpack module import for dlpack module was missing in `__init__.py` file. just added that 28 August 2023, 05:01:17 UTC
871c9f4 Merge pull request #17307 from froystig:wrap-key PiperOrigin-RevId: 560536131 27 August 2023, 19:58:50 UTC
66c0400 Update XLA dependency to use revision http://github.com/openxla/xla/commit/1c22df67539a85904b4f686bcbe69efd518a5698. PiperOrigin-RevId: 560496105 27 August 2023, 13:22:28 UTC
a69f134 add `jax.extend.random.wrap_key_data` 26 August 2023, 18:39:25 UTC
841baab Adds Pallas flash attention TPU kernel. Implementation based on https://arxiv.org/pdf/2205.14135.pdf. PiperOrigin-RevId: 560346791 26 August 2023, 15:03:48 UTC
08ca945 Update XLA dependency to use revision http://github.com/openxla/xla/commit/7a371ed44aba34f83d6d3d1159d2e6d0d327c603. PiperOrigin-RevId: 560339802 26 August 2023, 13:52:36 UTC
eee5c13 Donate only if result and arg memory kinds are not None. PiperOrigin-RevId: 560264967 26 August 2023, 03:58:39 UTC
9aaacc5 Merge pull request #17275 from jakevdp:ufunc-reduce-where PiperOrigin-RevId: 560251564 26 August 2023, 02:18:32 UTC
4e227a3 Merge pull request #17305 from mzgubic:patch-1 PiperOrigin-RevId: 560227591 25 August 2023, 23:48:55 UTC
c35bc81 Add an optional `sharding` argument to `pure_callback` and `io_callback` This CL allows callers of `pure_callback` and `io_callback` to be able to specify the device to be used to run host callbacks. This is to make it possible to move users of `jax.experimental.host_callback(..., device_index=i)` (deprecated) to `pure_callback` or `io_callback`. Instead of taking a device index referring to a device in the device assignment, the new API takes sharding to match the look and feel with other JAX APIs. The current implementation supports `SingleDeviceSharding` only since the way sharding is annotated in StableHLO makes it tricky to use anything other than `MAXIMAL` or `MANUAL`. But if we later decide to expand support for other types of sharding, we will be able to do it without changing the API or breaking existing users. This CL also fixes an issue where `pure_callback` and `io_callback` had different sharding semantics inside `SPMDAxisContext`. Specifically, `io_callback` used to emit `MAXIMAL` even for `SPMDAxisContext`, whereas `pure_callback` used `MANUAL` sharding. The latter seems to make more sense since `SPMDAxisContext` with all manual axes should come with per-device semantics. This CL made both styles of callbacks use consistent sharding by factoring out sharding calculation as a common function. PiperOrigin-RevId: 560203044 25 August 2023, 21:57:35 UTC
992e5e4 Fix typo in jnp.interp docstring. 25 August 2023, 21:39:15 UTC
7ec2a88 Merge pull request #17302 from google:mlir PiperOrigin-RevId: 560164637 25 August 2023, 19:22:58 UTC
ae70fb7 Add generated MLIR _enum_gen.py files to jaxlib wheel. Fixes CI failure after LLVM update. 25 August 2023, 19:16:15 UTC
970f4c9 Remove trivial execution from jax since it leads to 100x slower dispatch time. Trivial computations were added for a pre-omnistaging world. After omnistaging, JAX produces less trivial computations, so there is need for this to exist. In the future, if we want to support forwarding of inputs to outputs, there would need to be a different way which the C++ dispatch path knows about. ``` jit_trivial_dispatch 246µs ± 3% 4µs ± 1% -98.52% (p=0.008 n=5+5) jit_trivial 250µs ± 3% 5µs ± 1% -98.19% (p=0.008 n=5+5) ``` PiperOrigin-RevId: 560141018 25 August 2023, 17:59:48 UTC
c71eedf Merge pull request #17295 from jakevdp:lax-pow-jvp PiperOrigin-RevId: 560133324 25 August 2023, 17:35:44 UTC
6cec5d4 lax.pow: fix shape mismatch failure in jvp rule 25 August 2023, 17:05:55 UTC
3ea0a74 Merge pull request #17277 from hawkinsp:trapz PiperOrigin-RevId: 560111745 25 August 2023, 16:22:34 UTC
b11e245 Remove stale reference to coordination_service flag. Fixes https://github.com/google/jax/issues/17288 PiperOrigin-RevId: 560103075 25 August 2023, 15:47:53 UTC
975dae3 Deprecate jax.numpy.trapz. Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead. Fixes https://github.com/google/jax/issues/17244 25 August 2023, 15:04:13 UTC
a454081 Merge pull request #17289 from google:nanobind PiperOrigin-RevId: 560090600 25 August 2023, 14:53:04 UTC
9a5b808 Update nanobind version to 1.5.2. 25 August 2023, 14:45:32 UTC
ac8ea86 Fix accidental signature change to get_serialized_metadata() from nanobind PR. pybind11 accepts either Python strings or bytes as a std::string argument, whereas nanobind accepts only strings. Change the argument to nb::bytes instead. PiperOrigin-RevId: 560086072 25 August 2023, 14:31:31 UTC
87cec1a Update XLA dependency to use revision http://github.com/openxla/xla/commit/14c0e8565ab7774268e1c454c72a11535acc1aae. PiperOrigin-RevId: 560077922 25 August 2023, 13:51:19 UTC
eea2603 Add a proper jax config for memories so that we can iteratively develop and enable it. PiperOrigin-RevId: 559977015 25 August 2023, 05:23:55 UTC
0099327 get aval directly via attribute in key array shard arg handler No need to go through `core.get_aval` here. PiperOrigin-RevId: 559945841 25 August 2023, 02:47:35 UTC
back to top