https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
731e899 Add registrations and CPU kernels for comparisons on booleans PiperOrigin-RevId: 377534845 06 June 2021, 14:51:46 UTC
0c102db Merge pull request #6892 from gnecula:tf_cplx PiperOrigin-RevId: 377529860 04 June 2021, 16:40:43 UTC
d243258 [jax2tf] Implement inequalities and friends for complex numbers. This requires re-using JAX's lowering rule for comparisons of complex numbers to use lexicographic comparison. 04 June 2021, 14:56:44 UTC
de9f557 Merge pull request #6895 from gnecula:tf_uint64 PiperOrigin-RevId: 377510835 04 June 2021, 14:49:44 UTC
ede457f [jax2tf] Fix bug with max_int for uint64 04 June 2021, 12:29:54 UTC
293ca65 [jax2tf] Update limitations to account for tf.math improvements for trigonometric functions. PiperOrigin-RevId: 377436077 04 June 2021, 04:17:56 UTC
1912bd9 Merge pull request #6673 from hypercubestart:master PiperOrigin-RevId: 377417011 04 June 2021, 01:29:31 UTC
9d677e6 Merge pull request #6875 from gnecula:tf_x32 PiperOrigin-RevId: 377351830 03 June 2021, 19:51:57 UTC
5dc9df3 [JAX] Attach a priority to JAX backends. Use the backend with the highest priority when choosing a default backend. PiperOrigin-RevId: 377351657 03 June 2021, 19:48:24 UTC
98e4e40 Merge pull request #6879 from jakevdp:fix-tuple-returns PiperOrigin-RevId: 377324305 03 June 2021, 17:48:22 UTC
21dbe30 BUG: return JAX arrays rather than NumPy arrays in jnp.unravel_index 03 June 2021, 16:15:01 UTC
b2c7ae7 [JAX] Catch all exceptions from backend initialization. PiperOrigin-RevId: 377278098 03 June 2021, 13:49:56 UTC
c7a98b3 Fix a typo in shape checks for associative_scan Fixes #6884. PiperOrigin-RevId: 377276183 03 June 2021, 13:37:31 UTC
39526a0 Merge pull request #6873 from ROCmSoftwarePlatform:fix_rocm_linalg PiperOrigin-RevId: 377273325 03 June 2021, 13:16:39 UTC
bca3d61 Insert xmap SPMD axes into pjit sharding annotations This should let us emit good XLA annotations for `xmap(pjit)`. Previously we might have been overestimating the set of replicated mesh dimensions. PiperOrigin-RevId: 377259226 03 June 2021, 11:13:29 UTC
ecab743 Merge pull request #6877 from hawkinsp:tracebacks PiperOrigin-RevId: 377247694 03 June 2021, 09:47:21 UTC
c0d5256 Merge pull request #6880 from skye:cloud_tpu_readme PiperOrigin-RevId: 377246573 03 June 2021, 09:38:30 UTC
9c6a77e Merge pull request #6827 from gnecula:tf_poly2 PiperOrigin-RevId: 377244801 03 June 2021, 09:23:37 UTC
2ccda70 [jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials. Previously we allowed a dimension variable in lieu of a dimension. Now we allow multi-variate dimension polynomials. These polynomials overload addition, subtraction, multiplication. They also partially support equality and inequality checking. Equality and inequality are supported only when the operation result is the same for all valuations of variables greater than 0. For example, `a == a`, `a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a >= 2`. Division is supported only in the cases when either there is no remainder, or the divisor is a constant. This change allows us to support more general cases of `jnp.reshape(-1)`, such as those used in the internal implementation of `random_gamma`: ``` y = x.reshape((2, -1)) z = ... y ... return z.reshape(x.shape) ``` 03 June 2021, 07:58:06 UTC
1fc51ef Merge pull request #6883 from gnecula:tf_x32_more PiperOrigin-RevId: 377232091 03 June 2021, 07:40:11 UTC
d03d849 [jax2tf] Fix the 32/64-bit behavior to follow JAX rules JAX and TensorFlow have different behavior w.r.t. 32-64 bit computations. This PR cleans up the handling of types in jax2tf to ensure that we follow the same behavior in jax2tf and in JAX. This means that f_jax(args) always does the computation with the same precision as jax2tf.convert(f_jax)(args). This may mean that the result of the conversion depends on the value of JAX_ENABLE_x64. See README.md for more details. 03 June 2021, 07:12:58 UTC
3e9b13b Expanded the type promotion documentation with a confusing case I filed this as a bug, but I am assuming that it is not easy to fix, so I also change the documentation. Bug: 6874 03 June 2021, 06:45:01 UTC
8a65ff8 Update cloud_tpu_colabs README to reflect Cloud TPU VM announcement. 02 June 2021, 23:55:52 UTC
3954ba3 Merge pull request #6878 from zhangqiaorjc:disable_gmres PiperOrigin-RevId: 377143472 02 June 2021, 21:21:55 UTC
ed96e53 Fix incorrect handling of axis_index_groups in parallel primitive fallbacks PiperOrigin-RevId: 377139424 02 June 2021, 21:03:47 UTC
a2209c9 Merge pull request #6608 from tlu7:scipy-special-lpmn PiperOrigin-RevId: 377138669 02 June 2021, 21:00:08 UTC
730ec1b Disable test_gmres_against_scipy due to LLVM changes. 02 June 2021, 20:56:08 UTC
2882286 Add a --jax_traceback_filtering flag to control the traceback filtering mode. Add a new traceback filtering mode that uses __tracebackhide__, and use it in IPython. 02 June 2021, 20:25:37 UTC
a02bf59 Adds associated Legendre functions of the first kind. Co-authored-by: Jake VanderPlas <jakevdp@google.com> 02 June 2021, 18:37:37 UTC
46cc654 Move jax.abstract_arrays to jax._src.abstract_arrays. PiperOrigin-RevId: 377044255 02 June 2021, 13:25:22 UTC
f2eab9c Merge pull request #6834 from k-w-w:patch-3 PiperOrigin-RevId: 377033195 02 June 2021, 12:01:21 UTC
6219b19 Update savedmodel_test.py Minor change, just to trigger a copybara re-import. 02 June 2021, 10:28:48 UTC
ccc1ba7 Update savedmodel_test.py I added a cast to `float32`. This is needed because of an obscure bug in JAX (#6874). 02 June 2021, 09:17:03 UTC
ab52da0 Merge pull request #6053 from Jakob-Unfried:master PiperOrigin-RevId: 377007642 02 June 2021, 08:26:16 UTC
5b3d203 Merge branch 'master' into master 02 June 2021, 08:05:47 UTC
9580fd1 [jax2tf] Update limitations to account for tf.math improvements for division. PiperOrigin-RevId: 376997282 02 June 2021, 07:03:59 UTC
8e6101c Merge pull request #6866 from gnecula:tf_pjit PiperOrigin-RevId: 376989780 02 June 2021, 05:50:12 UTC
012da54 add gpu to the rocsolver backend 02 June 2021, 04:03:06 UTC
dca61aa address comments and update documentation 02 June 2021, 02:14:33 UTC
87abab7 Merge pull request #6785 from GregCT:changelist/373551581 PiperOrigin-RevId: 376922795 01 June 2021, 21:47:15 UTC
edd203e Merge pull request #6726 from njunge94:auxiliary_solver_data PiperOrigin-RevId: 376899659 01 June 2021, 19:58:39 UTC
d098391 Fix wrong method name 01 June 2021, 19:42:55 UTC
7db0c56 [JAX] Change how JAX manages XLA platforms. * Combine the concepts of "platform" and "backend". The main upshot of this is that the tpu_driver backend requires users to write `jit(..., backend="tpu_driver")` if mixing CPU and TPU execution, however I doubt users are writing that because it didn't work to mix CPU and tpu_driver before. * Initialize all platforms at startup, rather than lazily initializing platforms on demand. This makes it easy to do things like "list the available platforms". * Don't use two levels of caching. Cache backends only in xla_bridge.py, not xla_client.py. PiperOrigin-RevId: 376883261 01 June 2021, 18:44:31 UTC
87332d0 Merge pull request #6861 from ho-oto:patch-1 PiperOrigin-RevId: 376821439 01 June 2021, 13:47:56 UTC
8c1e146 Merge pull request #6864 from gnecula:tf_einsum PiperOrigin-RevId: 376814238 01 June 2021, 12:59:47 UTC
9d2491c Merge pull request #6856 from lkhphuc:patch-1 PiperOrigin-RevId: 376813353 01 June 2021, 12:52:41 UTC
973171b [jax2tf] Add support for pjit. 01 June 2021, 11:32:59 UTC
4fb9715 Merge pull request #6855 from PhilipVinc:logsumexp PiperOrigin-RevId: 376786399 01 June 2021, 09:16:22 UTC
c07d54a [jax2tf] Add shape polymorphism support for jnp.einsum. The main problem was that jnp.einsum uses opt_einsum.contract_path to parse the specification string and compute the order or the contractions. This function wants to compute the sizes of operands and intermediate results, and will fail if some dimensions are polymorphic. The (partial) solution here is to replace the operands with jax.ShapeDtypeStruct with a fixed size for all dimension variables, then call opt_einsum.contract_path and use that result if there is only one contraction. We abort if there are multiple contractions. This behavior is clearly sound. If there were multiple contractions, perhaps their order would be different with different dimension sizes. 31 May 2021, 16:06:15 UTC
c0c8e0d make logsumexp work with complex numbers 31 May 2021, 14:01:57 UTC
160c3e9 rename v to vh 31 May 2021, 13:18:24 UTC
e0f285f Merge pull request #6839 from jakevdp:reshape-doc PiperOrigin-RevId: 376645137 31 May 2021, 09:30:04 UTC
f30a36d Typo. Update Common_Gotchas_in_JAX.md 29 May 2021, 16:22:40 UTC
44b1791 Copybara import of the project: -- 8226dfc8a4974b4c8031ee267fa5327e778140ee by Nicholas Junge <nicholas.junge@web.de>: Handle negative values for list-like sections in jnp.split PiperOrigin-RevId: 376302305 28 May 2021, 03:33:49 UTC
5065e1b Add missing typing.Optional type annotations to function parameters. PiperOrigin-RevId: 376300297 28 May 2021, 03:10:23 UTC
dded0e3 DOC: add notes to jax.numpy docstrings about returning copies rather than views 28 May 2021, 01:05:45 UTC
44fcd71 Merge pull request #6851 from njunge94:split-fix PiperOrigin-RevId: 376286728 28 May 2021, 00:55:04 UTC
8226dfc Handle negative values for list-like sections in jnp.split 27 May 2021, 16:25:18 UTC
4ad332e Merge pull request #6841 from colemanliyah:fix_fingerprint PiperOrigin-RevId: 376023432 26 May 2021, 21:02:05 UTC
b68f2c9 fixed fingerprint debugging message to be compatible with current min jaxlib version 26 May 2021, 19:03:28 UTC
7150a10 Merge pull request #6835 from jakevdp:sharp-bits PiperOrigin-RevId: 375988125 26 May 2021, 18:21:44 UTC
9deeb73 Merge pull request #6836 from jakevdp:numpy-doc PiperOrigin-RevId: 375987921 26 May 2021, 18:17:57 UTC
0da0caa Merge pull request #6840 from hawkinsp:ci PiperOrigin-RevId: 375987148 26 May 2021, 18:14:28 UTC
f07ccf0 Use short tracebacks in CI builds. Often useful information is hard to see in the GitHub UI with the default traceback verbosity of pytest. 26 May 2021, 17:43:38 UTC
d844609 DOC: add section to Sharp Bits discussing implicit list conversions 26 May 2021, 16:03:42 UTC
ba422f2 Merge pull request #6825 from colemanliyah:master PiperOrigin-RevId: 375955559 26 May 2021, 16:00:20 UTC
03a1ee9 Update Jax linesearch to behave more like Scipy 26 May 2021, 11:49:56 UTC
62603fd Copybara import of the project: -- 746a232632652233f649b15d94f3ed2fd0ccc1fb by George Necula <gcnecula@gmail.com>: [jax2tf] Updates known limitations. This PR fixes several issues: * It updates the documentation of the known limitations * Increases the numerical tolerance for conv_general_dilated on GPU, to address test flakiness. * Adds a workaround for a TF bug that results in a crash when trying to extract the optimized HLO. -- 4302101aed30a2c7625a2dd5acbe1ca17f9540e4 by George Necula <gcnecula@gmail.com>: Added limitation for dot_general on GPU -- 207f66a970b7f596e1b265c7aa91fa56e27e7d51 by George Necula <gcnecula@gmail.com>: Added limitation for dot_general on GPU COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6837 from gnecula:tf_adjust_lim 207f66a970b7f596e1b265c7aa91fa56e27e7d51 PiperOrigin-RevId: 375910042 26 May 2021, 11:03:07 UTC
ada35ae [jax2tf] Update limitations to account for tf.math.igamma improvements PiperOrigin-RevId: 375888440 26 May 2021, 08:14:20 UTC
80a310f DOC: add note about array views in numpy docs 25 May 2021, 23:54:28 UTC
00e0302 Enable custom gradients SavedModel option in test 25 May 2021, 23:07:09 UTC
3b973ac Merge pull request #6822 from hawkinsp:take PiperOrigin-RevId: 375794394 25 May 2021, 21:22:07 UTC
0308527 Add auxiliary data support in custom_linear_solve 25 May 2021, 16:00:46 UTC
369ca13 add fingerprint to debugging log 24 May 2021, 23:31:25 UTC
d81a13c Merge pull request #6826 from akbir:update_bazel_mac_arm PiperOrigin-RevId: 375548691 24 May 2021, 20:46:21 UTC
dc81610 updated to official bazel.4.1.0 24 May 2021, 20:40:08 UTC
0573169 Merge pull request #6821 from hawkinsp:versions PiperOrigin-RevId: 375499903 24 May 2021, 17:09:14 UTC
dacf31f Check for NumPy and SciPy versions during jaxlib builds. 24 May 2021, 16:39:37 UTC
c2b0f72 Fix handling of empty dimensions in jnp.take(). 24 May 2021, 15:59:41 UTC
0702954 Merge pull request #6820 from hawkinsp:xla PiperOrigin-RevId: 375476354 24 May 2021, 15:03:28 UTC
a64d685 [jax2tf] Cleanup limitations for `rev` in light of improvements in TensorFlow. PiperOrigin-RevId: 375475438 24 May 2021, 14:56:59 UTC
e87173e Update XLA. 24 May 2021, 14:43:40 UTC
6743d77 Merge pull request #6819 from lgeiger:replace-pow PiperOrigin-RevId: 375461396 24 May 2021, 13:05:00 UTC
3a2e80e Replace `pow()` with `srqt()` or `square()` where possible 24 May 2021, 09:43:35 UTC
7ea7cea Merge pull request #6816 from gnecula:bfloat16_random PiperOrigin-RevId: 375417596 24 May 2021, 05:58:35 UTC
89d208b Merge pull request #6807 from lgeiger:reuse-jvp-ans PiperOrigin-RevId: 375366940 23 May 2021, 18:13:57 UTC
70f0110 Fix dtypes.issubdtype when called with "bfloat16" (as string) Fixes: #6813 23 May 2021, 16:32:45 UTC
74638a4 [jax2tf] Improve conversion of sign and abs, to account for TF limitations PiperOrigin-RevId: 375274010 22 May 2021, 18:03:01 UTC
6d77b9f [JAX:TFRT:CPU] Fix TfrtCpuClient::BufferFromHostBuffer bug when shape has non-default layout (e.g., from TPU). PiperOrigin-RevId: 375165974 21 May 2021, 21:36:35 UTC
56e9f7c Merge pull request #5813 from google:ad-oob-docs PiperOrigin-RevId: 375140429 21 May 2021, 19:30:33 UTC
b04d0c7 sync common gotchas notebook 21 May 2021, 19:24:28 UTC
8c0aa38 add note on autodiff of OOB indexing 21 May 2021, 19:23:33 UTC
b7f9189 Merge pull request #6808 from jakevdp:fix-broken-link PiperOrigin-RevId: 375137941 21 May 2021, 19:17:50 UTC
b624731 Fix broken link in custom derivatives notebook 21 May 2021, 18:32:21 UTC
387d629 Merge pull request #6806 from jakevdp:doc-cpu-warning PiperOrigin-RevId: 375116363 21 May 2021, 17:41:00 UTC
971eb86 Reduce redundant calculations of `tan`, `erfc` and `rsqrt` jvp 21 May 2021, 17:36:42 UTC
ff7c2af DOC: directly silence GPU warning 21 May 2021, 17:03:50 UTC
cbb2742 Merge pull request #6802 from lgeiger:np-resize PiperOrigin-RevId: 375107309 21 May 2021, 17:01:19 UTC
f523987 Merge pull request #6798 from saeta:bug-6791 PiperOrigin-RevId: 375098090 21 May 2021, 16:17:39 UTC
4905aae Merge pull request #6804 from gnecula:tf_cumsum PiperOrigin-RevId: 375097262 21 May 2021, 16:10:17 UTC
back to top