https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
b576417 Update README for jaxlib 0.1.50 release. (#3574) 26 June 2020, 17:20:59 UTC
66cea02 Fix test failures on GPU. (#3572) 26 June 2020, 16:50:22 UTC
8f86b13 Update docker script for CUDA 11. (#3571) 26 June 2020, 15:20:51 UTC
26c6c3a fix error when doing forward-mode of odeint (#3566) fixes #3558 26 June 2020, 03:57:34 UTC
4b1bb18 Fix logsumexp test (#3563) Previously, this test could generate an axis out of bounds error with large num_generated_cases. (discovered in the process of testing #3561) 26 June 2020, 03:08:14 UTC
062ce29 removed stale faq entries (#3565) 26 June 2020, 02:17:24 UTC
c2501d1 update version and changelog for pypi (#3564) 26 June 2020, 00:48:21 UTC
db80ca5 allow closures for odeint dynamics functions (#3562) * allow closures for odeint dynamics functions fixes #2718, #3557 * add tests for odeint dynamics closing over tracers 26 June 2020, 00:36:17 UTC
2a6fc31 Update XLA in preparation for a new jaxlib release (0.1.50). (#3560) 25 June 2020, 19:12:47 UTC
a141cc6 Make CUDA wheels manylinux2010 compliant, add CUDA 11, drop CUDA 9.2 (#3555) * Use dynamic loading to locate CUDA libraries in jaxlib. This should allow jaxlib CUDA wheels to be manylinux2010 compliant. * Tag CUDA jaxlib wheels as manylinux2010. Drop support for CUDA 9.2, add support for CUDA 11.0. * Reorder CUDA imports. 25 June 2020, 18:37:14 UTC
c9670d5 Fix lazy broadcast issue (#3536) 25 June 2020, 14:50:11 UTC
93054fa add remat to top-level docs (#3554) 25 June 2020, 14:26:26 UTC
32e419d Fix eigh JVP to ensure that both the primal and tangents of the eigen… (#3550) * Fix eigh JVP to ensure that both the primal and tangents of the eigenvalues are real. Add test to jax.test_util.check_jvp that ensure the primals and both the primals and tangents produced by a JVP rule have identical types. * Cast input to static indexing grad tests to a JAX array so new type check passes. 25 June 2020, 12:14:54 UTC
696958d add remat docstring (#3542) * add remat docstring first part of addressing #3314 24 June 2020, 23:11:26 UTC
7e5407c Fix typos pytrees page on readthedocs (#3548) 24 June 2020, 22:00:24 UTC
a9c7413 Fix failure in PyTorch array interoperability test in non-x64 mode. (#3546) 24 June 2020, 20:15:59 UTC
e680304 Remove warning suppression for tuple and list arguments to reductions. (#3545) Fix callers. 24 June 2020, 19:59:31 UTC
677baa5 Clarify docstrings regarding usage of static arguments in jit and vmap. (#3484) The docstring for pmap does not currently mention that any "non-data" arguments need to be indicated in `static_broadcasted_argnums`. This commit updates the docs to parallel those for `jax.jit` which does explain this. Additionally, a remark is added to the `static_*` argument descriptions on both jit and pmap, so that this point can be understood without reading the whole docstring. 24 June 2020, 19:47:09 UTC
f036f5d Avoid direct type/dtype comparisons to fix NumPy 1.19 deprecation war… (#3543) * Avoid direct type/dtype comparisons to fix NumPy 1.19 deprecation warnings. * Pin a newer tf-nightly to fix jax2tf tests for NumPy 1.19.0 24 June 2020, 19:19:00 UTC
90894a2 Fix test failures due to RuntimeWarning in variance tests. (#3541) 24 June 2020, 18:41:20 UTC
28262da Restrict .compress tests to arrays only. (#3540) Fixes test failures with 100 generated cases. 24 June 2020, 18:31:37 UTC
a422f63 Merge quantile and nanquantile implementations. (#3539) Allows more code sharing, and adds support for other interpolation modes to nanquantile(). 24 June 2020, 18:25:47 UTC
a44bc0c Add np.diag_indices_from (#3500) 24 June 2020, 16:31:16 UTC
fede6d9 Add missing functions to documentation. (#3538) 24 June 2020, 16:16:56 UTC
c357aad Implement nan-quantile/median/percentile (#2397) * Implement nan-quantile/median/percentile * Fix flake8 error. Co-authored-by: Peter Hawkins <phawkins@google.com> 24 June 2020, 15:16:45 UTC
db2291b Fixed issue with the 'compress' method of 'jnp.ndarray' (#3537) * Fixed issue with the 'compress' method of 'jnp.ndarray' behaving differently than the free function * Added test 24 June 2020, 15:13:56 UTC
319eeaf Future warning about lists and tuples (#3369) 24 June 2020, 14:54:06 UTC
a6e3f99 Add np.unwrap (#3527) 24 June 2020, 13:22:35 UTC
edcd43a de-duplicate constants when lowering to XLA Co-authored-by: Matthew Johnson <mattjj@google.com> 24 June 2020, 02:03:20 UTC
9d173c6 Support `b` and `return_sign` in scipy.special.logsumexp (#3488) 23 June 2020, 22:36:45 UTC
74ee2ef avoid value-based error check in random.choice (#3531) 23 June 2020, 21:03:36 UTC
a45e283 add back a full_lower, dropped in #3491 (#3530) 23 June 2020, 19:08:12 UTC
7527830 refactor call primitives, simpler param processing (#3491) 23 June 2020, 16:39:45 UTC
d5a5d30 lax.sort: allow any sequence of Arrays, not just tuples (#3367) 23 June 2020, 15:28:04 UTC
7cc2650 Fix lax_reference's round for edge case inputs (#3309) * Fix lax_reference's round for edge case inputs - round(8388609) would compute trunc(8388609 + 0.5) == 8388610. Fix this by not modifying sufficiently inputs. - round(0.499999970198) would compute trunc(0.499999970198 + 0.5) == 1.0 Fix this by explicitly special casing the first float before 0.5. * Add explicit casts to lax_reference implementation of `round`. Co-authored-by: Peter Hawkins <phawkins@google.com> 23 June 2020, 14:35:41 UTC
86fcfbf Fix memory leak when no axis is provided to pmap. (#3394) * Fix memory leak when no axis is provided to pmap. * Work around flake8 false positive. Co-authored-by: Matthew Johnson <mattjj@google.com> 23 June 2020, 13:29:58 UTC
490c853 Adds boolean support for bitwise not and unittests for boolean support on logical operations. (#3483) Co-authored-by: Thomas Keck <thomaskeck@google.com> 23 June 2020, 13:25:53 UTC
5ee6bc0 Remove unnecessary static_argnum in np.gradient (#3512) 23 June 2020, 03:46:41 UTC
0249492 fix an issue with newer versions of pytype (#3526) 23 June 2020, 03:04:07 UTC
33c455a Add jax.scipy.signal.detrend (#3516) 23 June 2020, 02:49:00 UTC
ca5b0b1 Cast int8 to bool for lax.not in jax2tf. (#3519) 23 June 2020, 02:44:09 UTC
046006e Fix typo: np.bool -> np.bool_ (#3525) Replaced np.bool (which is just bool) with np.bool_, which is numpy's Boolean type. 23 June 2020, 02:43:25 UTC
2f7108f remove the lower_fun default multiple_results=True (#3524) 23 June 2020, 00:50:33 UTC
8fe6da0 Add instructions for how to use the TensorBoard profiler to the profiling docs. (#3481) 23 June 2020, 00:33:23 UTC
7102322 Add private class wrapper for double-double arithmetic (#3521) 22 June 2020, 23:16:30 UTC
e5d4ca3 Fix typo understanding jaxprs page on readthedocs (#3513) 22 June 2020, 19:31:08 UTC
18798a3 Revert "Remove documentation presubmit action now that the RTD presubmit is enabled. (#3480)" (#3497) This (partially) reverts commit 68dcbdd1189cd938bf6023e4e1efaf64c71629aa. @hawkinsp points out this was running doctest, which the RTD presubmit doesn't do AFAIK. We still don't build the docs here, as the RTD presubmit takes care of that. Co-authored-by: Stephan Hoyer <shoyer@google.com> 22 June 2020, 18:02:41 UTC
a81e732 fix jet typo: jnp not np (#3507) 22 June 2020, 17:27:44 UTC
f8570ec Update JAX FAQ.rst, jax.device_put (#3496) 22 June 2020, 17:08:03 UTC
3fff837 pin numpy version in setup.py to avoid warnings (#3509) 22 June 2020, 15:12:41 UTC
7f9fc27 Another small fix of the link rendering in the Autodiff Cookbook - vmap transformation (#3502) 21 June 2020, 20:52:34 UTC
f0dff9c Fix a link rendering to Autograd's reverse-mode Jacobian method (#3501) 21 June 2020, 18:07:29 UTC
1ed60b3 support stacked doubling (#3490) 19 June 2020, 23:46:06 UTC
19f308b implement jax.random.choice (#3463) 19 June 2020, 23:04:42 UTC
8f4ba7e Allow specifying both `devices` and `axis_size` to pmap. (#3475) This allows providing custom device assignments to nested pmaps or pmap-of-sharded_jit when running on a multi-host platform. 19 June 2020, 22:51:12 UTC
a088c02 Bump jaxlib version to 0.1.49 and update WORKSPACE (#3495) 19 June 2020, 20:27:12 UTC
925d661 Fix test failure in jax2tf due to conflicting merges. (#3492) 19 June 2020, 14:14:53 UTC
927c209 Add random_gamma_grad and use in jax.random.gamma (#3281) 19 June 2020, 13:34:18 UTC
0da7b4d Improve dtype test coverage for random_test (#3254) 18 June 2020, 22:17:13 UTC
57d5a39 Disable the workaround to prevent expansion of type aliases (#3485) * Disable the workaround to prevent expansion of type aliases * Fix flake 18 June 2020, 18:01:40 UTC
3b4a123 [jax2tf] Control on which TF device the TF converted code runs. (#3482) * [jax2tf] Control on which TF device the TF converted code runs. The device should match the device on which the JAX code runs, otherwise the numerical comparisons don't make sense. * This code requires some build-file changes in Google3 to properly run on GPU and TPU. * Fix warnings 18 June 2020, 12:38:50 UTC
a05263f Avoid tuple psum in pmean (#3479) We have an optimization to avoid doing flops (or rather, communication) for `psum(1)` (instead we look up the axis size at trace time). But although `pmean(x)`, meaning `psum(x) / psum(1)`, is likely the most common user of `psum(1)`, it doesn't actually trigger this optimization right now because it's implemented as `psum((x, 1))` and `bind` lifts the 1 into the same `JaxprTrace` as `x` rather than letting the psum impl rule see it. The major reason for using tuple psum—providing a fixed order to avoid multihost GPU deadlocks—doesn't apply here because we don't expect the `psum(1)` to lower to an actual XLA AllReduce. 18 June 2020, 04:42:17 UTC
68dcbdd Remove documentation presubmit action now that the RTD presubmit is enabled. (#3480) 18 June 2020, 00:15:25 UTC
8bc7820 Use myst to parse markdown docs, and convert a page from rst to markdown. (#3477) Also sets the minimum sphinx version to 2.1. 17 June 2020, 23:59:14 UTC
3290e16 Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com> 17 June 2020, 23:35:36 UTC
3c6c0af disable new jax2tf TPU test (#3476) 17 June 2020, 22:09:54 UTC
5a74ebf Add experimental precision doubling transform (#3465) This PR adds an experimental precision doubling transform, following the basic approach outlined in Dekker 1971 ([pdf](http://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf)). When this transform is applied, the number of significant bits is approximately doubled compared to the base operation. Simple demo: ```python In [1]: import jax.numpy as jnp In [2]: from jax.experimental.doubledouble import doubledouble In [3]: def f(a, b): ...: return a + b - a ...: In [4]: f(1E20, 1.0) # float64 loses precision Out[4]: 0.0 In [5]: g = doubledouble(f)(1E20, 1.0) Out[5]: DeviceArray(1., dtype=float64) ``` This initial experiment supports basic arithmetic operators and inequalities. 17 June 2020, 21:59:40 UTC
5344ec5 use original numpy for shape calculations (#3474) cf. #3453 17 June 2020, 17:05:28 UTC
5ee936f Add polyder numpy function (#3403) 17 June 2020, 16:43:50 UTC
4f21b93 [jax2tf] Added special case for tf.pad. (#3462) Fixed lax_reference.pad to handle lax.pad with negative edge padding. 17 June 2020, 08:57:21 UTC
ce782e6 Fixed bug with silent TF using 64-bit operations in 32-bit mode. (#3471) We already cast the NumPy arrays to 32-bit, but this was not happening if the input was tf.Variable or tf.constant. In the process I have added some invariant checks, if core.skip_checks is False, which it always is in testing. Added also some sanity checking or arguments. 17 June 2020, 08:56:32 UTC
575216e add jet primitives, refactor tests (#3468) Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> 17 June 2020, 02:48:25 UTC
e9ce700 update tpu readme not to pin jax / jaxlib versions (#3467) 16 June 2020, 22:48:01 UTC
20f4ec6 fix a bug from #3459 (#3466) 16 June 2020, 22:46:51 UTC
df3b507 update jaxlib version in readme 16 June 2020, 21:19:03 UTC
a70ba92 jaxpr pretty-print: wrap equation RHS when the LHS is long 16 June 2020, 20:49:13 UTC
711c93d Merge pull request #3459 from google/simplify-remat-partial-eval simplify remat partial eval parameterization 16 June 2020, 19:20:59 UTC
dfdf05f deflake 16 June 2020, 18:56:42 UTC
005958e added reviewer suggestion 16 June 2020, 18:46:37 UTC
140c9ea Merge pull request #3396 from JuliusKunze/simplify-scan-mask Remove special case for scan masking 16 June 2020, 05:19:14 UTC
270e921 Merge pull request #3456 from jacobjinkelly/int_pow Jet rule for `integer_pow` 16 June 2020, 05:17:25 UTC
60760f2 Merge pull request #3455 from jacobjinkelly/erf_inv add erf_inv rule 16 June 2020, 05:16:04 UTC
c30d23f Merge pull request #3429 from BuddenD/changelist/316251368 Propagate raw __name__ of functions wrapped by jit and sharded_jit. 16 June 2020, 05:15:43 UTC
fe14aa3 Merge branch 'master' into changelist/316251368 16 June 2020, 05:15:34 UTC
08b83e8 Merge pull request #3426 from google/warn-jit-of-pmap print warning when doing jit-of-pmap 16 June 2020, 05:12:29 UTC
d4c6cb6 print warning when doing jit-of-pmap 16 June 2020, 04:37:30 UTC
8a901ba deflake 16 June 2020, 02:36:45 UTC
5d20c6f add scipy.special.zeta, scipy.special.polygamma (#3385) 16 June 2020, 02:21:26 UTC
05e0716 simplify remat partial eval parameterization The main win here is reducing the number of arguments for the function that parameterizes _remat_partial_eval (so it can be used both with remat and invertible ad features). I also included a fix to _remat_partial_eval that is needed in #3370, though I don't think it's needed on master. It was easier to include the fix now. Both these changes made rebasing #3370 easier! 16 June 2020, 01:42:53 UTC
f463598 add int pow rule 15 June 2020, 21:23:57 UTC
6bcf056 Fix jnp.arange(x) for fractional x (#3453) 15 June 2020, 20:02:59 UTC
3cf6b1d add erf inv rule erf_inv rule not working works up to order 2 erf inv rule use np for now actually use np for now 15 June 2020, 19:40:31 UTC
37f4722 Merge pull request #3449 from google/issue3440 add systematic tests for vmap-of-pmap 15 June 2020, 18:07:06 UTC
12160fe deflake 15 June 2020, 16:15:56 UTC
fcfcffe add systematic tests for vmap-of-pmap fixes #3440 Also re-applies the fix in #3439 (i.e it rolls-back the rollback PR #3448) because we're now confident it's correct (and some internal tests are buggy). 15 June 2020, 16:10:40 UTC
ea9af1b Merge pull request #3448 from google/roll-back-3439 roll back of #3439 while we debug internal fails 15 June 2020, 14:59:55 UTC
12ce6e3 roll back of #3439 while we debug internal fails 15 June 2020, 14:32:42 UTC
4d40b20 Initial version of invertible AD implementation (#3232) This is a prototype implementation of the memory-efficient VJP method for invertible function. The general idea is that thanks to invertibility, we don't have to memoize any intermediate primal values, but can simply reconstruct them in lock-step with gradient computation. The API is such that the only thing a user has to do, is decorate a function with `@invertible`, which will make AD apply the more efficient transpose than usual. The current version is expressive enough to support e.g. the Reversible ResNet, but there are still some caveats: - The definition of "invertible" function is a one that produces a jaxpr that can be inverted correctly if only we iterate over its equations in reverse. This is a bit strict, because users generally don't have too much control over that, and there are functions that produce jaxprs which will be treated as invertible when one topological ordering of equations is used, while they will be considered non-invertible for other valid orderings. - It doesn't follow the usual jvp + transpose path, and it turns out that zero argument pruning in JVPTrace makes it pretty much impossible to implement correctly. - `custom_ivjp` is an initial-style primitive. - Invertible reverse-mode implementation (`rev_backward_pass`) assumes that all the VJPs of primal primitives are jittable (not sure if that's a problem, but worth pointing out). - Not having a dedicated linearization pass makes the JVP of `custom_ivjp` inefficient if it is being staged out. 15 June 2020, 10:35:06 UTC
3c78605 Propagate raw __name__ and __doc__ of functions wrapped by jit and sharded_jit. 15 June 2020, 09:32:09 UTC
2e3d439 [jax2tf] fix the too-early use of tf.constant (#3446) If I leave it at top-level, I get test failures about missing platform Host 15 June 2020, 09:14:09 UTC
back to top