1cf708e | Stephan Hoyer | 04 April 2020, 22:55:46 UTC | Support pytrees in jax.scipy.linalg.cg (#2600) * Support pytrees in jax.scipy.linalg.cg Ideally there would be an easier way to write this, but for now this will do. * Fixup test | 04 April 2020, 22:55:46 UTC |
2c4ced2 | Matthew Johnson | 04 April 2020, 03:18:21 UTC | Merge pull request #2601 from adarob/fix-err Fix error handling when an attempt is made to pmap a scalar. | 04 April 2020, 03:18:21 UTC |
48eb524 | Adam Roberts | 04 April 2020, 01:03:42 UTC | Fix error handling when an attempt is made to pmap a scalar. | 04 April 2020, 01:03:42 UTC |
c2f56fb | Matthew Johnson | 03 April 2020, 23:21:38 UTC | add notes to changelog | 03 April 2020, 23:21:38 UTC |
60de46a | Matthew Johnson | 03 April 2020, 22:47:41 UTC | Merge pull request #2591 from google/tracer-printing make tracers tree-pretty-print their contents | 03 April 2020, 22:47:41 UTC |
b67250e | Matthew Johnson | 03 April 2020, 22:21:45 UTC | Merge pull request #2599 from sharadmv/logit-fix Fix grad(logit) to use defjvps and enable it in tests | 03 April 2020, 22:21:45 UTC |
3c9fb35 | Sharad Vikram | 03 April 2020, 21:37:29 UTC | Fix dtype error | 03 April 2020, 21:37:29 UTC |
1b93bb5 | Stephan Hoyer | 03 April 2020, 20:37:11 UTC | Implement scipy.sparse.linalg.cg (second try) (#2566) * super minimal starter code * Update optimizers.py * implement flip with axis = None * Create sparse.py * fix some imports * Update sparse.py * add partial function & test * Update lax_scipy_sparse_test.py * Update lax_scipy_sparse_test.py * add a test case for sparse pd matrix & add bigger dim * address comments * fix info return & create matrix with rng_factory * Update lax_scipy_sparse_test.py * Update lax_scipy_sparse_test.py * Update sparse.py * Update sparse.py * Update sparse.py * Update lax_scipy_sparse_test.py * Update lax_scipy_sparse_test.py * cast jax arrays into numpy array for scipy compatibility * Update sparse.py * Update sparse.py * fix None issue, but algo is not working * fix return of build_and_solve and output of while_loop * fix condition func of while loop * clearer variable names * mismatch error * Update lax_scipy_sparse_test.py * Fixes to jax.experimental.sparse.cg * Fix tests for gradients * Add support for preconditioners to cg * Move cg into scipy, update docs * doc tweak Co-authored-by: Tuan Nguyen <anhtuan277@gmail.com> | 03 April 2020, 20:37:11 UTC |
72783bb | Sharad Vikram | 03 April 2020, 20:27:02 UTC | Fix grad(logit) to use defjvps and enable it in tests | 03 April 2020, 20:27:02 UTC |
2b3beff | Peter Hawkins | 03 April 2020, 20:09:48 UTC | Make reduce_prod differentiable to arbitrary order. (#2597) * Make reduce_prod differentiable to arbitrary order. The previous strategy for computing the JVP of reduce_prod used a pair of reduce_window operations to form left and right products for each position. This PR instead builds an explicit reduction tree and differentiates through it, which while not as efficient as using XLA's built-in reductions, has the advantage of being differentiable to arbitrary order. . * Return the tree-reduction primals instead of returning the original primals in JVP rule. | 03 April 2020, 20:09:48 UTC |
824ac86 | Peter Hawkins | 03 April 2020, 19:39:56 UTC | Reimplement np.cumsum and np.cumprod in terms of a parallel prefix scan. (#2596) * Reimplement np.cumsum and np.cumprod in terms of a parallel prefix scan. Unlike the existing implementation based on lax.reduce_window, this implementation is O(n log n) instead of O(n^2) and is arbitrarily differentiable. Fixes #1212, #2418, #2542. May help with issue #2380. * Relax gradient test tolerance. | 03 April 2020, 19:39:56 UTC |
192e908 | Matthew Johnson | 03 April 2020, 06:53:39 UTC | Merge pull request #2561 from madisonmay/batch-norm-no-scale-or-center FIX: batch norm w/ no scale, center | 03 April 2020, 06:53:39 UTC |
6402419 | Matthew Johnson | 03 April 2020, 06:52:40 UTC | Merge pull request #2563 from google/callable-typechecks add callable typechecks to more api.py functions | 03 April 2020, 06:52:40 UTC |
25aeb80 | Matthew Johnson | 03 April 2020, 06:42:31 UTC | Merge pull request #2592 from google/travis-mypy add trace state check tearDown to JaxTestCase | 03 April 2020, 06:42:31 UTC |
3dee689 | Matthew Johnson | 03 April 2020, 06:28:57 UTC | Merge pull request #2593 from google/issue2578 add full lower to custom_jvp/vjp call bind | 03 April 2020, 06:28:57 UTC |
ba8225f | Matthew Johnson | 03 April 2020, 06:11:55 UTC | skip all parallelize tests (abandonware right now) | 03 April 2020, 06:11:55 UTC |
0e49133 | Matthew Johnson | 03 April 2020, 05:52:07 UTC | add full lower to custom_jvp/vjp call bind fixes #2578 | 03 April 2020, 05:54:29 UTC |
f2de1bf | Matthew Johnson | 03 April 2020, 05:01:43 UTC | add trace state check tearDown to JaxTestCase | 03 April 2020, 05:01:43 UTC |
297c902 | Matthew Johnson | 03 April 2020, 04:04:12 UTC | make tracers tree-pretty-print their contents | 03 April 2020, 04:04:12 UTC |
64a7d17 | Matthew Johnson | 03 April 2020, 04:02:13 UTC | Merge pull request #2587 from google/travis-mypy re-enable travis mypy testing (typo broke it) | 03 April 2020, 04:02:13 UTC |
5d3f1bd | Matthew Johnson | 03 April 2020, 03:14:12 UTC | tell mypy: using __init__ to reinitialize is OK | 03 April 2020, 03:14:12 UTC |
6d4987c | Matthew Johnson | 03 April 2020, 01:19:44 UTC | make core.trace_state resetting be thread-local | 03 April 2020, 01:19:44 UTC |
b78b7a0 | Matthew Johnson | 03 April 2020, 01:03:58 UTC | add global trace state checks to more tests | 03 April 2020, 01:03:58 UTC |
ab0a005 | Matthew Johnson | 03 April 2020, 00:18:47 UTC | check sublevel is reset in loops_test.py | 03 April 2020, 00:18:47 UTC |
c72abf6 | Matthew Johnson | 02 April 2020, 22:52:01 UTC | re-enable travis mypy testing (typo broke it) | 02 April 2020, 22:52:01 UTC |
5c0ac40 | Skye Wanderman-Milne | 02 April 2020, 19:54:58 UTC | Revert jax.numpy.matmul to pre-#2512 version. (#2584) https://github.com/google/jax/pull/2512 was causing some Google-internal tests to take longer. | 02 April 2020, 19:54:58 UTC |
0bdd0f6 | Matthew Johnson | 02 April 2020, 16:03:12 UTC | Merge pull request #2581 from google/jet-process-call post process call of jet! | 02 April 2020, 16:03:12 UTC |
84dc6cc | Matthew Johnson | 02 April 2020, 14:52:17 UTC | post process call of jet! Also included David's jet rule for lax.select. Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> Co-authored-by: Jacob Kelly <jacob.jin.kelly@gmail.com> Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu> | 02 April 2020, 14:56:26 UTC |
32c45fb | George Necula | 02 April 2020, 11:10:50 UTC | Another attempt to disable new failing test on TPU | 02 April 2020, 11:11:45 UTC |
0c53ce9 | George Necula | 02 April 2020, 09:13:40 UTC | Disable test with float16 on TPU | 02 April 2020, 09:15:25 UTC |
1bb9aaa | Skye Wanderman-Milne | 01 April 2020, 23:33:37 UTC | Revert "Refactor shard_args ShardedDeviceArray slow path. (#2562)" (#2577) This reverts commit 2e87674aff0609d58eed07091d8e9fad8ee71d6c. This refactoring was broken, but we apparently have no test coverage! mypy/linting would have caught it though. | 01 April 2020, 23:33:37 UTC |
2e87674 | Skye Wanderman-Milne | 01 April 2020, 22:35:26 UTC | Refactor shard_args ShardedDeviceArray slow path. (#2562) | 01 April 2020, 22:35:26 UTC |
2d42c42 | Peter Hawkins | 01 April 2020, 21:18:17 UTC | Fix missing type promotion accidentally removed by #2512. (#2575) This is in fact covered by the existing tests, but we were unlucky and didn't hit them in the set of generated tests we selected. | 01 April 2020, 21:18:17 UTC |
8c4a938 | Tzu-Wei Sung | 01 April 2020, 19:29:48 UTC | Implement np.ldexp and np.frexp. (#1529) Co-authored-by: Peter Hawkins <phawkins@google.com> | 01 April 2020, 19:29:48 UTC |
79ce698 | Peter Hawkins | 01 April 2020, 16:35:12 UTC | Fix incorrect raise in np.clip and add a test. (#2571) | 01 April 2020, 16:35:12 UTC |
651316f | Xiayun Sun | 01 April 2020, 15:57:57 UTC | Fix issue 1465: fix jit in example (#1473) * fix jit in example * Avoid using static_argnums on a keyword argument; use a positional argument and a wrapper function for now. Co-authored-by: Peter Hawkins <phawkins@google.com> | 01 April 2020, 15:57:57 UTC |
86a4073 | Matthew Johnson | 01 April 2020, 15:25:32 UTC | enable beta test on float64 values (#1177) * enable beta test on float64 values cf. #1123 * Enable beta test on all platforms. It seems sufficiently fast now. Co-authored-by: Peter Hawkins <phawkins@google.com> | 01 April 2020, 15:25:32 UTC |
c758aff | Peter Hawkins | 01 April 2020, 14:54:47 UTC | Fix some missing cases of broadcasting in np.einsum. (#2512) * Fix some missing cases of broadcasting in np.einsum. In particular, np.einsum allows one side of a batch or contracting dimension to have size 1 even if the other side has a non-1 size. Implement np.matmul in terms of np.einsum. This allows us to reuse einsum's logic for performing broadcasting without explicitly broadcasting the LHS and RHS together. * Add regression test. Fixes #2189. | 01 April 2020, 14:54:47 UTC |
1d09b6b | Tom Hennigan | 01 April 2020, 13:58:17 UTC | Replace CHECK with assert. (#2569) | 01 April 2020, 13:58:17 UTC |
ca23be6 | Tom Hennigan | 28 March 2020, 13:14:40 UTC | Add `jax.tree_util.all_leaves(iterable)`. In Haiku (https://github.com/deepmind/dm-haiku) we have `FlatMapping` which is an immutable Mapping subclass maintaining a flat internal representation. Our goal is to allow very cheap flatten/unflatten since these objects are used to represent parameters/state and are often passed in and out of JAX functions that flatten their inputs (e.g. jit/pmap). One challenge we have is that on unflatten we need a fast way of testing whether the list of leaves provided are flat or not (since we want to cache both the flat structure and the leaves). Consider the following case: ```python d = FlatMapping.from_mapping({"a": 1}) # Caches the result of jax.tree_flatten. l, t = jax.tree_flatten(d) # Fine, leaves are flat. l = list(map(lambda x: (x, x), l)) # leaves are no longer flat. d2 = jax.tree_unflatten(t, l) # Needs to recompute structure. jax.tree_leaves(d2) # Should return [1, 1] not [(1, 1)] ``` Actual implementation here: https://github.com/deepmind/dm-haiku/blob/d37b486e09696ef34f7396c11b04074cb73a963c/haiku/_src/data_structures.py#L204-L208 This function allows an efficient way to do this using the JAX public API. | 01 April 2020, 07:56:01 UTC |
025d874 | Rick Wierenga | 01 April 2020, 07:01:59 UTC | Fix error message | 01 April 2020, 07:54:37 UTC |
d2a827a | George Necula | 31 March 2020, 08:01:19 UTC | Ensure the global trace_state is restored on errors in loops This is an attempted fix for https://github.com/google/jax/issues/2507 | 01 April 2020, 07:23:14 UTC |
83b9575 | Matthew Johnson | 01 April 2020, 01:46:15 UTC | add callable typechecks to more api.py functions | 01 April 2020, 01:46:15 UTC |
e3a9a56 | Skye Wanderman-Milne | 01 April 2020, 00:47:54 UTC | Remove unused argument from pxla.shard_args (#2560) | 01 April 2020, 00:47:54 UTC |
4fa1534 | Madison May | 31 March 2020, 23:59:57 UTC | FIX: batch norm w/ no scale, center | 31 March 2020, 23:59:57 UTC |
c28c46e | Skye Wanderman-Milne | 31 March 2020, 22:52:41 UTC | Add ShardedDeviceArray indexing benchmark. (#2549) Example output: ``` ---------Benchmark summary for ShardedDeviceArray_indexing--------- indices_fn mean %std relative ------------------ -------- ------- ---------- integer_indices 0.16901 8.52522 1 integer_2D_indices 18.4918 0 109.412 ``` | 31 March 2020, 22:52:41 UTC |
bfbd0b8 | Peter Hawkins | 31 March 2020, 21:09:14 UTC | Move tuple_arguments onto Compile() instead of Execute(). (#2559) Update minimum jaxlib version to 0.1.43. | 31 March 2020, 21:09:14 UTC |
6781eca | Matthew Johnson | 31 March 2020, 20:51:45 UTC | Merge pull request #2558 from google/while-loop-of-pmap-bug don't hardcode array size in test | 31 March 2020, 20:51:45 UTC |
949d423 | Matthew Johnson | 31 March 2020, 20:07:59 UTC | improve error message for #2554 | 31 March 2020, 20:07:59 UTC |
29c581b | Matthew Johnson | 31 March 2020, 18:54:57 UTC | don't hardcode array size in test Fixes #2554 | 31 March 2020, 18:54:57 UTC |
bd1708c | Peter Hawkins | 31 March 2020, 14:02:38 UTC | Update changelog and README for jaxlib 0.1.43. (#2556) | 31 March 2020, 14:02:38 UTC |
59ed4ae | George Necula | 31 March 2020, 09:11:47 UTC | Disable test_while_loop_of_pmap on all platforms Issue: #2554 Disable the test so that we can continue the google3 tests | 31 March 2020, 09:12:35 UTC |
fd52fbf | George Necula | 31 March 2020, 08:36:47 UTC | Fix import in benchmarks This works on my machine as 'python benchmarks/pmap_benchmark.py'. It also follows the code in examples. This will need a copybara rule to change the import to 'from jax.benchmarks import benchmark' | 31 March 2020, 08:48:08 UTC |
a4ceae1 | Matthew Johnson | 31 March 2020, 05:12:38 UTC | fix link in custom derivatives tutorial notebook | 31 March 2020, 05:12:38 UTC |
27604c3 | Matthew Johnson | 31 March 2020, 05:11:35 UTC | fix typo in notebook | 31 March 2020, 05:11:35 UTC |
e017a92 | Matthew Johnson | 31 March 2020, 05:06:00 UTC | fix typo | 31 March 2020, 05:06:00 UTC |
7a9c550 | Matthew Johnson | 31 March 2020, 04:30:47 UTC | add sphinx-autodoc-typehints to travis install | 31 March 2020, 04:30:47 UTC |
e7be43d | Matthew Johnson | 31 March 2020, 04:09:12 UTC | update api.py docstrings for sphinx highlighting | 31 March 2020, 04:09:12 UTC |
cda505a | Matthew Johnson | 31 March 2020, 03:46:09 UTC | Merge pull request #2551 from google/enable-mpc-tests try re-enabling control tests that trigger #2507 | 31 March 2020, 03:46:09 UTC |
909fee6 | Matthew Johnson | 31 March 2020, 03:22:04 UTC | try adding sphinx-autodoc-typehints | 31 March 2020, 03:22:04 UTC |
b015e57 | Matthew Johnson | 31 March 2020, 03:12:33 UTC | try re-enabling control tests that trigger #2507 | 31 March 2020, 03:12:33 UTC |
1f03d48 | Matthew Johnson | 31 March 2020, 03:10:39 UTC | try resetting global tracer state in loops_test.py attempting to address #2507 | 31 March 2020, 03:10:39 UTC |
de37eae | Matthew Johnson | 31 March 2020, 03:10:16 UTC | Merge pull request #2550 from google/update-custom-derivatives-notebook update custom derivatives tutorial notebook | 31 March 2020, 03:10:16 UTC |
15009c9 | Matthew Johnson | 31 March 2020, 02:45:45 UTC | add docstring for defjvps, fix sphinx docs | 31 March 2020, 02:45:45 UTC |
bd726fc | Matthew Johnson | 31 March 2020, 02:37:11 UTC | update custom derivatives tutorial notebook * add clip_gradient example * add defjvps convenience wrapper | 31 March 2020, 02:37:11 UTC |
b7d2cbd | Matthew Johnson | 31 March 2020, 01:34:50 UTC | Merge pull request #2548 from google/while-loop-of-pmap-bug fix a while-loop-of-pmap bug (thanks @jaspersnoek) | 31 March 2020, 01:34:50 UTC |
375575a | Matthew Johnson | 31 March 2020, 00:53:47 UTC | skip new test on cpu unless num_devices > 1 | 31 March 2020, 00:53:47 UTC |
df3b5fe | Matthew Johnson | 31 March 2020, 00:48:07 UTC | fix a while-loop-of-pmap bug (thanks @jaspersnoek) | 31 March 2020, 00:48:07 UTC |
acb2414 | Peter Hawkins | 30 March 2020, 23:43:23 UTC | Update XLA to fix CUDA 9.2 build problem. (#2547) | 30 March 2020, 23:43:23 UTC |
0c7e134 | Peter Hawkins | 30 March 2020, 20:56:54 UTC | Update XLA and increment jaxlib version to 0.1.43. (#2546) | 30 March 2020, 20:56:54 UTC |
1275bc5 | Matthew Johnson | 30 March 2020, 20:50:32 UTC | don't pass `-n 1` to pytest in travis | 30 March 2020, 20:50:32 UTC |
b43051e | Matthew Johnson | 30 March 2020, 20:49:56 UTC | minor fix to custom_transforms | 30 March 2020, 20:49:56 UTC |
3d0c187 | Matthew Johnson | 30 March 2020, 20:38:05 UTC | Merge pull request #2543 from google/process-custom-jvp-default-implementation comments/defaults for process_custom_{jv,vj}p_call | 30 March 2020, 20:38:05 UTC |
ea9401d | Matthew Johnson | 30 March 2020, 19:47:13 UTC | Merge pull request #2541 from google/xla-computation-duck-typing allow duck-typing in xla_computation arguments | 30 March 2020, 19:47:13 UTC |
70a3f47 | Matthew Johnson | 30 March 2020, 18:57:03 UTC | comments/defaults for process_custom_{jv,vj}p_call | 30 March 2020, 19:02:25 UTC |
f766c5e | Matthew Johnson | 30 March 2020, 18:31:29 UTC | allow duck-typing in xla_computation arguments | 30 March 2020, 18:31:29 UTC |
9d8823c | Matthew Johnson | 30 March 2020, 07:41:04 UTC | add initial_style_staging to custom_transforms | 30 March 2020, 07:41:04 UTC |
a6a837a | Matthew Johnson | 30 March 2020, 07:35:45 UTC | add some stage_out=True indicators | 30 March 2020, 07:35:45 UTC |
fab6600 | Matthew Johnson | 30 March 2020, 06:31:07 UTC | Merge pull request #2539 from google/remat-fix remove remat context check, add initial staging | 30 March 2020, 06:31:07 UTC |
bdc0c3b | Matthew Johnson | 30 March 2020, 06:29:55 UTC | remove remat context check, add initial staging | 30 March 2020, 06:29:55 UTC |
e7f8503 | Matthew Johnson | 30 March 2020, 06:04:00 UTC | Merge pull request #2500 from google/custom-jvp-fix revise custom_jvp / custom_vjp rule jaxpr staging | 30 March 2020, 06:04:00 UTC |
74d358d | Matthew Johnson | 30 March 2020, 06:00:40 UTC | skip ode test on import error (internal) | 30 March 2020, 06:00:40 UTC |
b2b68e5 | Matthew Johnson | 30 March 2020, 03:51:51 UTC | fix bugs, add tests | 30 March 2020, 05:53:37 UTC |
7a4c4d5 | Matthew Johnson | 30 March 2020, 03:48:08 UTC | use custom_jvp for internal functions | 30 March 2020, 03:48:08 UTC |
6193e5e | Matthew Johnson | 28 March 2020, 21:15:46 UTC | revamp custom_jvp/vjp implementation to fix bugs Co-authored-by: Dougal Maclaurin <dougalm@google.com> | 30 March 2020, 02:35:01 UTC |
954334c | Matthew Johnson | 30 March 2020, 01:05:48 UTC | Merge pull request #2537 from duvenaud/morejets Adds lots of trivial jet rules, partly addresses #2431 | 30 March 2020, 01:05:48 UTC |
0f11f9c | Matthew Johnson | 30 March 2020, 00:02:19 UTC | Merge pull request #2538 from google/pfau workaround for pmap output PRED arrays on cpu/gpu | 30 March 2020, 00:02:19 UTC |
1762a86 | Matthew Johnson | 29 March 2020, 21:45:17 UTC | workaround for pmap output PRED arrays on cpu/gpu | 29 March 2020, 21:45:17 UTC |
305dd8c | Matthew Johnson | 29 March 2020, 21:43:42 UTC | Merge pull request #2536 from google/issue2534 add docstring / reference doc link for axis_index | 29 March 2020, 21:43:42 UTC |
fcc1e76 | Matthew Johnson | 29 March 2020, 20:56:26 UTC | add docstring / reference doc link for axis_index fixes #2534 | 29 March 2020, 20:56:26 UTC |
614d39d | Matthew Johnson | 29 March 2020, 20:37:49 UTC | Merge pull request #2527 from lucasb-eyer/patch-1 Make it more explicit that default JVP assumes |R | 29 March 2020, 20:37:49 UTC |
ead8011 | David Duvenaud | 29 March 2020, 20:28:17 UTC | Added lots of trivial jet rules. Co-Authored-By: jessebett <jessebett@gmail.com> Co-Authored-By: Jacob Kelly <jacob.kelly@mail.utoronto.ca> | 29 March 2020, 20:28:17 UTC |
67283a0 | Matthew Johnson | 28 March 2020, 20:52:40 UTC | add new custom_jvp tests from #2500 Co-authored-by: Dougal Maclaurin <dougalm@google.com> | 28 March 2020, 22:33:20 UTC |
bcc5191 | Matthew Johnson | 28 March 2020, 22:15:47 UTC | Merge pull request #2533 from google/core-trace-type-annotations add type annotations to core.py tracing machinery | 28 March 2020, 22:15:47 UTC |
f99720b | Matthew Johnson | 28 March 2020, 21:55:58 UTC | add type annotations to core.py tracing machinery also add .copy() method to core.trace_state global trace state | 28 March 2020, 21:58:35 UTC |
cdbcbb0 | Matthew Johnson | 28 March 2020, 19:31:40 UTC | Merge pull request #2532 from google/sharded-device-array-handlers add ShardedDeviceArray to ad vspace op handlers | 28 March 2020, 19:31:40 UTC |
1b59789 | Matthew Johnson | 28 March 2020, 18:56:12 UTC | add ShardedDeviceArray to ad vspace op handlers fixes #2529 (thanks, @dpfau !) | 28 March 2020, 18:56:12 UTC |
415cde5 | Lucas Beyer | 28 March 2020, 11:32:44 UTC | Make it more explicit that default JVP assumes |R It's just an attempt to make this implicit assumption, as it only became clear to me after our discussion in chat, not after reading this. | 28 March 2020, 11:32:44 UTC |
f371bfc | Peter Hawkins | 28 March 2020, 01:24:26 UTC | Improve speed of LU decomposition on TPU. (#2526) Increase the block size, which helps with compilation time. Merge the two row permutations in the outer loop, which means we do row-at-a-time gathers. | 28 March 2020, 01:24:26 UTC |
6d0810a | Matthew Johnson | 27 March 2020, 19:46:01 UTC | Merge pull request #2501 from botev/isclose Making isclose handle correctly infinite and NaN values. | 27 March 2020, 19:46:01 UTC |