https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
edff693 update jax version and changelog for pypi (#3823) 22 July 2020, 21:20:34 UTC
fddb28d Cleanup: fix type issues in lax_numpy.py (#3816) These changes are basically a no-op wth the current default types, but fixes issues if/when those types are changed to 32-bit in the future. 22 July 2020, 19:50:54 UTC
04a6238 implement jax.numpy.lexsort (#3812) 22 July 2020, 19:48:49 UTC
0a3a5bb address nan issue (#3777) 22 July 2020, 16:17:06 UTC
f574b11 Support preconditions on scatter indices (#3147) * wire through precondition flags to XLA scatter * use scatter precondition flags in tests * fix DUS batching rule * make new arguments kw-only * onp -> np * fix jax2tf for new args * fix more test failures 22 July 2020, 06:16:27 UTC
74d363e fix extremely minor typo (#3815) "ijnputs" -> "inputs" 21 July 2020, 19:41:08 UTC
3a3c8ea tweak jnp.repeat not to use jnp on shapes (#3810) tweak jnp.repeat not to use jnp on shapes 21 July 2020, 13:48:55 UTC
e2bfc90 [jax2tf] First draft of top_k conversion to tf. (#3785) * [jax2tf] First draft of top_k conversion to tf. We are able to use tf.math.top_k in most cases but: - bfloat16 and complex numbers are unsupported - tf.math.top_k works properly with args of type int32/64, uint32/64 and float32/64. Other flavors of uints/ints and bool can not be used with the operation by default, and they are thus promoted to their equivalent 32 bit representation, and the result is casted back to the original dtype. A limitation right now is the handling of edge cases: TF and JAX order np.inf and np.nan differently, resulting in inconsistencies in the case where these values are present. Note: this has been sanity checked, the tests fail if we set sorted=False. 21 July 2020, 13:38:12 UTC
9773432 [jax2tf] First draft of testing the QR conversion. (#3775) * [jax2tf] First draft of testing the QR conversion. QR decomposition is off by over 1e-6 in some instances, requiring custom atol and rtol values in testing code. There is an odd problem in which experimental compilation fails for complex types, although they are in principle supported. 21 July 2020, 13:36:35 UTC
71f80a5 Fix type mismatch in jet rule for abs (#3807) 21 July 2020, 00:18:08 UTC
a6e2d20 Add support for base dilation and window dilation to reduce window op… (#3803) 20 July 2020, 21:27:24 UTC
ce14409 Fix `jax.image._resize` function (#3805) This PR fixes a bug in jax.image._resize where the local `method_id` variable may be used without being defined first. This bug can be easily reproduced by passing to `jax.image.resize` parameter `method` a `ResizeMethod` instead of an `str`. By doing this, `method_id` is never defined and the instruction `if method_id == ResizeMethod.NEAREST` raises an error. Currently, this can be easily bypassed assigning parameter `method` a `str`. To fix this bug, it only needs to rename `method_id` to `method`, the same name of the input parameter. 20 July 2020, 20:15:40 UTC
f6b3184 add clarification about jit inside indexing error message (#3804) 20 July 2020, 18:10:10 UTC
fe99a06 Error message and docstring updates RE: dynamic_slice (#3795) This should clarify the underlying issues from #1007 and #3794. It might be worth mentioning masking, but that's a little big for fitting into an error message. Maybe once the masking transformation is non-experimental or if we had a dedicated doc page. 20 July 2020, 13:08:54 UTC
dd3cb82 Enable buffer donation for GPU. (#3800) 20 July 2020, 12:59:13 UTC
d433934 Add a TypeVar to while_loop definition. (#3792) 18 July 2020, 20:36:23 UTC
fa2a027 revert #3674 17 July 2020, 22:44:51 UTC
2df486e Note in pmap docs that pmap compiles like jit. (#3787) 17 July 2020, 22:11:26 UTC
9a01d78 Merge pull request #3782 from hawkinsp/dot Don't move batch dimensions to start in jnp.einsum. 17 July 2020, 21:38:05 UTC
3eda6e9 Merge pull request #3783 from hawkinsp/matmul Avoid broadcasts in implementation of jnp.matmul 17 July 2020, 21:36:29 UTC
57c1822 Merge pull request #3789 from hawkinsp/rw Improve reduce-window testing. 17 July 2020, 21:14:03 UTC
7c4d41f Merge pull request #3788 from hawkinsp/convtol Relax test tolerance for conv_general_dilated gradients. 17 July 2020, 21:13:42 UTC
7e3433e Improve reduce-window testing. 17 July 2020, 20:05:51 UTC
e04fdeb Relax test tolerance for conv_general_dilated gradients. The new support for complex types leads to slightly higher errors. 17 July 2020, 19:07:17 UTC
7b3eff8 Merge pull request #3786 from hawkinsp/einsumargs Use keyword arguments in einsum. 17 July 2020, 18:55:36 UTC
f7260a4 stack traces without jax-internal frames (#3674) 17 July 2020, 18:32:36 UTC
252a027 Update numpy signatures test. 17 July 2020, 18:15:52 UTC
8fc1332 Avoid broadcasting for batch dimensions in jnp.matmul. Instead, squeeze size 1 dimensions out of the matmul input, and transpose any non-batch non-contracting dimensions into the correct location. 17 July 2020, 17:58:02 UTC
8e986f9 Use keyword arguments in einsum. Python 3 cleanup only, no functional changes intended. 17 July 2020, 16:54:36 UTC
6a37e26 Don't move batch dimensions to start in jnp.einsum. We no longer require batch dimensions to be first, and this simplifies the code. 17 July 2020, 12:57:09 UTC
943c421 Relax dimension ordering rules for dot_general. (#3778) JAX currently requires that batch dimensions appear first and contiguously in the arguments to dot_general. However, XLA does not require this; relax JAX's checks so that it also allows batch dimensions in arbitrary positions. Since batch dimensions are now allowed in arbitrary positions, it's not hard to generalize the dot_general batching rule to avoid performing any transposes (#2972). In passing, also move the bool/int dot expansion into the XLA translation rule. The expansion inside the `lax.dot_general()` wrapper predated the existence of (or at least my knowledge of) `xla.lower_fun()`. 17 July 2020, 12:55:23 UTC
e4d5ead Use iteration over equations to test for "transpose" and "broadcast". 17 July 2020, 12:44:47 UTC
165f31e Also test for transpose in dot vmap test. 17 July 2020, 12:38:33 UTC
dd4db64 docstring for api_boundary 17 July 2020, 00:19:16 UTC
733f6c4 disable false pytype error 17 July 2020, 00:12:09 UTC
6416ca0 append filtered stack traces to error messages raised under transformations 17 July 2020, 00:12:09 UTC
e2e73a8 Relax dimension ordering rules for dot_general. JAX currently requires that batch dimensions appear first and contiguously in the arguments to dot_general. However, XLA does not require this; relax JAX's checks so that it also allows batch dimensions in arbitrary positions. Since batch dimensions are now allowed in arbitrary positions, it's not hard to generalize the dot_general batching rule to avoid performing any transposes (#2972). In passing, also move the bool/int dot expansion into the XLA translation rule. The expansion inside the `lax.dot_general()` wrapper predated the existence of (or at least my knowledge of) `xla.lower_fun()`. 16 July 2020, 23:36:22 UTC
3fb8874 [x64 deprecation] Create _np_array utility routine (#3779) 16 July 2020, 22:53:16 UTC
44fbce5 Revert "Add `in_parts` and `out_parts` optional arguments `jax.xla_computation`. (#3771)" (#3780) This reverts commit dbc3f83f6d14d491a06137f698aca92f7f3c572d. This is breaking some google-internal users of xla_computation. Reverting while I investigate. 16 July 2020, 22:22:40 UTC
bfe8e4f Enable `all_to_all` in multi-host settings. (#3772) I tested this via `pswapaxes` and it seems to work. There may still be issues with all_to_all (e.g. https://github.com/google/jax/issues/1332) but it seems worth enabling the `pswapaxes` use case. 16 July 2020, 16:19:41 UTC
7b18a4e [jax2tf] Fix interface of ConvertAndCompare function. (#3776) 16 July 2020, 13:44:20 UTC
15125b8 [jax2tf] Fix bug with double XLA compile (#3765) * [jax2tf] Fix bug with double XLA compile Some converted ops require XLA compilation to work around bugs in TF where the behavior without XLA is incorrect. If those ops are then part of an outer tf.function(experimental_compile=True) then we get a TF error. This change primarily detects that we are in an compilation context and does not use XLA compiler for ops. This, however, changes the error behavior for dynamic_slice. Improved the testing, to skip fewer tests but instead to expect and check for errors. 16 July 2020, 13:06:34 UTC
dbc3f83 Add `in_parts` and `out_parts` optional arguments `jax.xla_computation`. (#3771) This allows partitioned computations in `xla_computation`, like those produced by `sharded_jit`. 15 July 2020, 21:56:58 UTC
05904fa Change onp/np to np/jnp in docs & notebooks (#3760) 15 July 2020, 20:17:38 UTC
150d028 Update scipy.ndimage.map_coordinates docstring (#3762) 15 July 2020, 18:19:40 UTC
34669d7 source sync (#3763) 15 July 2020, 18:03:58 UTC
8a62a9b block-unrolled scan primitive implementation (#3738) * block-unrolled scan implementation, via optional `_unroll` scan parameter * index statically in the inlined path of lax.scan * make `unroll` a required scan parameter, and test that it unrolls 15 July 2020, 18:00:50 UTC
23c279f [jax2tf] Relax tolerance for JaxPrimitiveTest.test_unary_elementwise_lgamma_float32 (#3767) 15 July 2020, 15:19:22 UTC
d55c47d [jax2tf] Fix error in sorting with TF graphs. (#3764) 15 July 2020, 12:12:42 UTC
c380356 [jax2tf] Refactor tests to increase coverage (#3700) * [jax2tf] Refactor tests to increase coverage. This change has several goals: (a) increase the test coverage for the elementwise primitives, (b) expose explicitly the situations where JAX and TF produce different results, e.g., inf vs Nan, and (c) run all comparisons with and without tf.function, with and without experimental_compile=True. Previously the test code was just masking off non-finite values. This change uncovered quite a few unimplemented cases, e.g., with float16, bfloat16, conversions that cannot be compiled. This are left as TODO for now. * Disable the sort tests 15 July 2020, 06:49:51 UTC
68c8dc7 Added installation note for jax2tf with GPU support. (#3750) 15 July 2020, 04:56:19 UTC
f02d5b4 Support differentiation through jax.lax.all_to_all (#3733) * Support differentiation through jax.lax.all_to_all Credit to @levskaya for the solution. * Test gradient of all_to_all We are testing all_to_all through pswapaxes, since general all_to_all is problematic according to https://github.com/google/jax/issues/1332. * Removed trailing spaces 14 July 2020, 22:45:49 UTC
a7c2cde Cleanup: convert uses of `import numpy as onp` in library code (#3754) 14 July 2020, 20:05:31 UTC
512ed18 Cleanup: convert uses of 'import numpy as onp' in tests (#3756) 14 July 2020, 20:03:24 UTC
58aba9b Fix low probability (4/1000) flake of batching_test on GPU. (#3752) 14 July 2020, 18:10:13 UTC
f6f9755 A jit-able version of np.repeat. (#3670) A new keyword argument has been added to np.repeat, total_repeat_length, that can optionally be supplied to make np.repeat jit-able. 14 July 2020, 17:37:09 UTC
0605321 Fix compilation bug in histogram_bin_edges (#3745) 14 July 2020, 17:24:42 UTC
e1f8f24 lax_numpy: rename arguments to match numpy (#3747) 14 July 2020, 17:21:38 UTC
7f9c2f1 Make jnp.take work for empty slices of empty arrays. (#3751) 14 July 2020, 16:02:26 UTC
2b7a39f Add pshuffle to docs (#3742) 14 July 2020, 13:05:45 UTC
f78ccf1 Fix tuple() in reduce_window padding (#3748) * Fix tuple() in reduce_window padding * Update lax.py 14 July 2020, 01:16:11 UTC
a017c10 Implement nearest neighbor image resizes. (#3743) 14 July 2020, 00:27:12 UTC
9a867ca Remove unused private function (#3744) 13 July 2020, 23:55:49 UTC
a6ab742 Improve np.intersect1d (#3739) 13 July 2020, 22:31:47 UTC
6017205 Add defensive tuple() in lax.reduce_window (#3741) 13 July 2020, 21:37:46 UTC
3c6cd5f Implement complex convolutions on CPU and GPU. (#3735) Lowers using Gauss's complex multiplication algorithm (which internally is also what the XLA:TPU implementation does.) 13 July 2020, 18:44:24 UTC
6391cfe [jax2tf] First draft of converting sort_p using TF2XLA (#3713) Limitations: stable sort is skipped, as well as tests using > 2 arrays. num_keys is always set to 1, and boolean key-value-sort is also skipped. * Skipped multiarray tests on GPU. TensorFlow is run on CPU which makes the results incompatible with the XLA GPU implementation. 13 July 2020, 15:01:24 UTC
71253ac Generalize reduce-window padding to support (lo, hi) pairs. (#3728) * Generalize reduce-window padding to support (lo, hi) pairs, as XLA does.. This turns out to simplify the code slightly, too. * Fix select_and_gather_add batching rule and test. * Fix documentation text to refer to ReduceWindowWithGeneralPadding. 13 July 2020, 13:49:52 UTC
a9da06c Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729) 13 July 2020, 13:43:19 UTC
0d81e98 Implement np.intersect1d (#3726) * Implement np.intersect1d * Add jitable helper to function * Fix argsort failing tests * Fix linter errors 13 July 2020, 05:32:41 UTC
9da9156 Change image resize implementation to use a matmul per dimension. (#3720) * Change image resize implementation to use a matmul per dimension. This should have better space scaling behaviors than the previous gather approach. In particular, it does not require temporary memory that scales with the batch size or number of features of an image. * Plumb precision option through to image resize API. 12 July 2020, 18:00:10 UTC
51ca57d check matmul inputs aren't scalar (#3725) also dot_general shape rule should check dimension numbers are in range fixes #3718 12 July 2020, 03:47:22 UTC
aa75209 Import profiler in jax/__init__.py (#3719) 12 July 2020, 03:44:16 UTC
eb67571 Merge pull request #3705 from NeilGirdhar/nestable_vjp Make vjp cotangent functions pytree-like 11 July 2020, 20:21:49 UTC
503e597 Make vjp cotangent functions pytree-like Fixes #3667 11 July 2020, 03:22:38 UTC
e073e25 Add a jax.profiler.StepTraceContext API. (#3591) * Add RootTraceContext. * Rename to StepTraceContext. 11 July 2020, 00:20:06 UTC
412b9d5 hfft and ihfft implementation (#3664) 10 July 2020, 17:34:59 UTC
60d8527 lexicographic sort_p: accept num_keys rather than comparator (#3715) 10 July 2020, 16:58:35 UTC
d2f9c46 Remove some non-inclusive language (#3710) 10 July 2020, 16:29:06 UTC
24c9ee6 fix flake error (#3716) 10 July 2020, 16:24:51 UTC
44eae61 Turn off INFO logging (again). (#3707) Something must have started logging earlier than before, causing INFO-level logging to be initialized before we disabled it. This change disables INFO logging sooner. 10 July 2020, 15:11:48 UTC
417de0d Add jit to jax.image.resize (#3714) * Add image/ directory to Bazel build. * Use a jit on jax.image.resize to reduce compilation time. Relax bfloat16 test tolerance. 10 July 2020, 14:32:13 UTC
b943b31 Add jax.image.resize. (#3703) * Add jax.image.resize. This is a port of `tf.image.resize()` and the `ScaleAndTranslate` operator. While I don't expect this implementation to be particularly fast, it is a useful generic implementation to which we can add optimized special cases as the need arises. 10 July 2020, 13:57:59 UTC
804e449 Generalize lax.sort to support lexicographic sorts. (#3709) 10 July 2020, 03:05:19 UTC
0a6b715 Add _NOT_IMPLEMENTED attribute to jax.numpy (fixes #3689) (#3698) 09 July 2020, 23:31:08 UTC
c1aeb8b Add simple JAX API microbenchmarks. (#3679) 09 July 2020, 17:02:23 UTC
b813ae3 Cleanup: record names in get_module_functions (#3697) 08 July 2020, 21:44:49 UTC
19adce5 Cleanup: use test_util dtypes where possible (#3695) * Cleanup: use test_util dtypes where possible * fix issue in fft test * fix duplicate test name issue 08 July 2020, 20:21:48 UTC
11b40fb Fix a link to TensorBoard's profiler in Profiling JAX Programs (#3692) 08 July 2020, 17:53:06 UTC
fdd7f0c Added concurrent id_tap tests, disabled for GPU (#3690) 08 July 2020, 13:08:54 UTC
6b471e2 Cleanup: define type lists in test_util & use in several test files. (#3616) 08 July 2020, 00:01:38 UTC
82dbaca Revert #3610 & #3684 (#3688) * Revert "linalg_test: define test matrices lazily (#3684)" This reverts commit 2be1baa41a170192c209c94b060d0d034d1de2c2. * Revert "Make LU gradient work for low-rank matrices (#3610)" This reverts commit 23deefa71838ceeab41977ac0ab781164c914a8c. 07 July 2020, 23:19:43 UTC
1034f29 fix bad pmap tests from #3675 (#3685) 07 July 2020, 21:48:54 UTC
e89860c fix pmap test on GPU/TPU (#3682) * fix pmap test on GPU/TPU * AssertionError -> ValueError 07 July 2020, 21:47:44 UTC
2be1baa linalg_test: define test matrices lazily (#3684) 07 July 2020, 21:47:24 UTC
9c3e6c3 AssertionError -> ValueError 07 July 2020, 20:21:44 UTC
4711589 fix pmap test on GPU/TPU 07 July 2020, 20:19:19 UTC
242b382 Remove a deprecated reference to testExamplesJaxprDoc in Understanding Jaxpr (#3680) 07 July 2020, 18:29:44 UTC
bf97e47 Make infeed_test and host_callback_test independent. (#3676) * Make infeed_test and host_callback_test independent. * the infeed_test will stop the outfeed receiver * Remove the use of --dist=loadfile. * Prevent logging on exit 07 July 2020, 08:03:30 UTC
d2ebb6e fix ppermute test bugs found by @jekbradbury (#3675) 07 July 2020, 07:30:08 UTC
back to top