fedf7fb | Skye Wanderman-Milne | 05 May 2020, 18:57:11 UTC | Update jax version to 0.1.66 (#2970) | 05 May 2020, 18:57:11 UTC |
9612c52 | Peter Hawkins | 05 May 2020, 18:10:57 UTC | Relax test tolerances, suppress warning messages. (#2967) | 05 May 2020, 18:10:57 UTC |
0eba939 | Skye Wanderman-Milne | 05 May 2020, 16:25:48 UTC | New and improved _shard_device_array function. (#2958) This gets the performance of sharding DeviceArray arguments to pmap roughly back to what it was prior to https://github.com/google/jax/commit/07571ae4dd3fceee580aa49c4490f99ce7f6b6de. It does so by re-introducing a _shard_device_array function that can handle arbitrary array slices. Benchmark results compared to https://github.com/google/jax/commit/87d959089f3406714c98e674c145b09156319ef3 (i.e. just prior to the regression): ``` ---------Benchmark summary for pmap_shard_device_array--------- nargs nshards mean %std relative mean/baseline ------- --------- --------- -------- ---------- --------------- 10 8 0.0479975 12.0865 1 1.09631 100 8 0.32916 5.7446 6.85786 1.10263 500 8 1.5563 2.68041 32.4246 1.10066 100 2 0.136431 8.33826 2.84245 1.15886 100 4 0.198815 5.91716 4.1422 1.11409 100 8 0.31788 4.80559 6.62285 1.06637 ``` This still seems a bit slower than it was before, but gets most of the performance back. We can further optimize in future changes if needed. Fixes https://github.com/google/jax/pull/2958 (hopefully) | 05 May 2020, 16:25:48 UTC |
61a34f5 | Peter Hawkins | 05 May 2020, 16:18:55 UTC | Update README for jaxlib 0.1.46 release. (#2968) | 05 May 2020, 16:18:55 UTC |
dc234b6 | Joost Bastings | 05 May 2020, 08:11:10 UTC | Expose functools.reduce initializer argument to tree_util.tree_reduce (#2935) * Expose `functools.reduce` initializer argument to `tree_util.tree_reduce`. `functools.reduce` takes an optional `initializer` argument (default=None) which is currently not exposed by `tree_reduce'. This can be useful e.g. for computing an L2 penalty, where you would initialize with 0., and then sum the L2 for each parameter. Example: ``` def l2_sum(total, param): return total + jnp.sum(param**2) tree_reduce(l2_sum, params, 0.) ``` * Only call functools.reduce with initializer when it is not None. * Change logic to check for number of args to allow None value as initializer * Rename seq to tree, and add tree_leaves * Change reduce to functools.reduce. * Make tree_reduce self-documenting * Replace jax.tree_leaves with tree_leaves * Update to use custom sentinel instead of optional position argument * jax.tree_leaves -> tree_leaves | 05 May 2020, 08:11:10 UTC |
e4d8cac | Julius Kunze | 05 May 2020, 03:12:43 UTC | Fix tests for random.categorical with multi-dimensional logits (#2955) | 05 May 2020, 03:12:43 UTC |
7116cc5 | Peter Hawkins | 05 May 2020, 03:00:20 UTC | Improve JAX test PRNG APIs to fix correlations between test cases. (#2957) * Improve JAX test PRNG APIs to fix correlations between test cases. In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because they were formed with identically seeded PRNGs. This change updates the test_util.rand_...() functions to take an explicit numpy.random.RandomState, and adds a rng() method to JaxTestCase to form a RandomState seeded on the test case name. This gives the following properties: * different test cases receive different seeds * PRNG seeding is deterministic and independent of execution order and sharding. * PRNG seeding is deterministic across runs. * Fix some failing tests. * Fix more test failures. Simplify ediff1d implementation and make it more permissive when casting. * Relax test tolerance of laplace CDF test. | 05 May 2020, 03:00:20 UTC |
3cd409e | Matthew Johnson | 05 May 2020, 02:44:22 UTC | add optional 'forward' argument to lax.scan (#2921) * add optional 'forward' argument to lax.scan * switch to reverse; revise disable-jit case * fix jaxpr.rst * fix loops.py Co-authored-by: James Bradbury <jekbradbury@gmail.com> | 05 May 2020, 02:44:22 UTC |
3e52237 | yurodiviy | 05 May 2020, 01:29:55 UTC | Raise an error in np.var when array is complex and dtype is not (#2288) Co-authored-by: vlad <veryfakemail@ya.ru> | 05 May 2020, 01:29:55 UTC |
9174684 | Peter Hawkins | 05 May 2020, 01:08:34 UTC | Cache test_utils.format_shape_and_dtype_string. (#2959) A significant fraction of time when collecting test cases is spent building shape and dtype strings (which are usually similar and usually thrown away.) | 05 May 2020, 01:08:34 UTC |
4c2c5ad | Tom Hennigan | 04 May 2020, 22:46:12 UTC | Add a note about jax.pmap when leading dim is smaller than num devices. (#2949) | 04 May 2020, 22:46:12 UTC |
c9c653a | joschkabraun | 04 May 2020, 20:55:47 UTC | Implementation numpy.ediff1d (#2729) * Implementation of numpy.ediff1d * Added testing for numpy.ediff1d implementation * Made ediff1d jit-compatible * Implemented corrections: style and more testing * Adapted tests * changed tests * modified tests * Incorporated changes * Style changes * Added line between tests * Changed op_record test | 04 May 2020, 20:55:47 UTC |
6aab9e5 | Stephan Hoyer | 04 May 2020, 20:06:25 UTC | DOC: write a new dosctring for jax.numpy.vectorize (#2944) * DOC: write a new dosctring for jax.numpy.vectorize This version is customized entirely for JAX. * review and typo fixes | 04 May 2020, 20:06:25 UTC |
5a0bf46 | Stephan Hoyer | 04 May 2020, 19:37:29 UTC | DOC: add a table of contents for top level API docs (#2946) This makes them easier to scan. | 04 May 2020, 19:37:29 UTC |
72efa78 | Peter Hawkins | 04 May 2020, 18:50:08 UTC | Fix spurious rank promotion warning. (#2954) | 04 May 2020, 18:50:08 UTC |
d61d6f4 | Peter Hawkins | 04 May 2020, 18:34:08 UTC | Fix a number of flaky tests. (#2953) * relax some test tolerances. * disable 'random' preconditioner in CG test (#2951). * ensure that scatter and top-k tests don't create ties. | 04 May 2020, 18:34:08 UTC |
04102e5 | tamaranorman | 04 May 2020, 18:02:13 UTC | Allow ConvDimensionNumbers to be passed into conv_transpose (#2915) | 04 May 2020, 18:02:13 UTC |
4d236b5 | Peter Hawkins | 04 May 2020, 13:17:07 UTC | Update XLA to fix build failures. (#2950) | 04 May 2020, 13:17:07 UTC |
525235d | Roman Ring | 04 May 2020, 10:20:21 UTC | Fix a codeblock in the "understanding jaxpr" doc. (#2942) This fixes an issue where the codeblock didn't render properly on the website. | 04 May 2020, 10:20:21 UTC |
d315564 | George Necula | 04 May 2020, 08:30:28 UTC | Fixed a few more places where device commitment was lost. (#2913) * trivial jit computations were forcing commitment to the default device * a device_put with a device specification would not set the commitment if the data was already (uncommitted) on the specified device. * added tests for the above * once the above were fixed the LaztTest.test_zeros_ones_compilation stated to fail because the `sticky` parameter to lazy_force_computation was changing. Fixed this by removing stickyness from the compilation key. * Expanded docstring for jax.device_put; expanded the device placement FAQ entry. | 04 May 2020, 08:30:28 UTC |
1cc6b7d | James Bradbury | 03 May 2020, 02:33:10 UTC | support axis argument in nn.glu (#2879) * support axis argument in nn.glu * also add basic correctness test * Update nn_test.py | 03 May 2020, 02:33:10 UTC |
9f7115e | Matthew Johnson | 02 May 2020, 19:02:43 UTC | reduce use of lax on static data (e.g. shapes) (#2933) * reduce use of lax on static data (e.g. shapes) * use f-string for error message | 02 May 2020, 19:02:43 UTC |
64f12a4 | Matthew Johnson | 02 May 2020, 17:25:53 UTC | improve docs and error message for odeint *args (#2931) cf. #2920 | 02 May 2020, 17:25:53 UTC |
a182578 | Peter Hawkins | 02 May 2020, 16:47:07 UTC | Update XLA. (#2932) Mention illegal instruction fix in changelog. | 02 May 2020, 16:47:07 UTC |
46ce80b | Stephan Hoyer | 02 May 2020, 15:24:59 UTC | jax.random.poisson (#2805) * jax.random.poisson The implementation for lam < 10 was directly copied from TensorFlow probability: https://github.com/tensorflow/probability/blob/v0.10.0-rc0/tensorflow_probability/python/internal/backend/numpy/random_generators.py#L155 I adapted the implementation for lam > 10 from TensorFlow: https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_poisson_op.cc The methods themselves match both TensorFlow and NumPy: https://github.com/numpy/numpy/blob/v1.18.3/numpy/random/src/distributions/distributions.c#L574 * add a check for even larger lambda * increment iter count * remove comment that makes no sense * Fix chi-squared tests in random_test.py As far as I can tell, the previous implementation of the chi-squared test for samples from discrete probability distributions was broken. It should have been asserting that the p-value was greater 0.01, e.g., as illustrated here: http://hamelg.blogspot.com/2015/11/python-for-data-analysis-part-25-chi.html This hid a few other bugs, such a miscalculation of expected frequencies. Fortunately, the existing random tests for Bernoulli and Categorical *mostly* still pass, which the exception of multi-dimensional logits for Categorical. Those tests are disabled by this PR. * Fix accept condition (based on correct chi-squared test) * Add moment checks for Poisson * Add batching test, more Poisson rates | 02 May 2020, 15:24:59 UTC |
ee38e1b | Peter Hawkins | 02 May 2020, 15:09:21 UTC | Update XLA. (#2929) Includes a fix that may help with issue #2906. | 02 May 2020, 15:09:21 UTC |
6425ca2 | Jake Vanderplas | 02 May 2020, 13:32:50 UTC | Merge pull request #2925 from jakevdp/shuffle Deprecate random.shuffle() and implement random.permutation() for multi-dim inputs | 02 May 2020, 13:32:50 UTC |
9802d73 | Peter Hawkins | 02 May 2020, 01:08:56 UTC | Update XLA. (#2927) | 02 May 2020, 01:08:56 UTC |
a821e67 | Jacob Kelly | 02 May 2020, 00:10:20 UTC | instantiate zeros (#2924) fix dtype remove TODO | 02 May 2020, 00:10:20 UTC |
d8d7140 | Jake VanderPlas | 01 May 2020, 22:18:24 UTC | Deprecate random.shuffle() and implement random.permutation() for multi-dimensional matrices. | 01 May 2020, 22:18:24 UTC |
f8fa589 | Matthew Johnson | 01 May 2020, 21:40:24 UTC | revert previous change | 01 May 2020, 21:40:24 UTC |
2263899 | Matthew Johnson | 01 May 2020, 21:39:30 UTC | replace accidental use of jax.numpy.min w/ builtin | 01 May 2020, 21:39:30 UTC |
1cdd8f1 | James Bradbury | 01 May 2020, 21:37:13 UTC | Add support for in_axes=None (but not out_axes, or in_axes>0) to pmap (#2896) * allow in_axes=None for pmap in api.py * wire in_axes=None through parallel_callable * add test * fix error string * fixes * fixes * add test for nested pmap with in_axes * test pmap still defaults to (implicit) out_axes=0 | 01 May 2020, 21:37:13 UTC |
49a8901 | Matthew Johnson | 01 May 2020, 20:47:46 UTC | skip failing shapecheck tests cc @JuliusKunze | 01 May 2020, 20:48:07 UTC |
c00e9a2 | Julius Kunze | 01 May 2020, 19:34:29 UTC | Reapply #2017 (Allow shapecheck of PixelCNN++), fixing #2245 (#2800) * Unrevert "Allow shapecheck of PixelCNN++ (google#2017)" This reverts commit ceab1e3edf1e2395035173dc50f24ce6a27475f6. * Fix out-of-bound slices (#2245) * Minor * Add type annotations * Fix Poly.__rsub__ * any -> _any * tweaks, mostly comments/whitespace * separate polymorphic code path, patch _slice_sizes * put back some logic for handling Poly sizes * improve test_slice_indices * Remove to_index, replace with canonicalize_shape * Fix slicing with polymorphic start/stop * Test negative step for polymorphic slicing * Refactor polymorphic slicing * Simplify diff * Fix shapecheck(iota) Co-authored-by: Matthew Johnson <mattjj@google.com> | 01 May 2020, 19:34:29 UTC |
1b56428 | Peter Hawkins | 01 May 2020, 17:57:09 UTC | Fix test flakiness in autodiff tests for min/max type functions (#2918) * Fix test flakiness in autodiff tests for clamp, reduce, and reduce-window. We change the tests to avoid computing numerical gradients in the neighborhood of nondifferentiable points where, for example, the maximum element in a reduce-max changes. The autodiff approximation is only valid within an epsilon ball around a point, and close to an inflection point the approximation may not be valid. * Only test reduce-grad-mul for float types. | 01 May 2020, 17:57:09 UTC |
0736679 | Tom Hennigan | 01 May 2020, 17:00:38 UTC | Explicitly broadcast values in nn.one_hot and nn.initializers.orthogonal. (#2901) At head the following fails: ```python >>> import jax >>> import jax.numpy as jnp >>> jax.config.update('jax_numpy_rank_promotion', 'raise') >>> jax.nn.one_hot(jnp.ones([8]), 512) ... ValueError: Operands could not be broadcast together for equal on shapes (8, 1) (512,) and with the config option jax_numpy_rank_promotion='raise'. For more information, see https://jax.readthedocs.io/en/latest/rank_promotion_warning.html. ``` | 01 May 2020, 17:00:38 UTC |
279a077 | James Bradbury | 01 May 2020, 17:00:06 UTC | Avoid tuple allreduce lowering of psum on TPUs (#2914) Tuple-shaped allreduces aren't supported in an XLA:TPU optimization pass (see internal bug), but since our use of them on GPU is due to compiler nondeterminism that isn't present on TPU, it should be fine to avoid this bug by disabling tuple psum on TPU. | 01 May 2020, 17:00:06 UTC |
25e8280 | Peter Hawkins | 01 May 2020, 15:45:28 UTC | Relax some test tolerances. (#2917) | 01 May 2020, 15:45:28 UTC |
ac023bf | George Necula | 01 May 2020, 07:06:59 UTC | Fixed a few places where device sticky-ness was lost. Added FAQ (#2882) * Fixed a few places where device sitckyness was lost. Added FAQ for device placement. I have also added a new test (multi_device_test.test_computation_follows_data), written more as part of the documentation. It is shorted than the old test_computation_follows_data (which is still there, renamed as test_computation_follows_data_old). I believe there is no extra coverage in test_computation_follows_data_old w.r.t. all the other tests we have. * Fix mypy annotations and updates based on comments * Undid some changes, will make another PR | 01 May 2020, 07:06:59 UTC |
2e9047d | George Necula | 01 May 2020, 06:16:31 UTC | Add flag to enable checking, and turn on checking in tests. (#2900) Fix an error in check_jaxpr. | 01 May 2020, 06:16:31 UTC |
e06bde8 | Matthew Johnson | 01 May 2020, 00:21:10 UTC | revise xla.device_put device logic (#2907) * revise xla.device_put device logic, fixes #2905 * remove test of behavior we don't want Previously, we were testing that for a DeviceArray x, writing jax.device_put(x) would evaluate to a DeviceArray *on the default device*. Instead, we should be happy with just returning the same DeviceArray without any movement. | 01 May 2020, 00:21:10 UTC |
a4deae3 | Roy Frostig | 30 April 2020, 22:56:59 UTC | err on empty operand dimension in numpy argmin and argmax see #2899 | 30 April 2020, 23:35:42 UTC |
bb60833 | Matthew Johnson | 30 April 2020, 23:13:36 UTC | update changelog | 30 April 2020, 23:13:36 UTC |
3aa953d | Skye Wanderman-Milne | 30 April 2020, 22:31:42 UTC | Update jax version to 0.1.65 (#2909) | 30 April 2020, 22:31:42 UTC |
815a92e | Skye Wanderman-Milne | 30 April 2020, 21:49:33 UTC | Remove assert from ShardedDeviceArray staging. (#2908) This would erroneously fail on Cloud TPU because the TPU client has its own buffer type. | 30 April 2020, 21:49:33 UTC |
3216f5c | Roy Frostig | 30 April 2020, 15:31:48 UTC | err on empty operand in numpy argmin and argmax fixes #2899 | 30 April 2020, 16:30:18 UTC |
8d4b685 | George Necula | 30 April 2020, 16:16:05 UTC | Fix typo in tests; caught on GPU and TPU (#2902) | 30 April 2020, 16:16:05 UTC |
b39da1f | George Necula | 30 April 2020, 07:16:14 UTC | Fix jit with device placement (#2883) In setups with multiple backends, a jit happens on the default backend, unless we give a `backend` parameter. This is true even if the inputs are committed to a device on the non-default backend, or if we pass a `device` parameter to jit. | 30 April 2020, 07:16:14 UTC |
1f7ebab | Jacob Kelly | 30 April 2020, 02:18:21 UTC | add jets for sines fns (#2892) refactor remove duplicate | 30 April 2020, 02:18:21 UTC |
c8d1700 | Peter Hawkins | 30 April 2020, 01:25:43 UTC | Make sure gather/scatter indices in lax gradient tests aren't out of bounds. (#2895) Out-of-bounds gathers are clamped to be in bounds, but out-of-bounds scatters are dropped entirely. This can cause gradient tests to fail because the two operations aren't duals of one another, as the gradient rules expect. | 30 April 2020, 01:25:43 UTC |
b8cbc95 | Peter Hawkins | 30 April 2020, 01:03:46 UTC | Fix lax_reference implementation of round() to match lax. (#2894) lax.round() is documented to round half away from zero, but np.round() rounds to nearest even. | 30 April 2020, 01:03:46 UTC |
3e87e8f | Peter Hawkins | 29 April 2020, 18:59:23 UTC | Add relu6, hard_swish, and hard_sigmoid to docs. (#2886) | 29 April 2020, 18:59:23 UTC |
43efbe2 | James Bradbury | 29 April 2020, 18:31:36 UTC | Reset parameter replication default (#2880) * Reset parameter replication default * add tests | 29 April 2020, 18:31:36 UTC |
0557248 | Peter Hawkins | 29 April 2020, 18:14:49 UTC | Check for unsupported dtypes and issue a helpful error. (#2885) | 29 April 2020, 18:14:49 UTC |
52c69e8 | Martin Sotir | 29 April 2020, 07:16:49 UTC | Fix slices in Gated Linear Unit activation (#2341) | 29 April 2020, 07:16:49 UTC |
ef963f0 | Vaibhav Balloli | 29 April 2020, 07:07:18 UTC | Add ReLU6, Hard sigmoid, swish (#2709) | 29 April 2020, 07:07:18 UTC |
790d929 | Matthew Johnson | 28 April 2020, 23:41:26 UTC | iterate on jax.hessian docs (#2873) * iterate on jax.hessian docs * tweaks * add back note about block structure | 28 April 2020, 23:41:26 UTC |
3ee8a7b | Peter Hawkins | 28 April 2020, 23:12:32 UTC | Add nanargmin and nanargmax to documentation. (#2877) | 28 April 2020, 23:12:32 UTC |
697aa48 | Skye Wanderman-Milne | 28 April 2020, 23:02:30 UTC | Fix bug in ShardedDeviceArrayTest.testThreadsafeIndexing (#2875) | 28 April 2020, 23:02:30 UTC |
b0b6cd8 | Peter Hawkins | 28 April 2020, 22:44:00 UTC | Make dlpack code robust against upcoming XLA Python binding change. (#2876) | 28 April 2020, 22:44:00 UTC |
56f6294 | yurodiviy | 28 April 2020, 19:23:03 UTC | Implement nanargmin-max and add tests (#2398) Co-authored-by: vlad <veryfakemail@ya.ru> | 28 April 2020, 19:23:03 UTC |
2a0637a | Eduardo Pignatelli | 28 April 2020, 18:34:27 UTC | add spacing to numpy.gradient (#2545) | 28 April 2020, 18:34:27 UTC |
ae6a3fe | Peter Hawkins | 28 April 2020, 18:07:35 UTC | Document how jax.hessian and pytrees interact. (#2705) * Document how jax.hessian and pytrees interact. | 28 April 2020, 18:07:35 UTC |
e599a25 | Anselm Levskaya | 28 April 2020, 17:49:17 UTC | fix sort_key_val return type annotation, docstring | 28 April 2020, 17:49:17 UTC |
2611fd2 | Paige Bailey | 28 April 2020, 17:40:05 UTC | Updated README wrt. new features for Stax. (#2862) * Updated README wrt. new features for Stax. | 28 April 2020, 17:40:05 UTC |
f6e9060 | Jamie Townsend | 28 April 2020, 16:58:49 UTC | Qr complex jvp fix (#2872) * Fix qr jvp for complex input * Fix qr jvp for complex64 inputs when jax_enable_x64=True * Reenable complex jvp test for qr | 28 April 2020, 16:58:49 UTC |
e287f98 | Peter Hawkins | 28 April 2020, 16:01:54 UTC | Fix definition of qr primitive to return only the upper triangular part of r. (#2870) Issue #2863. | 28 April 2020, 16:01:54 UTC |
0dbbc27 | Peter Hawkins | 28 April 2020, 15:58:51 UTC | Clarify that `grad` requires arguments to be differentiated to be of inexact type. (#2712) | 28 April 2020, 15:58:51 UTC |
ca4e396 | Anselm Levskaya | 28 April 2020, 14:57:29 UTC | Merge pull request #2853 from levskaya/topkjvp Add top_k jvp and batching rules and tests | 28 April 2020, 14:57:29 UTC |
dddad2a | Anselm Levskaya | 19 April 2020, 18:49:15 UTC | Add top_k jvp and batching rules | 28 April 2020, 14:19:58 UTC |
c0023f4 | Peter Hawkins | 28 April 2020, 14:03:31 UTC | Change isinstance test in xla_bridge.py to not explicitly name xla_client.Backend. (#2868) Change in preparation for removing xla_client.Backend in favor of the underlying C++ classes. | 28 April 2020, 14:03:31 UTC |
5fe6b06 | Adam Paszke | 28 April 2020, 13:07:08 UTC | Correct the order of .format arguments in vjp wrapper (#2866) | 28 April 2020, 13:07:08 UTC |
75617be | Jamie Townsend | 28 April 2020, 05:32:52 UTC | Add population_count primitive to lax (#2753) * add population_count primitive (needs new jaxlib) fixes #2263 * Add popcount docs * Add population_count to lax_reference * Use int prng (since we're only testing uints) Co-authored-by: Matthew Johnson <mattjj@google.com> | 28 April 2020, 05:32:52 UTC |
2d96cfb | Chris Jones | 28 April 2020, 05:09:30 UTC | Remove unused `ispure` method (#2781) | 28 April 2020, 05:09:30 UTC |
cc0e9a3 | Jacob Kelly | 28 April 2020, 04:53:38 UTC | refactor ode tests, add scipy benchmark (#2824) * refactor ode tests, add scipy benchmark remove double import rename to scipy merge vmap test properly * clean up more global trace state after errors Co-authored-by: Matthew Johnson <mattjj@google.com> | 28 April 2020, 04:53:38 UTC |
e6df98d | Stephan Hoyer | 28 April 2020, 00:24:39 UTC | Fix chi-squared tests in random_test.py (#2847) As far as I can tell, the previous implementation of the chi-squared test for samples from discrete probability distributions was broken. It should have been asserting that the p-value was greater 0.01, e.g., as illustrated here: http://hamelg.blogspot.com/2015/11/python-for-data-analysis-part-25-chi.html This hid a few other bugs, such a miscalculation of expected frequencies. Fortunately, the existing random tests for Bernoulli and Categorical *mostly* still pass, which the exception of multi-dimensional logits for Categorical. Those tests are disabled by this PR. | 28 April 2020, 00:24:39 UTC |
4b03343 | Skye Wanderman-Milne | 28 April 2020, 00:21:05 UTC | Add pmap_shard_device_array_benchmark. (#2864) Also renames pmap_shard_args_benchmark to pmap_shard_sharded_device_array_benchmark. | 28 April 2020, 00:21:05 UTC |
b277b55 | Jacob Kelly | 27 April 2020, 23:45:51 UTC | check step size is greater than zero (#2857) loosen tols for grad test set tol only for float64 | 27 April 2020, 23:45:51 UTC |
283393f | Jamie Townsend | 27 April 2020, 23:44:46 UTC | Update jaxpr.rst (#2859) * Update jaxpr doc * Make jaxpr.rst doctestable | 27 April 2020, 23:44:46 UTC |
5da74d4 | samuela | 27 April 2020, 18:56:48 UTC | Simplify _odeint_rev (#2832) | 27 April 2020, 18:56:48 UTC |
3736278 | Abhishek Sharma | 27 April 2020, 05:10:48 UTC | Add precision only arguments (#2850) * Make precision argument keyword only in jax.numpy * Fix private functions | 27 April 2020, 05:10:48 UTC |
0a5cc09 | Matthew Johnson | 25 April 2020, 16:20:26 UTC | split testDetGradOfSingularMatrix into corank=1,2 (#2845) | 25 April 2020, 16:20:26 UTC |
ffaf417 | David Pfau | 25 April 2020, 15:32:27 UTC | Fix typo in docstring for _cofactor_solve (#2844) Found a small typo in the description of _cofactor_solve | 25 April 2020, 15:32:27 UTC |
02b3fc5 | David Pfau | 25 April 2020, 15:26:25 UTC | Custom derivative for np.linalg.det (#2809) * Add vjp and jvp rules for jnp.linalg.det * Add tests for new determinant gradients * Replace index_update with concatenate in cofactor_solve This avoids issues with index_update not having a transpose rule, removing one bug in the way of automatically converting the JVP into a VJP (still need to deal with the np.where). * Changes to cofactor_solve so it can be transposed This allows a single JVP rule to give both forward and backward derivatives * Update det grad tests All tests pass now - however second derivatives still do not work for nonsingular matrices. * Add explanation to docstring for _cofactor_solve * Fixed comment | 25 April 2020, 15:26:25 UTC |
4e020cc | Peter Hawkins | 25 April 2020, 15:01:06 UTC | Enable some tests that now pass. (#2841) | 25 April 2020, 15:01:06 UTC |
f9806b3 | Peter Hawkins | 25 April 2020, 14:19:28 UTC | Remove some tests for Jaxlib versions older than the minimum. (#2840) | 25 April 2020, 14:19:28 UTC |
a6093d8 | Peter Hawkins | 25 April 2020, 14:19:19 UTC | Replace uses of xla_client.Buffer.from_pyval() with backend.buffer_from_pyval(). (#2839) Change in preparation for deleting xla_client.Buffer. | 25 April 2020, 14:19:19 UTC |
9e9b348 | Peter Hawkins | 25 April 2020, 13:23:34 UTC | Update deprecated API usages in lapack.pyx. (#2838) | 25 April 2020, 13:23:34 UTC |
89e3840 | Matthew Johnson | 25 April 2020, 01:45:34 UTC | handle mapped_invars correctly in more places (#2828) fixes #2822 We didn't handle `pmap`'s `mapped_invars` correctly in all places in #1959. (I'm actually not sure if #1959 introduced the bug where things were working before, or just refactored it in terms of `mapped_invars`, though my guess is that because the information now contained in `mapped_invars` was implicitly contained in the pmapped jaxpr's `constvars` and `env_vars` that it was working correctly before #1959.) In particular, in #1959 we: 1. assumed the `mapped_invars` parameter of xla_pmap_p was only populated after partial_eval and set to None otherwise (i.e. staging out for a jit or a control flow primitive), 2. didn't update it correctly in JVPTrace.process_map (which adds new inputs corresponding to nonzero tangents, and hence `mapped_invars` must be grown), 3. didn't update it correctly in JaxprTrace.process_map (which adds residual inputs to the staged-out version of the primitive), 4. didn't forward it correctly in JaxprTrace.process_map anyway (we were setting it to all-true for the staged out eqn for all tracers regardless of what the original `mapped_invars` said), 5. removed the leading axes of all pvs in JaxprTrace.process_map regardless of whether the corresponding entry of `mapped_invars` was True or False. The reason we didn't notice 2 and 3 was that they only arise when doing control flow (e.g. scan or remat) of pmap involving closed-over tracers (apparently a rare case), since that's the case where we first form a jaxpr (populating `mapped_invars`) and then later have to apply transformations like AD and further partial eval (thus engaging JVPTrace.process_map and JaxprTrace.process_map with a populated `mapped_invars` parameter). It worked in other cases, e.g. when the pmap was not inside control flow or a remat, because in those cases we left `mapped_invars` set to None, indicating all-true of any length (so it didn't matter if we add inputs). This commit fixes those issues by 1. making `mapped_invars` non-optional, 2. handling `mapped_invars` correctly in * JaxprTrace.process_map * JVPTrace.process_map * ad.map_transpose (since having symbolic-zero cotangents effectively prunes inputs, and having undefined-primal args also prunes inputs) * ad._eval_subjaxpr_primals (since having undefined-primal args prunes inputs) 3. making the separate cases of calls and maps handled more explicitly by adding a new Primitive.map_primitive boolean attribute (analogous to Primitive.call_primitive), to be revised further in #2829. This is begging for a more coherent cleanup. For example, we reuse the same Primitive class but tag it with `call_primitive` or `map_primitive` (only one of which can be True); we should instead just have a separate Primitive class for these cases and track the type tag with built-in Python mechanisms. Moreover, when `call_primitive=True` or `map_primitive=True` implies things about what `params` must be present (`call_jaxpr` and `mapped_invars`). I plan to follow up with those cleanups in #2829, but I wanted to get something working first. | 25 April 2020, 01:45:34 UTC |
8f90245 | Matthew Johnson | 25 April 2020, 01:19:24 UTC | only maximally stage out for some call primitives (#2834) fixes #2833 | 25 April 2020, 01:19:24 UTC |
77901e9 | Jon Malmaud | 24 April 2020, 20:43:04 UTC | Fix lax.rng_uniform. (#2830) | 24 April 2020, 20:43:04 UTC |
343e486 | Skye Wanderman-Milne | 24 April 2020, 20:11:53 UTC | Remove platform canonicalization from xla_bridge.py (#2815) | 24 April 2020, 20:11:53 UTC |
11d7fb0 | Matthew Johnson | 24 April 2020, 08:47:20 UTC | add more ode tests (#2819) | 24 April 2020, 08:47:20 UTC |
6ad2908 | Matthew Johnson | 24 April 2020, 08:21:27 UTC | add ode test file (#2818) * add ode test file * control test tolerances based on precision | 24 April 2020, 08:21:27 UTC |
23f4874 | samuela | 24 April 2020, 07:48:47 UTC | Fix time issues in odeint reverse mode (#2817) * Fix time issues in odeint reverse mode * Add regression test | 24 April 2020, 07:48:47 UTC |
e0d42e9 | MichaelMarien | 24 April 2020, 05:40:33 UTC | Feature/permutation (#1568) * added test for random.permutation * added permutation that wraps shuffle with behaviour of np.random.permutation * update docstring * need to shuffle also the integer range input * fixed test for permutation with integer * tweak handling of random.permutation scalar case * NotImplementedError for random.permutation on >1d pending resolution to #2066 * address reviewer comments: improve tests Co-authored-by: Matthew Johnson <mattjj@google.com> | 24 April 2020, 05:40:33 UTC |
fc4203c | Jacob Kelly | 24 April 2020, 05:07:35 UTC | implement jet rules by lowering to other primitives (#2816) merge jet_test add jet rules use lax.square | 24 April 2020, 05:07:35 UTC |
2518343 | Matthew Johnson | 24 April 2020, 01:07:51 UTC | use static_argnums in xla_computation (#2812) * use static_argnums in xla_computation fixes #1017 * add static_argnums to make_jaxpr * fix type error: handle int case | 24 April 2020, 01:07:51 UTC |
9b6976b | Skye Wanderman-Milne | 23 April 2020, 23:46:05 UTC | Pin mypy version in .travis.yml. (#2811) This is recommended in https://mypy.readthedocs.io/en/stable/existing_code.html#continuous-integration, to avoid unexpected upgrades introducing new type errors. | 23 April 2020, 23:46:05 UTC |