https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
93d5b63 Replace jax2tf.shape_poly with export.shape_poly The shape_poly code has recently been moved and we want to remove the backwards-compatibility shims from jax2tf.shape_poly. PiperOrigin-RevId: 589025336 08 December 2023, 07:23:53 UTC
7af1c14 [Pallas/Mosaic] Lower `lax.fori_loop`s to *rolled* loops. Note that this is a breaking change! Current uses of `lax.fori_loop` inside of kernels should instead pass `unroll=True` (loops were being unrolled by default and we are switching that with this change). PiperOrigin-RevId: 589017485 08 December 2023, 06:55:12 UTC
1701b73 Update XLA dependency to use revision http://github.com/openxla/xla/commit/3997706f5b2361b9f72638fac1c69e8cd05bf368. PiperOrigin-RevId: 589013162 08 December 2023, 06:31:26 UTC
ffb115b Add annotation for donation. This configures XLA's AddBufferDonor directly via donated_args instead of first configuring input_output_aliases. This is best effort anyways, but it works on TPU. PiperOrigin-RevId: 588974683 08 December 2023, 02:46:30 UTC
1189d61 [Pallas] Fix batching rule for kernels with scratch inputs Scratch inputs do not need a batching dimension. PiperOrigin-RevId: 588921137 07 December 2023, 23:10:12 UTC
e423347 Declare magic port number for jax.distributed.initialize in cloud TPU environments. PiperOrigin-RevId: 588920806 07 December 2023, 23:02:04 UTC
c2f8e18 Merge pull request #18862 from superbobry:pp-improvement PiperOrigin-RevId: 588906413 07 December 2023, 22:15:33 UTC
397a44e Merge pull request #18864 from 8bitmp3:jax-docs-pytrees-101 PiperOrigin-RevId: 588859487 07 December 2023, 19:40:19 UTC
78e0e6b Internal change PiperOrigin-RevId: 588856820 07 December 2023, 19:30:54 UTC
5eebb91 Upgrade JAX pytrees doc 07 December 2023, 19:18:04 UTC
3e68ad1 Make `GetParameterLayouts` return the correct thing when both CPU and TPU executables exist. PiperOrigin-RevId: 588852594 07 December 2023, 19:18:02 UTC
c4239fc Merge pull request #18797 from jakevdp:factorial PiperOrigin-RevId: 588809589 07 December 2023, 16:59:57 UTC
ea158d3 Print pjit name= before other params The jaxpr sometimes gets pretty big, making it hard to see the name. 07 December 2023, 16:54:07 UTC
d5edbb9 Merge pull request #18860 from gnecula:poly_min_max PiperOrigin-RevId: 588749760 07 December 2023, 12:52:10 UTC
0a02d83 [shape_poly] Add simpler APIs max_dim and min_dim, improve >= 0 Add core.max_dim and core.min_dim as nicer wrappers around the core.non_negative_dim. Also improve the completeness of the heuristics for deciding >= 0, and add more tests. 07 December 2023, 08:41:47 UTC
8541124 Update XLA dependency to use revision http://github.com/openxla/xla/commit/7fcd6896a04244f4f05a5daffc591bd45acbc90d. PiperOrigin-RevId: 588654203 07 December 2023, 06:23:43 UTC
6aa74cd Merge pull request #18854 from mattjj:shmap-special-rules PiperOrigin-RevId: 588635591 07 December 2023, 04:51:42 UTC
62abe11 Merge pull request #18856 from mattjj:mesh-assertion-error PiperOrigin-RevId: 588634361 07 December 2023, 04:43:43 UTC
045a9ef Disable all_gather_test on non-v5e TPUs Also consolidate logic for selectively enabling TPU tests on TPU versions PiperOrigin-RevId: 588597889 07 December 2023, 01:47:17 UTC
64cb53f improve an error message during Mesh creation 07 December 2023, 00:43:36 UTC
7fd6c8d Update XLA dependency to use revision http://github.com/openxla/xla/commit/5c250f76726e53edf9ed09698631e2d3faf1cef5. PiperOrigin-RevId: 588578604 07 December 2023, 00:29:27 UTC
f177670 Merge pull request #18853 from mattjj:debug-callback-type-error PiperOrigin-RevId: 588570193 06 December 2023, 23:59:49 UTC
5641d81 [shard-map] add lax.special rep checking rules 06 December 2023, 23:05:36 UTC
7608cce improve a debug.callback type error message for idiots (i am the idiot) 06 December 2023, 22:41:52 UTC
5bdc303 Merge pull request #18850 from jakevdp:nonzero-zerodim PiperOrigin-RevId: 588534749 06 December 2023, 21:58:30 UTC
1dd68c5 Merge pull request #18845 from jakevdp:jaxpr-repr PiperOrigin-RevId: 588522842 06 December 2023, 21:18:59 UTC
5196004 jnp.nonzero: deprecate zero-dimensional inputs 06 December 2023, 20:57:25 UTC
fe6e195 Merge pull request #18824 from NeilGirdhar:precisionlike PiperOrigin-RevId: 588501503 06 December 2023, 20:03:21 UTC
9f85beb Expose PrecisionLike This is used in client code like: https://github.com/search?q=repo%3Agoogle%2Fflax%20%20PrecisionLike&type=code 06 December 2023, 19:41:22 UTC
bcf3255 Merge pull request #18846 from hawkinsp:jaxfix PiperOrigin-RevId: 588490406 06 December 2023, 19:25:39 UTC
5601873 Add register_jax_dialects to jaxlib wheel. Fixes build breakage. 06 December 2023, 19:07:04 UTC
3403631 [Pallas/Mosaic] Fixes for interpret mode on TPU * scratch space support * trivial lowering for trace_start/end PiperOrigin-RevId: 588482689 06 December 2023, 19:03:05 UTC
c2a0530 jaxpr: improve printed repr when eqn has no return values 06 December 2023, 18:45:24 UTC
eba08ed Merge pull request #18844 from jakevdp:dep-device-buffer PiperOrigin-RevId: 588476984 06 December 2023, 18:45:19 UTC
35b8440 Deprecate arr.device_buffer and arr.device_buffers 06 December 2023, 18:20:29 UTC
9e3a8fa Merge pull request #18833 from jakevdp:linalg-shapes PiperOrigin-RevId: 588458694 06 December 2023, 17:49:26 UTC
4bdcb11 Merge pull request #18843 from hawkinsp:unbreak PiperOrigin-RevId: 588452358 06 December 2023, 17:29:56 UTC
45905fa jnp.solve: handle corner case in input shapes 06 December 2023, 17:28:31 UTC
78543f7 Add jax.extend.mlir. Some users of JAX want to use the MLIR dialects defined in jaxlib. In particular, these need to be used by custom lowering rules. Add a semi-public (jax.extend) API to access these, rather than having them use jax._src.lib.mlir. PiperOrigin-RevId: 588448489 06 December 2023, 17:16:43 UTC
1c80b36 Remove stale reference to _site_initialize_0 in wheel build script. 06 December 2023, 17:12:15 UTC
d95084d Use an explicit MLIR dialect registration, rather than _site_initialize_0. Remove some special case handling of the SCF dialect, use upstream utilities instead. PiperOrigin-RevId: 588433245 06 December 2023, 16:19:55 UTC
ad14478 Further enhance the seed This choice passes more tests for more targets. PiperOrigin-RevId: 588395200 06 December 2023, 13:47:20 UTC
f691fe4 Update XLA dependency to use revision http://github.com/openxla/xla/commit/a6e6c1f6a53d4a23451c649110519c7ba8581bf9. PiperOrigin-RevId: 588298720 06 December 2023, 06:44:38 UTC
80e91fd Disabling running all_gather test on non-TPU platforms PiperOrigin-RevId: 588180749 05 December 2023, 21:57:19 UTC
dc469c3 Merge pull request #18676 from 8bitmp3:jax-docs-autodiff-101-201 PiperOrigin-RevId: 588172128 05 December 2023, 21:29:10 UTC
7090347 Merge pull request #18739 from jakevdp:ci-precommit PiperOrigin-RevId: 588161526 05 December 2023, 20:54:44 UTC
f25a51e Change bazel config name "rbe_cpu_linux_py39" to "rbe_cpu_linux_py3.9"to be consistent with cuda bazel build configs. Change other python versions bazel config similarly. PiperOrigin-RevId: 588100957 05 December 2023, 17:42:14 UTC
720ff42 [bazel] Add a macro if_building_jaxlib() to guard dependencies that should only be present if building jaxlib. Cleanup only, NFC intended. PiperOrigin-RevId: 588074047 05 December 2023, 16:05:17 UTC
9d35b90 [XLA:Mosaic] Support expanding lane dim in shapecast: (..., 128) -> (..., m * 128) and handle relayout from (1, 128) to (8, 128) for more general cases. PiperOrigin-RevId: 588024159 05 December 2023, 12:25:10 UTC
a31129a [Pallas/TPU] Add all gather kernel example PiperOrigin-RevId: 587963496 05 December 2023, 08:11:13 UTC
91ef37b Merge pull request #18818 from gnecula:fix_indexing_with_jax_array PiperOrigin-RevId: 587962024 05 December 2023, 08:03:09 UTC
ec46058 Fix indexing with slices when the slice elements are jax.Array. This fixes a bug introduced in #18679, for the case when some elements of the slice are `jax.Array`. We add a new test also. 05 December 2023, 07:02:50 UTC
7a3e214 Update XLA dependency to use revision http://github.com/openxla/xla/commit/a83901622bca41cb58594e81f752a1609f4417df. PiperOrigin-RevId: 587936503 05 December 2023, 05:48:24 UTC
7fa0f46 [bazel] Add a BUILD file for jax/extend, and add more granular targets for individual pieces of extend. In general we'd like to use more granular BUILD targets rather than larger monolithic targets. If nothing else, they interact better with pytype. This change is in preparation for adding the JAX MLIR bindings to jax.extend, since they are something that JAX users sometimes need especially for defining custom ops. PiperOrigin-RevId: 587893573 05 December 2023, 01:48:50 UTC
f29ec4e Change seed for a test We got unlucky and hit a seed which happens to fail the KS test. PiperOrigin-RevId: 587885112 05 December 2023, 01:06:17 UTC
46bad2b Upgrade JAX Autodiff 101 05 December 2023, 00:00:13 UTC
a9bfbd3 Finish jax and jaxlib 0.4.21 release PiperOrigin-RevId: 587866580 04 December 2023, 23:51:58 UTC
193eb12 Use sys.version_info to guard Python 3.11 only code. This makes pytype happier since it understands sys.version_info, but didn't understand the previous hasattr() test. PiperOrigin-RevId: 587846307 04 December 2023, 22:48:46 UTC
aa27048 Merge pull request #18809 from jakevdp:fix-numpy-warning PiperOrigin-RevId: 587845552 04 December 2023, 22:40:51 UTC
150ab68 Merge pull request #18580 from 8bitmp3:jax-docs-new-installation PiperOrigin-RevId: 587840299 04 December 2023, 22:24:34 UTC
8b74b93 Test: fix casting warning in betainc test 04 December 2023, 22:20:14 UTC
b1a8bc6 Upgrade JAX Installation doc 04 December 2023, 21:42:13 UTC
c4d2fc7 Replace device_buffers with addressable_shards in test because device_buffers is deprecated PiperOrigin-RevId: 587825636 04 December 2023, 21:36:12 UTC
baa7756 Use scoped disable_jit() in dynamic_api_test. This test was leaving jit disabled, affecting other tests. PiperOrigin-RevId: 587803847 04 December 2023, 20:20:36 UTC
b04fd31 Add option to pass in `unroll=True/False` into `scan` and `fori_loop`. PiperOrigin-RevId: 587795364 04 December 2023, 19:54:50 UTC
5b97960 Disable sanitizer builds of shape_poly_test. These take a very long time and sometimes timeout so it's probably not worth running them in CI. PiperOrigin-RevId: 587768399 04 December 2023, 18:42:56 UTC
54e3b76 Add support for unrolling to `lax.fori_loop` PiperOrigin-RevId: 587767613 04 December 2023, 18:34:53 UTC
5942e15 Prepare for 0.4.21 release PiperOrigin-RevId: 587767502 04 December 2023, 18:26:48 UTC
d91c13e Merge pull request #18795 from gnecula:test_export_grad PiperOrigin-RevId: 587730171 04 December 2023, 16:32:27 UTC
8a2d4a0 [export] Add and fix a test for exporting higher-order gradients with sharding There was a test for export with gradients, we changed the test to (a) export 2nd order gradient also, and (b) to export both with a mesh context and without a mesh context (using NamedSharding). This test currently fails, only in the case when we do NOT have a mesh context, as explained below: When exporting gradient functions, we first export the primal functions and we use the in/out-shardings to construct shardings of the gradient function. Since Exported shardings now contain only HloSharding objects, and to lower the gradient function we must use `pjit(vjp(f)).lower()`, we construct GSPMDSharding objects using the current devices and the HloSharding object from the Exported primal. However, these objects do not have the `_original_sharding` attribute. Later in `pjit._resource_typing_pjit` we attempt to `parse_flatten_op_sharding` using the mesh context (which is empty). This fails. This PR contains one workaround, to skip `parse_flatten_op_sharding` if the physical mesh of the `resource_env` is empty. Another, probably better solution, is to ensure that `resource_env` is `None` when then is no mesh context. That seemed reasonable, but currently the code returns an empty mesh from the resource_env if there is no mesh context. Changing this would have effects in more parts of the code, so I have not done it here, but it may be worth doing. 04 December 2023, 16:07:23 UTC
1d95e79 Disable export_harnesses_multi_platform_test under sanitizers. This test appears to hit some sort of LLVM bug on Sapphire Rapids CPUs. PiperOrigin-RevId: 587719850 04 December 2023, 15:54:35 UTC
5e0993c Merge pull request #18794 from olupton:qualname PiperOrigin-RevId: 587696743 04 December 2023, 14:23:09 UTC
70d0f60 Add special.factorial function 04 December 2023, 14:13:14 UTC
3c0c6b7 Use qualified name if possible. 04 December 2023, 13:19:38 UTC
a137edc Update XLA dependency to use revision http://github.com/openxla/xla/commit/96c4f8749b6521ff3e2670d168c9095a5f6323c5. PiperOrigin-RevId: 587591208 04 December 2023, 05:18:37 UTC
bf08411 Update XLA dependency to use revision http://github.com/openxla/xla/commit/6f22d6a8d45ce00c2d6654fb5f99ca7f0f094b51. PiperOrigin-RevId: 587406410 03 December 2023, 05:40:59 UTC
61e79cd Merge pull request #18786 from gnecula:test_export_effects PiperOrigin-RevId: 587333913 02 December 2023, 18:49:08 UTC
b51b80e Merge pull request #18761 from gnecula:export_sharding PiperOrigin-RevId: 587330573 02 December 2023, 18:24:47 UTC
bd7c1aa [export] Improve the testing of exporting with effects We now test that when we call an Exported from a computation that already uses effects, the effects from the calling computation are identified with the events from the called Exported. 02 December 2023, 18:18:22 UTC
3eb3e2d [export] Simplify the handling of shardings in Exported. Previously, Exported contained tuples of `XlaCompatibleSharding` for the input and output shardings. These shardings contain references to JAX devices, which is too much for exporting purposes and in fact it gets in the way when we want to serialize the Exported. We change Exported to carry `xla_client.HloSharding` instead, which conveniently can be serialized to proto. We use the value `None` to denote an unspecified sharding. We also add `nr_devices` and then for exporting purposes we can construct actual `XlaCompatibleSharding` when we need to. 02 December 2023, 18:10:24 UTC
b822801 Merge pull request #18783 from gnecula:fix_indexing PiperOrigin-RevId: 587328492 02 December 2023, 18:09:29 UTC
32fb1b4 Remove the ml_program MLIR dialect from jaxlib. Jax isn't using this, and in fact our code to build this wasn't including the C++ parts, so it was broken anyway. Remove it until someone actually needs it for something. PiperOrigin-RevId: 587323808 02 December 2023, 17:29:39 UTC
d2f6261 Fix bug in indexing with slices that overflow, and add tests. This bug was introduced in #18679, and was not caught in unit tests because we were not testing cases when the slice needs to be clamped. 02 December 2023, 14:47:06 UTC
0418370 Put the input creating via `jnp` outside the pjit cpp cache miss counting code. PiperOrigin-RevId: 587256278 02 December 2023, 08:31:24 UTC
965ba05 Rewrite `tf.aliasing_output` when wrapping the main StableHLO function When the original main function has tokens and the wrapped main function does not, there are fewer outputs in the wrapped main than the original main. This is problematic for `tf.aliasing_output`, which is an argument attribute that stores result indexes to which the argument can alias. This CL makes the wrapper main creation rewrite `tf.aliasing_output` according to the new result indexes. The newly added test verifies that the aliasing indexes are correct across all supported serialization versions. Confirmed that the test fails without changes in export.py (versions 6, 7, and 8 fail and 9 passes). PiperOrigin-RevId: 587237842 02 December 2023, 06:42:37 UTC
1aab108 Update XLA dependency to use revision http://github.com/openxla/xla/commit/33c73ddc733b38245f1e2ca46f092119615fb386. PiperOrigin-RevId: 587237109 02 December 2023, 06:34:53 UTC
f0bc7e0 Reverts f0382a5838f4526d21631e804f6fe576bfc3f97e PiperOrigin-RevId: 587231484 02 December 2023, 06:06:33 UTC
86661c8 Re-enable passing tests on TPU These have been working for a while. PiperOrigin-RevId: 587199000 02 December 2023, 03:26:58 UTC
e61f7a8 Merge pull request #18757 from jakevdp:astype PiperOrigin-RevId: 587167285 02 December 2023, 01:09:02 UTC
f5f885d Merge pull request #18774 from hawkinsp:nightlyci PiperOrigin-RevId: 587150147 02 December 2023, 00:05:58 UTC
595117b Add a test to check if arr.delete() is idempotent. PiperOrigin-RevId: 587121346 01 December 2023, 22:28:51 UTC
7274960 Remove nightly NVIDIA GPU multiprocess CI. This CI seems to be dead. 01 December 2023, 22:13:24 UTC
a999120 Improve error message when cudnn is not found. We infer a missing cudnn if cudnnGetVersion() returns 0, since the stub implementation in TSL will do that if the library isn't found (https://github.com/openxla/xla/blob/10a378f49978aa4ee4564ceb105d33694fd48202/third_party/tsl/tsl/cuda/cudnn_stub.cc#L58). PiperOrigin-RevId: 587056454 01 December 2023, 18:52:48 UTC
50c7223 Fix Windows build failure. The TPU extension didn't build because the MLIR Python binding code requires pybind11 to be included first on Windows, per https://github.com/llvm/llvm-project/blob/9584f5834499e6093797d4a28fde209f927ea556/mlir/include/mlir-c/Bindings/Python/Interop.h#L24 PiperOrigin-RevId: 587049246 01 December 2023, 18:31:53 UTC
95bc2ba Inline sigmoid, isfinite, and isnan in jaxprs. In the common case (real values) these are all single-expression jaxprs themselves, so putting them out of line just makes things more verbose. There's no reason to include stuff like this in a jaxpr: ``` cxd:bool[8,16] = pjit[ jaxpr={ lambda ; cxe:f32[8,16]. let cxf:bool[8,16] = is_finite cxe in (cxf,) } name=isfinite ] cxc ``` PiperOrigin-RevId: 587047955 01 December 2023, 18:23:56 UTC
ada5fe5 Remove numpy-dispatch CI job & simplify build specification The numpy-dispatch approach has been superseded by the Python Array API (Tracked for JAX in https://github.com/google/jax/issues/18353). While we're here, we'll reduce the github CI to only two jobs: the oldest and newest supported Python versions. Other versions can be covered by Kokoro. PiperOrigin-RevId: 587041291 01 December 2023, 18:03:15 UTC
b3c579e Merge pull request #18762 from gnecula:poly_getitem_next PiperOrigin-RevId: 587018677 01 December 2023, 16:33:55 UTC
65fca0e [shape_poly] Add heuristics for deciding >= 0 The rules for deciding inequalities of symbolic expressions are incomplete. Here we add two heuristics that help decide the bounds checking of indices computed for indexing with slices: To decide whether an expression that contains `non_negative(e)` is >= 0, it is sufficient to show that the expression is >=0 if we replace the `non_negative(e)` with `0` and with `e`. To decide whether `floordiv(e, k)` is >= 0, when `k >= 0`, it is sufficient to show that `e` is >= 0. These are sufficient for the bounds checking that JAX is doing internally, but may not be for the cases when the user program does index computations using those operators. This enables us to re-enable the shape_poly indexing tests. 01 December 2023, 11:55:42 UTC
54fee48 Cast in/out shardings to tuple before passing to `Exported` ctor. PiperOrigin-RevId: 586951567 01 December 2023, 10:52:26 UTC
e60aa3b Merge pull request #18679 from gnecula:poly_getitem2 PiperOrigin-RevId: 586902301 01 December 2023, 07:07:31 UTC
back to top