https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
3d82e2d Hacky fix to GDA ckpt for multiopt PiperOrigin-RevId: 460531663 12 July 2022, 21:17:11 UTC
a24100d Merge pull request #11447 from jakevdp:empty-impl PiperOrigin-RevId: 460543062 12 July 2022, 20:45:28 UTC
4d1c6df Merge pull request #11469 from jakevdp:fix-rem-jvp PiperOrigin-RevId: 460517781 12 July 2022, 18:53:27 UTC
67858ff Merge pull request #11459 from jakevdp:clip-deprecation PiperOrigin-RevId: 460496031 12 July 2022, 17:26:03 UTC
188cb3f Merge pull request #11453 from jakevdp:docutils-version PiperOrigin-RevId: 460492785 12 July 2022, 17:14:05 UTC
daf6e3b BUG: fix jvp rule for lax.rem 12 July 2022, 16:50:42 UTC
3eff9d1 Internal change PiperOrigin-RevId: 460434859 12 July 2022, 12:21:20 UTC
9b54660 [jax2tf] Use tf.nest.flatten instead of tf.keras.backend.flatten. PiperOrigin-RevId: 460395701 12 July 2022, 07:54:06 UTC
82e81a0 Define PyTreeDef on pytree module rather than its parent PiperOrigin-RevId: 460342835 12 July 2022, 00:57:36 UTC
153b6ae Fix limitations-of-call_tf github link typo. call_tf misspell as "call-tf" on multiple places. PiperOrigin-RevId: 460335218 12 July 2022, 00:12:59 UTC
0bc8f8a * Check if the device assignment is the same across input and output shardings. * Allow mixed inputs only if the sharding matches with what is specified in in_axis_resources. PiperOrigin-RevId: 460326054 11 July 2022, 23:27:11 UTC
833a6ed jnp.ndarray.clip: remove deprecated arguments 11 July 2022, 23:02:31 UTC
11896b6 Merge pull request #11429 from sharadmv:for-loop PiperOrigin-RevId: 460318883 11 July 2022, 22:52:42 UTC
d109c60 Merge pull request #11451 from jakevdp:doc-execute-promotion PiperOrigin-RevId: 460305621 11 July 2022, 21:51:08 UTC
9e16efa Integrate LLVM at llvm/llvm-project@71c9757474c3 Updates LLVM usage to match [71c9757474c3](https://github.com/llvm/llvm-project/commit/71c9757474c3) PiperOrigin-RevId: 460299215 11 July 2022, 21:21:09 UTC
d8d836b Fix execution issues in type_promotion doc 11 July 2022, 21:16:11 UTC
b405f17 CI: upgrade doc requirements 11 July 2022, 21:08:53 UTC
8f09606 Make `jaxpr_has_pmap` work for other primitives too PiperOrigin-RevId: 460286042 11 July 2022, 20:24:16 UTC
9d610e2 Add loop invariant residual fixpoint test 11 July 2022, 20:10:03 UTC
183b9e4 Merge pull request #11397 from jakevdp:diagonal-err PiperOrigin-RevId: 460236416 11 July 2022, 16:52:19 UTC
6c478ed Improve documentation of jnp.empty() and jnp.empty_like() 11 July 2022, 16:21:39 UTC
cc42c80 Merge pull request #11406 from jakevdp:bcoo-add-batchdim PiperOrigin-RevId: 460226570 11 July 2022, 16:06:42 UTC
df74907 Merge pull request #11440 from google:random-docstring-fix PiperOrigin-RevId: 460220463 11 July 2022, 15:36:00 UTC
e19d026 Merge pull request #11442 from hawkinsp:shards PiperOrigin-RevId: 460211938 11 July 2022, 14:52:06 UTC
64e0b5d Increase bazel sharding of GPU tests. Reduces the maximum time for some test shards to avoid flaky timeouts. 11 July 2022, 14:19:43 UTC
64eb46a Fix RST formatting in random.py docstring 11 July 2022, 09:51:35 UTC
b666f66 Rollback of HCB GPU custom call due to internal failures PiperOrigin-RevId: 460079787 10 July 2022, 20:05:27 UTC
5910cdc Print the repr of device PiperOrigin-RevId: 459986696 10 July 2022, 00:08:44 UTC
ed51c65 Merge pull request #11405 from mattjj:djax-vmap PiperOrigin-RevId: 459958155 09 July 2022, 17:38:39 UTC
5b82ba7 [dynamic-shapes] start basic vmap compatibility 09 July 2022, 17:03:40 UTC
09ba51f Move _get_array_mapping from gda.py to pxla.py PiperOrigin-RevId: 459891853 09 July 2022, 04:38:06 UTC
df993ea Merge pull request #11410 from sharadmv:for-loop PiperOrigin-RevId: 459879694 09 July 2022, 02:37:57 UTC
bff71b2 Add loop-invariant residual optimization for `for` 09 July 2022, 01:54:51 UTC
66ab792 Merge pull request #11383 from YouJiacheng:Enable-HCB-customCall-implementation-on-GPU PiperOrigin-RevId: 459872063 09 July 2022, 01:23:16 UTC
847c04f [mhlo] CosOp -> CosineOp Aligns the op class name with the mnemonic PiperOrigin-RevId: 459852934 08 July 2022, 23:04:14 UTC
5f7018a [mhlo] SinOp -> SineOp Aligns the op class name with the mnemonic PiperOrigin-RevId: 459830783 08 July 2022, 21:08:12 UTC
1bb1fe0 Remove workaround for rank-0 zarr chunk layout bug in TensorStore This has now been fixed in TensorStore. PiperOrigin-RevId: 459824051 08 July 2022, 20:35:29 UTC
dac310c Merge pull request #11421 from jakevdp:scalar-meta-nocopy PiperOrigin-RevId: 459823335 08 July 2022, 20:30:20 UTC
bb2c5f1 Resolve TODOs and add some more checks for the jax.Array path. PiperOrigin-RevId: 459808511 08 July 2022, 19:19:19 UTC
a2f2d1f [mhlo] ConvOp -> ConvolutionOp Aligns the op class name with the mnemonic PiperOrigin-RevId: 459808502 08 July 2022, 19:13:51 UTC
7c70783 Enable CustomCall implementation on GPU 08 July 2022, 18:29:08 UTC
e19df1a Use asarray rather than array in ScalarMeta Why? This will make it so that jnp.int32(x) and friends no longer insert a gratuitous copy_p operation in the jaxpr 08 July 2022, 18:16:40 UTC
5285a15 Merge pull request #11419 from hawkinsp:jaxlibcleanup PiperOrigin-RevId: 459780016 08 July 2022, 17:04:39 UTC
229ddec * Remove AUTO from MeshPspecSharding and treat it like _UNSPECIFIED singleton value. * Support partial mentions of AUTO which is supported by GDA currently and used in pax. Added tests for all of this. * As a consequence of this, I lifted the restriction on not providing `in_axis_resources` to pjit under `config.jax_array`. * Made all auto sharding tests parameterized to test both gda and array. PiperOrigin-RevId: 459776152 08 July 2022, 16:45:23 UTC
5a7bedc Increase shard_count for sparse_test_gpu to 20. https://github.com/google/jax/commit/1918d39765f6775de778625ecc4f197b2e104a6f updated the wrong test! This test is close to the timeout in the GPU CI and flakes sometimes. PiperOrigin-RevId: 459762867 08 July 2022, 15:30:26 UTC
41b015a Remove stale code from jax/_src/lib/__init__.py Remove inaccurate/stale __all__. Remove unused alias _xla_extension_version. 08 July 2022, 15:09:58 UTC
928f22c Merge pull request #11418 from hawkinsp:bzl PiperOrigin-RevId: 459754677 08 July 2022, 14:38:13 UTC
a48f4e1 Change Bazel test rules to generate per-backend test suites. 08 July 2022, 14:19:05 UTC
55dcbec Merge pull request #11407 from hawkinsp:minver PiperOrigin-RevId: 459740984 08 July 2022, 13:04:47 UTC
bc9c4b7 Adjust docs to account for what the actual current RNG behavior is PiperOrigin-RevId: 459712928 08 July 2022, 09:55:36 UTC
7ffedb5 Merge pull request #11400 from jakevdp:deprecate-treeutil PiperOrigin-RevId: 459681801 08 July 2022, 06:05:35 UTC
34fea3d Merge pull request #11408 from hawkinsp:sparseshard PiperOrigin-RevId: 459647331 08 July 2022, 01:23:20 UTC
1918d39 Increase number of shards for GPU sparse_test to 20. 08 July 2022, 01:14:25 UTC
a3e8ae4 Merge pull request #11388 from jakevdp:fix-bool-weak PiperOrigin-RevId: 459641851 08 July 2022, 00:43:17 UTC
0b4b0ba Update minimum jaxlib version to 0.3.14. 08 July 2022, 00:36:02 UTC
adcf30e [sparse] remove deprecated bcoo_add_batch_dim utility 07 July 2022, 23:57:36 UTC
44bd311 Merge pull request #11403 from jakevdp:sparse-unary PiperOrigin-RevId: 459634024 07 July 2022, 23:57:23 UTC
17de5e4 jnp.diagonal: raise explicit error if ndim < 2 07 July 2022, 23:36:40 UTC
56d61d3 BUG: ensure that boolean scalars are never marked weak 07 July 2022, 22:41:23 UTC
7da733f Change the internals of `with_sharding_constraint` to use the sharding instances. PiperOrigin-RevId: 459600050 07 July 2022, 21:22:10 UTC
2b4f72b [sparse] fix unary operations in presence of duplicate indices 07 July 2022, 20:49:50 UTC
fe1bbd5 Merge pull request #11399 from mattjj:lower-abstracted-axes PiperOrigin-RevId: 459585916 07 July 2022, 20:20:39 UTC
12a56c3 [dynamic-shapes] add basic abstracted_axes support to jit(f, ...).lower(...) 07 July 2022, 19:48:29 UTC
9d18f43 Do not normalize FFT by a constant "1" if no normalization is provided (i.e., norm is None). Without this, the compiled graph will still contain a node multipying a complex number with a constant 1+0j (1 is cast to complex because the other term is complex as well). This is problematic when converting to TFLite using jax2tf, because multiplying complex numbers is not supported in TFLite. With this change, the multiplication is removed from the graph all together. PiperOrigin-RevId: 459566727 07 July 2022, 18:54:39 UTC
ce08a9f Deprecate top-level aliases of jax.tree_util functions 07 July 2022, 18:41:46 UTC
a10f037 Avoid top-level aliases of jax.tree_util.* 07 July 2022, 18:41:02 UTC
57ed5dc Add a util to fetch value of a GDA to host when its single-controller. Error out in McJAX PiperOrigin-RevId: 459555907 07 July 2022, 18:09:13 UTC
2314951 Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974 07 July 2022, 17:41:52 UTC
88c1e7d Flip after_neurips flag to True. PiperOrigin-RevId: 459541278 07 July 2022, 17:12:15 UTC
fb7e39b Merge pull request #11390 from hawkinsp:distributed_init PiperOrigin-RevId: 459518348 07 July 2022, 15:23:26 UTC
2b8fbe9 Merge pull request #11367 from apaszke:xmap-tracer-leak PiperOrigin-RevId: 459456785 07 July 2022, 09:01:51 UTC
5270cb1 Merge pull request #11387 from mattjj:djax-bint PiperOrigin-RevId: 459430960 07 July 2022, 06:00:59 UTC
98e71fe [dynamic-shapes] revive basic bounded int machinery, add tests 07 July 2022, 05:31:26 UTC
6274b9e Enable Python callbacks on TFRT TPU backend PiperOrigin-RevId: 459415455 07 July 2022, 03:52:50 UTC
5d379bb mhlo.rng op with distribution attr Aligns with the XLA kRng which takes a distribution as an attribute instead of having separate ops for each distribution. PiperOrigin-RevId: 459389874 07 July 2022, 01:03:02 UTC
bdbdecd Refactor distributed GPU device initialization. Avoid reregistering backend factories; instead simply have the usual factory function support distributed GPU. 07 July 2022, 00:45:19 UTC
89a6766 Merge pull request #11313 from mattjj:djax-revive-iree PiperOrigin-RevId: 459360223 06 July 2022, 22:34:05 UTC
6bb90fd [dynamic shapes] revive iree 06 July 2022, 22:01:16 UTC
638e435 Merge pull request #11381 from bartvm:main PiperOrigin-RevId: 459346579 06 July 2022, 21:40:42 UTC
95e7933 Add JAX_TEST_TARGETS and JAX_EXCLUDE_TEST_TARGETS environment variables to assist with skipping tests under Bazel. Add "multiaccelerator" test tags to mark tests that would meaningfully run with more than one accelerator (e.g., GPU). PiperOrigin-RevId: 459320212 06 July 2022, 19:51:43 UTC
354c684 [jax2tf] Update docs for supported convolution types. PiperOrigin-RevId: 459316769 06 July 2022, 19:36:29 UTC
de08344 Avoid casting input to _fft_helper. 06 July 2022, 18:29:54 UTC
4ed8255 Fix iree.py python integration for backend changes CPU / VMVX runtime is now called local-task. Updated to separate compiler, runtime, and backend naming for single specified configuration. PiperOrigin-RevId: 459298179 06 July 2022, 18:17:44 UTC
da5385f Merge pull request #11379 from hawkinsp:parallel PiperOrigin-RevId: 459276185 06 July 2022, 16:53:32 UTC
4443705 Add script for parallel accelerator testing under Bazel. 06 July 2022, 14:58:04 UTC
b5e6145 Merge pull request #11359 from hawkinsp:bazel PiperOrigin-RevId: 459234031 06 July 2022, 13:13:20 UTC
1c75eee Document how to run tests using Bazel. * Add a new --configure_only option to build.py to allow build.py to generate a .bazelrc without necessarily building jaxlib. * Add a bazel flag that make the dependency of //jax on //jaxlib optional. If //jaxlib isn't built by bazel, then tests will implicitly use a preinstalled jaxlib. 06 July 2022, 12:30:35 UTC
5777c1e Add support for post_process of xmap in BatchTrace PiperOrigin-RevId: 459108183 05 July 2022, 19:07:26 UTC
0719f98 Merge pull request #11368 from gnecula:shape_poly_test_refactor PiperOrigin-RevId: 459057179 05 July 2022, 12:19:05 UTC
dc3d776 [shape_poly] Refactor tests to separate the vmap tests Introduce ShapePolyVmapPrimitivesTest to contain all the tests that vmap results in batch polymprphic code. Also fix some warnings about eig, eigh, and qr taking only kwarg arguments. 05 July 2022, 12:01:19 UTC
7439e1b Properly count sublevels when tracing xmap body Otherwise it can lead to tracer leak errors. I'm not a 100% sure how this works out, because the sublevel counting has changed since I read it previously. This replicates the changes applied to DynamicJaxprTrace.process_map since I last looked at it. 05 July 2022, 11:43:26 UTC
b2d7005 Merge pull request #9426 from gnecula:iree_poly PiperOrigin-RevId: 459042352 05 July 2022, 10:28:21 UTC
b6c9069 Fix mypy annotations 05 July 2022, 09:49:37 UTC
5983d38 [dynamic-shapes] Expand the handling of dynamic shapes for reshape and iota. Also add more tests. 05 July 2022, 09:14:15 UTC
5d6f81c Merge pull request #11361 from hawkinsp:tri PiperOrigin-RevId: 458800381 03 July 2022, 22:32:47 UTC
5620264 Add missing dtype canonicalization to tridiagonal solve lowering. This meant that the tridiagonal solve test failed when X64 mode was disabled on GPU. 03 July 2022, 20:08:54 UTC
a4798c3 Merge pull request #11358 from nalzok:patch-1 PiperOrigin-RevId: 458786037 03 July 2022, 19:38:14 UTC
8f16270 Merge pull request #11360 from hawkinsp:tol PiperOrigin-RevId: 458786025 03 July 2022, 19:32:37 UTC
62a392a Relax test tolerances. These tests current fail on M1 Mac. 03 July 2022, 19:13:28 UTC
2d063d3 Fix typos in omnistaging.md 03 July 2022, 19:02:30 UTC
back to top