swh:1:snp:71afc25eb6e6e055a37a962e6b91010ec35e397f

sort by:
Revision Author Date Message Commit Date
1094740 Merge pull request #5528 from google:update-pypi PiperOrigin-RevId: 354012572 27 January 2021, 04:47:13 UTC
a9e5fab Merge pull request #5526 from google:issue5522 PiperOrigin-RevId: 354012448 27 January 2021, 04:43:41 UTC
d2ae949 update version and changelog for pypi 27 January 2021, 04:19:09 UTC
ff66c55 add axis_env argument to make_jaxpr fixes #5522 27 January 2021, 04:11:05 UTC
1fd1faa Merge pull request #5523 from sharadmv:custom-gradient-tuple PiperOrigin-RevId: 353957596 26 January 2021, 23:00:32 UTC
6061b09 Allow `jax.custom_gradient` to return vjp with singleton return value 26 January 2021, 21:40:23 UTC
814c4ad Merge pull request #5490 from google:xeinsum PiperOrigin-RevId: 353929996 26 January 2021, 20:56:05 UTC
e2140bf Merge pull request #5426 from jakevdp:thinking-in-jax PiperOrigin-RevId: 353926472 26 January 2021, 20:40:50 UTC
eec2dbc Merge pull request #5512 from google:closure-convert-docs PiperOrigin-RevId: 353926248 26 January 2021, 20:37:24 UTC
737c162 relax test tolerance on tpu 26 January 2021, 20:35:43 UTC
967f3ac Add thinking_in_jax.ipynb 26 January 2021, 20:08:37 UTC
93c61e4 import closure_convert at top module level 26 January 2021, 17:44:48 UTC
f4cf710 Add support for sharded_jit conversion to jax2tf. PiperOrigin-RevId: 353834475 26 January 2021, 11:36:33 UTC
237f3ee Merge pull request #5509 from danieldjohnson:all_gather_to_unbatched PiperOrigin-RevId: 353784844 26 January 2021, 03:56:16 UTC
a8fc8b3 Merge pull request #5511 from lukaszlew:patch-2 PiperOrigin-RevId: 353781688 26 January 2021, 03:30:07 UTC
2e19b1d Merge pull request #5383 from zhangqiaorjc:jax_cpp PiperOrigin-RevId: 353780826 26 January 2021, 03:21:58 UTC
7e15524 Add example code to save JAX program and run using C++ runtime. 26 January 2021, 03:12:26 UTC
4adc436 include closure_convert in generated docs 26 January 2021, 01:42:46 UTC
9dccf56 Clarify tracking Clarify tracing a bit and use wording that does not suggest that JAX executed python program. 26 January 2021, 01:16:29 UTC
15b95e3 Use np.shape instead of assuming argument has a shape attr 25 January 2021, 23:11:38 UTC
84f55f8 Merge pull request #5508 from jakevdp:fix-sphinx PiperOrigin-RevId: 353733345 25 January 2021, 22:38:28 UTC
c6a1bba Add evaluation rule for all_gather. This should only be called when an all_gather runs on arguments that are not batch tracers, for instance when all_gather-ing a constant. 25 January 2021, 22:27:39 UTC
cfe934c Fix some doc build warnings 25 January 2021, 22:08:57 UTC
7865043 Improve batched collective rule for all_gather_p When an all_gather references a vmapped axis, there is a particularly simple way of implementing it: simply "forget" that the axis was mapped, and return the full array. Conveniently, this doesn't require any explicit broadcasting, and makes it possible to use out_axes=None with the results. 25 January 2021, 21:52:38 UTC
f4b5ff9 Merge pull request #5488 from jakevdp:x64-contextmanager PiperOrigin-RevId: 353722592 25 January 2021, 21:51:12 UTC
b545461 Add experimental context manager to enable/disable X64 mode 25 January 2021, 21:23:15 UTC
136358c Merge pull request #5480 from pschuh:wrap-outs PiperOrigin-RevId: 353670528 25 January 2021, 18:03:17 UTC
13b921b Merge pull request #5465 from apaszke:xmap-replica-nesting PiperOrigin-RevId: 353640436 25 January 2021, 15:16:25 UTC
137321f Restore support for nested xmaps when using the (standard) replica lowering Fortunately this wasn't too difficult, as most of the code was already there. The biggest issues were a lack of axis name substitution and inability to handle multiple resource assignments for a single logical axis. I'm planning to add a more comprehensive test suite for this, but I'd like to wait for the pdot test PR (#5459) to land first. It contains a bunch of utilities that would come in handy. 25 January 2021, 13:39:49 UTC
802c773 Merge pull request #5502 from jackd:testNorm-fix PiperOrigin-RevId: 353565580 25 January 2021, 03:58:49 UTC
f02f775 Merge pull request #5500 from terhorst:poisson-cdf PiperOrigin-RevId: 353565572 25 January 2021, 03:55:15 UTC
f288cf4 fixed testNorm 25 January 2021, 02:56:11 UTC
1524b82 add support for scipy.stats.poisson.cdf 24 January 2021, 16:15:31 UTC
b0f5ef4 Merge pull request #5492 from google:revive-revive-leak-checker PiperOrigin-RevId: 353440404 23 January 2021, 22:44:30 UTC
a7bfebe improve leak checker flag description 23 January 2021, 22:17:22 UTC
9787894 refactor batching transform logic, fix leak checks See PR description in #5492 for details. Co-authored-by: Peter Hawkins <phawkins@google.com> 23 January 2021, 04:17:03 UTC
203af45 revive the leak checker, as a debug mode Co-authored-by: James Bradbury <jekbradbury@google.com> 23 January 2021, 02:31:00 UTC
bc3cd12 Merge pull request #5495 from hawkinsp:tokens PiperOrigin-RevId: 353283976 22 January 2021, 19:34:43 UTC
dd34d48 Fix exception when tokens are used in AD. 22 January 2021, 16:00:31 UTC
9ccfc9f Merge pull request #5478 from jakevdp:fix-docstrings PiperOrigin-RevId: 353133338 22 January 2021, 00:53:06 UTC
7c67ec1 Merge pull request #5477 from google:more-escaped-tracer-error-message-tweaks PiperOrigin-RevId: 353115571 21 January 2021, 23:24:31 UTC
6d2f832 add xeinsum, an einsum for xmap (& einsum easter egg) Co-authored-by: Adam Paszke <apaszke@google.com> 21 January 2021, 22:47:35 UTC
4a7b7df Merge pull request #5491 from google:pdot-tests3 PiperOrigin-RevId: 353105977 21 January 2021, 22:38:16 UTC
c02d804 add systematic pdot tests, utility functions Run lots of tests with e.g. ``` env JAX_NUM_GENERATED_CASES=1000 python tests/xmap_test.py PDotTests ``` 21 January 2021, 22:06:30 UTC
7925a3a Merge pull request #5486 from jakevdp:finfo PiperOrigin-RevId: 353042636 21 January 2021, 18:00:01 UTC
60a87fd Merge pull request #5419 from jakevdp:int-power PiperOrigin-RevId: 353033127 21 January 2021, 17:14:34 UTC
f726709 Merge pull request #5446 from gsp-27:master PiperOrigin-RevId: 353021985 21 January 2021, 16:11:34 UTC
fc258c3 Make jnp.finfo a class and support bfloat16 21 January 2021, 16:05:36 UTC
988949f Move jax.bzl into correct location. Unbreaks build. PiperOrigin-RevId: 353009344 21 January 2021, 14:56:23 UTC
32de6ff Replace None with an object NoSharding. This is to make the change to a C++ ShardingSpec easier. See also https://github.com/google/jax/pull/5444 PiperOrigin-RevId: 352965689 21 January 2021, 09:15:25 UTC
a900620 Merge pull request #5482 from google:contributing-squash PiperOrigin-RevId: 352937260 21 January 2021, 05:15:05 UTC
df273c0 Refactor how the result processing code wraps the indexing information. Convert `array_result_handler` and `avals_to_results_handler` to use functools.partial and a callable class so that the sharding information can be introspected instead of opaque lambdas. 21 January 2021, 03:01:04 UTC
186c973 Added test for python scalar 21 January 2021, 02:47:18 UTC
160dfd3 Revert import path changes to examples/ and benchmarks/ PiperOrigin-RevId: 352911869 21 January 2021, 01:35:55 UTC
297250d add a request for PR-squashing in contributing.md 21 January 2021, 00:37:57 UTC
ffa05d1 Merge pull request #5481 from jakevdp:doc-formatting PiperOrigin-RevId: 352892385 20 January 2021, 23:47:03 UTC
84e91d5 add transformed fun src info to escaped tracer err This change adds to the error message when we hit an escaped tracer. In particular, it adds source info for the function that was transformed. This change currently only applies to escaped `DynamicJaxprTracer`s (arising from `jit`, `pmap`, `scan`, and other staging functions) and not other traces. A natural follow-up would be to attach this information to other traces. Co-authored-by: Lena Martens <lenamartens@google.com> 20 January 2021, 23:30:37 UTC
a0b12bb DOC: fix minor formatting issues 20 January 2021, 22:38:19 UTC
9bdc2ec Consolidate build macros into a single jax.bzl file. PiperOrigin-RevId: 352871429 20 January 2021, 22:06:22 UTC
fe9a3e7 util._wraps: correctly handle initial line indentation 20 January 2021, 21:40:26 UTC
929a684 Small cleanups to dependency structure. PiperOrigin-RevId: 352853244 20 January 2021, 20:43:28 UTC
e217a7b Merge pull request #5473 from hawkinsp:docs2 PiperOrigin-RevId: 352839545 20 January 2021, 19:43:47 UTC
2991b04 jnp.power: use integer_power path in all applicable cases 20 January 2021, 17:21:57 UTC
62e89cb Merge pull request #5213 from malmaud:changelist/346683050 PiperOrigin-RevId: 352800932 20 January 2021, 16:45:06 UTC
8420ee2 Merge pull request #5472 from hawkinsp:build PiperOrigin-RevId: 352797345 20 January 2021, 16:21:02 UTC
4fdd029 Clarify the CuDNN versions expected by jaxlib wheels. 20 January 2021, 15:00:19 UTC
d585638 Merge pull request #5470 from inailuig:fix-rocm-build PiperOrigin-RevId: 352783431 20 January 2021, 14:46:44 UTC
3bec4b3 Fix ABSL build dependencies of //jaxlib:handle_pool 20 January 2021, 14:38:55 UTC
5ced12f Merge pull request #5455 from chr1sj0nes:changelist/351545649 PiperOrigin-RevId: 352756095 20 January 2021, 11:03:40 UTC
57405b0 fix building on ROCm 20 January 2021, 09:50:04 UTC
619bfd5 Merge pull request #5468 from google:issue5440 PiperOrigin-RevId: 352715439 20 January 2021, 05:05:18 UTC
0f70451 add BatchTrace.process_custom_vjp_call It was an oversight not to include this! Notice we have BatchTrace.process_custom_jvp_call. In fact, we can use the same function! We just needed the simplest possible post-process-call which just peels and packages. fixes #5440 Co-authored-by: Roy Frostig <frostig@google.com> 20 January 2021, 03:08:23 UTC
34e798f Merge pull request #5466 from jakevdp:fix-integer-pow PiperOrigin-RevId: 352685155 20 January 2021, 01:01:47 UTC
0a89fc8 integer_pow: fix jvp rule for y=0 19 January 2021, 23:42:40 UTC
7acf521 integer_pow: fix translation rule for y=0 19 January 2021, 23:42:18 UTC
d641e6c Revert part of #5319: don't use `traceback_util.path_starts` in `source_info_util.user_frame`. The issue was that it incurs filesystem accesses while we're lowering to XLA, which can cause a significant slowdown. PiperOrigin-RevId: 352662434 19 January 2021, 23:01:52 UTC
9076008 lax.integer_pow(): always bind the primitive 19 January 2021, 19:36:39 UTC
c0c4843 Add support for 'preferred_element_type' keyword arg in `dot` and `dot_general`. XLA recently added support for this parameter to xops.DotGeneral. It's an optional parameter that controls the accumulation type used by the dot operation. This is useful for eg quantized ANNs, where you might want to do matrix multiples with int8 tensors and get back an int32 tensor instead of an int8 tensor that suffers from severe overflow. Note it's not sufficient in this case to cast the inputs to 'dot' to int32 beforehand and rely on the default output dtype inference, since backend devices might have an accelerated path for int8*int8->int32 matmuls and we want that explicitly represented in the XLA. Note because XLA still doesn't support integer dots on the CPU backend, that use case can't tested with a CPU-only test at the moment. 19 January 2021, 18:56:46 UTC
2d3f33d fix float32 19 January 2021, 14:49:28 UTC
9260f2c regex check for each test case 19 January 2021, 14:12:11 UTC
b4900ee Merge pull request #5453 from apaszke:xmap-vmap PiperOrigin-RevId: 352537923 19 January 2021, 12:16:28 UTC
4b48c7f Use XLA AllGather op for GPU (attempt 2). This is an expansion of the first, rolled-back attempt (https://github.com/google/jax/pull/5260), this time with auto-diff and batching rules that some users are relying on. My benchmarks suggest a speed-up of ~2-2.5x for larger inputs. 19 January 2021, 11:16:25 UTC
0b57a16 Merge pull request #5462 from google:error-message-tweak2 PiperOrigin-RevId: 352488512 19 January 2021, 05:02:35 UTC
47f7cd4 avoid printing double periods in error messages 19 January 2021, 04:37:12 UTC
bd9ac93 Added changes requested 19 January 2021, 04:21:33 UTC
772a6da Merge pull request #5319 from google:source-line-info-on-escaped-tracer-errors PiperOrigin-RevId: 352482449 19 January 2021, 03:56:19 UTC
886b26f add source line info to more escaped tracer errors This extra source info is still only on jaxpr staging tracers, but those seem to be the most common culprits. I moved the `_line_info` attribute to the base Tracer class in core.py in anticipation of populating it for more traces than just DynamicJaxprTrace, but I'll leave that extension to follow-up. I adapted the main escaped tracer error messages in core.py, and also slightly generalized and debugged source_info_util functions (thanks for explaining the path prefix bug, @froystig !). 19 January 2021, 03:00:04 UTC
c287180 Update jax/experimental/maps.py Co-authored-by: Matthew Johnson <mattjj@google.com> 18 January 2021, 18:40:46 UTC
6d2b307 Make it possible to vmap xmapped functions Or perhaps more importantly make it possible to nest xmaps that don't specify any `axis_resources`. The math is a little tricky, so I've added a fairly strong test that enumerates a wide range of potential ways of interleaving vmapped and xmapped axes in both inputs and the output. Thanks to that, I've actually caught one very subtle bug in the dynamic tracing rule for xmap (sorting by dimension names instead of positional axes). 18 January 2021, 15:35:47 UTC
eeb5c42 Merge pull request #5452 from gnecula:tf_limit2 PiperOrigin-RevId: 352400540 18 January 2021, 14:12:07 UTC
d77a941 [jax2tf] Reflect in the limitations that add is now implemented for uint32 in TF 18 January 2021, 13:19:01 UTC
040d268 Merge pull request #5451 from gnecula:tf_jit_compile PiperOrigin-RevId: 352392391 18 January 2021, 13:03:27 UTC
04628c9 Merge pull request #5435 from apaszke:xmap-parameterized PiperOrigin-RevId: 352389890 18 January 2021, 12:41:50 UTC
6d2b976 [jax2tf] Start using jit_compile instead of the deprecated experimental_compile 18 January 2021, 12:41:42 UTC
449c2bc Merge pull request #4804 from marcvanzee:patch-1 PiperOrigin-RevId: 352382502 18 January 2021, 11:44:38 UTC
6bf6349 Copybara import of the project: -- 781492e0120ec915f9fdc83479884908f59d113d by George Necula <gcnecula@gmail.com>: [jax2tf] Update limitations Some bugs were fixed on the TF-side, and we can remove some limitations. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/5449 from gnecula:jax2tf_limit 781492e0120ec915f9fdc83479884908f59d113d PiperOrigin-RevId: 352381535 18 January 2021, 11:38:22 UTC
a1e182c Updates comment 18 January 2021, 10:39:37 UTC
166e129 Clean up xmap tests Move `pdot` tests to a separate class. Automatically run all regular xmap tests with both the `pmap`-style lowering and `sharded_jit`-style lowering. 18 January 2021, 10:21:32 UTC
67b5af9 Copybara import of the project: -- 9be685946252edc67c2c28261b100b9aee68614a by George Necula <gcnecula@gmail.com>: Change the translation rule for lax.nextafter_p to ensure broadcasting during translation. Previously, this was the only binary arithmetic primitive that did not have broadcasting during translation. Trying to use it with non-equal shapes resulted in the error: ``` RuntimeError: Internal: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/client/xla_builder.cc:748) non_scalar_shape.value().dimensions() == shape->dimensions() Unimplemented implicit broadcast.: This is a bug in JAX's shape-checking rules; please report it! ``` COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/5448 from gnecula:nextafter 9be685946252edc67c2c28261b100b9aee68614a PiperOrigin-RevId: 352367039 18 January 2021, 09:59:55 UTC
4dd50b8 removed whitespace causing flake failing 18 January 2021, 02:16:28 UTC
back to top