https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
3745da2 [6/n] Introduce Module into FunctionDefLibrary. Serialize ModuleDefs to and load them from SavedModel using Python API. PiperOrigin-RevId: 520744762 06 April 2023, 00:22:44 UTC
c10cb17 Accelerate deprecation of jax.ShapedArray This is deprecated as of https://github.com/google/jax/pull/15263: most users will never need to use ShapedArray directly, and so having it exposed in the top-level public namespace causes undue confusion. PiperOrigin-RevId: 522168275 05 April 2023, 22:26:19 UTC
8c8f50f Fix tolerance and shard_count for experimental_rnn_test This should fix the current GPU test timeout. PiperOrigin-RevId: 522167894 05 April 2023, 22:19:19 UTC
f6da71c Merge pull request #15401 from mattjj:issue15400 PiperOrigin-RevId: 522115286 05 April 2023, 18:55:47 UTC
29ba2ca Report the argument path when encountering an overflow error for a Python value. PiperOrigin-RevId: 522106244 05 April 2023, 18:24:40 UTC
7b4e579 Merge pull request #15415 from jakevdp:sparse-add PiperOrigin-RevId: 522101591 05 April 2023, 18:07:50 UTC
05f32a7 [sparse] allow sparse-dense add when the output is the same size as dense input 05 April 2023, 17:39:43 UTC
0e549ac Update unit test and doc how work around jit_compile=False on TPU for native_serialization. PiperOrigin-RevId: 522077249 05 April 2023, 16:44:34 UTC
ac4942d fix conj transpose on symbolic zero fixes #15400 05 April 2023, 03:45:21 UTC
bf50551 Explicitly import jax.custom_{batching,derivatives,transpose}. https://github.com/google/jax/pull/15391 had the unintentional side effect of causing these names not to be imported by default. Restore the status quo by importing them. PiperOrigin-RevId: 521898088 04 April 2023, 23:40:15 UTC
aab24fe Merge pull request #15396 from jakevdp:conv-elem-type PiperOrigin-RevId: 521895290 04 April 2023, 23:26:30 UTC
c2fe350 future-proof lax.convert_element_type In the future, np.array(large_value, 'int32') will error 04 April 2023, 22:57:32 UTC
ffa9d01 DCE as early as possible so that `committed` is not dependent on DCE's vars PiperOrigin-RevId: 521879918 04 April 2023, 22:21:12 UTC
9095faa Remove PyBuffer type and its bindings. PiperOrigin-RevId: 521865179 04 April 2023, 21:24:23 UTC
6040580 stages should not eagerly load the executables by calling cpp_call. PiperOrigin-RevId: 521849296 04 April 2023, 20:25:24 UTC
75d0f65 Add cupti pip dependency, needed for GPU profiling. Issue https://github.com/google/jax/issues/15384 PiperOrigin-RevId: 521841461 04 April 2023, 19:55:36 UTC
c1f65fc Avoid imports from the public jax.* namespace in more places internally. This change is in preparation for more cycle breaking in the Bazel dependency graph. PiperOrigin-RevId: 521822756 04 April 2023, 18:41:40 UTC
3c1f3ab Merge pull request #15149 from sharadmv:runstate PiperOrigin-RevId: 521809360 04 April 2023, 17:56:25 UTC
efcd85d Merge pull request #15086 from shoyer:extrap PiperOrigin-RevId: 521796315 04 April 2023, 17:12:13 UTC
9bb3d86 Merge pull request #15390 from jakevdp:checkify-dynamic-slice PiperOrigin-RevId: 521790925 04 April 2023, 16:54:19 UTC
46297dc checkify: catch OOB errors in dynamic_slice This will allow checkify tests to continue working properly after #15377 04 April 2023, 15:16:59 UTC
35bfdc6 [shape_poly] Add some support for shape polymorphism for FFT, and tests PiperOrigin-RevId: 521749241 04 April 2023, 13:45:57 UTC
5101184 Add initial implementation of a `run_state` primitive 04 April 2023, 04:32:32 UTC
8a6c929 Merge pull request #15289 from cgarciae:add-missing-api-references PiperOrigin-RevId: 521617419 04 April 2023, 01:23:04 UTC
b361f4c Merge pull request #15169 from cgarciae:fix-lstm PiperOrigin-RevId: 521616002 04 April 2023, 01:13:19 UTC
14b572f Remove _compile_replicated option from `compile` since it is not needed anymore and some other cosmetic fixes. PiperOrigin-RevId: 521604489 04 April 2023, 00:17:33 UTC
4009005 Support extrapolation in jnp.interp Fixes https://github.com/google/jax/issues/14858 03 April 2023, 22:31:14 UTC
aa12e35 handle seq_lengths in lstm_ref 03 April 2023, 22:22:54 UTC
c2b15a1 Break out aot_test from array_test (for serialization and other aot APIs). PiperOrigin-RevId: 521568985 03 April 2023, 21:47:53 UTC
78678ee Rename `count_pjit_cache_miss` with `count_pjit_cpp_cache_miss` because it is confusing which cache the first function is taking about as pjit has many caches PiperOrigin-RevId: 521559652 03 April 2023, 21:15:02 UTC
6f2256a Improve the error message of device_indices_map when the sharding is not divisible by the shape rather than raising an opaque assertion error PiperOrigin-RevId: 521507810 03 April 2023, 18:05:25 UTC
05249ec [jax2tf] Add more sharding tests with shape polymorphism PiperOrigin-RevId: 521471546 03 April 2023, 15:54:58 UTC
ff313a3 [jax2tf] Skip "graph" mode primitive tests on TPUs. PiperOrigin-RevId: 521468145 03 April 2023, 15:39:36 UTC
d743d23 Convolution functions in TF, like- tf.nn.depthwise_conv2d_v2, tf.nn.conv2d_transpose_v2, tf.nn.conv2d_v2 all follow the same principal when it comes to padding(explained here- https://www.tensorflow.org/api_docs/python/tf/nn#notes_on_padding_2). These principal happens to match with that of `lax.convolution.conv_general_dilated`. So, this CL safely uses the padding(list[tuple(int, int)]) to call tf.nn.<conv> functions PiperOrigin-RevId: 521464565 03 April 2023, 15:22:02 UTC
607c7c1 Plumbing for dynamic shapes for custom calls. PiperOrigin-RevId: 521439418 03 April 2023, 13:16:12 UTC
0d32724 Merge pull request #15340 from gnecula:dim_vars3 PiperOrigin-RevId: 521424534 03 April 2023, 11:44:16 UTC
cd35e90 [shape_poly] Cleanup handling of dimension variables. We unify the way we compute with dimension variables (computing their values from the shape of the actual arguments, and also using those values to evaluate shapes that contain dimension variables). We remove DimExprValueMlir, and all computations with dimension variables and DimExpr are now done by JAX interpretation, followed by lowering to TF or StableHLO. 03 April 2023, 11:33:29 UTC
bf2c071 [jax2tf] Add test that compile_args[tuple_args] does not matter for serialization PiperOrigin-RevId: 521422653 03 April 2023, 11:32:34 UTC
2ce78ac [jax2tf] Add checks that we do not see unexpected lowered.compiler_args Some of those compile_args change the semantics and the calling convention for the lowered module. We want to be explicit about the ones that we are handling. PiperOrigin-RevId: 521419681 03 April 2023, 11:13:31 UTC
b0a6cdb Merge pull request #15341 from gnecula:tf_grad PiperOrigin-RevId: 521409693 03 April 2023, 10:15:36 UTC
cf599c7 Avoid re-constructing set. Expensive at scale. PiperOrigin-RevId: 521310375 02 April 2023, 21:42:35 UTC
e9fc02e [jax2tf] Cleanup of handling of tf.custom_gradient There are some incompatibilities between JAX and TF when it comes to gradients for functions that take non-float arguments, or whose arguments are unused. JAX uses float0 for those gradients, but this is not a type that TF recognizes. Furthermore, under tf.function context TF will pass `None` in place of cotangents for the outputs that have non-float type. Previously, the workarounds for these were in the JAX function that was converted to obtain the TF gradient. Here we move those workarounds in the TF-land of jax2tf. This will enable us to expand jax_export with handling of gradients. jax_export is pure JAX, and hence it is important to move the TF workarounds outside of the converte JAX functions. This is just a refactor. 02 April 2023, 19:07:44 UTC
b8dfb97 Integrate StableHLO at openxla/stablehlo@7a93924 PiperOrigin-RevId: 521293524 02 April 2023, 18:14:01 UTC
88f77bb [jax2tf] Removed call_tf tests that are not applicable anymore. A recent change in TensorFlow makes copies of np.ndarray when they are turned into tf.constant. This means that call_tf cannot guarantee anymore no-copy. Removing those tests, and the paragraph in the documentation that describes this property. PiperOrigin-RevId: 521120090 01 April 2023, 10:07:13 UTC
2432ade Add Deprecation warning if gda_serialization is imported PiperOrigin-RevId: 521081821 01 April 2023, 04:28:07 UTC
d27a80d Rename gda_serialization to array_serialization but keep gda_serialization around until it is included in a jax release so that OSS projects can be moved to array_serialization PiperOrigin-RevId: 521055760 01 April 2023, 01:07:51 UTC
0b31e8b Remove dead code from pxla.py PiperOrigin-RevId: 521003815 31 March 2023, 20:51:49 UTC
db025df Stop importing old `tree_util` APIs conveniently and set explicit time for removal. PiperOrigin-RevId: 521003611 31 March 2023, 20:45:10 UTC
82fcfc3 Buffer -> Array in some pxla type annotations. PiperOrigin-RevId: 520975371 31 March 2023, 18:42:22 UTC
b37c741 accelerate deprecation of jax.curry PiperOrigin-RevId: 520958381 31 March 2023, 17:37:39 UTC
ffb8352 Merge pull request #15342 from jakevdp:doc-requirements PiperOrigin-RevId: 520955387 31 March 2023, 17:27:30 UTC
2841bd3 Merge pull request #15321 from jakevdp:remove-msort PiperOrigin-RevId: 520952178 31 March 2023, 17:16:18 UTC
6e00ba8 Enable more mesh shape assignment We now sort the mesh dims by size first. Smaller dims have fewer choices so they should be assigned first. PiperOrigin-RevId: 520942700 31 March 2023, 16:36:16 UTC
dfbbc25 Merge pull request #15317 from ROCmSoftwarePlatform:rocm_pmap_fix PiperOrigin-RevId: 520934992 31 March 2023, 16:05:07 UTC
abf1acf Replace references to jax.interpreters with jax._src.interpreters in JAX core. PiperOrigin-RevId: 520933067 31 March 2023, 15:58:00 UTC
182cc98 Merge pull request #15323 from NeilGirdhar:fix_rayleigh PiperOrigin-RevId: 520932851 31 March 2023, 15:50:40 UTC
9ec3ad1 DOC: pin newest sphinx-book-theme 31 March 2023, 15:42:34 UTC
749dc1b Remove deprecated function jnp.msort 31 March 2023, 15:24:36 UTC
0df2ddc Merge pull request #15232 from gnecula:tf_arange PiperOrigin-RevId: 520914838 31 March 2023, 14:11:19 UTC
c368c69 [shape_poly] Extend the handling of jnp.arange with shape polymorphism. Previously, only `arange(stop, dtype=...)` was being handled in presence of shape polymorphism. Here we extend to add support for `start` and `step` to be also present. There are still plenty of restrictions: * no floating point constants are allowed among start, stop and step * we must resolve statically if step is positive or negative * we must resolve statically if the distance between start and stop is negative or positive. 31 March 2023, 12:41:26 UTC
76b922a Merge pull request #15337 from mattjj:axis-name-shadowing-2 PiperOrigin-RevId: 520838748 31 March 2023, 06:01:02 UTC
6a2b081 fix bug from #15335 by checking main_trace tag 31 March 2023, 05:35:03 UTC
12bcdeb Merge pull request #15335 from mattjj:axis-name-shadowing PiperOrigin-RevId: 520829991 31 March 2023, 04:56:42 UTC
211bc29 add assertions for axis name shadowing bugs 31 March 2023, 04:31:02 UTC
d383ab6 Merge pull request #15255 from eltociear:patch-6 PiperOrigin-RevId: 520814903 31 March 2023, 03:38:17 UTC
8e17da4 Merge pull request #15322 from jakevdp:pre-commit PiperOrigin-RevId: 520793950 31 March 2023, 01:26:27 UTC
248ffc2 Merge pull request #15329 from jakevdp:padfunc-protocol-2 PiperOrigin-RevId: 520793934 31 March 2023, 01:19:43 UTC
61064a1 Merge pull request #15331 from jakevdp:protocol PiperOrigin-RevId: 520793925 31 March 2023, 01:12:41 UTC
f8bff8d Merge pull request #15332 from mattjj:shmap-vmap-closure PiperOrigin-RevId: 520788236 31 March 2023, 00:43:02 UTC
0bb4685 expose `compiler_options` on `compile()` Co-authored-by: Roy Frostig <frostig@google.com> PiperOrigin-RevId: 520782460 31 March 2023, 00:14:26 UTC
7c3c46c [shard-map] handle closed-over vmap tracers 30 March 2023, 23:43:40 UTC
d58c970 Merge pull request #15327 from jakevdp:fix-user-guides PiperOrigin-RevId: 520767709 30 March 2023, 23:14:11 UTC
7d5047f Merge pull request #15328 from NeilGirdhar:fix_geom PiperOrigin-RevId: 520767692 30 March 2023, 23:07:12 UTC
6d006b5 [typing] use protocol for cumulative reductions 30 March 2023, 22:43:43 UTC
92386b8 [typing] use protocol for pad stat_func 30 March 2023, 22:07:47 UTC
78204f7 Fix broadcasting in jax.random.geometric 30 March 2023, 21:54:18 UTC
69c9660 Raise deprecation warnings for `{in|out}_axis_resources` for pjit and `axis_resources` for with_sharding_constraint PiperOrigin-RevId: 520748845 30 March 2023, 21:51:01 UTC
36bf14b Remove some dead code. PiperOrigin-RevId: 520746309 30 March 2023, 21:41:26 UTC
ec63d69 DOC: fix headings in user_guides 30 March 2023, 21:39:25 UTC
ef29ff1 Merge pull request #15300 from skye:version PiperOrigin-RevId: 520744521 30 March 2023, 21:34:16 UTC
01c6863 Merge pull request #15301 from google:timeout PiperOrigin-RevId: 520742762 30 March 2023, 21:26:50 UTC
16ca0ca Relax the tolerance of testCauchyLogCdf PiperOrigin-RevId: 520741306 30 March 2023, 21:19:50 UTC
31eeaed Split mlir.py and xla.py into separate Bazel targets. PiperOrigin-RevId: 520737811 30 March 2023, 21:06:16 UTC
c978df5 Delete unused functions from dispatch.py and pjit.py PiperOrigin-RevId: 520730163 30 March 2023, 20:38:44 UTC
1d1b131 Fix broadcasting in jax.random.rayleigh 30 March 2023, 20:38:08 UTC
b30e6e7 CI: add EOF and debug precommit hooks 30 March 2023, 20:29:50 UTC
23451dc Merge pull request #15303 from jakevdp:lax-asarray PiperOrigin-RevId: 520717999 30 March 2023, 20:11:11 UTC
13e45c8 [ROCm]: Run pmap test on specific number of GPUs 30 March 2023, 18:34:47 UTC
8f72454 Add internal jax.lax.asarray utility 30 March 2023, 17:21:55 UTC
67a28ce Relax test tolerances for testLogisticPpf. Fixes a test failure in CI. PiperOrigin-RevId: 520649225 30 March 2023, 15:41:56 UTC
dedfc8d Merge pull request #15282 from JiaYaobo:geom_random PiperOrigin-RevId: 520635974 30 March 2023, 14:45:19 UTC
1fd6e01 Merge pull request #15287 from gnecula:tf_dim_vars PiperOrigin-RevId: 520633830 30 March 2023, 14:37:47 UTC
794769c Merge pull request #15302 from mattjj:pmap-pytree-prefix-errors PiperOrigin-RevId: 520632081 30 March 2023, 14:29:51 UTC
0a2e383 Merge pull request #15297 from jakevdp:finfo-props PiperOrigin-RevId: 520632058 30 March 2023, 14:22:28 UTC
081b86b [shape_poly] Improved computation of dimension variables for native serialization Previously for native serialization we could only support polymorphic_shapes where the specification was a simple dimension variable. E.g., we could not handle a specification where `polymorphic_shapes="2*b"` because there was no way to recover the value of `b` from the actual shape. (For non-native serialization we were supporting some limited equation solving.) The above is important, e.g., for the gradient of functions like `jnp.concatenate([x, x])`, where the output shape if `2 *b`. This is possible because in #15258 we have brought the computation of the dimension variables into jax_export. What we do here is to even out the support for native serialization to have the same power as the non-native one. We do this by reusing the same `shape_poly.prepare_dim_var_env` that we use for non-native serialization. After we land this, we will refactor the shape environment to be cleaner. 30 March 2023, 13:51:24 UTC
47177e1 Split more targets out the main JAX Bazel target. Namely: * abstract_arrays * ad_util * api_util * interpreters/partial_eval * lax_reference PiperOrigin-RevId: 520618715 30 March 2023, 13:12:45 UTC
81de5b7 improve pmap in_axes/out_axes pytree prefix error messages 29 March 2023, 23:56:40 UTC
3135fbc [JAX] Delete _DeviceArray and DeviceArray. PiperOrigin-RevId: 520453090 29 March 2023, 22:07:14 UTC
30a51b2 Update version and changelog after jax 0.4.8 release 29 March 2023, 21:27:09 UTC
6927e5d Add 1-hour timeout to each Cloud TPU CI job. Sometimes they hang, and the default timeout is 6 hours, which is way too long. 29 March 2023, 21:23:45 UTC
back to top