swh:1:snp:71afc25eb6e6e055a37a962e6b91010ec35e397f

sort by:
Revision Author Date Message Commit Date
15c8d4c update version and changelog for pypi 22 March 2020, 03:57:25 UTC
a669eda Merge pull request #2470 from GeorgOstrovski/changelist/301855313 Make jax.numpy.squeeze as strict as numpy.squeeze about axis shape. 22 March 2020, 02:56:03 UTC
7b0ee9a improve implementation of MVN logpdf (#2481) fixes #2314 I also added a bit more test coverage, but not a ton: scipy has different batch shape semantics and default arguments than I might expect, so I didn't bother to implement those (and left some test cases commented out). I ran into this surprising scipy bug: ```python In [1]: from scipy.stats import multivariate_normal In [2]: import numpy as np In [3]: args = [np.array(1., np.float32), np.array(2., np.float64), np.array(3., np.float64)] In [4]: print([x.shape for x in args]) [(), (), ()] In [5]: multivariate_normal.logpdf(*args) Out[5]: -1.6349113442053944 In [6]: print([x.shape for x in args]) [(), (1,), (1, 1)] ``` Mutated arguments! But it depends on dtype promotion: ```python In [7]: args = [np.array(1., np.float32), np.array(2., np.float32), np.array(3., np.float32)] In [8]: print([x.shape for x in args]) [(), (), ()] In [9]: multivariate_normal.logpdf(*args) Out[9]: -1.6349113442053944 In [10]: print([x.shape for x in args]) [(), (), ()] ``` 21 March 2020, 22:42:59 UTC
e456edf Merge pull request #2480 from google/solve-triangular-vectors make lax_linalg.solve_triangular allow vector rhs 21 March 2020, 19:22:44 UTC
93d3e34 make lax_linalg.solve_triangular allow vector rhs also add tests for jax.scipy.linalg.cho_solve 21 March 2020, 17:46:07 UTC
fcdbe63 Trigger a Travis build (#2477) * Remove more unused imports * Fix warnings in travis.yml 21 March 2020, 16:38:46 UTC
e41a5ea Pytype2 (#2476) * Try to install pytype from pip 21 March 2020, 14:45:59 UTC
428377a Added type annotations and removed unused imports (#2472) * Added type annotations and removed unused imports * Adjusted type hints for pytype 21 March 2020, 12:54:30 UTC
9331fc5 Added pytype checking to Travis (#2475) 21 March 2020, 12:53:35 UTC
f7e6b26 Make jax.numpy.squeeze as strict as numpy.squeeze about axis shape. Raise error if an axis explicitly selected to be squeezed has shape != 1. 20 March 2020, 13:46:06 UTC
578e5cf Fix return type for vjp. (#2462) Fix vjp doc string. 19 March 2020, 18:54:04 UTC
592ba02 Fix nondeterminism in test ordering in Python 3.6. (#2460) Set orders aren't deterministic, and this made pytest-xdist complain. 19 March 2020, 18:53:24 UTC
7f8ce8f fix test errors from previous commit 19 March 2020, 18:33:00 UTC
1d0b7e2 make jaxpr pretty-print show multiple outputs 19 March 2020, 18:26:29 UTC
c3f8909 Update jaxlib version in README.md. (#2461) 19 March 2020, 16:41:19 UTC
d11a9ab Expose jax.lax.all_gather (#2449) * Expose jax.lax.all_gather * add all_gather to RTD 19 March 2020, 15:35:00 UTC
c7f211d Update JAX to use XLA hyperbolic functions. (#2415) 19 March 2020, 14:29:37 UTC
afdd1a7 Add more return types to api.py. (#2452) 19 March 2020, 14:28:29 UTC
cecfb37 Increment jaxlib version to 0.1.42. (#2457) Update XLA. 19 March 2020, 13:57:11 UTC
78c1f6b Increased tolerance for testScipySpecialFun (#2454) Prevent failures on TPU 19 March 2020, 07:54:37 UTC
cd7ab0a Changed to pmap_benchmark to make it runnable in Google (#2448) 19 March 2020, 05:56:59 UTC
2998a21 Updated Common Gotchas (#2435) * Minor update to docs; trigger readthedocs * Updated Common Gotchas notebook Handle errors explicitly, otherwise it is too hard to test the notebook by 'Run all' * Added a section about pure functions to Common Gotchas 19 March 2020, 05:55:43 UTC
68b32bf Add mypy type checking (#2430) * Add type annotations to make mypy pass. * Add mypy to .travis.yml. 18 March 2020, 21:06:05 UTC
cd248ba Fix xlog1py and xlogy not returning 0 when x == 0. (#2450) * Fix xlog1py and xlogy not returning 0 when x == 0. * Add tests for xlog1py and xlogy 18 March 2020, 21:05:28 UTC
cbdf9a5 Drop support for Python 3.5. (#2445) 18 March 2020, 14:54:28 UTC
7f1e859 Merge pull request #2447 from google/remove-safe-mul remove safe_mul (undo #383, also cf. #1052) 18 March 2020, 06:42:37 UTC
26d5a68 Wrap pad_widths in a tuple to avoid cache misses (#2379) 18 March 2020, 05:41:02 UTC
f1d9130 remove safe_mul (undo #383, also cf. #1052) 18 March 2020, 05:07:53 UTC
75077a1 Add pmap_benchmark.py (#2409) Example output: ``` $ TARGET_TOTAL_SECS=2 CUDA_VISIBLE_DEVICES= XLA_FLAGS=--xla_force_host_platform_device_count=500 python3 benchmarks/pmap_benchmark.py 2020-03-12 15:46:35.903121: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:313] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected /usr/local/google/home/skyewm/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU. warnings.warn('No GPU/TPU found, falling back to CPU.') ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.034490 std=0.002890 %std=8.378140 total=2.000426 #iters=58 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=100_nshards=4--------- mean=0.091495 std=0.005935 %std=6.486871 total=2.012888 #iters=22 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=101_nshards=4--------- mean=0.113549 std=0.009080 %std=7.996712 total=2.043878 #iters=18 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=500_nshards=4--------- mean=0.356868 std=0.007960 %std=2.230518 total=2.141210 #iters=6 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=2--------- mean=0.022288 std=0.002946 %std=13.219607 total=2.005951 #iters=90 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=4--------- mean=0.035210 std=0.002024 %std=5.747389 total=2.006975 #iters=57 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=8--------- mean=0.048641 std=0.001578 %std=3.243398 total=2.042912 #iters=42 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=100--------- mean=0.257487 std=0.007190 %std=2.792452 total=2.059900 #iters=8 #warmup=1 ---------Benchmark results for pmap_shard_args_nargs=10_nshards=500--------- mean=1.696294 std=0.005097 %std=0.300473 total=3.392588 #iters=2 #warmup=1 ---------Benchmark summary for pmap_shard_args--------- nargs nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0344901 8.37814 1 100 4 0.0914949 6.48687 2.65279 101 4 0.113549 7.99671 3.29221 500 4 0.356868 2.23052 10.347 10 2 0.0222883 13.2196 0.646224 10 4 0.0352101 5.74739 1.02088 10 8 0.0486408 3.2434 1.41028 10 100 0.257487 2.79245 7.46555 10 500 1.69629 0.300473 49.182 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.061780 std=0.004737 %std=7.668032 total=2.038743 #iters=33 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=100_nshards=4--------- mean=0.123264 std=0.005980 %std=4.851385 total=2.095494 #iters=17 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=500_nshards=4--------- mean=0.471524 std=0.024051 %std=5.100792 total=2.357622 #iters=5 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=2--------- mean=0.041546 std=0.004446 %std=10.700256 total=2.035745 #iters=49 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=4--------- mean=0.063768 std=0.002756 %std=4.322039 total=2.040561 #iters=32 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=8--------- mean=0.087285 std=0.005343 %std=6.121320 total=2.007556 #iters=23 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=100--------- mean=0.623440 std=0.004038 %std=0.647725 total=2.493759 #iters=4 #warmup=1 ---------Benchmark results for pmap_shard_outputs_nouts=10_nshards=500--------- mean=4.096676 std=0.000000 %std=0.000000 total=4.096676 #iters=1 #warmup=1 ---------Benchmark summary for pmap_shard_outputs--------- nouts nshards mean %std relative ------- --------- --------- --------- ---------- 10 4 0.0617801 7.66803 1 100 4 0.123264 4.85139 1.99521 500 4 0.471524 5.10079 7.6323 10 2 0.0415458 10.7003 0.672479 10 4 0.0637675 4.32204 1.03217 10 8 0.087285 6.12132 1.41283 10 100 0.62344 0.647725 10.0913 10 500 4.09668 0 66.3106 ``` 17 March 2020, 21:31:25 UTC
75813f4 Update jax version number. (#2444) 17 March 2020, 21:15:59 UTC
db8bea4 Update changelog for jax 0.1.61 release. (#2443) 17 March 2020, 21:09:05 UTC
84140ea Ensure jax.host_ids() returns a stable ordering (#2442) ...in this case, by sorting them. 17 March 2020, 21:02:41 UTC
e46a002 Remove runtime tuple use from JAX. (#2441) Change in preparation for upcoming runtime changes related to buffer aliasing. 17 March 2020, 21:02:22 UTC
985d5f7 Fix Python 3.5 support. (#2439) * Fix Python 3.5 compatibility problems. 17 March 2020, 21:01:04 UTC
6b157ff Update jax version to 0.1.60. (#2437) 17 March 2020, 14:04:17 UTC
c4c770b Minor update to docs; trigger readthedocs (#2434) 17 March 2020, 08:24:17 UTC
e66e569 Minor update to docsl trigger readthedocs (#2433) 17 March 2020, 08:07:14 UTC
5f2d225 Fix typo in ShapeDtypeStruct (#2253) * fix ShapeDtypeSTruct dtype bug * move dtype conversion to constructor 17 March 2020, 06:45:17 UTC
0ddc2ec Fixed failing tests 17 March 2020, 05:51:01 UTC
3362591 Updated CHANGELOG 17 March 2020, 05:51:01 UTC
5cf82c7 Improved argument checking for lax.broadcast_in_dim * Added checking that the output shape has higher or equal rank to input * Added checking that the broadcast_dims are sorted (required by XLA) * Relaxed check that operand dimension size can be 1 * Added lax.broadcast_in_dim docstring 17 March 2020, 05:51:01 UTC
c0c3a4a Merge pull request #2401 from hawkinsp/ones Check for invalid shapes in broadcast_in_dim and fail gracefully. 17 March 2020, 02:46:52 UTC
36a46cb Merge pull request #2426 from mtthss/master Make gradient clipping by norm more numerically safe 17 March 2020, 02:45:30 UTC
ec9513f Advertise jaxlib 0.1.41. (#2432) Bump minimum jaxlib version to 0.1.41. 16 March 2020, 20:10:26 UTC
ed8dbd2 temporarily switch off #2414 changes 16 March 2020, 19:17:09 UTC
efaedf4 undo previous commit 16 March 2020, 19:13:25 UTC
5280793 fix custom_transforms + jit bug from #2416 16 March 2020, 17:23:24 UTC
1c202ac fix typo, unbreak pmap (from #2416) 16 March 2020, 16:20:34 UTC
388e78f Increment jaxlib version to 0.1.41. (#2428) Update XLA. 16 March 2020, 15:24:02 UTC
219d503 Don't show progress bar in build script if output is not a terminal. (#2429) 16 March 2020, 15:01:08 UTC
196fddb Make gradient clipping by norm more numerically safe 16 March 2020, 11:12:04 UTC
6545cf3 Merge pull request #2424 from google/broadcast-shapecheck add lax.broadcast_in_dim shape check and test 16 March 2020, 05:22:24 UTC
7666c25 fix buggy broadcast_in_dim shapecheck test 16 March 2020, 04:32:56 UTC
a73bdd5 Merge pull request #2363 from jessebett/jet-prims Primitive Rules for Higher-Order Derivatives (jet) Scoreboard 16 March 2020, 04:05:26 UTC
94832f9 add lax.broadcast_in_dim shape check and test Operand dimensions must equal their corresponding dimensions in the broadcast shape. 16 March 2020, 03:30:44 UTC
a00e398 remove scipy dep, fix dtype issue 15 March 2020, 19:00:44 UTC
8d402d8 add copyright notice to jet.py 15 March 2020, 18:39:44 UTC
ae921c7 update changelog 15 March 2020, 18:15:13 UTC
a7b3be7 move jet into jax.experimental 15 March 2020, 18:10:56 UTC
92a0b3d add basic pytree support to jet 15 March 2020, 16:58:54 UTC
668a170 add jet tests, remove top-level files 15 March 2020, 04:22:10 UTC
840797d refactor reduce_max jet rule 15 March 2020, 01:42:51 UTC
2c53b94 add tests Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu> Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu> 15 March 2020, 01:42:51 UTC
b4d003d jet rule for log Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu> Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu> 15 March 2020, 01:42:51 UTC
30830df linear rule for sub Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> 15 March 2020, 01:42:51 UTC
dcebe50 jet for reduce_max Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu> Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu> 15 March 2020, 01:42:51 UTC
3bcf02a Add gather rule Co-authored-by: Matthew Johnson <mattjj@csail.mit.edu> Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu> 15 March 2020, 01:42:51 UTC
098aabe fix typo 15 March 2020, 01:42:51 UTC
ddd52c4 adding div and linear prims Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> 15 March 2020, 01:42:51 UTC
7adf9fe add more jet rules! Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> Co-authored-by: Jacob Kelly <jacob.jin.kelly@gmail.com> Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu> 15 March 2020, 01:41:44 UTC
a21fdf8 more jet rules and tests Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> 15 March 2020, 01:41:44 UTC
e84a621 new jet implementation, with conv-based rules Co-authored-by: Jesse Bettencourt <jessebett@cs.toronto.edu> Co-authored-by: David Duvenaud <duvenaud@cs.toronto.edu> 15 March 2020, 01:41:44 UTC
47df7b9 change the xla representation of JAX's unit (#2416) * change the xla representation of JAX's unit Previously the representation of JAX's unit value (a sentinel / placeholder) was an empty tuple, but by changing the representation to something else we can further reduce our dependence on runtime tuples. This commit makes the representation fairly easy to change. There are three functions in xla.py that define the representation. Here are versions that would keep the old XLA representation as an empty tuple: ``` def _make_unit(c): return c.Tuple() def _make_abstract_unit(_): return xc.Shape.tuple_shape(()) def _device_put_unit(_, device): return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device)) ``` The new representation is as a trivial array. An alternative representation would be nothing at all: we don't need to generate XLA computations that have representations of JAX units. While that alterntaive is probably the best choice, it seemed like it would require a bit more refactoring/bookkeeping (e.g. to allow XLA computations to have a smaller number of outputs than the corresponding JAX function), and would also mean the XLA representation would be a step further removed from the jaxpr representation. So I stuck with a trivial array for now. The mapping from JAX types to XLA types need not be invertible. However, XLA translation rules currently don't take as arguments the corresponding JAX types (abstract values), and there were a few cases where we relied on checking whether an argument's XLA type was that of an empty tuple so as to determine if we were effectively operating on a JAX unit. In particular, the AD-related primitive add_jaxvals_p could in principle add two units, and get lowered to an XLA addition on the unit representation. Previously, the translation rule for add_jaxvals_p checked the XLA type so that adding two empty tuples didn't produce any XLA operation; now it adds its inputs, and so if unit is represented as a trivial array we could be inserting trivial scalar adds where we had none before. However, if that case is ever possible, it doesn't come up in our tests (which I checked by keeping the representation as an empty tuple and then asserting an XLA tuple type is never seen by that translation rule). * add comment about JAX<->XLA array types assumption 14 March 2020, 19:33:14 UTC
c41c4de lower fori_loop to scan when possible (#2414) When a fori_loop specialized on trip count is to be evaluated, it's preferable to generate a scan rather than a while_loop because the former is reverse-mode differentiable while the latter is not. Otherwise they're essentially the same; in particular, no extensive inputs/outputs arise unless reverse-mode autodiff is applied. Also fixes #2412. 13 March 2020, 22:15:55 UTC
271041b Update a regression test to test size-zero device to device transfers. (#2411) 13 March 2020, 17:35:18 UTC
7f0463e remove input shapes from params of some primitives (#2410) Long, long ago, when JAX was first born, we realized that we couldn't transpose this jaxpr: { lambda ; a. let b = reduce_sum[ axes=(0,) ] a in b } The problem was that the transpose of a reduce-sum is a broadcast, but because jaxprs didn't have shape information available, we didn't know what input shape to broadcast to! Our hack was to have the primitives that required shape information for transposition to acquire it into their parameters, so that we'd produce jaxprs like this one: { lambda ; a. let b = reduce_sum[ axes=(0,) input_shape=(3,) ] a in b } That's not only aesthetically unpleasant, but also it meant we were limiting an (unused) capability of the system: ideally we should be able to trace a reduce-sum jaxpr without specializing on shape information (e.g. at the Unshaped level) and only require shape specialization for transposition. (Good thing no one actually traces at Unshaped...) But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that shape information (or whatever information with which the jaxpr was specialized out of Python) is in the jaxpr itself. So we could finally remove these shapes-in-params warts! That's exactly what this commit does! Co-authored-by: Roy Frostig <frostig@google.com> Co-authored-by: Roy Frostig <frostig@google.com> 13 March 2020, 14:13:29 UTC
58feed2 jax.lax.nextafter test fix. (#2408) Fixes #2403. 12 March 2020, 22:53:47 UTC
cf41f76 Add np.linalg and np.fft functions to documentation. (#2407) 12 March 2020, 19:05:59 UTC
61b430e Added more documentation for how to fix notebook build failures (#2404) 12 March 2020, 09:59:30 UTC
620bf43 [remat] Change remat lowering to XLA::Conditional (#2391) * [remat] Change remat lowering to XLA::Conditional `jax.remat` creates rematerializing passes that don't have data dependencies on the actual loss-computing forward pass. This means that the XLA scheduler was free to schedule the remat forward pass before the loss-computing pass, defeating the goal of saving accelerator memory with `jax.remat`. In practice, it sometimes did for my workloads. This change expresses the lowering of remat_call(f) as: Conditional(true, inputs, f, inputs, dummy_f). In the common case of `jax.grad(jax.remat(f))`, the content of the lowered remat_call are both the forwards & backwards; that is, the incoming cotangents are part of the args. Additionally, Conditional (AFAIK) is un-inlineable in the sense that it doesn't execute until all its inputs (e.g. cotangents!) are available. Downsides: - AFAICT, we can no longer interleave computation in/outside the rematerialized block. - Potentially, lower performance. I do not observe this in my tests. * provide no replication info for subcomputation params 11 March 2020, 20:36:07 UTC
2dfeaeb Allow zero tolerance for jax.test_util.tolerance (#2393) Currently, if a user passes any falsy value to jax.test_util.tolerance, it is changed to the default value. This makes sense when the value passed is None, but not when the value passed is 0 (which indicates a desired tolerance of exactly 0). Disables failing tests for now. 11 March 2020, 20:19:46 UTC
cdf188a add raises-exception notebook cell metadata (#2402) 11 March 2020, 16:42:25 UTC
419961f Check for invalid shapes in broadcast_in_dim and fail gracefully. 11 March 2020, 13:57:20 UTC
ffa0340 Add jnp vs np out of bounds indexing to Sharp Bits nb (#2378) 11 March 2020, 00:53:43 UTC
cfbdb65 add register_pytree_node_class, fixes #2396 (#2400) Co-authored-by: Stephan Hoyer <shoyer@google.com> Co-authored-by: Stephan Hoyer <shoyer@google.com> 10 March 2020, 22:01:18 UTC
ebbcbad allow vmap in_axes to be a list, fixes #2367 (#2395) 10 March 2020, 15:29:46 UTC
9fd69a0 Replace uses of ExecutePerReplica with ExecuteOnLocalDevices. (#2394) ExecutePerReplica is deprecated, and ExecuteOnLocalDevices is now available via the minimum jaxlib version. 10 March 2020, 14:51:09 UTC
cc53aa9 skip new optix test on tpu (cf. #2350) 10 March 2020, 13:59:54 UTC
5c3b478 Add a module to apply updates every k steps (and accumulate them otherwise) (#2350) 10 March 2020, 13:40:38 UTC
863576c jit lax_numpy.roll (#2392) This was making tracing slow for code with lots of rolls. 09 March 2020, 20:21:30 UTC
f3f0abb Fix exception causes all over the codebase (#2376) Co-authored-by: Peter Hawkins <phawkins@google.com> 09 March 2020, 20:06:12 UTC
ac6a313 Fix ONNX mnist example (#2374) * Fix ONNX mnist example * use np to compute the shape; rename jax.numpy as jnp 09 March 2020, 20:04:59 UTC
c52f32b Removed unused imports (#2385) Also disabled a couple more linalg tests that crash on my Mac 09 March 2020, 19:42:08 UTC
282225f Added some pytype annotations (#2386) Tried to catch all uses of linear_util.WrappedFun 09 March 2020, 19:41:01 UTC
c53ae2c automatic detection of wheel version (#2373) 09 March 2020, 18:43:45 UTC
411c8a4 Update lnp -> jnp to fix test failure after merge. (#2389) 09 March 2020, 14:17:42 UTC
8339511 Implement NumPy sorting routines. (#2318) Implement `np.msort`. Related issue: #2079 09 March 2020, 14:07:12 UTC
25dd419 share rng and predicate between all foiled args (#2375) This doesn't appear to be a performance disadvantage, and it makes reading the HLO infinitely easier, as there's now only one pair of RNG/constants per remat call. 09 March 2020, 14:05:40 UTC
a5daafd update gitignore (#2361) 09 March 2020, 14:04:23 UTC
0080c89 Fix a few type annotations in `api.py`. (#2387) 09 March 2020, 13:35:21 UTC
back to top