6af4769 | Matthew Johnson | 15 September 2020, 15:00:47 UTC | update version and changelog for pypi (#4294) | 15 September 2020, 15:00:47 UTC |
cefa93f | Peter Hawkins | 15 September 2020, 13:04:54 UTC | Lower LU decomposition to a custom TPU implementation for float32 types. (#4291) | 15 September 2020, 13:04:54 UTC |
58a117f | Benjamin Chetioui | 15 September 2020, 08:45:15 UTC | Modifies eig_p and related operations to take advantage of the new jaxlib geev API (#4266) * Add options to compute L/R eigenvectors in geev. The new arguments are by default set to True to ensure backwards compatibility between jaxlib and jax. Reformulate eig-related operations based on the new geev API. * Addressed hawkinsp's comments from google/jax#3882. Co-authored-by: Skye Wanderman-Milne <skyewm@google.com> | 15 September 2020, 08:45:15 UTC |
a5c2c47 | Benjamin Chetioui | 15 September 2020, 08:40:07 UTC | [jax2tf] Added support for x64 for the remaining test files (#4282) * [jax2tf] Added support for x64 in other test files. This includes: - control_flow_ops_test.py - jax2tf_test.py - saved_model_test.py - stax_test | 15 September 2020, 08:40:07 UTC |
4e04d4e | Benjamin Chetioui | 15 September 2020, 06:35:35 UTC | [jax2tf] Build a primitive harness for test_type_promotion. (#4279) * [jax2tf] Build a primitive harness for test_type_promotion. We were previously generating the cases using `jtu.cases_from_list`, which by default dropped 2 test cases (JAX_NUM_GENERATED_CASES=10, number of generated cases = 12). * [jax2tf] Fix the generated test cases for test_type_promotion. | 15 September 2020, 06:35:35 UTC |
9040336 | Trevor Cai | 15 September 2020, 01:35:41 UTC | Allow device_get to pass Python scalars through unchanged (#4283) * Allow device_get to pass Python scalars through unchanged * address comment | 15 September 2020, 01:35:41 UTC |
7569e80 | Matthew Johnson | 14 September 2020, 19:31:51 UTC | revert #4277 (google failure) (#4281) * revert #4277 (google failure) Some downstream user is relying on the rank of stax's biases being 1. * only revert one change | 14 September 2020, 19:31:51 UTC |
6bd3216 | Stephan Hoyer | 14 September 2020, 09:47:28 UTC | Simplify the interface for host_callback.id_tap (#4101) * Simplify the internal interface for host_callback.id_tap This is a breaking change for `id_tap` users (but not `id_print` users). This makes it easier to use (and type check) ``tap_func``, because the expected signature is now ``tap_func(arg, transforms)`` vs ``tap_func(arg, *, transforms, **kwargs)``. Most of the test changes are just adding whitespace/indentation, but I've also slightly changed the way transformations are printed. | 14 September 2020, 09:47:28 UTC |
2ff3479 | Benjamin Chetioui | 14 September 2020, 09:34:31 UTC | [jax2tf] Fix tests when running with JAX_ENABLE_X64=1. (#4261) Fixed tests: - test_binary_elementwise - dynamic_update_slice - fft - population_count - test_unary_elementwise - top_k - select_and_gather_add | 14 September 2020, 09:34:31 UTC |
7e6d114 | Benjamin Chetioui | 14 September 2020, 08:35:43 UTC | [jax2tf] Add converted primitives without tests to the generated doc. (#4248) * [jax2tf] Add converted primitives without tests to the generated doc. * Ignore some primitives in the output of untested primitives. * Added control_flow_ops_test to template and updated primitives. * Removed svd from the list of missing tests. Was just included because I run the tests using JAX_SKIP_SLOW_TESTS=1, which didn't run the SVD tests. Patched the generated file manually. | 14 September 2020, 08:35:43 UTC |
fa827a5 | Benjamin Chetioui | 14 September 2020, 08:33:58 UTC | [jax2tf] Added the last comments from the jax2tf doc inside the (#4249) correctness_stats code. In principle, all the relevant documentation that was in the doc has been moved to the new documentation & comments of categorize. | 14 September 2020, 08:33:58 UTC |
6e2fa39 | Peter Buchlovsky | 14 September 2020, 08:33:02 UTC | [jax2tf] Fix lax.div_p (#4263) | 14 September 2020, 08:33:02 UTC |
9fc4353 | Roman Novak | 14 September 2020, 02:21:17 UTC | Avoid rank promotion in stax biases (#4277) * Avoid rank promotion in stax biases * remove itertools | 14 September 2020, 02:21:17 UTC |
38b43ef | Roman Novak | 14 September 2020, 02:20:31 UTC | Avoid rank promotion in np.outer (#4276) | 14 September 2020, 02:20:31 UTC |
64bead2 | Alex Minnaar | 12 September 2020, 20:10:01 UTC | fixing typo (#4273) I assume "...one of more type parameters..." was intended to read "...one or more type parameters..." | 12 September 2020, 20:10:01 UTC |
f039f6d | Matthew Johnson | 12 September 2020, 05:40:12 UTC | thread backend in pxla.replicate (#4272) * thread backend in pxla.replicate fixes #4223 * add test for #4223 | 12 September 2020, 05:40:12 UTC |
83b4f3b | Jake Vanderplas | 11 September 2020, 23:49:18 UTC | Cleanup: use _canonicalize_axis() utility where possible (#4270) | 11 September 2020, 23:49:18 UTC |
ee9dccf | Skye Wanderman-Milne | 11 September 2020, 19:12:34 UTC | Move failing CPPJitTest test case to PythonJitTest (#4268) | 11 September 2020, 19:12:34 UTC |
6dc161c | Skye Wanderman-Milne | 11 September 2020, 18:16:54 UTC | Pin pygments version in RTD build. (#4267) This fixes our RTD failures, which were caused by RTD installing an older version of pygments: ``` jupyterlab-pygments 0.1.1 requires pygments<3,>=2.4.1, but you'll have pygments 2.3.1 which is incompatible. nbconvert 6.0.1 requires pygments>=2.4.1, but you'll have pygments 2.3.1 which is incompatible. ``` | 11 September 2020, 18:16:54 UTC |
ca1d8f4 | Alvaro | 11 September 2020, 17:51:42 UTC | Fixing weird behavior in segment_sum when num_segments is None (#4034) Co-authored-by: alvarosg <alvarosg@google.com> | 11 September 2020, 17:51:42 UTC |
3b7329c | Jake Vanderplas | 11 September 2020, 17:04:47 UTC | Call check_arraylike() in jax.numpy reductions (#4195) | 11 September 2020, 17:04:47 UTC |
40fb01b | Adam Paszke | 08 September 2020, 18:04:11 UTC | Extend axis env while translating the pmapped jaxpr to XLA This is normally unnecessary, because the XLA translation usually doesn't bind any of the primitives in the jaxpr, but this is not true in case of scan! Its translation rule reevaluates the jaxpr as a function, and if it contains collectives such as `axis_index` it can fail due to axis being missing. | 11 September 2020, 15:56:32 UTC |
8a18b10 | Jake Vanderplas | 11 September 2020, 15:47:05 UTC | implement jnp.apply_along_axis (#4253) | 11 September 2020, 15:47:05 UTC |
7e694bd | Qiao Zhang | 11 September 2020, 01:28:44 UTC | Update README to point to jaxlib-0.1.55. (#4256) | 11 September 2020, 01:28:44 UTC |
fc39332 | Skye Wanderman-Milne | 10 September 2020, 22:53:04 UTC | Clarify when jaxlib version should be bumped. (#4250) | 10 September 2020, 22:53:04 UTC |
2a33b3d | Jake Vanderplas | 10 September 2020, 18:23:29 UTC | fix documentation typo (#4252) | 10 September 2020, 18:23:29 UTC |
82af356 | Qiao Zhang | 10 September 2020, 17:10:33 UTC | Bump TF hash to get an upstream LLVM GCC fix. (#4251) | 10 September 2020, 17:10:33 UTC |
cf65f6b | Peter Hawkins | 10 September 2020, 15:16:35 UTC | Change lax_linalg.lu to return a permutation representation of the partial pivoting information. (#4241) The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants. | 10 September 2020, 15:16:35 UTC |
b67e42a | Peter Hawkins | 10 September 2020, 13:38:14 UTC | Revert "Revert "Delete batching.last. (#4148)" (#4160)" (#4242) This reverts commit 36846e0ed96cc613e419ac85d9c3d54a49aa9ebc. | 10 September 2020, 13:38:14 UTC |
26a53ae | Qiao Zhang | 10 September 2020, 10:58:28 UTC | Add comments for residuals from f_bwd. (#4244) | 10 September 2020, 10:58:28 UTC |
0962ceb | George Necula | 10 September 2020, 10:55:57 UTC | [jax2tf] Fix test failure on TPUs (#4247) | 10 September 2020, 10:55:57 UTC |
29f97af | Benjamin Chetioui | 10 September 2020, 09:59:44 UTC | [jax2tf] Cleanup test_unary_elementwise. (#4246) * [jax2tf] Cleanup test_unary_elementwise. | 10 September 2020, 09:59:44 UTC |
f0a3fd4 | George Necula | 10 September 2020, 09:19:22 UTC | [jax2tf] Moved the limitations for XlaSort to correctness_stats (#4237) | 10 September 2020, 09:19:22 UTC |
adb3448 | Qiao Zhang | 10 September 2020, 00:23:43 UTC | Reorder nocuda/cuda build to fail early. (#4243) | 10 September 2020, 00:23:43 UTC |
0b04439 | Qiao Zhang | 10 September 2020, 00:16:58 UTC | Update install_cuda script to specify cublas. (#4240) | 10 September 2020, 00:16:58 UTC |
a14133a | Qiao Zhang | 09 September 2020, 23:36:30 UTC | Update TF dep to a passing commit hash. (#4239) | 09 September 2020, 23:36:30 UTC |
3f8aaab | Adam Paszke | 08 September 2020, 16:10:35 UTC | Interrupt lu transformation generators whenever an exception occurs This fixes some errors that have been appearing in our CI from time to time. All transformations are implemented as generators, but they haven't been explicitly aborted when an exception has been raised. Instead, they only got closed when they got garbage collected, which could happen at an unspecified later time, potentially leading to a corruption of global state, which could have been modified after the exception was handled. Note that this implementation doesn't propagate the original exception into the argument transformations, and doesn't allow them to handle the error either. Such an extension would be possible, but throwing an exception into a generator mutates the exception object, clobbering the nice traceback that we would usually carry. One can work around those issues, but it feels really hacky and we don't need it right now anyway, so I figured we'll be better off with the simple thing for the time being. | 09 September 2020, 18:43:05 UTC |
70891f4 | Benjamin Chetioui | 09 September 2020, 14:48:00 UTC | [jax2tf] Add a template file for documentation generation. (#4219) * [jax2tf] Add a template file for documentation generation. The documentation now gives instructions about how to regenerate it, as well as when it was last generated. * Added a list of conversions that are not yet implemented. | 09 September 2020, 14:48:00 UTC |
053cd5a | Benjamin Chetioui | 09 September 2020, 12:43:36 UTC | [jax2tf] Clean up test_dynamic_slice. (#4236) * [jax2tf] Clean up test_dynamic_slice. With the XLA nested compilation bug fixed, this should now work fine. | 09 September 2020, 12:43:36 UTC |
bff24bd | Roman Ring | 09 September 2020, 12:02:45 UTC | Add axis_index_groups support to all_gather. (#4194) | 09 September 2020, 12:02:45 UTC |
f908f6f | Benjamin Chetioui | 09 September 2020, 10:20:59 UTC | [jax2tf] Updated test_pad to test all dtypes and remove old (#4235) skipped test. | 09 September 2020, 10:20:59 UTC |
ee38e71 | George Necula | 09 September 2020, 08:34:22 UTC | [jax2tf] Clean up code for XlaGather, experimental_compile not necessary (#4030) * [jax2tf] Clean up code for XlaGather, experimental_compile not necessary Now that XlaGather has been fixed in XLA, we do not need to use experimental_compile workaround (which was not working anyway when put in a SavedModel). This fix requires a recent tf-nightly installation. | 09 September 2020, 08:34:22 UTC |
3cf7336 | Qiao Zhang | 09 September 2020, 04:28:34 UTC | Fix Dockerfile wheel installation issues. (#4232) | 09 September 2020, 04:28:34 UTC |
745d90d | Matthew Johnson | 09 September 2020, 04:14:25 UTC | improve lax.pad shape rule (#4234) It's now: * better tested * better at catching errors * faster * easier to read | 09 September 2020, 04:14:25 UTC |
cf2d15d | Skye Wanderman-Milne | 09 September 2020, 01:23:42 UTC | jaxlib build fixes. (#4066) 1. `wheel.pep425tags` has been removed as of https://github.com/pypa/setuptools/pull/1829. Use the new `packaging.tags` instead. 2. Add `--allow-downgrades` to cuda install command. I'm not sure this is always necessary, but I ran into it, I'm guessing due to a cached docker image. | 09 September 2020, 01:23:42 UTC |
4600dd7 | Qiao Zhang | 09 September 2020, 00:20:48 UTC | Update jaxlib version for dlpack fix. (#4231) | 09 September 2020, 00:20:48 UTC |
9a70be2 | Jake Vanderplas | 08 September 2020, 20:30:57 UTC | Add test for dtype coverage of jax.numpy ufuncs (#3913) | 08 September 2020, 20:30:57 UTC |
7f3078b | Matthew Johnson | 08 September 2020, 15:54:13 UTC | updtate version and changelog for pypi (#4224) | 08 September 2020, 15:54:13 UTC |
ed0d8c0 | Matthew Johnson | 08 September 2020, 15:27:41 UTC | tweak lax.py shape broadcasting logic (#4217) This new implementation is faster, and works for polymorphic shapes without weird tricks. (This new implementation is faster even if we remove the weird tricks for polymorphism.) | 08 September 2020, 15:27:41 UTC |
798a264 | Benjamin Chetioui | 08 September 2020, 08:32:53 UTC | [jax2tf] Fix bug in population count and move expect_tf_exception (#4214) into correctness stats. The code was using `tf.bitcast` instead of `tf.cast`, but using `expect_tf_exception` in every case was hiding the errors. | 08 September 2020, 08:32:53 UTC |
e1340f3 | Benjamin Chetioui | 07 September 2020, 15:12:35 UTC | [jax2tf] Fix missing complex64 TPU corner case of scatter_{add,mul} (#4213) | 07 September 2020, 15:12:35 UTC |
0aed1f4 | Adam Paszke | 07 September 2020, 12:31:58 UTC | Add more context to the axis_frame error message. Some of the vmap and gmap collective tests have been failing on master and I can't seem to be able to reproduce them locally. Hopefully, if this happens again, this extra bit of information will be useful in debugging the problem. | 07 September 2020, 14:25:30 UTC |
4413bb8 | George Necula | 07 September 2020, 14:13:11 UTC | [jax2tf] Do not use jax.random.PRNGKey before in primitive harness (#4211) We cannot execute JAX functions before the program is initialized | 07 September 2020, 14:13:11 UTC |
be8ea14 | Benjamin Chetioui | 07 September 2020, 13:47:18 UTC | [jax2tf] Expand coverage of primitives by categorize. (#4209) * [jax2tf] Expand coverage of primitives by categorize. This commit adds handling logic for the limitations of: - qr - svd - select_and_gather_add - reduce_window/reduce_window_{min,max,sum} - add - mul - scatter/scatter_{min,max,mul,add} Also fixes a bug in a call to _infer_shape_jax, which wasn't compatible with boolean operands and went undetected due to the high-level handling of TF exceptions in higher-order primitives. | 07 September 2020, 13:47:18 UTC |
1e84cbe | George Necula | 07 September 2020, 11:41:50 UTC | [jax2tf] Fix random.split when jax_exable_x64 (#4208) Since we do the threefry with signed integers when converting to TF, we run into the type promotion 'uint32 - int32 = int64', which then results in lax.shift_right_logical(uint32, int64), which fails. | 07 September 2020, 11:41:50 UTC |
6c62935 | Benjamin Chetioui | 07 September 2020, 09:03:00 UTC | [jax2tf] Cleanup the correctness stats layout. (#4201) * [jax2tf] Cleanup the correctness stats layout. * Added Google license at the top of the file. * Cleanup: fix docstring for 80 char boundary. * Monkey patch/cleanup outside of the loop. * Removed tensorflow dependency. * Fixed the name of attributes of Limitation. | 07 September 2020, 09:03:00 UTC |
c6e6ee2 | George Necula | 07 September 2020, 08:26:52 UTC | [jax2tf] Use the JAX impl rule for threefry instead of writing our own (#4204) * performance is the same | 07 September 2020, 08:26:52 UTC |
96278e6 | AdrienCorenflos | 04 September 2020, 16:21:43 UTC | Add reverse flag in associative scan (#4181) Add optional 'reverse' argument in associative scan | 04 September 2020, 16:21:43 UTC |
bcf9777 | Benjamin Chetioui | 03 September 2020, 13:56:22 UTC | [jax2tf] Generator for the documentation of operations with limited support (WIP) (#4193) * [jax2tf] Draft of a generator for the documentation of operations with limited support. | 03 September 2020, 13:56:22 UTC |
abdd138 | George Necula | 03 September 2020, 11:24:04 UTC | [jax2tf] Flip the with_gradient=True; was flipped back by mistake (#4200) | 03 September 2020, 11:24:04 UTC |
5eac477 | George Necula | 03 September 2020, 11:18:35 UTC | [jax2tf] Implementation of random_gamma (#4192) * [jax2tf] implementation of random_gamma The simplest implementation is by converting the JAX own impl_rule, which rewrites gamma into other JAX primitives. On TPU with use_vmap=True the performance is the same for JAX and TF, provided we use tf.function(compile=True). | 03 September 2020, 11:18:35 UTC |
708d07d | Alex Riley | 02 September 2020, 23:13:17 UTC | Add jax.numpy.array_split (#4197) | 02 September 2020, 23:13:17 UTC |
04f9a7e | Matthew Johnson | 02 September 2020, 01:16:20 UTC | better jax.numpy.tile implementation (#4190) Use reshape, broadcast_to, reshape. | 02 September 2020, 01:16:20 UTC |
421550a | Jake Vanderplas | 01 September 2020, 22:48:40 UTC | copysign: promote to inexact to match numpy & support unsigned inputs (#4188) | 01 September 2020, 22:48:40 UTC |
0cdb1f7 | Benjamin Chetioui | 01 September 2020, 07:35:25 UTC | [jax2tf] Indicate the version of TF used in tests in README. (#4185) | 01 September 2020, 07:35:25 UTC |
bdd6545 | Jean-Baptiste Lespiau | 01 September 2020, 07:34:47 UTC | Add more features to the C++ jax.jit. (#4169) This mainly follows https://github.com/google/jax/pull/4089 by adding: - support for disable_jit from C++ - support for jax._cpp_jit on methods. - supporting applying @jax.jit on top-level functions, by delaying the retrieval of the device and backend. - concurrency support. I am not aware of any feature missing (but I suspect there are still some differences due to the differences between xla_computation and _xla_callable.) See: - https://i.ibb.co/ZMvZ4nK/benchmark.png for the benchmarking comparison (see cr/328899906 + benchmarks for how numbers were generated) - The results of the Jax tests when enabling this: http://sponge2/4a67d132-209f-45c5-ab7b-83716d329ec2 (110 fails, 92 passes, but many common cause of failure). | 01 September 2020, 07:34:47 UTC |
36368a2 | Jake Vanderplas | 31 August 2020, 21:11:49 UTC | jnp.abs(): support boolean inputs (#4186) | 31 August 2020, 21:11:49 UTC |
44bcf7e | Hamza Merzić | 31 August 2020, 14:00:34 UTC | Fix axis checking and remove extra print statement (#4184) A series of PRs renaming the frame entries have been submitted, one of them introducing a bug when using omnistaging. This PR fixes that and removes a print comment (assuming added for debugging purposes). | 31 August 2020, 14:00:34 UTC |
b6b1f5e | George Necula | 31 August 2020, 07:26:32 UTC | [jax2tf] Turn on with_gradient by default (#4180) As I was writing the demo I realized that it makes more sense for with_gradient to be set to True by default. I have also fixed a bug with tie_in in omnistaging. | 31 August 2020, 07:26:32 UTC |
634c625 | George Necula | 30 August 2020, 09:38:14 UTC | More renaming of master to main in JAX internals (#4179) | 30 August 2020, 09:38:14 UTC |
ffbfadd | Jake Vanderplas | 30 August 2020, 08:36:47 UTC | lax.associative_scan: fix docstring examples (#4172) * lax.associative_scan: fix docstring examples * add verbiage from #3583 | 30 August 2020, 08:36:47 UTC |
6b6789a | Matthew Johnson | 30 August 2020, 08:16:51 UTC | applied simple find+sed for 'master' -> 'main' (#4174) * applied simple find+sed for 'master' -> 'main' * Rename master->main in JAX API and internals (#4178) * Started with #4174 * Renamed Trace.master to Trace.main * Renamed core.new_master and core.new_base_master Co-authored-by: George Necula <gcnecula@gmail.com> | 30 August 2020, 08:16:51 UTC |
1a87fd3 | Benjamin Chetioui | 29 August 2020, 08:24:03 UTC | Implement a proper shape checking rule for gather. (#4166) * Implement a proper shape checking rule for gather. The implementation is based on the corresponding shape inference code in `tensorflow/compiler/xla/service/shape_inference.cc`. The tests added in `tests/lax_test.py` are similarly mirroring the corresponding tests in tensorflow, with slight adaptations for the particular setting of JAX. Fixes google/jax#2826, and in principle fixes google/jax#4154 and google/jax#3905. * Extracted common functions for gather/scatter shape checking rules. | 29 August 2020, 08:24:03 UTC |
a33f4dd | Adam Paszke | 28 August 2020, 18:03:39 UTC | Add support for axis_index inside vmap (#4168) Also, reorganize the code to put all `axis_index` related functions in `lax_parallel.py`, next to all other parallel collectives. | 28 August 2020, 18:03:39 UTC |
1dab791 | Jake Vanderplas | 28 August 2020, 16:07:30 UTC | Avoid calling jnp.sum() on list (#4163) | 28 August 2020, 16:07:30 UTC |
04f9ff7 | Benjamin Chetioui | 28 August 2020, 14:27:32 UTC | Addition of one more conclusive polynomial comparison case. (#4167) * Addition of one more conclusive polynomial comparison case. In the case when the difference between two polynomials is a constant, it is possible to conclusively compare them. This commit adds such a case to masking.Poly.__ge__. * Added a few relevant tests in tests.masking_test.test_Poly_compare. | 28 August 2020, 14:27:32 UTC |
7210d6f | Adam Paszke | 18 August 2020, 09:14:38 UTC | Add support for binding axis_name in gmap This allows executing collectives over the gmapped axes. This requires some extra manipulation of the gmapped jaxpr, since gmap exposes a single logical axis name, but evaluates the program using multiple "physical" axes. This also fixes some bugs around handling `multiple_returns` in vmap collective implementation. | 28 August 2020, 12:42:01 UTC |
e95d570 | Jean-Baptiste Lespiau | 27 August 2020, 14:02:13 UTC | Add benchmarks for specifically the dispatch time. (#4128) The goal is to distinguish the time it takes for `jitted_f` to return, and the time it takes to return and wait for the result. We also add one to distinguish the time it takes to call the function with the argument transfer or without it. e.g. name time/op jit_trivial_dispatch 28.9µs ± 2% jit_trivial 31.5µs ± 5% jit_simple_dispatch 60.7µs ± 4% jit_simple 129µs ±24% jit_simple_many_args_disptch 390µs ±19% jit_simple_many_args 388µs ±16% jit_dispatch_without_transfer 379µs ± 6% jit_dispatch_with_transfer 450µs ± 5% | 27 August 2020, 14:02:13 UTC |
36846e0 | George Necula | 27 August 2020, 09:45:48 UTC | Revert "Delete batching.last. (#4148)" (#4160) This reverts commit 4bf3d6e9cccc5de3834e37affae2012e6e3d3180. This commit fails internal tests. | 27 August 2020, 09:45:48 UTC |
a7faf09 | Benjamin Chetioui | 27 August 2020, 09:24:13 UTC | [jax2tf] Added conversion for scatter*_p primitives. (#4091) * [jax2tf] Added conversion for scatter*_p primitives. Limitations: the conversion works as well as the conversion of the underlying reduction functions (e.g. lax.scatter_max is not properly converted for the int8 dtype, because tf.math.maximum is not defined for int8 tensors); the conversion can not take advantage of the unique_indices parameter. This does not affect correctness, but may affect performance on certain platforms (as stated in the documentation of lax.scatter). * Put tf.function experimental compile wrapper back on scatter. * Removed unique_indices=True test cases * Remove non-deterministic test cases from the scatter harness. This commit also documents the reasons for ignoring these test cases and potential pitfalls, in case someone needs to perform these tests at a later time. | 27 August 2020, 09:24:13 UTC |
4d7396a | Benjamin Chetioui | 27 August 2020, 09:04:32 UTC | Implement a proper shape checking rule for scatter. (#4144) The implementation is based on the corresponding shape inference code in `tensorflow/compiler/xla/service/shape_inference.cc`. The tests added in `tests/lax_test.py` are similarly mirroring the corresponding tests in tensorflow, with slight adaptations for the particular setting of JAX. | 27 August 2020, 09:04:32 UTC |
80114e5 | Benjamin Chetioui | 27 August 2020, 07:47:19 UTC | Add a boolean to _check_shapelike to accept or reject shapes (#4108) * Add a boolean to _check_shapelike to accept or reject shapes corresponding to arrays of 0 elements. (Fixes google/jax#3972). * Added test for failures referenced in issue 3972. | 27 August 2020, 07:47:19 UTC |
1dc71b2 | Benjamin Chetioui | 27 August 2020, 07:46:32 UTC | [jax2tf] Add testing for add/mul/min/max conversion. (#4142) * [jax2tf] Add testing for add/mul/min/max conversion. Only certain types are supported for each of the operations above. This commit adds previously missing tests to make this explicit. | 27 August 2020, 07:46:32 UTC |
c76b84f | George Necula | 27 August 2020, 07:34:53 UTC | Revert "Increase tolerance for CPU test LaxBackedNumpyTests::testCorrCoef (#4080)" (#4151) This reverts commit 22b92c5122ab5af6f5e4560f9be08f5649ae7653. We revert this change because the LLVM bug that made us relax the test tolerance is now fixed. | 27 August 2020, 07:34:53 UTC |
57f49b6 | George Necula | 27 August 2020, 07:05:24 UTC | Fix bug in omnistaging_enabler (#4159) This code was failing with "KeyError: psum" for the tests "//third_party/py/flax/...". I suspect that the error is due to the ordering of the omnistaging enablers, changed in #4152. I am not sure of this fix, but this seemed to be enough for all the presubmit tests to pass and allow the copybara import. | 27 August 2020, 07:05:24 UTC |
417c9ff | George Necula | 27 August 2020, 06:41:16 UTC | Fix pytype error (#4158) | 27 August 2020, 06:41:16 UTC |
29073be | Jake Vanderplas | 27 August 2020, 04:13:33 UTC | cleanup: remove duplicate line (#4156) | 27 August 2020, 04:13:33 UTC |
f0fb7d0 | Tom Hennigan | 26 August 2020, 21:06:27 UTC | Use omnistaging env var even when not using absl flags for config. (#4152) | 26 August 2020, 21:06:27 UTC |
1d93991 | Matthew Johnson | 26 August 2020, 17:21:56 UTC | allow random.choice to accept ndarray input (#4145) * allow random.choice to accept ndarray `a` follow-up to #4137 to allow ndarray inputs to be passed * add jax.random.choice tests to cover ndarray input * don't use callables in test params it can mess with pytest-xdist because of hashing by id | 26 August 2020, 17:21:56 UTC |
01319fb | Peter Hawkins | 25 August 2020, 17:05:06 UTC | Speed up and clean up geomspace test. (#4149) * Speed up and clean up geomspace test. | 25 August 2020, 17:05:06 UTC |
4bf3d6e | Peter Hawkins | 25 August 2020, 16:53:18 UTC | Delete batching.last. (#4148) A -1 axis works just as well at head. | 25 August 2020, 16:53:18 UTC |
8c8060e | Peter Hawkins | 25 August 2020, 16:53:02 UTC | Remove workaround for illegal vmap out_axes. (#4147) | 25 August 2020, 16:53:02 UTC |
6d54eb5 | Jake Vanderplas | 25 August 2020, 12:47:43 UTC | Do not call asarray() on inputs of jax.random.choice (#4137) | 25 August 2020, 12:47:43 UTC |
f959219 | Jean-Baptiste Lespiau | 25 August 2020, 12:39:45 UTC | Rename collectives into "collective operations" for the pmap function. (#4136) It is just because it serves as the entry point, and this term leads to good Google results, such as https://en.wikipedia.org/wiki/Collective_operation, while the current "collectives" do not. | 25 August 2020, 12:39:45 UTC |
f4b05bc | Matthew Johnson | 25 August 2020, 12:38:41 UTC | make pe.abstract_eval_fun use omnistaging (#4139) | 25 August 2020, 12:38:41 UTC |
04173b3 | Matthew Johnson | 25 August 2020, 12:38:20 UTC | Merge pull request #4140 from sharadmv/patch-2 Remove frame check assertion in `extend_axis_env`. | 25 August 2020, 12:38:20 UTC |
774b5f6 | Sharad Vikram | 25 August 2020, 04:08:23 UTC | Remove frame check assertion in `extend_axis_env`. | 25 August 2020, 04:13:30 UTC |
e06a6ab | Peter Hawkins | 25 August 2020, 00:21:19 UTC | Add support for negative axes to vmap. (#4111) * Add support for negative axes to vmap. * Add workaround for out-of-range vmap axes. | 25 August 2020, 00:21:19 UTC |
603f0c1 | Philipp Thölke | 24 August 2020, 20:50:44 UTC | Fix scan carry types in gradient of complex ODE (#4130) * Cast t_bar from potential complex to float in ode.py * Add test case for complex odeint (currently failing) * Wrap odeint into complex-to-real function in test case * fixup Co-authored-by: Stephan Hoyer <shoyer@google.com> | 24 August 2020, 20:50:44 UTC |
0cc3802 | Benjamin Chetioui | 24 August 2020, 19:29:22 UTC | Fix documentation of scatter_* operations. (#4138) * Fix documentation of scatter_* operations. This commit changes the documentation of the `unique_indices` parameter to scatter to better capture its intended meaning in XLA. | 24 August 2020, 19:29:22 UTC |