swh:1:snp:71afc25eb6e6e055a37a962e6b91010ec35e397f

sort by:
Revision Author Date Message Commit Date
182c462 hoist constants to parameters in jit-compiled functions Co-authored-by: Matthew Johnson <mattjj@google.com> 29 May 2020, 00:26:45 UTC
e48a4e0 uses np.prod instead of jnp.prod for shapes (#3236) 28 May 2020, 20:16:56 UTC
1a43693 implement np.extract (#3234) 28 May 2020, 18:04:15 UTC
7c90023 Fix sign error in custom_jvp / custom_vjp. (#3213) (#3219) f(x, y) = sin(x) * y. df/dy should be sin(x) instead of -sin(x). 28 May 2020, 17:21:39 UTC
572928d fix custom_jvp_call_jaxpr transpose function (#3231) * make custom_jvp_call_jaxpr handle multilinear funs see #3226 * remove old comment 28 May 2020, 17:20:36 UTC
c1ccbdf Small cleanup for partial_eval (#3210) `partial_eval` uses some pretty tricky conventions for return values (see `partial_eval_wrapper`), but it forces all call sites to deal with untangling them. This commit inlines the postprocessing into `partial_eval`, greatly simplifying its usage. 28 May 2020, 15:39:13 UTC
5475e1d Fix uses of sets in einsum to avoid nondeterminism. (#3230) * Fix uses of sets in einsum to avoid nondeterminism. * Address review comments. 28 May 2020, 14:07:06 UTC
9b91e89 Add implementation of np.compress (#3227) 28 May 2020, 01:57:00 UTC
5fa6ab5 Implement np.argwhere & np.flatnonzero (#3223) 28 May 2020, 00:08:12 UTC
6e50aa9 Update installation directions in README to mention expected CUDA location. (#3190) See https://github.com/google/jax/issues/989 28 May 2020, 00:05:36 UTC
02b4fd3 Fix broadcast_shapes for polymorphic dims (#3216) (#3224) * Fix #3216 * Simplify 27 May 2020, 22:15:01 UTC
7d96aae Fix bug in pytype fix. (#3229) 27 May 2020, 20:44:38 UTC
41292a2 mention numpy & scipy convolve functions in gotchas doc. (#3214) 27 May 2020, 20:20:27 UTC
336a0d6 Fix pytype error. (#3226) * Fix pytype error. * Incorporate review comment. 27 May 2020, 20:13:31 UTC
94b4ccd Relax test tolerance on lax_scipy_test to fix a test failure on Skylake machines at LLVM head. (#3225) 27 May 2020, 19:34:35 UTC
c5010cd use new gensym in host_callback jaxpr rewriter 27 May 2020, 19:03:34 UTC
da9d8c9 update reference jaxprs in tests 27 May 2020, 19:03:34 UTC
1b020fe update core.gensym consumers, address rewrite TODOs in lax control flow rules 27 May 2020, 19:03:34 UTC
e80e963 jaxpr-dependent gensym to avoid var duplication 27 May 2020, 19:03:34 UTC
8f2d72e Simplify handling of non-linear equations in backward_pass and fix remat (#3162) Previously, `backward_pass` has been generalized to be able to handle non-linear computation in the body, but it could easily get confused into doing unnecessary work only to throw it away later. Additionally, it treated any call primitive embedded inside remat like remat itself, which is obviously wrong. This patch fixes both of those issues and simplifies a bunch of the code at the same time. `backward_pass` now has an invariant that it only deals with jaxprs containing linear equations alone, and becomes a simple transposing interpreter again. **Background on JVP vs linearization** Ok, so why does this change actually fix the problem? It is important to understand that JVP and linearization transforms are actually two different things, even though we often identify them as one. Both take in a function of type `a -> b`, but their ranges are different! JVP returns a function of type `(a, T a) -> (b, T b)` while linearization returns `a -> (b, T a --o T b)`. Note that the second type carries more information, because we get a guarantee that (1) `b` does not depend on `T a` and (2) the dependence of `T b` on `T a` is linear. The reason why we usually treat them as equivalent, is that they can be shown to be "isomorphic". If we take the output of linearization, we can make it a JVP-like function using the following combinator: ```haskell jvp f = \a ta -> let (b, lf) = linearize f in (b, lf ta) ``` More importantly for JAX, which doesn't have a linearization interpreter, if we assume (1) and (2), linearization can be recovered in terms of jvp as well: ```haskell linearize f = \a -> let fjvp = jvp f in partial_eval fjvp (Known a) Unknown ``` That is, if we have a mathematically correct JVP, then linearization is simply partial evaluation with all primal values marked as known, and all tangents treated as yet unknown values. One important performance consideration is that for forward-mode AD we really want to use the JVP formulation, which can interleave the computation of primals and tangents, instead of sequencing them and increasing the memory cost. On the other hand, transposition (necessary for VJPs!) can only be applied to linear functions, and so it can't possibly work on the output of JVP. It really can only be apply to the second output of the linearization transform. Hence, we really care about both, but can we avoid having two very similar implementations of (approximately) the same thing? It seems that the answer is yes, because of the equivalence outlined above! **If all this is so nice, then what's the problem?** The problem is, of course, remat. Partial eval is able to thread the known/unknown information correctly through regular call primitives, but mind you, remat is no regular call primitive! Once we enter remat, we are no longer interested in treating _anything_ like a known value. After all, our goal here is to record an accurate trace of everything that has happened in the body of a remat, including the primal (known!) computation. This however presents a challenge for implementing linearization in terms of JVP, because inside the body of remat we break the assumption that known/unknown corresponds to the primal/tangent distinction. Its body, instead of representing the second output of linearization simply contains the traced JVP code now... One way to fix it would be to implement a proper linearization pass that would track the distinciton between primal and tangent information while still allowing to stage out code for primals. @mattjj and I have even started hacking together an implementation for that. I've been trying to convince @mattjj that there is no other way to go about it, but I couldn't really convince him that this is the case. Then, once I wanted to write a semi-formal proof I could no longer even convince myself! Turns out that there is an alternative solution! What this patch does is, it stops caring about the output of the `linearize` function (defined as JVP + partial eval, as discussed above) to be a good linearization. It still is if you don't use remats in your code, but it still breaks miserably once you do. However, as long as all the complications are contained solely in the `call_jaxpr` embedded inside a remat, we still have a chance to fix them! This is because the transposition interpreter never reaches into those bodies directly, but rather asks the call primitive to transpose itself. Now, how do you transpose remat? We can't just reuse the code used for regular call primitives (this is what happens now BTW), because unlike for them, the `call_jaxpr` doesn't represent a linear function! But it's not completely useless either --- it contains the traced JVP code. So, how do we get from there to a linear function? Partial eval! And if you think about it, it is exactly what we wanted --- we end up evaluating all the primal code in the body once again, while only staging out the tangent computation, to be passed into the transposing interpreter again. Fin. 27 May 2020, 18:22:40 UTC
ec3b593 Added geometric distribution to scipy stats (#3205) 27 May 2020, 16:37:55 UTC
e3b046b Adjust complex64 tolerance for upcoming XLA change (#3218) It makes sense for complex64 tolerance to be the same as float32 tolerance here. While there re-enable the test on GPU, which was blocked on a bug that's long gone. 27 May 2020, 14:29:00 UTC
1cc4719 Remove pe from name_stack and test. (#3209) 27 May 2020, 07:59:31 UTC
9f8a4ad remove stray print statement from #1529 27 May 2020, 03:01:36 UTC
6ffde80 Implement pmap of sharded_jit (#3144) * Implement pmap of sharded_jit * Update jax/interpreters/pxla.py Co-authored-by: James Bradbury <jekbradbury@google.com> * Address comments Co-authored-by: James Bradbury <jekbradbury@google.com> 26 May 2020, 21:26:53 UTC
e526109 Remove dtype warning for `np.quantile` (#3188) * drop the warning in index_to_gather * fix dtype issue at quantile * revert the change, the issue seems to be fixed 26 May 2020, 19:41:01 UTC
0f23002 Add a JAX flag to avoid most optimizations. (#3208) 26 May 2020, 19:21:22 UTC
a486f54 Add a summary explaining the usage and context for JAX PRNG design. (#2525) * Add a summary explaining the usage and context for JAX PRNG design. The current design_notes do not match current JAX API, and it's a pretty long doc to read to understand how to use it. Closes: #2087 * Change 'should' to be more precise. * Address comments. 26 May 2020, 07:38:28 UTC
f18f792 Fix error in code generation of batched while loops (#3207) Fixed the case when the value is a unit, which we do not batch. 26 May 2020, 07:22:33 UTC
0eace80 Fix experimental host callback on multi-host (#3200) * Fix experimental host callback on multi-host Hosts can only access the outfeed queue for local devices, while `api.devices` returns all devices in the system. * Update host_callback.py 25 May 2020, 05:12:58 UTC
f1ae216 Added argument check to all primitives. (#3197) * Added argument check to all primitives. The issue that inspired this is that `lax.tie_in` is easy to misuse if the first argument is not a JAX type, then it silently disappears. This means that `lax.tie_in((x, x), const)` is the same as `const` even though `x` is a tracer. This error would be caught previously if core.skip_checks == False because then `bind` checks its arguments. I have essentially added an unconditional argument check to `bind`. In case this is considered too inefficient, we can add argument checking to individual primivites, e.g., tie_in. For most primitives if a non-JAX array is passed, the `impl` rule would fire and `numpy` would report the error somehow, perhaps. * Merged find_top_trace with check_args This was previously merged as #2948 but reverted awaiting the fixes in some user code. 24 May 2020, 16:12:37 UTC
afadb12 Improved tapping support for while: tap inside cond, vmap of while (#3195) * Improved tapping support for while: tap inside cond, vmap of while * Fix float64->float32 in tests 24 May 2020, 07:50:07 UTC
9d6744f Cleanup test_custom_root_scalar and re-enable it for TPUs (#3184) The test passes now on TPUs, thanks to the new ``lax.integer_pow`` primitive. 23 May 2020, 18:45:27 UTC
b493a7e Fix the handling of repeated vmap for id_tap (#3132) * Fix the handling of repeated vmap for id_tap * Updated the transforms to always be a tuple of tuples * Changed the transforms to be dictionaries 23 May 2020, 10:49:27 UTC
de800d2 fix #3165 if round half up behavior is desired (#3166) * fix issue 3165 if round half up behaviour is desired * add test for round half * fix integer array input and and add to test * fix truncated integer output * match input and output dtypes * fix asymmetric rounding and extend test * use lax for rounding 22 May 2020, 22:41:37 UTC
190f88d Update Common_Gotchas_in_JAX.ipynb (#3189) typo fix 22 May 2020, 21:12:44 UTC
96c20f3 Merge pull request #2734 from google/tycheck typecheck jaxprs 22 May 2020, 05:07:24 UTC
c293a10 work around mypy 22 May 2020, 03:54:02 UTC
69d7bcf except-and-raise during jaxpr checking, adding jaxpr as context, and simplify type environment 22 May 2020, 03:02:30 UTC
8e61ce8 fix unitvar comparisons and move to class attributes 22 May 2020, 01:28:09 UTC
ecd8936 Address comments 21 May 2020, 21:50:16 UTC
77e4d8b Updates onp -> np in random, loops, jet and in the tests of stax and optix (#3182) 21 May 2020, 21:12:18 UTC
d8ede01 Update jax/interpreters/pxla.py Co-authored-by: James Bradbury <jekbradbury@google.com> 21 May 2020, 21:00:58 UTC
a3e0cd1 Fix pxla.shard_args bug (#3170) 21 May 2020, 20:52:03 UTC
5d12553 axis_index abstract eval rule 21 May 2020, 20:21:07 UTC
1a91662 return tuple in psum abstract eval rule 21 May 2020, 20:21:07 UTC
1d78081 use jax.numpy in jaxpr typecheck tests 21 May 2020, 20:21:07 UTC
7ff389b extend type transfer to all primitives, including call and map primitives 21 May 2020, 20:21:07 UTC
e2cc568 raise type errors consistently in jaxpr checker 21 May 2020, 20:21:07 UTC
6475f60 fix import in core_test 21 May 2020, 20:21:07 UTC
1e55603 avoid attempt to read literals from the typechecking environment 21 May 2020, 20:21:07 UTC
060bd8a tidy jaxpr typechecking error test 21 May 2020, 20:21:07 UTC
0f109d9 add jaxpr context to typechecker error message 21 May 2020, 20:21:07 UTC
3705252 have UnitVar subclass Var (caught by mypy) 21 May 2020, 20:21:07 UTC
42e7e20 update check_jaxpr doc 21 May 2020, 20:21:07 UTC
cc34ed2 check aval compatibility, not strict equality, when typechecking jaxpr equations 21 May 2020, 20:21:07 UTC
0c2c558 check that variables are typed equally throughout a jaxpr 21 May 2020, 20:21:07 UTC
8e70769 factor out jaxpr-check context and variable environment 21 May 2020, 20:21:07 UTC
1205f7a factor out jaxpr equation checks 21 May 2020, 20:21:07 UTC
94b1f63 raise TypeError for jaxpr typechecking errors 21 May 2020, 20:21:07 UTC
82a9af5 typecheck jaxpr equations 21 May 2020, 20:21:07 UTC
4cbd14c Update jax version to 0.1.68 (#3181) 21 May 2020, 17:38:04 UTC
bb2127c Future-proof view test against signaling NaNs (#3178) 21 May 2020, 16:20:59 UTC
c459280 Added `associative_scan`. (#2170) * Added `associative_scan`. * Fixed problem where base case of associative scan could fail * remove jax.numpy dependence in associative_scan Co-authored-by: Matthew Johnson <mattjj@google.com> 21 May 2020, 15:35:17 UTC
5c1de28 revise vmap in_axes/out_axes leaf type error msg (#3179) from #3161 discussion 21 May 2020, 15:00:18 UTC
eb81d7e add dict in_axes example to vmap docstring (#3176) * add dict in_axes example to vmap docstring fixes #3161 * fix typo 21 May 2020, 13:47:02 UTC
6e3c8b1 Fix arr.view() on TPU & improve tests (#3141) 21 May 2020, 13:40:24 UTC
ae9d175 fix while_loop cond function batching (#3174) * fix while_loop cond function batching fixes #3164 * add test for #3164 21 May 2020, 03:21:41 UTC
f9c978e improve docstring of jax.numpy.broadcast_to (#3173) thanks @joaogui1 ! 21 May 2020, 02:09:54 UTC
a4094f7 revise "Tracer with raw numpy" error message (#3160) * revise "Tracer with raw numpy" error message fixes #3133 * fix f-string typo * fix typo Co-authored-by: James Bradbury <jekbradbury@google.com> Co-authored-by: James Bradbury <jekbradbury@google.com> 21 May 2020, 02:09:44 UTC
12f26d3 Improve ``devices`` and related documentation (#3155) 20 May 2020, 21:40:28 UTC
8b32c5d Avoid running trivial jitted subcomputations in pe (#3169) 20 May 2020, 21:30:33 UTC
f349979 Relax tolerance of LaxVmapTest.testDot for float64 inputs. (#3167) 20 May 2020, 20:17:17 UTC
7d157c7 onp -> np (#3157) 20 May 2020, 04:43:48 UTC
42b425d fix disable_jit logic in lax.cond and lax.while_loop (#3156) * fix disable_jit logic in lax.cond, fixes #3093 * fix disable_jit logic in lax.while_loop, fix #2823 * add test for issue #3093 * add test for #2823 * add test for #2598 20 May 2020, 01:14:10 UTC
3141ff8 Add lax implementation of np.isin() and np.in1d() (#3087) 19 May 2020, 23:58:42 UTC
ccb203c improve pmap unbound axis error, fixes #3120 (#3152) 19 May 2020, 22:51:07 UTC
b7ff305 fix broken TPU test (#3153) 19 May 2020, 22:50:54 UTC
bc47a32 make lax.psum promote bool -> int (#3150) * make lax.psum promote bool -> int, fixes #3123 * fix test bug * fix typo in test 19 May 2020, 22:41:03 UTC
83a339e add erf and erfc rules (#3051) refactor def comp 19 May 2020, 22:22:25 UTC
b6777c0 Correct calculation of loss and increase learning rate (#3113) 19 May 2020, 22:18:34 UTC
850f1af improve errors for complex derivs, fixes #3121 (#3149) 19 May 2020, 22:17:03 UTC
8fe2619 Expand type support for random uniform() & randint() (#3138) 19 May 2020, 21:19:00 UTC
73b76e9 Exported lax from jax/__init__.py (#3135) This allows to use lax functions without separately importing jax.lax. 19 May 2020, 19:40:03 UTC
77e3132 Fix indentation for docstrings in jax.experimental.host_callback (#3119) 19 May 2020, 15:23:45 UTC
85fe5a2 Add gradients to the scatter_max and scatter_min operations. (#3111) This is being done to allow the creation of a differentiable segment_max. Segment_max is an important operation for GraphNets and is an open feature request at https://github.com/google/jax/issues/2255 Co-authored-by: Alex Davies <adavies@google.com> 19 May 2020, 06:06:32 UTC
8d0749f Fix a corner case in `repeat`. (#3117) * Fixes a corner case: `jnp.repeat(jnp.array(0), 1, axis=0)` throws an error, whereas `np.repeat(np.array(0), 1, axis=0) = np.array([0])`. * Add test for `np.range(np.array(0), 1, axis=0)`. 19 May 2020, 02:24:45 UTC
888c9c7 Implement pmap of sharded_jit 19 May 2020, 01:40:28 UTC
d53bab9 Improve sharded_jit error message and fix test (#3145) 19 May 2020, 01:24:33 UTC
083cdd3 Fix in pxla._inner_partitions (#3146) In cb77f2a22de49e85da93f43b7dc448aa238d5207, I switched to looking for sharding_constraint_p's name since sharding_constraint_p itself is defined in sharded_jit.py, but didn't quite get the update right. 19 May 2020, 01:18:06 UTC
cb77f2a Move some sharded_jit functionality into pxla.py (#3142) Specifically: * Move `_inner_partitions` * Move `get_num_partitions` * Move and slightly modify the logic for finding and validating inner partitions to a new function, `reconcile_num_partitions` * Move `_partitioned_sharding_spec` and rename to `partitioned_sharding_spec` This is in preparation for enabling pmap-of-sharded_jit, since pmap will need access to this functionality as well. 18 May 2020, 23:37:49 UTC
9152b76 Add with_sharding_constraint method to be used within sharded_jit. (#3100) See the with_sharding_constraint docstring for a description of what this method does. Depending on how we decide nested sharded_jits should work, an alternative implementation for with_sharding_constraint could be: ```python def with_sharding_constraint(x, partitions): return sharded_jit(lambda x: x, in_parts=partitions, out_parts=partitions) ``` In this case, we could get rid of the with_sharding_constraint primitive, and possibly even the API. This implementation gets the job done for now without committing to a nested sharded_jit behavior, and is also much easier to take the gradient of than sharded_jit. 18 May 2020, 22:20:49 UTC
4603432 Fix gradient for jax.scipy.ndimage.map_coordinates (#3110) * Fix gradient for jax.scipy.ndimage.map_coordinates Fixes GH3024 * minor refactor for clarification 18 May 2020, 22:13:54 UTC
36e7fad Add a primitive integer_pow() for values raised to a fixed integer scalar. (#3140) * Add a primitive integer_pow() for values raised to fixed integer scalar. Use integer_pow() in the RHS JVP of div(). Also use it in square() and reciprocal(). Fixes #3136 ``` In [1]: from jax import grad, make_jaxpr In [2]: def inv(x): return 1/x In [3]: print(grad(grad(grad(grad(grad(grad(inv))))))(4.)) 0.043945312 In [4]: make_jaxpr(grad(grad(grad(grad(grad(grad(inv)))))))(4.) Out[4]: { lambda ; a. let b = integer_pow[ y=-7 ] a c = mul -6.0 b d = mul -120.0 c in (d,) } In [5]: ``` * Use x ** 3 in gelu definition. 18 May 2020, 21:54:20 UTC
a9c1b38 Added link to README (#3139) 18 May 2020, 21:12:52 UTC
ed0e227 Add sharded_jit translation rule. (#3099) This is potentially dangerous, because it lets sharded_jit() be called inside other calls primitives (e.g. jit, pmap) which isn't supported yet. I'm adding it now because I'm planning to implement pmap-of-sharded_jit soon, and it will help with testing a set_sharding API I'm also planning to add soon. 18 May 2020, 18:39:55 UTC
bc5a0b3 Remove some uses of jax.partial. (#3131) 18 May 2020, 14:19:03 UTC
b071b12 Fixed link in FAQ (#3129) 18 May 2020, 13:02:49 UTC
2be471c Remove redundant flag configuration from lax_scipy_sparse_test (#3128) 18 May 2020, 05:12:21 UTC
912b282 Implement np.histogram_bin_edges and np.histogram (#3081) 16 May 2020, 17:23:26 UTC
back to top