https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
41417ee Merge pull request #12752 from skye:workspace PiperOrigin-RevId: 480405935 11 October 2022, 18:12:26 UTC
63be220 Update WORKSPACE and setup.py for jaxlib 0.3.22 release 11 October 2022, 17:55:05 UTC
af98d8b Merge pull request #12746 from jakevdp:fix-changelog PiperOrigin-RevId: 480375907 11 October 2022, 16:21:28 UTC
05faf0f Remove deprecated functionality from jax.test_util PiperOrigin-RevId: 480360504 11 October 2022, 15:16:34 UTC
9c99498 Support MANUAL collectives in top-level xmaps It's a bit of a weird use-case, since MANUAL mode is meant for xmaps that are nested inside pjits, but it doesn't hurt us to support it. PiperOrigin-RevId: 480342531 11 October 2022, 13:47:30 UTC
bf78697 Merge pull request #12727 from jakevdp:scipy-linalg-types PiperOrigin-RevId: 480340339 11 October 2022, 13:34:13 UTC
9431294 changelog: add missing github commit links 11 October 2022, 13:24:38 UTC
e45df46 Clarify docs for `fori_loop`, noting that negative or custom increments are not supported. PiperOrigin-RevId: 480317277 11 October 2022, 11:41:10 UTC
eade368 Merge pull request #12324 from gnecula:tf_pjit PiperOrigin-RevId: 480310826 11 October 2022, 11:01:12 UTC
9c879ad [jax2tf] Implement jax2tf(pjit) for experimental_native_lowering This implementation is for the case jax2tf.convert(pjit(f_jax)), that is, the `pjit` appears at the top-level of the function to be lowered. 11 October 2022, 07:45:07 UTC
ff17d3d Add support for calculating the device_assignment when there are no inputs to `jit` and `pjit`. Also look at the shardings inside the jaxpr for `sharding_constraint_p` and `pjit_p` primitives since with `jax.Array`, each `with_sharding_constraint`/`pjit` inside a computation can contain a different sharding (so we need to check if the device_assignment is the same). Also the output is `committed` if there are jaxpr shardings inside the computation via `with_sharding_constraint`/`pjit` or if any of the inputs are committed or `output_sharding` is specified. Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 480256796 11 October 2022, 05:08:42 UTC
df5f7cb Rolling forward https://github.com/google/jax/pull/12707 after rollback, due to changes in relatively trivial jax.numpy shape validation code failed in some downstream user tests. PiperOrigin-RevId: 480229237 11 October 2022, 01:51:37 UTC
9b3e864 Add weak_type attribute to `Array` since it exists on DA (but doesn't exist on SDA). PiperOrigin-RevId: 480223116 11 October 2022, 01:11:11 UTC
afe74b4 [typing] add type annotations to jax.scipy.linalg 10 October 2022, 23:54:29 UTC
ec5b1c9 Turn on cpp pjit py default PiperOrigin-RevId: 480185387 10 October 2022, 22:01:04 UTC
76d8c08 Fix the type annotation of return type of `device_buffer` and `device_buffers` which return `ArrayImpl` instead of DeviceArray. PiperOrigin-RevId: 480181798 10 October 2022, 21:45:12 UTC
a3a2206 Fix compilation failure in lapack kernel under msan. a_size wasn't defined, but it would only be caught under memory sanitizer. PiperOrigin-RevId: 480176934 10 October 2022, 21:24:59 UTC
2246887 Add input-output aliasing annotations for LAPACK calls on CPU. PiperOrigin-RevId: 480156067 10 October 2022, 19:57:29 UTC
752c3ff Lift `lambda x: x` to the top level so that we don't recompile on every invocation of `process_allgather`. PiperOrigin-RevId: 480155482 10 October 2022, 19:51:17 UTC
90e9abe Merge pull request #12722 from hawkinsp:tests PiperOrigin-RevId: 480149233 10 October 2022, 19:22:52 UTC
2ba0396 Add changes accidentally omitted from https://github.com/google/jax/pull/12717 10 October 2022, 19:11:58 UTC
34eb6ce [sparse] BCSR fromdense and todense. PiperOrigin-RevId: 480141918 10 October 2022, 18:54:22 UTC
c657449 Copybara import of the project: -- d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>: Migrate more tests from jtu.cases_from_list to jtu.sample_product. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d PiperOrigin-RevId: 480136538 10 October 2022, 18:35:32 UTC
22cd505 Reapply: Use input-output aliasing for jaxlib GPU custom calls. Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need. It turns out some users are relying on the API contract of the custom calls within serialized HLO remaining stable. For the moment, we reapply only the Python changes. The C++ code is already tolerant of both aliased and unaliased outputs, and this gets us all the benefit of saving a copy. We can break backwards compatibility on the serialized HLO after users upgrade their saved HLO to the aliased version. PiperOrigin-RevId: 480134780 10 October 2022, 18:29:18 UTC
707b07c Merge pull request #12697 from jakevdp:lax-slicing-types PiperOrigin-RevId: 480131675 10 October 2022, 18:22:35 UTC
ad4dcc4 Merge pull request #12700 from jakevdp:scipy-types PiperOrigin-RevId: 480131410 10 October 2022, 18:15:46 UTC
124021d [typing] add annotations to jax.scipy.ndimage 10 October 2022, 16:11:13 UTC
76e4a1d [typing] add annotations to jax.scipy.fft 09 October 2022, 12:18:45 UTC
ae9f8ee [typing] annotate lax.slicing 09 October 2022, 11:20:46 UTC
9cabd22 Copybara import of the project: -- 6d2aaac2454117d54997243714c1a009827707ca by Matthew Johnson <mattjj@google.com>: implement bint arrays (opaque dtypes), add padding rules Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com> PiperOrigin-RevId: 479883102 09 October 2022, 08:25:50 UTC
25c6ef7 Merge pull request #12707 from mattjj:djax-slice-sick4 PiperOrigin-RevId: 479876971 09 October 2022, 07:23:39 UTC
6d2aaac implement bint arrays (opaque dtypes), add padding rules Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com> 09 October 2022, 05:57:29 UTC
75b2d05 Make `is_fully_replicated` and `is_fully_addressble` a property rather than a method. Why? 1. Because it's easy to cache a property than a method with only the `self` argument. (See below for article) 2. There's no harm in making them a property because both of them return a bool without any side-effects and are cached (so its fast). Why cache `is_fully_addressable`? Because its very expensive to calculate when you have 1000s of devices. PiperOrigin-RevId: 479850850 09 October 2022, 02:24:12 UTC
674038c Merge pull request #12705 from mattjj:fix-prng-key-array-device-put PiperOrigin-RevId: 479813689 08 October 2022, 18:39:05 UTC
0a0f492 make device_put(prngkeyarray, sharding) for Array Co-authored-by: Yash Katariya <yashkatariya@google.com> Co-authored-by: Roy Frostig <frostig@google.com> 07 October 2022, 23:50:16 UTC
2693afa Revert: Use input-output aliasing for jaxlib GPU custom calls. Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need. This seems to be causing some test failures downstream, so reverting this for the moment until I can debug them. PiperOrigin-RevId: 479670565 07 October 2022, 21:36:24 UTC
e8ba61d Merge pull request #12677 from mattjj:jit-pjit-lower-sharding PiperOrigin-RevId: 479669125 07 October 2022, 21:28:51 UTC
58cd837 Merge pull request #12675 from mattjj:device-put2 PiperOrigin-RevId: 479660808 07 October 2022, 20:49:57 UTC
33f18fe Merge pull request #12703 from ROCmSoftwarePlatform:rocm-ci-update2 PiperOrigin-RevId: 479652964 07 October 2022, 20:12:57 UTC
34f6646 Add default setting for TENSORFLOW_ROCM_COMMIT 07 October 2022, 19:57:53 UTC
93b839a Use input-output aliasing for jaxlib GPU custom calls. Previously we had no way to tell XLA that inputs and outputs of GPU custom calls must alias. This now works in XLA:GPU so we can just ask XLA to enforce the aliasing we need. PiperOrigin-RevId: 479642543 07 October 2022, 19:22:04 UTC
0995f79 Merge pull request #12698 from jakevdp:scipy-fft-todo PiperOrigin-RevId: 479615134 07 October 2022, 17:29:18 UTC
730f993 Update scipy version in jax.scipy.fft 07 October 2022, 16:57:39 UTC
fde444e Merge pull request #12695 from gnecula:tf_dimension PiperOrigin-RevId: 479602588 07 October 2022, 16:38:45 UTC
fb2141f [jax2tf] Allow the use of DimPolynomial with jnp.array and binary operations Prior to this the user had to explicitly call core.dimension_as_value whenever using a potentially polymorphic shape in the computation, e.g., x + core.dimension_as_value(x.shape[0]). Furthermore, jnp.array(x.shape[0]) would fail. Now, these operations are allowed implicitly, and the user can call `jnp.array(x.shape[0])`. This uses an internal extensibility mechanism called __jax_array__ that is experimental and probably not fully implemented. 07 October 2022, 14:58:41 UTC
363cc12 Merge pull request #12197 from ROCmSoftwarePlatform:fixedRocmUnitTestsSkip PiperOrigin-RevId: 479566021 07 October 2022, 13:36:11 UTC
7c7c94c Expand support for __jax_array__ in jnp.array. This relates to the long discussion in #4725 and #10065. 07 October 2022, 11:25:07 UTC
6c70e4d Merge pull request #12691 from mattjj:issue12688 PiperOrigin-RevId: 479507232 07 October 2022, 07:10:18 UTC
076a734 fix -O / PYTHONOPTIMIZE bug fixes #12688 I'm not sure how to write test cases for PYTHONOPTIMIZE=1 (without growing our whole test matrix), so I'm leaving this untested... 07 October 2022, 06:15:22 UTC
7a82536 [sparse] Bug fix in _validate_bcsr. PiperOrigin-RevId: 479452053 07 October 2022, 00:28:52 UTC
f465aee Merge pull request #12673 from jakevdp:lax-linalg-types PiperOrigin-RevId: 479444281 06 October 2022, 23:48:45 UTC
bcca6fb add test, small fixes Co-authored-by: Yash Katariya <yashkatariya@google.com> 06 October 2022, 23:45:34 UTC
9104536 [typing] add type annotations to lax.linalg functions 06 October 2022, 23:19:00 UTC
ce95eba make device_put work with Sharding 2nd arg Co-authored-by: Yash Katariya <yashkatariya@google.com> 06 October 2022, 23:14:15 UTC
f49d3d4 Rename Executable to LoadedExecutable within jax. PiperOrigin-RevId: 479423951 06 October 2022, 22:14:33 UTC
df4f399 Merge pull request #12678 from jakevdp:average-axis-tuple PiperOrigin-RevId: 479414002 06 October 2022, 21:38:13 UTC
ba2d803 Merge pull request #12674 from jakevdp:svd-type PiperOrigin-RevId: 479413950 06 October 2022, 21:31:33 UTC
1f3048b Merge pull request #12573 from ROCmSoftwarePlatform:rocm-ci-update PiperOrigin-RevId: 479369638 06 October 2022, 18:33:41 UTC
824903b Merge pull request #12686 from hawkinsp:sample PiperOrigin-RevId: 479354990 06 October 2022, 17:51:56 UTC
8107e36 Switch lax_numpy_indexing_test to use jtu.sample_product. 06 October 2022, 17:44:17 UTC
d174b3d Take `shardings` as a parameter to `deserialize` and `run_deserialization` instead of `mesh` and `pspecs`. PiperOrigin-RevId: 479346552 06 October 2022, 17:20:49 UTC
32ef3ba jnp.average: support tuple axis 06 October 2022, 17:20:46 UTC
d94327c Move promote_like_jnp to jax.test_util 06 October 2022, 17:20:26 UTC
219c574 Merge pull request #12676 from jakevdp:fix-concat-test PiperOrigin-RevId: 479318171 06 October 2022, 15:15:22 UTC
1b5e9e4 add input/output sharding to executable protocol 06 October 2022, 00:30:56 UTC
d65145f Merge pull request #12532 from jakevdp:reduction-dtype PiperOrigin-RevId: 479180617 06 October 2022, 00:23:30 UTC
3f08f85 Merge pull request #12666 from jakevdp:polynomial-typing PiperOrigin-RevId: 479180604 06 October 2022, 00:17:06 UTC
2525aa6 Merge pull request #12642 from RissyRan:platform PiperOrigin-RevId: 479176113 05 October 2022, 23:54:11 UTC
e8dc6d1 improve jit(f).lower(duck_args) and pjit(f).lower(duck_args) Co-authored-by: Yash Katariya <yashkatariya@google.com> 05 October 2022, 22:47:59 UTC
ff08109 test: fix LaxNumpyTest:testConcatenate 05 October 2022, 22:29:15 UTC
2de91d2 Handle FP8 types. PiperOrigin-RevId: 479148993 05 October 2022, 21:48:30 UTC
c1328a6 [typing] overloads for jnp.linalg.svd & jnp.linalg.qr 05 October 2022, 21:29:00 UTC
9d12e21 Add `addressable_shards` to SDA and DA as a compatibility API to match with `jax.Array`. This will aid in transition to `jax.Array`. PiperOrigin-RevId: 479115126 05 October 2022, 19:35:17 UTC
167004d Merge pull request #12654 from jakevdp:fft-typing PiperOrigin-RevId: 479106034 05 October 2022, 19:00:10 UTC
3ec5e45 Merge pull request #12660 from hawkinsp:testing PiperOrigin-RevId: 479098519 05 October 2022, 18:33:55 UTC
6a348f9 [typing] add types for jax.numpy.polynomial 05 October 2022, 18:23:45 UTC
f3ded0f Address comments for change log 05 October 2022, 18:16:49 UTC
dd0a455 Add sharding to `DeviceArray` and `ShardedDeviceArray` as a compatibility change to rollout `jax.Array`. Also expose `device_replica_id_map` since that is important API for checkpointing to find all the unique shards of an Array. You can also use this to calculate the unique indices in a sharding (which is what `sda.one_replica_buffer_indices` does) PiperOrigin-RevId: 479072520 05 October 2022, 16:58:41 UTC
4da72cf Add `host_local_array_to_global_array` and `global_array_to_host_local_array` for enabling transition to jax.Array. Also support `FROM_GDA` for `jax.Array` as a backwards compatible change so that users can continue to use that until they transition to jax.Array. Its currently required because of usage like `in_axis_resources = (FROM_GDA, FROM_GDA, P('data'), None)` and changing this on users side will require input from users so we as JAX can just support it as a temporary thing since GDA and Array act similarly in pjit. PiperOrigin-RevId: 479035326 05 October 2022, 14:01:58 UTC
2c946b3 Migrate api_test, lax_numpy_test, and lax_vmap_test to jtu.sample_product. Gives a ~2x improvement in pytest --collect-only timing for lax_numpy_test. 05 October 2022, 13:46:19 UTC
4870fd3 Update message and change log 05 October 2022, 04:39:04 UTC
22e1547 Update set up message 05 October 2022, 04:39:04 UTC
ef55953 Add set up message for JAX_PLATFORMS 05 October 2022, 04:39:04 UTC
78a7e16 Set JAX_PLATFORMS=tpu,cpu on TPUs 05 October 2022, 04:39:04 UTC
015a12e Merge pull request #12657 from jakevdp:linalg-typing PiperOrigin-RevId: 478937233 05 October 2022, 02:05:44 UTC
9646199 [typing] add annotations to numpy.fft 04 October 2022, 22:52:54 UTC
78ed03c [typing] add annotations to jax.numpy.linalg 04 October 2022, 22:50:29 UTC
1206e10 Merge pull request #12658 from sharadmv:fix-profiler PiperOrigin-RevId: 478893609 04 October 2022, 22:09:49 UTC
83ef7d0 Fix collect_profile _src import 04 October 2022, 21:31:17 UTC
2b0542d Merge pull request #12640 from mattjj:djax-slice-sick PiperOrigin-RevId: 478860162 04 October 2022, 19:51:04 UTC
06a2c85 [dynamic-shapes] small fix to einsum (and indexing) 04 October 2022, 19:29:41 UTC
a60ca9f Test that array layout is preserved in Python callbacks PiperOrigin-RevId: 478852392 04 October 2022, 19:14:47 UTC
3f663e4 Merge pull request #12639 from skye:version PiperOrigin-RevId: 478850508 04 October 2022, 19:07:38 UTC
62774f5 Merge pull request #12641 from jakevdp:asarray-type PiperOrigin-RevId: 478829647 04 October 2022, 17:53:12 UTC
9e0b8ca Merge pull request #12653 from jakevdp:jacobian-wraps PiperOrigin-RevId: 478829535 04 October 2022, 17:47:01 UTC
0d93679 jax.jacobian: propagate function signature to transformed function 04 October 2022, 17:21:54 UTC
3be2087 jnp.prod & jnp.sum: promote to default integer type rather than int64/uint64 04 October 2022, 17:08:30 UTC
069866e Add types to jax/_src/numpy/util.py 04 October 2022, 17:07:38 UTC
ae49d2e [sparse] Add conversions between BCSR and BCOO. PiperOrigin-RevId: 478816413 04 October 2022, 17:00:16 UTC
37f9db7 Create `Array`s from `__getitem__` and `__iter__`. This is done by `device_put`ting from the host to default device which is suboptimal. But there is a TODO to fix this! PiperOrigin-RevId: 478691051 04 October 2022, 05:29:03 UTC
back to top