https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
6927e5d Add 1-hour timeout to each Cloud TPU CI job. Sometimes they hang, and the default timeout is 6 hours, which is way too long. 29 March 2023, 21:23:45 UTC
f282c25 Add minimal pyproject.toml specifying build system Replaces #15274, Fixes #15256 PiperOrigin-RevId: 520367622 29 March 2023, 17:08:30 UTC
cfa330b Merge pull request #15283 from JiaYaobo:fix_wald_doc PiperOrigin-RevId: 520364879 29 March 2023, 17:01:27 UTC
2d94f76 Merge pull request #15278 from hawkinsp:cudainstall PiperOrigin-RevId: 520364354 29 March 2023, 16:53:58 UTC
fbc05ee Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago PiperOrigin-RevId: 520356179 29 March 2023, 16:23:22 UTC
a964ae7 Internal Code Change PiperOrigin-RevId: 520341781 29 March 2023, 15:23:56 UTC
7200d07 Merge pull request #15286 from hawkinsp:testjobs PiperOrigin-RevId: 520319910 29 March 2023, 13:39:04 UTC
d9b0f3c Recommend --local_test_jobs in bazel test command line on GPU. 29 March 2023, 13:28:53 UTC
07fc022 Merge pull request #15279 from hawkinsp:versions PiperOrigin-RevId: 520307157 29 March 2023, 12:25:50 UTC
3a4d0b3 remove scale in wald docstring 29 March 2023, 03:39:49 UTC
705b5cc Add version constraints to CUDA pip wheel dependencies. Fixes https://github.com/google/jax/issues/15267 29 March 2023, 01:55:32 UTC
775f404 Update the CUDA installation instructions. 29 March 2023, 01:46:07 UTC
c2d6fcc Split core.py and several files in an SCC with it into a separate Bazel build target. PiperOrigin-RevId: 520192610 29 March 2023, 01:31:13 UTC
8c4fed6 Merge pull request #15270 from skye:pjrt_c_api PiperOrigin-RevId: 520156646 28 March 2023, 22:48:27 UTC
473d1c3 Turn on PJRT C API by default. I forgot that the default setting is actually in jaxlib: https://github.com/openxla/xla/blob/fbe9a80fdb8c429e8a175962459da348cd560a50/xla/python/xla_client.py#L135 To be able to make this change as a jax-only release, I manually set the env var on Cloud TPU if it isn't already set. 28 March 2023, 22:28:13 UTC
5ae2e79 Silence some pytype errors. PiperOrigin-RevId: 520150523 28 March 2023, 22:24:48 UTC
4061bbb Merge pull request #15269 from skye:min_jaxlib_version PiperOrigin-RevId: 520127548 28 March 2023, 21:02:27 UTC
00acf45 Bump minimum jaxlib version from 0.4.6 to 0.4.7. Also removes a bunch of dead version guards (0.4.7 has xla_extension_version 144 and mlir_api_version 47) 28 March 2023, 20:43:01 UTC
014033b Merge pull request #15266 from mehdiataei:mehdiataei-patch-1 PiperOrigin-RevId: 520111030 28 March 2023, 20:09:56 UTC
bbec461 Merge pull request #15263 from jakevdp:deprecations PiperOrigin-RevId: 520110559 28 March 2023, 20:02:32 UTC
fc47137 Add deprecation warnings for several top-level jax imports 28 March 2023, 19:40:59 UTC
2f105bd Jax 0.4.7 has been released so assert that length of warnings is 1 in test_cache_read_warning PiperOrigin-RevId: 520098757 28 March 2023, 19:17:30 UTC
8d090a0 Fixed spelling error in msgs 28 March 2023, 18:48:18 UTC
2fbccc8 Merge pull request #15251 from jakevdp:mypy-deps PiperOrigin-RevId: 520079899 28 March 2023, 18:11:15 UTC
7442faa Remove MeshPspecSharding since it has been more than 3 months since it was deprecated (Nov 2, 2022). The replacement name is NamedSharding. PiperOrigin-RevId: 520072687 28 March 2023, 17:47:42 UTC
97c8ce3 Deprecate FROM_GDA and remove its support from pjit's code since jax.Array inside pjit has sharding inference capabilities by default. PiperOrigin-RevId: 520067392 28 March 2023, 17:33:00 UTC
6ace666 Set coordinator address to allow it to later be used to initialize OCDBT coordinator server. Allow user to pass ts.Context when serializing or deserializing. PiperOrigin-RevId: 520064049 28 March 2023, 17:25:07 UTC
2ac2dc6 Remove jax2tf experimental_native_lowering. Users should use native_serialization. PiperOrigin-RevId: 520063928 28 March 2023, 17:17:58 UTC
86c0b36 Remove Cuda 11.4 support. JAX from 0.4.8 release will support cuda 11.8 and cuda 12 PiperOrigin-RevId: 520056811 28 March 2023, 16:54:36 UTC
4106c35 Merge pull request #15258 from gnecula:dim_vars PiperOrigin-RevId: 520033493 28 March 2023, 15:21:00 UTC
a1538c7 [shape_poly] Refactor the computation of the dimension variables in native serialization Currently, JAX native serialization produces a module whose main function takes additional arguments for the values of the dimension variables. These values are then resolved in the XlaCallModule based on a dim_args_spec parameter. We move the code that computes the dimension variables from XlaCallModule to jax_export following pretty much the same technique. This simplifies XlaCallModule and especially its API (the dim_args_spec). So far this is just a refactoring with no semantic changes, but this will allow us to improve the support for dimension variables that occur in linear polynomials, e.g., "2*b" rather than just "b". 28 March 2023, 10:51:48 UTC
4533578 Merge pull request #15206 from jakevdp:expm-batch PiperOrigin-RevId: 519883991 28 March 2023, 01:21:27 UTC
03d1442 Merge pull request #15241 from jakevdp:instance-check PiperOrigin-RevId: 519868384 28 March 2023, 00:06:00 UTC
6f39237 Merge pull request #15243 from jakevdp:coo-sort-warning PiperOrigin-RevId: 519864866 27 March 2023, 23:50:24 UTC
ad0fc89 jax.scipy.linalg.expm: support batched inputs 27 March 2023, 23:39:48 UTC
61190cb CI: add numpy & scipy to mypy env 27 March 2023, 22:08:44 UTC
670fba3 Finish jax and jaxlib 0.4.7 release PiperOrigin-RevId: 519839723 27 March 2023, 22:06:38 UTC
10dc941 Add jaxlib version guard for rnn test PiperOrigin-RevId: 519833650 27 March 2023, 21:43:46 UTC
6cc1bf5 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval. Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters. PiperOrigin-RevId: 519813664 27 March 2023, 20:30:47 UTC
ae4f1fc Update the commit in workspace too PiperOrigin-RevId: 519797427 27 March 2023, 19:30:18 UTC
e9cac5e Prepare for jax and jaxlib 0.4.7 release PiperOrigin-RevId: 519785176 27 March 2023, 18:45:22 UTC
e21aee1 Add deprecation warning for FROM_GDA usage since that argument is not required anymore. PiperOrigin-RevId: 519781715 27 March 2023, 18:33:11 UTC
3c3fa04 Copy seq_lengths before creating descriptor PiperOrigin-RevId: 519771897 27 March 2023, 17:59:44 UTC
88c2898 Use pytype_strict_library() in Bazel build rules. PiperOrigin-RevId: 519757928 27 March 2023, 17:16:08 UTC
392bd93 [sparse] fix coo efficiency warning 27 March 2023, 17:15:43 UTC
ed9fa13 jax.typing: recommend instance check in Python 3.10 or newer 27 March 2023, 17:01:28 UTC
40fb646 Fix duplicate definition of 'cuda' extra in setup.py. PiperOrigin-RevId: 519750659 27 March 2023, 16:52:37 UTC
af4d494 Merge pull request #8633 from shawwn:2021-11-19/autodidax-fix-jaxpr-subcomp-return-type PiperOrigin-RevId: 519745476 27 March 2023, 16:52:20 UTC
10d51c7 Merge pull request #15218 from hawkinsp:mypy PiperOrigin-RevId: 519745465 27 March 2023, 16:37:54 UTC
41695cc Temporarily fix the compilation cache test which is failing on latest jaxlib release PiperOrigin-RevId: 519745099 27 March 2023, 16:37:37 UTC
2c4be6f Merge pull request #15226 from canyon289:patch-1 PiperOrigin-RevId: 519743393 27 March 2023, 16:30:24 UTC
d19e60e Merge pull request #15228 from canyon289:patch-2 PiperOrigin-RevId: 519742908 27 March 2023, 16:30:07 UTC
cf8c2b8 Delete benchmark and pmap_benchmark files as they are legacy and replaced with api_benchmark.py PiperOrigin-RevId: 519742866 27 March 2023, 16:22:57 UTC
d473e86 Merge pull request #13008 from hawkinsp:pipcuda PiperOrigin-RevId: 519740461 27 March 2023, 16:14:24 UTC
6715736 Merge pull request #15205 from yhtang:editable-jaxlib-build PiperOrigin-RevId: 519704474 27 March 2023, 13:33:31 UTC
f3613a1 Merge pull request #15234 from gnecula:get_dim_size PiperOrigin-RevId: 519691037 27 March 2023, 12:21:54 UTC
befb449 [shape_poly] Fixed bug with dimension variables in unused args JAX will aggressively drop module input arguments if they are not used. This can interfere with shape polymorphism, because it may result in dropping arguments from which we need to derive the values of shape variables. We fix this for now by disabling dropping arguments if there are dimension variables in the arguments shapes. A more precise technique would be to force keeping only of arguments that we need for deriving the dimension variables. However, that would be a much more involved change, for an uncertain benefit. 27 March 2023, 11:37:39 UTC
99facba [jax2tf] Turn an error into a warning with native serialization We want to allow using native_serialization_platforms even if the native_serialization is False. This is useful for code that is runnable with and without native serialization. PiperOrigin-RevId: 519649827 27 March 2023, 08:24:56 UTC
08a8a5e Fix hessian llnk 27 March 2023, 02:21:52 UTC
8c25495 Update user_guides.rst Fix minor typo 27 March 2023, 00:21:35 UTC
b62f114 Add support for using pip-installed CUDA wheels. Add a currently undocumented jax[cuda11_pip] and jax[cuda12_pip] that depend on the pip CUDA wheels. Add a currently undocumented jax[cuda11_local] and jax[cuda12_local] that avoid the CUDA wheel dependency. 26 March 2023, 12:35:00 UTC
ec427f2 Split dtype argument from other arguments in special functions. This helps pytype to determine that the arguments are of different kinds, preventing type errors. PiperOrigin-RevId: 519401250 25 March 2023, 18:41:14 UTC
05319b5 Suppress mypy warnings about missing imports. 25 March 2023, 13:45:55 UTC
a5d3085 Add `src` argument to device_put as an experimental arg PiperOrigin-RevId: 519308082 25 March 2023, 04:10:26 UTC
5da4a5a Add SDA deprecation warning to pytest.ini PiperOrigin-RevId: 519281775 25 March 2023, 01:04:07 UTC
c572155 Merge pull request #15212 from google:pjrt_c_api_tests PiperOrigin-RevId: 519276265 25 March 2023, 00:27:40 UTC
ef5e4a4 Remove 'pjrt_c_api_unimplemented' pytest mark. Instead, we skip tests that the PJRT C API doesn't support. We had this tag for feature development so it was easy to broadly disable, but now we don't expect to need to do that. 24 March 2023, 23:14:54 UTC
6842e98 Migrate regularized_incomplete_beta_p off xla_fallback PiperOrigin-RevId: 519244597 24 March 2023, 21:53:20 UTC
ac44d2c Migrate besseli0e off xla_fallback PiperOrigin-RevId: 519241252 24 March 2023, 21:39:40 UTC
caaa0a2 add build option to create editable jaxlib Co-authored-by: Yonghao Zhuang <zhuangyh@sjtu.edu.cn> Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com> 24 March 2023, 21:25:26 UTC
257ac6a If each host has the full value of the Array, allow fetching it to host. Fixes #15162 Benchmarks: ``` name old cpu/op new cpu/op delta np_asarray_8_devices 3.71ms ± 6% 3.32ms ± 7% -10.48% (p=0.008 n=5+5) name old time/op new time/op delta np_asarray_8_devices 3.86ms ± 6% 3.49ms ± 7% -9.72% (p=0.008 n=5+5) ``` PiperOrigin-RevId: 519222320 24 March 2023, 20:21:57 UTC
6ed66ad Delete remote TPU support. TPU VMs are the only supported way to use TPUs as of JAX 0.4.0. PiperOrigin-RevId: 519211267 24 March 2023, 19:33:33 UTC
fad4e6f [1/n] store embedded tf.graph to stablehlo.custom_call PiperOrigin-RevId: 519194911 24 March 2023, 18:27:24 UTC
21541e6 Guard ArrayImpl checks by xla_extension_version. PiperOrigin-RevId: 519191714 24 March 2023, 18:15:36 UTC
bc231ee After the SPMD bug fix, always take the _rewriting_take route for getitem instead of bouncing to host. PiperOrigin-RevId: 519170785 24 March 2023, 17:00:41 UTC
61a5686 Merge pull request #15175 from nouiz:nightly_ci_keep_alive PiperOrigin-RevId: 519167122 24 March 2023, 16:52:49 UTC
72bb4fe Merge pull request #15203 from nouiz:nightly_ci PiperOrigin-RevId: 519167117 24 March 2023, 16:45:21 UTC
8c75e27 Migrate random_gamma_grad off xla_fallback PiperOrigin-RevId: 519154537 24 March 2023, 15:49:40 UTC
8d1d522 Migrate igamma_grad_a_p off xla_fallback PiperOrigin-RevId: 519148548 24 March 2023, 15:21:22 UTC
229a4cf remove another dependency not currently needed. 24 March 2023, 15:04:27 UTC
4a9b094 Migrate igammac_p off xla_fallback path It is now decomposed into stablehlo ops. PiperOrigin-RevId: 519122775 24 March 2023, 12:58:38 UTC
8081031 [jaxlib] fix build w/ depenency on stablehlo_serialization PiperOrigin-RevId: 519120624 24 March 2023, 12:42:38 UTC
32e7128 Merge pull request #15192 from mattjj:issue15190 PiperOrigin-RevId: 519037959 24 March 2023, 03:48:33 UTC
1982a11 Merge pull request #15187 from mattjj:djax-revival PiperOrigin-RevId: 519036576 24 March 2023, 03:38:01 UTC
7743fcd [dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit 24 March 2023, 03:20:01 UTC
793387e fix jax.Array.round() fixes #15190 24 March 2023, 03:16:23 UTC
4cb3b01 Remove PJRT C API bypass. Now that all functionality needed by frameworks is implemented, let's remove the possibility of not noticing missing functionality due to the bypass. PiperOrigin-RevId: 519018438 24 March 2023, 01:39:14 UTC
e36bd57 Merge pull request #14309 from hawkinsp:numpy PiperOrigin-RevId: 519015703 24 March 2023, 01:24:01 UTC
b7375b3 Increase minimum NumPy version to 1.21. Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21. 24 March 2023, 01:15:10 UTC
e9bc7ee Merge pull request #15184 from jakevdp:move-median PiperOrigin-RevId: 519003606 24 March 2023, 00:14:50 UTC
6f8885a lax_numpy: move quantile-based functions to reductions.py 23 March 2023, 23:39:20 UTC
f981243 Migrate igamma_p off xla_fallback We decompose it into a series or a call to igammac. PiperOrigin-RevId: 518993077 23 March 2023, 23:26:59 UTC
d777cf2 [jax2tf] A simple failing test on TPU with native serialization PiperOrigin-RevId: 518987577 23 March 2023, 23:04:53 UTC
171b22d Add padding option "SAME_LOWER" for ticket https://github.com/google/jax/pull/14990 PiperOrigin-RevId: 518984018 23 March 2023, 22:50:16 UTC
7df5245 [jax2tf] Create a jax_export library with JAX-only pieces for native serialization This is a pure refactor, no functionality should change. PiperOrigin-RevId: 518982222 23 March 2023, 22:42:45 UTC
6407fa6 Merge pull request #15178 from mattjj:improve-scan-errors PiperOrigin-RevId: 518977054 23 March 2023, 22:28:16 UTC
adbdaa4 Refactor special functions into their own module. We're going to want to decompose these using series and continued fraction representations, and for that we'll need control flow PiperOrigin-RevId: 518977008 23 March 2023, 22:21:15 UTC
1f7c305 Merge pull request #15172 from jakevdp:jax-array-refactor PiperOrigin-RevId: 518971494 23 March 2023, 22:00:15 UTC
ba2ff51 improve scan error messages 23 March 2023, 21:53:05 UTC
1136d0f [jax2tf] Minor addition to the documentation PiperOrigin-RevId: 518969936 23 March 2023, 21:52:01 UTC
back to top