https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
81631f9 update version and changelog for pypi 30 July 2020, 05:27:34 UTC
c8771e1 add omnistaging flag placeholder (#3904) 30 July 2020, 04:44:03 UTC
c28a711 Cleanup: pass function name rather than function object (#3897) 29 July 2020, 22:32:44 UTC
b0ef483 Fixes to test_scipy_optimize.py for Google internal tests (#3902) 29 July 2020, 22:31:02 UTC
242b324 Add missing license headers (#3899) Oops! 29 July 2020, 21:22:21 UTC
02009e0 BFGS algorithm (#3101) * BFGS algorithm Addressing https://github.com/google/jax/issues/1400 * * addresses @shoyer comments of PR * * skip dtype checks * * backslash in docstring * * increase closeness tol * * increase closeness atol to 1.6e-6 * * addresses jakevdp comments * * same line search as scipy * same results format * same (and more) testing as in scipy for line search and bfgs * 2 spacing * documenting * analytic hessian non default but still available * NamedTuple classes * * small fix in setup_method * * small doc string addition * * increase atol to 2e-5 for comparison * * removed experimental analytic_hessian * using jnp.where for all binary replace operations * removed _nojit as this is what disable_jit does * * fix indentation mangling * remove remaining _nojit * * fixing more indentation mangling * * segregate third_party test * * use parametrise * * use parametrise * * minor nitpicking * * fix some errors * * use _CompileAndCheck * * replace f_0 and g_0 for (ugly) scipy variable names * * remove unused function * * fix spacing * * add args argument to minimize * adhere fmin_bfgs to scipy api * * remove unused function * * ignore F401 * * look into unittest * * fix unittest error * * delete unused function * more adherence to scipy's api * add scipy's old_old_fval arg though unused * increase line_search default maxiter to 20 (10 not enough in some cases) * * remove unused imports * * add ord=norm to the initial convergence check * * remove helper function * * merge jax/master * * Resolve a remnant conflict from merging master to solve ReadTheDocs issue. * * Add an informative termination message and status number. * Revert changes to unrelated files * cleanup bfgs_minimize * cleanup minimize.py * Move minimize_bfgs.py to _bfgs.py * Move more modules around * improve docs * high precision einsum * Formatting in line search * fixup * Type checking * fix mypy failures * minor fixup Co-authored-by: Stephan Hoyer <shoyer@google.com> 29 July 2020, 21:14:40 UTC
659dd39 Add MLPerf results link. (#3896) 29 July 2020, 20:22:12 UTC
e0a8d44 Add jnp.modf() & improve test coverage for related functions (#3894) 29 July 2020, 19:53:28 UTC
190b6af Improve searchsorted implementation (#3873) 29 July 2020, 19:52:41 UTC
c38bc36 jnp.linspace & friends: more carefully handle dtypes (#3859) 29 July 2020, 03:54:30 UTC
3098074 refine population_count type check (#3887) * refine population_count type check fixes #3886 * allow signed/unsigned ints for population_count https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/compiler/xla/service/shape_inference.cc;l=314?q=xla%20f:shape_inference.cc * make lax_reference.population_count handle signed 29 July 2020, 02:46:00 UTC
e28db33 Fix dynamic_slice, dynamic_update_slice scalar batching, fixes #3883 (#3888) * Add test for issue 3883 * Fix dynamic_slice, dynamic_update_slice scalar batching, fixes #3883 29 July 2020, 01:39:32 UTC
33faf6a TPUs support half precision arithmetic (#3878) * TPUs support half precision arithmetic * update jax2tf tests to handle fp16 Co-authored-by: Matthew Johnson <mattjj@google.com> 29 July 2020, 01:07:38 UTC
7506a3e Fix flaky generated_fun_test.py test (#3885) 28 July 2020, 14:27:07 UTC
d7733c3 Cleanup: canonicalize several dtypes to prevent noisy warnings (#3874) 28 July 2020, 13:55:10 UTC
dd7ab39 Fix formatting in the custom derivatives notebook (#3876) Sphinx is apparently quite picky about consistent use of headers: you can't skip a header level. We were getting warnings like "WARNING: Title level inconsistent" in the docs build, and sub-headers weren't showing up on this page after the first section. 28 July 2020, 05:25:16 UTC
616d63b fix vmap error, fixes #3877 (#3879) 28 July 2020, 04:51:12 UTC
3aa37d3 Replicating sort_complex functionality from np.sort_complex to jax.numpy (#3870) 27 July 2020, 21:27:36 UTC
c9d8acd put core trace state in a threading.local class (#3869) this is a refinement of the fix in #3845, so that we no longer need TraceState.set_state (and so that #3370 is easier to adapt) 27 July 2020, 05:38:14 UTC
ee7f035 jax.random: use correct x32/x64 default dtypes. (#3841) This is a no-op in the current package, but will make things cleaner during the x64 deprecation. 26 July 2020, 15:58:37 UTC
5e91965 delete standard_parallel_primitive helper (#3858) delete standard_parallel_primitive helper meant to include this in #3853, it's even in the PR message! 24 July 2020, 22:54:21 UTC
5c674cf Properly set X64 flag in github actions (#3854) This allows the github actions CI to actually exercise tests with jax_enable_x64. 24 July 2020, 22:03:48 UTC
57c2fcb tweak parallel collective translation rule api (#3853) also remove standard_parallel_primitive helper function, which wasn't very helpful 24 July 2020, 19:52:52 UTC
b7bcfa6 Create a separate internal helper function for XLA compilation (#3852) XLA backends are written in C++, so method calls don't show up in Python profiling results from cProfile. Adding an explicit function call fixes that. This is helpful for interpretting profiling results, e.g., on the example from https://github.com/google/jax/issues/3847. Before: 70814996 function calls (69915267 primitive calls) in 112.804 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 1193 24.936 0.021 30.336 0.025 xla.py:227(xla_primitive_callable) 10524/1 16.342 0.002 112.991 112.991 xla.py:595(_xla_callable) 2014622/1843062 8.745 0.000 16.618 0.000 util.py:29(safe_map) 18145 3.662 0.000 4.218 0.000 source_info_util.py:27(user_frame) 196061/183909 1.604 0.000 24.647 0.000 partial_eval.py:150(default_process_primitive) 423499 1.569 0.000 1.569 0.000 {method 'reduce' of 'numpy.ufunc' objects} After: 71147652 function calls (70235594 primitive calls) in 101.718 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 1294 38.894 0.030 38.894 0.030 xla.py:325(_backend_compile) 2017790/1844559 6.965 0.000 14.139 0.000 util.py:29(safe_map) 18146 3.317 0.000 3.839 0.000 source_info_util.py:27(user_frame) 196226/184073 1.467 0.000 21.889 0.000 partial_eval.py:150(default_process_primitive) 423771 1.419 0.000 1.419 0.000 {method 'reduce' of 'numpy.ufunc' objects} We now clearly see that both `xla_primitive_callable` and `_xla_callable` are slow for the same reason and ~40 seconds is spent inside XLA compilation. 24 July 2020, 18:05:40 UTC
5ae6043 Cleanup test names in random_test.py (#3842) 24 July 2020, 16:36:16 UTC
53a4538 Fix source_info crash in Jaxpr printing (#3849) 24 July 2020, 15:52:32 UTC
55cd827 [jax2tf] Selectively disable some primitive tests on TPU. (#3835) Currently, all primitive tests are disabled on TPU, but only some cases fail. This PR simply disables the selected failing tests on TPU, so that we can enable the rest by default. Future PRs will address each failure individually. 24 July 2020, 07:14:34 UTC
8673262 [jax2tf] Expand the support for jax.remat. (#3828) In the simplest forms, remat is already handled by the `process_call`. But when the `remat` has outputs computed from values captured from an outer environment, we need to also implement `post_process_call`. 24 July 2020, 07:14:08 UTC
e2424e3 attempt to fix CI failure (from #3845 test?) (#3846) 24 July 2020, 03:59:12 UTC
cc9528d fix thread locality bug in custom_derivatives (#3845) * fix thread locality bug in custom_derivatives fixes #3843 24 July 2020, 02:49:04 UTC
67ad5eb add result_shape option to xla_computation (#3844) add result_shape option to xla_computation 24 July 2020, 02:38:56 UTC
f6221a6 Enable int{8,16} and uint{8,16} tests in lax_test and lax_numpy_test. (#3833) 23 July 2020, 20:17:55 UTC
2796032 Tweak Dockerfile to prevent build failure and add TODO (#3838) 23 July 2020, 20:08:06 UTC
9281092 Improve test coverage for jax.numpy sorting algorithms (#3836) 23 July 2020, 17:28:44 UTC
90b3532 Update XLA. (#3834) 23 July 2020, 16:10:39 UTC
7da1cba Remove fallback to 2-pass algorithm for argmin/argmax on TPU. (#3831) (The compiler problem that prompted the workaround seems to be fixed.) 23 July 2020, 15:35:29 UTC
1bc5896 update README for new jaxlib version (#3825) 23 July 2020, 02:39:51 UTC
17c21a7 update jaxlib version and changelog for pypi (#3824) 22 July 2020, 22:21:30 UTC
edff693 update jax version and changelog for pypi (#3823) 22 July 2020, 21:20:34 UTC
fddb28d Cleanup: fix type issues in lax_numpy.py (#3816) These changes are basically a no-op wth the current default types, but fixes issues if/when those types are changed to 32-bit in the future. 22 July 2020, 19:50:54 UTC
04a6238 implement jax.numpy.lexsort (#3812) 22 July 2020, 19:48:49 UTC
0a3a5bb address nan issue (#3777) 22 July 2020, 16:17:06 UTC
f574b11 Support preconditions on scatter indices (#3147) * wire through precondition flags to XLA scatter * use scatter precondition flags in tests * fix DUS batching rule * make new arguments kw-only * onp -> np * fix jax2tf for new args * fix more test failures 22 July 2020, 06:16:27 UTC
74d363e fix extremely minor typo (#3815) "ijnputs" -> "inputs" 21 July 2020, 19:41:08 UTC
3a3c8ea tweak jnp.repeat not to use jnp on shapes (#3810) tweak jnp.repeat not to use jnp on shapes 21 July 2020, 13:48:55 UTC
e2bfc90 [jax2tf] First draft of top_k conversion to tf. (#3785) * [jax2tf] First draft of top_k conversion to tf. We are able to use tf.math.top_k in most cases but: - bfloat16 and complex numbers are unsupported - tf.math.top_k works properly with args of type int32/64, uint32/64 and float32/64. Other flavors of uints/ints and bool can not be used with the operation by default, and they are thus promoted to their equivalent 32 bit representation, and the result is casted back to the original dtype. A limitation right now is the handling of edge cases: TF and JAX order np.inf and np.nan differently, resulting in inconsistencies in the case where these values are present. Note: this has been sanity checked, the tests fail if we set sorted=False. 21 July 2020, 13:38:12 UTC
9773432 [jax2tf] First draft of testing the QR conversion. (#3775) * [jax2tf] First draft of testing the QR conversion. QR decomposition is off by over 1e-6 in some instances, requiring custom atol and rtol values in testing code. There is an odd problem in which experimental compilation fails for complex types, although they are in principle supported. 21 July 2020, 13:36:35 UTC
71f80a5 Fix type mismatch in jet rule for abs (#3807) 21 July 2020, 00:18:08 UTC
a6e2d20 Add support for base dilation and window dilation to reduce window op… (#3803) 20 July 2020, 21:27:24 UTC
ce14409 Fix `jax.image._resize` function (#3805) This PR fixes a bug in jax.image._resize where the local `method_id` variable may be used without being defined first. This bug can be easily reproduced by passing to `jax.image.resize` parameter `method` a `ResizeMethod` instead of an `str`. By doing this, `method_id` is never defined and the instruction `if method_id == ResizeMethod.NEAREST` raises an error. Currently, this can be easily bypassed assigning parameter `method` a `str`. To fix this bug, it only needs to rename `method_id` to `method`, the same name of the input parameter. 20 July 2020, 20:15:40 UTC
f6b3184 add clarification about jit inside indexing error message (#3804) 20 July 2020, 18:10:10 UTC
fe99a06 Error message and docstring updates RE: dynamic_slice (#3795) This should clarify the underlying issues from #1007 and #3794. It might be worth mentioning masking, but that's a little big for fitting into an error message. Maybe once the masking transformation is non-experimental or if we had a dedicated doc page. 20 July 2020, 13:08:54 UTC
dd3cb82 Enable buffer donation for GPU. (#3800) 20 July 2020, 12:59:13 UTC
d433934 Add a TypeVar to while_loop definition. (#3792) 18 July 2020, 20:36:23 UTC
fa2a027 revert #3674 17 July 2020, 22:44:51 UTC
2df486e Note in pmap docs that pmap compiles like jit. (#3787) 17 July 2020, 22:11:26 UTC
9a01d78 Merge pull request #3782 from hawkinsp/dot Don't move batch dimensions to start in jnp.einsum. 17 July 2020, 21:38:05 UTC
3eda6e9 Merge pull request #3783 from hawkinsp/matmul Avoid broadcasts in implementation of jnp.matmul 17 July 2020, 21:36:29 UTC
57c1822 Merge pull request #3789 from hawkinsp/rw Improve reduce-window testing. 17 July 2020, 21:14:03 UTC
7c4d41f Merge pull request #3788 from hawkinsp/convtol Relax test tolerance for conv_general_dilated gradients. 17 July 2020, 21:13:42 UTC
7e3433e Improve reduce-window testing. 17 July 2020, 20:05:51 UTC
e04fdeb Relax test tolerance for conv_general_dilated gradients. The new support for complex types leads to slightly higher errors. 17 July 2020, 19:07:17 UTC
7b3eff8 Merge pull request #3786 from hawkinsp/einsumargs Use keyword arguments in einsum. 17 July 2020, 18:55:36 UTC
f7260a4 stack traces without jax-internal frames (#3674) 17 July 2020, 18:32:36 UTC
252a027 Update numpy signatures test. 17 July 2020, 18:15:52 UTC
8fc1332 Avoid broadcasting for batch dimensions in jnp.matmul. Instead, squeeze size 1 dimensions out of the matmul input, and transpose any non-batch non-contracting dimensions into the correct location. 17 July 2020, 17:58:02 UTC
8e986f9 Use keyword arguments in einsum. Python 3 cleanup only, no functional changes intended. 17 July 2020, 16:54:36 UTC
6a37e26 Don't move batch dimensions to start in jnp.einsum. We no longer require batch dimensions to be first, and this simplifies the code. 17 July 2020, 12:57:09 UTC
943c421 Relax dimension ordering rules for dot_general. (#3778) JAX currently requires that batch dimensions appear first and contiguously in the arguments to dot_general. However, XLA does not require this; relax JAX's checks so that it also allows batch dimensions in arbitrary positions. Since batch dimensions are now allowed in arbitrary positions, it's not hard to generalize the dot_general batching rule to avoid performing any transposes (#2972). In passing, also move the bool/int dot expansion into the XLA translation rule. The expansion inside the `lax.dot_general()` wrapper predated the existence of (or at least my knowledge of) `xla.lower_fun()`. 17 July 2020, 12:55:23 UTC
e4d5ead Use iteration over equations to test for "transpose" and "broadcast". 17 July 2020, 12:44:47 UTC
165f31e Also test for transpose in dot vmap test. 17 July 2020, 12:38:33 UTC
dd4db64 docstring for api_boundary 17 July 2020, 00:19:16 UTC
733f6c4 disable false pytype error 17 July 2020, 00:12:09 UTC
6416ca0 append filtered stack traces to error messages raised under transformations 17 July 2020, 00:12:09 UTC
e2e73a8 Relax dimension ordering rules for dot_general. JAX currently requires that batch dimensions appear first and contiguously in the arguments to dot_general. However, XLA does not require this; relax JAX's checks so that it also allows batch dimensions in arbitrary positions. Since batch dimensions are now allowed in arbitrary positions, it's not hard to generalize the dot_general batching rule to avoid performing any transposes (#2972). In passing, also move the bool/int dot expansion into the XLA translation rule. The expansion inside the `lax.dot_general()` wrapper predated the existence of (or at least my knowledge of) `xla.lower_fun()`. 16 July 2020, 23:36:22 UTC
3fb8874 [x64 deprecation] Create _np_array utility routine (#3779) 16 July 2020, 22:53:16 UTC
44fbce5 Revert "Add `in_parts` and `out_parts` optional arguments `jax.xla_computation`. (#3771)" (#3780) This reverts commit dbc3f83f6d14d491a06137f698aca92f7f3c572d. This is breaking some google-internal users of xla_computation. Reverting while I investigate. 16 July 2020, 22:22:40 UTC
bfe8e4f Enable `all_to_all` in multi-host settings. (#3772) I tested this via `pswapaxes` and it seems to work. There may still be issues with all_to_all (e.g. https://github.com/google/jax/issues/1332) but it seems worth enabling the `pswapaxes` use case. 16 July 2020, 16:19:41 UTC
7b18a4e [jax2tf] Fix interface of ConvertAndCompare function. (#3776) 16 July 2020, 13:44:20 UTC
15125b8 [jax2tf] Fix bug with double XLA compile (#3765) * [jax2tf] Fix bug with double XLA compile Some converted ops require XLA compilation to work around bugs in TF where the behavior without XLA is incorrect. If those ops are then part of an outer tf.function(experimental_compile=True) then we get a TF error. This change primarily detects that we are in an compilation context and does not use XLA compiler for ops. This, however, changes the error behavior for dynamic_slice. Improved the testing, to skip fewer tests but instead to expect and check for errors. 16 July 2020, 13:06:34 UTC
dbc3f83 Add `in_parts` and `out_parts` optional arguments `jax.xla_computation`. (#3771) This allows partitioned computations in `xla_computation`, like those produced by `sharded_jit`. 15 July 2020, 21:56:58 UTC
05904fa Change onp/np to np/jnp in docs & notebooks (#3760) 15 July 2020, 20:17:38 UTC
150d028 Update scipy.ndimage.map_coordinates docstring (#3762) 15 July 2020, 18:19:40 UTC
34669d7 source sync (#3763) 15 July 2020, 18:03:58 UTC
8a62a9b block-unrolled scan primitive implementation (#3738) * block-unrolled scan implementation, via optional `_unroll` scan parameter * index statically in the inlined path of lax.scan * make `unroll` a required scan parameter, and test that it unrolls 15 July 2020, 18:00:50 UTC
23c279f [jax2tf] Relax tolerance for JaxPrimitiveTest.test_unary_elementwise_lgamma_float32 (#3767) 15 July 2020, 15:19:22 UTC
d55c47d [jax2tf] Fix error in sorting with TF graphs. (#3764) 15 July 2020, 12:12:42 UTC
c380356 [jax2tf] Refactor tests to increase coverage (#3700) * [jax2tf] Refactor tests to increase coverage. This change has several goals: (a) increase the test coverage for the elementwise primitives, (b) expose explicitly the situations where JAX and TF produce different results, e.g., inf vs Nan, and (c) run all comparisons with and without tf.function, with and without experimental_compile=True. Previously the test code was just masking off non-finite values. This change uncovered quite a few unimplemented cases, e.g., with float16, bfloat16, conversions that cannot be compiled. This are left as TODO for now. * Disable the sort tests 15 July 2020, 06:49:51 UTC
68c8dc7 Added installation note for jax2tf with GPU support. (#3750) 15 July 2020, 04:56:19 UTC
f02d5b4 Support differentiation through jax.lax.all_to_all (#3733) * Support differentiation through jax.lax.all_to_all Credit to @levskaya for the solution. * Test gradient of all_to_all We are testing all_to_all through pswapaxes, since general all_to_all is problematic according to https://github.com/google/jax/issues/1332. * Removed trailing spaces 14 July 2020, 22:45:49 UTC
a7c2cde Cleanup: convert uses of `import numpy as onp` in library code (#3754) 14 July 2020, 20:05:31 UTC
512ed18 Cleanup: convert uses of 'import numpy as onp' in tests (#3756) 14 July 2020, 20:03:24 UTC
58aba9b Fix low probability (4/1000) flake of batching_test on GPU. (#3752) 14 July 2020, 18:10:13 UTC
f6f9755 A jit-able version of np.repeat. (#3670) A new keyword argument has been added to np.repeat, total_repeat_length, that can optionally be supplied to make np.repeat jit-able. 14 July 2020, 17:37:09 UTC
0605321 Fix compilation bug in histogram_bin_edges (#3745) 14 July 2020, 17:24:42 UTC
e1f8f24 lax_numpy: rename arguments to match numpy (#3747) 14 July 2020, 17:21:38 UTC
7f9c2f1 Make jnp.take work for empty slices of empty arrays. (#3751) 14 July 2020, 16:02:26 UTC
2b7a39f Add pshuffle to docs (#3742) 14 July 2020, 13:05:45 UTC
f78ccf1 Fix tuple() in reduce_window padding (#3748) * Fix tuple() in reduce_window padding * Update lax.py 14 July 2020, 01:16:11 UTC
a017c10 Implement nearest neighbor image resizes. (#3743) 14 July 2020, 00:27:12 UTC
back to top