4747205 | Skye Wanderman-Milne | 21 May 2020, 17:36:28 UTC | Update jax version to 0.1.68 | 21 May 2020, 17:36:28 UTC |
bb2127c | Jake Vanderplas | 21 May 2020, 16:20:59 UTC | Future-proof view test against signaling NaNs (#3178) | 21 May 2020, 16:20:59 UTC |
c459280 | Dan Piponi | 21 May 2020, 15:35:17 UTC | Added `associative_scan`. (#2170) * Added `associative_scan`. * Fixed problem where base case of associative scan could fail * remove jax.numpy dependence in associative_scan Co-authored-by: Matthew Johnson <mattjj@google.com> | 21 May 2020, 15:35:17 UTC |
5c1de28 | Matthew Johnson | 21 May 2020, 15:00:18 UTC | revise vmap in_axes/out_axes leaf type error msg (#3179) from #3161 discussion | 21 May 2020, 15:00:18 UTC |
eb81d7e | Matthew Johnson | 21 May 2020, 13:47:02 UTC | add dict in_axes example to vmap docstring (#3176) * add dict in_axes example to vmap docstring fixes #3161 * fix typo | 21 May 2020, 13:47:02 UTC |
6e3c8b1 | Jake Vanderplas | 21 May 2020, 13:40:24 UTC | Fix arr.view() on TPU & improve tests (#3141) | 21 May 2020, 13:40:24 UTC |
ae9d175 | Matthew Johnson | 21 May 2020, 03:21:41 UTC | fix while_loop cond function batching (#3174) * fix while_loop cond function batching fixes #3164 * add test for #3164 | 21 May 2020, 03:21:41 UTC |
f9c978e | Matthew Johnson | 21 May 2020, 02:09:54 UTC | improve docstring of jax.numpy.broadcast_to (#3173) thanks @joaogui1 ! | 21 May 2020, 02:09:54 UTC |
a4094f7 | Matthew Johnson | 21 May 2020, 02:09:44 UTC | revise "Tracer with raw numpy" error message (#3160) * revise "Tracer with raw numpy" error message fixes #3133 * fix f-string typo * fix typo Co-authored-by: James Bradbury <jekbradbury@google.com> Co-authored-by: James Bradbury <jekbradbury@google.com> | 21 May 2020, 02:09:44 UTC |
12f26d3 | Skye Wanderman-Milne | 20 May 2020, 21:40:28 UTC | Improve ``devices`` and related documentation (#3155) | 20 May 2020, 21:40:28 UTC |
8b32c5d | James Bradbury | 20 May 2020, 21:30:33 UTC | Avoid running trivial jitted subcomputations in pe (#3169) | 20 May 2020, 21:30:33 UTC |
f349979 | Peter Hawkins | 20 May 2020, 20:17:17 UTC | Relax tolerance of LaxVmapTest.testDot for float64 inputs. (#3167) | 20 May 2020, 20:17:17 UTC |
7d157c7 | joao guilherme | 20 May 2020, 04:43:48 UTC | onp -> np (#3157) | 20 May 2020, 04:43:48 UTC |
42b425d | Matthew Johnson | 20 May 2020, 01:14:10 UTC | fix disable_jit logic in lax.cond and lax.while_loop (#3156) * fix disable_jit logic in lax.cond, fixes #3093 * fix disable_jit logic in lax.while_loop, fix #2823 * add test for issue #3093 * add test for #2823 * add test for #2598 | 20 May 2020, 01:14:10 UTC |
3141ff8 | Jake Vanderplas | 19 May 2020, 23:58:42 UTC | Add lax implementation of np.isin() and np.in1d() (#3087) | 19 May 2020, 23:58:42 UTC |
ccb203c | Matthew Johnson | 19 May 2020, 22:51:07 UTC | improve pmap unbound axis error, fixes #3120 (#3152) | 19 May 2020, 22:51:07 UTC |
b7ff305 | Jake Vanderplas | 19 May 2020, 22:50:54 UTC | fix broken TPU test (#3153) | 19 May 2020, 22:50:54 UTC |
bc47a32 | Matthew Johnson | 19 May 2020, 22:41:03 UTC | make lax.psum promote bool -> int (#3150) * make lax.psum promote bool -> int, fixes #3123 * fix test bug * fix typo in test | 19 May 2020, 22:41:03 UTC |
83a339e | Jacob Kelly | 19 May 2020, 22:22:25 UTC | add erf and erfc rules (#3051) refactor def comp | 19 May 2020, 22:22:25 UTC |
b6777c0 | Sebastian Bischoff | 19 May 2020, 22:18:34 UTC | Correct calculation of loss and increase learning rate (#3113) | 19 May 2020, 22:18:34 UTC |
850f1af | Matthew Johnson | 19 May 2020, 22:17:03 UTC | improve errors for complex derivs, fixes #3121 (#3149) | 19 May 2020, 22:17:03 UTC |
8fe2619 | Jake Vanderplas | 19 May 2020, 21:19:00 UTC | Expand type support for random uniform() & randint() (#3138) | 19 May 2020, 21:19:00 UTC |
73b76e9 | Sergei Lebedev | 19 May 2020, 19:40:03 UTC | Exported lax from jax/__init__.py (#3135) This allows to use lax functions without separately importing jax.lax. | 19 May 2020, 19:40:03 UTC |
77e3132 | Stephan Hoyer | 19 May 2020, 15:23:45 UTC | Fix indentation for docstrings in jax.experimental.host_callback (#3119) | 19 May 2020, 15:23:45 UTC |
85fe5a2 | alexdavies | 19 May 2020, 06:06:32 UTC | Add gradients to the scatter_max and scatter_min operations. (#3111) This is being done to allow the creation of a differentiable segment_max. Segment_max is an important operation for GraphNets and is an open feature request at https://github.com/google/jax/issues/2255 Co-authored-by: Alex Davies <adavies@google.com> | 19 May 2020, 06:06:32 UTC |
8d0749f | Ben Lee | 19 May 2020, 02:24:45 UTC | Fix a corner case in `repeat`. (#3117) * Fixes a corner case: `jnp.repeat(jnp.array(0), 1, axis=0)` throws an error, whereas `np.repeat(np.array(0), 1, axis=0) = np.array([0])`. * Add test for `np.range(np.array(0), 1, axis=0)`. | 19 May 2020, 02:24:45 UTC |
d53bab9 | Skye Wanderman-Milne | 19 May 2020, 01:24:33 UTC | Improve sharded_jit error message and fix test (#3145) | 19 May 2020, 01:24:33 UTC |
083cdd3 | Skye Wanderman-Milne | 19 May 2020, 01:18:06 UTC | Fix in pxla._inner_partitions (#3146) In cb77f2a22de49e85da93f43b7dc448aa238d5207, I switched to looking for sharding_constraint_p's name since sharding_constraint_p itself is defined in sharded_jit.py, but didn't quite get the update right. | 19 May 2020, 01:18:06 UTC |
cb77f2a | Skye Wanderman-Milne | 18 May 2020, 23:37:49 UTC | Move some sharded_jit functionality into pxla.py (#3142) Specifically: * Move `_inner_partitions` * Move `get_num_partitions` * Move and slightly modify the logic for finding and validating inner partitions to a new function, `reconcile_num_partitions` * Move `_partitioned_sharding_spec` and rename to `partitioned_sharding_spec` This is in preparation for enabling pmap-of-sharded_jit, since pmap will need access to this functionality as well. | 18 May 2020, 23:37:49 UTC |
9152b76 | Skye Wanderman-Milne | 18 May 2020, 22:20:49 UTC | Add with_sharding_constraint method to be used within sharded_jit. (#3100) See the with_sharding_constraint docstring for a description of what this method does. Depending on how we decide nested sharded_jits should work, an alternative implementation for with_sharding_constraint could be: ```python def with_sharding_constraint(x, partitions): return sharded_jit(lambda x: x, in_parts=partitions, out_parts=partitions) ``` In this case, we could get rid of the with_sharding_constraint primitive, and possibly even the API. This implementation gets the job done for now without committing to a nested sharded_jit behavior, and is also much easier to take the gradient of than sharded_jit. | 18 May 2020, 22:20:49 UTC |
4603432 | Stephan Hoyer | 18 May 2020, 22:13:54 UTC | Fix gradient for jax.scipy.ndimage.map_coordinates (#3110) * Fix gradient for jax.scipy.ndimage.map_coordinates Fixes GH3024 * minor refactor for clarification | 18 May 2020, 22:13:54 UTC |
36e7fad | Peter Hawkins | 18 May 2020, 21:54:20 UTC | Add a primitive integer_pow() for values raised to a fixed integer scalar. (#3140) * Add a primitive integer_pow() for values raised to fixed integer scalar. Use integer_pow() in the RHS JVP of div(). Also use it in square() and reciprocal(). Fixes #3136 ``` In [1]: from jax import grad, make_jaxpr In [2]: def inv(x): return 1/x In [3]: print(grad(grad(grad(grad(grad(grad(inv))))))(4.)) 0.043945312 In [4]: make_jaxpr(grad(grad(grad(grad(grad(grad(inv)))))))(4.) Out[4]: { lambda ; a. let b = integer_pow[ y=-7 ] a c = mul -6.0 b d = mul -120.0 c in (d,) } In [5]: ``` * Use x ** 3 in gelu definition. | 18 May 2020, 21:54:20 UTC |
a9c1b38 | Sandu Ursu | 18 May 2020, 21:12:52 UTC | Added link to README (#3139) | 18 May 2020, 21:12:52 UTC |
ed0e227 | Skye Wanderman-Milne | 18 May 2020, 18:39:55 UTC | Add sharded_jit translation rule. (#3099) This is potentially dangerous, because it lets sharded_jit() be called inside other calls primitives (e.g. jit, pmap) which isn't supported yet. I'm adding it now because I'm planning to implement pmap-of-sharded_jit soon, and it will help with testing a set_sharding API I'm also planning to add soon. | 18 May 2020, 18:39:55 UTC |
bc5a0b3 | Peter Hawkins | 18 May 2020, 14:19:03 UTC | Remove some uses of jax.partial. (#3131) | 18 May 2020, 14:19:03 UTC |
b071b12 | George Necula | 18 May 2020, 13:02:49 UTC | Fixed link in FAQ (#3129) | 18 May 2020, 13:02:49 UTC |
2be471c | Stephan Hoyer | 18 May 2020, 05:12:21 UTC | Remove redundant flag configuration from lax_scipy_sparse_test (#3128) | 18 May 2020, 05:12:21 UTC |
912b282 | Jake Vanderplas | 16 May 2020, 17:23:26 UTC | Implement np.histogram_bin_edges and np.histogram (#3081) | 16 May 2020, 17:23:26 UTC |
670fab5 | Jamie Townsend | 16 May 2020, 13:19:24 UTC | Test code in docs and api.py docstrings (#2994) Also remove jaxpr doc tests from api_test.py. | 16 May 2020, 13:19:24 UTC |
510af1d | Ed Schmerling | 16 May 2020, 03:51:53 UTC | Fix documentation for `nn.elu`, `nn.celu`, and `lax.expm1`. (#3116) | 16 May 2020, 03:51:53 UTC |
e675f80 | Jake Vanderplas | 16 May 2020, 02:09:43 UTC | Add support for 8- and 16-bit output in _random_bits (#3090) | 16 May 2020, 02:09:43 UTC |
12a9af8 | Jake Vanderplas | 15 May 2020, 21:29:02 UTC | Update random.logistic() to prevent infinities (#3048) | 15 May 2020, 21:29:02 UTC |
812df27 | Peter Hawkins | 15 May 2020, 19:51:07 UTC | Update uses of deprecated XLA methods. (#3109) | 15 May 2020, 19:51:07 UTC |
450b7ab | Roy Frostig | 15 May 2020, 02:38:43 UTC | Merge pull request #2993 from google/cond single-operand cond | 15 May 2020, 02:38:43 UTC |
b19166c | Roy Frostig | 15 May 2020, 01:01:34 UTC | silence a pytype error in jet tracing | 15 May 2020, 01:01:34 UTC |
28da2bc | Roy Frostig | 15 May 2020, 00:21:57 UTC | improve two-operand cond detection | 15 May 2020, 00:21:57 UTC |
77703b8 | Peter Hawkins | 14 May 2020, 23:17:44 UTC | Add support for sorting complex values, defaulting to a NumPy-style l… (#3096) * Add support for sorting complex values, defaulting to a NumPy-style lexicographic ordering. Implemented using a custom comparator, since the XLA-level default comparator doesn't impose and ordering for complex values. * Disable sort test on CPU and TPU. | 14 May 2020, 23:17:44 UTC |
85d7b30 | Peter Hawkins | 14 May 2020, 23:17:23 UTC | Allow non-classes as second argument of `issubclass`. (#3098) | 14 May 2020, 23:17:23 UTC |
dceb578 | Matthew Johnson | 14 May 2020, 23:06:20 UTC | stash order on jet master trace, fixes #3079 (#3097) | 14 May 2020, 23:06:20 UTC |
9d8ecc5 | Roy Frostig | 14 May 2020, 20:56:07 UTC | avoid committing to argument types so long as cond is overloaded | 14 May 2020, 20:56:07 UTC |
007cdf2 | Thomas Keck | 14 May 2020, 20:04:08 UTC | Adds additional epsilon to adam for numerical stability. (#3091) * Adds additional epsilon to adam for numerical stability. Meta-gradients through the adam optimizer diverge, because the derivative of the adam scaling with respect to the gradients get an additional 1/sqrt(g) factor. This additional factor is unregularized without the second epsilon added in this commit. * Renames eps2 to eps_root and improves docstring Co-authored-by: Thomas Keck <thomaskeck@google.com> | 14 May 2020, 20:04:08 UTC |
83c7394 | Jake Vanderplas | 14 May 2020, 19:58:31 UTC | fix astype() test (#3072) | 14 May 2020, 19:58:31 UTC |
7c687b2 | Skye Wanderman-Milne | 14 May 2020, 18:38:08 UTC | sharded_jit cleanup (#3075) * Add sharding utilities to xla_bridge * Change partitions argument to in_parts and out_parts * Add unit tests * Reuse more pxla functionality * Remove stale translation rule * Fail on non-TPU platforms * Add docstring * And more! | 14 May 2020, 18:38:08 UTC |
73a55e0 | Roy Frostig | 14 May 2020, 16:16:49 UTC | changelog: single-operand cond | 14 May 2020, 16:16:49 UTC |
930be58 | Roy Frostig | 14 May 2020, 16:05:31 UTC | update reference jaxpr using cond in outfeed rewriting test | 14 May 2020, 16:05:31 UTC |
cb5b1d1 | Roy Frostig | 14 May 2020, 16:04:48 UTC | handle single-operand cond in token-threading rewriter | 14 May 2020, 16:04:48 UTC |
9f6774a | Roy Frostig | 14 May 2020, 16:02:29 UTC | update uses of cond in host callback test | 14 May 2020, 16:02:29 UTC |
4ce2aa2 | Peter Hawkins | 14 May 2020, 15:13:15 UTC | Make lax.sort support tuple arguments using a variadic sort. (#3085) * Make lax.sort support tuple arguments using a variadic sort. Change sort_jvp to use a gather of ids to compute the JVP rather than sorting repeatedly. Remove sort_key_val_p, since it is redundant with a variadic sort_p. * Fix mypy errors. * Change JVP rule to use NumPy indexing. Remove redundant case in batching rule. | 14 May 2020, 15:13:15 UTC |
6e3bfc3 | Roy Frostig | 12 May 2020, 02:18:31 UTC | comment on joining constants for conditional branch jaxprs | 14 May 2020, 04:14:41 UTC |
2950eb1 | Roy Frostig | 12 May 2020, 02:12:53 UTC | comment on joining staged jaxprs in partial evaluation of conditionals | 14 May 2020, 04:14:41 UTC |
34efdd9 | Roy Frostig | 12 May 2020, 00:55:37 UTC | style changes in lax.cond | 14 May 2020, 04:14:41 UTC |
e8f12b6 | Roy Frostig | 12 May 2020, 00:18:22 UTC | remove deprecation warning on five-argument cond | 14 May 2020, 04:14:41 UTC |
76612c6 | Roy Frostig | 12 May 2020, 00:11:21 UTC | test and fix cond when branch-staging is off | 14 May 2020, 04:14:41 UTC |
de03c99 | Roy Frostig | 11 May 2020, 23:48:30 UTC | update jaxpr doc and tests with single-operand cond | 14 May 2020, 04:14:41 UTC |
f90bd4f | Roy Frostig | 11 May 2020, 22:47:16 UTC | move cond operand to final argument position | 14 May 2020, 04:14:41 UTC |
cc6cea2 | Roy Frostig | 11 May 2020, 20:59:57 UTC | update reference jaxpr in cond-related jaxpr test | 14 May 2020, 04:14:41 UTC |
df62279 | Roy Frostig | 07 May 2020, 18:06:15 UTC | cache jaxprs formed by cond branch staging | 14 May 2020, 04:14:41 UTC |
efc1104 | Roy Frostig | 07 May 2020, 02:42:50 UTC | have loops module generate same-argument jaxprs for single-operand cond | 14 May 2020, 04:14:41 UTC |
30bb5fd | Roy Frostig | 07 May 2020, 02:03:53 UTC | transpose rule for single-operand cond primitive | 14 May 2020, 04:14:41 UTC |
37bb6f0 | Roy Frostig | 07 May 2020, 01:44:23 UTC | partial evaluation rule for single-operand cond primitive | 14 May 2020, 04:14:41 UTC |
139c2a9 | Roy Frostig | 06 May 2020, 02:44:34 UTC | fix cond-with-constants tests to use jax.numpy where needed | 14 May 2020, 04:14:41 UTC |
34d75c2 | Roy Frostig | 30 April 2020, 14:23:04 UTC | JVP rule for single-operand cond primitive | 14 May 2020, 04:14:41 UTC |
226b582 | Roy Frostig | 30 April 2020, 13:58:23 UTC | batching rule for single-operand cond primitive | 14 May 2020, 04:14:41 UTC |
fc4ab77 | Roy Frostig | 30 April 2020, 13:19:46 UTC | merge constvars when forming cond branch jaxprs | 14 May 2020, 04:14:41 UTC |
027a539 | Roy Frostig | 29 April 2020, 23:00:49 UTC | translation rule for single-operand cond primitive | 14 May 2020, 04:14:37 UTC |
97738be | Roy Frostig | 29 April 2020, 22:56:15 UTC | bind a single-operand cond primitive and update jaxpr typechecks | 14 May 2020, 04:13:18 UTC |
28e698e | Roy Frostig | 29 April 2020, 05:37:44 UTC | update uses of cond in lax control flow tests | 14 May 2020, 04:13:12 UTC |
9217641 | Roy Frostig | 24 April 2020, 00:41:19 UTC | take a single operand in lax.cond and deprecate the old calling convention | 14 May 2020, 04:11:25 UTC |
16cf845 | Skye Wanderman-Milne | 14 May 2020, 00:46:35 UTC | Update jaxlib version to 0.1.47 in README | 14 May 2020, 00:47:37 UTC |
6bd0160 | Jake Vanderplas | 14 May 2020, 00:31:02 UTC | disable arr.view() test on TPU (#3089) * disable arr.view() test on TPU * Update lax_numpy_test.py Use decorator. * Update lax_numpy_test.py Co-authored-by: Peter Hawkins <hawkinsp@cs.stanford.edu> | 14 May 2020, 00:31:02 UTC |
777636a | Jake Vanderplas | 13 May 2020, 21:36:46 UTC | promote integer inputs to float in jnp.median() and jnp.quantile() (#3082) | 13 May 2020, 21:36:46 UTC |
59aab01 | Jake Vanderplas | 13 May 2020, 19:48:16 UTC | Implement .view() method of jax.numpy arrays (#3073) | 13 May 2020, 19:48:16 UTC |
22d14fd | Peter Hawkins | 13 May 2020, 18:00:44 UTC | Remove workaround for Mac linear algebra bug that is fixed in the minimum jaxlib version. (#3080) | 13 May 2020, 18:00:44 UTC |
91d1e0d | Peter Hawkins | 13 May 2020, 14:59:31 UTC | Disable trapz test on TPU. (#3078) | 13 May 2020, 14:59:31 UTC |
e9d3394 | Sharad Vikram | 13 May 2020, 06:08:14 UTC | Make jnp.array convert empty list to DeviceArray (#3049) * Make jnp.array convert empty list to DeviceArray * Add additional tests for empty classes with __array__ Co-authored-by: Peter Hawkins <phawkins@google.com> | 13 May 2020, 06:08:14 UTC |
d2f84d6 | joao guilherme | 13 May 2020, 00:37:05 UTC | Change instances of onp to np and np to jnp (#3044) | 13 May 2020, 00:37:05 UTC |
cd966f2 | Peter Hawkins | 12 May 2020, 22:45:21 UTC | Disable check_type for trapz test due to test failures. (#3071) | 12 May 2020, 22:45:21 UTC |
abdf504 | Tom Hennigan | 12 May 2020, 22:03:22 UTC | Avoid recompilation of rolled loops in threefry2x32. (#3069) | 12 May 2020, 22:03:22 UTC |
11760ca | Skye Wanderman-Milne | 12 May 2020, 21:34:31 UTC | Refactor aval_to_result_handler to take unsharded aval. (#3067) This is in preparation for calling it from the sharded_jit. Currently aval_to_result_handler is specific to pmap, but this change makes it work for any kind of sharding. | 12 May 2020, 21:34:31 UTC |
4a84b91 | James Bradbury | 12 May 2020, 21:34:23 UTC | lower in_axes=None to XLA replication annotation (#3025) * lower in_axes=None to XLA replication annotation * ignore replicated value for tokens | 12 May 2020, 21:34:23 UTC |
88d3422 | Peter Hawkins | 12 May 2020, 21:19:09 UTC | Add special case for integer scalars to jax.numpy.power. (#3066) * Add special case for integer scalars to jax.numpy.power. | 12 May 2020, 21:19:09 UTC |
96fbfee | Jake Vanderplas | 12 May 2020, 21:12:03 UTC | Add lax implementation of np.trapz (#3042) | 12 May 2020, 21:12:03 UTC |
ef4debc | Skye Wanderman-Milne | 12 May 2020, 18:17:22 UTC | Update jax version to 0.1.67 (#3065) | 12 May 2020, 18:17:22 UTC |
0d97c3b | Skye Wanderman-Milne | 12 May 2020, 18:05:03 UTC | Import tpu_driver after xla_client (#3064) This is a workaround until we build a new jaxlib with https://github.com/tensorflow/tensorflow/commit/f4628678066c72309d3fd121af1aaf54d9905ca3 | 12 May 2020, 18:05:03 UTC |
8008aa9 | Peter Hawkins | 12 May 2020, 15:09:02 UTC | Fix error message in optimizers piecewise_constant. (#3061) | 12 May 2020, 15:09:02 UTC |
ccb8d45 | Yusuke Oda | 12 May 2020, 15:04:53 UTC | Uses jnp.square instead of power. (#3036) * Uses multiplication instead of power. * Uses jnp.square instead of mul and adds check if jnp.square is implemented by mul. | 12 May 2020, 15:04:53 UTC |
28bc4b7 | George Necula | 12 May 2020, 07:09:42 UTC | Adjusted lax.numpy.indices test for older versions of numpy (#3053) This test was failing on numpy 1.16.4 | 12 May 2020, 07:09:42 UTC |
cc9de87 | George Necula | 12 May 2020, 07:06:32 UTC | Disabled lstsq test due to numerical failures (#3054) | 12 May 2020, 07:06:32 UTC |
a2d6b1a | George Necula | 12 May 2020, 06:06:22 UTC | Fix typo in lstsq (#3052) | 12 May 2020, 06:06:22 UTC |
db71f3c | Jake Vanderplas | 11 May 2020, 21:53:17 UTC | Initial implementation of np.linalg.lstsq() via SVD (#2744) | 11 May 2020, 21:53:17 UTC |