https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
7783a5e Merge pull request #14343 from skye:cache_options_check PiperOrigin-RevId: 507933881 08 February 2023, 01:49:31 UTC
1228cbd Change the executable_build_options check in compilation_cache.py to be more robust. Prior to this change, the check would spuriously fire on Python 3.11 because it added a default `__getstate__` method to all objects. This change makes it so we only look at public fields and methods. 08 February 2023, 01:22:14 UTC
a46f31b Merge pull request #14342 from skye:version PiperOrigin-RevId: 507907767 07 February 2023, 23:51:49 UTC
8ab1585 Update WORKSPACE and setup.py for jax/jaxlib 0.4.3 release 07 February 2023, 23:45:28 UTC
eb13c05 Add option to run tests with persistent compilation cache enabled. This can help us get a lot more coverage of the compilation cache, since all compiles will trigger it, instead of having to write explicit compilation cache tests. PiperOrigin-RevId: 507898535 07 February 2023, 23:15:31 UTC
6860cb8 Move jax.interpreters.xla to jax._src.interpreters.xla. Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally. PiperOrigin-RevId: 507895040 07 February 2023, 23:01:32 UTC
9c827fb Merge pull request #14340 from ROCmSoftwarePlatform:rocm_reenable_linalg_sparse_tests PiperOrigin-RevId: 507886628 07 February 2023, 22:30:37 UTC
01a10a1 [ROCm] Re-enable some linalg and sparse tests 07 February 2023, 22:05:14 UTC
8eb00c5 Merge pull request #14335 from jakevdp:doc-transformations PiperOrigin-RevId: 507864667 07 February 2023, 21:09:37 UTC
a022a4e DOC: remove transformations.md It's currently unused, and the content duplicates what's in the README 07 February 2023, 20:32:11 UTC
98b75cf Prune accidental exports from jax.interpreters.pxla. These imports do not appear to have users outside JAX itself. PiperOrigin-RevId: 507835295 07 February 2023, 19:16:42 UTC
5cfd15b Merge pull request #14334 from jakevdp:fix-doc-conf PiperOrigin-RevId: 507811052 07 February 2023, 17:56:00 UTC
d4f422f Merge pull request #14303 from carlosgmartin:rankdata PiperOrigin-RevId: 507805953 07 February 2023, 17:37:04 UTC
3ab0633 DOC: simplify jax-101 patterns in conf.py 07 February 2023, 17:36:26 UTC
92eb131 Merge pull request #14319 from jakevdp:doc-contributing PiperOrigin-RevId: 507803923 07 February 2023, 17:29:08 UTC
c9d2186 Merge pull request #14332 from jakevdp:doc-pjit-stub PiperOrigin-RevId: 507800497 07 February 2023, 17:16:33 UTC
8251957 Added scipy.stats.rankdata 07 February 2023, 17:07:00 UTC
ef45db7 DOC: add stub for removed pjit tutorial 07 February 2023, 16:44:56 UTC
d0abb72 DOC: update contributing guide 07 February 2023, 16:06:45 UTC
219723c migrate internal dependencies from `jax.interpreters.ad` to `jax._src.interpreters.ad` ... in preparation for paring down `jax.interpreters.ad`'s exported symbols. Includes some import fixups along the way. PiperOrigin-RevId: 507684262 07 February 2023, 06:52:36 UTC
c252162 Make pjit's cache global just like `jit`'s cache. This will allow cache hits in C++ when `pjit(f)(jnp.arange(3.))` is executed twice. Also includes Peter's change to fix the cache hit behavior which was broken at HEAD with jit. PiperOrigin-RevId: 507662634 07 February 2023, 04:35:26 UTC
b03606f Merge pull request #14323 from mattjj:shmap-add-trivial-rules PiperOrigin-RevId: 507636946 07 February 2023, 02:23:01 UTC
4214cb1 Merge pull request #14321 from mattjj:shmap-axis-index PiperOrigin-RevId: 507630920 07 February 2023, 01:52:20 UTC
198bfe3 [shard_map] add a lot of trivial rules 07 February 2023, 01:45:47 UTC
6cef087 Don't write executables with host callbacks to persistent compilation cache. The persistent compilation cache can't de/serialize the callback functions (yet?). PiperOrigin-RevId: 507628297 07 February 2023, 01:37:32 UTC
2eb10d2 Correctly hash auto_spmd fields in compilation cache key. I'm in the process of adding test coverage for this (https://github.com/google/jax/pull/14314), which is how I found this! I manually verified with the new test coverage that it's fixed. PiperOrigin-RevId: 507624101 07 February 2023, 01:15:23 UTC
6db3f48 [shard_map] add rep rule for axis_index, trivial test 07 February 2023, 00:59:22 UTC
08ff7f4 Prune accidentally exported names from jax.interpreters.ad. PiperOrigin-RevId: 507584433 06 February 2023, 22:36:44 UTC
38a59a3 Move jax.interpreters.pxla to jax._src.interpreters.pxla. Make jax.interpreters.pxla a shim that at the moment re-exports everything in the implementation, with the goal of reducing it over time. PiperOrigin-RevId: 507584264 06 February 2023, 22:29:10 UTC
3d9ae6b Add a .cost_analysis() on lowered but uncompiled computations. Allows users to call XLA's HLO cost analysis without using internal APIs. In practice plenty of users appear to be doing this using non-public APIs, so we may as well offer a supported API for it. PiperOrigin-RevId: 507560058 06 February 2023, 20:57:57 UTC
f37f00d Merge pull request #14274 from jakevdp:sparsify-bcsr PiperOrigin-RevId: 507533389 06 February 2023, 19:19:53 UTC
597c201 [sparse] support BCSR in sparsify transform 06 February 2023, 19:01:57 UTC
25d8eb0 Merge pull request #14280 from jakevdp:bcoo-broadcast-performance PiperOrigin-RevId: 507524731 06 February 2023, 18:49:59 UTC
8a69444 Bump minimum jaxlib_version to 0.4.2 i.e xla_extension_version == 119 and mlir_api_version == 43 PiperOrigin-RevId: 507520956 06 February 2023, 18:37:33 UTC
953ad90 Merge pull request #14271 from jakevdp:sparse-conv PiperOrigin-RevId: 507511980 06 February 2023, 18:07:32 UTC
63e0e0f Merge pull request #14291 from sharadmv:fix-checkify-caching PiperOrigin-RevId: 507504176 06 February 2023, 17:39:07 UTC
a13a2c5 [JAX] Remove obsolete unit type declarations in jax.core. Remove obsolete unit test in host_callback. PiperOrigin-RevId: 507473737 06 February 2023, 15:33:14 UTC
fbbd442 Remove support for classic HLO computations in compilation cache. These are never used except in this unit test any more; we always use MLIR. PiperOrigin-RevId: 507473543 06 February 2023, 15:24:46 UTC
077ff29 [jax2tf] Fixes a bug in flax model testing. We should also strip commas from the example name otherwise we cannot pass it through the command-line. Also added some documentation for this. PiperOrigin-RevId: 507413528 06 February 2023, 09:42:00 UTC
a12679b If there is only 1 process in process_allgather then just pull it to host without going via pjit. PiperOrigin-RevId: 507318748 05 February 2023, 22:01:21 UTC
be67db3 Skip `testAutodiffCache` test if xla_extension_version < 123 PiperOrigin-RevId: 507292333 05 February 2023, 17:39:36 UTC
a30ba83 Fix the latest jax jaxlib on pypi failure PiperOrigin-RevId: 507208172 05 February 2023, 04:16:33 UTC
2567331 Only do the XLA sharding override check if `xla_extension_version >= 123` because the xla change for not overriding sharding is at HEAD. PiperOrigin-RevId: 507180051 04 February 2023, 23:51:26 UTC
973bdb2 Copy the jit docs and paste it inside the new jit fork. PiperOrigin-RevId: 507161252 04 February 2023, 20:34:35 UTC
134db08 Use `new_mesh_sharding_specs` since `mesh_sharding_specs` is deprecated PiperOrigin-RevId: 507159068 04 February 2023, 20:14:21 UTC
9ad22b1 Merge pull request #14290 from gnecula:poly_hashable PiperOrigin-RevId: 507137155 04 February 2023, 16:32:54 UTC
c231171 Fix checkify caching with nested call primitives 04 February 2023, 07:28:37 UTC
15be538 [shape_poly] Fix the hashing and equality of symbolic dimensions 04 February 2023, 06:30:44 UTC
f445c84 Add support for a list of `allow_spmd_sharding_propagation_to_output`. This gives us more flexibility to tell SPMD which shardings to override. PiperOrigin-RevId: 507035958 04 February 2023, 01:59:10 UTC
428713e [sparse] support all unbatched 1D convolutions 03 February 2023, 23:59:42 UTC
0affb3b Merge pull request #14283 from pschuh:static_argnums_custom_partitioning PiperOrigin-RevId: 507005561 03 February 2023, 23:14:08 UTC
428189f Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding. This change updates: * {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh * {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec * jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding. PiperOrigin-RevId: 506994892 03 February 2023, 22:28:45 UTC
136c11a Clear pjit's cache too in clear_backends() similar to jit. PiperOrigin-RevId: 506989563 03 February 2023, 22:08:07 UTC
def35b7 Remove scatter/gather dimension proto helpers. These are unused since the MHLO switch. PiperOrigin-RevId: 506969590 03 February 2023, 20:40:31 UTC
7526d0e Add static_argnums to custom_partitioning. Arguments specified by static_argnums cannot contain any jax tracers because they will be passed into the XLA compiler where the lowering information for these tracers is already lost. 03 February 2023, 19:41:17 UTC
fada1c2 [XLA] Add way to allow propagation to output only to a subset of root instruction tuple shardings. PiperOrigin-RevId: 506935285 03 February 2023, 18:22:33 UTC
613fd3c [sparse] improve performance of bcoo_broadcast_in_dim 03 February 2023, 18:16:41 UTC
5bc14fd Merge pull request #14277 from gnecula:poly_div PiperOrigin-RevId: 506905837 03 February 2023, 16:11:30 UTC
f147e82 [shape_poly] Add support for evaluating div/mod for DimExpr We have added the ability to represent floordiv and mod to DimExper. Here we add support for evaluating these dimensions for the native lowering. 03 February 2023, 15:44:26 UTC
b8d6efe Merge pull request #14273 from mattjj:shard-map PiperOrigin-RevId: 506820113 03 February 2023, 07:25:39 UTC
ff1e9b3 shard_map (shmap) prototype and JEP Co-authored-by: Sharad Vikram <sharadmv@google.com> Co-authored-by: Sholto Douglas <sholto@google.com> 03 February 2023, 07:01:30 UTC
f18b91d Merge pull request #14276 from mattjj:core-type-annotation-tweaks PiperOrigin-RevId: 506802661 03 February 2023, 05:15:54 UTC
644d3b6 minor tweaks to type annotations, specialize code on those types I noticed some slightly-too-general type annotations in core.py. By tightening them we could simplify the code too. (I think these were leftovers from pre-omnistaging...) 03 February 2023, 04:24:26 UTC
0f289ab Merge pull request #14174 from google:pjrt_test PiperOrigin-RevId: 506751529 03 February 2023, 00:23:26 UTC
30c9376 Merge pull request #14272 from jakevdp:conditions PiperOrigin-RevId: 506737218 02 February 2023, 23:22:24 UTC
c4ec299 Sharp bits: mention alternatives to lax.cond 02 February 2023, 21:19:26 UTC
5e51995 Merge pull request #14269 from hawkinsp:notimpl PiperOrigin-RevId: 506697948 02 February 2023, 20:55:47 UTC
6d2aea2 Merge pull request #14270 from hawkinsp:device PiperOrigin-RevId: 506697883 02 February 2023, 20:48:09 UTC
b730ed4 Remove placeholder functions for unimplemented NumPy functions. These don't seem necessary now JAX has fairly complete coverage of the NumPy API. Also removes the accidental export of _NOT_IMPLEMENTED in several modules. 02 February 2023, 18:00:18 UTC
74f1ab0 Export Device as jax.Device. Users are writing things like jax.lib.xla_client.Device in type annotations which is not a public API. Add a supported public name for the Device type. 02 February 2023, 17:58:15 UTC
365262b Reapply: move `jax.interpreters.ad` to `jax._src.interpreters.ad` Re-export roughly all of the same symbols via `jax.interpreters.ad` for now. This version of the PR includes the names jax.interpreters.ad.source_info_util and jax.interpreters.ad.config, which the neural tangents is using. PiperOrigin-RevId: 506642132 02 February 2023, 17:29:05 UTC
795c14b Merge pull request #14252 from jakevdp:sparse-conv PiperOrigin-RevId: 506641181 02 February 2023, 17:21:26 UTC
04525e8 Revert: move `jax.interpreters.ad` to `jax._src.interpreters.ad` Re-export roughly all of the same symbols via `jax.interpreters.ad` for now. This change broke some tests. PiperOrigin-RevId: 506606721 02 February 2023, 14:52:47 UTC
e5b2c5e Remove the jit_pjit_api_merge disable for api_test now that it is passing PiperOrigin-RevId: 506508148 02 February 2023, 05:03:30 UTC
a79dea5 Merge pull request #14263 from mattjj:custom-jvp-nondiff-argnums-tracers PiperOrigin-RevId: 506506008 02 February 2023, 04:49:47 UTC
a7964a7 Merge pull request #14262 from froystig:issue14249 PiperOrigin-RevId: 506503680 02 February 2023, 04:35:44 UTC
cd615b6 skip custom_jvp/vjp tests which dont work with initial-style staging These tests, involving nondiff_argnums and/or closing over tracers, happen to work with final-style JIT but not our initial-style primitives. We shouldn't support this behavior anyway; there are good alternatives. 02 February 2023, 04:34:47 UTC
26b75ff add "linear solve batching via `jacrev`" test from github.com/google/jax/issues/14249 02 February 2023, 04:01:53 UTC
e199b35 Revert "Merge pull request #14113 from botev:main" This reverts commit 69d18cc7b58ae4ed82246605d66ed07a49fad676, reversing changes made to 13e875f8b8d8dd9152045c7e3b5045a9bb0d7db0. Reverting until we address https://github.com/google/jax/issues/14249 02 February 2023, 03:50:27 UTC
0e77af0 move `jax.interpreters.ad` to `jax._src.interpreters.ad` Re-export roughly all of the same symbols via `jax.interpreters.ad` for now. PiperOrigin-RevId: 506490796 02 February 2023, 03:46:47 UTC
038798e [sparse] add support for simple 1D convolutions 02 February 2023, 02:53:49 UTC
4d56def Merge pull request #14257 from jakevdp:sparse-rev PiperOrigin-RevId: 506483272 02 February 2023, 02:51:58 UTC
9d5132f [jax] Skip compilation cache test for older jaxlibs PiperOrigin-RevId: 506460144 02 February 2023, 00:53:19 UTC
7a5a63f Merge pull request #14250 from mattjj:checkify-retracing PiperOrigin-RevId: 506458253 02 February 2023, 00:44:56 UTC
4fa80b4 [sparse] implement sparse rule for lax.rev 01 February 2023, 23:43:47 UTC
782a34f Add more logging to serialization code to figure out exactly where we are during async checkpointing. PiperOrigin-RevId: 506438425 01 February 2023, 23:24:46 UTC
06e3d8c Merge pull request #14251 from jakevdp:sparse-len PiperOrigin-RevId: 506428591 01 February 2023, 22:53:47 UTC
1a858fe Merge pull request #14254 from jakevdp:cleanup-dtypes PiperOrigin-RevId: 506428355 01 February 2023, 22:46:24 UTC
0cd3dee Consolidate the experimental_get_compiler_ir eager and tf function path in jax2tf.call_tf. PiperOrigin-RevId: 506424270 01 February 2023, 22:31:14 UTC
c241ae6 add blank line, mainly to trigger/test source sync PiperOrigin-RevId: 506414439 01 February 2023, 21:56:29 UTC
c90a854 Merge pull request #14248 from jakevdp:dead-code PiperOrigin-RevId: 506405131 01 February 2023, 21:25:46 UTC
72dfb23 Remove jax.dtypes._jax_types 01 February 2023, 20:49:06 UTC
27c068e [sparse] implement __len__ on sparse objects 01 February 2023, 19:46:02 UTC
684846b checkify: cache jaxpr formation so we don't always retrace 01 February 2023, 18:19:47 UTC
0b5443c Clean up: remove unused helper functions 01 February 2023, 17:55:58 UTC
fcb9dfb Merge pull request #14236 from rwitten:rwitten_debug_docs PiperOrigin-RevId: 506154997 01 February 2023, 00:52:16 UTC
b0202b6 Merge pull request #14235 from jakevdp:take-doc PiperOrigin-RevId: 506151220 01 February 2023, 00:34:50 UTC
278ff25 Update docs that jax.debug is unsupported on Cloud TPUs 01 February 2023, 00:12:51 UTC
14a0fe0 DOC: improve documentation of OOB indices in jnp.take 31 January 2023, 23:59:06 UTC
957adbd Merge pull request #14234 from jakevdp:fix-doc PiperOrigin-RevId: 506134475 31 January 2023, 23:30:01 UTC
back to top