81631f9 | Matthew Johnson | 30 July 2020, 05:27:34 UTC | update version and changelog for pypi | 30 July 2020, 05:27:34 UTC |
c8771e1 | Matthew Johnson | 30 July 2020, 04:44:03 UTC | add omnistaging flag placeholder (#3904) | 30 July 2020, 04:44:03 UTC |
c28a711 | Jake Vanderplas | 29 July 2020, 22:32:44 UTC | Cleanup: pass function name rather than function object (#3897) | 29 July 2020, 22:32:44 UTC |
b0ef483 | Stephan Hoyer | 29 July 2020, 22:31:02 UTC | Fixes to test_scipy_optimize.py for Google internal tests (#3902) | 29 July 2020, 22:31:02 UTC |
242b324 | Stephan Hoyer | 29 July 2020, 21:22:21 UTC | Add missing license headers (#3899) Oops! | 29 July 2020, 21:22:21 UTC |
02009e0 | Joshua George Albert | 29 July 2020, 21:14:40 UTC | BFGS algorithm (#3101) * BFGS algorithm Addressing https://github.com/google/jax/issues/1400 * * addresses @shoyer comments of PR * * skip dtype checks * * backslash in docstring * * increase closeness tol * * increase closeness atol to 1.6e-6 * * addresses jakevdp comments * * same line search as scipy * same results format * same (and more) testing as in scipy for line search and bfgs * 2 spacing * documenting * analytic hessian non default but still available * NamedTuple classes * * small fix in setup_method * * small doc string addition * * increase atol to 2e-5 for comparison * * removed experimental analytic_hessian * using jnp.where for all binary replace operations * removed _nojit as this is what disable_jit does * * fix indentation mangling * remove remaining _nojit * * fixing more indentation mangling * * segregate third_party test * * use parametrise * * use parametrise * * minor nitpicking * * fix some errors * * use _CompileAndCheck * * replace f_0 and g_0 for (ugly) scipy variable names * * remove unused function * * fix spacing * * add args argument to minimize * adhere fmin_bfgs to scipy api * * remove unused function * * ignore F401 * * look into unittest * * fix unittest error * * delete unused function * more adherence to scipy's api * add scipy's old_old_fval arg though unused * increase line_search default maxiter to 20 (10 not enough in some cases) * * remove unused imports * * add ord=norm to the initial convergence check * * remove helper function * * merge jax/master * * Resolve a remnant conflict from merging master to solve ReadTheDocs issue. * * Add an informative termination message and status number. * Revert changes to unrelated files * cleanup bfgs_minimize * cleanup minimize.py * Move minimize_bfgs.py to _bfgs.py * Move more modules around * improve docs * high precision einsum * Formatting in line search * fixup * Type checking * fix mypy failures * minor fixup Co-authored-by: Stephan Hoyer <shoyer@google.com> | 29 July 2020, 21:14:40 UTC |
659dd39 | Peter Hawkins | 29 July 2020, 20:22:12 UTC | Add MLPerf results link. (#3896) | 29 July 2020, 20:22:12 UTC |
e0a8d44 | Jake Vanderplas | 29 July 2020, 19:53:28 UTC | Add jnp.modf() & improve test coverage for related functions (#3894) | 29 July 2020, 19:53:28 UTC |
190b6af | Jake Vanderplas | 29 July 2020, 19:52:41 UTC | Improve searchsorted implementation (#3873) | 29 July 2020, 19:52:41 UTC |
c38bc36 | Jake Vanderplas | 29 July 2020, 03:54:30 UTC | jnp.linspace & friends: more carefully handle dtypes (#3859) | 29 July 2020, 03:54:30 UTC |
3098074 | Matthew Johnson | 29 July 2020, 02:46:00 UTC | refine population_count type check (#3887) * refine population_count type check fixes #3886 * allow signed/unsigned ints for population_count https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/service/shape_inference.cc;l=314?q=xla%20f:shape_inference.cc * make lax_reference.population_count handle signed | 29 July 2020, 02:46:00 UTC |
e28db33 | Jamie Townsend | 29 July 2020, 01:39:32 UTC | Fix dynamic_slice, dynamic_update_slice scalar batching, fixes #3883 (#3888) * Add test for issue 3883 * Fix dynamic_slice, dynamic_update_slice scalar batching, fixes #3883 | 29 July 2020, 01:39:32 UTC |
33faf6a | David Majnemer | 29 July 2020, 01:07:38 UTC | TPUs support half precision arithmetic (#3878) * TPUs support half precision arithmetic * update jax2tf tests to handle fp16 Co-authored-by: Matthew Johnson <mattjj@google.com> | 29 July 2020, 01:07:38 UTC |
7506a3e | Jamie Townsend | 28 July 2020, 14:27:07 UTC | Fix flaky generated_fun_test.py test (#3885) | 28 July 2020, 14:27:07 UTC |
d7733c3 | Jake Vanderplas | 28 July 2020, 13:55:10 UTC | Cleanup: canonicalize several dtypes to prevent noisy warnings (#3874) | 28 July 2020, 13:55:10 UTC |
dd7ab39 | Stephan Hoyer | 28 July 2020, 05:25:16 UTC | Fix formatting in the custom derivatives notebook (#3876) Sphinx is apparently quite picky about consistent use of headers: you can't skip a header level. We were getting warnings like "WARNING: Title level inconsistent" in the docs build, and sub-headers weren't showing up on this page after the first section. | 28 July 2020, 05:25:16 UTC |
616d63b | Matthew Johnson | 28 July 2020, 04:51:12 UTC | fix vmap error, fixes #3877 (#3879) | 28 July 2020, 04:51:12 UTC |
3aa37d3 | Vaibhav Srivastav | 27 July 2020, 21:27:36 UTC | Replicating sort_complex functionality from np.sort_complex to jax.numpy (#3870) | 27 July 2020, 21:27:36 UTC |
c9d8acd | Matthew Johnson | 27 July 2020, 05:38:14 UTC | put core trace state in a threading.local class (#3869) this is a refinement of the fix in #3845, so that we no longer need TraceState.set_state (and so that #3370 is easier to adapt) | 27 July 2020, 05:38:14 UTC |
ee7f035 | Jake Vanderplas | 26 July 2020, 15:58:37 UTC | jax.random: use correct x32/x64 default dtypes. (#3841) This is a no-op in the current package, but will make things cleaner during the x64 deprecation. | 26 July 2020, 15:58:37 UTC |
5e91965 | Matthew Johnson | 24 July 2020, 22:54:21 UTC | delete standard_parallel_primitive helper (#3858) delete standard_parallel_primitive helper meant to include this in #3853, it's even in the PR message! | 24 July 2020, 22:54:21 UTC |
5c674cf | Jake Vanderplas | 24 July 2020, 22:03:48 UTC | Properly set X64 flag in github actions (#3854) This allows the github actions CI to actually exercise tests with jax_enable_x64. | 24 July 2020, 22:03:48 UTC |
57c2fcb | Matthew Johnson | 24 July 2020, 19:52:52 UTC | tweak parallel collective translation rule api (#3853) also remove standard_parallel_primitive helper function, which wasn't very helpful | 24 July 2020, 19:52:52 UTC |
b7bcfa6 | Stephan Hoyer | 24 July 2020, 18:05:40 UTC | Create a separate internal helper function for XLA compilation (#3852) XLA backends are written in C++, so method calls don't show up in Python profiling results from cProfile. Adding an explicit function call fixes that. This is helpful for interpretting profiling results, e.g., on the example from https://github.com/google/jax/issues/3847. Before: 70814996 function calls (69915267 primitive calls) in 112.804 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 1193 24.936 0.021 30.336 0.025 xla.py:227(xla_primitive_callable) 10524/1 16.342 0.002 112.991 112.991 xla.py:595(_xla_callable) 2014622/1843062 8.745 0.000 16.618 0.000 util.py:29(safe_map) 18145 3.662 0.000 4.218 0.000 source_info_util.py:27(user_frame) 196061/183909 1.604 0.000 24.647 0.000 partial_eval.py:150(default_process_primitive) 423499 1.569 0.000 1.569 0.000 {method 'reduce' of 'numpy.ufunc' objects} After: 71147652 function calls (70235594 primitive calls) in 101.718 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 1294 38.894 0.030 38.894 0.030 xla.py:325(_backend_compile) 2017790/1844559 6.965 0.000 14.139 0.000 util.py:29(safe_map) 18146 3.317 0.000 3.839 0.000 source_info_util.py:27(user_frame) 196226/184073 1.467 0.000 21.889 0.000 partial_eval.py:150(default_process_primitive) 423771 1.419 0.000 1.419 0.000 {method 'reduce' of 'numpy.ufunc' objects} We now clearly see that both `xla_primitive_callable` and `_xla_callable` are slow for the same reason and ~40 seconds is spent inside XLA compilation. | 24 July 2020, 18:05:40 UTC |
5ae6043 | Jake Vanderplas | 24 July 2020, 16:36:16 UTC | Cleanup test names in random_test.py (#3842) | 24 July 2020, 16:36:16 UTC |
53a4538 | Peter Hawkins | 24 July 2020, 15:52:32 UTC | Fix source_info crash in Jaxpr printing (#3849) | 24 July 2020, 15:52:32 UTC |
55cd827 | George Necula | 24 July 2020, 07:14:34 UTC | [jax2tf] Selectively disable some primitive tests on TPU. (#3835) Currently, all primitive tests are disabled on TPU, but only some cases fail. This PR simply disables the selected failing tests on TPU, so that we can enable the rest by default. Future PRs will address each failure individually. | 24 July 2020, 07:14:34 UTC |
8673262 | George Necula | 24 July 2020, 07:14:08 UTC | [jax2tf] Expand the support for jax.remat. (#3828) In the simplest forms, remat is already handled by the `process_call`. But when the `remat` has outputs computed from values captured from an outer environment, we need to also implement `post_process_call`. | 24 July 2020, 07:14:08 UTC |
e2424e3 | Matthew Johnson | 24 July 2020, 03:59:12 UTC | attempt to fix CI failure (from #3845 test?) (#3846) | 24 July 2020, 03:59:12 UTC |
cc9528d | Matthew Johnson | 24 July 2020, 02:49:04 UTC | fix thread locality bug in custom_derivatives (#3845) * fix thread locality bug in custom_derivatives fixes #3843 | 24 July 2020, 02:49:04 UTC |
67ad5eb | Matthew Johnson | 24 July 2020, 02:38:56 UTC | add result_shape option to xla_computation (#3844) add result_shape option to xla_computation | 24 July 2020, 02:38:56 UTC |
f6221a6 | Peter Hawkins | 23 July 2020, 20:17:55 UTC | Enable int{8,16} and uint{8,16} tests in lax_test and lax_numpy_test. (#3833) | 23 July 2020, 20:17:55 UTC |
2796032 | Jake Vanderplas | 23 July 2020, 20:08:06 UTC | Tweak Dockerfile to prevent build failure and add TODO (#3838) | 23 July 2020, 20:08:06 UTC |
9281092 | Jake Vanderplas | 23 July 2020, 17:28:44 UTC | Improve test coverage for jax.numpy sorting algorithms (#3836) | 23 July 2020, 17:28:44 UTC |
90b3532 | Peter Hawkins | 23 July 2020, 16:10:39 UTC | Update XLA. (#3834) | 23 July 2020, 16:10:39 UTC |
7da1cba | Peter Hawkins | 23 July 2020, 15:35:29 UTC | Remove fallback to 2-pass algorithm for argmin/argmax on TPU. (#3831) (The compiler problem that prompted the workaround seems to be fixed.) | 23 July 2020, 15:35:29 UTC |
1bc5896 | Jake Vanderplas | 23 July 2020, 02:39:51 UTC | update README for new jaxlib version (#3825) | 23 July 2020, 02:39:51 UTC |
17c21a7 | Jake Vanderplas | 22 July 2020, 22:21:30 UTC | update jaxlib version and changelog for pypi (#3824) | 22 July 2020, 22:21:30 UTC |
edff693 | Jake Vanderplas | 22 July 2020, 21:20:34 UTC | update jax version and changelog for pypi (#3823) | 22 July 2020, 21:20:34 UTC |
fddb28d | Jake Vanderplas | 22 July 2020, 19:50:54 UTC | Cleanup: fix type issues in lax_numpy.py (#3816) These changes are basically a no-op wth the current default types, but fixes issues if/when those types are changed to 32-bit in the future. | 22 July 2020, 19:50:54 UTC |
04a6238 | Jake Vanderplas | 22 July 2020, 19:48:49 UTC | implement jax.numpy.lexsort (#3812) | 22 July 2020, 19:48:49 UTC |
0a3a5bb | Du Phan | 22 July 2020, 16:17:06 UTC | address nan issue (#3777) | 22 July 2020, 16:17:06 UTC |
f574b11 | James Bradbury | 22 July 2020, 06:16:27 UTC | Support preconditions on scatter indices (#3147) * wire through precondition flags to XLA scatter * use scatter precondition flags in tests * fix DUS batching rule * make new arguments kw-only * onp -> np * fix jax2tf for new args * fix more test failures | 22 July 2020, 06:16:27 UTC |
74d363e | bion howard | 21 July 2020, 19:41:08 UTC | fix extremely minor typo (#3815) "ijnputs" -> "inputs" | 21 July 2020, 19:41:08 UTC |
3a3c8ea | Matthew Johnson | 21 July 2020, 13:48:55 UTC | tweak jnp.repeat not to use jnp on shapes (#3810) tweak jnp.repeat not to use jnp on shapes | 21 July 2020, 13:48:55 UTC |
e2bfc90 | Benjamin Chetioui | 21 July 2020, 13:38:12 UTC | [jax2tf] First draft of top_k conversion to tf. (#3785) * [jax2tf] First draft of top_k conversion to tf. We are able to use tf.math.top_k in most cases but: - bfloat16 and complex numbers are unsupported - tf.math.top_k works properly with args of type int32/64, uint32/64 and float32/64. Other flavors of uints/ints and bool can not be used with the operation by default, and they are thus promoted to their equivalent 32 bit representation, and the result is casted back to the original dtype. A limitation right now is the handling of edge cases: TF and JAX order np.inf and np.nan differently, resulting in inconsistencies in the case where these values are present. Note: this has been sanity checked, the tests fail if we set sorted=False. | 21 July 2020, 13:38:12 UTC |
9773432 | Benjamin Chetioui | 21 July 2020, 13:36:35 UTC | [jax2tf] First draft of testing the QR conversion. (#3775) * [jax2tf] First draft of testing the QR conversion. QR decomposition is off by over 1e-6 in some instances, requiring custom atol and rtol values in testing code. There is an odd problem in which experimental compilation fails for complex types, although they are in principle supported. | 21 July 2020, 13:36:35 UTC |
71f80a5 | Jake Vanderplas | 21 July 2020, 00:18:08 UTC | Fix type mismatch in jet rule for abs (#3807) | 21 July 2020, 00:18:08 UTC |
a6e2d20 | Peter Hawkins | 20 July 2020, 21:27:24 UTC | Add support for base dilation and window dilation to reduce window op… (#3803) | 20 July 2020, 21:27:24 UTC |
ce14409 | Claudio Fantacci | 20 July 2020, 20:15:40 UTC | Fix `jax.image._resize` function (#3805) This PR fixes a bug in jax.image._resize where the local `method_id` variable may be used without being defined first. This bug can be easily reproduced by passing to `jax.image.resize` parameter `method` a `ResizeMethod` instead of an `str`. By doing this, `method_id` is never defined and the instruction `if method_id == ResizeMethod.NEAREST` raises an error. Currently, this can be easily bypassed assigning parameter `method` a `str`. To fix this bug, it only needs to rename `method_id` to `method`, the same name of the input parameter. | 20 July 2020, 20:15:40 UTC |
f6b3184 | Stephan Hoyer | 20 July 2020, 18:10:10 UTC | add clarification about jit inside indexing error message (#3804) | 20 July 2020, 18:10:10 UTC |
fe99a06 | Stephan Hoyer | 20 July 2020, 13:08:54 UTC | Error message and docstring updates RE: dynamic_slice (#3795) This should clarify the underlying issues from #1007 and #3794. It might be worth mentioning masking, but that's a little big for fitting into an error message. Maybe once the masking transformation is non-experimental or if we had a dedicated doc page. | 20 July 2020, 13:08:54 UTC |
dd3cb82 | Lena Martens | 20 July 2020, 12:59:13 UTC | Enable buffer donation for GPU. (#3800) | 20 July 2020, 12:59:13 UTC |
d433934 | John Aslanides | 18 July 2020, 20:36:23 UTC | Add a TypeVar to while_loop definition. (#3792) | 18 July 2020, 20:36:23 UTC |
fa2a027 | Roy Frostig | 17 July 2020, 22:01:12 UTC | revert #3674 | 17 July 2020, 22:44:51 UTC |
2df486e | Skye Wanderman-Milne | 17 July 2020, 22:11:26 UTC | Note in pmap docs that pmap compiles like jit. (#3787) | 17 July 2020, 22:11:26 UTC |
9a01d78 | Peter Hawkins | 17 July 2020, 21:38:05 UTC | Merge pull request #3782 from hawkinsp/dot Don't move batch dimensions to start in jnp.einsum. | 17 July 2020, 21:38:05 UTC |
3eda6e9 | Peter Hawkins | 17 July 2020, 21:36:29 UTC | Merge pull request #3783 from hawkinsp/matmul Avoid broadcasts in implementation of jnp.matmul | 17 July 2020, 21:36:29 UTC |
57c1822 | Peter Hawkins | 17 July 2020, 21:14:03 UTC | Merge pull request #3789 from hawkinsp/rw Improve reduce-window testing. | 17 July 2020, 21:14:03 UTC |
7c4d41f | Peter Hawkins | 17 July 2020, 21:13:42 UTC | Merge pull request #3788 from hawkinsp/convtol Relax test tolerance for conv_general_dilated gradients. | 17 July 2020, 21:13:42 UTC |
7e3433e | Peter Hawkins | 17 July 2020, 20:05:51 UTC | Improve reduce-window testing. | 17 July 2020, 20:05:51 UTC |
e04fdeb | Peter Hawkins | 17 July 2020, 19:07:17 UTC | Relax test tolerance for conv_general_dilated gradients. The new support for complex types leads to slightly higher errors. | 17 July 2020, 19:07:17 UTC |
7b3eff8 | Peter Hawkins | 17 July 2020, 18:55:36 UTC | Merge pull request #3786 from hawkinsp/einsumargs Use keyword arguments in einsum. | 17 July 2020, 18:55:36 UTC |
f7260a4 | Roy Frostig | 17 July 2020, 18:32:36 UTC | stack traces without jax-internal frames (#3674) | 17 July 2020, 18:32:36 UTC |
252a027 | Peter Hawkins | 17 July 2020, 18:15:52 UTC | Update numpy signatures test. | 17 July 2020, 18:15:52 UTC |
8fc1332 | Peter Hawkins | 17 July 2020, 17:58:02 UTC | Avoid broadcasting for batch dimensions in jnp.matmul. Instead, squeeze size 1 dimensions out of the matmul input, and transpose any non-batch non-contracting dimensions into the correct location. | 17 July 2020, 17:58:02 UTC |
8e986f9 | Peter Hawkins | 17 July 2020, 16:54:36 UTC | Use keyword arguments in einsum. Python 3 cleanup only, no functional changes intended. | 17 July 2020, 16:54:36 UTC |
6a37e26 | Peter Hawkins | 17 July 2020, 12:57:09 UTC | Don't move batch dimensions to start in jnp.einsum. We no longer require batch dimensions to be first, and this simplifies the code. | 17 July 2020, 12:57:09 UTC |
943c421 | Peter Hawkins | 17 July 2020, 12:55:23 UTC | Relax dimension ordering rules for dot_general. (#3778) JAX currently requires that batch dimensions appear first and contiguously in the arguments to dot_general. However, XLA does not require this; relax JAX's checks so that it also allows batch dimensions in arbitrary positions. Since batch dimensions are now allowed in arbitrary positions, it's not hard to generalize the dot_general batching rule to avoid performing any transposes (#2972). In passing, also move the bool/int dot expansion into the XLA translation rule. The expansion inside the `lax.dot_general()` wrapper predated the existence of (or at least my knowledge of) `xla.lower_fun()`. | 17 July 2020, 12:55:23 UTC |
e4d5ead | Peter Hawkins | 17 July 2020, 12:44:47 UTC | Use iteration over equations to test for "transpose" and "broadcast". | 17 July 2020, 12:44:47 UTC |
165f31e | Peter Hawkins | 17 July 2020, 12:38:33 UTC | Also test for transpose in dot vmap test. | 17 July 2020, 12:38:33 UTC |
dd4db64 | Roy Frostig | 17 July 2020, 00:11:52 UTC | docstring for api_boundary | 17 July 2020, 00:19:16 UTC |
733f6c4 | Roy Frostig | 16 July 2020, 20:32:01 UTC | disable false pytype error | 17 July 2020, 00:12:09 UTC |
6416ca0 | Roy Frostig | 02 July 2020, 19:18:01 UTC | append filtered stack traces to error messages raised under transformations | 17 July 2020, 00:12:09 UTC |
e2e73a8 | Peter Hawkins | 16 July 2020, 20:23:27 UTC | Relax dimension ordering rules for dot_general. JAX currently requires that batch dimensions appear first and contiguously in the arguments to dot_general. However, XLA does not require this; relax JAX's checks so that it also allows batch dimensions in arbitrary positions. Since batch dimensions are now allowed in arbitrary positions, it's not hard to generalize the dot_general batching rule to avoid performing any transposes (#2972). In passing, also move the bool/int dot expansion into the XLA translation rule. The expansion inside the `lax.dot_general()` wrapper predated the existence of (or at least my knowledge of) `xla.lower_fun()`. | 16 July 2020, 23:36:22 UTC |
3fb8874 | Jake Vanderplas | 16 July 2020, 22:53:16 UTC | [x64 deprecation] Create _np_array utility routine (#3779) | 16 July 2020, 22:53:16 UTC |
44fbce5 | Skye Wanderman-Milne | 16 July 2020, 22:22:40 UTC | Revert "Add `in_parts` and `out_parts` optional arguments `jax.xla_computation`. (#3771)" (#3780) This reverts commit dbc3f83f6d14d491a06137f698aca92f7f3c572d. This is breaking some google-internal users of xla_computation. Reverting while I investigate. | 16 July 2020, 22:22:40 UTC |
bfe8e4f | Skye Wanderman-Milne | 16 July 2020, 16:19:41 UTC | Enable `all_to_all` in multi-host settings. (#3772) I tested this via `pswapaxes` and it seems to work. There may still be issues with all_to_all (e.g. https://github.com/google/jax/issues/1332) but it seems worth enabling the `pswapaxes` use case. | 16 July 2020, 16:19:41 UTC |
7b18a4e | Benjamin Chetioui | 16 July 2020, 13:44:20 UTC | [jax2tf] Fix interface of ConvertAndCompare function. (#3776) | 16 July 2020, 13:44:20 UTC |
15125b8 | George Necula | 16 July 2020, 13:06:34 UTC | [jax2tf] Fix bug with double XLA compile (#3765) * [jax2tf] Fix bug with double XLA compile Some converted ops require XLA compilation to work around bugs in TF where the behavior without XLA is incorrect. If those ops are then part of an outer tf.function(experimental_compile=True) then we get a TF error. This change primarily detects that we are in an compilation context and does not use XLA compiler for ops. This, however, changes the error behavior for dynamic_slice. Improved the testing, to skip fewer tests but instead to expect and check for errors. | 16 July 2020, 13:06:34 UTC |
dbc3f83 | Skye Wanderman-Milne | 15 July 2020, 21:56:58 UTC | Add `in_parts` and `out_parts` optional arguments `jax.xla_computation`. (#3771) This allows partitioned computations in `xla_computation`, like those produced by `sharded_jit`. | 15 July 2020, 21:56:58 UTC |
05904fa | Jake Vanderplas | 15 July 2020, 20:17:38 UTC | Change onp/np to np/jnp in docs & notebooks (#3760) | 15 July 2020, 20:17:38 UTC |
150d028 | Claudio Fantacci | 15 July 2020, 18:19:40 UTC | Update scipy.ndimage.map_coordinates docstring (#3762) | 15 July 2020, 18:19:40 UTC |
34669d7 | Jonathan Godwin | 15 July 2020, 18:03:58 UTC | source sync (#3763) | 15 July 2020, 18:03:58 UTC |
8a62a9b | Roy Frostig | 15 July 2020, 18:00:50 UTC | block-unrolled scan primitive implementation (#3738) * block-unrolled scan implementation, via optional `_unroll` scan parameter * index statically in the inlined path of lax.scan * make `unroll` a required scan parameter, and test that it unrolls | 15 July 2020, 18:00:50 UTC |
23c279f | George Necula | 15 July 2020, 15:19:22 UTC | [jax2tf] Relax tolerance for JaxPrimitiveTest.test_unary_elementwise_lgamma_float32 (#3767) | 15 July 2020, 15:19:22 UTC |
d55c47d | SIben | 15 July 2020, 12:12:42 UTC | [jax2tf] Fix error in sorting with TF graphs. (#3764) | 15 July 2020, 12:12:42 UTC |
c380356 | George Necula | 15 July 2020, 06:49:51 UTC | [jax2tf] Refactor tests to increase coverage (#3700) * [jax2tf] Refactor tests to increase coverage. This change has several goals: (a) increase the test coverage for the elementwise primitives, (b) expose explicitly the situations where JAX and TF produce different results, e.g., inf vs Nan, and (c) run all comparisons with and without tf.function, with and without experimental_compile=True. Previously the test code was just masking off non-finite values. This change uncovered quite a few unimplemented cases, e.g., with float16, bfloat16, conversions that cannot be compiled. This are left as TODO for now. * Disable the sort tests | 15 July 2020, 06:49:51 UTC |
68c8dc7 | SIben | 15 July 2020, 04:56:19 UTC | Added installation note for jax2tf with GPU support. (#3750) | 15 July 2020, 04:56:19 UTC |
f02d5b4 | Joan Puigcerver | 14 July 2020, 22:45:49 UTC | Support differentiation through jax.lax.all_to_all (#3733) * Support differentiation through jax.lax.all_to_all Credit to @levskaya for the solution. * Test gradient of all_to_all We are testing all_to_all through pswapaxes, since general all_to_all is problematic according to https://github.com/google/jax/issues/1332. * Removed trailing spaces | 14 July 2020, 22:45:49 UTC |
a7c2cde | Jake Vanderplas | 14 July 2020, 20:05:31 UTC | Cleanup: convert uses of `import numpy as onp` in library code (#3754) | 14 July 2020, 20:05:31 UTC |
512ed18 | Jake Vanderplas | 14 July 2020, 20:03:24 UTC | Cleanup: convert uses of 'import numpy as onp' in tests (#3756) | 14 July 2020, 20:03:24 UTC |
58aba9b | Peter Hawkins | 14 July 2020, 18:10:13 UTC | Fix low probability (4/1000) flake of batching_test on GPU. (#3752) | 14 July 2020, 18:10:13 UTC |
f6f9755 | Jonathan Godwin | 14 July 2020, 17:37:09 UTC | A jit-able version of np.repeat. (#3670) A new keyword argument has been added to np.repeat, total_repeat_length, that can optionally be supplied to make np.repeat jit-able. | 14 July 2020, 17:37:09 UTC |
0605321 | Jake Vanderplas | 14 July 2020, 17:24:42 UTC | Fix compilation bug in histogram_bin_edges (#3745) | 14 July 2020, 17:24:42 UTC |
e1f8f24 | Jake Vanderplas | 14 July 2020, 17:21:38 UTC | lax_numpy: rename arguments to match numpy (#3747) | 14 July 2020, 17:21:38 UTC |
7f9c2f1 | Peter Hawkins | 14 July 2020, 16:02:26 UTC | Make jnp.take work for empty slices of empty arrays. (#3751) | 14 July 2020, 16:02:26 UTC |
2b7a39f | Chase Roberts | 14 July 2020, 13:05:45 UTC | Add pshuffle to docs (#3742) | 14 July 2020, 13:05:45 UTC |
f78ccf1 | James Bradbury | 14 July 2020, 01:16:11 UTC | Fix tuple() in reduce_window padding (#3748) * Fix tuple() in reduce_window padding * Update lax.py | 14 July 2020, 01:16:11 UTC |
a017c10 | Peter Hawkins | 14 July 2020, 00:27:12 UTC | Implement nearest neighbor image resizes. (#3743) | 14 July 2020, 00:27:12 UTC |