https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
6af4769 update version and changelog for pypi (#4294) 15 September 2020, 15:00:47 UTC
cefa93f Lower LU decomposition to a custom TPU implementation for float32 types. (#4291) 15 September 2020, 13:04:54 UTC
58a117f Modifies eig_p and related operations to take advantage of the new jaxlib geev API (#4266) * Add options to compute L/R eigenvectors in geev. The new arguments are by default set to True to ensure backwards compatibility between jaxlib and jax. Reformulate eig-related operations based on the new geev API. * Addressed hawkinsp's comments from google/jax#3882. Co-authored-by: Skye Wanderman-Milne <skyewm@google.com> 15 September 2020, 08:45:15 UTC
a5c2c47 [jax2tf] Added support for x64 for the remaining test files (#4282) * [jax2tf] Added support for x64 in other test files. This includes: - control_flow_ops_test.py - jax2tf_test.py - saved_model_test.py - stax_test 15 September 2020, 08:40:07 UTC
4e04d4e [jax2tf] Build a primitive harness for test_type_promotion. (#4279) * [jax2tf] Build a primitive harness for test_type_promotion. We were previously generating the cases using `jtu.cases_from_list`, which by default dropped 2 test cases (JAX_NUM_GENERATED_CASES=10, number of generated cases = 12). * [jax2tf] Fix the generated test cases for test_type_promotion. 15 September 2020, 06:35:35 UTC
9040336 Allow device_get to pass Python scalars through unchanged (#4283) * Allow device_get to pass Python scalars through unchanged * address comment 15 September 2020, 01:35:41 UTC
7569e80 revert #4277 (google failure) (#4281) * revert #4277 (google failure) Some downstream user is relying on the rank of stax's biases being 1. * only revert one change 14 September 2020, 19:31:51 UTC
6bd3216 Simplify the interface for host_callback.id_tap (#4101) * Simplify the internal interface for host_callback.id_tap This is a breaking change for `id_tap` users (but not `id_print` users). This makes it easier to use (and type check) ``tap_func``, because the expected signature is now ``tap_func(arg, transforms)`` vs ``tap_func(arg, *, transforms, **kwargs)``. Most of the test changes are just adding whitespace/indentation, but I've also slightly changed the way transformations are printed. 14 September 2020, 09:47:28 UTC
2ff3479 [jax2tf] Fix tests when running with JAX_ENABLE_X64=1. (#4261) Fixed tests: - test_binary_elementwise - dynamic_update_slice - fft - population_count - test_unary_elementwise - top_k - select_and_gather_add 14 September 2020, 09:34:31 UTC
7e6d114 [jax2tf] Add converted primitives without tests to the generated doc. (#4248) * [jax2tf] Add converted primitives without tests to the generated doc. * Ignore some primitives in the output of untested primitives. * Added control_flow_ops_test to template and updated primitives. * Removed svd from the list of missing tests. Was just included because I run the tests using JAX_SKIP_SLOW_TESTS=1, which didn't run the SVD tests. Patched the generated file manually. 14 September 2020, 08:35:43 UTC
fa827a5 [jax2tf] Added the last comments from the jax2tf doc inside the (#4249) correctness_stats code. In principle, all the relevant documentation that was in the doc has been moved to the new documentation & comments of categorize. 14 September 2020, 08:33:58 UTC
6e2fa39 [jax2tf] Fix lax.div_p (#4263) 14 September 2020, 08:33:02 UTC
9fc4353 Avoid rank promotion in stax biases (#4277) * Avoid rank promotion in stax biases * remove itertools 14 September 2020, 02:21:17 UTC
38b43ef Avoid rank promotion in np.outer (#4276) 14 September 2020, 02:20:31 UTC
64bead2 fixing typo (#4273) I assume "...one of more type parameters..." was intended to read "...one or more type parameters..." 12 September 2020, 20:10:01 UTC
f039f6d thread backend in pxla.replicate (#4272) * thread backend in pxla.replicate fixes #4223 * add test for #4223 12 September 2020, 05:40:12 UTC
83b4f3b Cleanup: use _canonicalize_axis() utility where possible (#4270) 11 September 2020, 23:49:18 UTC
ee9dccf Move failing CPPJitTest test case to PythonJitTest (#4268) 11 September 2020, 19:12:34 UTC
6dc161c Pin pygments version in RTD build. (#4267) This fixes our RTD failures, which were caused by RTD installing an older version of pygments: ``` jupyterlab-pygments 0.1.1 requires pygments<3,>=2.4.1, but you'll have pygments 2.3.1 which is incompatible. nbconvert 6.0.1 requires pygments>=2.4.1, but you'll have pygments 2.3.1 which is incompatible. ``` 11 September 2020, 18:16:54 UTC
ca1d8f4 Fixing weird behavior in segment_sum when num_segments is None (#4034) Co-authored-by: alvarosg <alvarosg@google.com> 11 September 2020, 17:51:42 UTC
3b7329c Call check_arraylike() in jax.numpy reductions (#4195) 11 September 2020, 17:04:47 UTC
40fb01b Extend axis env while translating the pmapped jaxpr to XLA This is normally unnecessary, because the XLA translation usually doesn't bind any of the primitives in the jaxpr, but this is not true in case of scan! Its translation rule reevaluates the jaxpr as a function, and if it contains collectives such as `axis_index` it can fail due to axis being missing. 11 September 2020, 15:56:32 UTC
8a18b10 implement jnp.apply_along_axis (#4253) 11 September 2020, 15:47:05 UTC
7e694bd Update README to point to jaxlib-0.1.55. (#4256) 11 September 2020, 01:28:44 UTC
fc39332 Clarify when jaxlib version should be bumped. (#4250) 10 September 2020, 22:53:04 UTC
2a33b3d fix documentation typo (#4252) 10 September 2020, 18:23:29 UTC
82af356 Bump TF hash to get an upstream LLVM GCC fix. (#4251) 10 September 2020, 17:10:33 UTC
cf65f6b Change lax_linalg.lu to return a permutation representation of the partial pivoting information. (#4241) The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants. 10 September 2020, 15:16:35 UTC
b67e42a Revert "Revert "Delete batching.last. (#4148)" (#4160)" (#4242) This reverts commit 36846e0ed96cc613e419ac85d9c3d54a49aa9ebc. 10 September 2020, 13:38:14 UTC
26a53ae Add comments for residuals from f_bwd. (#4244) 10 September 2020, 10:58:28 UTC
0962ceb [jax2tf] Fix test failure on TPUs (#4247) 10 September 2020, 10:55:57 UTC
29f97af [jax2tf] Cleanup test_unary_elementwise. (#4246) * [jax2tf] Cleanup test_unary_elementwise. 10 September 2020, 09:59:44 UTC
f0a3fd4 [jax2tf] Moved the limitations for XlaSort to correctness_stats (#4237) 10 September 2020, 09:19:22 UTC
adb3448 Reorder nocuda/cuda build to fail early. (#4243) 10 September 2020, 00:23:43 UTC
0b04439 Update install_cuda script to specify cublas. (#4240) 10 September 2020, 00:16:58 UTC
a14133a Update TF dep to a passing commit hash. (#4239) 09 September 2020, 23:36:30 UTC
3f8aaab Interrupt lu transformation generators whenever an exception occurs This fixes some errors that have been appearing in our CI from time to time. All transformations are implemented as generators, but they haven't been explicitly aborted when an exception has been raised. Instead, they only got closed when they got garbage collected, which could happen at an unspecified later time, potentially leading to a corruption of global state, which could have been modified after the exception was handled. Note that this implementation doesn't propagate the original exception into the argument transformations, and doesn't allow them to handle the error either. Such an extension would be possible, but throwing an exception into a generator mutates the exception object, clobbering the nice traceback that we would usually carry. One can work around those issues, but it feels really hacky and we don't need it right now anyway, so I figured we'll be better off with the simple thing for the time being. 09 September 2020, 18:43:05 UTC
70891f4 [jax2tf] Add a template file for documentation generation. (#4219) * [jax2tf] Add a template file for documentation generation. The documentation now gives instructions about how to regenerate it, as well as when it was last generated. * Added a list of conversions that are not yet implemented. 09 September 2020, 14:48:00 UTC
053cd5a [jax2tf] Clean up test_dynamic_slice. (#4236) * [jax2tf] Clean up test_dynamic_slice. With the XLA nested compilation bug fixed, this should now work fine. 09 September 2020, 12:43:36 UTC
bff24bd Add axis_index_groups support to all_gather. (#4194) 09 September 2020, 12:02:45 UTC
f908f6f [jax2tf] Updated test_pad to test all dtypes and remove old (#4235) skipped test. 09 September 2020, 10:20:59 UTC
ee38e71 [jax2tf] Clean up code for XlaGather, experimental_compile not necessary (#4030) * [jax2tf] Clean up code for XlaGather, experimental_compile not necessary Now that XlaGather has been fixed in XLA, we do not need to use experimental_compile workaround (which was not working anyway when put in a SavedModel). This fix requires a recent tf-nightly installation. 09 September 2020, 08:34:22 UTC
3cf7336 Fix Dockerfile wheel installation issues. (#4232) 09 September 2020, 04:28:34 UTC
745d90d improve lax.pad shape rule (#4234) It's now: * better tested * better at catching errors * faster * easier to read 09 September 2020, 04:14:25 UTC
cf2d15d jaxlib build fixes. (#4066) 1. `wheel.pep425tags` has been removed as of https://github.com/pypa/setuptools/pull/1829. Use the new `packaging.tags` instead. 2. Add `--allow-downgrades` to cuda install command. I'm not sure this is always necessary, but I ran into it, I'm guessing due to a cached docker image. 09 September 2020, 01:23:42 UTC
4600dd7 Update jaxlib version for dlpack fix. (#4231) 09 September 2020, 00:20:48 UTC
9a70be2 Add test for dtype coverage of jax.numpy ufuncs (#3913) 08 September 2020, 20:30:57 UTC
7f3078b updtate version and changelog for pypi (#4224) 08 September 2020, 15:54:13 UTC
ed0d8c0 tweak lax.py shape broadcasting logic (#4217) This new implementation is faster, and works for polymorphic shapes without weird tricks. (This new implementation is faster even if we remove the weird tricks for polymorphism.) 08 September 2020, 15:27:41 UTC
798a264 [jax2tf] Fix bug in population count and move expect_tf_exception (#4214) into correctness stats. The code was using `tf.bitcast` instead of `tf.cast`, but using `expect_tf_exception` in every case was hiding the errors. 08 September 2020, 08:32:53 UTC
e1340f3 [jax2tf] Fix missing complex64 TPU corner case of scatter_{add,mul} (#4213) 07 September 2020, 15:12:35 UTC
0aed1f4 Add more context to the axis_frame error message. Some of the vmap and gmap collective tests have been failing on master and I can't seem to be able to reproduce them locally. Hopefully, if this happens again, this extra bit of information will be useful in debugging the problem. 07 September 2020, 14:25:30 UTC
4413bb8 [jax2tf] Do not use jax.random.PRNGKey before in primitive harness (#4211) We cannot execute JAX functions before the program is initialized 07 September 2020, 14:13:11 UTC
be8ea14 [jax2tf] Expand coverage of primitives by categorize. (#4209) * [jax2tf] Expand coverage of primitives by categorize. This commit adds handling logic for the limitations of: - qr - svd - select_and_gather_add - reduce_window/reduce_window_{min,max,sum} - add - mul - scatter/scatter_{min,max,mul,add} Also fixes a bug in a call to _infer_shape_jax, which wasn't compatible with boolean operands and went undetected due to the high-level handling of TF exceptions in higher-order primitives. 07 September 2020, 13:47:18 UTC
1e84cbe [jax2tf] Fix random.split when jax_exable_x64 (#4208) Since we do the threefry with signed integers when converting to TF, we run into the type promotion 'uint32 - int32 = int64', which then results in lax.shift_right_logical(uint32, int64), which fails. 07 September 2020, 11:41:50 UTC
6c62935 [jax2tf] Cleanup the correctness stats layout. (#4201) * [jax2tf] Cleanup the correctness stats layout. * Added Google license at the top of the file. * Cleanup: fix docstring for 80 char boundary. * Monkey patch/cleanup outside of the loop. * Removed tensorflow dependency. * Fixed the name of attributes of Limitation. 07 September 2020, 09:03:00 UTC
c6e6ee2 [jax2tf] Use the JAX impl rule for threefry instead of writing our own (#4204) * performance is the same 07 September 2020, 08:26:52 UTC
96278e6 Add reverse flag in associative scan (#4181) Add optional 'reverse' argument in associative scan 04 September 2020, 16:21:43 UTC
bcf9777 [jax2tf] Generator for the documentation of operations with limited support (WIP) (#4193) * [jax2tf] Draft of a generator for the documentation of operations with limited support. 03 September 2020, 13:56:22 UTC
abdd138 [jax2tf] Flip the with_gradient=True; was flipped back by mistake (#4200) 03 September 2020, 11:24:04 UTC
5eac477 [jax2tf] Implementation of random_gamma (#4192) * [jax2tf] implementation of random_gamma The simplest implementation is by converting the JAX own impl_rule, which rewrites gamma into other JAX primitives. On TPU with use_vmap=True the performance is the same for JAX and TF, provided we use tf.function(compile=True). 03 September 2020, 11:18:35 UTC
708d07d Add jax.numpy.array_split (#4197) 02 September 2020, 23:13:17 UTC
04f9a7e better jax.numpy.tile implementation (#4190) Use reshape, broadcast_to, reshape. 02 September 2020, 01:16:20 UTC
421550a copysign: promote to inexact to match numpy & support unsigned inputs (#4188) 01 September 2020, 22:48:40 UTC
0cdb1f7 [jax2tf] Indicate the version of TF used in tests in README. (#4185) 01 September 2020, 07:35:25 UTC
bdd6545 Add more features to the C++ jax.jit. (#4169) This mainly follows https://github.com/google/jax/pull/4089 by adding: - support for disable_jit from C++ - support for jax._cpp_jit on methods. - supporting applying @jax.jit on top-level functions, by delaying the retrieval of the device and backend. - concurrency support. I am not aware of any feature missing (but I suspect there are still some differences due to the differences between xla_computation and _xla_callable.) See: - https://i.ibb.co/ZMvZ4nK/benchmark.png for the benchmarking comparison (see cr/328899906 + benchmarks for how numbers were generated) - The results of the Jax tests when enabling this: http://sponge2/4a67d132-209f-45c5-ab7b-83716d329ec2 (110 fails, 92 passes, but many common cause of failure). 01 September 2020, 07:34:47 UTC
36368a2 jnp.abs(): support boolean inputs (#4186) 31 August 2020, 21:11:49 UTC
44bcf7e Fix axis checking and remove extra print statement (#4184) A series of PRs renaming the frame entries have been submitted, one of them introducing a bug when using omnistaging. This PR fixes that and removes a print comment (assuming added for debugging purposes). 31 August 2020, 14:00:34 UTC
b6b1f5e [jax2tf] Turn on with_gradient by default (#4180) As I was writing the demo I realized that it makes more sense for with_gradient to be set to True by default. I have also fixed a bug with tie_in in omnistaging. 31 August 2020, 07:26:32 UTC
634c625 More renaming of master to main in JAX internals (#4179) 30 August 2020, 09:38:14 UTC
ffbfadd lax.associative_scan: fix docstring examples (#4172) * lax.associative_scan: fix docstring examples * add verbiage from #3583 30 August 2020, 08:36:47 UTC
6b6789a applied simple find+sed for 'master' -> 'main' (#4174) * applied simple find+sed for 'master' -> 'main' * Rename master->main in JAX API and internals (#4178) * Started with #4174 * Renamed Trace.master to Trace.main * Renamed core.new_master and core.new_base_master Co-authored-by: George Necula <gcnecula@gmail.com> 30 August 2020, 08:16:51 UTC
1a87fd3 Implement a proper shape checking rule for gather. (#4166) * Implement a proper shape checking rule for gather. The implementation is based on the corresponding shape inference code in `tensorflow/compiler/xla/service/shape_inference.cc`. The tests added in `tests/lax_test.py` are similarly mirroring the corresponding tests in tensorflow, with slight adaptations for the particular setting of JAX. Fixes google/jax#2826, and in principle fixes google/jax#4154 and google/jax#3905. * Extracted common functions for gather/scatter shape checking rules. 29 August 2020, 08:24:03 UTC
a33f4dd Add support for axis_index inside vmap (#4168) Also, reorganize the code to put all `axis_index` related functions in `lax_parallel.py`, next to all other parallel collectives. 28 August 2020, 18:03:39 UTC
1dab791 Avoid calling jnp.sum() on list (#4163) 28 August 2020, 16:07:30 UTC
04f9ff7 Addition of one more conclusive polynomial comparison case. (#4167) * Addition of one more conclusive polynomial comparison case. In the case when the difference between two polynomials is a constant, it is possible to conclusively compare them. This commit adds such a case to masking.Poly.__ge__. * Added a few relevant tests in tests.masking_test.test_Poly_compare. 28 August 2020, 14:27:32 UTC
7210d6f Add support for binding axis_name in gmap This allows executing collectives over the gmapped axes. This requires some extra manipulation of the gmapped jaxpr, since gmap exposes a single logical axis name, but evaluates the program using multiple "physical" axes. This also fixes some bugs around handling `multiple_returns` in vmap collective implementation. 28 August 2020, 12:42:01 UTC
e95d570 Add benchmarks for specifically the dispatch time. (#4128) The goal is to distinguish the time it takes for `jitted_f` to return, and the time it takes to return and wait for the result. We also add one to distinguish the time it takes to call the function with the argument transfer or without it. e.g. name time/op jit_trivial_dispatch 28.9µs ± 2% jit_trivial 31.5µs ± 5% jit_simple_dispatch 60.7µs ± 4% jit_simple 129µs ±24% jit_simple_many_args_disptch 390µs ±19% jit_simple_many_args 388µs ±16% jit_dispatch_without_transfer 379µs ± 6% jit_dispatch_with_transfer 450µs ± 5% 27 August 2020, 14:02:13 UTC
36846e0 Revert "Delete batching.last. (#4148)" (#4160) This reverts commit 4bf3d6e9cccc5de3834e37affae2012e6e3d3180. This commit fails internal tests. 27 August 2020, 09:45:48 UTC
a7faf09 [jax2tf] Added conversion for scatter*_p primitives. (#4091) * [jax2tf] Added conversion for scatter*_p primitives. Limitations: the conversion works as well as the conversion of the underlying reduction functions (e.g. lax.scatter_max is not properly converted for the int8 dtype, because tf.math.maximum is not defined for int8 tensors); the conversion can not take advantage of the unique_indices parameter. This does not affect correctness, but may affect performance on certain platforms (as stated in the documentation of lax.scatter). * Put tf.function experimental compile wrapper back on scatter. * Removed unique_indices=True test cases * Remove non-deterministic test cases from the scatter harness. This commit also documents the reasons for ignoring these test cases and potential pitfalls, in case someone needs to perform these tests at a later time. 27 August 2020, 09:24:13 UTC
4d7396a Implement a proper shape checking rule for scatter. (#4144) The implementation is based on the corresponding shape inference code in `tensorflow/compiler/xla/service/shape_inference.cc`. The tests added in `tests/lax_test.py` are similarly mirroring the corresponding tests in tensorflow, with slight adaptations for the particular setting of JAX. 27 August 2020, 09:04:32 UTC
80114e5 Add a boolean to _check_shapelike to accept or reject shapes (#4108) * Add a boolean to _check_shapelike to accept or reject shapes corresponding to arrays of 0 elements. (Fixes google/jax#3972). * Added test for failures referenced in issue 3972. 27 August 2020, 07:47:19 UTC
1dc71b2 [jax2tf] Add testing for add/mul/min/max conversion. (#4142) * [jax2tf] Add testing for add/mul/min/max conversion. Only certain types are supported for each of the operations above. This commit adds previously missing tests to make this explicit. 27 August 2020, 07:46:32 UTC
c76b84f Revert "Increase tolerance for CPU test LaxBackedNumpyTests::testCorrCoef (#4080)" (#4151) This reverts commit 22b92c5122ab5af6f5e4560f9be08f5649ae7653. We revert this change because the LLVM bug that made us relax the test tolerance is now fixed. 27 August 2020, 07:34:53 UTC
57f49b6 Fix bug in omnistaging_enabler (#4159) This code was failing with "KeyError: psum" for the tests "//third_party/py/flax/...". I suspect that the error is due to the ordering of the omnistaging enablers, changed in #4152. I am not sure of this fix, but this seemed to be enough for all the presubmit tests to pass and allow the copybara import. 27 August 2020, 07:05:24 UTC
417c9ff Fix pytype error (#4158) 27 August 2020, 06:41:16 UTC
29073be cleanup: remove duplicate line (#4156) 27 August 2020, 04:13:33 UTC
f0fb7d0 Use omnistaging env var even when not using absl flags for config. (#4152) 26 August 2020, 21:06:27 UTC
1d93991 allow random.choice to accept ndarray input (#4145) * allow random.choice to accept ndarray `a` follow-up to #4137 to allow ndarray inputs to be passed * add jax.random.choice tests to cover ndarray input * don't use callables in test params it can mess with pytest-xdist because of hashing by id 26 August 2020, 17:21:56 UTC
01319fb Speed up and clean up geomspace test. (#4149) * Speed up and clean up geomspace test. 25 August 2020, 17:05:06 UTC
4bf3d6e Delete batching.last. (#4148) A -1 axis works just as well at head. 25 August 2020, 16:53:18 UTC
8c8060e Remove workaround for illegal vmap out_axes. (#4147) 25 August 2020, 16:53:02 UTC
6d54eb5 Do not call asarray() on inputs of jax.random.choice (#4137) 25 August 2020, 12:47:43 UTC
f959219 Rename collectives into "collective operations" for the pmap function. (#4136) It is just because it serves as the entry point, and this term leads to good Google results, such as https://en.wikipedia.org/wiki/Collective_operation, while the current "collectives" do not. 25 August 2020, 12:39:45 UTC
f4b05bc make pe.abstract_eval_fun use omnistaging (#4139) 25 August 2020, 12:38:41 UTC
04173b3 Merge pull request #4140 from sharadmv/patch-2 Remove frame check assertion in `extend_axis_env`. 25 August 2020, 12:38:20 UTC
774b5f6 Remove frame check assertion in `extend_axis_env`. 25 August 2020, 04:13:30 UTC
e06a6ab Add support for negative axes to vmap. (#4111) * Add support for negative axes to vmap. * Add workaround for out-of-range vmap axes. 25 August 2020, 00:21:19 UTC
603f0c1 Fix scan carry types in gradient of complex ODE (#4130) * Cast t_bar from potential complex to float in ode.py * Add test case for complex odeint (currently failing) * Wrap odeint into complex-to-real function in test case * fixup Co-authored-by: Stephan Hoyer <shoyer@google.com> 24 August 2020, 20:50:44 UTC
0cc3802 Fix documentation of scatter_* operations. (#4138) * Fix documentation of scatter_* operations. This commit changes the documentation of the `unique_indices` parameter to scatter to better capture its intended meaning in XLA. 24 August 2020, 19:29:22 UTC
back to top