swh:1:snp:71afc25eb6e6e055a37a962e6b91010ec35e397f

sort by:
Revision Author Date Message Commit Date
e69305e very sketchy vmap collectives Co-authored-by: Adam Paszke <apaszke@google.com> 30 June 2020, 17:01:45 UTC
1f1122c added basic pmap handling to gmap! Co-authored-by: Adam Paszke <apaszke@google.com> 30 June 2020, 16:25:15 UTC
72ad64e check gmap jaxpr by calling impl directly Co-authored-by: Adam Paszke <apaszke@google.com> 30 June 2020, 16:14:48 UTC
bca8a34 Add example 30 June 2020, 16:04:46 UTC
7f36d05 Initial version of gmap Co-autored-by: Matthew Johnson <mattjj@google.com> 30 June 2020, 15:19:47 UTC
df97557 omnistaging wip 27 June 2020, 18:01:54 UTC
496cde6 [jax2tf] Add special case for translation of lax.gather to tf.gather. (#3486) Also adds more tests for conversion for gather. 27 June 2020, 11:55:28 UTC
ccb640a lax.sort: stable by default 27 June 2020, 03:37:23 UTC
6eea0ce detailed assertion in profiler test 27 June 2020, 00:55:58 UTC
c7e97b7 limit jaxpr context in typechecker error messages 26 June 2020, 22:31:04 UTC
c54acbf introduce custom typecheck rules, implement them for cond and scan 26 June 2020, 22:31:04 UTC
6b3b42d raise a custom error in jaxpr checker 26 June 2020, 22:31:04 UTC
42cbe49 Correction a typo of the period of the PRNG. (#3578) 26 June 2020, 22:19:57 UTC
5116fd4 Add a heap profiler API and document it. (#3576) 26 June 2020, 21:09:09 UTC
63ff6cb Clarify docs on jax.lax.cond. (#3569) 26 June 2020, 18:44:50 UTC
11caa21 ensure lax.reduce monoid test uses original numpy (#3573) 26 June 2020, 18:44:16 UTC
99a43f2 Added missing is_stable argument to lax.sort (#3553) 26 June 2020, 17:40:00 UTC
b576417 Update README for jaxlib 0.1.50 release. (#3574) 26 June 2020, 17:20:59 UTC
66cea02 Fix test failures on GPU. (#3572) 26 June 2020, 16:50:22 UTC
8f86b13 Update docker script for CUDA 11. (#3571) 26 June 2020, 15:20:51 UTC
26c6c3a fix error when doing forward-mode of odeint (#3566) fixes #3558 26 June 2020, 03:57:34 UTC
4b1bb18 Fix logsumexp test (#3563) Previously, this test could generate an axis out of bounds error with large num_generated_cases. (discovered in the process of testing #3561) 26 June 2020, 03:08:14 UTC
062ce29 removed stale faq entries (#3565) 26 June 2020, 02:17:24 UTC
c2501d1 update version and changelog for pypi (#3564) 26 June 2020, 00:48:21 UTC
db80ca5 allow closures for odeint dynamics functions (#3562) * allow closures for odeint dynamics functions fixes #2718, #3557 * add tests for odeint dynamics closing over tracers 26 June 2020, 00:36:17 UTC
2a6fc31 Update XLA in preparation for a new jaxlib release (0.1.50). (#3560) 25 June 2020, 19:12:47 UTC
a141cc6 Make CUDA wheels manylinux2010 compliant, add CUDA 11, drop CUDA 9.2 (#3555) * Use dynamic loading to locate CUDA libraries in jaxlib. This should allow jaxlib CUDA wheels to be manylinux2010 compliant. * Tag CUDA jaxlib wheels as manylinux2010. Drop support for CUDA 9.2, add support for CUDA 11.0. * Reorder CUDA imports. 25 June 2020, 18:37:14 UTC
c9670d5 Fix lazy broadcast issue (#3536) 25 June 2020, 14:50:11 UTC
93054fa add remat to top-level docs (#3554) 25 June 2020, 14:26:26 UTC
32e419d Fix eigh JVP to ensure that both the primal and tangents of the eigen… (#3550) * Fix eigh JVP to ensure that both the primal and tangents of the eigenvalues are real. Add test to jax.test_util.check_jvp that ensure the primals and both the primals and tangents produced by a JVP rule have identical types. * Cast input to static indexing grad tests to a JAX array so new type check passes. 25 June 2020, 12:14:54 UTC
696958d add remat docstring (#3542) * add remat docstring first part of addressing #3314 24 June 2020, 23:11:26 UTC
7e5407c Fix typos pytrees page on readthedocs (#3548) 24 June 2020, 22:00:24 UTC
a9c7413 Fix failure in PyTorch array interoperability test in non-x64 mode. (#3546) 24 June 2020, 20:15:59 UTC
e680304 Remove warning suppression for tuple and list arguments to reductions. (#3545) Fix callers. 24 June 2020, 19:59:31 UTC
677baa5 Clarify docstrings regarding usage of static arguments in jit and vmap. (#3484) The docstring for pmap does not currently mention that any "non-data" arguments need to be indicated in `static_broadcasted_argnums`. This commit updates the docs to parallel those for `jax.jit` which does explain this. Additionally, a remark is added to the `static_*` argument descriptions on both jit and pmap, so that this point can be understood without reading the whole docstring. 24 June 2020, 19:47:09 UTC
f036f5d Avoid direct type/dtype comparisons to fix NumPy 1.19 deprecation war… (#3543) * Avoid direct type/dtype comparisons to fix NumPy 1.19 deprecation warnings. * Pin a newer tf-nightly to fix jax2tf tests for NumPy 1.19.0 24 June 2020, 19:19:00 UTC
90894a2 Fix test failures due to RuntimeWarning in variance tests. (#3541) 24 June 2020, 18:41:20 UTC
28262da Restrict .compress tests to arrays only. (#3540) Fixes test failures with 100 generated cases. 24 June 2020, 18:31:37 UTC
a422f63 Merge quantile and nanquantile implementations. (#3539) Allows more code sharing, and adds support for other interpolation modes to nanquantile(). 24 June 2020, 18:25:47 UTC
a44bc0c Add np.diag_indices_from (#3500) 24 June 2020, 16:31:16 UTC
fede6d9 Add missing functions to documentation. (#3538) 24 June 2020, 16:16:56 UTC
c357aad Implement nan-quantile/median/percentile (#2397) * Implement nan-quantile/median/percentile * Fix flake8 error. Co-authored-by: Peter Hawkins <phawkins@google.com> 24 June 2020, 15:16:45 UTC
db2291b Fixed issue with the 'compress' method of 'jnp.ndarray' (#3537) * Fixed issue with the 'compress' method of 'jnp.ndarray' behaving differently than the free function * Added test 24 June 2020, 15:13:56 UTC
319eeaf Future warning about lists and tuples (#3369) 24 June 2020, 14:54:06 UTC
a6e3f99 Add np.unwrap (#3527) 24 June 2020, 13:22:35 UTC
edcd43a de-duplicate constants when lowering to XLA Co-authored-by: Matthew Johnson <mattjj@google.com> 24 June 2020, 02:03:20 UTC
9d173c6 Support `b` and `return_sign` in scipy.special.logsumexp (#3488) 23 June 2020, 22:36:45 UTC
74ee2ef avoid value-based error check in random.choice (#3531) 23 June 2020, 21:03:36 UTC
a45e283 add back a full_lower, dropped in #3491 (#3530) 23 June 2020, 19:08:12 UTC
7527830 refactor call primitives, simpler param processing (#3491) 23 June 2020, 16:39:45 UTC
d5a5d30 lax.sort: allow any sequence of Arrays, not just tuples (#3367) 23 June 2020, 15:28:04 UTC
7cc2650 Fix lax_reference's round for edge case inputs (#3309) * Fix lax_reference's round for edge case inputs - round(8388609) would compute trunc(8388609 + 0.5) == 8388610. Fix this by not modifying sufficiently inputs. - round(0.499999970198) would compute trunc(0.499999970198 + 0.5) == 1.0 Fix this by explicitly special casing the first float before 0.5. * Add explicit casts to lax_reference implementation of `round`. Co-authored-by: Peter Hawkins <phawkins@google.com> 23 June 2020, 14:35:41 UTC
86fcfbf Fix memory leak when no axis is provided to pmap. (#3394) * Fix memory leak when no axis is provided to pmap. * Work around flake8 false positive. Co-authored-by: Matthew Johnson <mattjj@google.com> 23 June 2020, 13:29:58 UTC
490c853 Adds boolean support for bitwise not and unittests for boolean support on logical operations. (#3483) Co-authored-by: Thomas Keck <thomaskeck@google.com> 23 June 2020, 13:25:53 UTC
5ee6bc0 Remove unnecessary static_argnum in np.gradient (#3512) 23 June 2020, 03:46:41 UTC
0249492 fix an issue with newer versions of pytype (#3526) 23 June 2020, 03:04:07 UTC
33c455a Add jax.scipy.signal.detrend (#3516) 23 June 2020, 02:49:00 UTC
ca5b0b1 Cast int8 to bool for lax.not in jax2tf. (#3519) 23 June 2020, 02:44:09 UTC
046006e Fix typo: np.bool -> np.bool_ (#3525) Replaced np.bool (which is just bool) with np.bool_, which is numpy's Boolean type. 23 June 2020, 02:43:25 UTC
2f7108f remove the lower_fun default multiple_results=True (#3524) 23 June 2020, 00:50:33 UTC
8fe6da0 Add instructions for how to use the TensorBoard profiler to the profiling docs. (#3481) 23 June 2020, 00:33:23 UTC
7102322 Add private class wrapper for double-double arithmetic (#3521) 22 June 2020, 23:16:30 UTC
e5d4ca3 Fix typo understanding jaxprs page on readthedocs (#3513) 22 June 2020, 19:31:08 UTC
18798a3 Revert "Remove documentation presubmit action now that the RTD presubmit is enabled. (#3480)" (#3497) This (partially) reverts commit 68dcbdd1189cd938bf6023e4e1efaf64c71629aa. @hawkinsp points out this was running doctest, which the RTD presubmit doesn't do AFAIK. We still don't build the docs here, as the RTD presubmit takes care of that. Co-authored-by: Stephan Hoyer <shoyer@google.com> 22 June 2020, 18:02:41 UTC
a81e732 fix jet typo: jnp not np (#3507) 22 June 2020, 17:27:44 UTC
f8570ec Update JAX FAQ.rst, jax.device_put (#3496) 22 June 2020, 17:08:03 UTC
3fff837 pin numpy version in setup.py to avoid warnings (#3509) 22 June 2020, 15:12:41 UTC
7f9fc27 Another small fix of the link rendering in the Autodiff Cookbook - vmap transformation (#3502) 21 June 2020, 20:52:34 UTC
f0dff9c Fix a link rendering to Autograd's reverse-mode Jacobian method (#3501) 21 June 2020, 18:07:29 UTC
1ed60b3 support stacked doubling (#3490) 19 June 2020, 23:46:06 UTC
19f308b implement jax.random.choice (#3463) 19 June 2020, 23:04:42 UTC
8f4ba7e Allow specifying both `devices` and `axis_size` to pmap. (#3475) This allows providing custom device assignments to nested pmaps or pmap-of-sharded_jit when running on a multi-host platform. 19 June 2020, 22:51:12 UTC
a088c02 Bump jaxlib version to 0.1.49 and update WORKSPACE (#3495) 19 June 2020, 20:27:12 UTC
925d661 Fix test failure in jax2tf due to conflicting merges. (#3492) 19 June 2020, 14:14:53 UTC
927c209 Add random_gamma_grad and use in jax.random.gamma (#3281) 19 June 2020, 13:34:18 UTC
0da7b4d Improve dtype test coverage for random_test (#3254) 18 June 2020, 22:17:13 UTC
57d5a39 Disable the workaround to prevent expansion of type aliases (#3485) * Disable the workaround to prevent expansion of type aliases * Fix flake 18 June 2020, 18:01:40 UTC
3b4a123 [jax2tf] Control on which TF device the TF converted code runs. (#3482) * [jax2tf] Control on which TF device the TF converted code runs. The device should match the device on which the JAX code runs, otherwise the numerical comparisons don't make sense. * This code requires some build-file changes in Google3 to properly run on GPU and TPU. * Fix warnings 18 June 2020, 12:38:50 UTC
a05263f Avoid tuple psum in pmean (#3479) We have an optimization to avoid doing flops (or rather, communication) for `psum(1)` (instead we look up the axis size at trace time). But although `pmean(x)`, meaning `psum(x) / psum(1)`, is likely the most common user of `psum(1)`, it doesn't actually trigger this optimization right now because it's implemented as `psum((x, 1))` and `bind` lifts the 1 into the same `JaxprTrace` as `x` rather than letting the psum impl rule see it. The major reason for using tuple psum—providing a fixed order to avoid multihost GPU deadlocks—doesn't apply here because we don't expect the `psum(1)` to lower to an actual XLA AllReduce. 18 June 2020, 04:42:17 UTC
68dcbdd Remove documentation presubmit action now that the RTD presubmit is enabled. (#3480) 18 June 2020, 00:15:25 UTC
8bc7820 Use myst to parse markdown docs, and convert a page from rst to markdown. (#3477) Also sets the minimum sphinx version to 2.1. 17 June 2020, 23:59:14 UTC
3290e16 Attach source info to Jaxpr equations. (#3421) * Attach source info to Jaxpr equations. Example: ``` In [1]: import jax, jax.numpy as jnp In [2]: def f(x, y): ...: z = jax.numpy.cos(x) ...: z = z * jax.numpy.tanh(y) ...: return z + 2 ...: In [3]: jax.make_jaxpr(jax.value_and_grad(f))(7., 9.) Out[3]: { lambda ; a b. let c = cos a [<ipython-input-2-5d59f71cb65d>:2 (f)] d = tanh b [<ipython-input-2-5d59f71cb65d>:3 (f)] e = mul c d [<ipython-input-2-5d59f71cb65d>:3 (f)] f = add e 2.0 [<ipython-input-2-5d59f71cb65d>:4 (f)] g = mul 1.0 d [<ipython-input-2-5d59f71cb65d>:3 (f)] h = neg g [<ipython-input-2-5d59f71cb65d>:2 (f)] i = sin a [<ipython-input-2-5d59f71cb65d>:2 (f)] j = mul h i [<ipython-input-2-5d59f71cb65d>:2 (f)] in (f, j) } In [7]: print(jax.xla_computation(jax.value_and_grad(f))(7., 9.).as_hlo_module().to_string()) HloModule xla_computation_f__4.15 ENTRY %xla_computation_f__4.15 (parameter.1: f32[], parameter.2: f32[]) -> (f32[], f32[]) { %constant.3 = pred[] constant(false) %parameter.1 = f32[] parameter(0) %cosine.4 = f32[] cosine(f32[] %parameter.1), metadata={op_type="cos" op_name="xla_computation(f)/cos" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %parameter.2 = f32[] parameter(1) %tanh.5 = f32[] tanh(f32[] %parameter.2), metadata={op_type="tanh" op_name="xla_computation(f)/tanh" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.6 = f32[] multiply(f32[] %cosine.4, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %constant.7 = f32[] constant(2), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %add.8 = f32[] add(f32[] %multiply.6, f32[] %constant.7), metadata={op_type="add" op_name="xla_computation(f)/add" source_file="<ipython-input-2-5d59f71cb65d>" source_line=4} %constant.9 = f32[] constant(1), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %multiply.10 = f32[] multiply(f32[] %constant.9, f32[] %tanh.5), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=3} %negate.11 = f32[] negate(f32[] %multiply.10), metadata={op_type="neg" op_name="xla_computation(f)/neg" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %sine.12 = f32[] sine(f32[] %parameter.1), metadata={op_type="sin" op_name="xla_computation(f)/sin" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} %multiply.13 = f32[] multiply(f32[] %negate.11, f32[] %sine.12), metadata={op_type="mul" op_name="xla_computation(f)/mul" source_file="<ipython-input-2-5d59f71cb65d>" source_line=2} ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %add.8, f32[] %multiply.13) } ``` Co-authored-by: Matthew Johnson <mattjj@google.com> 17 June 2020, 23:35:36 UTC
3c6c0af disable new jax2tf TPU test (#3476) 17 June 2020, 22:09:54 UTC
5a74ebf Add experimental precision doubling transform (#3465) This PR adds an experimental precision doubling transform, following the basic approach outlined in Dekker 1971 ([pdf](http://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf)). When this transform is applied, the number of significant bits is approximately doubled compared to the base operation. Simple demo: ```python In [1]: import jax.numpy as jnp In [2]: from jax.experimental.doubledouble import doubledouble In [3]: def f(a, b): ...: return a + b - a ...: In [4]: f(1E20, 1.0) # float64 loses precision Out[4]: 0.0 In [5]: g = doubledouble(f)(1E20, 1.0) Out[5]: DeviceArray(1., dtype=float64) ``` This initial experiment supports basic arithmetic operators and inequalities. 17 June 2020, 21:59:40 UTC
5344ec5 use original numpy for shape calculations (#3474) cf. #3453 17 June 2020, 17:05:28 UTC
5ee936f Add polyder numpy function (#3403) 17 June 2020, 16:43:50 UTC
4f21b93 [jax2tf] Added special case for tf.pad. (#3462) Fixed lax_reference.pad to handle lax.pad with negative edge padding. 17 June 2020, 08:57:21 UTC
ce782e6 Fixed bug with silent TF using 64-bit operations in 32-bit mode. (#3471) We already cast the NumPy arrays to 32-bit, but this was not happening if the input was tf.Variable or tf.constant. In the process I have added some invariant checks, if core.skip_checks is False, which it always is in testing. Added also some sanity checking or arguments. 17 June 2020, 08:56:32 UTC
575216e add jet primitives, refactor tests (#3468) Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> 17 June 2020, 02:48:25 UTC
e9ce700 update tpu readme not to pin jax / jaxlib versions (#3467) 16 June 2020, 22:48:01 UTC
20f4ec6 fix a bug from #3459 (#3466) 16 June 2020, 22:46:51 UTC
df3b507 update jaxlib version in readme 16 June 2020, 21:19:03 UTC
a70ba92 jaxpr pretty-print: wrap equation RHS when the LHS is long 16 June 2020, 20:49:13 UTC
711c93d Merge pull request #3459 from google/simplify-remat-partial-eval simplify remat partial eval parameterization 16 June 2020, 19:20:59 UTC
dfdf05f deflake 16 June 2020, 18:56:42 UTC
005958e added reviewer suggestion 16 June 2020, 18:46:37 UTC
140c9ea Merge pull request #3396 from JuliusKunze/simplify-scan-mask Remove special case for scan masking 16 June 2020, 05:19:14 UTC
270e921 Merge pull request #3456 from jacobjinkelly/int_pow Jet rule for `integer_pow` 16 June 2020, 05:17:25 UTC
60760f2 Merge pull request #3455 from jacobjinkelly/erf_inv add erf_inv rule 16 June 2020, 05:16:04 UTC
c30d23f Merge pull request #3429 from BuddenD/changelist/316251368 Propagate raw __name__ of functions wrapped by jit and sharded_jit. 16 June 2020, 05:15:43 UTC
back to top