https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
bdb01e8 Chain value error in flatten_axis_resources. PiperOrigin-RevId: 493596333 08 December 2022, 15:37:14 UTC
440b25b Merge pull request #13427 from gnecula:tf_native_poly PiperOrigin-RevId: 493800880 08 December 2022, 06:41:23 UTC
8fb344a [jax2tf] An alternative support for shape polymorphism for native serialization. jax2tf already supports many cases of shape polymorphism, e.g., those where the shapes of all intermediates can be expressed as polynomials in the dimension variables in the input. We want to achieve the same same coverage, or more, while using StableHLO as the lowering format, rather than tf.Graph. For native serialization we will support two lowering implementations: * one is using the growing support in JAX for dynamic shapes, of which shape polymorphism is a special case. This implementation is enabled with the --jax_dynamic_shapes flag. At the moment, the JAX dynamic shapes support is still incomplete and over 300 jax2tf shape polymorphism tests fail. * a new one (added) here in which we form a Jaxpr using abstract values that express dimension sizes as dimension polynomials (as for the standard jax2tf). Then we lower the Jaxpr to StableHLO. This implementation is enabled when --jax_dynamic_shapes is off. With this implementation only 50 jax2tf tests fail (to be fixed separately). The key contribution here is to enable lowering a Jaxpr that contains dimension polynomials in some of the intermediate shapes. Many lowering rules already have some partial support for Jaxprs where the shapes contain `Var`s. To the extent possible, we try to write lowering rules that should cover both cases of dynamic shapes: Var or polynomials in shapes. The lowering convention is that at top level we collect the sorted list of dimension variable names in the inputs, and we store it in ModuleContext.dim_vars. All IR functions will take N additional prefix arguments of int32 type containing the values of the dimension variables. This is stored as a list of `ir.Value` in `LoweringContext.dim_var_values`. Note that the Jaxprs are not changed to have extra Vars for the dimension variable values. An alternative implementation could work by transforming the Jaxpr to replace dimension polynomials into Vars. The key code pattern used in the lowering rule is:: if not core.is_constant_shape(shape): # Handles both Var, and polynomials shape = mlir.eval_dynamic_shape(ctx, shape) return mhlo.DynamicXXX(..., shape) else: return mhlo.XXX(..., shape) with `mlir.eval_dynamic_shape` handling both cases:: def eval_dynamic_shape(ctx, shape): if config.jax_dynamic_shapes: # Using Var return ... subst using ctx.axis_size_env ... else: # Using polynomials return ... subst using ctx.module_context.dim_vars and ctx.dim_var_values In order to support the above some lowering functions need to take a LoweringContext parameter, e.g., mlir.broadcast_mhlo. I expect that the changes here will improve the --jax_dynamic_shapes coverage as well. 08 December 2022, 06:19:35 UTC
0118f8d Prepare for jax and jaxlib 0.4.0 release PiperOrigin-RevId: 493733609 08 December 2022, 00:02:24 UTC
dd64760 Make `jnp.copy` work with all shardings especially `PmapSharding`. This fixes the problem where a `jax.Array` with a `PmapSharding` round tripped through host and returned a `jax.Array` with a `SingleDeviceSharding`. Now, `jnp.copy` works without going through a round trip via host and maintains the sharding of the input array across all the Shardings we have. PiperOrigin-RevId: 493728354 07 December 2022, 23:41:47 UTC
3159642 Merge pull request #13547 from jakevdp:numpy-msort PiperOrigin-RevId: 493649811 07 December 2022, 18:50:33 UTC
ce990cf Merge pull request #13546 from jakevdp:shard-types PiperOrigin-RevId: 493637391 07 December 2022, 18:11:53 UTC
09d1b6d Deprecate jnp.msort following deprecation of numpy.msort 07 December 2022, 18:08:18 UTC
777754d [typing] fix Shard/Sharding type TODO 07 December 2022, 17:44:03 UTC
794cec1 Merge pull request #13544 from jakevdp:fix-readme PiperOrigin-RevId: 493625570 07 December 2022, 17:28:25 UTC
79c8d66 README: fix badge URL 07 December 2022, 16:48:40 UTC
6e0c802 Fix type annotation for bcoo_update_layout. PiperOrigin-RevId: 493567424 07 December 2022, 12:25:25 UTC
ac72346 Ensure that the initial dynamic_trace_state is canonicalized. The non-canonical state meant that we were falling back to a more expensive comparison for the first jit-compiled function in the program. I doubt there will be any impact on real benchmarks, but this perturbs the results of running a single microbenchmark in isolation. PiperOrigin-RevId: 493489154 07 December 2022, 04:39:53 UTC
b168df0 Merge pull request #13539 from wookayin:fix/config-typing PiperOrigin-RevId: 493473725 07 December 2022, 03:10:02 UTC
cd22585 Fix a false-positive typing warning on jax.default_device Consider the following code where static type checkers can report an error: ```python CPU = jax.devices('cpu')[0] with jax.default_device(CPU): ... # ^^^ ``` Error message: ``` Pyright: Argument of type "Device" cannot be assigned to parameter "new_val" of type "NoDefault"   "Device" is incompatible with "NoDefault" (reportGeneralTypeIssues) ``` This is because `_StateContextManager.__call__` does not have a proper type annotation on the parameter, unlike the attribute `_default_value` which has a type annotation. Adding a `Any` to the parameter would make the error disappear. 07 December 2022, 02:05:35 UTC
61398ff Merge pull request #13536 from jakevdp:quickstart-timing PiperOrigin-RevId: 493414640 06 December 2022, 22:40:44 UTC
23b808f Merge pull request #13446 from google:maxfail PiperOrigin-RevId: 493414635 06 December 2022, 22:34:01 UTC
7b59ce2 DOC: pre-execute the quickstart notebook on GPU 06 December 2022, 21:24:02 UTC
1132098 [PJRT:C] Separate loading PJRT plugin from creating PJRT client. - Add xla_client.maybe_load_pjrt_plugins which maybe load PJRT plugins from a hardcoded set. - Call xla_client.maybe_load_pjrt_plugins in xla_bridge beforer initializing backends. - Add binding of python method load_pjrt_plugin to LoadPjrtPlugin which does dlopen and dlsym. - Remove loading PJRT plugin from tpu_initializer_helper.cc. - Add an extra call to LoadPjrtPlugin when getting the PJRT_Api* to be backward compatible. PiperOrigin-RevId: 493381393 06 December 2022, 20:29:38 UTC
1bbcec7 Update FAQ since buffer donation is implemented on CPU. PiperOrigin-RevId: 493372426 06 December 2022, 19:57:34 UTC
2fb9238 Merge pull request #13534 from jakevdp:fix-rand-test PiperOrigin-RevId: 493349509 06 December 2022, 18:40:35 UTC
7e3f674 random_test: skip singular covariance test on accelerators 06 December 2022, 17:26:23 UTC
edaa562 Merge pull request #13517 from froystig:rng-part-jax2tf PiperOrigin-RevId: 493317025 06 December 2022, 16:46:28 UTC
72f97f7 Merge pull request #13523 from jakevdp:sparse-unused PiperOrigin-RevId: 493312001 06 December 2022, 16:31:49 UTC
9f5a631 Merge pull request #13518 from jakevdp:multiv-normal PiperOrigin-RevId: 493311895 06 December 2022, 16:31:32 UTC
33a1b88 Mark arguments to ufuncs as positional-only. PiperOrigin-RevId: 493311821 06 December 2022, 16:24:11 UTC
a790016 Merge pull request #12962 from hawkinsp:rocm PiperOrigin-RevId: 493308550 06 December 2022, 16:09:45 UTC
938b625 Remove unreliable call_tf_test PiperOrigin-RevId: 493225546 06 December 2022, 08:31:19 UTC
e7e9687 Allow pjit's C++ dispatch path to operate on uncommitted array only if it belongs on a single device. This will bring pjit's dispatch performance in line with `jit` to prepare for jit/pjit frontend merge. PiperOrigin-RevId: 493164446 06 December 2022, 02:09:59 UTC
2fc7bbc Cache the exec time `tf.get_current_name_scope` into `_thread_local_state.exec_time_tf_name_scope` and add it as prefix during tracing. Also move the test cases to right code location. PiperOrigin-RevId: 493163018 06 December 2022, 02:02:44 UTC
516f0d0 Support negative axes in all_gather. Previously we didn't check for these and they caused crashes during MHLO verification. PiperOrigin-RevId: 493160581 06 December 2022, 01:48:50 UTC
189d0c0 Merge pull request #13525 from jakevdp:typing-extensions PiperOrigin-RevId: 493158064 06 December 2022, 01:35:04 UTC
9400555 Merge pull request #13526 from jakevdp:isinstance-array PiperOrigin-RevId: 493156840 06 December 2022, 01:27:55 UTC
88bfe90 Replace utility function with isinstance check 06 December 2022, 00:13:33 UTC
382a248 Merge pull request #13512 from jakevdp:fix-optimizers PiperOrigin-RevId: 493139751 06 December 2022, 00:11:59 UTC
4389216 Remove typing_extensions dependency 05 December 2022, 23:42:26 UTC
a42e9e2 [sparse] delete _bcoo_unbatch utility 05 December 2022, 22:35:34 UTC
23261d7 Merge pull request #13515 from jakevdp:sparse-typing PiperOrigin-RevId: 493113475 05 December 2022, 22:28:18 UTC
1f9c988 Use `_thread_local_state.__dict__.get()` instead of `getattr(_thread_local_state, ...)`. `getattr` turns out to be a tiny bit slower than `__get__()` on `__dict__` in the case that the attribute is absent. `getattr` appears to form an error message that is thrown away if a default is present. Improves the device_put benchmark: ``` name old cpu/op new cpu/op delta device_put 51.4µs ± 1% 48.9µs ± 3% -4.87% (p=0.000 n=8+9) name old time/op new time/op delta device_put 51.4µs ± 1% 48.9µs ± 3% -4.87% (p=0.000 n=8+9) ``` PiperOrigin-RevId: 493108288 05 December 2022, 22:09:47 UTC
c376836 [typing] annotate jax.experimental.sparse 05 December 2022, 22:05:25 UTC
45707a2 Merge pull request #13499 from jakevdp:error-docs PiperOrigin-RevId: 493092719 05 December 2022, 21:14:37 UTC
25d1a0b Add cudnn 86 (for cuda 11.8) so that I can release cuda 11.8 nightlies. PiperOrigin-RevId: 493086060 05 December 2022, 20:50:09 UTC
58d6a3b random.multivariate_normal: add note about singular covariance 05 December 2022, 20:43:05 UTC
29942e3 docs: add another example to the ConcretizationTypeError docs 05 December 2022, 19:24:54 UTC
431c51a rename `iota_32x2_shape` to `iota_2x32_shape` ... for consistency with other internal Threefry primitive names. 05 December 2022, 19:09:56 UTC
dad21d3 Merge pull request #13506 from simonbutt:bugfix/jax101-pytrees PiperOrigin-RevId: 493052164 05 December 2022, 18:45:21 UTC
b8b6e27 Add typehints and point to the correct endpoint of Mesh and PartitionSpec in the args section. PiperOrigin-RevId: 493035898 05 December 2022, 17:49:18 UTC
d317cfa Revert part of #13498 05 December 2022, 17:21:09 UTC
75af6b5 add a jax2tf translation rule for the shaped-iota primitive This allows for jax2tf conversion of the partitionable Threefry RNG. 05 December 2022, 17:19:25 UTC
a3483db docstring for shaped iota primitive 05 December 2022, 17:15:27 UTC
401fbb6 Disable xmap_test on TPU under asan due to CI timeouts. PiperOrigin-RevId: 492994226 05 December 2022, 14:52:09 UTC
55d6daa Skip test_lstm on CPU and TPU for jax OSS build. PiperOrigin-RevId: 492722650 03 December 2022, 22:16:07 UTC
542b38a Updated jax.tree_leaves --> jax.tree_util.tree_leaves to remove deprecation notice in jax101-pytrees tutorial Signed-off-by: Simon Butt <simonbutt123@gmail.com> 03 December 2022, 21:22:25 UTC
e814f70 Raise an error when a numpy input is passed with a non-trivial sharding. This can lead to weird behavior with pjit and XLA since host-local inputs are not allowed with pjit anymore. PiperOrigin-RevId: 492621424 03 December 2022, 04:47:45 UTC
02fab52 Add tests to check if pjit handles deleted array inputs gracefully and consistently pjit dispatch paths should check deleted array inputs when attempting to use them. These new tests ensure that various pjit dispatch paths detect and handle them gracefully and consistently. Add a check to the PyArray argument handling to make the tests pass. PiperOrigin-RevId: 492605524 03 December 2022, 02:41:31 UTC
693047a Merge pull request #13498 from jakevdp:x64-other-tests PiperOrigin-RevId: 492593760 03 December 2022, 01:10:04 UTC
f22d4a8 Merge pull request #13490 from jakevdp:x64-check-grads PiperOrigin-RevId: 492592519 03 December 2022, 01:01:56 UTC
06755ad Reduce the buffer size used in ShardedDeviceArrayTest.testThreadsafeIndexing testThreadsafeIndexing uses a fairly large buffer size. When overlapping many executions under a constraint host memory for testing using an alternative backend, this test may hit the maximum allowed memory use. This change reduces the buffer size by half, which is likely still interesting and runs more reliably on an alternative backend. PiperOrigin-RevId: 492588538 03 December 2022, 00:38:47 UTC
c2c3669 Remove long-deprecated method .block_host_until_ready(). PiperOrigin-RevId: 492571809 02 December 2022, 23:18:11 UTC
5e102c1 Implement .on_device_size_in_bytes() on jax.Array. This is an array present in DeviceArray that is missing from Array. PiperOrigin-RevId: 492571171 02 December 2022, 23:11:27 UTC
924894f [x64] make tests more type-safe 02 December 2022, 21:21:35 UTC
9e53de8 [x64] make chack_grads() more type-safe 02 December 2022, 20:51:41 UTC
8a28ccd Merge pull request #13491 from jakevdp:x64-stax PiperOrigin-RevId: 492526309 02 December 2022, 20:06:11 UTC
f9b5312 Do not mirror JAX config options back to ABSL flags. Currently when JAX config values are configured via ABSL, we use the ABSL flags as a source of truth: if we read or write the JAX config option, we read or write the corresponding ABSL flag. This works but has the unfortunate downside that ABSL flags are relatively slow to read, which slows down JAX every time we read a configuration option. However, there's fundamentally no reason we are mirroring the JAX configuration options back to ABSL in the first place. We can use ABSL flag parsing as a way only to populate the JAX configuration values. The downside is that if someone changes the ABSL flag values after parsing, that change will not be reflected in JAX's config values. JAX config changes after ABSL flags have been parsed must be made via the `jax.config.update()` API. This gives a decent improvement on the device_put benchmark: ``` name old cpu/op new cpu/op delta device_put 79.5µs ± 6% 69.4µs ± 7% -12.73% (p=0.000 n=10+9) name old time/op new time/op delta device_put 79.5µs ± 6% 69.4µs ± 7% -12.73% (p=0.000 n=10+9) ``` PiperOrigin-RevId: 492519085 02 December 2022, 19:37:22 UTC
8431e43 [x64] more type safety in stax_test.py 02 December 2022, 18:02:25 UTC
1027d55 Optimize core.find_top_trace This function is quite important, since it runs at every JAX primitive bind, but it included a few redundant conditionals. PiperOrigin-RevId: 492481837 02 December 2022, 17:00:50 UTC
bbf22db Optimize core.find_top_trace This function is quite important, since it runs at every JAX primitive bind, but it included a few redundant conditionals. PiperOrigin-RevId: 492460102 02 December 2022, 15:04:52 UTC
01377bc Merge pull request #13485 from jakevdp:x64-random PiperOrigin-RevId: 492367153 02 December 2022, 04:43:18 UTC
5927032 Merge pull request #13482 from jakevdp:x64-signal PiperOrigin-RevId: 492367133 02 December 2022, 04:36:41 UTC
e1d118c Merge pull request #13476 from jakevdp:x64-lax-numpy PiperOrigin-RevId: 492367125 02 December 2022, 04:29:46 UTC
ff6f215 Merge pull request #13481 from jakevdp:x64-line-search PiperOrigin-RevId: 492359730 02 December 2022, 03:40:58 UTC
934bc4e Move `PartitionSpec` and `Mesh` out of experimental and into the `sharding` namespace. The new API endpoint is `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`. PiperOrigin-RevId: 492358238 02 December 2022, 03:28:32 UTC
ed9519d Merge pull request #13484 from google:index PiperOrigin-RevId: 492323132 02 December 2022, 00:02:50 UTC
fdf5894 [x64] make random_test more type-safe 01 December 2022, 23:51:37 UTC
a2870a1 Add jax.Array to the index page 01 December 2022, 23:44:51 UTC
a6c7fa8 Merge pull request #13483 from google:cross PiperOrigin-RevId: 492314228 01 December 2022, 23:23:19 UTC
70d5081 Add cross-linking for the migration guide and the parallelism with JAX tutorial Co-authored-by: Skye Wanderman-Milne <skyewm@google.com> 01 December 2022, 22:42:59 UTC
8f92d62 Merge pull request #13465 from jakevdp:x64-lax-numpy-test PiperOrigin-RevId: 492299616 01 December 2022, 22:24:51 UTC
37acc6e [x64] more type safety in scipy.optimize.line_search 01 December 2022, 22:04:39 UTC
60a45a4 Merge pull request #13480 from skye:doc_url PiperOrigin-RevId: 492293304 01 December 2022, 22:01:16 UTC
3cf2924 [x64] minor fixes for lax_numpy_test type safety 01 December 2022, 21:56:42 UTC
d25a96c [x64] more type safety in jax.scipy.signal 01 December 2022, 21:43:07 UTC
b3b7eb6 Merge pull request #13464 from jakevdp:x64-host-callback PiperOrigin-RevId: 492284287 01 December 2022, 21:25:35 UTC
82b442f [docs] Replace one more jax_Array.html reference I missed this in #13479, thanks @yashk2810 for flagging! 01 December 2022, 20:50:55 UTC
b2b4fcb Merge pull request #13479 from skye:doc_url PiperOrigin-RevId: 492273702 01 December 2022, 20:45:40 UTC
89c8df6 Merge pull request #13434 from jurahul:master PiperOrigin-RevId: 492262252 01 December 2022, 20:02:34 UTC
51db1cf [docs] Rename "JAX in Parallelism" files so the URL matches the title. 01 December 2022, 19:53:31 UTC
904398a [x64] better type safety for host_callback 01 December 2022, 19:47:07 UTC
b037feb [x64] more type safety for lax_numpy-related tests 01 December 2022, 19:18:02 UTC
6213228 Fix `vmap(jvp(pjit(f)))` when pjit doesn't have any axis_resources PiperOrigin-RevId: 492238366 01 December 2022, 18:42:26 UTC
5847efc Merge pull request #13458 from jakevdp:f-strings PiperOrigin-RevId: 492219655 01 December 2022, 17:38:20 UTC
26d9837 Switch to new-style f-strings 01 December 2022, 17:14:16 UTC
79190b1 Merge pull request #13463 from jakevdp:x64-tests-types PiperOrigin-RevId: 492211762 01 December 2022, 17:12:20 UTC
f0f1d01 Merge pull request #13470 from gnecula:tf_roll_poly PiperOrigin-RevId: 492211433 01 December 2022, 17:05:24 UTC
fcaf7f1 [jax2tf] Fix the handling of jnp.roll for polymorphic shapes 01 December 2022, 10:18:47 UTC
4ca05f4 [call_tf] Use the same platform for TF lowering as the embedding JAX computation This requires some changes for abstract evaluation, when JAX does not use a specific platform. Also attempt to fix the case when the TF lowering fails because the TF computation uses a tf.Variable on another device as that used for lowering. PiperOrigin-RevId: 492112847 01 December 2022, 07:22:24 UTC
4443b86 Remove local imports of array.py. The remaining local imports are in pxla.py but I will chip away at them when we delete SDA and move some more APIs out of experimental. PiperOrigin-RevId: 492033543 30 November 2022, 23:26:03 UTC
f09fd8a [x64] minor test-only updates for better type safety 30 November 2022, 23:18:40 UTC
e835739 Remove an unnecessary include/ from pybind11 include paths. PiperOrigin-RevId: 492016679 30 November 2022, 22:20:02 UTC
cfee99e Merge pull request #13435 from jakevdp:unused-code PiperOrigin-RevId: 491985960 30 November 2022, 20:18:50 UTC
back to top