swh:1:snp:71afc25eb6e6e055a37a962e6b91010ec35e397f

sort by:
Revision Author Date Message Commit Date
bdc3fa0 Remove HLO_Pred as one of the options for HLO_DimensionValue Not sure what was the original idea behind this, but it's not evident to me how a boolean can be a valid element type of HLO_DimensionTensor. Let's see if this breaks anything. PiperOrigin-RevId: 447366636 09 May 2022, 03:09:39 UTC
78fafc5 Merge pull request #10612 from ROCmSoftwarePlatform:rocm_build_fixes PiperOrigin-RevId: 447111173 07 May 2022, 01:28:05 UTC
1093559 [linalg] Add matmul precision scope for svd. PiperOrigin-RevId: 447095391 06 May 2022, 23:33:50 UTC
bbdcec8 Fixes to enable JAX to build on ROCm. 06 May 2022, 22:57:51 UTC
7251956 Remove unused jaxlib build targets. //jaxlib:lapack is unused, and once it is gone we can merge //jaxlib:mhlo_helpers into //jaxlib. PiperOrigin-RevId: 447084931 06 May 2022, 22:34:14 UTC
244bc8c Merge pull request #10393 from lgeiger:speedup-indexing PiperOrigin-RevId: 447083562 06 May 2022, 22:26:53 UTC
883cf2b Refactor custom call building code in jaxlib to use a helper function. Refactoring only, no functional changes intended. This should fix a jaxlib build issue on Windows: we only have one constructor of layouts, and it explicitly requests an int64 type. Fixes https://github.com/google/jax/issues/10474 PiperOrigin-RevId: 447076192 06 May 2022, 21:51:24 UTC
cec717d Return wrapper type in pytree rule for TF's DictWrapper. Fixes #10586 PiperOrigin-RevId: 447063792 06 May 2022, 20:54:21 UTC
08c3c2e Split CUDA and HIP C++ code in jaxlib into separate directories. PiperOrigin-RevId: 447062506 06 May 2022, 20:48:00 UTC
79f0f6c Revert: Make concatenate allow concatenation on dynamic dimensions Using -1 as a dynamic dimension is an MHLO convention not a JAX-level convention. PiperOrigin-RevId: 447052267 06 May 2022, 20:00:16 UTC
dcfb546 Merge pull request #10607 from jakevdp:jit-faq PiperOrigin-RevId: 447049741 06 May 2022, 19:48:20 UTC
c4836aa DOC: add FAQ entry on jit-compiling methods 06 May 2022, 19:39:12 UTC
452de3f Merge pull request #10604 from google:yashk2810-patch-11 PiperOrigin-RevId: 447025482 06 May 2022, 18:01:39 UTC
40c61cf Update workspace for jaxlib release for testing 06 May 2022, 17:54:38 UTC
562e27d Merge remaining CUDA and ROCM Python code. Completes work started in https://github.com/google/jax/pull/10556 PiperOrigin-RevId: 447005344 06 May 2022, 16:35:01 UTC
ab7a60b Merge pull request #10536 from mattjj:scan-dce PiperOrigin-RevId: 446907899 06 May 2022, 06:03:59 UTC
04e4ffd gate scan dce rule on after_neurips flag 06 May 2022, 05:23:02 UTC
d0863a1 add scan dce rule tests, fix bugs 06 May 2022, 04:27:22 UTC
d57e364 [linalg] Update qdwh to prevent underflow in norm estimation. PiperOrigin-RevId: 446887070 06 May 2022, 03:12:32 UTC
851a159 Merge pull request #10594 from jakevdp:api-util PiperOrigin-RevId: 446865291 06 May 2022, 01:03:13 UTC
5d45458 api_util: make shaped_abstractify respect raise_to_shaped 06 May 2022, 00:20:00 UTC
212edd6 Merge pull request #10578 from sharadmv:enable-cond-effects PiperOrigin-RevId: 446840237 05 May 2022, 23:02:06 UTC
99a08ee Fix check for numpy bool indices 05 May 2022, 22:47:00 UTC
cace686 Simplify isinstance check 05 May 2022, 22:47:00 UTC
167c6a9 Expand constant indexing test to check `slice` 05 May 2022, 22:47:00 UTC
2e681ff Simplify `_normalize_index` 05 May 2022, 22:47:00 UTC
5e2dd9c Add jaxpr test to ensure that no normalization happens for constant indices 05 May 2022, 22:47:00 UTC
0c5b132 Speedup `_expand_bool_indices` when passing basic integer indices 05 May 2022, 22:47:00 UTC
c9d6e76 Do not call `concrete_aval` for basic integer index checks 05 May 2022, 22:47:00 UTC
38cee9f Remove unnecessary `asarray` call 05 May 2022, 22:47:00 UTC
10addfd Prevent unnecessary `expand_dims` 05 May 2022, 22:47:00 UTC
7655392 Speedup index computations for integer indexing 05 May 2022, 22:47:00 UTC
3941e76 Enable effects for cond Co-authored-by: Matthew Johnson <mattjj@google.com> 05 May 2022, 22:26:05 UTC
479b6a6 Merge pull request #10577 from sharadmv:while-cond-effects PiperOrigin-RevId: 446830634 05 May 2022, 22:20:26 UTC
5c76eeb Merge pull request #10593 from mattjj:fix-new-remat-batching PiperOrigin-RevId: 446828856 05 May 2022, 22:12:51 UTC
19b5332 Merge pull request #10566 from jakevdp:bcoo-reconfigure PiperOrigin-RevId: 446827297 05 May 2022, 22:06:35 UTC
bf4c3b6 [sparse] add bcoo_update_layout utility 05 May 2022, 21:53:38 UTC
4b2f595 Use tf.io.gfile.rmtree instead of tf.io.gfile.remove as `remove` does not work on some file systems. PiperOrigin-RevId: 446803505 05 May 2022, 20:39:31 UTC
b92c6b1 fix ad_checkpoint.checkpoint vmap rule 05 May 2022, 20:31:27 UTC
83dee8f Make jaxlib extension libraries Bazel deps of //jaxlib. Previously we depended on various .so files directly so they were pulled into the jaxlib wheel build, but it seems to work to add the libraries in question to //jaxlib and depend on that in the usual way. It appears if a py_library() is used as a data-dependency of another rule, Bazel includes any transitive C++ extension deps, and that's what we want. PiperOrigin-RevId: 446802592 05 May 2022, 20:30:06 UTC
8dac44d [mhlo] Add result type inference for mhlo.pad. PiperOrigin-RevId: 446780637 05 May 2022, 19:03:40 UTC
e169bcf Merge pull request #10590 from jakevdp:gmres-matmul PiperOrigin-RevId: 446779333 05 May 2022, 18:57:55 UTC
bc44b3d Merge pull request #10582 from sharadmv:debug-ad-rule PiperOrigin-RevId: 446764484 05 May 2022, 18:04:10 UTC
4618f9c Consolidate hip_prng and cuda_prng. The Python code in jaxlib to build AMD HIP (ROCM) and NVIDIA CUDA kernels is almost identical. Share that Python code rather than duplicating it. This change only updates the prng kernels; the idea would be to follow it with similar changes consolidating the other Python code in jaxlib between CUDA and HIP. PiperOrigin-RevId: 446761784 05 May 2022, 17:55:29 UTC
30fd817 jax.scipy.sparse.linalg: support sparse matrices as operators 05 May 2022, 17:33:08 UTC
4296cb1 Enable AD rules for `debug_print` 05 May 2022, 17:11:58 UTC
931bf36 [JAX] Split the "gpu" platform in internal JAX usage into separate "cuda" and "rocm" platforms. In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist. [PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device. PiperOrigin-RevId: 446737518 05 May 2022, 16:33:06 UTC
ef18220 Updated documentation for approx_*_k to specify when the top k is guaranteed to be ordered. PiperOrigin-RevId: 446708610 05 May 2022, 14:12:15 UTC
a47c990 Merge pull request #10581 from sharadmv:debug-batching-rule PiperOrigin-RevId: 446640686 05 May 2022, 06:47:25 UTC
c1a8d7f Enable batching rule for debug_print Co-authored-by: Matthew Johnson <mattjj@google.com> 05 May 2022, 06:32:36 UTC
4703766 Enable effects for cond in while loop 05 May 2022, 05:59:41 UTC
24eb7d8 Merge pull request #10569 from sharadmv:enable-scan-ordered-effect PiperOrigin-RevId: 446628368 05 May 2022, 05:07:24 UTC
7b53abf Merge pull request #10579 from jakevdp:fix-readthedocs PiperOrigin-RevId: 446628291 05 May 2022, 05:02:08 UTC
d2e3d42 Updates values after jax and jaxlib 0.3.10 release PiperOrigin-RevId: 446623299 05 May 2022, 04:17:37 UTC
dc42d7b Enable print in while loops/scan 05 May 2022, 01:55:47 UTC
dd71e29 DOC: remove mentions of units 04 May 2022, 23:08:44 UTC
ceb3222 Merge pull request #10568 from jakevdp:sparse-doc PiperOrigin-RevId: 446573409 04 May 2022, 23:03:30 UTC
e9ad622 Merge pull request #10548 from pschuh:opt-barrier-hide PiperOrigin-RevId: 446573378 04 May 2022, 22:58:07 UTC
8ee8a75 Merge pull request #10490 from ajcr:add_jax_scipy_linalg_funm PiperOrigin-RevId: 446570023 04 May 2022, 22:42:39 UTC
38ce6d0 Update TF commit for release PiperOrigin-RevId: 446555288 04 May 2022, 21:42:50 UTC
7297115 Merge pull request #10546 from jakevdp:unravel-indices PiperOrigin-RevId: 446553390 04 May 2022, 21:37:25 UTC
12743e7 If optimization_barrier is not available, disable the jax2tf tests. 04 May 2022, 21:19:36 UTC
3c2d2b2 jnp.unravel_index: improve test coverage 04 May 2022, 20:32:16 UTC
58320e2 jnp.unravel_index: avoid overflow for large dimension sizes 04 May 2022, 20:12:29 UTC
a8c6742 Restrict Bazel visibility of //jaxlib:gpu_support PiperOrigin-RevId: 446524414 04 May 2022, 19:34:08 UTC
9e371c7 [sparse] incremental improvement to docs 04 May 2022, 19:25:16 UTC
03c8020 Merge pull request #10532 from mattjj:remove-units-final PiperOrigin-RevId: 446519086 04 May 2022, 19:09:51 UTC
8c6f916 Reapply https://github.com/google/jax/pull/10482 now the TF PR is ready for submission. Make a couple of style cleanups in passing. PiperOrigin-RevId: 446506301 04 May 2022, 18:18:58 UTC
9cd55a2 [remove-units] remove units 04 May 2022, 17:58:56 UTC
97b7fd7 Merge pull request #10564 from mattjj:gate-xla-version PiperOrigin-RevId: 446498170 04 May 2022, 17:51:01 UTC
0c5864a add xla_client._version checks for mhlo.ConstOp signature fix break from 0cf08d0c6841332240cae873e4b4cf9a9b313373 04 May 2022, 16:54:06 UTC
0cf08d0 Integrate LLVM at llvm/llvm-project@46cc04de341b Updates LLVM usage to match [46cc04de341b](https://github.com/llvm/llvm-project/commit/46cc04de341b) PiperOrigin-RevId: 446430294 04 May 2022, 12:31:41 UTC
5c838d2 Add an option when lowering to not remove unused arguments. This way, code using the output xla executable does not need to also drop the unused arguments, simplifying downstream code. PiperOrigin-RevId: 446391558 04 May 2022, 08:22:14 UTC
35c0862 [mhlo] Add result type inference for mhlo.dynamic-slice. PiperOrigin-RevId: 446366018 04 May 2022, 04:58:29 UTC
2498af3 Merge pull request #10557 from sharadmv:dont-leak PiperOrigin-RevId: 446361367 04 May 2022, 04:17:55 UTC
78a4e30 Don't leak the keepalive in debug_callback lowering 04 May 2022, 03:53:30 UTC
a9c0a97 Merge pull request #10553 from sharadmv:debug-print PiperOrigin-RevId: 446357443 04 May 2022, 03:45:54 UTC
2bb7e97 Merge pull request #10549 from sharadmv:attach-executable PiperOrigin-RevId: 446355430 04 May 2022, 03:29:22 UTC
7a641de Add initial debug print implementation Co-authored-by: Matthew Johnson <mattjj@google.com> 04 May 2022, 03:16:05 UTC
cfe9256 Fix the release bug where jaxlib version was not right. PiperOrigin-RevId: 446352145 04 May 2022, 03:08:24 UTC
ef982cf Attach keepalive to executable 04 May 2022, 03:08:03 UTC
c823025 Fix the checking of all lockfiles written. Previously the check was dependent on `gfile.listdir(lockfiles_dir)` in process 0. But other processes could run ahead and delete their lockfiles once done causing the check on process 0 to fail. Now process 0 keeps a count after a successful write and then checks against that count. PiperOrigin-RevId: 446343403 04 May 2022, 02:04:39 UTC
3bfcabb Merge pull request #10295 from sharadmv:jaxpr-effect-lowering PiperOrigin-RevId: 446317877 03 May 2022, 23:30:41 UTC
8031eee Add in runtime tokens for effectful jaxprs 03 May 2022, 22:55:07 UTC
ff1a3c4 jax and jaxlib release PiperOrigin-RevId: 446295827 03 May 2022, 21:52:40 UTC
c6343dd jax.scipy.linalg.schur: error on 16-bit floats Fixes https://github.com/google/jax/issues/10530 PiperOrigin-RevId: 446279906 03 May 2022, 20:47:44 UTC
37ea024 Temporarily revert: 5d742fc0c3313798ba01b45bc13f973ef950e2e7 by lipracer <lipracer@gmail.com>: Compatible with RngBitGeneratorOp builder modifications I will reapply this PR when the TF PR is merged. PiperOrigin-RevId: 446250797 03 May 2022, 18:46:44 UTC
02b9909 Make `_one_replica_buffer_indices` callable by itself. PiperOrigin-RevId: 446131175 03 May 2022, 07:57:13 UTC
f7296a4 [linalg] Update LU decomposition. PiperOrigin-RevId: 446122667 03 May 2022, 06:53:03 UTC
31be621 Update the serialization example PiperOrigin-RevId: 446121048 03 May 2022, 06:39:21 UTC
c721544 better error message in mesh_utils.py PiperOrigin-RevId: 446119369 03 May 2022, 06:24:56 UTC
de7a872 Take temp_checkpoint_dir and final_checkpoint_dir as the arguments to serialize instead of the __init__. THis is because this manager will be defined at the top where the directories may not yet be known. PiperOrigin-RevId: 446104174 03 May 2022, 04:22:43 UTC
888e5c6 Update the version numbers after JAX release. PiperOrigin-RevId: 446092433 03 May 2022, 02:51:11 UTC
634f58c Enable a number of tests on GPU. In particular, pjit/xmap work on CPU these days. PiperOrigin-RevId: 446085110 03 May 2022, 01:57:27 UTC
b7293d5 Add fully asynchronous checkpointing. This will allow the training to proceed forward when the checkpoint is being committed. PiperOrigin-RevId: 446083057 03 May 2022, 01:43:54 UTC
939233e Merge pull request #10522 from mattjj:remove-units-partial-eval PiperOrigin-RevId: 446023568 02 May 2022, 21:02:11 UTC
372371c Add jax.scipy.linalg.funm 02 May 2022, 20:46:41 UTC
11ad045 [remove-units] remove units from partial_eval.py After last week's changes, units are no longer traced or introduced into jaxprs in any way, so we don't need to use them in partial evaluation. (Also there are some unrelated removals of dead code in maps.py.) 02 May 2022, 20:43:27 UTC
44006c7 Merge pull request #10526 from hawkinsp:cudasm PiperOrigin-RevId: 446017399 02 May 2022, 20:37:30 UTC
9fb9e12 Don't include PTX for older GPU generations. See: https://github.com/tensorflow/tensorflow/pull/55613 For a CUDA build at head with the default compute capabilities, reduces wheel size from 141MB to 112MB. Don't redundantly specify default compute capabilities in .bazelrc and in build.py. 02 May 2022, 20:27:37 UTC
back to top