https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
4747205 Update jax version to 0.1.68 21 May 2020, 17:36:28 UTC
bb2127c Future-proof view test against signaling NaNs (#3178) 21 May 2020, 16:20:59 UTC
c459280 Added `associative_scan`. (#2170) * Added `associative_scan`. * Fixed problem where base case of associative scan could fail * remove jax.numpy dependence in associative_scan Co-authored-by: Matthew Johnson <mattjj@google.com> 21 May 2020, 15:35:17 UTC
5c1de28 revise vmap in_axes/out_axes leaf type error msg (#3179) from #3161 discussion 21 May 2020, 15:00:18 UTC
eb81d7e add dict in_axes example to vmap docstring (#3176) * add dict in_axes example to vmap docstring fixes #3161 * fix typo 21 May 2020, 13:47:02 UTC
6e3c8b1 Fix arr.view() on TPU & improve tests (#3141) 21 May 2020, 13:40:24 UTC
ae9d175 fix while_loop cond function batching (#3174) * fix while_loop cond function batching fixes #3164 * add test for #3164 21 May 2020, 03:21:41 UTC
f9c978e improve docstring of jax.numpy.broadcast_to (#3173) thanks @joaogui1 ! 21 May 2020, 02:09:54 UTC
a4094f7 revise "Tracer with raw numpy" error message (#3160) * revise "Tracer with raw numpy" error message fixes #3133 * fix f-string typo * fix typo Co-authored-by: James Bradbury <jekbradbury@google.com> Co-authored-by: James Bradbury <jekbradbury@google.com> 21 May 2020, 02:09:44 UTC
12f26d3 Improve ``devices`` and related documentation (#3155) 20 May 2020, 21:40:28 UTC
8b32c5d Avoid running trivial jitted subcomputations in pe (#3169) 20 May 2020, 21:30:33 UTC
f349979 Relax tolerance of LaxVmapTest.testDot for float64 inputs. (#3167) 20 May 2020, 20:17:17 UTC
7d157c7 onp -> np (#3157) 20 May 2020, 04:43:48 UTC
42b425d fix disable_jit logic in lax.cond and lax.while_loop (#3156) * fix disable_jit logic in lax.cond, fixes #3093 * fix disable_jit logic in lax.while_loop, fix #2823 * add test for issue #3093 * add test for #2823 * add test for #2598 20 May 2020, 01:14:10 UTC
3141ff8 Add lax implementation of np.isin() and np.in1d() (#3087) 19 May 2020, 23:58:42 UTC
ccb203c improve pmap unbound axis error, fixes #3120 (#3152) 19 May 2020, 22:51:07 UTC
b7ff305 fix broken TPU test (#3153) 19 May 2020, 22:50:54 UTC
bc47a32 make lax.psum promote bool -> int (#3150) * make lax.psum promote bool -> int, fixes #3123 * fix test bug * fix typo in test 19 May 2020, 22:41:03 UTC
83a339e add erf and erfc rules (#3051) refactor def comp 19 May 2020, 22:22:25 UTC
b6777c0 Correct calculation of loss and increase learning rate (#3113) 19 May 2020, 22:18:34 UTC
850f1af improve errors for complex derivs, fixes #3121 (#3149) 19 May 2020, 22:17:03 UTC
8fe2619 Expand type support for random uniform() & randint() (#3138) 19 May 2020, 21:19:00 UTC
73b76e9 Exported lax from jax/__init__.py (#3135) This allows to use lax functions without separately importing jax.lax. 19 May 2020, 19:40:03 UTC
77e3132 Fix indentation for docstrings in jax.experimental.host_callback (#3119) 19 May 2020, 15:23:45 UTC
85fe5a2 Add gradients to the scatter_max and scatter_min operations. (#3111) This is being done to allow the creation of a differentiable segment_max. Segment_max is an important operation for GraphNets and is an open feature request at https://github.com/google/jax/issues/2255 Co-authored-by: Alex Davies <adavies@google.com> 19 May 2020, 06:06:32 UTC
8d0749f Fix a corner case in `repeat`. (#3117) * Fixes a corner case: `jnp.repeat(jnp.array(0), 1, axis=0)` throws an error, whereas `np.repeat(np.array(0), 1, axis=0) = np.array([0])`. * Add test for `np.range(np.array(0), 1, axis=0)`. 19 May 2020, 02:24:45 UTC
d53bab9 Improve sharded_jit error message and fix test (#3145) 19 May 2020, 01:24:33 UTC
083cdd3 Fix in pxla._inner_partitions (#3146) In cb77f2a22de49e85da93f43b7dc448aa238d5207, I switched to looking for sharding_constraint_p's name since sharding_constraint_p itself is defined in sharded_jit.py, but didn't quite get the update right. 19 May 2020, 01:18:06 UTC
cb77f2a Move some sharded_jit functionality into pxla.py (#3142) Specifically: * Move `_inner_partitions` * Move `get_num_partitions` * Move and slightly modify the logic for finding and validating inner partitions to a new function, `reconcile_num_partitions` * Move `_partitioned_sharding_spec` and rename to `partitioned_sharding_spec` This is in preparation for enabling pmap-of-sharded_jit, since pmap will need access to this functionality as well. 18 May 2020, 23:37:49 UTC
9152b76 Add with_sharding_constraint method to be used within sharded_jit. (#3100) See the with_sharding_constraint docstring for a description of what this method does. Depending on how we decide nested sharded_jits should work, an alternative implementation for with_sharding_constraint could be: ```python def with_sharding_constraint(x, partitions): return sharded_jit(lambda x: x, in_parts=partitions, out_parts=partitions) ``` In this case, we could get rid of the with_sharding_constraint primitive, and possibly even the API. This implementation gets the job done for now without committing to a nested sharded_jit behavior, and is also much easier to take the gradient of than sharded_jit. 18 May 2020, 22:20:49 UTC
4603432 Fix gradient for jax.scipy.ndimage.map_coordinates (#3110) * Fix gradient for jax.scipy.ndimage.map_coordinates Fixes GH3024 * minor refactor for clarification 18 May 2020, 22:13:54 UTC
36e7fad Add a primitive integer_pow() for values raised to a fixed integer scalar. (#3140) * Add a primitive integer_pow() for values raised to fixed integer scalar. Use integer_pow() in the RHS JVP of div(). Also use it in square() and reciprocal(). Fixes #3136 ``` In [1]: from jax import grad, make_jaxpr In [2]: def inv(x): return 1/x In [3]: print(grad(grad(grad(grad(grad(grad(inv))))))(4.)) 0.043945312 In [4]: make_jaxpr(grad(grad(grad(grad(grad(grad(inv)))))))(4.) Out[4]: { lambda ; a. let b = integer_pow[ y=-7 ] a c = mul -6.0 b d = mul -120.0 c in (d,) } In [5]: ``` * Use x ** 3 in gelu definition. 18 May 2020, 21:54:20 UTC
a9c1b38 Added link to README (#3139) 18 May 2020, 21:12:52 UTC
ed0e227 Add sharded_jit translation rule. (#3099) This is potentially dangerous, because it lets sharded_jit() be called inside other calls primitives (e.g. jit, pmap) which isn't supported yet. I'm adding it now because I'm planning to implement pmap-of-sharded_jit soon, and it will help with testing a set_sharding API I'm also planning to add soon. 18 May 2020, 18:39:55 UTC
bc5a0b3 Remove some uses of jax.partial. (#3131) 18 May 2020, 14:19:03 UTC
b071b12 Fixed link in FAQ (#3129) 18 May 2020, 13:02:49 UTC
2be471c Remove redundant flag configuration from lax_scipy_sparse_test (#3128) 18 May 2020, 05:12:21 UTC
912b282 Implement np.histogram_bin_edges and np.histogram (#3081) 16 May 2020, 17:23:26 UTC
670fab5 Test code in docs and api.py docstrings (#2994) Also remove jaxpr doc tests from api_test.py. 16 May 2020, 13:19:24 UTC
510af1d Fix documentation for `nn.elu`, `nn.celu`, and `lax.expm1`. (#3116) 16 May 2020, 03:51:53 UTC
e675f80 Add support for 8- and 16-bit output in _random_bits (#3090) 16 May 2020, 02:09:43 UTC
12a9af8 Update random.logistic() to prevent infinities (#3048) 15 May 2020, 21:29:02 UTC
812df27 Update uses of deprecated XLA methods. (#3109) 15 May 2020, 19:51:07 UTC
450b7ab Merge pull request #2993 from google/cond single-operand cond 15 May 2020, 02:38:43 UTC
b19166c silence a pytype error in jet tracing 15 May 2020, 01:01:34 UTC
28da2bc improve two-operand cond detection 15 May 2020, 00:21:57 UTC
77703b8 Add support for sorting complex values, defaulting to a NumPy-style l… (#3096) * Add support for sorting complex values, defaulting to a NumPy-style lexicographic ordering. Implemented using a custom comparator, since the XLA-level default comparator doesn't impose and ordering for complex values. * Disable sort test on CPU and TPU. 14 May 2020, 23:17:44 UTC
85d7b30 Allow non-classes as second argument of `issubclass`. (#3098) 14 May 2020, 23:17:23 UTC
dceb578 stash order on jet master trace, fixes #3079 (#3097) 14 May 2020, 23:06:20 UTC
9d8ecc5 avoid committing to argument types so long as cond is overloaded 14 May 2020, 20:56:07 UTC
007cdf2 Adds additional epsilon to adam for numerical stability. (#3091) * Adds additional epsilon to adam for numerical stability. Meta-gradients through the adam optimizer diverge, because the derivative of the adam scaling with respect to the gradients get an additional 1/sqrt(g) factor. This additional factor is unregularized without the second epsilon added in this commit. * Renames eps2 to eps_root and improves docstring Co-authored-by: Thomas Keck <thomaskeck@google.com> 14 May 2020, 20:04:08 UTC
83c7394 fix astype() test (#3072) 14 May 2020, 19:58:31 UTC
7c687b2 sharded_jit cleanup (#3075) * Add sharding utilities to xla_bridge * Change partitions argument to in_parts and out_parts * Add unit tests * Reuse more pxla functionality * Remove stale translation rule * Fail on non-TPU platforms * Add docstring * And more! 14 May 2020, 18:38:08 UTC
73a55e0 changelog: single-operand cond 14 May 2020, 16:16:49 UTC
930be58 update reference jaxpr using cond in outfeed rewriting test 14 May 2020, 16:05:31 UTC
cb5b1d1 handle single-operand cond in token-threading rewriter 14 May 2020, 16:04:48 UTC
9f6774a update uses of cond in host callback test 14 May 2020, 16:02:29 UTC
4ce2aa2 Make lax.sort support tuple arguments using a variadic sort. (#3085) * Make lax.sort support tuple arguments using a variadic sort. Change sort_jvp to use a gather of ids to compute the JVP rather than sorting repeatedly. Remove sort_key_val_p, since it is redundant with a variadic sort_p. * Fix mypy errors. * Change JVP rule to use NumPy indexing. Remove redundant case in batching rule. 14 May 2020, 15:13:15 UTC
6e3bfc3 comment on joining constants for conditional branch jaxprs 14 May 2020, 04:14:41 UTC
2950eb1 comment on joining staged jaxprs in partial evaluation of conditionals 14 May 2020, 04:14:41 UTC
34efdd9 style changes in lax.cond 14 May 2020, 04:14:41 UTC
e8f12b6 remove deprecation warning on five-argument cond 14 May 2020, 04:14:41 UTC
76612c6 test and fix cond when branch-staging is off 14 May 2020, 04:14:41 UTC
de03c99 update jaxpr doc and tests with single-operand cond 14 May 2020, 04:14:41 UTC
f90bd4f move cond operand to final argument position 14 May 2020, 04:14:41 UTC
cc6cea2 update reference jaxpr in cond-related jaxpr test 14 May 2020, 04:14:41 UTC
df62279 cache jaxprs formed by cond branch staging 14 May 2020, 04:14:41 UTC
efc1104 have loops module generate same-argument jaxprs for single-operand cond 14 May 2020, 04:14:41 UTC
30bb5fd transpose rule for single-operand cond primitive 14 May 2020, 04:14:41 UTC
37bb6f0 partial evaluation rule for single-operand cond primitive 14 May 2020, 04:14:41 UTC
139c2a9 fix cond-with-constants tests to use jax.numpy where needed 14 May 2020, 04:14:41 UTC
34d75c2 JVP rule for single-operand cond primitive 14 May 2020, 04:14:41 UTC
226b582 batching rule for single-operand cond primitive 14 May 2020, 04:14:41 UTC
fc4ab77 merge constvars when forming cond branch jaxprs 14 May 2020, 04:14:41 UTC
027a539 translation rule for single-operand cond primitive 14 May 2020, 04:14:37 UTC
97738be bind a single-operand cond primitive and update jaxpr typechecks 14 May 2020, 04:13:18 UTC
28e698e update uses of cond in lax control flow tests 14 May 2020, 04:13:12 UTC
9217641 take a single operand in lax.cond and deprecate the old calling convention 14 May 2020, 04:11:25 UTC
16cf845 Update jaxlib version to 0.1.47 in README 14 May 2020, 00:47:37 UTC
6bd0160 disable arr.view() test on TPU (#3089) * disable arr.view() test on TPU * Update lax_numpy_test.py Use decorator. * Update lax_numpy_test.py Co-authored-by: Peter Hawkins <hawkinsp@cs.stanford.edu> 14 May 2020, 00:31:02 UTC
777636a promote integer inputs to float in jnp.median() and jnp.quantile() (#3082) 13 May 2020, 21:36:46 UTC
59aab01 Implement .view() method of jax.numpy arrays (#3073) 13 May 2020, 19:48:16 UTC
22d14fd Remove workaround for Mac linear algebra bug that is fixed in the minimum jaxlib version. (#3080) 13 May 2020, 18:00:44 UTC
91d1e0d Disable trapz test on TPU. (#3078) 13 May 2020, 14:59:31 UTC
e9d3394 Make jnp.array convert empty list to DeviceArray (#3049) * Make jnp.array convert empty list to DeviceArray * Add additional tests for empty classes with __array__ Co-authored-by: Peter Hawkins <phawkins@google.com> 13 May 2020, 06:08:14 UTC
d2f84d6 Change instances of onp to np and np to jnp (#3044) 13 May 2020, 00:37:05 UTC
cd966f2 Disable check_type for trapz test due to test failures. (#3071) 12 May 2020, 22:45:21 UTC
abdf504 Avoid recompilation of rolled loops in threefry2x32. (#3069) 12 May 2020, 22:03:22 UTC
11760ca Refactor aval_to_result_handler to take unsharded aval. (#3067) This is in preparation for calling it from the sharded_jit. Currently aval_to_result_handler is specific to pmap, but this change makes it work for any kind of sharding. 12 May 2020, 21:34:31 UTC
4a84b91 lower in_axes=None to XLA replication annotation (#3025) * lower in_axes=None to XLA replication annotation * ignore replicated value for tokens 12 May 2020, 21:34:23 UTC
88d3422 Add special case for integer scalars to jax.numpy.power. (#3066) * Add special case for integer scalars to jax.numpy.power. 12 May 2020, 21:19:09 UTC
96fbfee Add lax implementation of np.trapz (#3042) 12 May 2020, 21:12:03 UTC
ef4debc Update jax version to 0.1.67 (#3065) 12 May 2020, 18:17:22 UTC
0d97c3b Import tpu_driver after xla_client (#3064) This is a workaround until we build a new jaxlib with https://github.com/tensorflow/tensorflow/commit/f4628678066c72309d3fd121af1aaf54d9905ca3 12 May 2020, 18:05:03 UTC
8008aa9 Fix error message in optimizers piecewise_constant. (#3061) 12 May 2020, 15:09:02 UTC
ccb8d45 Uses jnp.square instead of power. (#3036) * Uses multiplication instead of power. * Uses jnp.square instead of mul and adds check if jnp.square is implemented by mul. 12 May 2020, 15:04:53 UTC
28bc4b7 Adjusted lax.numpy.indices test for older versions of numpy (#3053) This test was failing on numpy 1.16.4 12 May 2020, 07:09:42 UTC
cc9de87 Disabled lstsq test due to numerical failures (#3054) 12 May 2020, 07:06:32 UTC
a2d6b1a Fix typo in lstsq (#3052) 12 May 2020, 06:06:22 UTC
db71f3c Initial implementation of np.linalg.lstsq() via SVD (#2744) 11 May 2020, 21:53:17 UTC
back to top