https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
0f30685 Remove StreamExecutor-based TPU runtime from Cloud TPU CI The old StreamExecutor-based backend is no longer supported as of https://github.com/google/jax/commit/3e50fea29edb6b78426dde511414429f2d2fddf8 09 August 2023, 17:05:46 UTC
be543f0 Merge pull request #17041 from mtsokol:update-ninf-usage PiperOrigin-RevId: 555185385 09 August 2023, 16:26:12 UTC
1fedf04 API: Remove NINF and PINF usages 09 August 2023, 12:16:33 UTC
c9cf6b4 Remove allowlist for multihost collectives. This allowlist used to prevent users from using collectives that didn't work correctly in multihost pmap(). But currently every collective in JAX (except for pgather(), which isn't public), is on the list. So the allowlist serves no purpose any more. PiperOrigin-RevId: 555124144 09 August 2023, 11:43:51 UTC
1bd5fd2 Add `serialize_with_paths` and `deserialize_with_paths` API to `GlobalAsyncCheckpointManager` PiperOrigin-RevId: 555050522 09 August 2023, 05:47:27 UTC
6615e23 Merge pull request #17032 from jakevdp:extended-doc PiperOrigin-RevId: 554999499 09 August 2023, 01:02:40 UTC
ca924cd Add visibility to jax2tf_internals. PiperOrigin-RevId: 554994907 09 August 2023, 00:40:01 UTC
1d27576 jax.dtypes.extended: fix docstring example 08 August 2023, 23:08:45 UTC
ca17b6c Move functions out of xla.py closer to their users. Refactoring only, no changes intended. The goal is to shrink xla.py down to only its HLO-compatibility role, and remove things that aren't related to HLO compatibility. Remove an unused top_k translation rule as well. PiperOrigin-RevId: 554946059 08 August 2023, 21:40:42 UTC
d01695c Change --jax_xla_profile_version definition to config. Changing the flag to a config permits more contained testing. This is in preparation for an upcoming change to incorporate AutoFDO profile versions in the cache key. Testing: test workload. PiperOrigin-RevId: 554942573 08 August 2023, 21:29:09 UTC
3e50fea Remove option to use StreamExecutor Cloud TPU client in JAX It's been over three months since the new PJRT C API client was enabled by default (https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-8-march-29-2023). PiperOrigin-RevId: 554935166 08 August 2023, 21:05:27 UTC
f05f197 Reverts changelist 554885148 PiperOrigin-RevId: 554930183 08 August 2023, 20:50:03 UTC
d8f7993 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17027 from jakevdp:dtypes-annotations a116a9c498a7b085f9b3fec93b37da12289f6e31 PiperOrigin-RevId: 554905739 08 August 2023, 20:38:44 UTC
6b07f5b Use core.valid_jaxtype() in xla.check_arg(). Fixes a TODO. PiperOrigin-RevId: 554885023 08 August 2023, 18:21:09 UTC
f37afb1 Add JAX primitives test suite in TFLite. PiperOrigin-RevId: 554883170 08 August 2023, 18:12:26 UTC
e58f1ba Move some utilities out of dispatch.py next to their users, add more types. Internal cleanups only, no user-visible changes intended. PiperOrigin-RevId: 554876522 08 August 2023, 17:52:11 UTC
afd56c1 Move jax.jaxpr_util to jax._src.jaxpr_util, and split it into a separate build target. Change jaxpr_util_test to be a py_test(), since there's no point testing it on every hardware configuration. PiperOrigin-RevId: 554861284 08 August 2023, 17:09:09 UTC
b024e01 Improve dispatch.py typing. Inline _xla_callable_uncached, which is trivial, into its only caller. Cleanup only, no user-visible changes intended. PiperOrigin-RevId: 554805210 08 August 2023, 13:34:34 UTC
dec2366 Create the failure test when tf.SavedModel miss the XLACallModule function_list after loading. PiperOrigin-RevId: 554726455 08 August 2023, 07:50:50 UTC
d17adde [jax2tf] Adjust tolerance for jax2tf graph serialization GPU tests. Should fix test flakyness. PiperOrigin-RevId: 554704965 08 August 2023, 06:04:30 UTC
5efc681 Cleanup comments for define_{int,float}_state. There is no parameter enum_values in these functions. Probably a copy/paste issue from define_enum_state. PiperOrigin-RevId: 554644871 08 August 2023, 00:41:09 UTC
b804988 [XLA:Client] Make HloSharding::iota_tile actually produce V2 shardings. PiperOrigin-RevId: 554631780 07 August 2023, 23:46:53 UTC
c949869 Merge pull request #17015 from mattjj:hypothesis-skip PiperOrigin-RevId: 554625928 07 August 2023, 23:23:44 UTC
cdb946b [pallas] skip indexing tests when hypothesis not available Co-authored-by: Roy Frostig <frostig@google.com> 07 August 2023, 22:28:50 UTC
c879f65 [JAX] Remove the non-coordination service distributed service implementation from JAX. The coordination service has been the default for a long time, and has significant additional functionality. Remove the older code path to simplify the code. PiperOrigin-RevId: 554608165 07 August 2023, 22:17:25 UTC
22285e6 Merge pull request #16971 from apaszke:pallas-tpu-docs PiperOrigin-RevId: 554587987 07 August 2023, 21:10:07 UTC
bf8e550 [Pallas] Flatten in_specs for PrefetchScalarGridSpec PiperOrigin-RevId: 554574437 07 August 2023, 20:28:44 UTC
9d67fb7 Merge pull request #16987 from sharadmv:pallas-docs PiperOrigin-RevId: 554568088 07 August 2023, 20:15:21 UTC
a600020 Update ducc to commit: 2b2cead005e08d2632478e831d7f45da754162dc NOTE: this version of DUCC has a breaking change, where the fft.h header no longer contains the definitions of many fft functions - instead they exist within fft1d_impl.h and fftnd_impl.h. PiperOrigin-RevId: 554567641 07 August 2023, 20:06:43 UTC
079ecfb Allow device mesh handler to return None and use the default logic PiperOrigin-RevId: 554563151 07 August 2023, 19:50:43 UTC
9152414 Merge pull request #17007 from jakevdp:callbacks-doc PiperOrigin-RevId: 554559029 07 August 2023, 19:35:48 UTC
e219456 Merge pull request #16972 from mtsokol:update-np-exceptions-imports PiperOrigin-RevId: 554548376 07 August 2023, 18:58:59 UTC
2fa0bb0 Add `initialization_timeout` as a parameter to allow users to increase/decreases the init_timeout parameter. PiperOrigin-RevId: 554545535 07 August 2023, 18:49:41 UTC
e4701b2 DOC: callback doc is no longer work-in-progress 07 August 2023, 18:16:56 UTC
e4d2e20 Merge pull request #16960 from JiaYaobo:add_random_lognormal_and_triang PiperOrigin-RevId: 554533959 07 August 2023, 18:14:04 UTC
b3a36ef Merge pull request #16965 from jakeh-gc:fp8_fnuz PiperOrigin-RevId: 554523823 07 August 2023, 17:43:18 UTC
d183a2c ENH: Update numpy exceptions imports 07 August 2023, 17:08:41 UTC
a80d952 Remove unused exports from jax.interpreters.pxla. This change removes names exported from jax.interpreters.pxla for which I couldn't find any JAX-external references. PiperOrigin-RevId: 554483707 07 August 2023, 15:27:03 UTC
5fcd926 Merge pull request #16975 from hawkinsp:win PiperOrigin-RevId: 554479893 07 August 2023, 15:10:21 UTC
bd50474 Merge pull request #16861 from gnecula:test_naming PiperOrigin-RevId: 554461041 07 August 2023, 13:48:50 UTC
85f124c Add support for float8_e4m3fnuz and float8_e5m2fnuz. 07 August 2023, 10:48:53 UTC
f2dbde5 [jax2tf] Sanitize the parameterized test case names to be friendly to -k Sanitizes the name of tests so that a name matches the rules of an identifier for pytest -k and unittest -k test filtering. Sequences of problematic characters are replaced with a single "_". 07 August 2023, 09:08:21 UTC
8d80e25 [jax2tf] Turn on JAX native serialization by default. See changes to the README.md for mechanisms to override the default. PiperOrigin-RevId: 554390866 07 August 2023, 08:03:55 UTC
28af186 [Pallas] Fix rendering of math in quickstart 05 August 2023, 07:20:00 UTC
6d18445 add random.triangular and random.lognormal add random.triangular and random.lognormal add random.triangular and random.lognormal 05 August 2023, 04:27:53 UTC
364a245 [Memories] Make memory_kind private in C++ and then expose it as a property in python to keep things consistent with other Shardings PiperOrigin-RevId: 553962861 05 August 2023, 01:06:25 UTC
1ae37b4 Canonicalize to default memory in init of Shardings only on the backends that support memories right now. PiperOrigin-RevId: 553942534 04 August 2023, 23:27:15 UTC
73d1b26 [Pallas] Add better error message for unimplemented primitives PiperOrigin-RevId: 553939077 04 August 2023, 23:12:06 UTC
63e9eca [Pallas] Make indexing hypothesis test deterministic PiperOrigin-RevId: 553928039 04 August 2023, 22:26:57 UTC
9adf319 [Pallas] Add support for jax.nn.relu to Mosaic lowering PiperOrigin-RevId: 553923643 04 August 2023, 22:09:02 UTC
8f9fcdc Merge pull request #16981 from patrick-kidger:ndim-fix PiperOrigin-RevId: 553913984 04 August 2023, 21:32:21 UTC
07f052a [JAX] Remove jax.interpreters.pxla._create_pmap_sharding_spec. No deprecation period: jax.interpreters.pxla is not covered by compatibility policies. PiperOrigin-RevId: 553911673 04 August 2023, 21:23:44 UTC
aecd5c9 [Pallas] Add support for lax.fori_loop to Mosaic lowering PiperOrigin-RevId: 553901144 04 August 2023, 20:43:43 UTC
16c33df [Mosaic] Relax return type checks for vector.contract PiperOrigin-RevId: 553898552 04 August 2023, 20:34:12 UTC
808b0b2 Crash fix from ndim 04 August 2023, 20:32:12 UTC
6e873ab Merge pull request #16938 from jakevdp:spsolve-grad PiperOrigin-RevId: 553889073 04 August 2023, 19:59:23 UTC
574c53e Merge pull request #16979 from patrick-kidger:changelog-traceback PiperOrigin-RevId: 553886881 04 August 2023, 19:50:22 UTC
d6dad38 Documented the shortening of tracebacks 04 August 2023, 19:20:19 UTC
eb1ab2b Add autodiff rules for spsolve_p Fixes #16935 04 August 2023, 19:08:57 UTC
4fb8cdb [Memories] Add Memories support to jax.jit and jax.device_put! These are the following changes: * Add a temporary flag (`JAX_FETCH_MEMORY_KIND_ON_EXECUTABLE`) (should not be used by user but needed in C++ in pjrt-ifrt code) on whether to fetch memory kinds from executable. If it is set to True, the host runtime dep needs to be linked in and should also work in OSS (more work needs to happen for that). So only the test sets it to True for now until jax memories is under development. * Add with_memory_kind method on Sharding to allow for easier creation of shardings with different memory kind. * Add lowering rules for device_put and jax.jit. * For device_put, we always add the annotation that describes a transfer to a memory and a sharding annotation. * For jax.jit, if the argument is on host memory, it will have an extra attribute _xla_buffer_placement. * Handle the correct output sharding in pxla.py by extracting the memory kind from the executable. * Handle the caching of pjit caches by canonicalizing the memory_kinds so that `NS(mesh, pspec) == NS(mesh, pspec, memory_kind='tpu_hbm')`. Also canonicalize memory_kind in `__hash__` and `__eq__` of shardings. * This is to not change the StableHLO to include device placement annotations right now since the host aware passes are not enabled by default and the work is under progress to make it work everywhere. PiperOrigin-RevId: 553833344 04 August 2023, 16:44:24 UTC
622e7da Disable some tests that appear to crash on Windows. 04 August 2023, 15:24:58 UTC
d7940ee Always skip the BCOO matmul test on CUDA 12 We seem to be consistently hitting cuSPARSE bugs in this test, so disabling it for now. PiperOrigin-RevId: 553801002 04 August 2023, 14:29:49 UTC
98191da Add a guide for writing Pallas TPU kernels 04 August 2023, 14:27:58 UTC
467adf3 Merge pull request #16953 from sharadmv:pallas-docs PiperOrigin-RevId: 553633577 04 August 2023, 00:08:18 UTC
96e2d93 [Pallas] Add Pallas design doc 03 August 2023, 23:02:07 UTC
df4e0cb [Mosaic] Add support for relayout from (8, 128) to (1, 128) with 32-bit data. PiperOrigin-RevId: 553606429 03 August 2023, 22:24:15 UTC
eff2a9e Merge pull request #16927 from jeertmans:patch-1 PiperOrigin-RevId: 553582895 03 August 2023, 21:02:02 UTC
a8388e2 Merge pull request #16949 from patrick-kidger:simplified-traceback PiperOrigin-RevId: 553580442 03 August 2023, 20:53:33 UTC
7f068dc Merge pull request #16940 from sharadmv:pallas-docs PiperOrigin-RevId: 553573987 03 August 2023, 20:34:21 UTC
77d544f Merge pull request #16947 from hawkinsp:docs PiperOrigin-RevId: 553572087 03 August 2023, 20:25:48 UTC
6e55c20 chore(docs): improve `jax.lax.scan` Make the docstring a bit more explicit about what is t Co-authored-by: Jake Vanderplas <jakevdp@google.com> 03 August 2023, 20:18:07 UTC
7f1ef32 Add initial documentation for Pallas 03 August 2023, 19:30:19 UTC
4c3408b Add support for custom compilation target devices PiperOrigin-RevId: 553555190 03 August 2023, 19:24:30 UTC
5e276d0 Tracebacks no longer have JAX-internal frames prepended by default 03 August 2023, 18:38:38 UTC
26727ea Delete jax.interpreters.pxla.replicate(). pxla.replicate() can be replaced by jax.device_put_replicated(). No deprecation period because jax.interpreters APIs are not stable. PiperOrigin-RevId: 553502827 03 August 2023, 16:37:00 UTC
a22c477 Merge pull request #16933 from jakevdp:jnp-place PiperOrigin-RevId: 553499383 03 August 2023, 16:23:56 UTC
d4336c1 Improve the sharding documentation. * do some proofreading. * add PmapSharding and GSPMDSharding, which are both missing. 03 August 2023, 15:12:02 UTC
0228bf7 Fix MSAN errors in cache_key_test The device_assignment array was never initialized, causing MSAN errors. Replacing it with np.arange fixes the issue. PiperOrigin-RevId: 553469463 03 August 2023, 14:28:32 UTC
e2634a2 Fix insufficient test dependencies for Pallas tests The test only runs on GPU and depends on the GPU backend, but did not depend on Pallas GPU. PiperOrigin-RevId: 553468241 03 August 2023, 14:19:29 UTC
a184b5e [Mosaic] Add a reinterpret cast for memrefs This allows us to override the inferred tiling of the values, which makes it possible to e.g. preswizzle the data into a more efficient format before the kernel. PiperOrigin-RevId: 553402946 03 August 2023, 08:50:19 UTC
bd5a457 Implement jax.numpy.place with required inplace parameter 02 August 2023, 21:29:26 UTC
a0c1265 Merge pull request #16932 from jakevdp:where-annotations PiperOrigin-RevId: 553255574 02 August 2023, 21:22:36 UTC
5a5730d Fix type annotations for jnp.where 02 August 2023, 20:42:20 UTC
7708cf5 [JAX] Remove jax.interpreters.pxla._pmap_sharding_spec. Praxis appears to be the only user of this deprecated API, and only for JAX versions older than the current Praxis JAX version requirement. PiperOrigin-RevId: 553230995 02 August 2023, 19:58:03 UTC
a6a8f48 [JAX] Don't include ShardingSpecs or out_indices in the data passed to the C++ pmap() fast path. The pmap() fast path doesn't even look the ShardingSpec or the out_indices since the jax.Sharding rework. PiperOrigin-RevId: 553206145 02 August 2023, 18:29:05 UTC
f498442 [jax][benchmark] Added clearing caches for benchmarking compilation time in sparse JAX benchmarks PiperOrigin-RevId: 553179605 02 August 2023, 17:07:54 UTC
391d45f [Pallas] Import ops in jax/experimental/pallas/ops/__init__.py PiperOrigin-RevId: 553031447 02 August 2023, 05:41:30 UTC
d872812 [Pallas] Upstream pallas to JAX PiperOrigin-RevId: 552963029 01 August 2023, 23:43:13 UTC
69cd3eb Merge pull request #16843 from mattjj:shmap-partial-auto PiperOrigin-RevId: 552936692 01 August 2023, 22:11:59 UTC
614bbcc Add internal jaxlib function for fetching the topology from a set of devices. We may want to make this topology serializable or usable as a cache key. PiperOrigin-RevId: 552931150 01 August 2023, 21:54:08 UTC
44da274 [shmap-partial-auto] start adding partial auto / partial manual to shmap Co-authored-by: Parker Schuh <parkers@google.com> 01 August 2023, 21:51:07 UTC
0116d19 Prune some exports from jax.experimental.pjit. jax.experimental.pjit is deprecated in its entirety (use "jit" instead), and experimental APIs have no stability promises. PiperOrigin-RevId: 552903601 01 August 2023, 20:27:17 UTC
2e042b6 Enable test for indexing with u8 indices. https://github.com/openxla/xla/commit/4e4eff35bf9a5f8ed54fd290391bd0612f49533e fixed the underlying XLA problem. Fixes https://github.com/google/jax/issues/6122 https://github.com/google/jax/issues/16836 PiperOrigin-RevId: 552880163 01 August 2023, 19:13:58 UTC
853c470 Improve the repr of NamedSharding and error message of device_put PiperOrigin-RevId: 552841710 01 August 2023, 17:17:20 UTC
4ddf6a9 Bump minimum_jaxlib_version to 0.4.14. `xla_extension_version` is 174 and `mlir_api_version` is 54 PiperOrigin-RevId: 552816893 01 August 2023, 15:53:28 UTC
109ed50 Don't depend on jax in mesh_utils to remove circular dependency. PiperOrigin-RevId: 552799864 01 August 2023, 14:45:42 UTC
716f4f8 [Mosaic] Allow users to opt out of window prefetching PiperOrigin-RevId: 552797922 01 August 2023, 14:36:59 UTC
b8019dc Merge pull request #16900 from gnecula:poly_dot2 PiperOrigin-RevId: 552684253 01 August 2023, 05:12:43 UTC
f049aee Merge pull request #16735 from patrick-kidger:better-debugger PiperOrigin-RevId: 552656115 01 August 2023, 02:18:11 UTC
c1d1be0 jax.debug.breakpoint now supports being DCE'd. Drive-by: fix #16186 31 July 2023, 22:36:43 UTC
back to top