https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
f0c0689 Remove internal information 11 November 2022, 03:09:19 UTC
4a3b7f1 Change pickling for jax.sharding to not serialize device ids. PiperOrigin-RevId: 487700467 11 November 2022, 03:05:02 UTC
f9d7a6a Merge pull request #13197 from google:yashk2810-patch-17 PiperOrigin-RevId: 487687006 11 November 2022, 01:56:29 UTC
73935a5 Update jax_array_migration.md 11 November 2022, 01:23:16 UTC
c318f77 Merge pull request #13185 from tlu7:bcsr-from-scipy PiperOrigin-RevId: 487680265 11 November 2022, 01:18:44 UTC
aa66b93 Fix the docs build 11 November 2022, 01:08:57 UTC
b49a1bd Add jax.Array migration doc to OSS PiperOrigin-RevId: 487673643 11 November 2022, 00:46:30 UTC
311fb24 [sparse] Add BCSR from_scipy_sparse. Co-authored-by: Jake Vanderplas <vanderplas@google.com> 11 November 2022, 00:44:59 UTC
352b042 Add a GPU implementation of symmetric (Hermitian) tridiagonal reduction. Change the contract of lax.linalg.tridiagonal to return the d and e vectors as well. Since we only just added this function and have never released JAX with it we can make this change without breaking compatibility. Also fix wrong dtypes for d and e values in the CPU lapack sytrd wrapper. PiperOrigin-RevId: 487621469 10 November 2022, 21:16:21 UTC
74b136e Delete `jax_experimental_name_stack` flag PiperOrigin-RevId: 487601864 10 November 2022, 19:59:50 UTC
0ebb6b4 Merge pull request #13180 from jakevdp:bcoo-slice PiperOrigin-RevId: 487568853 10 November 2022, 18:04:35 UTC
cc41ee8 Mark scipy_signal_test and sparse_test `optonly` because it times out under debug mode. PiperOrigin-RevId: 487533356 10 November 2022, 15:38:58 UTC
71360ed Bump the shard count for TPU to avoid timeouts PiperOrigin-RevId: 487421018 10 November 2022, 04:32:12 UTC
e42e52d Rename test flag --num_generated_cases to --jax_num_generated_cases. parse_flags_with_absl() only parses flags that start with --jax_. Other flags are only parsed when absl.app's main function runs. But that's too late for test cases: test cases need to have the number of generated cases chosen at module initialization time. Hence the --num_generated_cases flag wasn't doing anything. Oops. By renaming it it works once again. It might make sense to stop using flags for the number of generated cases and only use environment variables. We defer that to a future change. Fix many test cases that were shown to be broken with a larger number of test cases enabled. PiperOrigin-RevId: 487406670 10 November 2022, 02:58:05 UTC
b36afc5 Merge pull request #13177 from jakevdp:bcoo-dynamic-slice PiperOrigin-RevId: 487390430 10 November 2022, 01:29:30 UTC
3731e44 Set default layout for Python callback PiperOrigin-RevId: 487388682 10 November 2022, 01:18:49 UTC
f9bbd58 Improve the error message when `@pjit` (with no {in_axis|out_axis}_resources is used without jax.Array enabled. PiperOrigin-RevId: 487380328 10 November 2022, 00:38:00 UTC
4c4f2a3 [sparse] support strides in bcoo_slice 09 November 2022, 23:03:21 UTC
b41c594 Internal change PiperOrigin-RevId: 487351591 09 November 2022, 22:41:39 UTC
46d9cac [sparse] bcoo_dynamic_slice: remove unnecessary padding from output 09 November 2022, 21:56:18 UTC
fa0217b Merge pull request #13175 from jakevdp:bcoo-transpose PiperOrigin-RevId: 487333413 09 November 2022, 21:34:55 UTC
0c3e330 [sparse] fix shape bug in bcoo_transpose 09 November 2022, 20:53:13 UTC
8ac7422 [JAX] Disables large k test cases in ann_test. Will investigate probability properties for the corner cases in the future. PiperOrigin-RevId: 487302143 09 November 2022, 19:32:47 UTC
63e3152 Merge pull request #13160 from jakevdp:bcoo-squeeze PiperOrigin-RevId: 487280563 09 November 2022, 18:18:22 UTC
f697b8e Merge pull request #13166 from LenaMartens:checking-keys PiperOrigin-RevId: 487267267 09 November 2022, 17:31:39 UTC
053b8b5 Checkify: fix nan_checks+PRNGKeys - a PRNGKey is never NaN! Add a guard to the nan_error_rule to not call jnp.isnan on keys. 09 November 2022, 17:08:21 UTC
1cead77 Add support for Hessenberg and tridiagonal matrix reductions on CPU. * Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg. * Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction. * Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction. None of these primitives are differentiable at the moment. PiperOrigin-RevId: 487224934 09 November 2022, 14:23:55 UTC
30637d0 Merge pull request #13168 from hawkinsp:fixbuild PiperOrigin-RevId: 487219149 09 November 2022, 13:50:34 UTC
41c90a8 Add missing stablehlo dialect files to jaxlib build. Unbreaks the build. 09 November 2022, 13:37:49 UTC
3b1ddf2 [linalg] Add jax.scipy.special.bessel_jn (Bessel function of the first kind). PiperOrigin-RevId: 487146250 09 November 2022, 07:03:21 UTC
5599632 Introduce XlaLowering::stablehlo() and use it in associated APIs See tests/api_test.py for usage examples. At the moment, stablehlo() works by using the hlo-legalize-to-stablehlo pass, which takes MHLO natively produced by JAX and converts it into StableHLO. This is an intermediate step towards switching JAX to natively produce StableHLO. This CL adds both mhlo_to_stablehlo and stablehlo_to_mhlo to jaxlib, even though only the former is used at the moment. This is done in anticipation of switching JAX to natively produce StableHLO, where stablehlo_to_mhlo will be needed to provide backward compatibility for XlaLowering::mhlo(). We're adding stablehlo_to_mhlo now, so that in the future we don't have to update jaxlib again which will make deployment easier. PiperOrigin-RevId: 487144342 09 November 2022, 06:50:06 UTC
df963bd Remove flaky Array defragmentation test check PiperOrigin-RevId: 487120630 09 November 2022, 04:06:36 UTC
0cf220f Merge pull request #13162 from jakevdp:bcoo-reshape PiperOrigin-RevId: 487106270 09 November 2022, 02:37:45 UTC
4255697 [sparse] add bcoo_squeeze function 09 November 2022, 02:16:20 UTC
53344b8 Don't create copies by device_putting a host local jax.Array if the sharding matches with the input. PiperOrigin-RevId: 487090094 09 November 2022, 01:02:23 UTC
7d3b1d6 [sparse] fix bcoo_reshape under jit 09 November 2022, 01:00:25 UTC
0d2cd6d [jax] Fix manual defragment method to work with Arrays PiperOrigin-RevId: 487068409 08 November 2022, 23:32:30 UTC
5e1d7cd Merge pull request #13032 from jakevdp:sharding-attr PiperOrigin-RevId: 487061046 08 November 2022, 23:01:23 UTC
8fbf8da Declare Array.sharding & raise an error on tracers 08 November 2022, 22:20:46 UTC
af017d4 Merge pull request #13153 from jakevdp:bcoo-reshape PiperOrigin-RevId: 487046508 08 November 2022, 22:11:51 UTC
768076e Merge pull request #13157 from jakevdp:bcoo-astype PiperOrigin-RevId: 487046458 08 November 2022, 22:05:09 UTC
7c0d0e6 [sparse] add support for BCOO.astype method 08 November 2022, 21:30:22 UTC
af95663 [sparse] fix bcoo_reshape when n_sparse=0 08 November 2022, 20:00:24 UTC
96f6c1c Let is_user_frame ignore frames from stdlib. When using decorators, we found contextlib.py from stdlib sometimes become the most recent non-jax frame. But it's not a user frame. PiperOrigin-RevId: 486993924 08 November 2022, 18:50:08 UTC
500cd85 Merge pull request #13144 from LenaMartens:donate-no-more PiperOrigin-RevId: 486979733 08 November 2022, 17:57:44 UTC
3994ac3 Merge pull request #13145 from hawkinsp:pinv PiperOrigin-RevId: 486935918 08 November 2022, 14:39:54 UTC
ab8cde9 Add support for the hermitian option on jnp.linalg.pinv. Improve the pinv implementation to avoid computing an unnecessary reduction: svd sorts its singular values so we don't need to use amax() to find the largest one. Avoid explicitly forming the identity matrix in the pinv JVP. 08 November 2022, 13:53:00 UTC
e80c34d Don't donate arguments in jit/pmap/pjit when debug_nans=True. 08 November 2022, 13:33:59 UTC
1e7e8e8 Merge pull request #13147 from hawkinsp:eyes PiperOrigin-RevId: 486826532 08 November 2022, 03:25:15 UTC
85f43dd Merge pull request #13061 from nouiz:test_doc PiperOrigin-RevId: 486816419 08 November 2022, 02:23:41 UTC
eb9e8c2 Merge pull request #13117 from 8bitmp3:move-multihost-multiprocess-toc PiperOrigin-RevId: 486780726 07 November 2022, 23:31:10 UTC
e00f7e7 Merge pull request #13093 from PhilipVinc:patch-1 PiperOrigin-RevId: 486753425 07 November 2022, 21:49:52 UTC
2c1fe45 Add UnloadedMeshExecutable to represent a MeshExecutable that is not loaded on any physical devices for the purposes of serialization. This type is easier to serialize because it has not yet been converted into arg-handlers. Potential API: ``` str, in_tree, out_tree = lowered.compile_and_serialize() exec = jax.experimental.load_serialized(str, in_tree, out_tree, backend) exec // identical to lowered.compile(). ``` PiperOrigin-RevId: 486751141 07 November 2022, 21:40:40 UTC
793fb9b Fix issue in `check_tree`, so that custom_linear_solve supports hax_aux=True when the vector and the aux are both pytrees. 07 November 2022, 20:29:43 UTC
2183059 Fix GDA error message formatting PiperOrigin-RevId: 486724647 07 November 2022, 19:55:28 UTC
845f8df Avoid forming identity matrix in SVD JVP. Set the default matmul precision in the SVD JVP, and use @ to express matmuls. Also fix a flaky test failure in QR test on Mac ARM. 07 November 2022, 18:55:45 UTC
b127b70 Remove static_argnums from AOT invocation. Static args are not needed during invoking an AOT computation. PiperOrigin-RevId: 486698420 07 November 2022, 18:21:57 UTC
da519f3 Check for `ArrayImpl` rather than `sharding` because this code is supposed to check for concrete Array until a `shard_like` primitive exists. PiperOrigin-RevId: 486689809 07 November 2022, 17:52:41 UTC
587885b Merge pull request #12077 from hawkinsp:docs PiperOrigin-RevId: 486681964 07 November 2022, 17:28:57 UTC
3da554d Merge pull request #13146 from jakevdp:fix-flake8 PiperOrigin-RevId: 486680771 07 November 2022, 17:17:35 UTC
b0e03fb Remove whitespace to fix flake8 07 November 2022, 17:10:05 UTC
cd84eb1 Add a number of missing function cross-references in the docs. 07 November 2022, 17:00:26 UTC
042595b Merge pull request #12890 from ROCmSoftwarePlatform:rocm_enable_multi_gpu_test PiperOrigin-RevId: 486675297 07 November 2022, 16:56:24 UTC
e9e014f Merge pull request #13140 from gnecula:clean_limitations PiperOrigin-RevId: 486610612 07 November 2022, 10:57:50 UTC
d9b1dc3 [jax2tf] Fixed jax2tf Limitations Improved the documentation, and fixed a dot_general limitations for preferred_element_type on GPU 07 November 2022, 09:26:39 UTC
8d59b0d Merge pull request #13108 from mattjj:djax-vmap2 PiperOrigin-RevId: 486534673 07 November 2022, 01:27:53 UTC
f2f2faa add a basic prototype of piles, behind jax_dynamic_shapes Co-authored-by: Adam Paszke <apaszke@google.com> Co-authored-by: Dougal Maclaurin <dougalm@google.com> 07 November 2022, 01:03:04 UTC
2a262e9 If the input to `host_local_array_to_global_array` is not fully addressable (i.e. not host local), return it as is. Also if the input to `global_array_to_host_local_array` is fully addressable (i.e. host local), return it as is. PiperOrigin-RevId: 486419066 06 November 2022, 03:16:14 UTC
3dbe101 Remove `device_indices` method which is redundant because of the existence of `devices_indices_map` and is slower because pinging a cache for every device is not free. PiperOrigin-RevId: 486405037 06 November 2022, 00:33:37 UTC
5b45357 Make Sharding instances picklable. PiperOrigin-RevId: 486386106 05 November 2022, 20:36:52 UTC
d9f3dc3 Improve the sharding mismatch error message by adding the arg to the message too. PiperOrigin-RevId: 486310996 05 November 2022, 07:15:15 UTC
0c36f84 Merge pull request #13127 from mattjj:issue13124 PiperOrigin-RevId: 486296679 05 November 2022, 04:44:17 UTC
e161d20 Improve the error message when the avals a function was AOT compiled with doesn't match the input avals when its called. PiperOrigin-RevId: 486294881 05 November 2022, 04:25:46 UTC
190204f fix jax.random.logits shape argument fixes #13124 05 November 2022, 02:51:39 UTC
2932c1e Merge pull request #13122 from yejingxin:main PiperOrigin-RevId: 486253810 04 November 2022, 23:15:14 UTC
75b5ccc Merge pull request #13024 from treyra:patch-2 PiperOrigin-RevId: 486244906 04 November 2022, 22:31:34 UTC
e6c88f2 update pytest.ini to print warning message for compilation_cache_test 04 November 2022, 21:43:51 UTC
2ce7eb5 Merge pull request #13113 from jakevdp:mypy-fix PiperOrigin-RevId: 486205511 04 November 2022, 19:35:26 UTC
f056647 Add a section to mention the `unsafe_rbg` approach for reducing the size of the generated TF graph. PiperOrigin-RevId: 486203922 04 November 2022, 19:28:22 UTC
636cac8 Increase visibility in index Multi-Host Multi-Process Envs Guide 04 November 2022, 19:22:39 UTC
e3cd967 Add x86_64 dependency note to pip installation Currently non x86_64 linux architectures are not supported, see #7097 for request to change this. This can lead to installation confusion, as jax will install, but jaxlib will not. For example see #12307. This adds a note to the install sections for the relevant pip wheels. 04 November 2022, 19:11:11 UTC
1799cf9 Merge pull request #13116 from jakevdp:fix-doc-requirements PiperOrigin-RevId: 486190913 04 November 2022, 18:36:18 UTC
acdb545 CI: add absl-py to docs/requirements.txt 04 November 2022, 18:28:15 UTC
0691be6 [typing] update jaxlib & remove unnecessary ignore 04 November 2022, 18:02:33 UTC
2f1f3e4 Merge pull request #13114 from skye:runner_readme PiperOrigin-RevId: 486171372 04 November 2022, 17:19:52 UTC
52775c4 Add .github/workflows/self_hosted_runner_utils/README.md This was meant to be part of https://github.com/google/jax/pull/13000, oops 04 November 2022, 17:12:54 UTC
1d48c93 Finish the release of jax and jaxlib 0.3.24 PiperOrigin-RevId: 486162090 04 November 2022, 16:43:12 UTC
3db2a59 Merge pull request #13097 from jakevdp:actions-permissions PiperOrigin-RevId: 486160888 04 November 2022, 16:36:32 UTC
46368e4 [sparse] Update the guard of cusparse SpMM and SpMv algorithms to cusparse version 11.7.1 onwards. PiperOrigin-RevId: 486051658 04 November 2022, 04:39:52 UTC
974134f Merge pull request #13103 from mattjj:issue13099 PiperOrigin-RevId: 485993800 03 November 2022, 22:47:32 UTC
4033007 improve error when f_vjp gets more than one argument fixes #13099 03 November 2022, 22:20:10 UTC
2e384ce Prepare for release of jax and jaxlib 0.3.24 PiperOrigin-RevId: 485985460 03 November 2022, 22:13:23 UTC
2b4735f Merge pull request #13094 from mattjj:fix-random-docs-table PiperOrigin-RevId: 485979035 03 November 2022, 21:48:10 UTC
dba9fc0 Merge pull request #13000 from skye:self_hosted_tpu PiperOrigin-RevId: 485973703 03 November 2022, 21:27:38 UTC
8c22e34 Add Github Actions workflow that runs on a self-hosted TPU VM runner. This also includes some utilites for setting up the self-hosted runner. Googlers, see go/jax-self-hosted-runners for more setup info. The workflow is pretty basic currently. We can and should add more functionality later, such as email notifications. I kept it simple here for easier reviewing. Testing: - Sample workflow run in my fork: https://github.com/skye/jax/actions/runs/3333614180 - Sample PR attempt: (will add soon but I did verify validate_job.sh blocks pull_request workflows) 03 November 2022, 21:15:57 UTC
8057e28 CI: set explicit permissions for ci-build action 03 November 2022, 20:21:58 UTC
478bd3e fix comparison table in random docs 1. rbg is not identical across cpu/gpu/tpu; 2. the unsafe_rbg column copied the jax.lax.rng_uniform column from the original table, but that wasnt right, as it should be identical to the rbg column; 3. for the last row mentioning identical across shardings, we should mention that's assuming the xla flag Also removed some rows which are only interesting in comparing to `jax.lax.rng_uniform` (which is not safe with `scan` or `remat`). Co-authored-by: Roy Frostig <frostig@google.com> 03 November 2022, 20:10:54 UTC
f4be5ab Merge pull request #12219 from jakevdp:indexing-slice PiperOrigin-RevId: 485946084 03 November 2022, 19:44:28 UTC
532cd7e Skip the benchmarks properly via state.skip_with_error when enough devices are not present. PiperOrigin-RevId: 485931295 03 November 2022, 18:44:57 UTC
753562d Add benchmarks for repeated static indexing & slicing 03 November 2022, 18:41:37 UTC
back to top