https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
849f837 Merge pull request #12609 from skye:jax2tf_import PiperOrigin-RevId: 478104104 30 September 2022, 23:25:17 UTC
0a69c9a Fix jax2tf import so it works with both the latest tensorflow release (2.10.0) and tf-nightly 30 September 2022, 22:55:22 UTC
fb8558c Add jax_array coverage to debug_nans_test PiperOrigin-RevId: 478079509 30 September 2022, 21:21:32 UTC
ec41de2 Merge pull request #12603 from jbushago:patch-1 PiperOrigin-RevId: 478068347 30 September 2022, 20:33:03 UTC
ea77c45 Merge pull request #12602 from skye:colab PiperOrigin-RevId: 478063554 30 September 2022, 20:17:47 UTC
8a1e0ed Merge pull request #12594 from skye:cache_warnings PiperOrigin-RevId: 478063392 30 September 2022, 20:11:34 UTC
15e5f38 Make persistent compilation cache warn instead of raise an error on cache read/write failures Fixes #12582. Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning. Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled. 30 September 2022, 18:38:22 UTC
2038988 Fix typo in faq.rst. Fixed a small typo in the FAQ: "inthe" -> "in the". 30 September 2022, 18:14:05 UTC
0cc4066 Pin default jax.tools.colab_tpu.setup_tpu driver version. Prior to this change, we were defaulting to the TPU nightly driver version. We should instead pin to the version associated with the default jaxlib version that Colab uses. 30 September 2022, 17:45:49 UTC
aafc77d Improve the checks done in `Array` and apply them to all `Sharding`s rather than just `XLACompatibleSharding`. Also check the symmetric difference of sharding and `_arrays` devices. PiperOrigin-RevId: 478017409 30 September 2022, 16:56:16 UTC
4e51d2d Roll back https://github.com/google/jax/pull/12588 because of test failures PiperOrigin-RevId: 477871341 30 September 2022, 01:30:58 UTC
d498bd1 Merge pull request #12588 from jakevdp:random-annotations PiperOrigin-RevId: 477855302 29 September 2022, 23:52:42 UTC
9ff570e Make debug_nans_test.py pass with jax_array=1. Both with enabled and disabled jax_array flag and --pdb_post_mortem, we fall to the same place. PiperOrigin-RevId: 477850567 29 September 2022, 23:29:58 UTC
3c7d927 Disable dynamic_api_test and custom_object_test.py with jax.Array. Enable it back when support for it is added. Also don't use xla_shape since its deprecated. PiperOrigin-RevId: 477833061 29 September 2022, 22:09:55 UTC
aed46f3 [typing] use jax.Array annotations in random.py 29 September 2022, 21:29:02 UTC
eb0fa40 Fix `process_allgather` to work with `jax.Array`. PiperOrigin-RevId: 477793014 29 September 2022, 19:32:21 UTC
4f90af9 Remove unused jax_unique_mhlo_module_names flag. PiperOrigin-RevId: 477778135 29 September 2022, 18:32:22 UTC
a770db0 Merge pull request #12579 from jakevdp:gather-unique PiperOrigin-RevId: 477767679 29 September 2022, 17:56:14 UTC
7b49a3f Run tests in multiprocess_gpu_test only if the backend is GPU. PiperOrigin-RevId: 477750739 29 September 2022, 16:54:32 UTC
1bc161a random.permutation: use unique_indices=True for efficiency 29 September 2022, 16:34:03 UTC
d49c5c3 jnp.take: add optional arguments forwarded to lax.gather 29 September 2022, 16:33:38 UTC
137384d Update xla_sharding import path to new location We are moving the TensorFlow APIs outside of XLA and will remove the old path soon. PiperOrigin-RevId: 477701988 29 September 2022, 12:40:56 UTC
de5dd1a Merge pull request #12444 from LenaMartens:checkify-switch PiperOrigin-RevId: 477688623 29 September 2022, 11:18:18 UTC
0639ace Raise cond index into tracing context in case of effects. So even if the cond is not data dependent at all, it's included in the dynamic trace, and effects can be discharged. 29 September 2022, 10:36:04 UTC
48b8956 Merge pull request #12566 from mattjj:djax-slice-sick PiperOrigin-RevId: 477626935 29 September 2022, 04:23:38 UTC
163b7e2 Convert shardings in `jit` path to OpShardingSharding to avoid recompilation when semantically similar shardings are used in `jit`. PiperOrigin-RevId: 477626548 29 September 2022, 04:17:29 UTC
500f8b7 Add HLOSharding's repr to OpShardingSharding since its more compact. PiperOrigin-RevId: 477587916 29 September 2022, 00:00:16 UTC
84768d2 Replace `jax.xla.DeviceArray` private type with the new public type `jax.Array`. PiperOrigin-RevId: 477582562 28 September 2022, 23:34:10 UTC
a8826e6 [dynamic-shapes] Add basic slicing support If e.g. `x : f32[10, n]` then we want to handle Python expressions like `x[0]`. To do that, we can use a generalized version of `dynamic_slice` which allows dynamic slice sizes (where the result shape depends on those slice sizes). Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com> 28 September 2022, 22:55:51 UTC
33dbf0e Merge pull request #12565 from hawkinsp:release PiperOrigin-RevId: 477549228 28 September 2022, 21:14:30 UTC
b49e31a Update version numbers after release. 28 September 2022, 18:49:22 UTC
c89cb5d Use `Array` in `__repr__` instead of the class name which is `ArrayImpl`. PiperOrigin-RevId: 477465432 28 September 2022, 15:57:53 UTC
0282b4b Merge pull request #12538 from jakevdp:bundle-pyi PiperOrigin-RevId: 477453094 28 September 2022, 15:00:20 UTC
aafc70d Merge pull request #12556 from hawkinsp:rocm PiperOrigin-RevId: 477440001 28 September 2022, 13:50:19 UTC
5fe7a54 Merge pull request #12555 from hawkinsp:release PiperOrigin-RevId: 477439236 28 September 2022, 13:50:06 UTC
39eabe8 Merge pull request #12552 from hawkinsp:nccl PiperOrigin-RevId: 477439228 28 September 2022, 13:43:45 UTC
f7bafb3 Disable multiprocess_gpu_test that fails on ROCm. 28 September 2022, 13:40:57 UTC
8d86436 jax/jaxlib 0.3.20 release candidate. 28 September 2022, 13:33:52 UTC
eabb91e Fix test failure in GPU CI if NCCL_DEBUG is enabled. If NCCL_DEBUG is enabled, NCCL prints extra status information. Make test accept this. 28 September 2022, 13:06:04 UTC
96abd9a Merge pull request #12540 from sharadmv:cond-lowering-fix PiperOrigin-RevId: 477358889 28 September 2022, 05:33:12 UTC
96a85bd Make addressable_shards a property like local_shards PiperOrigin-RevId: 477358276 28 September 2022, 05:27:19 UTC
9489068 Merge pull request #12546 from mattjj:issue12542 PiperOrigin-RevId: 477356925 28 September 2022, 05:16:02 UTC
ddeaa8d Fix lowering bug in effectful batched cond and add tests 28 September 2022, 05:12:13 UTC
b4e1d0a Propagate `name` through ExecuteReplicated for `dispatch.check_special` PiperOrigin-RevId: 477351323 28 September 2022, 04:32:32 UTC
b175e11 [c++ jit] only set use_fastpath in cache_miss if all args are DeviceArrays fixes #12542 Co-authored-by: Peter Hawkins <phawkins@google.com> Co-authored-by: Kuangyuan Chen <chky@google.com> 28 September 2022, 03:51:07 UTC
933b6a2 Fix the bug where XLA doesn't provide shardings for all the outputs if all the elements in the output tuple have the same sharding. XLA decides to run the `FusionTupleDeduplicator` to put the sharding on ROOT instead of the tuple. PiperOrigin-RevId: 477343328 28 September 2022, 03:27:39 UTC
c8bff11 Add `addressable_` counterparts of `local_` to GDA to make it easier for users to move to Array as both will have the same API. PiperOrigin-RevId: 477332697 28 September 2022, 02:19:29 UTC
e4f2bff Disintegrate `Array` into DeviceBuffers inside GDA. This is required for backwards compatibility changes as users can create GDAs and pass that to pjit even when Array is switched on. PiperOrigin-RevId: 477297406 27 September 2022, 23:02:23 UTC
0919a67 Merge pull request #12534 from google:update-pypi PiperOrigin-RevId: 477260550 27 September 2022, 20:31:05 UTC
6e6fb10 setup: bundle *.pyi files with distribution 27 September 2022, 19:55:42 UTC
d028d93 Update version and changelog for jax 0.3.19 release 27 September 2022, 18:00:27 UTC
9e4114f Move `array.py` and `sharding.py` from `experimental/` to `_src/`. PiperOrigin-RevId: 477201711 27 September 2022, 17:06:52 UTC
0e11688 Merge pull request #12382 from jakevdp:reduction-dtype PiperOrigin-RevId: 477179725 27 September 2022, 15:38:46 UTC
1bcf8d6 Merge pull request #12497 from mattjj:djax-dag-fix1 PiperOrigin-RevId: 477038279 27 September 2022, 01:14:56 UTC
e42247b Merge pull request #12524 from sharadmv:lax-import-fix PiperOrigin-RevId: 477038211 27 September 2022, 01:08:45 UTC
389a2e5 Add a backwards compat path for `op_sharding.clone()` because it doesn't exist with the latest jaxlib on pypi PiperOrigin-RevId: 477034758 27 September 2022, 00:50:19 UTC
1e7ca8f fix bug in djax type signature inference logic Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com> 27 September 2022, 00:48:25 UTC
1d895b2 Fix lax imports 27 September 2022, 00:32:44 UTC
cbf34cb Rename the concrete class `Array` to `ArrayImpl` PiperOrigin-RevId: 477017236 26 September 2022, 23:18:30 UTC
71bcabe [sparse] Add BCSR format template. PiperOrigin-RevId: 477013899 26 September 2022, 23:02:16 UTC
82636b0 Merge pull request #12523 from jakevdp:fix-build PiperOrigin-RevId: 477005157 26 September 2022, 22:23:02 UTC
6cae54f Fix bazel build alias 26 September 2022, 22:13:12 UTC
d63a944 Change jax_jit_test to be a jax_test() under Bazel that works across backends. Make it pass under TPU if x64 types are enabled. PiperOrigin-RevId: 476994286 26 September 2022, 21:38:35 UTC
265b39d Add pytype_srcs to main jax BUILD rule PiperOrigin-RevId: 476989241 26 September 2022, 21:18:13 UTC
ddd8581 Merge pull request #12480 from google:bug-template-gpu-smi PiperOrigin-RevId: 476979981 26 September 2022, 20:41:31 UTC
1860f6d [x64] add promote_integers argument to jnp.prod & jnp.sum 26 September 2022, 20:31:43 UTC
69d1a2c Merge pull request #12517 from skye:update-pypi PiperOrigin-RevId: 476969287 26 September 2022, 20:00:31 UTC
b2b60d9 Add `make_array_from_single_device_arrays` to prepare to rename of the concrete `Array` to `ArrayImpl`. PiperOrigin-RevId: 476965287 26 September 2022, 19:43:59 UTC
3c0d280 Update version and changelog for jax 0.3.18 release 26 September 2022, 19:43:39 UTC
2a7b319 add `nvidia-smi` question to bug template 26 September 2022, 18:06:29 UTC
e034432 Merge pull request #12513 from inoryy:patch-4 PiperOrigin-RevId: 476923412 26 September 2022, 17:04:14 UTC
7962b01 Merge pull request #12485 from LenaMartens:checkify-lower PiperOrigin-RevId: 476922387 26 September 2022, 16:53:40 UTC
27e3981 lowerable errors behind a config flag. 26 September 2022, 16:34:27 UTC
8bcf358 Remove unused _remat_static_argnums import. 26 September 2022, 16:14:09 UTC
78ecc14 Lowerable checks!! 26 September 2022, 15:54:18 UTC
28672cc Merge pull request #12496 from mattjj:improve-leak-checker-2 PiperOrigin-RevId: 476907407 26 September 2022, 15:50:13 UTC
9c66569 Merge pull request #12468 from LenaMartens:checkify-but-better PiperOrigin-RevId: 476901601 26 September 2022, 15:23:02 UTC
2df61b1 Merge pull request #12421 from jakevdp:jax-array PiperOrigin-RevId: 476898184 26 September 2022, 15:07:11 UTC
0cb233e Add initial jax.Array base class for instance checks & annotation 26 September 2022, 14:48:43 UTC
2180710 Merge pull request #12511 from hawkinsp:release PiperOrigin-RevId: 476889960 26 September 2022, 14:24:44 UTC
bcd36d8 Jax and jaxlib 0.3.18 release candidate. 26 September 2022, 14:10:57 UTC
53de057 Merge pull request #12510 from hawkinsp:context PiperOrigin-RevId: 476884674 26 September 2022, 13:58:46 UTC
f4bc663 Wrap multiprocess test popen() uses in a context manager. Ensures that resources from popen() are cleaned up. 26 September 2022, 13:48:56 UTC
ec15e83 - Wraps calls to lax.xeinsum and _einsum in a named call with their 'spec', the string specifying the computation. Makes xprof traces more interpretable. PiperOrigin-RevId: 476796185 26 September 2022, 03:54:17 UTC
7c85ca3 Only look at hlo_modules for output sharding if there is more than 1 device because if there is only 1 device, the spmd partitioner won't run. PiperOrigin-RevId: 476497929 24 September 2022, 00:31:33 UTC
8ee7129 Fix jnp.unwrap() test failures on GPU. A recent XLA change allows XLA to use excess precision on GPU, which caused CompileAndCheck to report noticeable numerical changes for bfloat16. In passing, also enable the comparison against NumPy test for bfloat16 by using a wrapper function. PiperOrigin-RevId: 476494989 24 September 2022, 00:11:51 UTC
d2fcfb6 Merge pull request #12407 from hirwa-nshuti:docs-fix PiperOrigin-RevId: 476467728 23 September 2022, 21:51:11 UTC
03abcc7 fix typo in test 23 September 2022, 21:43:24 UTC
e76aa77 Merge pull request #12437 from sudhakarsingh27:add_multi_host_pjit_tests PiperOrigin-RevId: 476451469 23 September 2022, 20:38:59 UTC
1fa0dda Return single device Arrays from `.device_buffer` and `.device_buffers`. PiperOrigin-RevId: 476449591 23 September 2022, 20:30:26 UTC
43bbce0 Merge pull request #12486 from hawkinsp:debugging PiperOrigin-RevId: 476445041 23 September 2022, 20:09:26 UTC
737327a Merge pull request #12490 from mattjj:improve-leak-checker PiperOrigin-RevId: 476442352 23 September 2022, 19:58:03 UTC
b6ef90f fix leak checker internal error The issue was that partial_eval.py's _memoize, used in custom_jvp, was made into an identity function by enabling config.jax_check_tracer_leaks (from references to the main trace (needed for the jvp_jaxpr thunk) and hence trigger the leak checker (which would see if any references to the main trace persisted after finishing tracing of the user function). But after #7345, the leak checker should only trigger when actual Tracers are leaked. So disabling the memoization when jax_check_tracer_leaks is no longer active shouldn't be necessary. (These PR numbers seem out of order! We're not sure why.) Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com> 23 September 2022, 19:33:45 UTC
4dd0d85 add multihost pjit tests 23 September 2022, 19:11:56 UTC
a6b24b3 Add regression test for lax.rev simplification error PiperOrigin-RevId: 476430486 23 September 2022, 19:07:15 UTC
ecb27a9 Update the `_check_special` code to not use xla_shape since its deprecated and does not work with Array. PiperOrigin-RevId: 476422732 23 September 2022, 18:40:32 UTC
d078f3f Merge pull request #12478 from sharadmv:sharding-docs PiperOrigin-RevId: 476420315 23 September 2022, 18:31:37 UTC
e8865c8 Merge pull request #12481 from kho:changelist/476272494 PiperOrigin-RevId: 476411483 23 September 2022, 17:55:10 UTC
c823151 Allow transpose axes to be negative to match (undocumented) NumPy behavior 23 September 2022, 17:18:23 UTC
0c08547 Modify CorrCoef test to not rely on floating poing representation of 1/3 The operation computed an average while using the dimension of size 3. This is then changed to multiplying by 1/3 with compilers, but 1/3 cannot be represented perfectly. That made this test case rely on a very precise result from an unrepresentable calculation. PiperOrigin-RevId: 476391389 23 September 2022, 16:39:01 UTC
back to top