https://github.com/google/jax

sort by:
Revision Author Date Message Commit Date
c414f25 Fix a type annotation. PiperOrigin-RevId: 595897224 05 January 2024, 06:50:39 UTC
1ed6a81 Fix a type annotation typo. PiperOrigin-RevId: 595891648 05 January 2024, 06:16:11 UTC
059fdaf Update XLA dependency to use revision http://github.com/openxla/xla/commit/8c7233c8bb54832d5993fd6430b2658dd0e3f418. PiperOrigin-RevId: 595874816 05 January 2024, 04:45:49 UTC
d6a4723 [Pallas/Mosaic] Change Mosaic semaphores to be MemRef types This is part 1 of a change that enables allocating arrays of semaphores. It does not add any new public facing functionality and only changes how semaphores are represented in Mosaic. PiperOrigin-RevId: 595848688 05 January 2024, 01:56:30 UTC
6dd5a69 Merge pull request #19201 from jakevdp:sort-kwargs PiperOrigin-RevId: 595818057 04 January 2024, 23:34:15 UTC
919233f Merge pull request #19204 from jakevdp:doc-requirements PiperOrigin-RevId: 595817526 04 January 2024, 23:25:26 UTC
ea66029 Introduce min entry size check for compilation cache. Currently, the persistent compilation cache has a time threshold: the entry is cached only if the compilation time is less than the threshold. If compilation happens to take a while, but the resulting executable is small, there is nothing that prevents caching. This can result in a large number of small files in the cache. Introduce a size threshold. If the resulting executable's size (after serialization and compression) is less than this threshold, don't cache. This check is in addition to the compilation time check described above. Testing: new unit test, test workload. PiperOrigin-RevId: 595815611 04 January 2024, 23:17:05 UTC
ccecdba DOC: remove unnecessary doc build requirement 04 January 2024, 23:06:12 UTC
8b62516 [array api] add stable & descending params to jnp.sort & jnp.argsort 04 January 2024, 22:21:25 UTC
8803774 Fix/disable two tests failing in Windows CI. PiperOrigin-RevId: 595793411 04 January 2024, 21:47:32 UTC
3ff4eb4 [XLA:Python] Raise an AttributeError if __cuda_array_interface__ is called on various invalid buffers, rather than a RuntimeError. This makes hasattr(x, "__cuda_array_interface__") fail gracefully. In passing, also move the implementation into py_array.cc, and use an allowlist of supported types rather than a denylist. Fixes https://github.com/google/jax/issues/19134 PiperOrigin-RevId: 595788328 04 January 2024, 21:26:56 UTC
ebc7af9 Fix typo in pmap docstring Docstring states: > If the pmapped function is called with fewer positional arguments than indicated by **`static_argnums`** then an error is raised. However `static_argnums` is not an argument that exists - I believe this should be corrected to `static_broadcasted_argnums`. PiperOrigin-RevId: 595731210 04 January 2024, 17:50:00 UTC
326d1d2 jaxlib: avoid external build-time dependency on ml_dtypes Currently, the ml_dtypes C++ sources are included in the set of sources at jaxlib build time. This is unnecessary, and can lead to problematic version skew in some cases (e.g. nightly builds). PiperOrigin-RevId: 595725529 04 January 2024, 17:26:05 UTC
05a9906 Guard call to os.register_at_fork() to unbreak Windows CI. PiperOrigin-RevId: 595708288 04 January 2024, 16:10:32 UTC
d56116c Raise rather than return NotImplementedError in PmapSharding. PiperOrigin-RevId: 595662888 04 January 2024, 12:08:28 UTC
981b670 [Jax] Allow to set the python traceback frames limit. PiperOrigin-RevId: 595607107 04 January 2024, 07:26:55 UTC
7bba80d Rename batch_dims variable in shard_alike PiperOrigin-RevId: 595578639 04 January 2024, 04:59:39 UTC
9112afb [Pallas/GPU] Add support for jnp.cumsum PiperOrigin-RevId: 595578503 04 January 2024, 04:51:25 UTC
0d3fd53 Update XLA dependency to use revision http://github.com/openxla/xla/commit/257f1592457a8036ea70f6020601018d67e943c5. PiperOrigin-RevId: 595574791 04 January 2024, 04:26:51 UTC
c06c292 Merge pull request #19186 from jakevdp:asarray-copy PiperOrigin-RevId: 595541132 04 January 2024, 01:09:19 UTC
68abe0d Add correct batching rule for shard_alike PiperOrigin-RevId: 595532031 04 January 2024, 00:24:27 UTC
23b9c2a Add the githash that the jaxlib was built at to __init__.py. This is to allow identifying the githash of nightlies. PiperOrigin-RevId: 595529249 04 January 2024, 00:12:23 UTC
97fc213 [array API] support copy argument to jnp.asarray 03 January 2024, 23:20:27 UTC
f1be301 Merge pull request #19183 from jakevdp:dlpack-warning PiperOrigin-RevId: 595502005 03 January 2024, 22:20:08 UTC
f8b7169 Merge pull request #19182 from jakevdp:ndarray-item PiperOrigin-RevId: 595495441 03 January 2024, 21:57:50 UTC
df4e9c0 DOC: add warning about dlpack and buffer mutation 03 January 2024, 21:31:57 UTC
47e5c81 jnp.ndarray.item(): add args support 03 January 2024, 21:03:47 UTC
b613679 [Jax] Speedup stacks traceback PiperOrigin-RevId: 595416135 03 January 2024, 17:01:23 UTC
afa2f1e [Pallas/Mosaic] Add support for nested `ref.at` PiperOrigin-RevId: 595289898 03 January 2024, 05:54:15 UTC
7912e05 Update XLA dependency to use revision http://github.com/openxla/xla/commit/f9171123c609ff392c6f32647369fb674a6706c0. PiperOrigin-RevId: 595273180 03 January 2024, 04:18:54 UTC
c06e186 Error on conversion of empty arrays to boolean. PiperOrigin-RevId: 595264332 03 January 2024, 03:26:45 UTC
985a042 Merge pull request #19174 from mattjj:shmap-custom-vjp-replication-error-message PiperOrigin-RevId: 595254128 03 January 2024, 02:15:08 UTC
836563f [Pallas] Refactor indexing primitives to use NDIndexer abstraction Some notes about this change: * This change upgrades the `RefView` abstraction to store multiple indexers. This allows doing things like `ref.at[0].at[0]` to recursively create a view of a `Ref`. `RefView`s therefore encapsluate multiple `NDIndexer`s. * This generalizes most of the indexing primitive APIs (i.e. get_p, swap_p, addupdate_p) but does *not* generalize their rules. Most of the rules will raise a NotImplementedError if you use multiple `NDIndexer`s. Adding support will be done in a future CL. * With the above in mind, this change only preserves existing public facing APIs and adding actual support will involve updating the rules. PiperOrigin-RevId: 595229523 02 January 2024, 23:53:40 UTC
8c5e7b2 Merge pull request #19135 from eukub:concatenation-to-fstrings PiperOrigin-RevId: 595201881 02 January 2024, 21:53:05 UTC
12e57de [shard-map] improve error message when a custom_vjp bwd has extra psum 02 January 2024, 21:26:40 UTC
c0d4653 Delete sharding spec to HloSharding conversion since it's not used anymore. PiperOrigin-RevId: 595192496 02 January 2024, 21:13:23 UTC
fff5ea5 Remove deprecated unsafe_raw_array method from PRNG keys PiperOrigin-RevId: 595190146 02 January 2024, 21:03:21 UTC
e6c8901 Generate Python bindings for the Triton MLIR dialect The bindings are not yet included in the jaxlib wheel. I will do that in a follow up PR. PiperOrigin-RevId: 595174466 02 January 2024, 19:55:05 UTC
15f4a8d Merge pull request #19166 from jakevdp:fix-binom PiperOrigin-RevId: 595168665 02 January 2024, 19:30:12 UTC
5192fca Delete `mesh_sharding_specs` from JAX PiperOrigin-RevId: 595164505 02 January 2024, 19:14:39 UTC
77258cd stats.binom.pmf: return zero for k > n 02 January 2024, 18:53:44 UTC
697f17a Remove reliance on `ShardingSpec`s from `NamedSharding` to `HloSharding` conversion. PiperOrigin-RevId: 595151695 02 January 2024, 18:28:02 UTC
a678015 Implement a stable serialization API for Mosaic This lets us break a dependency on standard MLIR dialects while serializing the program into HLO. The scheme is simple: we make a lightweight lazy fork of existing dialects by mangling the dialect name and otherwise keeping the structure of the ops identical. This keeps serialization and deserialization simple, for as long as the upstream dialects don't change much. If they do, we have to increment our version counter and write rules that update the IR structure. Note that this scheme only protects us from changes such as changing the attributes annotating the ops (renaming, etc.). However, it doesn't protect us from the attributes defined by a dialect from changing. Still, as far as I can tell, the only attributes we depend on are enums (which are simply plain integer attributes, so we can remap their values) and affine maps (that are unlikely to change much, I hope). This does not actually wire up the pass yet, as we are currently reorganizing the Python/C++ boundary significantly. The integration should be completed once that works is done. PiperOrigin-RevId: 595128374 02 January 2024, 16:51:56 UTC
0419e01 [Mosaic] Add a pass to check operation invariants on-device This lets us easily catch things such as out-of-bounds loads or reference slices (leading to OOB DMAs or loads downstream). PiperOrigin-RevId: 595072511 02 January 2024, 11:19:35 UTC
1044c4f [Pallas] Add missing shape checks in the Pallas FlashAttention kernel for TPUs Missing shape checks can cause hard to understand runtime errors caused by OOB checks inserted by XLA. We weren't verifying that the attention bias and the segment ids have the shapes we were expecting. PiperOrigin-RevId: 595065315 02 January 2024, 10:37:25 UTC
b29e8b2 Update XLA dependency to use revision http://github.com/openxla/xla/commit/3262826f68c347eff74892aabdfbf00f5929dac0. PiperOrigin-RevId: 594837129 01 January 2024, 04:13:39 UTC
d5db08d Update XLA dependency to use revision http://github.com/openxla/xla/commit/fa350aa02d5150e794095348fd3027e9724e225f. PiperOrigin-RevId: 594681196 31 December 2023, 04:02:55 UTC
7961fb8 Update XLA dependency to use revision http://github.com/openxla/xla/commit/3d132585dfc5a34942b65847a92664e007b58d89. PiperOrigin-RevId: 594524128 30 December 2023, 03:04:21 UTC
4b76d03 Add shape of PositionalSharding to it's repr PiperOrigin-RevId: 594489540 29 December 2023, 21:05:47 UTC
462ec6d Update XLA dependency to use revision http://github.com/openxla/xla/commit/7c3006be64c95ac72b2c9c22409a925e1b736c03. PiperOrigin-RevId: 594358108 29 December 2023, 04:00:47 UTC
5579ff2 сhanged concatenation of strings to f-strings to improve readability and unify with the rest of code 28 December 2023, 16:38:13 UTC
d9aba6f Update XLA dependency to use revision http://github.com/openxla/xla/commit/46a9aaa3087ad12fbbe63286e15f8cfb50093070. PiperOrigin-RevId: 594143989 28 December 2023, 03:59:24 UTC
a74a782 Merge pull request #19129 from mattjj:cache-miss-explanations-3 PiperOrigin-RevId: 593928096 27 December 2023, 06:28:50 UTC
9112dce add jax.explain_cache_misses tracing cache miss explanations As part of making JAX's behavior more transparent, it must be clear not only when code is slow because it's spending all its time missing caches (and hence retracing/recompiling), but also _why_ it missed those caches. That is, just knowing (from e.g. setting jax_log_compiles) that code is retracing a lot doesn't tell the user what to do to fix things. But once the user knows that the cache misses are due to changing dtypes, or due to jit being passed a new callable object on every iteration of a loop, it's often clear what to do. And JAX can provide that information The main idea here is that pointing out which parts of the cache key differs from previously-seen keys can constitute a pretty good explanation. This PR adds an explanation mechanism. It can be enabled in a few different ways: * setting the `JAX_EXPLAIN_CACHE_MISSES` shell environment variable to something truthy; * setting the config option `jax.config.update('jax_explain_cache_misses', True)`; * using the context manager `jax._src.config.explain_cache_misses` context manager (not in public namespace yet); * when parsing command line flags with absl, using the `--jax_explain_cache_misses` flag. Co-authored-by: Yash Katariya <yashkatariya@google.com> 27 December 2023, 05:54:27 UTC
a055e58 Update XLA dependency to use revision http://github.com/openxla/xla/commit/5a2bee1f06005dfc1812cb69e2f8967666320fb8. PiperOrigin-RevId: 593909713 27 December 2023, 04:28:34 UTC
9c2b21f Add shard_alike tf_impl rule PiperOrigin-RevId: 593824294 26 December 2023, 19:09:49 UTC
f9be08b Update XLA dependency to use revision http://github.com/openxla/xla/commit/11afd8745b08fa124c003f5527192dfd2056df73. PiperOrigin-RevId: 593687708 26 December 2023, 04:09:40 UTC
663c039 Merge pull request #19117 from mattjj:tangent-dtypes-scan-test PiperOrigin-RevId: 593486780 24 December 2023, 23:00:53 UTC
0ea9a71 add test illustrating scale dtypes in scan 24 December 2023, 21:50:28 UTC
7cd934b Merge pull request #19116 from mattjj:fix-tangent-dtypes PiperOrigin-RevId: 593479100 24 December 2023, 21:43:08 UTC
0d6007c Replace usage of make_array_from_callback in full_like with shard_alike PiperOrigin-RevId: 593468331 24 December 2023, 19:53:37 UTC
d80d7a7 revive tangent dtypes test broken by d5646db d5646db, the partial rollback of #19096, left a tangent dtypes test broken. this pr fixes it, though we'd still like to re-land #19096 eventually. 24 December 2023, 19:47:59 UTC
c83fd97 Fix jax mlir python dependency build after https://github.com/llvm/llvm-project/commit/537b2aa264c5a9879a80289c8d123b39e520eb15 PiperOrigin-RevId: 593370604 24 December 2023, 05:02:29 UTC
b413ee2 Update XLA dependency to use revision http://github.com/openxla/xla/commit/2b6813b32b4093c620990f82dccccf5ebc91b0e6. PiperOrigin-RevId: 593199523 23 December 2023, 03:55:40 UTC
d5646db partial rollback of #19096 due to internal breakage (relying on jax internals) PiperOrigin-RevId: 593175212 22 December 2023, 23:54:26 UTC
6524402 Merge pull request #19097 from mattjj:tangent-dtypes-3 PiperOrigin-RevId: 593147437 22 December 2023, 19:41:48 UTC
05da18a tweaks to enable adding custom tangent dtypes tweaks to enable adding custom tangent dtypes: * fix a bug in zeros_like_shaped_array and KeyTyRules.zero to ensure `scalar_zero` is actually a scalar * upgrade the adder handler for ShapedArray to delegate to an extended dtype rule for addition * convert_element_type shouldnt blanket-disallow extended dtypes; actually that can be a key operation for working with them! instead, add new `convert_from` and `convert_to` rules. instead of letting these rules perform arbitrary logic, for now they can just return a bool indicating whether the conversion is legit; if false, an error is raised, and if true, the existing convert_element_type lowering rule just generates a ConvertElementType HLO from one physical type to the other this pr also adds a test for a custom tangent dtype of interest for plumbing quantization scales out of a backward pass 22 December 2023, 19:33:14 UTC
4c9241e Cache ClosedJaxpr creation to minimize cache_misses. ClosedJaxpr should always be created under a cache. PiperOrigin-RevId: 593023314 22 December 2023, 06:15:52 UTC
ab79e4a Update XLA dependency to use revision http://github.com/openxla/xla/commit/a312bf9798a1293fed92cf050650f8f900d0bb4b. PiperOrigin-RevId: 592999710 22 December 2023, 03:39:56 UTC
b9aa589 Merge pull request #19088 from jakevdp:array-api-unique PiperOrigin-RevId: 592994080 22 December 2023, 02:52:32 UTC
32e1a0c Merge pull request #19096 from mattjj:no-more-add-jaxval-primitive PiperOrigin-RevId: 592985946 22 December 2023, 01:53:15 UTC
be3ca50 del add_any_p and zeros_like_p, replace aval-dispatched traceable 22 December 2023, 01:04:21 UTC
3021e90 Merge pull request #19093 from mattjj:not-matts-fault PiperOrigin-RevId: 592979026 22 December 2023, 01:03:27 UTC
33e059f Merge pull request #19034 from jakevdp:float8-min PiperOrigin-RevId: 592966489 21 December 2023, 23:54:47 UTC
5e957c6 array api: add unique_* interfaces 21 December 2023, 23:49:54 UTC
0cf38cc Merge pull request #19086 from jakevdp:fix-lu PiperOrigin-RevId: 592966222 21 December 2023, 23:46:50 UTC
e31018b in partial_eval_custom rule for pjit, cache ClosedJaxpr creation Anywhere we call the ClosedJaxpr constructor, we had better be under a cache. We should audit the code... Never trust comments, especially when blame says mattjj wrote them Co-authored-by: Yash Katariya <yashkatariya@google.com> 21 December 2023, 23:35:45 UTC
c28aa2c Remove CPU support from compilation cache. CPU support was originally added to the compilation cache in anticipation of the availability of CPU acceleration compilation. Since this is not available and the --xla_cpu_use_xla_runtime flag has been deprecated, cleanup the code and test. Testing: test workload, revised unit test. PiperOrigin-RevId: 592962316 21 December 2023, 23:25:56 UTC
95eaf55 linalg.lu: avoid NaNs in default lowering rule 21 December 2023, 22:37:47 UTC
14fe47c Merge pull request #19090 from jakevdp:unique-equal-nan PiperOrigin-RevId: 592941710 21 December 2023, 21:48:12 UTC
e3e26f2 jnp.unique: add support for the equal_nan keyword 21 December 2023, 20:37:09 UTC
655b590 Merge pull request #19054 from jakevdp:array-api-trig-aliases PiperOrigin-RevId: 592900989 21 December 2023, 18:46:30 UTC
b07d176 Merge pull request #19084 from apaszke:mosaic-versioning PiperOrigin-RevId: 592895717 21 December 2023, 18:23:54 UTC
a3c8865 Merge pull request #19072 from jakevdp:vectorize PiperOrigin-RevId: 592893601 21 December 2023, 18:15:15 UTC
dcc5020 Make Pallas work with older jaxlib Recent MLIR changes have changed the names of vector.CombiningKind enum values, so we have to support both of them in the Pallas lowering (since many people might have an older jaxlib). 21 December 2023, 17:53:36 UTC
42ae843 Merge pull request #19077 from mattjj:no-cast-in-ad-util PiperOrigin-RevId: 592873263 21 December 2023, 16:53:10 UTC
699565a Removed unused `_dtype_to_xla_type_string` PiperOrigin-RevId: 592862584 21 December 2023, 16:04:57 UTC
6863569 remove use of cast in ad_util i hate cast 21 December 2023, 05:00:42 UTC
f64b80f Update XLA dependency to use revision http://github.com/openxla/xla/commit/70ba83748baf84463a1087d725af87ea13d57805. PiperOrigin-RevId: 592716358 21 December 2023, 03:20:28 UTC
9792e00 Cleanup _find_arg_mismatch logic PiperOrigin-RevId: 592697969 21 December 2023, 01:24:26 UTC
57d74d6 Always return a NamedSharding from eager shard_map PiperOrigin-RevId: 592689808 21 December 2023, 00:42:57 UTC
454ffcc Merge pull request #19008 from 8bitmp3:jax-docs-advanced-autodiff PiperOrigin-RevId: 592683399 21 December 2023, 00:13:48 UTC
6afb83a Upgrade JAX Advanced Autodiff 201 20 December 2023, 23:58:49 UTC
e089c9b Internal only changes. Reverts 4347950d9d018a254fd00bded54ae79df2e71556 PiperOrigin-RevId: 592679160 20 December 2023, 23:55:53 UTC
ad3d743 jnp.vectorize: support excluding arguments by keyword 20 December 2023, 23:38:19 UTC
965fefd Merge pull request #19071 from mattjj:tangent-dtypes PiperOrigin-RevId: 592664820 20 December 2023, 23:02:02 UTC
7ecd22c Exclude test_gpu_memory_allocation from pytest execution. PiperOrigin-RevId: 592664477 20 December 2023, 22:53:33 UTC
ec7d28c revise logic for tangent types of extended dtypes * remove the dead code KeyTangentTy * replace TyRules.make_tangent with TyRules.zero * removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it * fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type * fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see #19009 for a check which catches this and hence includes the same test change We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though! 20 December 2023, 22:24:52 UTC
6fcaf79 Merge pull request #19070 from mattjj:make-hypothesis-optional PiperOrigin-RevId: 592646659 20 December 2023, 21:40:32 UTC
d7cceb1 make hypothesis use in all_gather_test.py optional I think a31129a aka cl/587963496 accidentally made hypothesis a test dependency in tests/all_gather_test.py, rather than following our existing convention as in tests/state_test.py of making it optional. I think it was an accident because there's no discussion of adding hypothesis as a test dependence on the review for that PR/CL. This PR changes tests/all_gather_test.py to follow the convention for making hypothesis optional. 20 December 2023, 21:29:55 UTC
back to top