swh:1:snp:71afc25eb6e6e055a37a962e6b91010ec35e397f

sort by:
Revision Author Date Message Commit Date
3e4210c Update jax version to 0.1.70 (#3383) 09 June 2020, 16:55:05 UTC
99401c5 Initial implementation of variadic lax.reduce() (#3342) 09 June 2020, 16:22:29 UTC
d3ccf0a fix typo in docs 09 June 2020, 14:35:07 UTC
307701c Merge pull request #3331 from JuliusKunze/mask-split Allow mask(jnp.split) 09 June 2020, 14:18:51 UTC
b7175a3 Merge pull request #3322 from j-towns/mask-of-jit Support mask(jit) 09 June 2020, 14:16:52 UTC
67927b0 Fix imports for jax_to_tf tests (#3377) 09 June 2020, 09:28:03 UTC
c12541d Refactor the jax_to_tf tests to separate the primitive test harness (#3376) * Refactor the jax_to_tf tests to separate the primitive test harness from the test. The goal is to have a collection of test harnesses for the JAX primitives to be able to test various implementation (JAX, NumPy, TensorFlow). For now we use these harnesses only in the jax_to_tf tests, although we can later use them for lax_test. Demonstrate the use of the harness for lax.pad and lax.squeeze, both in tf_ops_test and lax_test. The plan is to add support for more primitives as we make progress testing jax_to_tf. * Expanded pad harness with negative pads 09 June 2020, 08:07:32 UTC
6aa8f24 Fix remaining flakes and use exclude within setup.cfg (#3371) 09 June 2020, 05:58:03 UTC
563b65e minor fixes from #3222 (#3368) * clean up handling of aux data in jvp_subtrace_aux related to #3222 indirectly, in that we now won't try to do something crazy like `JVPTracer(trace, object(), zero)`, which #3222 didn't like * symbolic zeros tweak to work with div rule * fix a couple ad_util.Zero type checks * fix a docs bug * improve broadcast handling in batching.py We can avoid a circular import just by `import jax`! * fix another docs bug * yet another doc fix 08 June 2020, 22:11:49 UTC
4ed8cce fix mypy error 08 June 2020, 22:11:22 UTC
ee42800 yet another doc fix 08 June 2020, 22:06:00 UTC
2a6d3f4 fix another docs bug 08 June 2020, 21:13:01 UTC
04191ab improve broadcast handling in batching.py We can avoid a circular import just by `import jax`! 08 June 2020, 20:46:10 UTC
3b3e7a2 avoid some trival operations in mask computations (#3364) I think #2800 accidentally removed `mul`, `pow`, and `prod` from mask.py, but those were there to avoid some trivial computations. In particular, on this program: ```python @partial(mask, in_shapes=['n'], out_shape='') def foo(x): return np.sum(x) padded_x = np.array([0, 1, 2, 3, 999, 999]) print(make_jaxpr(foo)([padded_x], dict(n=3))) ``` after #2800 (and until this commit) we'd print ``` { lambda c h ; a b. let d = mul b 1 e = mul d 1 f = add e 0 g = let c f i = select g a h j = reduce_sum [ axes=(0,) ] i in (j,) } ``` but before #2800 (and after this commit) we print ``` { lambda c e ; a b. let d = lt c b f = select d a e g = reduce_sum[ axes=(0,) ] f in (g,) } ``` This might save a tiny bit of work, but also it means the make_jaxpr results are cleaner, and we like to show those! @j-towns spotted the fix cc @JuliusKunze 08 June 2020, 20:42:35 UTC
0a71697 fix a docs bug 08 June 2020, 20:30:00 UTC
866c17c fix a couple ad_util.Zero type checks 08 June 2020, 20:22:13 UTC
fb8b3a4 symbolic zeros tweak to work with div rule 08 June 2020, 20:16:19 UTC
982f867 clean up handling of aux data in jvp_subtrace_aux related to #3222 indirectly, in that we now won't try to do something crazy like `JVPTracer(trace, object(), zero)`, which #3222 didn't like 08 June 2020, 18:48:58 UTC
528207b fix flakes at head (#3361) 08 June 2020, 17:45:00 UTC
b1adef4 Fix polyadd and test (#3344) 08 June 2020, 17:06:20 UTC
65d95f1 A couple of ad_util.zero were missed in #3222 (#3363) 08 June 2020, 16:59:25 UTC
508821c fix a one-character issue from a bad merge (#3362) cf. #3222 08 June 2020, 16:47:32 UTC
dc41ff9 Update optix.py (#3355) 08 June 2020, 16:28:41 UTC
e36c72b Make ad_util.zero a class that carries avals (similar to UndefinedPrimal) (#3222) 08 June 2020, 15:50:14 UTC
96d72fc Edit documentation. (#3358) Use :func:, :class: and :meth: when referring to Python objects. Use :ref: for hyperlinks. Fix some bad formatting. 08 June 2020, 14:37:50 UTC
701bc76 Update README to recommend Windows Subsystem for Linux (WSL) to Windows users. (#3356) 08 June 2020, 13:03:02 UTC
6b948e6 Generate type hints for parameters, even if the parameters aren't documented. (#3337) This change tells sphinx-autodoc-typehint to generate stub parameter documentation for arguments that have type hints. 08 June 2020, 12:56:10 UTC
7863f81 Commented-out the literal jaxpr checks in host_callback (#3351) * Commented-out the literal jaxpr checks in host_callback Will re-enable when we change the host_callback, or when we have better tools for updating goldens 07 June 2020, 11:45:15 UTC
fd886b1 make jnp.array faster (#3350) fixes #2919 07 June 2020, 04:44:14 UTC
2a10dbb deflake remainder of jax (#3343) 06 June 2020, 17:51:34 UTC
fb17172 Cleanup: deflake jax.experimental and jax.ops (#3329) 06 June 2020, 02:00:04 UTC
1a3c180 Don't expand type aliases when generating JAX documentation. (#3336) 05 June 2020, 21:21:18 UTC
b65abdf Fix missing sharded_jit_test import 05 June 2020, 19:03:37 UTC
0b5d885 Handle with_sharding_constraint inside arbitrary subjaxprs. (#3339) Also fixes a bug where a replicated sharding constraint would incorrectly trigger an error. 05 June 2020, 18:34:20 UTC
29740de Add np.polysub (#3319) 05 June 2020, 16:44:10 UTC
841f21f Enable SVD on TPU. (#3334) 05 June 2020, 16:21:30 UTC
e2a32f7 Rebase fixes 05 June 2020, 15:52:03 UTC
17472b9 Optimize zeros_like_shaped_array This function is used a lot more now, because `ad.instantiate_zeros` now goes through that and not `zeros_like_array`. 05 June 2020, 15:52:03 UTC
cfa3b78 Remove accidental prints 05 June 2020, 15:52:03 UTC
7a0e8ba Update notebooks 05 June 2020, 15:52:03 UTC
3f1d3a7 Remove example from ad.instantiate_zeros, fix vmap bug 05 June 2020, 15:52:01 UTC
c5d8707 Fix host_callback 05 June 2020, 15:51:30 UTC
adb442e Make ad_util.zero a class that carries avals (similar to UndefinedPrimal) This is useful for remat transpose rule submitted in #3162 and e.g. allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it was unnecessarily declared as having multiple outputs). 05 June 2020, 15:51:30 UTC
74d160f Don't keep primal arguments and results in the linearized jaxpr (#3233) Linearized functions are supposed to take tangent types to tangent types, and so all primal arguments are unused and primal results get replaced by units. 05 June 2020, 15:22:55 UTC
ea78222 Allow mask(split) 05 June 2020, 07:12:07 UTC
93646b5 Improve error when zero-sized arrays passed to convolve (#3325) * Improve error when zero-sized arrays passed to convolve * Apply suggestions from code review Co-authored-by: Matthew Johnson <mattjj@google.com> Co-authored-by: Matthew Johnson <mattjj@google.com> 05 June 2020, 03:28:19 UTC
969ed6d Initial implementation of polymul function (#3303) 05 June 2020, 03:27:29 UTC
07aa795 Apply suggestions from code review Co-authored-by: Matthew Johnson <mattjj@google.com> 04 June 2020, 23:54:28 UTC
657863f BUG: fix column_stack and add tests (#3328) 04 June 2020, 23:34:16 UTC
a63b9cc Cleanup: deflake interpreters, lib, nn, third_party, and tools (#3327) 04 June 2020, 22:27:48 UTC
bc51e9c deflake jax/scipy/* and add to setup.cfg (#3316) 04 June 2020, 21:38:41 UTC
b187663 deflake jax/lax & add to flake8 check (#3310) 04 June 2020, 20:50:44 UTC
4544436 Improve error when zero-sized arrays passed to convolve 04 June 2020, 20:25:10 UTC
9c0a58a add float dtype checks to random.py (#3320) fixes #3317 04 June 2020, 17:13:15 UTC
38d4837 Fix x64 test 04 June 2020, 14:05:14 UTC
0f00327 Implement MaskTrace.post_process_call 04 June 2020, 11:59:29 UTC
dfe3462 Add device_put_p abstract_eval rule 04 June 2020, 11:28:38 UTC
c04dea1 Begin implementing mask(jit) 04 June 2020, 10:25:30 UTC
71f1c5c Refactoring of jax_to_tf tests: (#3262) (#3308) * Moved control-flow tests into their own file * Added a helper module tf_test_util, with a helper function ConvertAndCompare * Used self.assertAllClose instead of numpy.testing.assert_all_close because the former iterates over lists and tuples (and is standard in other JAX tests) * Used @parameterized.named_parameters for parameterized tests, for nicer test names. 04 June 2020, 06:54:06 UTC
afa9276 Implement jax_to_tf.scan (#3307) 04 June 2020, 06:41:45 UTC
c49bb75 update changelog with lax.switch 04 June 2020, 05:19:15 UTC
bd3cab9 update jaxpr doc to reflect lax.switch and indexed cond 04 June 2020, 05:19:15 UTC
6015a2a introduce lax.switch 04 June 2020, 05:19:15 UTC
dc4c9f0 change cond primitive to an indexed conditional with multiple branch functions in the core: * bind and check cond primitive in indexed form * rewrite abstract evaluation rule * rewrite translation rule * rewrite partial evaluation rule * rewrite batching rule * rewrite JVP rule * rewrite transpose rule * update jaxpr typechecker * update pretty printer * update outfeed-usage check * update reference jaxpr in cond jaxpr test * update reference regexes in HLO test in experimental modules: * update host_callback rewriter * update loops expression builder * generalize tf_impl rule 04 June 2020, 05:19:15 UTC
4f5547d Don't AD through max-subtraction in softmax (#2260) * Don't AD through max-subtraction in softmax * Also stop-grad the max in logsumexp 04 June 2020, 00:00:54 UTC
5ad9fed Fix handling of infeed token inside sharded_jit (#3313) 03 June 2020, 22:23:49 UTC
c77c083 deflake jax.numpy and add to flake8 check (#3312) 03 June 2020, 21:18:48 UTC
d1dbf7c Implement mask for some primitives + jit. (#2922) * Implement mask for slice, conv, pad, transpose, where * Remove tentative mask(jit) * Add explanatory comment to dot_general masking rule * Rm reshape from select masking rule * Rm unnecessary check from lax slice abstract_eval rule * Revert to standard indentation in masking_test.py * Begin simplifying masking tests * Finish drafting masking check function * More progress simplifying tests * Add conv masking in batch dim * Finish fixing up tests * Revert to old API, making out_shape compulsory again * More efficient conv masking rule * Tidy up masking_test imports * Check that out tree is preserved by masking * fix flake errors Co-authored-by: Jamie Townsend <jamestownsend@google.com> Co-authored-by: Jamie Townsend <jamiehntownsend@gmail.com> Co-authored-by: Matthew Johnson <mattjj@google.com> 03 June 2020, 20:40:48 UTC
0db57cb Fix validation code in lax.conv (#3279) 03 June 2020, 17:33:19 UTC
b998044 Add np.polyadd (#3261) 03 June 2020, 17:26:35 UTC
66ba734 Add note to docs describing how pytree arguments work. (#3284) Addresses #3095. I'm not sure if we wanna link to this from API docstrings. This also subsumes the original pytrees notebook. 03 June 2020, 16:46:00 UTC
f8bab4a update version and changelog for pypi 03 June 2020, 14:45:53 UTC
0e229e4 keep old name 'packed_state' of OptimizerState 03 June 2020, 14:32:44 UTC
b58eec5 make pmap axis checking an exception, hoist (#3239) 03 June 2020, 03:28:59 UTC
538691b remove `pack` from optimizers.py (#3305) It is vestigial, from a time when JaxTuples roamed free. 03 June 2020, 03:28:21 UTC
9ee4ef1 Cleanup: de-lint tests directory & add flake8 to travis (#3304) * Cleanup: fix lint errors in tests/*.py * Add flake8 step to travis * add setup.cfg 03 June 2020, 02:25:47 UTC
177e7cf moved check_jaxpr code around to match eval_jaxpr (#3240) * moved check_jaxpr code around to match eval_jaxpr This change is mostly stylistic; it brings check_jaxpr closer to eval_jaxpr (and the other jaxpr interpreters) in organization. There's a slight tweak to an error message which lets us save some slightly redundant code. * fixes and tweaks 03 June 2020, 02:10:55 UTC
8c0b27b tweak jax.experimental readme 03 June 2020, 02:00:54 UTC
c42a7f7 remove some trailing whitespace (#3287) 03 June 2020, 00:37:20 UTC
ea4277b Fix broken jnp.nancumsum() & jnp.nancumprod() and add tests (#3277) 02 June 2020, 23:45:44 UTC
dc4761c Fix type promotion for real FFTs. (#3300) Only enable gradient test in x64 mode. 02 June 2020, 21:04:52 UTC
f642d11 Reverting jax_to_tf (scan and tests) (#3299) * Revert "Refactoring of jax_to_tf tests: (#3262)" This reverts commit 38bfcee753893ced209c9829f5675fded39b1911. * Revert "Implement jax_to_tf.scan (#3260)" This reverts commit d36429b5fd992cb16081f44dfd787f28c296e0a8. 02 June 2020, 19:59:37 UTC
4d0e2ff add jnp.diagflat() to docs (#3298) 02 June 2020, 19:08:25 UTC
6ddac1d Disabled host_callback infrastructure for the HLO interpreter backend, which doesn't support infeed/outfeed. (#3294) 02 June 2020, 17:16:13 UTC
1eb7f1b Use onp instead of np in ode_test (#3288) * Use onp instead of np in ode_test * other ode_test.py fixes Co-authored-by: Matthew Johnson <mattjj@google.com> 02 June 2020, 16:54:51 UTC
dd81a8d Fix some type errors in lax.py found by pytype. (#3292) 02 June 2020, 14:27:14 UTC
042df4e Fix pytype errors. (#3291) 02 June 2020, 14:26:43 UTC
a06b122 Add support for 64-bit FFTs. (#3290) 02 June 2020, 13:41:44 UTC
3909875 Improve speed of tracing dynamic_update_slice (#3247) * Improve tracing performance of _dynamic_slice_indices * More precisely preserve semantics of dynamic_slice_indices * Use safe_map in dynamic_slice_indices 02 June 2020, 13:37:32 UTC
38bfcee Refactoring of jax_to_tf tests: (#3262) * Moved control-flow tests into their own file * Added a helper module tf_test_util, with a helper function ConvertAndCompare * Used self.assertAllClose instead of numpy.testing.assert_all_close because the former iterates over lists and tuples (and is standard in other JAX tests) * Used @parameterized.named_parameters for parameterized tests, for nicer test names. 02 June 2020, 11:41:36 UTC
d36429b Implement jax_to_tf.scan (#3260) Also removed the enable_jit, which was needed only to work around the lack of control flow primitive support. 02 June 2020, 09:35:28 UTC
7a4b222 Added support for np.diagflat (#3259) 02 June 2020, 03:43:43 UTC
34065df Add some type annotations to core and partial_eval. (#3251) 02 June 2020, 01:45:36 UTC
f1a7073 pmap(in_axes=None) of sharded_jit (#3257) * pmap(in_axes=None) of sharded_jit Co-authored-by: Skye Wanderman-Milne <skyewm@google.com> * address comments Co-authored-by: Skye Wanderman-Milne <skyewm@google.com> 01 June 2020, 23:50:22 UTC
972c7fd Fix bug where jnp.array returned a classic NumPy array, sometimes wit… (#3283) * Fix bug where jnp.array returned a classic NumPy array, sometimes with the wrong type. Unconditionally calls `device_put`, because `lax.convert_element_type` has a fast path that sometimes fails to lead to a `device_put`. Improve the test for `jnp.array` and its test harness. 01 June 2020, 23:29:26 UTC
00555b7 Remove duplicate test (#3275) 01 June 2020, 23:26:01 UTC
0eab560 Fix duplicated test name (#3273) 01 June 2020, 22:28:57 UTC
cf62419 Documentation fixes. (#3282) Improve some cross-references and poorly quoted text. 01 June 2020, 22:09:45 UTC
858f1e5 add missing core import in lax_test 01 June 2020, 21:47:14 UTC
fffdb2d Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280) * Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs. Default to check_dtypes=True. Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense. No functional changes intended. * Fix a number of lax reference implementations to preserve types. 01 June 2020, 21:19:23 UTC
back to top