b868cf7 | jax authors | 12 December 2022, 18:03:54 UTC | Merge pull request #13616 from jakevdp:fix-sparse-error PiperOrigin-RevId: 494758906 | 12 December 2022, 18:03:54 UTC |
4a9e9d5 | jax authors | 12 December 2022, 17:46:29 UTC | Merge pull request #13584 from LenaMartens:debug-check PiperOrigin-RevId: 494754750 | 12 December 2022, 17:46:29 UTC |
2e95990 | Jake VanderPlas | 12 December 2022, 17:38:17 UTC | [sparse] fix logic bug in GPU testcase | 12 December 2022, 17:38:17 UTC |
3db909e | lenamartens | 09 December 2022, 14:01:44 UTC | Checkify: Add documentation for adding run-time values to error message. | 12 December 2022, 15:56:40 UTC |
13c34f9 | Yash Katariya | 12 December 2022, 06:54:39 UTC | Move `with_sharding_constraint` out of experimental into `jax.lax` namespace. PiperOrigin-RevId: 494635809 | 12 December 2022, 06:55:21 UTC |
94590e2 | jax authors | 12 December 2022, 06:32:34 UTC | Merge pull request #13562 from gnecula:opaque_shape_poly PiperOrigin-RevId: 494632979 | 12 December 2022, 06:32:34 UTC |
27f5bd0 | George Necula | 04 December 2022, 06:37:24 UTC | Improves handling of opaque types for dynamic shapes The immediate motivation for this is to support the lowering to StableHLO for programs with polymorphic shapes. This requires mixing of dynamic shapes with opaque types. The general strategy is to push the actual selection of the MHLO ops down into mlir module (e.g., mlir.slice_op, mlir.broadcast_in_dim) so that we have one place where we pick whether we use the Dynamic or static ops. These routines can also handle the opaque type. This will result in a recursive call to, e.g., mlir.slice_op, but the inner call will be using the physical avals, which should not be opaque anymore. While making this change I was confused by the fact that the custom KeyTyRules in prng.py have lowerings that return multiple MHLO ops. See https://github.com/google/jax/pull/11768#issuecomment-1342349102 and I changed the rules to return a single op. . | 12 December 2022, 04:19:04 UTC |
2f1354e | George Necula | 12 December 2022, 04:17:42 UTC | Add workaround for imprecise shape inference for DynamicGatherOp This is needed for gather in presence of dynamic shapes. PiperOrigin-RevId: 494613303 | 12 December 2022, 04:18:15 UTC |
4af4234 | jax authors | 11 December 2022, 06:45:30 UTC | Merge pull request #13602 from mattjj:tweak-array-notebook PiperOrigin-RevId: 494480882 | 11 December 2022, 06:45:30 UTC |
1185c89 | Matthew Johnson | 11 December 2022, 03:59:14 UTC | in jax.Array notebook, polish beginning and tweak title and some wording | 11 December 2022, 06:16:54 UTC |
ffb4711 | Anselm Levskaya | 10 December 2022, 05:57:46 UTC | Expose channel_id in AllToAllOp in both XLA builder and MHLO. PiperOrigin-RevId: 494334791 | 10 December 2022, 05:58:28 UTC |
0a2d1cd | Jake VanderPlas | 09 December 2022, 23:41:06 UTC | Set bcoo_cusparse_lowering to False by default This was causing out-of-bound writes on some CUDA backends PiperOrigin-RevId: 494280591 | 09 December 2022, 23:41:49 UTC |
a8b90b3 | Tianjian Lu | 09 December 2022, 23:22:15 UTC | [sparse] Fix a bug in BCSR tree_flatten. PiperOrigin-RevId: 494276415 | 09 December 2022, 23:22:58 UTC |
d506770 | jax authors | 09 December 2022, 23:16:03 UTC | Merge pull request #13579 from gnecula:roll_poly PiperOrigin-RevId: 494274916 | 09 December 2022, 23:16:03 UTC |
0777ca6 | jax authors | 09 December 2022, 23:01:33 UTC | Merge pull request #13591 from google:ci_v3-8 PiperOrigin-RevId: 494271663 | 09 December 2022, 23:01:33 UTC |
8d4b50e | Skye Wanderman-Milne | 09 December 2022, 22:11:18 UTC | [TPU CI] Run build matrix on v3-8 as well as v4-8 We're seeing failures on v3-8 that don't appear on the current v4-8 testing. v3-8 also exposes 8 devices (vs. v4-8 exposes 4), and some tests needs 8 devices to run. I just added a v3-8 runner VM. Also adds a missing pip install command (I only caught this with a fresh runner since it only needs to be installed once). | 09 December 2022, 22:32:09 UTC |
b4fbd83 | jax authors | 09 December 2022, 22:16:49 UTC | Merge pull request #13587 from jakevdp:callback-doc PiperOrigin-RevId: 494261419 | 09 December 2022, 22:16:49 UTC |
df02d70 | Jake VanderPlas | 09 December 2022, 20:43:04 UTC | DOC: add example of pure_callback with custom_jvp | 09 December 2022, 20:43:04 UTC |
f2c5d28 | jax authors | 09 December 2022, 18:35:07 UTC | Merge pull request #13568 from hawkinsp:npy124 PiperOrigin-RevId: 494206861 | 09 December 2022, 18:35:07 UTC |
11fbe55 | Tianjian Lu | 09 December 2022, 18:09:48 UTC | [sparse] Add rand_bcsr to generate a random BCSR array. PiperOrigin-RevId: 494201364 | 09 December 2022, 18:10:28 UTC |
02f96a2 | jax authors | 09 December 2022, 15:34:15 UTC | Merge pull request #13569 from LenaMartens:typecheck PiperOrigin-RevId: 494167850 | 09 December 2022, 15:34:15 UTC |
7fe466c | lenamartens | 08 December 2022, 16:16:49 UTC | Small fix to scan type-check error message. | 09 December 2022, 11:41:41 UTC |
86a70ab | George Necula | 07 December 2022, 08:43:33 UTC | [jax2tf] Fix for jnp.roll with shape polymorphism There was a partial fix before, in #13470, but it was incomplete and the x64 mode was not handled properly. There are no tests added here; this was discovered by running the tests with --jax2tf_default_experimental_native_lowering, which will become default soon. | 09 December 2022, 06:08:28 UTC |
942aa7a | Tianjian Lu | 09 December 2022, 00:54:15 UTC | [sparse] Move _dot_general_validated_shape to sparse util. PiperOrigin-RevId: 494031113 | 09 December 2022, 00:54:43 UTC |
73de02d | Peter Hawkins | 08 December 2022, 19:40:56 UTC | Make JAX tests pass under NumPy 1.24.0rc2. * allow rc2 in numpy versions when parsed by tests. * don't cast np.empty(), which can lead to cast errors. * NumPy 1.24 now warns on overflowing scalar int to array casts in more places. | 08 December 2022, 19:46:10 UTC |
2c92037 | Eugene Burmako | 08 December 2022, 18:55:14 UTC | Fail lower_jaxpr_to_module if the module fails verification When working with George on https://github.com/google/jax/pull/13427, I discovered that modules with verifier errors can happily cross API boundaries and create confusion downstream. As discussed, this is unintentional - the expectation was that `ctx.module.operation.verify()` will throw an exception when verification fails. This CL addresses that and throws an exception accordingly. Not sure how to test this, given that passing a module with verifier errors to module_to_string indicates a logic error (i.e. such module shouldn't have been produced by JAX in the first place). As a result, I didn't write any tests, but I'm happy to write them if there's a good way to do that. PiperOrigin-RevId: 493940591 | 08 December 2022, 18:55:49 UTC |
02ba16e | jax authors | 08 December 2022, 18:26:00 UTC | Merge pull request #13251 from yotarok:toeplitz PiperOrigin-RevId: 493931922 | 08 December 2022, 18:26:00 UTC |
a618f27 | Yash Katariya | 08 December 2022, 17:29:07 UTC | Add device_ids and axis_names to the Mesh repr PiperOrigin-RevId: 493916858 | 08 December 2022, 17:29:55 UTC |
da285b6 | jax authors | 08 December 2022, 16:48:36 UTC | Merge pull request #13566 from hawkinsp:flake8 PiperOrigin-RevId: 493906873 | 08 December 2022, 16:48:36 UTC |
aacf44a | Peter Hawkins | 08 December 2022, 16:22:28 UTC | flake8 now rejects inline comments. See: https://flake8.pycqa.org/en/latest/user/configuration.html (search for "inline comments"). | 08 December 2022, 16:22:28 UTC |
1ade5f8 | Yotaro Kubo | 15 November 2022, 09:40:52 UTC | Add `jax.scipy.linalg.toeplitz`. | 08 December 2022, 16:03:21 UTC |
440b25b | jax authors | 08 December 2022, 06:41:23 UTC | Merge pull request #13427 from gnecula:tf_native_poly PiperOrigin-RevId: 493800880 | 08 December 2022, 06:41:23 UTC |
8fb344a | George Necula | 28 November 2022, 12:16:07 UTC | [jax2tf] An alternative support for shape polymorphism for native serialization. jax2tf already supports many cases of shape polymorphism, e.g., those where the shapes of all intermediates can be expressed as polynomials in the dimension variables in the input. We want to achieve the same same coverage, or more, while using StableHLO as the lowering format, rather than tf.Graph. For native serialization we will support two lowering implementations: * one is using the growing support in JAX for dynamic shapes, of which shape polymorphism is a special case. This implementation is enabled with the --jax_dynamic_shapes flag. At the moment, the JAX dynamic shapes support is still incomplete and over 300 jax2tf shape polymorphism tests fail. * a new one (added) here in which we form a Jaxpr using abstract values that express dimension sizes as dimension polynomials (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO. This implementation is enabled when --jax_dynamic_shapes is off. With this implementation only 50 jax2tf tests fail (to be fixed separately). The key contribution here is to enable lowering a Jaxpr that contains dimension polynomials in some of the intermediate shapes. Many lowering rules already have some partial support for Jaxprs where the shapes contain `Var`s. To the extent possible, we try to write lowering rules that should cover both cases of dynamic shapes: Var or polynomials in shapes. The lowering convention is that at top level we collect the sorted list of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars. All IR functions will take N additional prefix arguments of int32 type containing the values of the dimension variables. This is stored as a list of `ir.Value` in `LoweringContext.dim_var_values`. Note that the Jaxprs are not changed to have extra Vars for the dimension variable values. An alternative implementation could work by transforming the Jaxpr to replace dimension polynomials into Vars. The key code pattern used in the lowering rule is:: if not core.is_constant_shape(shape): # Handles both Var, and polynomials shape = mlir.eval_dynamic_shape(ctx, shape) return mhlo.DynamicXXX(..., shape) else: return mhlo.XXX(..., shape) with `mlir.eval_dynamic_shape` handling both cases:: def eval_dynamic_shape(ctx, shape): if config.jax_dynamic_shapes: # Using Var return ... subst using ctx.axis_size_env ... else: # Using polynomials return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values In order to support the above some lowering functions need to take a LoweringContext parameter, e.g., mlir.broadcast_mhlo. I expect that the changes here will improve the --jax_dynamic_shapes coverage as well. | 08 December 2022, 06:19:35 UTC |
0118f8d | Yash Katariya | 08 December 2022, 00:01:43 UTC | Prepare for jax and jaxlib 0.4.0 release PiperOrigin-RevId: 493733609 | 08 December 2022, 00:02:24 UTC |
dd64760 | Yash Katariya | 07 December 2022, 23:41:13 UTC | Make `jnp.copy` work with all shardings especially `PmapSharding`. This fixes the problem where a `jax.Array` with a `PmapSharding` round tripped through host and returned a `jax.Array` with a `SingleDeviceSharding`. Now, `jnp.copy` works without going through a round trip via host and maintains the sharding of the input array across all the Shardings we have. PiperOrigin-RevId: 493728354 | 07 December 2022, 23:41:47 UTC |
3159642 | jax authors | 07 December 2022, 18:50:33 UTC | Merge pull request #13547 from jakevdp:numpy-msort PiperOrigin-RevId: 493649811 | 07 December 2022, 18:50:33 UTC |
ce990cf | jax authors | 07 December 2022, 18:11:53 UTC | Merge pull request #13546 from jakevdp:shard-types PiperOrigin-RevId: 493637391 | 07 December 2022, 18:11:53 UTC |
09d1b6d | Jake VanderPlas | 07 December 2022, 18:08:18 UTC | Deprecate jnp.msort following deprecation of numpy.msort | 07 December 2022, 18:08:18 UTC |
777754d | Jake VanderPlas | 07 December 2022, 17:44:03 UTC | [typing] fix Shard/Sharding type TODO | 07 December 2022, 17:44:03 UTC |
794cec1 | jax authors | 07 December 2022, 17:28:25 UTC | Merge pull request #13544 from jakevdp:fix-readme PiperOrigin-RevId: 493625570 | 07 December 2022, 17:28:25 UTC |
79c8d66 | Jake VanderPlas | 07 December 2022, 16:48:40 UTC | README: fix badge URL | 07 December 2022, 16:48:40 UTC |
6e0c802 | Shaobo Hou | 07 December 2022, 12:24:29 UTC | Fix type annotation for bcoo_update_layout. PiperOrigin-RevId: 493567424 | 07 December 2022, 12:25:25 UTC |
ac72346 | Peter Hawkins | 07 December 2022, 04:39:21 UTC | Ensure that the initial dynamic_trace_state is canonicalized. The non-canonical state meant that we were falling back to a more expensive comparison for the first jit-compiled function in the program. I doubt there will be any impact on real benchmarks, but this perturbs the results of running a single microbenchmark in isolation. PiperOrigin-RevId: 493489154 | 07 December 2022, 04:39:53 UTC |
b168df0 | jax authors | 07 December 2022, 03:10:02 UTC | Merge pull request #13539 from wookayin:fix/config-typing PiperOrigin-RevId: 493473725 | 07 December 2022, 03:10:02 UTC |
cd22585 | Jongwook Choi | 06 December 2022, 23:39:54 UTC | Fix a false-positive typing warning on jax.default_device Consider the following code where static type checkers can report an error: ```python CPU = jax.devices('cpu')[0] with jax.default_device(CPU): ... # ^^^ ``` Error message: ``` Pyright: Argument of type "Device" cannot be assigned to parameter "new_val" of type "NoDefault" "Device" is incompatible with "NoDefault" (reportGeneralTypeIssues) ``` This is because `_StateContextManager.__call__` does not have a proper type annotation on the parameter, unlike the attribute `_default_value` which has a type annotation. Adding a `Any` to the parameter would make the error disappear. | 07 December 2022, 02:05:35 UTC |
61398ff | jax authors | 06 December 2022, 22:40:44 UTC | Merge pull request #13536 from jakevdp:quickstart-timing PiperOrigin-RevId: 493414640 | 06 December 2022, 22:40:44 UTC |
23b808f | jax authors | 06 December 2022, 22:34:01 UTC | Merge pull request #13446 from google:maxfail PiperOrigin-RevId: 493414635 | 06 December 2022, 22:34:01 UTC |
7b59ce2 | Jake VanderPlas | 06 December 2022, 21:24:02 UTC | DOC: pre-execute the quickstart notebook on GPU | 06 December 2022, 21:24:02 UTC |
1132098 | Jieying Luo | 06 December 2022, 20:29:03 UTC | [PJRT:C] Separate loading PJRT plugin from creating PJRT client. - Add xla_client.maybe_load_pjrt_plugins which maybe load PJRT plugins from a hardcoded set. - Call xla_client.maybe_load_pjrt_plugins in xla_bridge beforer initializing backends. - Add binding of python method load_pjrt_plugin to LoadPjrtPlugin which does dlopen and dlsym. - Remove loading PJRT plugin from tpu_initializer_helper.cc. - Add an extra call to LoadPjrtPlugin when getting the PJRT_Api* to be backward compatible. PiperOrigin-RevId: 493381393 | 06 December 2022, 20:29:38 UTC |
1bbcec7 | Peter Hawkins | 06 December 2022, 19:56:54 UTC | Update FAQ since buffer donation is implemented on CPU. PiperOrigin-RevId: 493372426 | 06 December 2022, 19:57:34 UTC |
2fb9238 | jax authors | 06 December 2022, 18:40:35 UTC | Merge pull request #13534 from jakevdp:fix-rand-test PiperOrigin-RevId: 493349509 | 06 December 2022, 18:40:35 UTC |
7e3f674 | Jake VanderPlas | 06 December 2022, 17:26:23 UTC | random_test: skip singular covariance test on accelerators | 06 December 2022, 17:26:23 UTC |
edaa562 | jax authors | 06 December 2022, 16:46:28 UTC | Merge pull request #13517 from froystig:rng-part-jax2tf PiperOrigin-RevId: 493317025 | 06 December 2022, 16:46:28 UTC |
72f97f7 | jax authors | 06 December 2022, 16:31:49 UTC | Merge pull request #13523 from jakevdp:sparse-unused PiperOrigin-RevId: 493312001 | 06 December 2022, 16:31:49 UTC |
9f5a631 | jax authors | 06 December 2022, 16:31:32 UTC | Merge pull request #13518 from jakevdp:multiv-normal PiperOrigin-RevId: 493311895 | 06 December 2022, 16:31:32 UTC |
33a1b88 | Peter Hawkins | 06 December 2022, 16:23:40 UTC | Mark arguments to ufuncs as positional-only. PiperOrigin-RevId: 493311821 | 06 December 2022, 16:24:11 UTC |
a790016 | jax authors | 06 December 2022, 16:09:45 UTC | Merge pull request #12962 from hawkinsp:rocm PiperOrigin-RevId: 493308550 | 06 December 2022, 16:09:45 UTC |
938b625 | George Necula | 06 December 2022, 08:30:37 UTC | Remove unreliable call_tf_test PiperOrigin-RevId: 493225546 | 06 December 2022, 08:31:19 UTC |
e7e9687 | Yash Katariya | 06 December 2022, 02:09:26 UTC | Allow pjit's C++ dispatch path to operate on uncommitted array only if it belongs on a single device. This will bring pjit's dispatch performance in line with `jit` to prepare for jit/pjit frontend merge. PiperOrigin-RevId: 493164446 | 06 December 2022, 02:09:59 UTC |
2fc7bbc | John QiangZhang | 06 December 2022, 02:02:05 UTC | Cache the exec time `tf.get_current_name_scope` into `_thread_local_state.exec_time_tf_name_scope` and add it as prefix during tracing. Also move the test cases to right code location. PiperOrigin-RevId: 493163018 | 06 December 2022, 02:02:44 UTC |
516f0d0 | Peter Hawkins | 06 December 2022, 01:48:22 UTC | Support negative axes in all_gather. Previously we didn't check for these and they caused crashes during MHLO verification. PiperOrigin-RevId: 493160581 | 06 December 2022, 01:48:50 UTC |
189d0c0 | jax authors | 06 December 2022, 01:35:04 UTC | Merge pull request #13525 from jakevdp:typing-extensions PiperOrigin-RevId: 493158064 | 06 December 2022, 01:35:04 UTC |
9400555 | jax authors | 06 December 2022, 01:27:55 UTC | Merge pull request #13526 from jakevdp:isinstance-array PiperOrigin-RevId: 493156840 | 06 December 2022, 01:27:55 UTC |
88bfe90 | Jake VanderPlas | 06 December 2022, 00:13:33 UTC | Replace utility function with isinstance check | 06 December 2022, 00:13:33 UTC |
382a248 | jax authors | 06 December 2022, 00:11:59 UTC | Merge pull request #13512 from jakevdp:fix-optimizers PiperOrigin-RevId: 493139751 | 06 December 2022, 00:11:59 UTC |
4389216 | Jake VanderPlas | 05 December 2022, 23:42:26 UTC | Remove typing_extensions dependency | 05 December 2022, 23:42:26 UTC |
a42e9e2 | Jake VanderPlas | 05 December 2022, 22:35:34 UTC | [sparse] delete _bcoo_unbatch utility | 05 December 2022, 22:35:34 UTC |
23261d7 | jax authors | 05 December 2022, 22:28:18 UTC | Merge pull request #13515 from jakevdp:sparse-typing PiperOrigin-RevId: 493113475 | 05 December 2022, 22:28:18 UTC |
1f9c988 | Peter Hawkins | 05 December 2022, 22:09:06 UTC | Use `_thread_local_state.__dict__.get()` instead of `getattr(_thread_local_state, ...)`. `getattr` turns out to be a tiny bit slower than `__get__()` on `__dict__` in the case that the attribute is absent. `getattr` appears to form an error message that is thrown away if a default is present. Improves the device_put benchmark: ``` name old cpu/op new cpu/op delta device_put 51.4µs ± 1% 48.9µs ± 3% -4.87% (p=0.000 n=8+9) name old time/op new time/op delta device_put 51.4µs ± 1% 48.9µs ± 3% -4.87% (p=0.000 n=8+9) ``` PiperOrigin-RevId: 493108288 | 05 December 2022, 22:09:47 UTC |
c376836 | Jake VanderPlas | 05 December 2022, 22:05:25 UTC | [typing] annotate jax.experimental.sparse | 05 December 2022, 22:05:25 UTC |
45707a2 | jax authors | 05 December 2022, 21:14:37 UTC | Merge pull request #13499 from jakevdp:error-docs PiperOrigin-RevId: 493092719 | 05 December 2022, 21:14:37 UTC |
25d1a0b | Yash Katariya | 05 December 2022, 20:49:33 UTC | Add cudnn 86 (for cuda 11.8) so that I can release cuda 11.8 nightlies. PiperOrigin-RevId: 493086060 | 05 December 2022, 20:50:09 UTC |
58d6a3b | Jake VanderPlas | 05 December 2022, 20:43:05 UTC | random.multivariate_normal: add note about singular covariance | 05 December 2022, 20:43:05 UTC |
29942e3 | Jake VanderPlas | 02 December 2022, 23:55:33 UTC | docs: add another example to the ConcretizationTypeError docs | 05 December 2022, 19:24:54 UTC |
431c51a | Roy Frostig | 05 December 2022, 19:09:56 UTC | rename `iota_32x2_shape` to `iota_2x32_shape` ... for consistency with other internal Threefry primitive names. | 05 December 2022, 19:09:56 UTC |
dad21d3 | jax authors | 05 December 2022, 18:45:21 UTC | Merge pull request #13506 from simonbutt:bugfix/jax101-pytrees PiperOrigin-RevId: 493052164 | 05 December 2022, 18:45:21 UTC |
b8b6e27 | Yash Katariya | 05 December 2022, 17:48:46 UTC | Add typehints and point to the correct endpoint of Mesh and PartitionSpec in the args section. PiperOrigin-RevId: 493035898 | 05 December 2022, 17:49:18 UTC |
d317cfa | Jake VanderPlas | 05 December 2022, 17:21:09 UTC | Revert part of #13498 | 05 December 2022, 17:21:09 UTC |
75af6b5 | Roy Frostig | 05 December 2022, 17:16:19 UTC | add a jax2tf translation rule for the shaped-iota primitive This allows for jax2tf conversion of the partitionable Threefry RNG. | 05 December 2022, 17:19:25 UTC |
a3483db | Roy Frostig | 05 December 2022, 17:15:27 UTC | docstring for shaped iota primitive | 05 December 2022, 17:15:27 UTC |
401fbb6 | Peter Hawkins | 05 December 2022, 14:51:28 UTC | Disable xmap_test on TPU under asan due to CI timeouts. PiperOrigin-RevId: 492994226 | 05 December 2022, 14:52:09 UTC |
55d6daa | Qiao Zhang | 03 December 2022, 22:15:31 UTC | Skip test_lstm on CPU and TPU for jax OSS build. PiperOrigin-RevId: 492722650 | 03 December 2022, 22:16:07 UTC |
542b38a | Simon Butt | 03 December 2022, 21:22:25 UTC | Updated jax.tree_leaves --> jax.tree_util.tree_leaves to remove deprecation notice in jax101-pytrees tutorial Signed-off-by: Simon Butt <simonbutt123@gmail.com> | 03 December 2022, 21:22:25 UTC |
e814f70 | Yash Katariya | 03 December 2022, 04:47:01 UTC | Raise an error when a numpy input is passed with a non-trivial sharding. This can lead to weird behavior with pjit and XLA since host-local inputs are not allowed with pjit anymore. PiperOrigin-RevId: 492621424 | 03 December 2022, 04:47:45 UTC |
02fab52 | Hyeontaek Lim | 03 December 2022, 02:40:59 UTC | Add tests to check if pjit handles deleted array inputs gracefully and consistently pjit dispatch paths should check deleted array inputs when attempting to use them. These new tests ensure that various pjit dispatch paths detect and handle them gracefully and consistently. Add a check to the PyArray argument handling to make the tests pass. PiperOrigin-RevId: 492605524 | 03 December 2022, 02:41:31 UTC |
693047a | jax authors | 03 December 2022, 01:10:04 UTC | Merge pull request #13498 from jakevdp:x64-other-tests PiperOrigin-RevId: 492593760 | 03 December 2022, 01:10:04 UTC |
f22d4a8 | jax authors | 03 December 2022, 01:01:56 UTC | Merge pull request #13490 from jakevdp:x64-check-grads PiperOrigin-RevId: 492592519 | 03 December 2022, 01:01:56 UTC |
06755ad | Hyeontaek Lim | 03 December 2022, 00:38:04 UTC | Reduce the buffer size used in ShardedDeviceArrayTest.testThreadsafeIndexing testThreadsafeIndexing uses a fairly large buffer size. When overlapping many executions under a constraint host memory for testing using an alternative backend, this test may hit the maximum allowed memory use. This change reduces the buffer size by half, which is likely still interesting and runs more reliably on an alternative backend. PiperOrigin-RevId: 492588538 | 03 December 2022, 00:38:47 UTC |
c2c3669 | Peter Hawkins | 02 December 2022, 23:13:55 UTC | Remove long-deprecated method .block_host_until_ready(). PiperOrigin-RevId: 492571809 | 02 December 2022, 23:18:11 UTC |
5e102c1 | Peter Hawkins | 02 December 2022, 23:10:56 UTC | Implement .on_device_size_in_bytes() on jax.Array. This is an array present in DeviceArray that is missing from Array. PiperOrigin-RevId: 492571171 | 02 December 2022, 23:11:27 UTC |
924894f | Jake VanderPlas | 02 December 2022, 21:20:30 UTC | [x64] make tests more type-safe | 02 December 2022, 21:21:35 UTC |
9e53de8 | Jake VanderPlas | 02 December 2022, 20:51:41 UTC | [x64] make chack_grads() more type-safe | 02 December 2022, 20:51:41 UTC |
8a28ccd | jax authors | 02 December 2022, 20:06:11 UTC | Merge pull request #13491 from jakevdp:x64-stax PiperOrigin-RevId: 492526309 | 02 December 2022, 20:06:11 UTC |
f9b5312 | Peter Hawkins | 02 December 2022, 19:36:51 UTC | Do not mirror JAX config options back to ABSL flags. Currently when JAX config values are configured via ABSL, we use the ABSL flags as a source of truth: if we read or write the JAX config option, we read or write the corresponding ABSL flag. This works but has the unfortunate downside that ABSL flags are relatively slow to read, which slows down JAX every time we read a configuration option. However, there's fundamentally no reason we are mirroring the JAX configuration options back to ABSL in the first place. We can use ABSL flag parsing as a way only to populate the JAX configuration values. The downside is that if someone changes the ABSL flag values after parsing, that change will not be reflected in JAX's config values. JAX config changes after ABSL flags have been parsed must be made via the `jax.config.update()` API. This gives a decent improvement on the device_put benchmark: ``` name old cpu/op new cpu/op delta device_put 79.5µs ± 6% 69.4µs ± 7% -12.73% (p=0.000 n=10+9) name old time/op new time/op delta device_put 79.5µs ± 6% 69.4µs ± 7% -12.73% (p=0.000 n=10+9) ``` PiperOrigin-RevId: 492519085 | 02 December 2022, 19:37:22 UTC |
8431e43 | Jake VanderPlas | 02 December 2022, 18:01:59 UTC | [x64] more type safety in stax_test.py | 02 December 2022, 18:02:25 UTC |
1027d55 | jax authors | 02 December 2022, 17:00:20 UTC | Optimize core.find_top_trace This function is quite important, since it runs at every JAX primitive bind, but it included a few redundant conditionals. PiperOrigin-RevId: 492481837 | 02 December 2022, 17:00:50 UTC |
bbf22db | Adam Paszke | 02 December 2022, 15:04:13 UTC | Optimize core.find_top_trace This function is quite important, since it runs at every JAX primitive bind, but it included a few redundant conditionals. PiperOrigin-RevId: 492460102 | 02 December 2022, 15:04:52 UTC |
01377bc | jax authors | 02 December 2022, 04:43:18 UTC | Merge pull request #13485 from jakevdp:x64-random PiperOrigin-RevId: 492367153 | 02 December 2022, 04:43:18 UTC |
5927032 | jax authors | 02 December 2022, 04:36:41 UTC | Merge pull request #13482 from jakevdp:x64-signal PiperOrigin-RevId: 492367133 | 02 December 2022, 04:36:41 UTC |
e1d118c | jax authors | 02 December 2022, 04:29:46 UTC | Merge pull request #13476 from jakevdp:x64-lax-numpy PiperOrigin-RevId: 492367125 | 02 December 2022, 04:29:46 UTC |