https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
6d95b6c Merge pull request #4573 from google:update-pypi PiperOrigin-RevId: 337028688 14 October 2020, 06:25:32 UTC
503bc36 fix typo in comment 14 October 2020, 06:02:49 UTC
4128240 update version and changelog for pypi 14 October 2020, 06:01:30 UTC
cb4a014 Merge pull request #4567 from skye:cuda111 PiperOrigin-RevId: 336992058 14 October 2020, 00:41:21 UTC
86e113c Merge pull request #4563 from jakevdp:pmap-error PiperOrigin-RevId: 336980117 13 October 2020, 23:34:58 UTC
dfe0526 Merge pull request #4570 from google:jaxpr-util-test PiperOrigin-RevId: 336979966 13 October 2020, 23:31:41 UTC
0daf4c0 assume less about source locations in jaxpr_util_test 13 October 2020, 22:48:04 UTC
484ec3e Internal change PiperOrigin-RevId: 336957861 13 October 2020, 21:36:45 UTC
01e113d Update jaxlib build scripts to build CUDA 11.1 wheels. 13 October 2020, 21:31:43 UTC
4624603 Improve pmap axis error in the presence of pytrees 13 October 2020, 18:22:59 UTC
83d0115 Merge pull request #4557 from jakevdp:fix-tests PiperOrigin-RevId: 336909706 13 October 2020, 17:49:04 UTC
76fe180 Merge pull request #4556 from jakevdp:index-fix PiperOrigin-RevId: 336900642 13 October 2020, 17:09:06 UTC
db7f1f4 Increase test coverage for indexing ops 13 October 2020, 17:08:20 UTC
9c48502 Merge pull request #4562 from bchetioui:expand_dot_general PiperOrigin-RevId: 336895192 13 October 2020, 16:48:32 UTC
5981fc3 Merge pull request #4530 from jakevdp:cleanup PiperOrigin-RevId: 336895154 13 October 2020, 16:45:13 UTC
be149f4 [jax2tf] Added conversion paths without einsum for dot_general. 13 October 2020, 16:13:34 UTC
f077139 Merge pull request #4451 from alexminnaar:master PiperOrigin-RevId: 336888044 13 October 2020, 16:08:28 UTC
4f7aec8 Merge pull request #4561 from hawkinsp:truncnorm PiperOrigin-RevId: 336870048 13 October 2020, 14:14:12 UTC
0b8eb92 Add stop_gradients around lax.nextafter to fix TFP gradient errors for jax.random.truncated_normal. 13 October 2020, 13:16:29 UTC
8b44131 Merge pull request #4522 from bchetioui:expand_conv_conversion PiperOrigin-RevId: 336833753 13 October 2020, 09:00:09 UTC
c674c19 [jax2tf] Add paths that do not use XLA in conv_general_dilated. This adds some amount of support for people who want to run convolutions without having XLA linked in. These paths can seemingly be converted for TFJS as well. Due to a so far unknown bug in some of the conversions, the paths are disabled by default and the "ENABLE_TF_CONVOLUTION" global variable in jax2tf.py must be explictly toggled to use them. See the comment associated with ENABLE_TF_CONVOLUTION for context. 13 October 2020, 08:00:34 UTC
e064451 Merge pull request #4543 from google:remove-double-dtype-warning PiperOrigin-RevId: 336771011 12 October 2020, 23:40:55 UTC
dbebc9a remove double warning with asarray dtype='int' 12 October 2020, 22:00:35 UTC
3ca9bba BUG: fix indexing error 12 October 2020, 20:58:49 UTC
d1ca3b3 Merge pull request #4548 from hawkinsp:truncnorm PiperOrigin-RevId: 336735846 12 October 2020, 20:51:21 UTC
080007a Ensure values returned by jax.random.truncated_normal() are in range. A user observed -inf values being returned by truncated_normal(), which occur if the uniform random value passed to erfinv() is out of range, e.g., due to rounding. Do more of the computation using jax.random.uniform(), which promises correct behavior in the face of rounding. As an added security measure, also clamp the outputs of the function to the open interval. 12 October 2020, 20:27:54 UTC
5a3929e Merge pull request #4547 from jblespiau:changelist/336647290 PiperOrigin-RevId: 336728597 12 October 2020, 20:19:33 UTC
0660939 Enable the C++ jax.jit fast code-path by default. Here is the list of things I think we should do/revisit at some point in the future: 1. Most importantly, think about Lazy expressions. I think they are powerful, and they can likely enable some optimization impossible otherwise (e.g. stack a ShardedDeviceArray and split it). We likely want to add support for them in the C++ jax.jit. 2. Remove trivial computation support (currently fallback to Python), at least in the omnistaging code-path, as it will no longer necessary with omnistaging, and that the complexity do not justify the feature. 3. Not sure what to do about jax.jit(pmap). It currently fallback to Python when the Executable has more than one device. I understand it executes this on all cores (as a pmap), but will return the first result (from the first core). I am tempted to think this is a non-feature, that no one is looking for and that it can be achieved by doing pmap(f)(x)[0] instead. Having a single simple way of doing one thing is usually beneficial. 4. Not sure what to do about `DeviceConstant` support (currently fallback to Python). I do not understand yet why they exist and problem they solve. 5. Revisit the stikiness for jax.jit(f, device=device)(x). I think it should fail if x is stiky to another device, to let the user aware of the copies. My humble opinion is that, with an efficient device_put, we should write jitted_f(device_put(x, device)). Improving performance is key for JAX users, and the way JAX let you control what lives where is wonderful. Helping the user be aware of their copies is helping in this direction/Doing copies for the users hurts performance. (The support for `jax_debug_nans` could be improved.) With the C++ jit: └──╼ benchy //third_party/py/jax/benchmarks:api_benchmark name time/op jit_trivial_dispatch 35.2µs ± 2% jit_trivial 36.8µs ± 2% jit_simple_dispatch 16.1µs ± 5% jit_simple 29.3µs ± 8% jit_simple_many_args_dispatch 149µs ± 8% jit_simple_many_args 153µs ± 9% jit_dispatch_without_transfer 144µs ± 7% jit_dispatch_with_transfer 150µs ± 5% Without: └──╼ benchy //third_party/py/jax/benchmarks:api_benchmark name time/op jit_trivial_dispatch 29.9µs ± 2% jit_trivial 31.8µs ± 4% jit_simple_dispatch 59.7µs ± 1% jit_simple 66.9µs ± 5% jit_simple_many_args_dispatch 368µs ± 3% jit_simple_many_args 367µs ± 3% jit_dispatch_without_transfer 346µs ± 9% jit_dispatch_with_transfer 392µs ± 9% See also https://github.com/google/jax/pull/4169 for context. PiperOrigin-RevId: 336716707 12 October 2020, 19:18:04 UTC
639cda2 Merge pull request #4554 from skye:workspace PiperOrigin-RevId: 336713864 12 October 2020, 19:05:27 UTC
b8cac03 Update XLA in WORKSPACE 12 October 2020, 18:57:59 UTC
b13775f Enrich the error messages with the bound names that are available. The user often do not know whether it's not the correct name, or whether it was not defined, etc. It's easier to get this information when debugging. 12 October 2020, 18:33:27 UTC
7c4b935 Merge pull request #4550 from jblespiau:changelist/336694640 PiperOrigin-RevId: 336705548 12 October 2020, 18:26:34 UTC
ee9ca56 Copybara import of the project: -- d70782e0a1c6568eab6ef1573fa277850cc0da97 by Jean-Baptiste Lespiau <jblespiau@google.com>: Gate some jax_jit test with a version check. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/4550 from jblespiau:changelist/336694640 d70782e0a1c6568eab6ef1573fa277850cc0da97 PiperOrigin-RevId: 336700683 12 October 2020, 18:04:58 UTC
5a097f5 Gate some jax_jit test with a version check. 12 October 2020, 18:04:19 UTC
c1e2595 Add support for jax_debug_nans and fix the last few glitches with the C++ jax.jit. - Sorting the keyword arguments must be done on the string, because we go through the Python path which uses flatten() which sort them by string. - Some error with obj == obj which is the same as obj.is(obj) and not obj.equal(obj). - Moves all the Python tests to the C++ tests (which also run on the _python_jit). PiperOrigin-RevId: 336671123 12 October 2020, 15:50:13 UTC
8f64715 Merge pull request #4546 from hawkinsp:flakes PiperOrigin-RevId: 336653760 12 October 2020, 13:48:35 UTC
e1adbcd Fix batching_test flakiness on GPU. 12 October 2020, 13:35:39 UTC
9d8139e Merge pull request #4440 from google:long-line PiperOrigin-RevId: 336638751 12 October 2020, 11:49:55 UTC
7320fc7 Merge pull request #4542 from google:fix-jax2tf-typo PiperOrigin-RevId: 336638720 12 October 2020, 11:46:26 UTC
bcf3209 remove trailing whitespace in jax2tf readme lines 11 October 2020, 17:17:46 UTC
e1708f5 fix typo "differentiaion" 11 October 2020, 17:15:48 UTC
e6ab45e Merge pull request #4538 from google:fix-mypy PiperOrigin-RevId: 336507262 11 October 2020, 05:14:45 UTC
2839b2b ignore mypy error in jax2tf 11 October 2020, 05:01:53 UTC
20e11a0 Merge pull request #4509 from malmaud:pmin_multihost PiperOrigin-RevId: 336504945 11 October 2020, 04:34:44 UTC
0485fd8 adding histogram2d implementation 10 October 2020, 16:03:46 UTC
cff6c0b Improve errors for failed compilations w/ core.concrete_or_error 10 October 2020, 16:03:46 UTC
a97cb87 Merge pull request #4532 from gnecula:tf_omnistaging PiperOrigin-RevId: 336456477 10 October 2020, 15:00:02 UTC
9d9762b [jax2tf] Ensure jax2tf still works without omnistaging This refines changes from PR 4470 to ensure that jax2tf works even without omnistaging. Some tests would fail though, e.g., converting a function with no arguments 10 October 2020, 07:43:17 UTC
0d69a84 Merge pull request #4065 from skye:wheels PiperOrigin-RevId: 336368043 09 October 2020, 21:39:49 UTC
cacb017 Use local version identifiers to distribute cuda jaxlib wheels. This change: * Updates our jaxlib build scripts to add `+cudaXXX` to the wheel version, where XXX is the CUDA version number (e.g. `110`). nocuda builds remain unchanged and do not have this extra identifier. * Adds `generate_release_index.py`, which writes an html page that pip can use to find the cuda wheels. (I based this format off of wheel PyTorch's index). * Updates the README to use the new local version identifier + wheel index. The end result is that the command to install cuda wheels is now much simpler. I manually made copies of the latest jaxlib 0.1.55 wheels that have the local version identifiers, so the new installation commands already work (as well as the old ones, until the next jaxlib release using the new tooling). Fow now, I put the html index to the GCP bucket with the wheels. We can move it to a prettier URL if/when we have one. 09 October 2020, 20:47:54 UTC
ae3d310 Merge pull request #4524 from jakevdp:fix-mean PiperOrigin-RevId: 336350279 09 October 2020, 20:05:37 UTC
a57243c Cleanup: remove extraneous device_put 09 October 2020, 19:57:17 UTC
cd03bea Merge pull request #4526 from google:4510-2 PiperOrigin-RevId: 336332245 09 October 2020, 18:31:47 UTC
78fe038 fix x64 issue 09 October 2020, 18:01:58 UTC
0aeeb63 add test for scan input forwarding 09 October 2020, 17:35:47 UTC
5355776 make scan fwd raw extensive inputs as DeviceArray follow-up on #4517 09 October 2020, 17:26:28 UTC
fc69d45 Merge pull request #4470 from gnecula:tf_tracers PiperOrigin-RevId: 336314262 09 October 2020, 17:06:22 UTC
67a77c7 Fix extraneous dtype warning in jnp.mean 09 October 2020, 16:13:21 UTC
0213efd [jax2tf] Port jax2tf to use omnistaging The main change is that we use `core.new_base_main` to use an omnistaging-based tracer. This has the benefit that we can convert to TF even functions with no arguments (previously they would be constant-folded by JAX prior to the conversion). We also add an explicit error if the jax2tf.convert transformation is nested under other JAX transformations. 09 October 2020, 15:42:28 UTC
e194dff Merge pull request #4518 from gnecula:tf_debug PiperOrigin-RevId: 336298684 09 October 2020, 15:37:52 UTC
a1e67be [jax2tf] Fix the behavior for shift when the shift amount is negative or large. It turns out that TF shift behavior is "implementation defined" if the shift amount is negative or larger or equal to the bitsize of the operand. This is different than XLA. We add conditionals to check and fix the corner cases, we expand the tests. A better solution would be to use an XlaShift operation. This also revealed some bugs with shift_right on TPU (for JAX). The jax2tf behavior is what we believe is the correct behavior, which means that it differs from JAX results on TPU. 09 October 2020, 15:09:08 UTC
e6399cd Merge pull request #4517 from google:issue4510 PiperOrigin-RevId: 336230439 09 October 2020, 05:24:59 UTC
7691be9 make pytype happy 09 October 2020, 04:04:44 UTC
52fe026 optimize scan partial_eval to fix #4510 fixes #4510 09 October 2020, 03:34:34 UTC
d4da9cc Merge pull request #4514 from google:remove-unneeded-code PiperOrigin-RevId: 336191840 08 October 2020, 23:49:22 UTC
ab462ba Merge pull request #4515 from google:enable-conv-test PiperOrigin-RevId: 336186047 08 October 2020, 23:17:38 UTC
4994b82 Merge pull request #4512 from jakevdp:raise-to-shaped PiperOrigin-RevId: 336180373 08 October 2020, 22:46:50 UTC
1614e28 re-enable lax autodiff test xla:cpu bug was due to a change in llvm, now reverted 08 October 2020, 22:43:15 UTC
55be9dd remove mysterious code that is no longer needed 08 October 2020, 22:36:42 UTC
91492ca Merge pull request #4511 from google:debug-nans-test-move PiperOrigin-RevId: 336178046 08 October 2020, 22:33:02 UTC
ef25da8 lax_control_flow: retry via function rather than via loop 08 October 2020, 22:22:16 UTC
008eae6 Fallback to Python for float0 in the C++ jax.jit. PiperOrigin-RevId: 336167762 08 October 2020, 21:39:35 UTC
24de811 move a debug_nans test into debug_nans test file 08 October 2020, 20:34:56 UTC
ddcfd45 Merge pull request #4161 from jakevdp:raise-to-shaped PiperOrigin-RevId: 336150089 08 October 2020, 20:15:48 UTC
3be8188 Merge pull request #4483 from google:debug-nans-no-store-error PiperOrigin-RevId: 336149914 08 October 2020, 20:12:35 UTC
57ec9dd Merge pull request #4483 from google:debug-nans-no-store-error PiperOrigin-RevId: 336147286 08 October 2020, 20:01:26 UTC
09f2be1 wait for result in debug_nans_test 08 October 2020, 20:00:32 UTC
2b76bcc Merge pull request #4508 from zhangqiaorjc:expm_test_slow PiperOrigin-RevId: 336146064 08 October 2020, 19:53:33 UTC
4dd802c Merge branch 'master' into debug-nans-no-store-error 08 October 2020, 19:28:11 UTC
a2a9409 Mark pmin and pmax as multi-host supported. 08 October 2020, 19:12:15 UTC
afef644 Skip 2nd order grad in expm tests. 08 October 2020, 19:09:21 UTC
a2a6dca Merge pull request #4507 from hawkinsp:xla PiperOrigin-RevId: 336135890 08 October 2020, 19:06:15 UTC
a20b00a fix weak_type issues in while_loop 08 October 2020, 18:59:56 UTC
72be97c Update XLA. 08 October 2020, 18:59:06 UTC
0f0fa53 fix weak_type issues in while_loop 08 October 2020, 18:53:52 UTC
3ca6e5a fix weak type issues in scan 08 October 2020, 18:53:52 UTC
6393349 raise_to_shaped: preserve weak_type by default 08 October 2020, 18:53:52 UTC
ab48a9a Merge pull request #4498 from jblespiau:changelist/333713171 PiperOrigin-RevId: 336132107 08 October 2020, 18:49:06 UTC
145ac40 Merge pull request #4502 from hawkinsp:linalg PiperOrigin-RevId: 336117857 08 October 2020, 17:45:01 UTC
684a584 Merge pull request #4500 from LenaMartens:changelist/336082045 PiperOrigin-RevId: 336112168 08 October 2020, 17:20:34 UTC
e3d622c Recast int/bool tangents to float0 in custom_jvp/vjps (also in the initial_style path). 08 October 2020, 16:37:35 UTC
95c0fea Merge pull request #4501 from gnecula:tf_tests_cf1 PiperOrigin-RevId: 336096721 08 October 2020, 16:06:43 UTC
2fca917 Enable TPU linalg tests that now pass. 08 October 2020, 14:59:52 UTC
6fec497 Add a flag to control the C++ jax.jit behavior. 08 October 2020, 14:51:07 UTC
dc9168b [jax2tf] Ensure that in tests TF does not constant-fold in eager mode before compiling 08 October 2020, 14:44:00 UTC
8df116b Merge pull request #4475 from bchetioui:fix_static_tests PiperOrigin-RevId: 336080718 08 October 2020, 14:27:06 UTC
8299598 Fix failing tests on TPU by avoiding using scalars 08 October 2020, 13:37:20 UTC
83060cc Merge pull request #4491 from google:remove-test PiperOrigin-RevId: 336069850 08 October 2020, 13:06:56 UTC
e65115b Merge pull request #4489 from google:test-fixes PiperOrigin-RevId: 336003413 08 October 2020, 02:55:59 UTC
30b8cc5 remove an always-skipped test, redundant by now 08 October 2020, 02:49:06 UTC
back to top