https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
a4d4c39 Internal change. PiperOrigin-RevId: 482900887 25 October 2022, 03:05:44 UTC
70f659a Merge pull request #12957 from jakevdp:fix-lstsq PiperOrigin-RevId: 483493703 24 October 2022, 21:59:30 UTC
15b415b Merge pull request #12951 from jakevdp:annotate-lax-numpy PiperOrigin-RevId: 483490497 24 October 2022, 21:47:11 UTC
56d42c0 [typing] annotate next batch of lax_numpy 24 October 2022, 21:21:35 UTC
9ade89e jnp.linalg.lstsq: handle zero-size inputs 24 October 2022, 21:10:31 UTC
964988c Merge pull request #12953 from eltociear:patch-1 PiperOrigin-RevId: 483471278 24 October 2022, 20:34:16 UTC
28def73 Fix typo in 9419-jax-versioning.md overriden -> overridden 24 October 2022, 18:26:48 UTC
b892108 Merge pull request #12950 from jakevdp:fix-ci-error PiperOrigin-RevId: 483425722 24 October 2022, 17:48:43 UTC
8f2f9f4 Merge pull request #12646 from adrn:truncnorm PiperOrigin-RevId: 483425197 24 October 2022, 17:41:51 UTC
894093c Move jaxlib cpu kernels under jaxlib/cpu/. No functional changes intended. PiperOrigin-RevId: 483413031 24 October 2022, 17:02:56 UTC
48e680c CI: avoid raising error when wrapped function is None 24 October 2022, 15:57:53 UTC
67fa7c2 Typo fix. PiperOrigin-RevId: 483380789 24 October 2022, 14:53:50 UTC
5784d61 implement truncnorm in jax.scipy.stats fix some shape and type issues import into namespace imports into non-_src library working logpdf test cleanup working tests for cdf and sf after fixing select relax need for x to be in (a, b) ensure behavior with invalid input matches scipy remove enforcing valid parameters in tests added truncnorm to docs whoops alphabetical fix linter error fix circular import issue 22 October 2022, 19:48:20 UTC
b07c586 [mhlo] Use 11 out of 12 new shared type inferences from StableHLO. The shape function of DotGeneralOp can't be integrated into MHLO yet: the shape function only predicts return shape but not able to predict element type. However, the current python binding infra will generate the constructor __init__() without the `return` as the first arg, which assumes the shape function can provide a fully inferred type (including an accurate element type). This leads to "inferred type does not match actual result type" errors in JAX. This needs a future solution. This CL is the corresponding change with https://github.com/openxla/stablehlo/pull/269 Related Python __init__() interface changes (used by JAX): batch_norm_grad: not used by JAX batch_norm_inference: not used by JAX batch_norm_training: not used by JAX case: no change* dot_general: open new b/253644255 to track the issue if: no change* map: no change* reduce: no change* reduce_window: no change* sort: no change* triangular_solve: updated in `linalg.py` while: no change* no change*: the signature of __init()__ for the op is not changed because of existence of regions https://github.com/llvm/llvm-project/blob/main/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp#L577 PiperOrigin-RevId: 482951512 22 October 2022, 03:34:04 UTC
3be5ab2 Allow calling `initialize_cache` a second time if the path is the same. PiperOrigin-RevId: 482945880 22 October 2022, 02:54:09 UTC
9956ad2 Add more pjit tests and make some tests go via actual computations rather than trivial computation. PiperOrigin-RevId: 482919649 21 October 2022, 23:53:53 UTC
a4e3663 Merge pull request #12921 from jakevdp:lax-numpy-dtypes PiperOrigin-RevId: 482905407 21 October 2022, 22:42:40 UTC
64e996e Merge pull request #12925 from jakevdp:annotate-lax PiperOrigin-RevId: 482903592 21 October 2022, 22:36:22 UTC
4acc293 Merge pull request #12923 from jakevdp:nperseg-test PiperOrigin-RevId: 482902569 21 October 2022, 22:30:04 UTC
e219d55 Roll-back #12892 because CUSPARSE_SPMV_COO_ALG2 is not available in CUDA 11.1 PiperOrigin-RevId: 482897448 21 October 2022, 22:06:17 UTC
4e8fbd0 Add delete method to GlobalDeviceArray and ShardedBuffer. This ensures all existing JAX buffer types have a `delete` method that can be used to free device buffer allocation eagerly. User code sometimes have lingering python refs due to cyclic deps and other reasons, yet users may know for sure that certain arrays will no longer be used after a certain point. Calling `foo_array.delete()` for DeviceArray/ShardedDeviceArray/GlobalDeviceArray/Array allows users to force free the device side allocation to minimize device memory usage. PiperOrigin-RevId: 482892157 21 October 2022, 21:42:32 UTC
ca7d05f [typing] fix incorrect type annotation on lax.argmax/argmin 21 October 2022, 21:37:59 UTC
37a0156 First pass at adding `GetOutputShardings` and `GetParameterShardings` on PjRTExecutable. PiperOrigin-RevId: 482878289 21 October 2022, 20:43:52 UTC
e297b13 Merge pull request #12899 from jakevdp:dead-code PiperOrigin-RevId: 482867884 21 October 2022, 20:02:47 UTC
4714a5c Add regression test for #12920 21 October 2022, 19:52:32 UTC
045d6e1 Merge pull request #12920 from yilei:nameerror PiperOrigin-RevId: 482864751 21 October 2022, 19:49:26 UTC
97b17af [typing] add type annotations to the first several lax_numpy functions 21 October 2022, 18:59:53 UTC
d63d038 Flatten the if/else block. 21 October 2022, 18:27:45 UTC
7c1bf0e Fix an NameError caused by #12754 21 October 2022, 17:22:36 UTC
503679d Merge pull request #12904 from froystig:dlpack64 PiperOrigin-RevId: 482798461 21 October 2022, 15:23:18 UTC
dd04953 [MLIR] Don't rely on hardcoded -1 for dynamic axis sizes The magic number might change, use an accessor to get it. PiperOrigin-RevId: 482796475 21 October 2022, 15:13:02 UTC
2228efe Merge pull request #12876 from mattjj:djax-vmap3 PiperOrigin-RevId: 482693137 21 October 2022, 05:39:51 UTC
8e8ae84 fix 21 October 2022, 05:23:29 UTC
f76fc01 put back some unsafe_maps 21 October 2022, 04:56:00 UTC
8030b49 Merge pull request #12912 from mattjj:issue12909 PiperOrigin-RevId: 482685295 21 October 2022, 04:44:09 UTC
6a3d2a0 update docs to point to jax.nn.standardize Fixes #12909 21 October 2022, 04:25:37 UTC
2d45d8b Merge pull request #12800 from mattjj:ayaka PiperOrigin-RevId: 482676604 21 October 2022, 03:40:08 UTC
60b236c improve (and shorten!) pmap error messages about inconsistent axis sizes 21 October 2022, 01:31:40 UTC
1b5294e Merge pull request #12906 from jakevdp:stats-mode-doc PiperOrigin-RevId: 482632489 20 October 2022, 23:37:56 UTC
408953b fix jax2tf readme typo PiperOrigin-RevId: 482625385 20 October 2022, 23:09:00 UTC
6e4b135 Merge pull request #12896 from jakevdp:unused-imports PiperOrigin-RevId: 482623200 20 October 2022, 23:01:02 UTC
4aceb81 Add docs & changelog for jax.scipy.stats.mode 20 October 2022, 22:55:57 UTC
4af3509 canonicalize dtypes when loading arrays via `dlpack` 20 October 2022, 22:20:34 UTC
7f89fd4 Cleanup: remove unused imports in private modules Also improve our flake8 filter rules to avoid ignoring these. 20 October 2022, 21:37:21 UTC
2801533 Merge pull request #12872 from jakevdp:annotate-map PiperOrigin-RevId: 482595321 20 October 2022, 21:22:36 UTC
f7f2351 Merge pull request #12900 from jakevdp:dlpack-test PiperOrigin-RevId: 482595065 20 October 2022, 21:15:43 UTC
dcd9143 Test: skip JaxToNumy DLPack test on numpy 1.22 np.from_dlpack was not added until numpy 1.23. 20 October 2022, 21:03:22 UTC
d6d5339 Remove unused code from jax/_src/lax/linalg.py 20 October 2022, 20:21:40 UTC
4facacb Merge pull request #12894 from jakevdp:annotate-ufuncs PiperOrigin-RevId: 482578230 20 October 2022, 20:18:11 UTC
7093142 [sparse] Update the default cuSparse matvec algorithm in jaxlib. PiperOrigin-RevId: 482553550 20 October 2022, 18:49:09 UTC
6d30865 [typing] annotate jax.numpy ufuncs 20 October 2022, 18:22:04 UTC
2d563bf [sparse] Add BCSR primitive bcsr_extract. PiperOrigin-RevId: 482530210 20 October 2022, 17:30:33 UTC
5d15757 [typing] annotate jax._src.util.safe_map 20 October 2022, 17:15:04 UTC
1a0affd Move `is_deleted()` to C++ so that we can check if an Array is deleted without materializing `_arrays`. Also raise a better error message when doing operations of a deleted Array rather than the current thing which says: `NoneType has no len()`. Now it says: `Array has been deleted`. PiperOrigin-RevId: 482497114 20 October 2022, 15:29:25 UTC
c1c8462 Merge pull request #12798 from skye:cache_min_instr_count PiperOrigin-RevId: 482349949 20 October 2022, 00:54:03 UTC
81eb3fc Add new config `jax_persistent_cache_min_instruction_count`. This can be used to limit the number of entries written to the persistent compilation cache. I defaulted to setting 6 as the minimum threshold based on running the flax wmt example (https://github.com/google/flax/tree/main/examples/wmt) and logging the instruction counts and complilation time: name | instruction_count | compile_time_secs ---- | ----------------- | ----------------- `broadcast_in_dim` | 2 | 0.01633763313 `convert_element_type` | 2 | 0.01704716682 `reshape` | 2 | 0.01730203629 `_squareit` | 2 | 0.01730823517 `broadcast_in_dim` | 2 | 0.0182030201 `convert_element_type` | 2 | 0.01982188225 `concatenate` | 2 | 0.02102327347 `true_divide` | 2 | 0.02172231674 `broadcast_in_dim` | 2 | 0.02370619774 `broadcast_in_dim` | 2 | 0.02393102646 `broadcast_in_dim` | 2 | 0.02488565445 `broadcast_in_dim` | 2 | 0.03395628929 `broadcast_in_dim` | 2 | 0.03428125381 `broadcast_in_dim` | 2 | 0.0394551754 `shift_right_logical` | 2 | 0.06500506401 `<lambda>` | 3 | 0.01793265343 `_unstack` | 5 | 0.01975226402 `_reduce_sum` | 5 | 0.0210878849 `_reduce_sum` | 5 | 0.02416801453 `_multi_slice` | 9 | 0.09065580368 `_threefry_split` | 232 | 0.09037566185 `_threefry_split` | 232 | 0.09161829948 `<unnamed wrapped function>` | 2668 | 7.701903343 `<unnamed wrapped function>` | 3455 | 17.57672167 `<unnamed wrapped function>` | 46580 | 166.2570884 `init` | 60361 | 26.35722399 `<unnamed wrapped function>` | 78010 | 3.879326344 Also adds new int config functionality. Fixes #12583 20 October 2022, 00:17:24 UTC
a1d303b [dynamic-shapes] fix nested vmap callable annotation logic 20 October 2022, 00:01:53 UTC
8536fae Merge pull request #12871 from ROCmSoftwarePlatform:rocm_docker_fix_for_pip PiperOrigin-RevId: 482279060 19 October 2022, 19:53:49 UTC
680096c [ROCm] Fix ROCm dockerfile to remove 2020-resolver opt in. The 2020-resolver opt in has been removed in pip 22.3 since it has now become the default. 19 October 2022, 18:11:04 UTC
76375ff Merge pull request #12863 from jakevdp:annotate-zip PiperOrigin-RevId: 482251072 19 October 2022, 18:04:06 UTC
524745f TMP: annotate util.safe_zip 19 October 2022, 17:29:53 UTC
02dc25f [JAX] Redisable int8 convolution tests on GPU due to CI failures. PiperOrigin-RevId: 482191832 19 October 2022, 13:50:36 UTC
c7dcd09 Merge pull request #12820 from mattjj:simple-sharding2 PiperOrigin-RevId: 482101400 19 October 2022, 04:37:16 UTC
43098f9 initial commit of DevicesSharding (fka SimpleSharding) need to add tests! Co-authored-by: Yash Katariya <yashkatariya@google.com> Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com> 19 October 2022, 04:10:24 UTC
3572bb2 [Rollback] Allow uncommitted single device PyArray in C++ pjit path. PiperOrigin-RevId: 482084898 19 October 2022, 02:42:10 UTC
8072699 Enable more GPU and TPU tests that pass at head. Increase precision of matmuls in LU decompositions, pseudo-inverse solves, and their gradients. It is unlikely users want to use low precision for these operations and high precision is probably the right default. PiperOrigin-RevId: 482071629 19 October 2022, 01:09:44 UTC
646416b Merge pull request #12854 from jakevdp:array-creation-annotation PiperOrigin-RevId: 482055094 18 October 2022, 23:43:29 UTC
c9a60f9 Only raise the warning if jax_array is enabled and the code is coming from `jit`. PiperOrigin-RevId: 482053435 18 October 2022, 23:35:43 UTC
5617a02 Remove JAX custom call implementation of batched triangular solve. XLA supports batched triangular solve on GPU and has since February 2022, which is older than the minimum jaxlib version. We can therefore delete our implementation and just use XLA's implementation. PiperOrigin-RevId: 482031830 18 October 2022, 22:04:14 UTC
d20b9fa Always use `.device_buffers` for jax.Array because `.device_buffer` can raise an error if there is more than 1 buffer present in the Array. PiperOrigin-RevId: 482028624 18 October 2022, 21:51:07 UTC
6aafb86 Merge pull request #12853 from jakevdp:annotate-index-tricks PiperOrigin-RevId: 482026886 18 October 2022, 21:44:57 UTC
a168c2d Merge pull request #12683 from ylamidon:add-scipy-stats-mode PiperOrigin-RevId: 482025648 18 October 2022, 21:38:44 UTC
9818d65 Merge pull request #12843 from LenaMartens:while-errors PiperOrigin-RevId: 482008689 18 October 2022, 20:46:14 UTC
eb20468 [typing] annotate jax.numpy array creation routines 18 October 2022, 20:46:07 UTC
8ba7911 Merge pull request #12851 from jakevdp:annotate-util PiperOrigin-RevId: 482008659 18 October 2022, 20:39:15 UTC
ccbc305 Add JAX equivalent of scipy.stats.mode 18 October 2022, 19:45:02 UTC
2fe71d2 [typing] annotate jax._src.numpy.index_tricks 18 October 2022, 19:25:33 UTC
66af016 Merge pull request #12844 from yejingxin:main PiperOrigin-RevId: 481970168 18 October 2022, 18:12:20 UTC
59374c1 skip some tests if runtime is stream_executor DETAILS: Run on CloudTPU v2-8 and found some tests in debugging_primitives_test fail due to stream_executor runtime cannot support host callback. Since host callback only support TFRT, so that skip all those types if runtime type is stream_executor. TESTED: passed unit test on both TPU v2-8 and CPU. 18 October 2022, 17:42:33 UTC
d64da3d Roll forward with fix: Remove the original python function `fun_` from C++ PjitFunction, as the destroying `fun_` may yield the thread in some cases, which causes error during deleting the python object of PjitFunction. PiperOrigin-RevId: 481950912 18 October 2022, 17:05:53 UTC
4af9795 Add default implementation for calculating `devices_indices_map` to `XLACompatibleSharding` by lowering to OpSharding and then using its `devices_indices_map`. Why? Because users don't have to write the logic for this once they have written the logic for calculating the op_sharding proto. PiperOrigin-RevId: 481946515 18 October 2022, 16:51:08 UTC
d60ceea [typing] annotate util.unzip2 & util.unzip3 18 October 2022, 16:47:49 UTC
9749183 Merge pull request #12781 from jakevdp:jax-dtypes-types PiperOrigin-RevId: 481810269 18 October 2022, 03:50:30 UTC
ae8eb6f Merge pull request #12839 from jakevdp:update-pre-commit PiperOrigin-RevId: 481767929 17 October 2022, 23:27:53 UTC
b36717e Merge pull request #12810 from LenaMartens:less-consts2 PiperOrigin-RevId: 481762484 17 October 2022, 23:02:06 UTC
8ff293a Fix xla_bridge_test on TPU DETAILS: When run xla_bridge_test on TPU v2-8 it raises the follow error about unknown backend tpu, this change set jax_platforms to be "" to eliminate this error. ``` FAILED tests/xla_bridge_test.py::GetBackendTest::test_backend_init_error - RuntimeError: Unable to initialize backend 'tpu': Unknown backend 'tpu' (set JAX_PLATFORMS='' to automatically choose an available backend) ``` TESTED: pass unit test on both CPU and TPU PiperOrigin-RevId: 481758573 17 October 2022, 22:44:06 UTC
c2a00a0 Disallow checkify-of-vmap-of-while. 17 October 2022, 22:01:43 UTC
7a99e51 Merge pull request #12834 from google:dependabot/github_actions/styfle/cancel-workflow-action-0.11.0 PiperOrigin-RevId: 481734398 17 October 2022, 21:12:35 UTC
9961872 Merge pull request #12838 from dan-zheng:fix-typo PiperOrigin-RevId: 481734354 17 October 2022, 21:05:38 UTC
ed7a8bb [typing] annotate jax._src.dtypes 17 October 2022, 20:49:26 UTC
1ed18fa add allow_opaque_dtype to dtypes.canonicalize_dtype utility 17 October 2022, 20:47:42 UTC
fd2f590 Allow uncommitted single device PyArray in C++ pjit path. PiperOrigin-RevId: 481711690 17 October 2022, 19:35:30 UTC
87f1a2b CI: update mypy version in pre-commit config 17 October 2022, 18:25:14 UTC
cef5f20 Bump styfle/cancel-workflow-action from 0.10.1 to 0.11.0 Bumps [styfle/cancel-workflow-action](https://github.com/styfle/cancel-workflow-action) from 0.10.1 to 0.11.0. - [Release notes](https://github.com/styfle/cancel-workflow-action/releases) - [Commits](https://github.com/styfle/cancel-workflow-action/compare/0.10.1...0.11.0) --- updated-dependencies: - dependency-name: styfle/cancel-workflow-action dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> 17 October 2022, 17:18:44 UTC
504b3c1 roll forward with the fix: Make `params` arg in Compiled.call() position-only so that it does not conflict with the keyword args. PiperOrigin-RevId: 481666211 17 October 2022, 16:50:55 UTC
9b0c4e5 Fix typo. decice -> device 15 October 2022, 05:12:08 UTC
4cfa01f Improve the error message when users are trying to create SDAs and pass them into pjit/xmap when jax.Array is enabled. The error message tells them exactly what to do to fix the error. PiperOrigin-RevId: 481282762 15 October 2022, 02:43:32 UTC
69525cd [sparse] Make BCSR vmappable. PiperOrigin-RevId: 481257762 14 October 2022, 23:27:24 UTC
63be0c3 Guard the new channel_handle feature on mlir_api_version for backwards compatibility PiperOrigin-RevId: 481246613 14 October 2022, 22:28:30 UTC
8bd913c Merge pull request #12813 from jakevdp:clarify-typing PiperOrigin-RevId: 481239620 14 October 2022, 21:54:12 UTC
back to top