6927e5d | Skye Wanderman-Milne | 29 March 2023, 21:23:45 UTC | Add 1-hour timeout to each Cloud TPU CI job. Sometimes they hang, and the default timeout is 6 hours, which is way too long. | 29 March 2023, 21:23:45 UTC |
f282c25 | Jake VanderPlas | 29 March 2023, 17:04:34 UTC | Add minimal pyproject.toml specifying build system Replaces #15274, Fixes #15256 PiperOrigin-RevId: 520367622 | 29 March 2023, 17:08:30 UTC |
cfa330b | jax authors | 29 March 2023, 17:01:27 UTC | Merge pull request #15283 from JiaYaobo:fix_wald_doc PiperOrigin-RevId: 520364879 | 29 March 2023, 17:01:27 UTC |
2d94f76 | jax authors | 29 March 2023, 16:53:58 UTC | Merge pull request #15278 from hawkinsp:cudainstall PiperOrigin-RevId: 520364354 | 29 March 2023, 16:53:58 UTC |
fbc05ee | Yash Katariya | 29 March 2023, 16:22:34 UTC | Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago PiperOrigin-RevId: 520356179 | 29 March 2023, 16:23:22 UTC |
a964ae7 | jax authors | 29 March 2023, 15:23:18 UTC | Internal Code Change PiperOrigin-RevId: 520341781 | 29 March 2023, 15:23:56 UTC |
7200d07 | jax authors | 29 March 2023, 13:39:04 UTC | Merge pull request #15286 from hawkinsp:testjobs PiperOrigin-RevId: 520319910 | 29 March 2023, 13:39:04 UTC |
d9b0f3c | Peter Hawkins | 29 March 2023, 13:28:53 UTC | Recommend --local_test_jobs in bazel test command line on GPU. | 29 March 2023, 13:28:53 UTC |
07fc022 | jax authors | 29 March 2023, 12:25:50 UTC | Merge pull request #15279 from hawkinsp:versions PiperOrigin-RevId: 520307157 | 29 March 2023, 12:25:50 UTC |
3a4d0b3 | jiayaobo | 29 March 2023, 03:39:49 UTC | remove scale in wald docstring | 29 March 2023, 03:39:49 UTC |
705b5cc | Peter Hawkins | 29 March 2023, 01:55:32 UTC | Add version constraints to CUDA pip wheel dependencies. Fixes https://github.com/google/jax/issues/15267 | 29 March 2023, 01:55:32 UTC |
775f404 | Peter Hawkins | 29 March 2023, 01:46:07 UTC | Update the CUDA installation instructions. | 29 March 2023, 01:46:07 UTC |
c2d6fcc | Peter Hawkins | 29 March 2023, 01:30:36 UTC | Split core.py and several files in an SCC with it into a separate Bazel build target. PiperOrigin-RevId: 520192610 | 29 March 2023, 01:31:13 UTC |
8c4fed6 | jax authors | 28 March 2023, 22:48:27 UTC | Merge pull request #15270 from skye:pjrt_c_api PiperOrigin-RevId: 520156646 | 28 March 2023, 22:48:27 UTC |
473d1c3 | Skye Wanderman-Milne | 28 March 2023, 20:42:51 UTC | Turn on PJRT C API by default. I forgot that the default setting is actually in jaxlib: https://github.com/openxla/xla/blob/fbe9a80fdb8c429e8a175962459da348cd560a50/xla/python/xla_client.py#L135 To be able to make this change as a jax-only release, I manually set the env var on Cloud TPU if it isn't already set. | 28 March 2023, 22:28:13 UTC |
5ae2e79 | Rebecca Chen | 28 March 2023, 22:23:54 UTC | Silence some pytype errors. PiperOrigin-RevId: 520150523 | 28 March 2023, 22:24:48 UTC |
4061bbb | jax authors | 28 March 2023, 21:02:27 UTC | Merge pull request #15269 from skye:min_jaxlib_version PiperOrigin-RevId: 520127548 | 28 March 2023, 21:02:27 UTC |
00acf45 | Skye Wanderman-Milne | 28 March 2023, 19:43:32 UTC | Bump minimum jaxlib version from 0.4.6 to 0.4.7. Also removes a bunch of dead version guards (0.4.7 has xla_extension_version 144 and mlir_api_version 47) | 28 March 2023, 20:43:01 UTC |
014033b | jax authors | 28 March 2023, 20:09:56 UTC | Merge pull request #15266 from mehdiataei:mehdiataei-patch-1 PiperOrigin-RevId: 520111030 | 28 March 2023, 20:09:56 UTC |
bbec461 | jax authors | 28 March 2023, 20:02:32 UTC | Merge pull request #15263 from jakevdp:deprecations PiperOrigin-RevId: 520110559 | 28 March 2023, 20:02:32 UTC |
fc47137 | Jake VanderPlas | 28 March 2023, 19:40:59 UTC | Add deprecation warnings for several top-level jax imports | 28 March 2023, 19:40:59 UTC |
2f105bd | Yash Katariya | 28 March 2023, 19:16:04 UTC | Jax 0.4.7 has been released so assert that length of warnings is 1 in test_cache_read_warning PiperOrigin-RevId: 520098757 | 28 March 2023, 19:17:30 UTC |
8d090a0 | mehdiataei | 28 March 2023, 18:48:18 UTC | Fixed spelling error in msgs | 28 March 2023, 18:48:18 UTC |
2fbccc8 | jax authors | 28 March 2023, 18:11:15 UTC | Merge pull request #15251 from jakevdp:mypy-deps PiperOrigin-RevId: 520079899 | 28 March 2023, 18:11:15 UTC |
7442faa | Yash Katariya | 28 March 2023, 17:46:28 UTC | Remove MeshPspecSharding since it has been more than 3 months since it was deprecated (Nov 2, 2022). The replacement name is NamedSharding. PiperOrigin-RevId: 520072687 | 28 March 2023, 17:47:42 UTC |
97c8ce3 | Yash Katariya | 28 March 2023, 17:29:01 UTC | Deprecate FROM_GDA and remove its support from pjit's code since jax.Array inside pjit has sharding inference capabilities by default. PiperOrigin-RevId: 520067392 | 28 March 2023, 17:33:00 UTC |
6ace666 | Colin Gaffney | 28 March 2023, 17:17:49 UTC | Set coordinator address to allow it to later be used to initialize OCDBT coordinator server. Allow user to pass ts.Context when serializing or deserializing. PiperOrigin-RevId: 520064049 | 28 March 2023, 17:25:07 UTC |
2ac2dc6 | George Necula | 28 March 2023, 17:17:19 UTC | Remove jax2tf experimental_native_lowering. Users should use native_serialization. PiperOrigin-RevId: 520063928 | 28 March 2023, 17:17:58 UTC |
86c0b36 | Yash Katariya | 28 March 2023, 16:53:57 UTC | Remove Cuda 11.4 support. JAX from 0.4.8 release will support cuda 11.8 and cuda 12 PiperOrigin-RevId: 520056811 | 28 March 2023, 16:54:36 UTC |
4106c35 | jax authors | 28 March 2023, 15:21:00 UTC | Merge pull request #15258 from gnecula:dim_vars PiperOrigin-RevId: 520033493 | 28 March 2023, 15:21:00 UTC |
a1538c7 | George Necula | 12 March 2023, 08:04:21 UTC | [shape_poly] Refactor the computation of the dimension variables in native serialization Currently, JAX native serialization produces a module whose main function takes additional arguments for the values of the dimension variables. These values are then resolved in the XlaCallModule based on a dim_args_spec parameter. We move the code that computes the dimension variables from XlaCallModule to jax_export following pretty much the same technique. This simplifies XlaCallModule and especially its API (the dim_args_spec). So far this is just a refactoring with no semantic changes, but this will allow us to improve the support for dimension variables that occur in linear polynomials, e.g., "2*b" rather than just "b". | 28 March 2023, 10:51:48 UTC |
4533578 | jax authors | 28 March 2023, 01:21:27 UTC | Merge pull request #15206 from jakevdp:expm-batch PiperOrigin-RevId: 519883991 | 28 March 2023, 01:21:27 UTC |
03d1442 | jax authors | 28 March 2023, 00:06:00 UTC | Merge pull request #15241 from jakevdp:instance-check PiperOrigin-RevId: 519868384 | 28 March 2023, 00:06:00 UTC |
6f39237 | jax authors | 27 March 2023, 23:50:24 UTC | Merge pull request #15243 from jakevdp:coo-sort-warning PiperOrigin-RevId: 519864866 | 27 March 2023, 23:50:24 UTC |
ad0fc89 | Jake VanderPlas | 27 March 2023, 23:39:48 UTC | jax.scipy.linalg.expm: support batched inputs | 27 March 2023, 23:39:48 UTC |
61190cb | Jake VanderPlas | 27 March 2023, 22:08:44 UTC | CI: add numpy & scipy to mypy env | 27 March 2023, 22:08:44 UTC |
670fba3 | Yash Katariya | 27 March 2023, 22:06:06 UTC | Finish jax and jaxlib 0.4.7 release PiperOrigin-RevId: 519839723 | 27 March 2023, 22:06:38 UTC |
10dc941 | Sharad Vikram | 27 March 2023, 21:43:05 UTC | Add jaxlib version guard for rnn test PiperOrigin-RevId: 519833650 | 27 March 2023, 21:43:46 UTC |
6cc1bf5 | Peter Hawkins | 27 March 2023, 20:29:59 UTC | Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval. Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters. PiperOrigin-RevId: 519813664 | 27 March 2023, 20:30:47 UTC |
ae4f1fc | Yash Katariya | 27 March 2023, 19:29:35 UTC | Update the commit in workspace too PiperOrigin-RevId: 519797427 | 27 March 2023, 19:30:18 UTC |
e9cac5e | Yash Katariya | 27 March 2023, 18:44:19 UTC | Prepare for jax and jaxlib 0.4.7 release PiperOrigin-RevId: 519785176 | 27 March 2023, 18:45:22 UTC |
e21aee1 | Yash Katariya | 27 March 2023, 18:32:30 UTC | Add deprecation warning for FROM_GDA usage since that argument is not required anymore. PiperOrigin-RevId: 519781715 | 27 March 2023, 18:33:11 UTC |
3c3fa04 | Sharad Vikram | 27 March 2023, 17:59:14 UTC | Copy seq_lengths before creating descriptor PiperOrigin-RevId: 519771897 | 27 March 2023, 17:59:44 UTC |
88c2898 | Peter Hawkins | 27 March 2023, 17:14:05 UTC | Use pytype_strict_library() in Bazel build rules. PiperOrigin-RevId: 519757928 | 27 March 2023, 17:16:08 UTC |
392bd93 | Jake VanderPlas | 27 March 2023, 17:15:43 UTC | [sparse] fix coo efficiency warning | 27 March 2023, 17:15:43 UTC |
ed9fa13 | Jake VanderPlas | 27 March 2023, 17:01:28 UTC | jax.typing: recommend instance check in Python 3.10 or newer | 27 March 2023, 17:01:28 UTC |
40fb646 | Peter Hawkins | 27 March 2023, 16:50:56 UTC | Fix duplicate definition of 'cuda' extra in setup.py. PiperOrigin-RevId: 519750659 | 27 March 2023, 16:52:37 UTC |
af4d494 | jax authors | 27 March 2023, 16:52:20 UTC | Merge pull request #8633 from shawwn:2021-11-19/autodidax-fix-jaxpr-subcomp-return-type PiperOrigin-RevId: 519745476 | 27 March 2023, 16:52:20 UTC |
10d51c7 | jax authors | 27 March 2023, 16:37:54 UTC | Merge pull request #15218 from hawkinsp:mypy PiperOrigin-RevId: 519745465 | 27 March 2023, 16:37:54 UTC |
41695cc | Yash Katariya | 27 March 2023, 16:30:42 UTC | Temporarily fix the compilation cache test which is failing on latest jaxlib release PiperOrigin-RevId: 519745099 | 27 March 2023, 16:37:37 UTC |
2c4be6f | jax authors | 27 March 2023, 16:30:24 UTC | Merge pull request #15226 from canyon289:patch-1 PiperOrigin-RevId: 519743393 | 27 March 2023, 16:30:24 UTC |
d19e60e | jax authors | 27 March 2023, 16:30:07 UTC | Merge pull request #15228 from canyon289:patch-2 PiperOrigin-RevId: 519742908 | 27 March 2023, 16:30:07 UTC |
cf8c2b8 | Yash Katariya | 27 March 2023, 16:22:15 UTC | Delete benchmark and pmap_benchmark files as they are legacy and replaced with api_benchmark.py PiperOrigin-RevId: 519742866 | 27 March 2023, 16:22:57 UTC |
d473e86 | jax authors | 27 March 2023, 16:14:24 UTC | Merge pull request #13008 from hawkinsp:pipcuda PiperOrigin-RevId: 519740461 | 27 March 2023, 16:14:24 UTC |
6715736 | jax authors | 27 March 2023, 13:33:31 UTC | Merge pull request #15205 from yhtang:editable-jaxlib-build PiperOrigin-RevId: 519704474 | 27 March 2023, 13:33:31 UTC |
f3613a1 | jax authors | 27 March 2023, 12:21:54 UTC | Merge pull request #15234 from gnecula:get_dim_size PiperOrigin-RevId: 519691037 | 27 March 2023, 12:21:54 UTC |
befb449 | George Necula | 27 March 2023, 11:12:10 UTC | [shape_poly] Fixed bug with dimension variables in unused args JAX will aggressively drop module input arguments if they are not used. This can interfere with shape polymorphism, because it may result in dropping arguments from which we need to derive the values of shape variables. We fix this for now by disabling dropping arguments if there are dimension variables in the arguments shapes. A more precise technique would be to force keeping only of arguments that we need for deriving the dimension variables. However, that would be a much more involved change, for an uncertain benefit. | 27 March 2023, 11:37:39 UTC |
99facba | George Necula | 27 March 2023, 08:24:20 UTC | [jax2tf] Turn an error into a warning with native serialization We want to allow using native_serialization_platforms even if the native_serialization is False. This is useful for code that is runnable with and without native serialization. PiperOrigin-RevId: 519649827 | 27 March 2023, 08:24:56 UTC |
08a8a5e | Ravin Kumar | 27 March 2023, 02:21:52 UTC | Fix hessian llnk | 27 March 2023, 02:21:52 UTC |
8c25495 | Ravin Kumar | 27 March 2023, 00:21:35 UTC | Update user_guides.rst Fix minor typo | 27 March 2023, 00:21:35 UTC |
b62f114 | Peter Hawkins | 27 October 2022, 14:53:19 UTC | Add support for using pip-installed CUDA wheels. Add a currently undocumented jax[cuda11_pip] and jax[cuda12_pip] that depend on the pip CUDA wheels. Add a currently undocumented jax[cuda11_local] and jax[cuda12_local] that avoid the CUDA wheel dependency. | 26 March 2023, 12:35:00 UTC |
ec427f2 | Peter Hawkins | 25 March 2023, 18:40:32 UTC | Split dtype argument from other arguments in special functions. This helps pytype to determine that the arguments are of different kinds, preventing type errors. PiperOrigin-RevId: 519401250 | 25 March 2023, 18:41:14 UTC |
05319b5 | Peter Hawkins | 25 March 2023, 13:45:55 UTC | Suppress mypy warnings about missing imports. | 25 March 2023, 13:45:55 UTC |
a5d3085 | Yash Katariya | 25 March 2023, 04:09:45 UTC | Add `src` argument to device_put as an experimental arg PiperOrigin-RevId: 519308082 | 25 March 2023, 04:10:26 UTC |
5da4a5a | Yash Katariya | 25 March 2023, 01:03:29 UTC | Add SDA deprecation warning to pytest.ini PiperOrigin-RevId: 519281775 | 25 March 2023, 01:04:07 UTC |
c572155 | jax authors | 25 March 2023, 00:27:40 UTC | Merge pull request #15212 from google:pjrt_c_api_tests PiperOrigin-RevId: 519276265 | 25 March 2023, 00:27:40 UTC |
ef5e4a4 | Skye Wanderman-Milne | 24 March 2023, 20:55:04 UTC | Remove 'pjrt_c_api_unimplemented' pytest mark. Instead, we skip tests that the PJRT C API doesn't support. We had this tag for feature development so it was easy to broadly disable, but now we don't expect to need to do that. | 24 March 2023, 23:14:54 UTC |
6842e98 | Anish Tondwalkar | 24 March 2023, 21:52:45 UTC | Migrate regularized_incomplete_beta_p off xla_fallback PiperOrigin-RevId: 519244597 | 24 March 2023, 21:53:20 UTC |
ac44d2c | Anish Tondwalkar | 24 March 2023, 21:39:05 UTC | Migrate besseli0e off xla_fallback PiperOrigin-RevId: 519241252 | 24 March 2023, 21:39:40 UTC |
caaa0a2 | Yu-Hang 'Maxin' Tang | 24 March 2023, 21:25:26 UTC | add build option to create editable jaxlib Co-authored-by: Yonghao Zhuang <zhuangyh@sjtu.edu.cn> Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com> | 24 March 2023, 21:25:26 UTC |
257ac6a | Yash Katariya | 24 March 2023, 20:21:20 UTC | If each host has the full value of the Array, allow fetching it to host. Fixes #15162 Benchmarks: ``` name old cpu/op new cpu/op delta np_asarray_8_devices 3.71ms ± 6% 3.32ms ± 7% -10.48% (p=0.008 n=5+5) name old time/op new time/op delta np_asarray_8_devices 3.86ms ± 6% 3.49ms ± 7% -9.72% (p=0.008 n=5+5) ``` PiperOrigin-RevId: 519222320 | 24 March 2023, 20:21:57 UTC |
6ed66ad | Peter Hawkins | 24 March 2023, 19:32:53 UTC | Delete remote TPU support. TPU VMs are the only supported way to use TPUs as of JAX 0.4.0. PiperOrigin-RevId: 519211267 | 24 March 2023, 19:33:33 UTC |
fad4e6f | John QiangZhang | 24 March 2023, 18:26:44 UTC | [1/n] store embedded tf.graph to stablehlo.custom_call PiperOrigin-RevId: 519194911 | 24 March 2023, 18:27:24 UTC |
21541e6 | Parker Schuh | 24 March 2023, 18:14:59 UTC | Guard ArrayImpl checks by xla_extension_version. PiperOrigin-RevId: 519191714 | 24 March 2023, 18:15:36 UTC |
bc231ee | Yash Katariya | 24 March 2023, 16:59:55 UTC | After the SPMD bug fix, always take the _rewriting_take route for getitem instead of bouncing to host. PiperOrigin-RevId: 519170785 | 24 March 2023, 17:00:41 UTC |
61a5686 | jax authors | 24 March 2023, 16:52:49 UTC | Merge pull request #15175 from nouiz:nightly_ci_keep_alive PiperOrigin-RevId: 519167122 | 24 March 2023, 16:52:49 UTC |
72bb4fe | jax authors | 24 March 2023, 16:45:21 UTC | Merge pull request #15203 from nouiz:nightly_ci PiperOrigin-RevId: 519167117 | 24 March 2023, 16:45:21 UTC |
8c75e27 | Anish Tondwalkar | 24 March 2023, 15:48:55 UTC | Migrate random_gamma_grad off xla_fallback PiperOrigin-RevId: 519154537 | 24 March 2023, 15:49:40 UTC |
8d1d522 | Anish Tondwalkar | 24 March 2023, 15:20:46 UTC | Migrate igamma_grad_a_p off xla_fallback PiperOrigin-RevId: 519148548 | 24 March 2023, 15:21:22 UTC |
229a4cf | Frederic Bastien | 24 March 2023, 14:59:11 UTC | remove another dependency not currently needed. | 24 March 2023, 15:04:27 UTC |
4a9b094 | Anish Tondwalkar | 24 March 2023, 12:57:59 UTC | Migrate igammac_p off xla_fallback path It is now decomposed into stablehlo ops. PiperOrigin-RevId: 519122775 | 24 March 2023, 12:58:38 UTC |
8081031 | Anish Tondwalkar | 24 March 2023, 12:42:01 UTC | [jaxlib] fix build w/ depenency on stablehlo_serialization PiperOrigin-RevId: 519120624 | 24 March 2023, 12:42:38 UTC |
32e7128 | jax authors | 24 March 2023, 03:48:33 UTC | Merge pull request #15192 from mattjj:issue15190 PiperOrigin-RevId: 519037959 | 24 March 2023, 03:48:33 UTC |
1982a11 | jax authors | 24 March 2023, 03:38:01 UTC | Merge pull request #15187 from mattjj:djax-revival PiperOrigin-RevId: 519036576 | 24 March 2023, 03:38:01 UTC |
7743fcd | Matthew Johnson | 23 March 2023, 03:54:45 UTC | [dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit | 24 March 2023, 03:20:01 UTC |
793387e | Matthew Johnson | 24 March 2023, 03:16:23 UTC | fix jax.Array.round() fixes #15190 | 24 March 2023, 03:16:23 UTC |
4cb3b01 | Skye Wanderman-Milne | 24 March 2023, 01:38:26 UTC | Remove PJRT C API bypass. Now that all functionality needed by frameworks is implemented, let's remove the possibility of not noticing missing functionality due to the bypass. PiperOrigin-RevId: 519018438 | 24 March 2023, 01:39:14 UTC |
e36bd57 | jax authors | 24 March 2023, 01:24:01 UTC | Merge pull request #14309 from hawkinsp:numpy PiperOrigin-RevId: 519015703 | 24 March 2023, 01:24:01 UTC |
b7375b3 | Peter Hawkins | 06 February 2023, 16:32:28 UTC | Increase minimum NumPy version to 1.21. Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21. | 24 March 2023, 01:15:10 UTC |
e9bc7ee | jax authors | 24 March 2023, 00:14:50 UTC | Merge pull request #15184 from jakevdp:move-median PiperOrigin-RevId: 519003606 | 24 March 2023, 00:14:50 UTC |
6f8885a | Jake VanderPlas | 23 March 2023, 23:39:20 UTC | lax_numpy: move quantile-based functions to reductions.py | 23 March 2023, 23:39:20 UTC |
f981243 | Anish Tondwalkar | 23 March 2023, 23:26:26 UTC | Migrate igamma_p off xla_fallback We decompose it into a series or a call to igammac. PiperOrigin-RevId: 518993077 | 23 March 2023, 23:26:59 UTC |
d777cf2 | George Necula | 23 March 2023, 23:04:23 UTC | [jax2tf] A simple failing test on TPU with native serialization PiperOrigin-RevId: 518987577 | 23 March 2023, 23:04:53 UTC |
171b22d | John QiangZhang | 23 March 2023, 22:49:44 UTC | Add padding option "SAME_LOWER" for ticket https://github.com/google/jax/pull/14990 PiperOrigin-RevId: 518984018 | 23 March 2023, 22:50:16 UTC |
7df5245 | George Necula | 23 March 2023, 22:42:05 UTC | [jax2tf] Create a jax_export library with JAX-only pieces for native serialization This is a pure refactor, no functionality should change. PiperOrigin-RevId: 518982222 | 23 March 2023, 22:42:45 UTC |
6407fa6 | jax authors | 23 March 2023, 22:28:16 UTC | Merge pull request #15178 from mattjj:improve-scan-errors PiperOrigin-RevId: 518977054 | 23 March 2023, 22:28:16 UTC |
adbdaa4 | Anish Tondwalkar | 23 March 2023, 22:20:37 UTC | Refactor special functions into their own module. We're going to want to decompose these using series and continued fraction representations, and for that we'll need control flow PiperOrigin-RevId: 518977008 | 23 March 2023, 22:21:15 UTC |
1f7c305 | jax authors | 23 March 2023, 22:00:15 UTC | Merge pull request #15172 from jakevdp:jax-array-refactor PiperOrigin-RevId: 518971494 | 23 March 2023, 22:00:15 UTC |
ba2ff51 | Matthew Johnson | 23 March 2023, 21:39:40 UTC | improve scan error messages | 23 March 2023, 21:53:05 UTC |
1136d0f | George Necula | 23 March 2023, 21:51:20 UTC | [jax2tf] Minor addition to the documentation PiperOrigin-RevId: 518969936 | 23 March 2023, 21:52:01 UTC |