88408e1 | Peter Hawkins | 03 September 2023, 19:18:25 UTC | Remove stale references to //jaxlib:setup.cfg in Bazel build. Fixes broken jaxlib wheel build. | 03 September 2023, 19:18:25 UTC |
01c068e | George Necula | 03 September 2023, 04:50:32 UTC | [callback] Some test cleanup. Removes callback testing function and uses io_callback and pure_callback instead. This allows us to remove some tests from the PureCallbackTest class. Renames IoPythonCallbackTest -> IoCallbackTest and PurePythonCallbackTest -> PureCallbackTest. PiperOrigin-RevId: 562285255 | 03 September 2023, 04:51:07 UTC |
da87d78 | jax authors | 02 September 2023, 11:50:34 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/7bd5f555f0e86efb00005dbc9e38991d5fb9765b. PiperOrigin-RevId: 562181905 | 02 September 2023, 11:51:09 UTC |
e9638da | Peter Hawkins | 01 September 2023, 23:17:59 UTC | Fix compatibility problem with the build 1.0 release. It no longer appears possible to pass platform tags to bdist_wheel on the build command line without triggering a deprecation warning from bdist_wheel. However, we can just pass the platform tag via a generated setup.cfg instead. It does not appear necessary to pass a Python tag, since the default behavior is correct. PiperOrigin-RevId: 562073096 | 01 September 2023, 23:18:33 UTC |
0ddbe76 | jax authors | 01 September 2023, 12:34:47 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/7b9aa07fe9f83bb7af898092fdde3d4310b7906c. PiperOrigin-RevId: 561928858 | 01 September 2023, 12:35:23 UTC |
73275aa | Adam Paszke | 01 September 2023, 10:01:23 UTC | [Mosaic] Add support for inserting a new lane dimension This is often useful when a kernel uses statistics tensors that are constant across the minormost dimensions. Right now the only way to use them is to force XLA to insert the extra dimension before the kernel, but that turns out to be very inefficient. PiperOrigin-RevId: 561903222 | 01 September 2023, 10:02:02 UTC |
efaea8e | George Necula | 01 September 2023, 05:30:59 UTC | [callback] Enable device_index support in terms of callback sharding support. This is part of deprecating host_callback and moving to io_callback. PiperOrigin-RevId: 561856023 | 01 September 2023, 05:31:35 UTC |
e0a6230 | George Necula | 01 September 2023, 05:07:42 UTC | [host_callback] Delete unused code paths. This is part of deprecating host_callback and moving to io_callback. PiperOrigin-RevId: 561851494 | 01 September 2023, 05:08:23 UTC |
70b58bb | Matthew Johnson | 01 September 2023, 00:30:34 UTC | rolling forward shard_map transpose fixes The new efficient-transpose path, enabled by setting check_rep=True in the shard_map call, had kept working. But the change inadvertently broke the check_rep=False path. And because most tests set check_rep=True, we didn't notice it in the tests! The issue was that with check_rep=False, we need the shard_map transpose rule to insert psums corresponding to in_specs with fan-out, and correspondingly insert division for out_specs with fan-in-consensus. (With the new check_rep=True path that this change adds, those extra operations aren't necessary as the body itself transposes correctly.) But the PR accidentally removed those! The fix was simple: just track whether we've applied the efficient-transpose-body-rewrite (i.e. whether we're in the new body-is-transposable path or old need-extra-operations path) by adding a boolean parameter `rewrite` to the shard_map primitive, and if the rewrite hasn't been applied then include the explicit psum/div operations in the transpose rule. Reverts 8a04dfd830ff89f46e1fe3e866ee4fb2da9c90aa PiperOrigin-RevId: 561805840 | 01 September 2023, 00:31:21 UTC |
c38f670 | jax authors | 01 September 2023, 00:21:17 UTC | Hash serialized CompileOptions for new cache key generation. The original cache key generation hashes individual fields of CompileOptions, ExecutableBuildOptions, and DebugOptions. This is not future proof: when a field is added to any of these structures, the corresponding hash needs to be added to the cache key generation. The new cache key generation algorithm hashes the serialized representation of CompileOptions. Some DebugOptions do not affect the compilation result; exclude them from the computation. If additional fields are identified, they can be added; such additions will reduce unnecessary cache misses. Testing: revised unit test. PiperOrigin-RevId: 561803875 | 01 September 2023, 00:21:57 UTC |
da097e7 | Peter Hawkins | 31 August 2023, 23:45:14 UTC | Don't hold a reference to _lapack.initialize(). Nanobind prints a warning if a reference to a nanobind-bound function is held at process atexit() time. But there's no particularly good reason we need to hold a function reference that long anyway in this case. PiperOrigin-RevId: 561795469 | 31 August 2023, 23:46:04 UTC |
9f7a19a | Chris Jones | 31 August 2023, 23:31:42 UTC | [jax_triton] Improve error message when shared memory exceeds that available on the GPU. PiperOrigin-RevId: 561792406 | 31 August 2023, 23:32:18 UTC |
ccb8814 | Yash Katariya | 31 August 2023, 22:17:57 UTC | Make apply_primitive go via C++ fast dispatch. This leads to a ~30% faster dispatch time. Ideally, we should replace this with jit, but that has it's own set of problems that I will look into later. ``` eager_unary_dispatch 40.3µs ± 2% 29.2µs ± 9% -27.51% (p=0.008 n=5+5) eager_unary 40.6µs ± 0% 31.1µs ±11% -23.41% (p=0.016 n=4+5) eager_binary_dispatch 49.6µs ± 0% 34.5µs ± 8% -30.58% (p=0.016 n=4+5) eager_binary 50.2µs ± 1% 35.4µs ± 9% -29.38% (p=0.016 n=4+5) bench_remat_eager_retracing_overheads 13.0ms ± 1% 11.3ms ± 8% -13.26% (p=0.008 n=5+5) bench_remat_eager_retracing_overheads_static_argnums 13.3ms ± 0% 12.3ms ± 6% -7.34% (p=0.016 n=4+5) bench_repeated_static_indexing 112ms ± 2% 82ms ± 5% -26.46% (p=0.008 n=5+5) bench_repeated_static_slicing 90.5ms ± 1% 68.3ms ± 5% -24.54% (p=0.008 n=5+5) ``` PiperOrigin-RevId: 561774696 | 31 August 2023, 22:25:11 UTC |
6c1b4b9 | jax authors | 31 August 2023, 22:06:12 UTC | [pallas] Remove kernel_name and kernel_regeneration_metadata from backend config This information is now attached as attributes to the MLIR custom_call Op. When serializing the these additional attributes are attached to the metadata field of the HLO custom_call op. This is done in order not to interfere with the backward guarantees on the serialization format of custom_call backend_config. PiperOrigin-RevId: 561771663 | 31 August 2023, 22:15:25 UTC |
80f6151 | jax authors | 31 August 2023, 22:04:43 UTC | Instrument metrics to track cache hit rate of original JAX compilation cache. Metrics: 1) '/jax/compilation_cache/compile_requests_use_cache' to track the number of the number of times `compile_or_get_cached` is called and `use_compilation_cache` is true. 2) '/jax/compilation_cache/cache_hits_original' to track the number of times the cached executable is successfully returned from a cache read using the original implementation. 3) '/jax/compilation_cache/cache_misses' to track the number of times cache is missed and the compiled executable is written to cache repository. Created a context manager to register/unregister event listeners. PiperOrigin-RevId: 561771262 | 31 August 2023, 22:05:23 UTC |
8a04dfd | Matthew Johnson | 31 August 2023, 19:39:13 UTC | rolling back shard_map transposition change to fix a bug Reverts 437d7be73534403f39fbee9d6391be1c532933a1 PiperOrigin-RevId: 561730581 | 31 August 2023, 19:39:56 UTC |
ccf8d89 | jax authors | 31 August 2023, 18:30:23 UTC | Merge pull request #17389 from jakevdp:rand-dtypes PiperOrigin-RevId: 561712142 | 31 August 2023, 18:30:23 UTC |
f0309b4 | Jake VanderPlas | 31 August 2023, 17:56:05 UTC | jax.random: warn on unsupported dtypes | 31 August 2023, 17:56:05 UTC |
24c3a9d | Tomás Longeri | 31 August 2023, 17:41:14 UTC | [MOSAIC] apply_vector_layout C++ rewrite (2) No-op pass and flag to use it instead of Python PiperOrigin-RevId: 561697585 | 31 August 2023, 17:42:03 UTC |
faa7a68 | jax authors | 31 August 2023, 17:25:26 UTC | Merge pull request #17388 from jakevdp:warning-type PiperOrigin-RevId: 561692672 | 31 August 2023, 17:25:26 UTC |
9cda6be | Jevin Jiang | 31 August 2023, 16:52:30 UTC | [Mosaic] Add debug flag to print module after each path. PiperOrigin-RevId: 561683148 | 31 August 2023, 16:53:11 UTC |
ce92442 | Enrique Piqueras | 31 August 2023, 16:29:58 UTC | Fix TPU custom call device kind checking. PiperOrigin-RevId: 561677171 | 31 August 2023, 16:30:44 UTC |
4e9b643 | Jake VanderPlas | 31 August 2023, 16:05:26 UTC | typing: annotate NumpyComplexWarning | 31 August 2023, 16:05:26 UTC |
9f2fd4e | jax authors | 31 August 2023, 13:47:03 UTC | Merge pull request #17316 from gnecula:export_multi_1 PiperOrigin-RevId: 561640360 | 31 August 2023, 13:47:03 UTC |
6b4fbe9 | jax authors | 31 August 2023, 12:53:42 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/449d45dbcc923bed79746ea11975aab4adc0316c. PiperOrigin-RevId: 561630195 | 31 August 2023, 12:54:20 UTC |
440ef03 | George Necula | 27 August 2023, 11:27:34 UTC | [export] Implement the calling convention for exporting with multi-platform lowering This is a first step towards supporting multi-platform exported JAX modules. Such modules are usable on more than one platform, and take an additional first argument that encodes the actual compilation platform as an index into the sequence of platforms for which the module was lowered. More details about the calling convention are in the docstring for jax_export.Exported in this PR. The value of the platform index is set by `jax_export.call_exported` when calling from JAX, and in the tf.XlaCallModule prior to compilation, when called from TensorFlow. This is already implemented in tf.XlaCallModule. This PR has some incomplete pieces: * Currently we actually lower only for the first platform specified, and the platform argument is not used. There are a couple of implementation strategies for actual multi-platform lowering, both using the same calling convention. We could lower separately for each platform and put the results together with one top-level conditional. Alternatively, we can take advantage of the fact that few primitives have per-platform lowering; we could lower those using a conditional. * we implement multi-platform lowering only for jax_export, not for regular JAX jit or AOT lowering. This ensure that this change is narrowly scoped and safe for most JAX usage. * we abuse the `_experimental_lowering_platform` kwarg to `lower()` to pass a tuple of platforms when we want multi-platform lowering. We ought to rename it to `_experimental_lowering_platforms`, but that requires more plumbing. * we take advantage of the fact that the lowering for the platform index is identical to that for dimension variables: add a new argument to inner functions and pass the values to callees. We implement platform index as a dimension variable. * we do not yet have the connection with jax2tf.convert. | 31 August 2023, 12:43:08 UTC |
65a2ae9 | Yash Katariya | 31 August 2023, 02:24:49 UTC | Remove always_lower now that trivial computations don't exist PiperOrigin-RevId: 561516209 | 31 August 2023, 02:31:30 UTC |
c21b0db | jax authors | 31 August 2023, 02:21:21 UTC | Merge pull request #17350 from jakevdp:dep-linear-util PiperOrigin-RevId: 561515645 | 31 August 2023, 02:21:21 UTC |
ca39457 | Jake VanderPlas | 30 August 2023, 22:14:47 UTC | JEX: move jax.linear_util to jax.extend.linear_util | 31 August 2023, 01:32:12 UTC |
437d7be | jax authors | 30 August 2023, 23:14:13 UTC | Merge pull request #17368 from mattjj:shmap-transpose PiperOrigin-RevId: 561476082 | 30 August 2023, 23:14:13 UTC |
51fa4d2 | jax authors | 30 August 2023, 23:04:27 UTC | Merge pull request #17353 from jakevdp:dep-prng PiperOrigin-RevId: 561476055 | 30 August 2023, 23:04:27 UTC |
03c2377 | jax authors | 30 August 2023, 22:19:23 UTC | Merge pull request #17355 from jakevdp:dep-tree-util PiperOrigin-RevId: 561464587 | 30 August 2023, 22:19:23 UTC |
4b89d03 | Jake VanderPlas | 30 August 2023, 22:13:32 UTC | Deprecate the contents of jax.prng | 30 August 2023, 22:13:32 UTC |
fdd252f | Matthew Johnson | 10 May 2023, 14:48:24 UTC | [shard-map] add rewrite for efficient transposition | 30 August 2023, 22:08:11 UTC |
1761f79 | jax authors | 30 August 2023, 22:05:31 UTC | Merge pull request #17369 from skye:version PiperOrigin-RevId: 561461008 | 30 August 2023, 22:05:31 UTC |
f71ba0a | Skye Wanderman-Milne | 30 August 2023, 19:31:01 UTC | Update versions and changelog post 0.4.15 release | 30 August 2023, 20:20:25 UTC |
603c879 | Yash Katariya | 30 August 2023, 17:26:52 UTC | Run `_check_sharding` checks during `api.device_put` instead of in the impl rule so that we don't have to repeat these checks in each rule of device_put. The same is done for jit and with_sharding_constraint. PiperOrigin-RevId: 561380348 | 30 August 2023, 17:27:37 UTC |
4268ed7 | jax authors | 30 August 2023, 13:32:37 UTC | Merge pull request #17362 from skye:version PiperOrigin-RevId: 561323664 | 30 August 2023, 13:32:37 UTC |
2c5e857 | Skye Wanderman-Milne | 30 August 2023, 13:10:07 UTC | Update versions for jax/jaxlib 0.4.15 release | 30 August 2023, 13:10:07 UTC |
c1680aa | jax authors | 30 August 2023, 12:51:31 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/61e12edca8c1cfd865550c261c634861eccf506a. PiperOrigin-RevId: 561315877 | 30 August 2023, 12:52:12 UTC |
d02b59e | Tomás Longeri | 30 August 2023, 05:56:33 UTC | [MOSAIC] apply_vector_layout C++ rewrite (1) VectorLayout functions PiperOrigin-RevId: 561237760 | 30 August 2023, 05:57:13 UTC |
6b57470 | Yash Katariya | 30 August 2023, 04:45:50 UTC | Abstract the array_serialization error message to a global variable so that it can be overridden. PiperOrigin-RevId: 561224461 | 30 August 2023, 04:46:26 UTC |
e785f89 | Yash Katariya | 30 August 2023, 03:58:20 UTC | Raise a good error message when mesh is not provided to jax.jit when using spmd_axis_name parameter of jax.vmap PiperOrigin-RevId: 561217612 | 30 August 2023, 03:58:57 UTC |
9be96c1 | Peter Hawkins | 30 August 2023, 03:47:22 UTC | Deprecate a number of exports from jax.interpreters.xla. Custom HLO lowering rules for primitives should be updated to use MLIR StableHLO lowering rules via jax.interpreter.mlir. PiperOrigin-RevId: 561215967 | 30 August 2023, 03:47:59 UTC |
ecee8f9 | Skye Wanderman-Milne | 29 August 2023, 23:38:34 UTC | [JAX] Implement importing external dlpack-aware Python arrays. See https://dmlc.github.io/dlpack/latest/python_spec.html. This is the import path. The export path was implemented in https://github.com/openxla/xla/commit/0b3cbfe4bc7cd68dc20924bd878cdfb2faa1a169. This allows for creating jax.Arrays from external GPU arrays asynchronously. PiperOrigin-RevId: 561172624 | 29 August 2023, 23:39:31 UTC |
e369445 | Peter Hawkins | 29 August 2023, 22:51:57 UTC | Remove tests for jax.numpy.in1d, which is deprecated. PiperOrigin-RevId: 561161024 | 29 August 2023, 22:52:34 UTC |
7c9cca8 | Jake VanderPlas | 29 August 2023, 22:14:16 UTC | jax.tree_util: use standard deprecation framework for deprecated items | 29 August 2023, 22:14:16 UTC |
9390024 | Peter Hawkins | 29 August 2023, 22:09:33 UTC | Remove jax.interpreters.xla.register_collective_primitive. We aren't consuming this data any more. It existed only to compare against the set of multiprocess-allowed collectives, but we removed that list also. So this registry is completely pointless. PiperOrigin-RevId: 561150259 | 29 August 2023, 22:10:05 UTC |
289ccad | jax authors | 29 August 2023, 21:20:07 UTC | Merge pull request #17351 from jakevdp:mypy-fix PiperOrigin-RevId: 561136741 | 29 August 2023, 21:20:07 UTC |
5a578cb | Tao Wang | 29 August 2023, 20:58:37 UTC | Add segment_ids support to pallas flash attention on GPU. PiperOrigin-RevId: 561130379 | 29 August 2023, 20:59:18 UTC |
f1fc2ad | Jake VanderPlas | 29 August 2023, 20:25:12 UTC | Fix mypy error | 29 August 2023, 20:25:12 UTC |
6072d59 | Yash Katariya | 29 August 2023, 19:17:37 UTC | Any devices passed to jax.sharding.Mesh are required to be hashable. This is true for mock devices or user specific devices and jax.devices() too. Fix the tests so that the mock devices are hashable. PiperOrigin-RevId: 561103167 | 29 August 2023, 19:20:54 UTC |
ff5b480 | jax authors | 29 August 2023, 19:11:23 UTC | Merge pull request #17337 from jakevdp:jax-extend-doc PiperOrigin-RevId: 561101464 | 29 August 2023, 19:11:23 UTC |
f935c00 | Jake VanderPlas | 28 August 2023, 20:55:33 UTC | DOC: add missing docs for jax.random functions | 29 August 2023, 18:20:25 UTC |
d0a6813 | Peter Hawkins | 29 August 2023, 15:49:30 UTC | Make mlir.custom_call() more general and expose it as jax.interpreters.mlir.custom_call(). This change is in preparation for deprecating the XlaBuilder APIs for building non-MLIR HLO. In general JAX would be best served by adding a more user-friendly "custom kernel" API that doesn't require the user to build IR directly, but for the moment the best we can do is migrate users to use MLIR/StableHLO utilities instead of classic HLO utilities. Since most users of custom kernels probably want to build a custom-call we can get most of the benefit by providing an ergonomic helper function for building the IR for custom calls that can be called by external primitive lowering rules. This function has two benefits over just building the stablehlo directly: a) it is a JAX API, and we can be more confident the API won't change because of upstream MLIR changes b) the Python API to build stablehlo.custom_call generated by the bindings isn't that easy to use (e.g. it doesn't have sensible defaults). Next step will be to deprecate XlaBuilder and encourage users to switch to lowering rules using this helper. PiperOrigin-RevId: 561042402 | 29 August 2023, 15:50:07 UTC |
ba8c4c3 | jax authors | 29 August 2023, 13:45:26 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/db4cc73e3fc233afdd30f0dd94192f513feb5c6a. PiperOrigin-RevId: 561015160 | 29 August 2023, 13:46:00 UTC |
46ac9e2 | Peter Hawkins | 29 August 2023, 00:48:19 UTC | Use the default CSR matmul algorithm. Previously we requested CUSPARSE_SPMM_CSR_ALG3 in an attempt to get deterministic results from cusparse SpMM CSR matmuls. In the past, Cusparse silently ignored this algorithm choice and used a different algorithm in cases where ALG3 was not supported, but cusparse 12.2.1 removed the silent fallback behavior. Since we're not actually getting deterministic behavior anyway in all cases, use the default algorithm always. PiperOrigin-RevId: 560867049 | 29 August 2023, 00:49:01 UTC |
046bcc0 | Jevin Jiang | 29 August 2023, 00:39:11 UTC | [Mosaic] Add missing headers in linalg vectorization. PiperOrigin-RevId: 560865251 | 29 August 2023, 00:39:47 UTC |
92128d4 | Peter Hawkins | 28 August 2023, 23:25:50 UTC | Remove backward compatibility code related to pytree registries. We always have a default_registry now, so we don't need to protect code that uses it with conditionals. A number of type suppressions are also stale. PiperOrigin-RevId: 560849610 | 28 August 2023, 23:26:25 UTC |
5bdbda5 | jax authors | 28 August 2023, 22:14:05 UTC | Merge pull request #17334 from google:cache_log PiperOrigin-RevId: 560829247 | 28 August 2023, 22:14:05 UTC |
a37e215 | Yash Katariya | 28 August 2023, 22:03:18 UTC | Don't drop out of C++ fast path if mesh pointers are not equal. This is done by returning the same object when constructing mesh if devices.shape, axis_names and flat device list matches. PiperOrigin-RevId: 560828993 | 28 August 2023, 22:04:05 UTC |
0c7ce72 | jax authors | 28 August 2023, 21:09:35 UTC | Merge pull request #17228 from jakevdp:doc-nb-exec PiperOrigin-RevId: 560811268 | 28 August 2023, 21:09:35 UTC |
fbadec2 | jax authors | 28 August 2023, 20:59:21 UTC | Merge pull request #17331 from jakevdp:print-docs PiperOrigin-RevId: 560811165 | 28 August 2023, 20:59:21 UTC |
34010a9 | Peter Hawkins | 28 August 2023, 19:55:31 UTC | Align dummy pointers passed to cusparse to 16 bytes Fixes alignment errors from Cusparse 12.2. PiperOrigin-RevId: 560793586 | 28 August 2023, 19:56:27 UTC |
c523e3b | Skye Wanderman-Milne | 28 August 2023, 19:49:25 UTC | Turn down jax_persistent_cache_min_compile_time_secs logging from info to debug. It's very noisy otherwise, since jax usually produces many small computations that aren't cached. | 28 August 2023, 19:49:25 UTC |
b09bef7 | jax authors | 28 August 2023, 18:52:17 UTC | Merge pull request #17300 from jakevdp:prng-test PiperOrigin-RevId: 560776573 | 28 August 2023, 18:52:17 UTC |
2f62111 | Jake VanderPlas | 28 August 2023, 17:59:40 UTC | DOC: mention static values in debug.print docs | 28 August 2023, 17:59:40 UTC |
2f878a7 | Jake VanderPlas | 25 August 2023, 21:11:19 UTC | Tests: set jax_legacy_prng_key='error' | 28 August 2023, 17:56:09 UTC |
c3e624a | Tomás Longeri | 28 August 2023, 17:46:34 UTC | [Mosaic] Fix assert PiperOrigin-RevId: 560756391 | 28 August 2023, 17:50:23 UTC |
70206ee | Peter Hawkins | 28 August 2023, 17:40:33 UTC | Give jax.numpy.array the type `Callable`. This is to prevent users from using as the type of arrays in type annotations. PiperOrigin-RevId: 560754568 | 28 August 2023, 17:41:07 UTC |
3ea141d | jax authors | 28 August 2023, 16:46:54 UTC | Merge pull request #17113 from jakevdp:faster-ufuncs PiperOrigin-RevId: 560737512 | 28 August 2023, 16:46:54 UTC |
cb7c7ad | Jake VanderPlas | 28 August 2023, 15:30:23 UTC | jnp.ufunc: add fast paths for add/prod reductions | 28 August 2023, 15:30:23 UTC |
f407298 | jax authors | 28 August 2023, 13:46:12 UTC | Merge pull request #17322 from abdulasiraj:patch-1 PiperOrigin-RevId: 560695446 | 28 August 2023, 13:46:12 UTC |
7b802dc | jax authors | 28 August 2023, 13:30:18 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/4f675c5715d23a49933e481d09adcf5065edfca5. PiperOrigin-RevId: 560691873 | 28 August 2023, 13:30:55 UTC |
599b35e | Muhammad Abdullah | 28 August 2023, 05:13:16 UTC | Update __init__.py | 28 August 2023, 05:13:16 UTC |
d09c55a | Muhammad Abdullah | 28 August 2023, 05:01:17 UTC | Update __init__.py to include dlpack module import for dlpack module was missing in `__init__.py` file. just added that | 28 August 2023, 05:01:17 UTC |
871c9f4 | jax authors | 27 August 2023, 19:58:50 UTC | Merge pull request #17307 from froystig:wrap-key PiperOrigin-RevId: 560536131 | 27 August 2023, 19:58:50 UTC |
66c0400 | jax authors | 27 August 2023, 13:21:49 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/1c22df67539a85904b4f686bcbe69efd518a5698. PiperOrigin-RevId: 560496105 | 27 August 2023, 13:22:28 UTC |
a69f134 | Roy Frostig | 26 August 2023, 00:36:15 UTC | add `jax.extend.random.wrap_key_data` | 26 August 2023, 18:39:25 UTC |
841baab | jax authors | 26 August 2023, 15:03:04 UTC | Adds Pallas flash attention TPU kernel. Implementation based on https://arxiv.org/pdf/2205.14135.pdf. PiperOrigin-RevId: 560346791 | 26 August 2023, 15:03:48 UTC |
08ca945 | jax authors | 26 August 2023, 13:51:59 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/7a371ed44aba34f83d6d3d1159d2e6d0d327c603. PiperOrigin-RevId: 560339802 | 26 August 2023, 13:52:36 UTC |
eee5c13 | Yash Katariya | 26 August 2023, 03:57:54 UTC | Donate only if result and arg memory kinds are not None. PiperOrigin-RevId: 560264967 | 26 August 2023, 03:58:39 UTC |
9aaacc5 | jax authors | 26 August 2023, 02:18:32 UTC | Merge pull request #17275 from jakevdp:ufunc-reduce-where PiperOrigin-RevId: 560251564 | 26 August 2023, 02:18:32 UTC |
4e227a3 | jax authors | 25 August 2023, 23:48:55 UTC | Merge pull request #17305 from mzgubic:patch-1 PiperOrigin-RevId: 560227591 | 25 August 2023, 23:48:55 UTC |
c35bc81 | Junwhan Ahn | 25 August 2023, 21:57:00 UTC | Add an optional `sharding` argument to `pure_callback` and `io_callback` This CL allows callers of `pure_callback` and `io_callback` to be able to specify the device to be used to run host callbacks. This is to make it possible to move users of `jax.experimental.host_callback(..., device_index=i)` (deprecated) to `pure_callback` or `io_callback`. Instead of taking a device index referring to a device in the device assignment, the new API takes sharding to match the look and feel with other JAX APIs. The current implementation supports `SingleDeviceSharding` only since the way sharding is annotated in StableHLO makes it tricky to use anything other than `MAXIMAL` or `MANUAL`. But if we later decide to expand support for other types of sharding, we will be able to do it without changing the API or breaking existing users. This CL also fixes an issue where `pure_callback` and `io_callback` had different sharding semantics inside `SPMDAxisContext`. Specifically, `io_callback` used to emit `MAXIMAL` even for `SPMDAxisContext`, whereas `pure_callback` used `MANUAL` sharding. The latter seems to make more sense since `SPMDAxisContext` with all manual axes should come with per-device semantics. This CL made both styles of callbacks use consistent sharding by factoring out sharding calculation as a common function. PiperOrigin-RevId: 560203044 | 25 August 2023, 21:57:35 UTC |
992e5e4 | Miha Zgubic | 25 August 2023, 21:39:15 UTC | Fix typo in jnp.interp docstring. | 25 August 2023, 21:39:15 UTC |
7ec2a88 | jax authors | 25 August 2023, 19:22:58 UTC | Merge pull request #17302 from google:mlir PiperOrigin-RevId: 560164637 | 25 August 2023, 19:22:58 UTC |
ae70fb7 | Peter Hawkins | 25 August 2023, 19:16:15 UTC | Add generated MLIR _enum_gen.py files to jaxlib wheel. Fixes CI failure after LLVM update. | 25 August 2023, 19:16:15 UTC |
970f4c9 | Yash Katariya | 25 August 2023, 17:59:10 UTC | Remove trivial execution from jax since it leads to 100x slower dispatch time. Trivial computations were added for a pre-omnistaging world. After omnistaging, JAX produces less trivial computations, so there is need for this to exist. In the future, if we want to support forwarding of inputs to outputs, there would need to be a different way which the C++ dispatch path knows about. ``` jit_trivial_dispatch 246µs ± 3% 4µs ± 1% -98.52% (p=0.008 n=5+5) jit_trivial 250µs ± 3% 5µs ± 1% -98.19% (p=0.008 n=5+5) ``` PiperOrigin-RevId: 560141018 | 25 August 2023, 17:59:48 UTC |
c71eedf | jax authors | 25 August 2023, 17:35:44 UTC | Merge pull request #17295 from jakevdp:lax-pow-jvp PiperOrigin-RevId: 560133324 | 25 August 2023, 17:35:44 UTC |
6cec5d4 | Jake VanderPlas | 25 August 2023, 17:05:55 UTC | lax.pow: fix shape mismatch failure in jvp rule | 25 August 2023, 17:05:55 UTC |
3ea0a74 | jax authors | 25 August 2023, 16:22:34 UTC | Merge pull request #17277 from hawkinsp:trapz PiperOrigin-RevId: 560111745 | 25 August 2023, 16:22:34 UTC |
b11e245 | Peter Hawkins | 25 August 2023, 15:47:11 UTC | Remove stale reference to coordination_service flag. Fixes https://github.com/google/jax/issues/17288 PiperOrigin-RevId: 560103075 | 25 August 2023, 15:47:53 UTC |
975dae3 | Peter Hawkins | 24 August 2023, 20:01:40 UTC | Deprecate jax.numpy.trapz. Expose the current implementation of jax.numpy.trapz as jax.scipy.integrate.trapezoid instead. Fixes https://github.com/google/jax/issues/17244 | 25 August 2023, 15:04:13 UTC |
a454081 | jax authors | 25 August 2023, 14:53:04 UTC | Merge pull request #17289 from google:nanobind PiperOrigin-RevId: 560090600 | 25 August 2023, 14:53:04 UTC |
9a5b808 | Peter Hawkins | 25 August 2023, 14:45:32 UTC | Update nanobind version to 1.5.2. | 25 August 2023, 14:45:32 UTC |
ac8ea86 | Peter Hawkins | 25 August 2023, 14:30:44 UTC | Fix accidental signature change to get_serialized_metadata() from nanobind PR. pybind11 accepts either Python strings or bytes as a std::string argument, whereas nanobind accepts only strings. Change the argument to nb::bytes instead. PiperOrigin-RevId: 560086072 | 25 August 2023, 14:31:31 UTC |
87cec1a | jax authors | 25 August 2023, 13:50:41 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/14c0e8565ab7774268e1c454c72a11535acc1aae. PiperOrigin-RevId: 560077922 | 25 August 2023, 13:51:19 UTC |
eea2603 | Yash Katariya | 25 August 2023, 05:23:13 UTC | Add a proper jax config for memories so that we can iteratively develop and enable it. PiperOrigin-RevId: 559977015 | 25 August 2023, 05:23:55 UTC |
0099327 | Roy Frostig | 25 August 2023, 02:46:50 UTC | get aval directly via attribute in key array shard arg handler No need to go through `core.get_aval` here. PiperOrigin-RevId: 559945841 | 25 August 2023, 02:47:35 UTC |