https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
b868cf7 Merge pull request #13616 from jakevdp:fix-sparse-error PiperOrigin-RevId: 494758906 12 December 2022, 18:03:54 UTC
4a9e9d5 Merge pull request #13584 from LenaMartens:debug-check PiperOrigin-RevId: 494754750 12 December 2022, 17:46:29 UTC
2e95990 [sparse] fix logic bug in GPU testcase 12 December 2022, 17:38:17 UTC
3db909e Checkify: Add documentation for adding run-time values to error message. 12 December 2022, 15:56:40 UTC
13c34f9 Move `with_sharding_constraint` out of experimental into `jax.lax` namespace. PiperOrigin-RevId: 494635809 12 December 2022, 06:55:21 UTC
94590e2 Merge pull request #13562 from gnecula:opaque_shape_poly PiperOrigin-RevId: 494632979 12 December 2022, 06:32:34 UTC
27f5bd0 Improves handling of opaque types for dynamic shapes The immediate motivation for this is to support the lowering to StableHLO for programs with polymorphic shapes. This requires mixing of dynamic shapes with opaque types. The general strategy is to push the actual selection of the MHLO ops down into mlir module (e.g., mlir.slice_op, mlir.broadcast_in_dim) so that we have one place where we pick whether we use the Dynamic or static ops. These routines can also handle the opaque type. This will result in a recursive call to, e.g., mlir.slice_op, but the inner call will be using the physical avals, which should not be opaque anymore. While making this change I was confused by the fact that the custom KeyTyRules in prng.py have lowerings that return multiple MHLO ops. See https://github.com/google/jax/pull/11768#issuecomment-1342349102 and I changed the rules to return a single op. . 12 December 2022, 04:19:04 UTC
2f1354e Add workaround for imprecise shape inference for DynamicGatherOp This is needed for gather in presence of dynamic shapes. PiperOrigin-RevId: 494613303 12 December 2022, 04:18:15 UTC
4af4234 Merge pull request #13602 from mattjj:tweak-array-notebook PiperOrigin-RevId: 494480882 11 December 2022, 06:45:30 UTC
1185c89 in jax.Array notebook, polish beginning and tweak title and some wording 11 December 2022, 06:16:54 UTC
ffb4711 Expose channel_id in AllToAllOp in both XLA builder and MHLO. PiperOrigin-RevId: 494334791 10 December 2022, 05:58:28 UTC
0a2d1cd Set bcoo_cusparse_lowering to False by default This was causing out-of-bound writes on some CUDA backends PiperOrigin-RevId: 494280591 09 December 2022, 23:41:49 UTC
a8b90b3 [sparse] Fix a bug in BCSR tree_flatten. PiperOrigin-RevId: 494276415 09 December 2022, 23:22:58 UTC
d506770 Merge pull request #13579 from gnecula:roll_poly PiperOrigin-RevId: 494274916 09 December 2022, 23:16:03 UTC
0777ca6 Merge pull request #13591 from google:ci_v3-8 PiperOrigin-RevId: 494271663 09 December 2022, 23:01:33 UTC
8d4b50e [TPU CI] Run build matrix on v3-8 as well as v4-8 We're seeing failures on v3-8 that don't appear on the current v4-8 testing. v3-8 also exposes 8 devices (vs. v4-8 exposes 4), and some tests needs 8 devices to run. I just added a v3-8 runner VM. Also adds a missing pip install command (I only caught this with a fresh runner since it only needs to be installed once). 09 December 2022, 22:32:09 UTC
b4fbd83 Merge pull request #13587 from jakevdp:callback-doc PiperOrigin-RevId: 494261419 09 December 2022, 22:16:49 UTC
df02d70 DOC: add example of pure_callback with custom_jvp 09 December 2022, 20:43:04 UTC
f2c5d28 Merge pull request #13568 from hawkinsp:npy124 PiperOrigin-RevId: 494206861 09 December 2022, 18:35:07 UTC
11fbe55 [sparse] Add rand_bcsr to generate a random BCSR array. PiperOrigin-RevId: 494201364 09 December 2022, 18:10:28 UTC
02f96a2 Merge pull request #13569 from LenaMartens:typecheck PiperOrigin-RevId: 494167850 09 December 2022, 15:34:15 UTC
7fe466c Small fix to scan type-check error message. 09 December 2022, 11:41:41 UTC
86a70ab [jax2tf] Fix for jnp.roll with shape polymorphism There was a partial fix before, in #13470, but it was incomplete and the x64 mode was not handled properly. There are no tests added here; this was discovered by running the tests with --jax2tf_default_experimental_native_lowering, which will become default soon. 09 December 2022, 06:08:28 UTC
942aa7a [sparse] Move _dot_general_validated_shape to sparse util. PiperOrigin-RevId: 494031113 09 December 2022, 00:54:43 UTC
73de02d Make JAX tests pass under NumPy 1.24.0rc2. * allow rc2 in numpy versions when parsed by tests. * don't cast np.empty(), which can lead to cast errors. * NumPy 1.24 now warns on overflowing scalar int to array casts in more places. 08 December 2022, 19:46:10 UTC
2c92037 Fail lower_jaxpr_to_module if the module fails verification When working with George on https://github.com/google/jax/pull/13427, I discovered that modules with verifier errors can happily cross API boundaries and create confusion downstream. As discussed, this is unintentional - the expectation was that `ctx.module.operation.verify()` will throw an exception when verification fails. This CL addresses that and throws an exception accordingly. Not sure how to test this, given that passing a module with verifier errors to module_to_string indicates a logic error (i.e. such module shouldn't have been produced by JAX in the first place). As a result, I didn't write any tests, but I'm happy to write them if there's a good way to do that. PiperOrigin-RevId: 493940591 08 December 2022, 18:55:49 UTC
02ba16e Merge pull request #13251 from yotarok:toeplitz PiperOrigin-RevId: 493931922 08 December 2022, 18:26:00 UTC
a618f27 Add device_ids and axis_names to the Mesh repr PiperOrigin-RevId: 493916858 08 December 2022, 17:29:55 UTC
da285b6 Merge pull request #13566 from hawkinsp:flake8 PiperOrigin-RevId: 493906873 08 December 2022, 16:48:36 UTC
aacf44a flake8 now rejects inline comments. See: https://flake8.pycqa.org/en/latest/user/configuration.html (search for "inline comments"). 08 December 2022, 16:22:28 UTC
1ade5f8 Add `jax.scipy.linalg.toeplitz`. 08 December 2022, 16:03:21 UTC
440b25b Merge pull request #13427 from gnecula:tf_native_poly PiperOrigin-RevId: 493800880 08 December 2022, 06:41:23 UTC
8fb344a [jax2tf] An alternative support for shape polymorphism for native serialization. jax2tf already supports many cases of shape polymorphism, e.g., those where the shapes of all intermediates can be expressed as polynomials in the dimension variables in the input. We want to achieve the same same coverage, or more, while using StableHLO as the lowering format, rather than tf.Graph. For native serialization we will support two lowering implementations: * one is using the growing support in JAX for dynamic shapes, of which shape polymorphism is a special case. This implementation is enabled with the --jax_dynamic_shapes flag. At the moment, the JAX dynamic shapes support is still incomplete and over 300 jax2tf shape polymorphism tests fail. * a new one (added) here in which we form a Jaxpr using abstract values that express dimension sizes as dimension polynomials (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO. This implementation is enabled when --jax_dynamic_shapes is off. With this implementation only 50 jax2tf tests fail (to be fixed separately). The key contribution here is to enable lowering a Jaxpr that contains dimension polynomials in some of the intermediate shapes. Many lowering rules already have some partial support for Jaxprs where the shapes contain `Var`s. To the extent possible, we try to write lowering rules that should cover both cases of dynamic shapes: Var or polynomials in shapes. The lowering convention is that at top level we collect the sorted list of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars. All IR functions will take N additional prefix arguments of int32 type containing the values of the dimension variables. This is stored as a list of `ir.Value` in `LoweringContext.dim_var_values`. Note that the Jaxprs are not changed to have extra Vars for the dimension variable values. An alternative implementation could work by transforming the Jaxpr to replace dimension polynomials into Vars. The key code pattern used in the lowering rule is:: if not core.is_constant_shape(shape): # Handles both Var, and polynomials shape = mlir.eval_dynamic_shape(ctx, shape) return mhlo.DynamicXXX(..., shape) else: return mhlo.XXX(..., shape) with `mlir.eval_dynamic_shape` handling both cases:: def eval_dynamic_shape(ctx, shape): if config.jax_dynamic_shapes: # Using Var return ... subst using ctx.axis_size_env ... else: # Using polynomials return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values In order to support the above some lowering functions need to take a LoweringContext parameter, e.g., mlir.broadcast_mhlo. I expect that the changes here will improve the --jax_dynamic_shapes coverage as well. 08 December 2022, 06:19:35 UTC
0118f8d Prepare for jax and jaxlib 0.4.0 release PiperOrigin-RevId: 493733609 08 December 2022, 00:02:24 UTC
dd64760 Make `jnp.copy` work with all shardings especially `PmapSharding`. This fixes the problem where a `jax.Array` with a `PmapSharding` round tripped through host and returned a `jax.Array` with a `SingleDeviceSharding`. Now, `jnp.copy` works without going through a round trip via host and maintains the sharding of the input array across all the Shardings we have. PiperOrigin-RevId: 493728354 07 December 2022, 23:41:47 UTC
3159642 Merge pull request #13547 from jakevdp:numpy-msort PiperOrigin-RevId: 493649811 07 December 2022, 18:50:33 UTC
ce990cf Merge pull request #13546 from jakevdp:shard-types PiperOrigin-RevId: 493637391 07 December 2022, 18:11:53 UTC
09d1b6d Deprecate jnp.msort following deprecation of numpy.msort 07 December 2022, 18:08:18 UTC
777754d [typing] fix Shard/Sharding type TODO 07 December 2022, 17:44:03 UTC
794cec1 Merge pull request #13544 from jakevdp:fix-readme PiperOrigin-RevId: 493625570 07 December 2022, 17:28:25 UTC
79c8d66 README: fix badge URL 07 December 2022, 16:48:40 UTC
6e0c802 Fix type annotation for bcoo_update_layout. PiperOrigin-RevId: 493567424 07 December 2022, 12:25:25 UTC
ac72346 Ensure that the initial dynamic_trace_state is canonicalized. The non-canonical state meant that we were falling back to a more expensive comparison for the first jit-compiled function in the program. I doubt there will be any impact on real benchmarks, but this perturbs the results of running a single microbenchmark in isolation. PiperOrigin-RevId: 493489154 07 December 2022, 04:39:53 UTC
b168df0 Merge pull request #13539 from wookayin:fix/config-typing PiperOrigin-RevId: 493473725 07 December 2022, 03:10:02 UTC
cd22585 Fix a false-positive typing warning on jax.default_device Consider the following code where static type checkers can report an error: ```python CPU = jax.devices('cpu')[0] with jax.default_device(CPU): ... # ^^^ ``` Error message: ``` Pyright: Argument of type "Device" cannot be assigned to parameter "new_val" of type "NoDefault"   "Device" is incompatible with "NoDefault" (reportGeneralTypeIssues) ``` This is because `_StateContextManager.__call__` does not have a proper type annotation on the parameter, unlike the attribute `_default_value` which has a type annotation. Adding a `Any` to the parameter would make the error disappear. 07 December 2022, 02:05:35 UTC
61398ff Merge pull request #13536 from jakevdp:quickstart-timing PiperOrigin-RevId: 493414640 06 December 2022, 22:40:44 UTC
23b808f Merge pull request #13446 from google:maxfail PiperOrigin-RevId: 493414635 06 December 2022, 22:34:01 UTC
7b59ce2 DOC: pre-execute the quickstart notebook on GPU 06 December 2022, 21:24:02 UTC
1132098 [PJRT:C] Separate loading PJRT plugin from creating PJRT client. - Add xla_client.maybe_load_pjrt_plugins which maybe load PJRT plugins from a hardcoded set. - Call xla_client.maybe_load_pjrt_plugins in xla_bridge beforer initializing backends. - Add binding of python method load_pjrt_plugin to LoadPjrtPlugin which does dlopen and dlsym. - Remove loading PJRT plugin from tpu_initializer_helper.cc. - Add an extra call to LoadPjrtPlugin when getting the PJRT_Api* to be backward compatible. PiperOrigin-RevId: 493381393 06 December 2022, 20:29:38 UTC
1bbcec7 Update FAQ since buffer donation is implemented on CPU. PiperOrigin-RevId: 493372426 06 December 2022, 19:57:34 UTC
2fb9238 Merge pull request #13534 from jakevdp:fix-rand-test PiperOrigin-RevId: 493349509 06 December 2022, 18:40:35 UTC
7e3f674 random_test: skip singular covariance test on accelerators 06 December 2022, 17:26:23 UTC
edaa562 Merge pull request #13517 from froystig:rng-part-jax2tf PiperOrigin-RevId: 493317025 06 December 2022, 16:46:28 UTC
72f97f7 Merge pull request #13523 from jakevdp:sparse-unused PiperOrigin-RevId: 493312001 06 December 2022, 16:31:49 UTC
9f5a631 Merge pull request #13518 from jakevdp:multiv-normal PiperOrigin-RevId: 493311895 06 December 2022, 16:31:32 UTC
33a1b88 Mark arguments to ufuncs as positional-only. PiperOrigin-RevId: 493311821 06 December 2022, 16:24:11 UTC
a790016 Merge pull request #12962 from hawkinsp:rocm PiperOrigin-RevId: 493308550 06 December 2022, 16:09:45 UTC
938b625 Remove unreliable call_tf_test PiperOrigin-RevId: 493225546 06 December 2022, 08:31:19 UTC
e7e9687 Allow pjit's C++ dispatch path to operate on uncommitted array only if it belongs on a single device. This will bring pjit's dispatch performance in line with `jit` to prepare for jit/pjit frontend merge. PiperOrigin-RevId: 493164446 06 December 2022, 02:09:59 UTC
2fc7bbc Cache the exec time `tf.get_current_name_scope` into `_thread_local_state.exec_time_tf_name_scope` and add it as prefix during tracing. Also move the test cases to right code location. PiperOrigin-RevId: 493163018 06 December 2022, 02:02:44 UTC
516f0d0 Support negative axes in all_gather. Previously we didn't check for these and they caused crashes during MHLO verification. PiperOrigin-RevId: 493160581 06 December 2022, 01:48:50 UTC
189d0c0 Merge pull request #13525 from jakevdp:typing-extensions PiperOrigin-RevId: 493158064 06 December 2022, 01:35:04 UTC
9400555 Merge pull request #13526 from jakevdp:isinstance-array PiperOrigin-RevId: 493156840 06 December 2022, 01:27:55 UTC
88bfe90 Replace utility function with isinstance check 06 December 2022, 00:13:33 UTC
382a248 Merge pull request #13512 from jakevdp:fix-optimizers PiperOrigin-RevId: 493139751 06 December 2022, 00:11:59 UTC
4389216 Remove typing_extensions dependency 05 December 2022, 23:42:26 UTC
a42e9e2 [sparse] delete _bcoo_unbatch utility 05 December 2022, 22:35:34 UTC
23261d7 Merge pull request #13515 from jakevdp:sparse-typing PiperOrigin-RevId: 493113475 05 December 2022, 22:28:18 UTC
1f9c988 Use `_thread_local_state.__dict__.get()` instead of `getattr(_thread_local_state, ...)`. `getattr` turns out to be a tiny bit slower than `__get__()` on `__dict__` in the case that the attribute is absent. `getattr` appears to form an error message that is thrown away if a default is present. Improves the device_put benchmark: ``` name old cpu/op new cpu/op delta device_put 51.4µs ± 1% 48.9µs ± 3% -4.87% (p=0.000 n=8+9) name old time/op new time/op delta device_put 51.4µs ± 1% 48.9µs ± 3% -4.87% (p=0.000 n=8+9) ``` PiperOrigin-RevId: 493108288 05 December 2022, 22:09:47 UTC
c376836 [typing] annotate jax.experimental.sparse 05 December 2022, 22:05:25 UTC
45707a2 Merge pull request #13499 from jakevdp:error-docs PiperOrigin-RevId: 493092719 05 December 2022, 21:14:37 UTC
25d1a0b Add cudnn 86 (for cuda 11.8) so that I can release cuda 11.8 nightlies. PiperOrigin-RevId: 493086060 05 December 2022, 20:50:09 UTC
58d6a3b random.multivariate_normal: add note about singular covariance 05 December 2022, 20:43:05 UTC
29942e3 docs: add another example to the ConcretizationTypeError docs 05 December 2022, 19:24:54 UTC
431c51a rename `iota_32x2_shape` to `iota_2x32_shape` ... for consistency with other internal Threefry primitive names. 05 December 2022, 19:09:56 UTC
dad21d3 Merge pull request #13506 from simonbutt:bugfix/jax101-pytrees PiperOrigin-RevId: 493052164 05 December 2022, 18:45:21 UTC
b8b6e27 Add typehints and point to the correct endpoint of Mesh and PartitionSpec in the args section. PiperOrigin-RevId: 493035898 05 December 2022, 17:49:18 UTC
d317cfa Revert part of #13498 05 December 2022, 17:21:09 UTC
75af6b5 add a jax2tf translation rule for the shaped-iota primitive This allows for jax2tf conversion of the partitionable Threefry RNG. 05 December 2022, 17:19:25 UTC
a3483db docstring for shaped iota primitive 05 December 2022, 17:15:27 UTC
401fbb6 Disable xmap_test on TPU under asan due to CI timeouts. PiperOrigin-RevId: 492994226 05 December 2022, 14:52:09 UTC
55d6daa Skip test_lstm on CPU and TPU for jax OSS build. PiperOrigin-RevId: 492722650 03 December 2022, 22:16:07 UTC
542b38a Updated jax.tree_leaves --> jax.tree_util.tree_leaves to remove deprecation notice in jax101-pytrees tutorial Signed-off-by: Simon Butt <simonbutt123@gmail.com> 03 December 2022, 21:22:25 UTC
e814f70 Raise an error when a numpy input is passed with a non-trivial sharding. This can lead to weird behavior with pjit and XLA since host-local inputs are not allowed with pjit anymore. PiperOrigin-RevId: 492621424 03 December 2022, 04:47:45 UTC
02fab52 Add tests to check if pjit handles deleted array inputs gracefully and consistently pjit dispatch paths should check deleted array inputs when attempting to use them. These new tests ensure that various pjit dispatch paths detect and handle them gracefully and consistently. Add a check to the PyArray argument handling to make the tests pass. PiperOrigin-RevId: 492605524 03 December 2022, 02:41:31 UTC
693047a Merge pull request #13498 from jakevdp:x64-other-tests PiperOrigin-RevId: 492593760 03 December 2022, 01:10:04 UTC
f22d4a8 Merge pull request #13490 from jakevdp:x64-check-grads PiperOrigin-RevId: 492592519 03 December 2022, 01:01:56 UTC
06755ad Reduce the buffer size used in ShardedDeviceArrayTest.testThreadsafeIndexing testThreadsafeIndexing uses a fairly large buffer size. When overlapping many executions under a constraint host memory for testing using an alternative backend, this test may hit the maximum allowed memory use. This change reduces the buffer size by half, which is likely still interesting and runs more reliably on an alternative backend. PiperOrigin-RevId: 492588538 03 December 2022, 00:38:47 UTC
c2c3669 Remove long-deprecated method .block_host_until_ready(). PiperOrigin-RevId: 492571809 02 December 2022, 23:18:11 UTC
5e102c1 Implement .on_device_size_in_bytes() on jax.Array. This is an array present in DeviceArray that is missing from Array. PiperOrigin-RevId: 492571171 02 December 2022, 23:11:27 UTC
924894f [x64] make tests more type-safe 02 December 2022, 21:21:35 UTC
9e53de8 [x64] make chack_grads() more type-safe 02 December 2022, 20:51:41 UTC
8a28ccd Merge pull request #13491 from jakevdp:x64-stax PiperOrigin-RevId: 492526309 02 December 2022, 20:06:11 UTC
f9b5312 Do not mirror JAX config options back to ABSL flags. Currently when JAX config values are configured via ABSL, we use the ABSL flags as a source of truth: if we read or write the JAX config option, we read or write the corresponding ABSL flag. This works but has the unfortunate downside that ABSL flags are relatively slow to read, which slows down JAX every time we read a configuration option. However, there's fundamentally no reason we are mirroring the JAX configuration options back to ABSL in the first place. We can use ABSL flag parsing as a way only to populate the JAX configuration values. The downside is that if someone changes the ABSL flag values after parsing, that change will not be reflected in JAX's config values. JAX config changes after ABSL flags have been parsed must be made via the `jax.config.update()` API. This gives a decent improvement on the device_put benchmark: ``` name old cpu/op new cpu/op delta device_put 79.5µs ± 6% 69.4µs ± 7% -12.73% (p=0.000 n=10+9) name old time/op new time/op delta device_put 79.5µs ± 6% 69.4µs ± 7% -12.73% (p=0.000 n=10+9) ``` PiperOrigin-RevId: 492519085 02 December 2022, 19:37:22 UTC
8431e43 [x64] more type safety in stax_test.py 02 December 2022, 18:02:25 UTC
1027d55 Optimize core.find_top_trace This function is quite important, since it runs at every JAX primitive bind, but it included a few redundant conditionals. PiperOrigin-RevId: 492481837 02 December 2022, 17:00:50 UTC
bbf22db Optimize core.find_top_trace This function is quite important, since it runs at every JAX primitive bind, but it included a few redundant conditionals. PiperOrigin-RevId: 492460102 02 December 2022, 15:04:52 UTC
01377bc Merge pull request #13485 from jakevdp:x64-random PiperOrigin-RevId: 492367153 02 December 2022, 04:43:18 UTC
5927032 Merge pull request #13482 from jakevdp:x64-signal PiperOrigin-RevId: 492367133 02 December 2022, 04:36:41 UTC
e1d118c Merge pull request #13476 from jakevdp:x64-lax-numpy PiperOrigin-RevId: 492367125 02 December 2022, 04:29:46 UTC
back to top