sort by:
Revision Author Date Message Commit Date
1a8a8a5 Fix example in `pjit` docstring 22 September 2022, 19:55:55 UTC
ba557d5 Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.". See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details. PiperOrigin-RevId: 476167538 22 September 2022, 19:27:19 UTC
d52de20 Disable tests that timeout in debug mode in CI PiperOrigin-RevId: 476157051 22 September 2022, 18:44:56 UTC
a157982 Make `jit(f).lower(*args)` go via lower_sharding_computation when `jax_array` is enabled. PiperOrigin-RevId: 476148608 22 September 2022, 18:13:33 UTC
640e15f Don't tuple arguments passed to XLA:CPU This is not needed and tuples are being avoided when possible for new code. This is tested by CPPJitTest.test_jit_with_many_args_works in jax/tests:api_test_cpu PiperOrigin-RevId: 476032228 22 September 2022, 08:29:14 UTC
405a231 Implement pjit fast path in cpp for jax.Array inputs PiperOrigin-RevId: 475988677 22 September 2022, 03:18:18 UTC
52476d1 Add addressable_data to Array (similar to GDA) to aid in transition and also in auto spmd partitioner mode, always convert to MeshPspecSharding. PiperOrigin-RevId: 475972534 22 September 2022, 01:19:35 UTC
a09ef8a Temporarily skip LaxBackedNumpyTests.testUnwrap on gpu to unblock jaxlib build PiperOrigin-RevId: 475970440 22 September 2022, 01:06:02 UTC
d41fa29 Merge pull request #12370 from jakevdp:lax-sort-overload PiperOrigin-RevId: 475907384 21 September 2022, 20:26:01 UTC
62d1783 Merge pull request #12451 from jakevdp:array-declarations PiperOrigin-RevId: 475906903 21 September 2022, 20:19:52 UTC
2dde633 [typing] add class-level declarations of Array members. This fixes some pytype errors associated with the changes in #12421 21 September 2022, 19:51:32 UTC
541aadc [XLA:GPU] Allow simplifying lowering-precision-conversions by default This might lead to the output having higher precision than specified by HLO. PiperOrigin-RevId: 475889141 21 September 2022, 19:04:45 UTC
b9e3843 Add many args benchmark for jax.Array PiperOrigin-RevId: 475853211 21 September 2022, 16:54:51 UTC
c7f2712 Flip default value of jax_unique_mhlo_module_names to False. This should help avoid unnecessary cache misses. PiperOrigin-RevId: 475852954 21 September 2022, 16:48:01 UTC
d0e1c3e Disable tests under sanitizers that are timing out in CI. PiperOrigin-RevId: 475839926 21 September 2022, 15:50:55 UTC
fd90f40 Merge pull request #12443 from cloudhan:fix-mlir-chlo-stablehlo-symbols PiperOrigin-RevId: 475808753 21 September 2022, 13:12:44 UTC
dc414c8 Switch from experimental_strict_action_env to incompatible_strict_action_env to avoid deprecation warning 21 September 2022, 09:26:21 UTC
3fa2c93 Fix linker error due to chlo and stablehol symbols are not exported in mlir dll 21 September 2022, 09:26:21 UTC
6183727 Update pjit_test to skip GDA tests with Array is enabled. PiperOrigin-RevId: 475684445 20 September 2022, 23:38:43 UTC
310bcd5 Merge pull request #12389 from LenaMartens:check-while-2 PiperOrigin-RevId: 475606527 20 September 2022, 18:23:42 UTC
018e700 Checkify: support batched while. 20 September 2022, 16:59:46 UTC
60c5b32 Merge pull request #12418 from jakevdp:jep-update PiperOrigin-RevId: 475582299 20 September 2022, 16:59:22 UTC
e855a9c Merge pull request #12428 from jakevdp:tracer-methods PiperOrigin-RevId: 475580916 20 September 2022, 16:52:56 UTC
fce1099 Update JEP-12049 implementation discussion 20 September 2022, 16:44:29 UTC
7469804 Tracer: add missing __round__ and __reversed__ methods 20 September 2022, 16:09:23 UTC
fc2902c Make the gda and xmap sharding check work generally by checking the OpSharding protos. PiperOrigin-RevId: 475560097 20 September 2022, 15:24:47 UTC
24bc153 Merge pull request #12425 from sharadmv:vis-pmap PiperOrigin-RevId: 475449377 20 September 2022, 03:13:21 UTC
0276a6e Add support for pmap sharding 20 September 2022, 02:29:44 UTC
09d3ee1 Merge pull request #12424 from sharadmv:vis-fixes PiperOrigin-RevId: 475442910 20 September 2022, 02:27:19 UTC
f825a3c Limit console width for visualize_sharding 20 September 2022, 01:41:45 UTC
e41e8d9 Only copy_to_device if the indices match. Otherwise reshard the array if its uncommitted. This is important where you have 1 process per device. PiperOrigin-RevId: 475418561 19 September 2022, 23:59:14 UTC
441f400 Merge pull request #12386 from sharadmv:viz_sharding PiperOrigin-RevId: 475387460 19 September 2022, 21:36:21 UTC
7ffe16b [typing] overloaded type declaration for variadic lax.sort 19 September 2022, 20:40:28 UTC
2d8b228 Add function to visualize `Sharding`s 19 September 2022, 20:27:08 UTC
2c315b3 Merge pull request #12422 from jakevdp:fix-doc-build PiperOrigin-RevId: 475370004 19 September 2022, 20:26:04 UTC
65cb0ff Merge pull request #12419 from jakevdp:kill-subprocs PiperOrigin-RevId: 475368869 19 September 2022, 20:19:42 UTC
be65694 docs: avoid deprecated matplotlib axis creation 19 September 2022, 19:55:18 UTC
2936c8a multiprocess_gpu_test: kill open subprocesses to avoid warning 19 September 2022, 19:31:10 UTC
aa2f898 Merge pull request #12414 from hawkinsp:xla PiperOrigin-RevId: 475271732 19 September 2022, 13:25:42 UTC
bec9d2e Bump XLA version. 19 September 2022, 13:17:46 UTC
a24726d Remove fast_path_args from Array and add `id` checks to Sharding's `__eq__` method as a fast shortcut. Also the C++ pjit path should help optimize the dispatch path. PiperOrigin-RevId: 475163903 18 September 2022, 22:35:49 UTC
9d8363a Fix the bug where the indices returned from `_get_input_metadata` were of length equal to the length of global devices but should have been the length of local devices instead. `shard_args` only deals with local devices and indices. Also, enable multihost pjit_test.py with Array. PiperOrigin-RevId: 475044692 17 September 2022, 20:29:40 UTC
590b5b5 Add `Array` counterparts to the serialization_test.py and disable the GDA tests if jax_array is enabled. PiperOrigin-RevId: 474944400 17 September 2022, 01:37:50 UTC
e6bdb00 Skip remote_transfer_test because Array does not have the xla_shape method since its deprecated. PiperOrigin-RevId: 474913967 16 September 2022, 22:26:12 UTC
dce93e4 Silence some pytype errors. PiperOrigin-RevId: 474898625 16 September 2022, 21:11:58 UTC
eec1b4a Set the sharding of uncommitted single device sharding Arrays correctly and fix some miscellaneous tests with Array too. Enable pjit_test and xmap_test with Array too (all of them are mechanical changes). PiperOrigin-RevId: 474858389 16 September 2022, 18:16:27 UTC
e010ae7 Pass device_assignment to ShardingContext instead of first_sharding which contains partitioning of an input too. It does not make sense to pass how an input is partitioned to ShardingContext because you can have `n` inputs all partitioned in a different way but all of them should have the same device_assignment. This follows SPMDAxisContext too. PiperOrigin-RevId: 474808207 16 September 2022, 14:18:50 UTC
45e48b3 Mark multiprocess_gpu_test as manual to skip it in OSS PiperOrigin-RevId: 474806518 16 September 2022, 14:08:08 UTC
35a5012 Fix and add test for weakref_lru_cache asan issue. PiperOrigin-RevId: 474684516 15 September 2022, 23:31:14 UTC
2a7dcb8 Merge pull request #12383 from jakevdp:valid-shape PiperOrigin-RevId: 474681872 15 September 2022, 23:17:15 UTC
a423dc7 tests: fix is_valid_shape() function 15 September 2022, 22:28:55 UTC
a7f4cb0 Merge pull request #12143 from wonhyeongseo:multinomial PiperOrigin-RevId: 474659934 15 September 2022, 21:37:38 UTC
28741b8 Some miscellaneous changes to make tests pass when jax.Array is enabled by default. 1. Add `device_buffer` and `device_buffers` fields to Array as a backwards compatible change for DA and SDA. 2. Support PartitionSpecs as input to in_axis_resources and out_axis_resources when jax_array is enabled as a backwards compatible change since all user code uses this currently. Create a MeshPspecSharding internally. 3. Some tests changes to make them pass PiperOrigin-RevId: 474642889 15 September 2022, 20:27:40 UTC
3f6eb40 JAX implementation of scipy.stats.multinomial pmf & logpmf Co-authored-by: harryjulian <harry.julian@peak.ai> 15 September 2022, 20:21:44 UTC
311f85e Adjust tolerance for lu test in preparation for XLA switch from cublas to cublasLt PiperOrigin-RevId: 474618770 15 September 2022, 19:01:26 UTC
fa86ce4 Merge pull request #12369 from jakevdp:fix-dtype-case PiperOrigin-RevId: 474616448 15 September 2022, 19:01:12 UTC
8304825 Remove the optional OpSharding typehint from XLACompatibleSharding PiperOrigin-RevId: 474598756 15 September 2022, 17:40:19 UTC
60ec541 Enable the debugging_primitives pjit(xmap) case. Also don't check for sharding mismatch when the array is not committed. Check the device assignment only for committed arrays. PiperOrigin-RevId: 474598597 15 September 2022, 17:34:04 UTC
9791199 Merge pull request #12361 from sharadmv:for-unroll PiperOrigin-RevId: 474506368 15 September 2022, 09:04:29 UTC
2fb8695 [jax] Update expected error message in preparation for new XLA:GPU runtime PiperOrigin-RevId: 474458551 15 September 2022, 03:19:38 UTC
08c5753 Implement unrolling for for_loop 15 September 2022, 01:32:37 UTC
8a1c5b8 Merge pull request #12371 from jakevdp:scatter-typing PiperOrigin-RevId: 474433898 15 September 2022, 00:48:34 UTC
1bc9672 Merge pull request #12372 from jakevdp:array-issubclass PiperOrigin-RevId: 474422608 14 September 2022, 23:44:52 UTC
bdb264e jax.Array: add issubclass test analogous to existing DeviceArray test. 14 September 2022, 23:29:38 UTC
ed06838 [typing] clear up logic in scatter_update Static type checkers do not parse deeply enough to know that by line 182 bucket_size cannot by None; branching on an explicit None check is easier to follow (even for human readers) 14 September 2022, 22:32:52 UTC
fa760bf Merge pull request #12368 from mattjj:relu-comment PiperOrigin-RevId: 474406048 14 September 2022, 22:28:45 UTC
5829c6a Change case of typing.Dtype -> typing.DType This follows the convention used in numpy.typing.DType. 14 September 2022, 22:03:55 UTC
f3710ae add paper link about grad-relu-at-zero 14 September 2022, 21:16:01 UTC
1338864 Change TensorFlow to depend on StableHLO instead of vendoring it This makes handling TensorFlow's dependency on StableHLO consistent with handling other TensorFlow's dependencies. For example, LLVM goes into //third_party/llvm, and so should StableHLO. Users of tensorflow/tensorflow (e.g. JAX) need to change Bazel builds, replacing `@org_tensorflow//tensorflow/compiler/xla/mlir_hlo/stablehlo` with `@stablehlo//`. Nothing else changes, e.g. C++ includes, C++ usage, Python bindings and Python usage all stay the same. Example: https://github.com/google/jax/pull/12174. Users of tensorflow/mlir-hlo are unaffected thanks to the awesome power of Copybara. There are minor changes in the StableHLO part of MLIR-HLO caused by the fact that the StableHLO repository and the vendored StableHLO inside tensorflow/tensorflow have diverged a little bit (e.g. Markdown formatting is slightly different between repositories because I didn't have the time to propagate these changes) and now they have been forced to converge, but these changes won't affect the behavior of neither CMake nor Bazel builds of MLIR-HLO. Moving forward, contributions to StableHLO will only be possible through openxla/stablehlo. This is because tensorflow/tensorflow no longer vendors StableHLO. (tensorflow/mlir-hlo still does, but it's readonly). PiperOrigin-RevId: 474360128 14 September 2022, 19:25:55 UTC
13a7034 Internal change PiperOrigin-RevId: 474331907 14 September 2022, 17:39:38 UTC
0a5d8e8 Make nested xmap work with Arrays and GDA (in single process). PiperOrigin-RevId: 474323667 14 September 2022, 17:09:27 UTC
c2ce0db Add a jit decorator to comparison operators. We had previously omitted jit because of https://github.com/google/jax/issues/6713, but concrete remat no longer exists. PiperOrigin-RevId: 474309905 14 September 2022, 16:19:22 UTC
71b0968 skip some for_loop test cases on gpu due to flakey timeouts PiperOrigin-RevId: 474168747 14 September 2022, 00:51:55 UTC
ef02568 Merge pull request #12347 from mattjj:djax-einsum PiperOrigin-RevId: 474156489 13 September 2022, 23:45:23 UTC
b27d8c1 Merge pull request #12342 from jakevdp:typing-test PiperOrigin-RevId: 474154750 13 September 2022, 23:36:49 UTC
2547e81 Use C++ Array in pmap path and move PmapSharding to cpp PiperOrigin-RevId: 474151089 13 September 2022, 23:19:18 UTC
49a6034 [dynamic-shapes] enable basic einsum support, following jax2tf shape polys 13 September 2022, 23:06:33 UTC
da90234 Delete soft_pmap as it has no users. Please use `pjit` or `xmap` if you do want soft_pmap. `jax.soft_pmap` is undocumented. If it were documented, a deprecation period would have been provided. PiperOrigin-RevId: 474145090 13 September 2022, 22:52:10 UTC
dc7db8d Merge pull request #12346 from jakevdp:jax-array-test PiperOrigin-RevId: 474142815 13 September 2022, 22:41:49 UTC
b7bc095 Merge pull request #12149 from sharadmv:for-nesting PiperOrigin-RevId: 474122048 13 September 2022, 21:18:42 UTC
8eb44fd jax_array_test: set config once & fix X64 failure 13 September 2022, 21:06:38 UTC
f26f1e8 Add support for closing over `Ref`s in nested for loops 13 September 2022, 20:32:44 UTC
b3c31eb Add typing_test.py 13 September 2022, 19:43:51 UTC
ad326b9 Use cases_from_list to subsample enumerated cases in for_loop_test PiperOrigin-RevId: 474093596 13 September 2022, 19:34:10 UTC
a2930e6 Merge pull request #11859 from jakevdp:jep-type-annotation PiperOrigin-RevId: 474062166 13 September 2022, 17:42:32 UTC
dc4922f Bump shards on for_loop_test PiperOrigin-RevId: 474038276 13 September 2022, 16:30:04 UTC
ee52173 Merge pull request #12300 from jakevdp:typing-simple PiperOrigin-RevId: 474037430 13 September 2022, 16:23:32 UTC
358363e JEP 12049: Type Annotation Roadmap 13 September 2022, 16:14:48 UTC
ec24005 Merge pull request #12334 from jakevdp:fix-ensure-index PiperOrigin-RevId: 473961635 13 September 2022, 09:16:21 UTC
c491aaa Merge pull request #12213 from jakevdp:environment-info PiperOrigin-RevId: 473866407 12 September 2022, 22:55:04 UTC
0fb462e Add jax.print_environment_info() 12 September 2022, 22:39:33 UTC
0063661 jax.test_util: add capture_stdout context manager 12 September 2022, 22:21:52 UTC
e5725f1 Split for_loop_test out of lax_control_flow_test PiperOrigin-RevId: 473848277 12 September 2022, 21:46:07 UTC
ef48533 Add a MLIR constant handler for GDA. PiperOrigin-RevId: 473848126 12 September 2022, 21:40:04 UTC
b439d76 Merge pull request #12332 from jakevdp:fix-index-take PiperOrigin-RevId: 473822050 12 September 2022, 19:54:27 UTC
45b71e5 ensure_index: raise better error for traced inputs 12 September 2022, 19:28:14 UTC
34eae3d jax.lax: ensure GatherDimensionNumbers contains tuples for hashability 12 September 2022, 19:10:17 UTC
3243e23 [sparse] Lower batch-mode bcoo_dot_genernal to cusparseSpMM. PiperOrigin-RevId: 473777597 12 September 2022, 17:09:41 UTC
cc72a20 use jax._src.typing in lax.py & a few other places 12 September 2022, 16:08:13 UTC
4fed097 jax._src/typing: add basic types 12 September 2022, 16:07:56 UTC
back to top