sort by:
Revision Author Date Message Commit Date
1cf708e Support pytrees in jax.scipy.linalg.cg (#2600) * Support pytrees in jax.scipy.linalg.cg Ideally there would be an easier way to write this, but for now this will do. * Fixup test 04 April 2020, 22:55:46 UTC
2c4ced2 Merge pull request #2601 from adarob/fix-err Fix error handling when an attempt is made to pmap a scalar. 04 April 2020, 03:18:21 UTC
48eb524 Fix error handling when an attempt is made to pmap a scalar. 04 April 2020, 01:03:42 UTC
c2f56fb add notes to changelog 03 April 2020, 23:21:38 UTC
60de46a Merge pull request #2591 from google/tracer-printing make tracers tree-pretty-print their contents 03 April 2020, 22:47:41 UTC
b67250e Merge pull request #2599 from sharadmv/logit-fix Fix grad(logit) to use defjvps and enable it in tests 03 April 2020, 22:21:45 UTC
3c9fb35 Fix dtype error 03 April 2020, 21:37:29 UTC
1b93bb5 Implement scipy.sparse.linalg.cg (second try) (#2566) * super minimal starter code * Update optimizers.py * implement flip with axis = None * Create sparse.py * fix some imports * Update sparse.py * add partial function & test * Update lax_scipy_sparse_test.py * Update lax_scipy_sparse_test.py * add a test case for sparse pd matrix & add bigger dim * address comments * fix info return & create matrix with rng_factory * Update lax_scipy_sparse_test.py * Update lax_scipy_sparse_test.py * Update sparse.py * Update sparse.py * Update sparse.py * Update lax_scipy_sparse_test.py * Update lax_scipy_sparse_test.py * cast jax arrays into numpy array for scipy compatibility * Update sparse.py * Update sparse.py * fix None issue, but algo is not working * fix return of build_and_solve and output of while_loop * fix condition func of while loop * clearer variable names * mismatch error * Update lax_scipy_sparse_test.py * Fixes to jax.experimental.sparse.cg * Fix tests for gradients * Add support for preconditioners to cg * Move cg into scipy, update docs * doc tweak Co-authored-by: Tuan Nguyen <anhtuan277@gmail.com> 03 April 2020, 20:37:11 UTC
72783bb Fix grad(logit) to use defjvps and enable it in tests 03 April 2020, 20:27:02 UTC
2b3beff Make reduce_prod differentiable to arbitrary order. (#2597) * Make reduce_prod differentiable to arbitrary order. The previous strategy for computing the JVP of reduce_prod used a pair of reduce_window operations to form left and right products for each position. This PR instead builds an explicit reduction tree and differentiates through it, which while not as efficient as using XLA's built-in reductions, has the advantage of being differentiable to arbitrary order. . * Return the tree-reduction primals instead of returning the original primals in JVP rule. 03 April 2020, 20:09:48 UTC
824ac86 Reimplement np.cumsum and np.cumprod in terms of a parallel prefix scan. (#2596) * Reimplement np.cumsum and np.cumprod in terms of a parallel prefix scan. Unlike the existing implementation based on lax.reduce_window, this implementation is O(n log n) instead of O(n^2) and is arbitrarily differentiable. Fixes #1212, #2418, #2542. May help with issue #2380. * Relax gradient test tolerance. 03 April 2020, 19:39:56 UTC
192e908 Merge pull request #2561 from madisonmay/batch-norm-no-scale-or-center FIX: batch norm w/ no scale, center 03 April 2020, 06:53:39 UTC
6402419 Merge pull request #2563 from google/callable-typechecks add callable typechecks to more api.py functions 03 April 2020, 06:52:40 UTC
25aeb80 Merge pull request #2592 from google/travis-mypy add trace state check tearDown to JaxTestCase 03 April 2020, 06:42:31 UTC
3dee689 Merge pull request #2593 from google/issue2578 add full lower to custom_jvp/vjp call bind 03 April 2020, 06:28:57 UTC
ba8225f skip all parallelize tests (abandonware right now) 03 April 2020, 06:11:55 UTC
0e49133 add full lower to custom_jvp/vjp call bind fixes #2578 03 April 2020, 05:54:29 UTC
f2de1bf add trace state check tearDown to JaxTestCase 03 April 2020, 05:01:43 UTC
297c902 make tracers tree-pretty-print their contents 03 April 2020, 04:04:12 UTC
64a7d17 Merge pull request #2587 from google/travis-mypy re-enable travis mypy testing (typo broke it) 03 April 2020, 04:02:13 UTC
5d3f1bd tell mypy: using __init__ to reinitialize is OK 03 April 2020, 03:14:12 UTC
6d4987c make core.trace_state resetting be thread-local 03 April 2020, 01:19:44 UTC
b78b7a0 add global trace state checks to more tests 03 April 2020, 01:03:58 UTC
ab0a005 check sublevel is reset in loops_test.py 03 April 2020, 00:18:47 UTC
c72abf6 re-enable travis mypy testing (typo broke it) 02 April 2020, 22:52:01 UTC
5c0ac40 Revert jax.numpy.matmul to pre-#2512 version. (#2584) https://github.com/google/jax/pull/2512 was causing some Google-internal tests to take longer. 02 April 2020, 19:54:58 UTC
0bdd0f6 Merge pull request #2581 from google/jet-process-call post process call of jet! 02 April 2020, 16:03:12 UTC
84dc6cc post process call of jet! Also included David's jet rule for lax.select. Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> Co-authored-by: Jacob Kelly <jacob.jin.kelly@gmail.com> Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu> 02 April 2020, 14:56:26 UTC
32c45fb Another attempt to disable new failing test on TPU 02 April 2020, 11:11:45 UTC
0c53ce9 Disable test with float16 on TPU 02 April 2020, 09:15:25 UTC
1bb9aaa Revert "Refactor shard_args ShardedDeviceArray slow path. (#2562)" (#2577) This reverts commit 2e87674aff0609d58eed07091d8e9fad8ee71d6c. This refactoring was broken, but we apparently have no test coverage! mypy/linting would have caught it though. 01 April 2020, 23:33:37 UTC
2e87674 Refactor shard_args ShardedDeviceArray slow path. (#2562) 01 April 2020, 22:35:26 UTC
2d42c42 Fix missing type promotion accidentally removed by #2512. (#2575) This is in fact covered by the existing tests, but we were unlucky and didn't hit them in the set of generated tests we selected. 01 April 2020, 21:18:17 UTC
8c4a938 Implement np.ldexp and np.frexp. (#1529) Co-authored-by: Peter Hawkins <phawkins@google.com> 01 April 2020, 19:29:48 UTC
79ce698 Fix incorrect raise in np.clip and add a test. (#2571) 01 April 2020, 16:35:12 UTC
651316f Fix issue 1465: fix jit in example (#1473) * fix jit in example * Avoid using static_argnums on a keyword argument; use a positional argument and a wrapper function for now. Co-authored-by: Peter Hawkins <phawkins@google.com> 01 April 2020, 15:57:57 UTC
86a4073 enable beta test on float64 values (#1177) * enable beta test on float64 values cf. #1123 * Enable beta test on all platforms. It seems sufficiently fast now. Co-authored-by: Peter Hawkins <phawkins@google.com> 01 April 2020, 15:25:32 UTC
c758aff Fix some missing cases of broadcasting in np.einsum. (#2512) * Fix some missing cases of broadcasting in np.einsum. In particular, np.einsum allows one side of a batch or contracting dimension to have size 1 even if the other side has a non-1 size. Implement np.matmul in terms of np.einsum. This allows us to reuse einsum's logic for performing broadcasting without explicitly broadcasting the LHS and RHS together. * Add regression test. Fixes #2189. 01 April 2020, 14:54:47 UTC
1d09b6b Replace CHECK with assert. (#2569) 01 April 2020, 13:58:17 UTC
ca23be6 Add `jax.tree_util.all_leaves(iterable)`. In Haiku (https://github.com/deepmind/dm-haiku) we have `FlatMapping` which is an immutable Mapping subclass maintaining a flat internal representation. Our goal is to allow very cheap flatten/unflatten since these objects are used to represent parameters/state and are often passed in and out of JAX functions that flatten their inputs (e.g. jit/pmap). One challenge we have is that on unflatten we need a fast way of testing whether the list of leaves provided are flat or not (since we want to cache both the flat structure and the leaves). Consider the following case: ```python d = FlatMapping.from_mapping({"a": 1}) # Caches the result of jax.tree_flatten. l, t = jax.tree_flatten(d) # Fine, leaves are flat. l = list(map(lambda x: (x, x), l)) # leaves are no longer flat. d2 = jax.tree_unflatten(t, l) # Needs to recompute structure. jax.tree_leaves(d2) # Should return [1, 1] not [(1, 1)] ``` Actual implementation here: https://github.com/deepmind/dm-haiku/blob/d37b486e09696ef34f7396c11b04074cb73a963c/haiku/_src/data_structures.py#L204-L208 This function allows an efficient way to do this using the JAX public API. 01 April 2020, 07:56:01 UTC
025d874 Fix error message 01 April 2020, 07:54:37 UTC
d2a827a Ensure the global trace_state is restored on errors in loops This is an attempted fix for https://github.com/google/jax/issues/2507 01 April 2020, 07:23:14 UTC
83b9575 add callable typechecks to more api.py functions 01 April 2020, 01:46:15 UTC
e3a9a56 Remove unused argument from pxla.shard_args (#2560) 01 April 2020, 00:47:54 UTC
4fa1534 FIX: batch norm w/ no scale, center 31 March 2020, 23:59:57 UTC
c28c46e Add ShardedDeviceArray indexing benchmark. (#2549) Example output: ``` ---------Benchmark summary for ShardedDeviceArray_indexing--------- indices_fn mean %std relative ------------------ -------- ------- ---------- integer_indices 0.16901 8.52522 1 integer_2D_indices 18.4918 0 109.412 ``` 31 March 2020, 22:52:41 UTC
bfbd0b8 Move tuple_arguments onto Compile() instead of Execute(). (#2559) Update minimum jaxlib version to 0.1.43. 31 March 2020, 21:09:14 UTC
6781eca Merge pull request #2558 from google/while-loop-of-pmap-bug don't hardcode array size in test 31 March 2020, 20:51:45 UTC
949d423 improve error message for #2554 31 March 2020, 20:07:59 UTC
29c581b don't hardcode array size in test Fixes #2554 31 March 2020, 18:54:57 UTC
bd1708c Update changelog and README for jaxlib 0.1.43. (#2556) 31 March 2020, 14:02:38 UTC
59ed4ae Disable test_while_loop_of_pmap on all platforms Issue: #2554 Disable the test so that we can continue the google3 tests 31 March 2020, 09:12:35 UTC
fd52fbf Fix import in benchmarks This works on my machine as 'python benchmarks/pmap_benchmark.py'. It also follows the code in examples. This will need a copybara rule to change the import to 'from jax.benchmarks import benchmark' 31 March 2020, 08:48:08 UTC
a4ceae1 fix link in custom derivatives tutorial notebook 31 March 2020, 05:12:38 UTC
27604c3 fix typo in notebook 31 March 2020, 05:11:35 UTC
e017a92 fix typo 31 March 2020, 05:06:00 UTC
7a9c550 add sphinx-autodoc-typehints to travis install 31 March 2020, 04:30:47 UTC
e7be43d update api.py docstrings for sphinx highlighting 31 March 2020, 04:09:12 UTC
cda505a Merge pull request #2551 from google/enable-mpc-tests try re-enabling control tests that trigger #2507 31 March 2020, 03:46:09 UTC
909fee6 try adding sphinx-autodoc-typehints 31 March 2020, 03:22:04 UTC
b015e57 try re-enabling control tests that trigger #2507 31 March 2020, 03:12:33 UTC
1f03d48 try resetting global tracer state in loops_test.py attempting to address #2507 31 March 2020, 03:10:39 UTC
de37eae Merge pull request #2550 from google/update-custom-derivatives-notebook update custom derivatives tutorial notebook 31 March 2020, 03:10:16 UTC
15009c9 add docstring for defjvps, fix sphinx docs 31 March 2020, 02:45:45 UTC
bd726fc update custom derivatives tutorial notebook * add clip_gradient example * add defjvps convenience wrapper 31 March 2020, 02:37:11 UTC
b7d2cbd Merge pull request #2548 from google/while-loop-of-pmap-bug fix a while-loop-of-pmap bug (thanks @jaspersnoek) 31 March 2020, 01:34:50 UTC
375575a skip new test on cpu unless num_devices > 1 31 March 2020, 00:53:47 UTC
df3b5fe fix a while-loop-of-pmap bug (thanks @jaspersnoek) 31 March 2020, 00:48:07 UTC
acb2414 Update XLA to fix CUDA 9.2 build problem. (#2547) 30 March 2020, 23:43:23 UTC
0c7e134 Update XLA and increment jaxlib version to 0.1.43. (#2546) 30 March 2020, 20:56:54 UTC
1275bc5 don't pass `-n 1` to pytest in travis 30 March 2020, 20:50:32 UTC
b43051e minor fix to custom_transforms 30 March 2020, 20:49:56 UTC
3d0c187 Merge pull request #2543 from google/process-custom-jvp-default-implementation comments/defaults for process_custom_{jv,vj}p_call 30 March 2020, 20:38:05 UTC
ea9401d Merge pull request #2541 from google/xla-computation-duck-typing allow duck-typing in xla_computation arguments 30 March 2020, 19:47:13 UTC
70a3f47 comments/defaults for process_custom_{jv,vj}p_call 30 March 2020, 19:02:25 UTC
f766c5e allow duck-typing in xla_computation arguments 30 March 2020, 18:31:29 UTC
9d8823c add initial_style_staging to custom_transforms 30 March 2020, 07:41:04 UTC
a6a837a add some stage_out=True indicators 30 March 2020, 07:35:45 UTC
fab6600 Merge pull request #2539 from google/remat-fix remove remat context check, add initial staging 30 March 2020, 06:31:07 UTC
bdc0c3b remove remat context check, add initial staging 30 March 2020, 06:29:55 UTC
e7f8503 Merge pull request #2500 from google/custom-jvp-fix revise custom_jvp / custom_vjp rule jaxpr staging 30 March 2020, 06:04:00 UTC
74d358d skip ode test on import error (internal) 30 March 2020, 06:00:40 UTC
b2b68e5 fix bugs, add tests 30 March 2020, 05:53:37 UTC
7a4c4d5 use custom_jvp for internal functions 30 March 2020, 03:48:08 UTC
6193e5e revamp custom_jvp/vjp implementation to fix bugs Co-authored-by: Dougal Maclaurin <dougalm@google.com> 30 March 2020, 02:35:01 UTC
954334c Merge pull request #2537 from duvenaud/morejets Adds lots of trivial jet rules, partly addresses #2431 30 March 2020, 01:05:48 UTC
0f11f9c Merge pull request #2538 from google/pfau workaround for pmap output PRED arrays on cpu/gpu 30 March 2020, 00:02:19 UTC
1762a86 workaround for pmap output PRED arrays on cpu/gpu 29 March 2020, 21:45:17 UTC
305dd8c Merge pull request #2536 from google/issue2534 add docstring / reference doc link for axis_index 29 March 2020, 21:43:42 UTC
fcc1e76 add docstring / reference doc link for axis_index fixes #2534 29 March 2020, 20:56:26 UTC
614d39d Merge pull request #2527 from lucasb-eyer/patch-1 Make it more explicit that default JVP assumes |R 29 March 2020, 20:37:49 UTC
ead8011 Added lots of trivial jet rules. Co-Authored-By: jessebett <jessebett@gmail.com> Co-Authored-By: Jacob Kelly <jacob.kelly@mail.utoronto.ca> 29 March 2020, 20:28:17 UTC
67283a0 add new custom_jvp tests from #2500 Co-authored-by: Dougal Maclaurin <dougalm@google.com> 28 March 2020, 22:33:20 UTC
bcc5191 Merge pull request #2533 from google/core-trace-type-annotations add type annotations to core.py tracing machinery 28 March 2020, 22:15:47 UTC
f99720b add type annotations to core.py tracing machinery also add .copy() method to core.trace_state global trace state 28 March 2020, 21:58:35 UTC
cdbcbb0 Merge pull request #2532 from google/sharded-device-array-handlers add ShardedDeviceArray to ad vspace op handlers 28 March 2020, 19:31:40 UTC
1b59789 add ShardedDeviceArray to ad vspace op handlers fixes #2529 (thanks, @dpfau !) 28 March 2020, 18:56:12 UTC
415cde5 Make it more explicit that default JVP assumes |R It's just an attempt to make this implicit assumption, as it only became clear to me after our discussion in chat, not after reading this. 28 March 2020, 11:32:44 UTC
f371bfc Improve speed of LU decomposition on TPU. (#2526) Increase the block size, which helps with compilation time. Merge the two row permutations in the outer loop, which means we do row-at-a-time gathers. 28 March 2020, 01:24:26 UTC
6d0810a Merge pull request #2501 from botev/isclose Making isclose handle correctly infinite and NaN values. 27 March 2020, 19:46:01 UTC
back to top