https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
2e384ce Prepare for release of jax and jaxlib 0.3.24 PiperOrigin-RevId: 485985460 03 November 2022, 22:13:23 UTC
2b4735f Merge pull request #13094 from mattjj:fix-random-docs-table PiperOrigin-RevId: 485979035 03 November 2022, 21:48:10 UTC
dba9fc0 Merge pull request #13000 from skye:self_hosted_tpu PiperOrigin-RevId: 485973703 03 November 2022, 21:27:38 UTC
8c22e34 Add Github Actions workflow that runs on a self-hosted TPU VM runner. This also includes some utilites for setting up the self-hosted runner. Googlers, see go/jax-self-hosted-runners for more setup info. The workflow is pretty basic currently. We can and should add more functionality later, such as email notifications. I kept it simple here for easier reviewing. Testing: - Sample workflow run in my fork: https://github.com/skye/jax/actions/runs/3333614180 - Sample PR attempt: (will add soon but I did verify validate_job.sh blocks pull_request workflows) 03 November 2022, 21:15:57 UTC
478bd3e fix comparison table in random docs 1. rbg is not identical across cpu/gpu/tpu; 2. the unsafe_rbg column copied the jax.lax.rng_uniform column from the original table, but that wasnt right, as it should be identical to the rbg column; 3. for the last row mentioning identical across shardings, we should mention that's assuming the xla flag Also removed some rows which are only interesting in comparing to `jax.lax.rng_uniform` (which is not safe with `scan` or `remat`). Co-authored-by: Roy Frostig <frostig@google.com> 03 November 2022, 20:10:54 UTC
f4be5ab Merge pull request #12219 from jakevdp:indexing-slice PiperOrigin-RevId: 485946084 03 November 2022, 19:44:28 UTC
532cd7e Skip the benchmarks properly via state.skip_with_error when enough devices are not present. PiperOrigin-RevId: 485931295 03 November 2022, 18:44:57 UTC
753562d Add benchmarks for repeated static indexing & slicing 03 November 2022, 18:41:37 UTC
1627bc6 generate dynamic_slice rather than slice for simple indexing/slicing 03 November 2022, 18:40:51 UTC
91d134d Remove unsupported type combinations from dot primitive tests. PiperOrigin-RevId: 485850678 03 November 2022, 12:55:13 UTC
b0621e3 Fix the vmap rule for remat_p It treated constants like args, but failed to convert_constvars_jaxpr to adjust the calling convention. PiperOrigin-RevId: 485847686 03 November 2022, 12:34:09 UTC
c8767ff Merge pull request #13080 from jakevdp:fix-jet PiperOrigin-RevId: 485776556 03 November 2022, 04:57:43 UTC
243b931 BUG: fix jet rule for dynamic_slice 03 November 2022, 04:37:41 UTC
eb9ac0d Merge pull request #13078 from jakevdp:bcoo-indexing PiperOrigin-RevId: 485759390 03 November 2022, 03:01:57 UTC
cc5af7e Rename `ReshapeableDevicesSharding` to `PositionalSharding` and add an alias `NamedSharding` for `MeshPspecSharding`. `MeshPspecSharding` name will be replaced with `NamedSharding` in 3 months. PiperOrigin-RevId: 485753078 03 November 2022, 02:13:13 UTC
5448ea6 Merge pull request #13083 from jakevdp:jet-dynamic-update PiperOrigin-RevId: 485740508 03 November 2022, 00:44:09 UTC
8b81240 Add c++ pickling for all the jax Sharding types. Needed for AOT. PiperOrigin-RevId: 485737664 03 November 2022, 00:28:38 UTC
d573ce1 Clarify the limitation of running jax2tf converted pjit models on CPUs/GPUs. PiperOrigin-RevId: 485725486 02 November 2022, 23:31:38 UTC
e1af93a Enable state effect in `cond_p` (except in `grad` and `vmap`) PiperOrigin-RevId: 485719926 02 November 2022, 23:07:58 UTC
db7eea1 add jet rule for dynamic_update_slice 02 November 2022, 21:47:11 UTC
2dc8043 Increase the shard count of scipy_stats_test because it is timing out in OSS builds. PiperOrigin-RevId: 485659146 02 November 2022, 19:12:54 UTC
b467feb [JAX] Add RunTimeError to host_local_array_to_global_array PiperOrigin-RevId: 485657586 02 November 2022, 19:06:26 UTC
27bb347 Test dynamic_arg_shardings only for '==' equality not also the default pointer equality. Also add tests which checks this behavior and makes sure that we don't fallback to python PiperOrigin-RevId: 485656967 02 November 2022, 18:59:25 UTC
ef63f75 Merge pull request #13039 from skye:cache_compile_time_heuristic PiperOrigin-RevId: 485644419 02 November 2022, 18:13:52 UTC
6ed9b14 [sparse] simplify BCOO indexing implementation 02 November 2022, 17:13:35 UTC
94ba43b Don't assume that vmap doesn't introduce constants Because it doesn't hold e.g. for the batcher of ppermute. PiperOrigin-RevId: 485601414 02 November 2022, 15:27:10 UTC
d77cccf Update docker images to add python 3.11 support. PiperOrigin-RevId: 485488215 02 November 2022, 03:35:38 UTC
93bcb59 Merge pull request #13065 from sharadmv:vis-colors PiperOrigin-RevId: 485467801 02 November 2022, 01:11:06 UTC
cc51710 Add new config `jax_persistent_cache_min_compile_time_secs`. This replaces `jax_persistent_cache_min_instruction_count` introduced in https://github.com/google/jax/pull/12798, since gating on the compile time seems strictly better than gating on the instruction count (except maybe that the instruction count is more deterministic, but I don't think that's a big deal). I defaulted to 1 second as the minimum threshold based on the same flax wmt example (https://github.com/google/flax/tree/main/examples/wmt) numbers from name | instruction_count | compile_time_secs ---- | ----------------- | ----------------- `broadcast_in_dim` | 2 | 0.01633763313 `convert_element_type` | 2 | 0.01704716682 `reshape` | 2 | 0.01730203629 `_squareit` | 2 | 0.01730823517 `broadcast_in_dim` | 2 | 0.0182030201 `convert_element_type` | 2 | 0.01982188225 `concatenate` | 2 | 0.02102327347 `true_divide` | 2 | 0.02172231674 `broadcast_in_dim` | 2 | 0.02370619774 `broadcast_in_dim` | 2 | 0.02393102646 `broadcast_in_dim` | 2 | 0.02488565445 `broadcast_in_dim` | 2 | 0.03395628929 `broadcast_in_dim` | 2 | 0.03428125381 `broadcast_in_dim` | 2 | 0.0394551754 `shift_right_logical` | 2 | 0.06500506401 `<lambda>` | 3 | 0.01793265343 `_unstack` | 5 | 0.01975226402 `_reduce_sum` | 5 | 0.0210878849 `_reduce_sum` | 5 | 0.02416801453 `_multi_slice` | 9 | 0.09065580368 `_threefry_split` | 232 | 0.09037566185 `_threefry_split` | 232 | 0.09161829948 `<unnamed wrapped function>` | 2668 | 7.701903343 `<unnamed wrapped function>` | 3455 | 17.57672167 `<unnamed wrapped function>` | 46580 | 166.2570884 `init` | 60361 | 26.35722399 `<unnamed wrapped function>` | 78010 | 3.879326344 Also adds new float config functionality. 02 November 2022, 00:56:19 UTC
3bbd5f3 Add colors to sharding visualization 01 November 2022, 23:55:06 UTC
60b9746 Merge pull request #13015 from jurahul:master PiperOrigin-RevId: 485452923 01 November 2022, 23:52:21 UTC
e881d16 Raise an error if jax_array is not enabled when use jax.device_put with a Sharding as input. PiperOrigin-RevId: 485441762 01 November 2022, 23:08:28 UTC
ef0f64e [sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib. PiperOrigin-RevId: 485441349 01 November 2022, 23:01:50 UTC
bb07028 Make device_put accept a prefix tree with Sharding leaves as the second argument PiperOrigin-RevId: 485419880 01 November 2022, 21:32:55 UTC
d207194 [sparse] fix a typo in n_sparse calculation. PiperOrigin-RevId: 485376576 01 November 2022, 18:46:10 UTC
52d29e1 Fix docstring in GDA PiperOrigin-RevId: 485325853 01 November 2022, 15:36:06 UTC
d08130b Create a table of contents for jax2tf documentation. PiperOrigin-RevId: 485247319 01 November 2022, 07:35:49 UTC
08fbbc3 Skip checkify checks for pjrt C API because of a failing pjit test with jax.Array is switched on. PiperOrigin-RevId: 485190494 01 November 2022, 00:44:21 UTC
5adfb08 Add `lax.cumlogsumexp` for cumulative logsumexp operations. PiperOrigin-RevId: 485158935 31 October 2022, 22:08:52 UTC
06c1d8e Rollback of: [sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib. Still breaks CUDA 11.1 PiperOrigin-RevId: 485151807 31 October 2022, 21:38:47 UTC
ca1f58e Add a new `jax.spmd_mode` config for preventing unintentional hangs and incorrect results when users pass `jax.Array`s that span across multiple processes (i.e. not fully addressable) to `jit` or jnp operations (that are jitted by default). Implicitly jitted functions will **always** require a `jax.spmd_mode` context manager for operating on non-fully addressable jax.Array. Explicitly jitted functions will require the `jax.spmd_mode` config to begin with as we roll out jax.Array since its a new behavior for `jit` (previously jit only worked on single device arrays). * Overtime (via docs) and as users become more familiar with the new parallelism APIs, we can relax this restriction and allow explicit `jit` to work without needing the config. This can happen when we merge the frontend of `jit` and `pjit`. PiperOrigin-RevId: 485075693 31 October 2022, 16:51:42 UTC
f3ddd56 Merge pull request #13035 from jakevdp:jnp-put PiperOrigin-RevId: 485075125 31 October 2022, 16:45:17 UTC
71edfc7 Merge pull request #12670 from jakevdp:numpy-check-arraylike PiperOrigin-RevId: 485073609 31 October 2022, 16:38:37 UTC
2416d15 Call _check_arraylike for jnp.linalg & jnp.fft functions 31 October 2022, 16:19:53 UTC
32a0ea8 Add `global_shards` to `jax.Array` as it exists on GDA and is being used in various places. PiperOrigin-RevId: 485065876 31 October 2022, 16:08:03 UTC
5b0cb11 Merge pull request #13040 from froystig:rng-part-bitschange-extents PiperOrigin-RevId: 484736309 29 October 2022, 13:03:09 UTC
213d2c8 integrate new (partitionable, count-space-exhaustive) counts generation 29 October 2022, 07:05:49 UTC
ef4fd14 [XLA:CPU][XLA:GPU] Remove Sharding CustomCall when num_partitions==1, or return an error when num_partitions>1 but SPMD partitioning is not enabled. This avoids a segfault when executing a CPU/GPU binary with Sharding CustomCalls. PiperOrigin-RevId: 484692861 29 October 2022, 06:22:35 UTC
8dea82e Merge pull request #13022 from mattjj:leak-checker-improvements PiperOrigin-RevId: 484640693 28 October 2022, 23:05:43 UTC
9dac458 Merge pull request #11077 from DanPuzzuoli:ode_dt_max PiperOrigin-RevId: 484639938 28 October 2022, 23:05:25 UTC
03b2691 Merge pull request #13037 from mattjj:fix-ci-failure PiperOrigin-RevId: 484639728 28 October 2022, 22:58:25 UTC
0203534 fix ci failure by skipping tests on gpu 28 October 2022, 21:39:51 UTC
63bfb87 wip bits-changing partitionable rng based on iota raveling Co-authored-by: Matthew Johnson <mattjj@google.com> 28 October 2022, 21:17:34 UTC
8bde3a0 Point to ndarray.at from docstring of unimplemented jnp.put & jnp.place 28 October 2022, 21:13:36 UTC
8562e76 jax.numpy docstrings: remove empty parameters section 28 October 2022, 21:13:29 UTC
6ebf44a make leak checker errors explain why objects are alive Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com> Co-authored-by: Roy Frostig <frostig@google.com> 28 October 2022, 21:12:17 UTC
f9e7629 Merge pull request #13010 from jakevdp:annotate-lax-numpy PiperOrigin-RevId: 484572239 28 October 2022, 17:57:48 UTC
89b240b Merge pull request #13012 from mattjj:rng-part-overgenerate PiperOrigin-RevId: 484567918 28 October 2022, 17:41:35 UTC
54967e9 Improve effect support on internal backends. PiperOrigin-RevId: 484566725 28 October 2022, 17:34:58 UTC
15b489f [typing] annotate next section of lax_numpy.py 28 October 2022, 17:14:00 UTC
c8b9280 partitionable threefry PRNG random bits implementation the cost is 2x overgeneration of bits Co-authored-by: Matthew Johnson <mattjj@google.com> 28 October 2022, 17:07:14 UTC
2cc4fd2 Add the full shape to the error message too (in pjit_check_aval_sharding) to give users full information about what is going on. PiperOrigin-RevId: 484547088 28 October 2022, 16:14:11 UTC
57eb19f Add a warning to device.live_buffers() as it is going to be deprecated with jax.Array and instruct users to use jax.live_arrays() instead. PiperOrigin-RevId: 484533292 28 October 2022, 15:11:51 UTC
1816263 Merge pull request #13023 from jakevdp:jnp-window PiperOrigin-RevId: 484528621 28 October 2022, 14:42:50 UTC
42e75a2 Merge pull request #12989 from apaszke:xmap-better-axis-checks PiperOrigin-RevId: 484524492 28 October 2022, 14:22:24 UTC
c1c22d9 [jax2tf] Improve handling of XlaSharding ops. Since TF does not support the XlaSharding op in eager mode, we make two changes: raise an error when a function with sharded arguments or results is used in a TF eager context, and omit the XlaSharding op for REPLICATED sharding. The latter improves usability, allowing lowered functions with replicated sharding to be used in TF eager context. It also results in smaller graphs. PiperOrigin-RevId: 484446391 28 October 2022, 06:38:34 UTC
4a21d41 Use `d.process_index == d.client.process_index()` to find addressable devices from a list of addressable and non-addressable devices. For GPU and TPU, this is a no-op change. For CPU, this makes a difference in multiprocess computation because if you run a `jit` on a CPU when you are on process 1 (for example), the `process_index` on the CPU device is 0. This is wrong since the CPU device is local to process 1 but the fix in the CL is a workaround until a runtime fix lands. PiperOrigin-RevId: 484403748 28 October 2022, 01:44:09 UTC
5cfc708 Remove error-prone most_recent_entry() support from lu.cache. PiperOrigin-RevId: 484382188 27 October 2022, 23:41:44 UTC
0a79fa4 jax.numpy: implement window functions in terms of lax ops Including blackman, bartlett, hamming, hanning, kaiser. Why? Previously these were implemented by computing the output on host at trace-time and embedding the result as a large constant array. Computing the results via lax operations is more in the spirit of jax.numpy. 27 October 2022, 22:47:04 UTC
51242bc jax.numpy: implement window functions in terms of lax ops Including blackman, bartlett, hamming, hanning, kaiser. Why? Previously these were implemented by embedding large constants; this should be more performant. 27 October 2022, 22:08:16 UTC
66e75ed [sparse] Update the default CUSPARSE SPMV and SPMM algorithms in jaxlib. PiperOrigin-RevId: 484351696 27 October 2022, 21:34:44 UTC
fc8f40c Internal visibility change PiperOrigin-RevId: 484340424 27 October 2022, 20:49:16 UTC
0c119cb Merge pull request #13013 from ROCmSoftwarePlatform:rocm_update_amdgpu_archs PiperOrigin-RevId: 484334622 27 October 2022, 20:27:36 UTC
f699acd Internal change PiperOrigin-RevId: 484330913 27 October 2022, 20:12:38 UTC
3e38675 Update debugging_primitives_test to not use nontrivial floating point text comparisons PiperOrigin-RevId: 484325096 27 October 2022, 19:48:11 UTC
caac5f9 Internal changes for Array integration on alternative JAX backends. PiperOrigin-RevId: 484319105 27 October 2022, 19:20:24 UTC
4d94908 Add more shards to `for_loop_test_cpu` PiperOrigin-RevId: 484309176 27 October 2022, 18:42:10 UTC
4b1fd63 Re-enable skipped test Fixes #12927 PiperOrigin-RevId: 484304818 27 October 2022, 18:25:54 UTC
9f80402 Add a default `PmapSharding` option which matches exactly `pmap`'s device placement. PiperOrigin-RevId: 484289013 27 October 2022, 17:28:25 UTC
0fce5be Improve undefined axis checks Previously we checked for out axes being a superset of the defined axes, but that's just not the right relation. In particular, out_axes of {'a'} are not a superset of defined axes {'b'}, but axis 'a' is undefined. The correct check is to verify emptiness of their difference. 27 October 2022, 17:05:26 UTC
42b2cf2 Fix result type for `InspectSharding` custom call - Use the result type that will be valid for the sharding that will be attached to this custom call. 27 October 2022, 17:01:22 UTC
6634410 [ROCm] Added gfx90a and gfx1030. 27 October 2022, 16:57:06 UTC
978dcde MHLO Pretty Print - Enhance type printing for CopyOp, ClampOp, CstrReshapeOp, ComputeReshapeShapeOp, SelectOp. Based on: https://github.com/openxla/stablehlo/pull/37 PiperOrigin-RevId: 484271777 27 October 2022, 16:27:00 UTC
994e0ac Merge pull request #12979 from jakevdp:annotate-lax-numpy-3 PiperOrigin-RevId: 484271670 27 October 2022, 16:20:28 UTC
94ac5fb Merge pull request #12997 from hawkinsp:minjaxlib PiperOrigin-RevId: 484259964 27 October 2022, 15:29:23 UTC
320d531 Increase the minimum jaxlib version to 0.3.22. The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32. 27 October 2022, 14:24:11 UTC
40c85bd [jax2tf] Uses MHLO bytecode for XlaCallModule op. Most of the changes here have to do with the fact that it is harder to inspect the converted code, since the MHLO is not in text form. This means that some tests need to be adjusted, and we are dropping an error message when the converted code uses custom calls, since the detection was based on inspecting the text of the MHLO. PiperOrigin-RevId: 484220490 27 October 2022, 11:49:08 UTC
540835f Merge pull request #12998 from jakevdp:annotate-reductions PiperOrigin-RevId: 484126361 27 October 2022, 01:41:31 UTC
9abacbd Merge pull request #9079 from NeilGirdhar:annotate_tree PiperOrigin-RevId: 484114597 27 October 2022, 00:30:06 UTC
d2df0fa Merge pull request #12996 from mattjj:tweak-jnp-canonicalize-shape PiperOrigin-RevId: 484100902 26 October 2022, 23:24:37 UTC
bf21391 [JAX] Change the default pmap() ordering to match the ordering of jax.devices() for single-process TPU jobs. PiperOrigin-RevId: 484062717 26 October 2022, 20:56:07 UTC
709ffd7 [typing] annotate jax.numpy reduction operations 26 October 2022, 20:33:15 UTC
a08ced8 Merge pull request #12991 from jakevdp:fix-faq PiperOrigin-RevId: 484042682 26 October 2022, 19:39:00 UTC
7e34181 [dynamic-shapes] tweak jnp.canonicalize_shape logic The idea with jnp.canonicalize_shape is that it handles non-tuple shapes, i.e. intended to be scalar-like arguments like Python builtin ints or numpy scalar types or 0D arrays. To do that, it checks numpy.ndim(shape) == 0. But numpy.ndim might attempt to convert its argument to a numpy.ndarray, which breaks when the argument is a tuple with Tracers inside! Instead, let's just check if the argument is one of the canonical sequence types (list or tuple) and if so then not even call numpy.ndim. 26 October 2022, 19:01:49 UTC
0814770 Fix FP8 compilation failure in jaxlib stemming from the CUDA/ROCM merge. PiperOrigin-RevId: 484026031 26 October 2022, 18:40:14 UTC
bdde0f0 [mesh_utils] Support single-core 2D meshes PiperOrigin-RevId: 484026013 26 October 2022, 18:32:50 UTC
e9194b2 FAQ: fix JIT numerics discussion 26 October 2022, 18:30:17 UTC
9c0f876 [typing] annotate jnp.pad 26 October 2022, 18:09:52 UTC
db2c8c1 Merge pull request #12994 from hawkinsp:docfix PiperOrigin-RevId: 484015353 26 October 2022, 17:55:14 UTC
b742b04 Annotate tree_util 26 October 2022, 17:38:38 UTC
back to top