c414f25 | jax authors | 05 January 2024, 06:48:03 UTC | Fix a type annotation. PiperOrigin-RevId: 595897224 | 05 January 2024, 06:50:39 UTC |
1ed6a81 | Qiao Zhang | 05 January 2024, 06:15:25 UTC | Fix a type annotation typo. PiperOrigin-RevId: 595891648 | 05 January 2024, 06:16:11 UTC |
059fdaf | jax authors | 05 January 2024, 04:45:07 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/8c7233c8bb54832d5993fd6430b2658dd0e3f418. PiperOrigin-RevId: 595874816 | 05 January 2024, 04:45:49 UTC |
d6a4723 | Sharad Vikram | 05 January 2024, 01:55:49 UTC | [Pallas/Mosaic] Change Mosaic semaphores to be MemRef types This is part 1 of a change that enables allocating arrays of semaphores. It does not add any new public facing functionality and only changes how semaphores are represented in Mosaic. PiperOrigin-RevId: 595848688 | 05 January 2024, 01:56:30 UTC |
6dd5a69 | jax authors | 04 January 2024, 23:34:15 UTC | Merge pull request #19201 from jakevdp:sort-kwargs PiperOrigin-RevId: 595818057 | 04 January 2024, 23:34:15 UTC |
919233f | jax authors | 04 January 2024, 23:25:26 UTC | Merge pull request #19204 from jakevdp:doc-requirements PiperOrigin-RevId: 595817526 | 04 January 2024, 23:25:26 UTC |
ea66029 | jax authors | 04 January 2024, 23:16:25 UTC | Introduce min entry size check for compilation cache. Currently, the persistent compilation cache has a time threshold: the entry is cached only if the compilation time is less than the threshold. If compilation happens to take a while, but the resulting executable is small, there is nothing that prevents caching. This can result in a large number of small files in the cache. Introduce a size threshold. If the resulting executable's size (after serialization and compression) is less than this threshold, don't cache. This check is in addition to the compilation time check described above. Testing: new unit test, test workload. PiperOrigin-RevId: 595815611 | 04 January 2024, 23:17:05 UTC |
ccecdba | Jake VanderPlas | 04 January 2024, 23:06:12 UTC | DOC: remove unnecessary doc build requirement | 04 January 2024, 23:06:12 UTC |
8b62516 | Jake VanderPlas | 04 January 2024, 22:21:25 UTC | [array api] add stable & descending params to jnp.sort & jnp.argsort | 04 January 2024, 22:21:25 UTC |
8803774 | Peter Hawkins | 04 January 2024, 21:46:50 UTC | Fix/disable two tests failing in Windows CI. PiperOrigin-RevId: 595793411 | 04 January 2024, 21:47:32 UTC |
3ff4eb4 | Peter Hawkins | 04 January 2024, 21:26:20 UTC | [XLA:Python] Raise an AttributeError if __cuda_array_interface__ is called on various invalid buffers, rather than a RuntimeError. This makes hasattr(x, "__cuda_array_interface__") fail gracefully. In passing, also move the implementation into py_array.cc, and use an allowlist of supported types rather than a denylist. Fixes https://github.com/google/jax/issues/19134 PiperOrigin-RevId: 595788328 | 04 January 2024, 21:26:56 UTC |
ebc7af9 | Tom Cobley | 04 January 2024, 17:49:14 UTC | Fix typo in pmap docstring Docstring states: > If the pmapped function is called with fewer positional arguments than indicated by **`static_argnums`** then an error is raised. However `static_argnums` is not an argument that exists - I believe this should be corrected to `static_broadcasted_argnums`. PiperOrigin-RevId: 595731210 | 04 January 2024, 17:50:00 UTC |
326d1d2 | Jake VanderPlas | 04 January 2024, 17:25:22 UTC | jaxlib: avoid external build-time dependency on ml_dtypes Currently, the ml_dtypes C++ sources are included in the set of sources at jaxlib build time. This is unnecessary, and can lead to problematic version skew in some cases (e.g. nightly builds). PiperOrigin-RevId: 595725529 | 04 January 2024, 17:26:05 UTC |
05a9906 | Peter Hawkins | 04 January 2024, 16:09:53 UTC | Guard call to os.register_at_fork() to unbreak Windows CI. PiperOrigin-RevId: 595708288 | 04 January 2024, 16:10:32 UTC |
d56116c | Tom Hennigan | 04 January 2024, 12:07:50 UTC | Raise rather than return NotImplementedError in PmapSharding. PiperOrigin-RevId: 595662888 | 04 January 2024, 12:08:28 UTC |
981b670 | jax authors | 04 January 2024, 07:26:22 UTC | [Jax] Allow to set the python traceback frames limit. PiperOrigin-RevId: 595607107 | 04 January 2024, 07:26:55 UTC |
7bba80d | Yash Katariya | 04 January 2024, 04:51:41 UTC | Rename batch_dims variable in shard_alike PiperOrigin-RevId: 595578639 | 04 January 2024, 04:59:39 UTC |
9112afb | Sharad Vikram | 04 January 2024, 04:50:43 UTC | [Pallas/GPU] Add support for jnp.cumsum PiperOrigin-RevId: 595578503 | 04 January 2024, 04:51:25 UTC |
0d3fd53 | jax authors | 04 January 2024, 04:25:14 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/257f1592457a8036ea70f6020601018d67e943c5. PiperOrigin-RevId: 595574791 | 04 January 2024, 04:26:51 UTC |
c06c292 | jax authors | 04 January 2024, 01:09:19 UTC | Merge pull request #19186 from jakevdp:asarray-copy PiperOrigin-RevId: 595541132 | 04 January 2024, 01:09:19 UTC |
68abe0d | Yash Katariya | 04 January 2024, 00:23:49 UTC | Add correct batching rule for shard_alike PiperOrigin-RevId: 595532031 | 04 January 2024, 00:24:27 UTC |
23b9c2a | Parker Schuh | 04 January 2024, 00:11:44 UTC | Add the githash that the jaxlib was built at to __init__.py. This is to allow identifying the githash of nightlies. PiperOrigin-RevId: 595529249 | 04 January 2024, 00:12:23 UTC |
97fc213 | Jake VanderPlas | 03 January 2024, 23:20:27 UTC | [array API] support copy argument to jnp.asarray | 03 January 2024, 23:20:27 UTC |
f1be301 | jax authors | 03 January 2024, 22:20:08 UTC | Merge pull request #19183 from jakevdp:dlpack-warning PiperOrigin-RevId: 595502005 | 03 January 2024, 22:20:08 UTC |
f8b7169 | jax authors | 03 January 2024, 21:57:50 UTC | Merge pull request #19182 from jakevdp:ndarray-item PiperOrigin-RevId: 595495441 | 03 January 2024, 21:57:50 UTC |
df4e9c0 | Jake VanderPlas | 03 January 2024, 21:31:57 UTC | DOC: add warning about dlpack and buffer mutation | 03 January 2024, 21:31:57 UTC |
47e5c81 | Jake VanderPlas | 03 January 2024, 21:03:47 UTC | jnp.ndarray.item(): add args support | 03 January 2024, 21:03:47 UTC |
b613679 | jax authors | 03 January 2024, 17:00:29 UTC | [Jax] Speedup stacks traceback PiperOrigin-RevId: 595416135 | 03 January 2024, 17:01:23 UTC |
afa2f1e | Sharad Vikram | 03 January 2024, 05:53:30 UTC | [Pallas/Mosaic] Add support for nested `ref.at` PiperOrigin-RevId: 595289898 | 03 January 2024, 05:54:15 UTC |
7912e05 | jax authors | 03 January 2024, 04:18:10 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/f9171123c609ff392c6f32647369fb674a6706c0. PiperOrigin-RevId: 595273180 | 03 January 2024, 04:18:54 UTC |
c06e186 | Jake VanderPlas | 03 January 2024, 03:26:09 UTC | Error on conversion of empty arrays to boolean. PiperOrigin-RevId: 595264332 | 03 January 2024, 03:26:45 UTC |
985a042 | jax authors | 03 January 2024, 02:15:08 UTC | Merge pull request #19174 from mattjj:shmap-custom-vjp-replication-error-message PiperOrigin-RevId: 595254128 | 03 January 2024, 02:15:08 UTC |
836563f | Sharad Vikram | 02 January 2024, 23:52:57 UTC | [Pallas] Refactor indexing primitives to use NDIndexer abstraction Some notes about this change: * This change upgrades the `RefView` abstraction to store multiple indexers. This allows doing things like `ref.at[0].at[0]` to recursively create a view of a `Ref`. `RefView`s therefore encapsluate multiple `NDIndexer`s. * This generalizes most of the indexing primitive APIs (i.e. get_p, swap_p, addupdate_p) but does *not* generalize their rules. Most of the rules will raise a NotImplementedError if you use multiple `NDIndexer`s. Adding support will be done in a future CL. * With the above in mind, this change only preserves existing public facing APIs and adding actual support will involve updating the rules. PiperOrigin-RevId: 595229523 | 02 January 2024, 23:53:40 UTC |
8c5e7b2 | jax authors | 02 January 2024, 21:53:05 UTC | Merge pull request #19135 from eukub:concatenation-to-fstrings PiperOrigin-RevId: 595201881 | 02 January 2024, 21:53:05 UTC |
12e57de | Matthew Johnson | 02 January 2024, 21:26:40 UTC | [shard-map] improve error message when a custom_vjp bwd has extra psum | 02 January 2024, 21:26:40 UTC |
c0d4653 | Yash Katariya | 02 January 2024, 21:12:44 UTC | Delete sharding spec to HloSharding conversion since it's not used anymore. PiperOrigin-RevId: 595192496 | 02 January 2024, 21:13:23 UTC |
fff5ea5 | Jake VanderPlas | 02 January 2024, 21:02:46 UTC | Remove deprecated unsafe_raw_array method from PRNG keys PiperOrigin-RevId: 595190146 | 02 January 2024, 21:03:21 UTC |
e6c8901 | Sergei Lebedev | 02 January 2024, 19:54:20 UTC | Generate Python bindings for the Triton MLIR dialect The bindings are not yet included in the jaxlib wheel. I will do that in a follow up PR. PiperOrigin-RevId: 595174466 | 02 January 2024, 19:55:05 UTC |
15f4a8d | jax authors | 02 January 2024, 19:30:12 UTC | Merge pull request #19166 from jakevdp:fix-binom PiperOrigin-RevId: 595168665 | 02 January 2024, 19:30:12 UTC |
5192fca | Yash Katariya | 02 January 2024, 19:13:57 UTC | Delete `mesh_sharding_specs` from JAX PiperOrigin-RevId: 595164505 | 02 January 2024, 19:14:39 UTC |
77258cd | Jake VanderPlas | 02 January 2024, 18:53:44 UTC | stats.binom.pmf: return zero for k > n | 02 January 2024, 18:53:44 UTC |
697f17a | Yash Katariya | 02 January 2024, 18:27:23 UTC | Remove reliance on `ShardingSpec`s from `NamedSharding` to `HloSharding` conversion. PiperOrigin-RevId: 595151695 | 02 January 2024, 18:28:02 UTC |
a678015 | Adam Paszke | 02 January 2024, 16:51:20 UTC | Implement a stable serialization API for Mosaic This lets us break a dependency on standard MLIR dialects while serializing the program into HLO. The scheme is simple: we make a lightweight lazy fork of existing dialects by mangling the dialect name and otherwise keeping the structure of the ops identical. This keeps serialization and deserialization simple, for as long as the upstream dialects don't change much. If they do, we have to increment our version counter and write rules that update the IR structure. Note that this scheme only protects us from changes such as changing the attributes annotating the ops (renaming, etc.). However, it doesn't protect us from the attributes defined by a dialect from changing. Still, as far as I can tell, the only attributes we depend on are enums (which are simply plain integer attributes, so we can remap their values) and affine maps (that are unlikely to change much, I hope). This does not actually wire up the pass yet, as we are currently reorganizing the Python/C++ boundary significantly. The integration should be completed once that works is done. PiperOrigin-RevId: 595128374 | 02 January 2024, 16:51:56 UTC |
0419e01 | Adam Paszke | 02 January 2024, 11:18:57 UTC | [Mosaic] Add a pass to check operation invariants on-device This lets us easily catch things such as out-of-bounds loads or reference slices (leading to OOB DMAs or loads downstream). PiperOrigin-RevId: 595072511 | 02 January 2024, 11:19:35 UTC |
1044c4f | Adam Paszke | 02 January 2024, 10:36:46 UTC | [Pallas] Add missing shape checks in the Pallas FlashAttention kernel for TPUs Missing shape checks can cause hard to understand runtime errors caused by OOB checks inserted by XLA. We weren't verifying that the attention bias and the segment ids have the shapes we were expecting. PiperOrigin-RevId: 595065315 | 02 January 2024, 10:37:25 UTC |
b29e8b2 | jax authors | 01 January 2024, 04:13:01 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/3262826f68c347eff74892aabdfbf00f5929dac0. PiperOrigin-RevId: 594837129 | 01 January 2024, 04:13:39 UTC |
d5db08d | jax authors | 31 December 2023, 04:02:16 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/fa350aa02d5150e794095348fd3027e9724e225f. PiperOrigin-RevId: 594681196 | 31 December 2023, 04:02:55 UTC |
7961fb8 | jax authors | 30 December 2023, 03:03:44 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/3d132585dfc5a34942b65847a92664e007b58d89. PiperOrigin-RevId: 594524128 | 30 December 2023, 03:04:21 UTC |
4b76d03 | Yash Katariya | 29 December 2023, 21:05:08 UTC | Add shape of PositionalSharding to it's repr PiperOrigin-RevId: 594489540 | 29 December 2023, 21:05:47 UTC |
462ec6d | jax authors | 29 December 2023, 04:00:08 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/7c3006be64c95ac72b2c9c22409a925e1b736c03. PiperOrigin-RevId: 594358108 | 29 December 2023, 04:00:47 UTC |
5579ff2 | eukub | 28 December 2023, 16:38:13 UTC | сhanged concatenation of strings to f-strings to improve readability and unify with the rest of code | 28 December 2023, 16:38:13 UTC |
d9aba6f | jax authors | 28 December 2023, 03:58:53 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/46a9aaa3087ad12fbbe63286e15f8cfb50093070. PiperOrigin-RevId: 594143989 | 28 December 2023, 03:59:24 UTC |
a74a782 | jax authors | 27 December 2023, 06:28:50 UTC | Merge pull request #19129 from mattjj:cache-miss-explanations-3 PiperOrigin-RevId: 593928096 | 27 December 2023, 06:28:50 UTC |
9112dce | Matthew Johnson | 09 June 2023, 21:43:42 UTC | add jax.explain_cache_misses tracing cache miss explanations As part of making JAX's behavior more transparent, it must be clear not only when code is slow because it's spending all its time missing caches (and hence retracing/recompiling), but also _why_ it missed those caches. That is, just knowing (from e.g. setting jax_log_compiles) that code is retracing a lot doesn't tell the user what to do to fix things. But once the user knows that the cache misses are due to changing dtypes, or due to jit being passed a new callable object on every iteration of a loop, it's often clear what to do. And JAX can provide that information The main idea here is that pointing out which parts of the cache key differs from previously-seen keys can constitute a pretty good explanation. This PR adds an explanation mechanism. It can be enabled in a few different ways: * setting the `JAX_EXPLAIN_CACHE_MISSES` shell environment variable to something truthy; * setting the config option `jax.config.update('jax_explain_cache_misses', True)`; * using the context manager `jax._src.config.explain_cache_misses` context manager (not in public namespace yet); * when parsing command line flags with absl, using the `--jax_explain_cache_misses` flag. Co-authored-by: Yash Katariya <yashkatariya@google.com> | 27 December 2023, 05:54:27 UTC |
a055e58 | jax authors | 27 December 2023, 04:27:51 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/5a2bee1f06005dfc1812cb69e2f8967666320fb8. PiperOrigin-RevId: 593909713 | 27 December 2023, 04:28:34 UTC |
9c2b21f | Yash Katariya | 26 December 2023, 19:09:04 UTC | Add shard_alike tf_impl rule PiperOrigin-RevId: 593824294 | 26 December 2023, 19:09:49 UTC |
f9be08b | jax authors | 26 December 2023, 04:08:53 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/11afd8745b08fa124c003f5527192dfd2056df73. PiperOrigin-RevId: 593687708 | 26 December 2023, 04:09:40 UTC |
663c039 | jax authors | 24 December 2023, 23:00:53 UTC | Merge pull request #19117 from mattjj:tangent-dtypes-scan-test PiperOrigin-RevId: 593486780 | 24 December 2023, 23:00:53 UTC |
0ea9a71 | Matthew Johnson | 24 December 2023, 21:50:28 UTC | add test illustrating scale dtypes in scan | 24 December 2023, 21:50:28 UTC |
7cd934b | jax authors | 24 December 2023, 21:43:08 UTC | Merge pull request #19116 from mattjj:fix-tangent-dtypes PiperOrigin-RevId: 593479100 | 24 December 2023, 21:43:08 UTC |
0d6007c | Yash Katariya | 24 December 2023, 19:53:02 UTC | Replace usage of make_array_from_callback in full_like with shard_alike PiperOrigin-RevId: 593468331 | 24 December 2023, 19:53:37 UTC |
d80d7a7 | Matthew Johnson | 23 December 2023, 00:01:00 UTC | revive tangent dtypes test broken by d5646db d5646db, the partial rollback of #19096, left a tangent dtypes test broken. this pr fixes it, though we'd still like to re-land #19096 eventually. | 24 December 2023, 19:47:59 UTC |
c83fd97 | Christian Sigg | 24 December 2023, 05:01:53 UTC | Fix jax mlir python dependency build after https://github.com/llvm/llvm-project/commit/537b2aa264c5a9879a80289c8d123b39e520eb15 PiperOrigin-RevId: 593370604 | 24 December 2023, 05:02:29 UTC |
b413ee2 | jax authors | 23 December 2023, 03:55:04 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/2b6813b32b4093c620990f82dccccf5ebc91b0e6. PiperOrigin-RevId: 593199523 | 23 December 2023, 03:55:40 UTC |
d5646db | Matthew Johnson | 22 December 2023, 23:53:48 UTC | partial rollback of #19096 due to internal breakage (relying on jax internals) PiperOrigin-RevId: 593175212 | 22 December 2023, 23:54:26 UTC |
6524402 | jax authors | 22 December 2023, 19:41:48 UTC | Merge pull request #19097 from mattjj:tangent-dtypes-3 PiperOrigin-RevId: 593147437 | 22 December 2023, 19:41:48 UTC |
05da18a | Matthew Johnson | 22 December 2023, 01:43:31 UTC | tweaks to enable adding custom tangent dtypes tweaks to enable adding custom tangent dtypes: * fix a bug in zeros_like_shaped_array and KeyTyRules.zero to ensure `scalar_zero` is actually a scalar * upgrade the adder handler for ShapedArray to delegate to an extended dtype rule for addition * convert_element_type shouldnt blanket-disallow extended dtypes; actually that can be a key operation for working with them! instead, add new `convert_from` and `convert_to` rules. instead of letting these rules perform arbitrary logic, for now they can just return a bool indicating whether the conversion is legit; if false, an error is raised, and if true, the existing convert_element_type lowering rule just generates a ConvertElementType HLO from one physical type to the other this pr also adds a test for a custom tangent dtype of interest for plumbing quantization scales out of a backward pass | 22 December 2023, 19:33:14 UTC |
4c9241e | Yash Katariya | 22 December 2023, 06:15:12 UTC | Cache ClosedJaxpr creation to minimize cache_misses. ClosedJaxpr should always be created under a cache. PiperOrigin-RevId: 593023314 | 22 December 2023, 06:15:52 UTC |
ab79e4a | jax authors | 22 December 2023, 03:38:56 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/a312bf9798a1293fed92cf050650f8f900d0bb4b. PiperOrigin-RevId: 592999710 | 22 December 2023, 03:39:56 UTC |
b9aa589 | jax authors | 22 December 2023, 02:52:32 UTC | Merge pull request #19088 from jakevdp:array-api-unique PiperOrigin-RevId: 592994080 | 22 December 2023, 02:52:32 UTC |
32e1a0c | jax authors | 22 December 2023, 01:53:15 UTC | Merge pull request #19096 from mattjj:no-more-add-jaxval-primitive PiperOrigin-RevId: 592985946 | 22 December 2023, 01:53:15 UTC |
be3ca50 | Matthew Johnson | 21 December 2023, 05:00:08 UTC | del add_any_p and zeros_like_p, replace aval-dispatched traceable | 22 December 2023, 01:04:21 UTC |
3021e90 | jax authors | 22 December 2023, 01:03:27 UTC | Merge pull request #19093 from mattjj:not-matts-fault PiperOrigin-RevId: 592979026 | 22 December 2023, 01:03:27 UTC |
33e059f | jax authors | 21 December 2023, 23:54:47 UTC | Merge pull request #19034 from jakevdp:float8-min PiperOrigin-RevId: 592966489 | 21 December 2023, 23:54:47 UTC |
5e957c6 | Jake VanderPlas | 21 December 2023, 23:49:26 UTC | array api: add unique_* interfaces | 21 December 2023, 23:49:54 UTC |
0cf38cc | jax authors | 21 December 2023, 23:46:50 UTC | Merge pull request #19086 from jakevdp:fix-lu PiperOrigin-RevId: 592966222 | 21 December 2023, 23:46:50 UTC |
e31018b | Matthew Johnson | 21 December 2023, 23:31:28 UTC | in partial_eval_custom rule for pjit, cache ClosedJaxpr creation Anywhere we call the ClosedJaxpr constructor, we had better be under a cache. We should audit the code... Never trust comments, especially when blame says mattjj wrote them Co-authored-by: Yash Katariya <yashkatariya@google.com> | 21 December 2023, 23:35:45 UTC |
c28aa2c | jax authors | 21 December 2023, 23:25:09 UTC | Remove CPU support from compilation cache. CPU support was originally added to the compilation cache in anticipation of the availability of CPU acceleration compilation. Since this is not available and the --xla_cpu_use_xla_runtime flag has been deprecated, cleanup the code and test. Testing: test workload, revised unit test. PiperOrigin-RevId: 592962316 | 21 December 2023, 23:25:56 UTC |
95eaf55 | Jake VanderPlas | 21 December 2023, 22:37:47 UTC | linalg.lu: avoid NaNs in default lowering rule | 21 December 2023, 22:37:47 UTC |
14fe47c | jax authors | 21 December 2023, 21:48:12 UTC | Merge pull request #19090 from jakevdp:unique-equal-nan PiperOrigin-RevId: 592941710 | 21 December 2023, 21:48:12 UTC |
e3e26f2 | Jake VanderPlas | 21 December 2023, 20:37:09 UTC | jnp.unique: add support for the equal_nan keyword | 21 December 2023, 20:37:09 UTC |
655b590 | jax authors | 21 December 2023, 18:46:30 UTC | Merge pull request #19054 from jakevdp:array-api-trig-aliases PiperOrigin-RevId: 592900989 | 21 December 2023, 18:46:30 UTC |
b07d176 | jax authors | 21 December 2023, 18:23:54 UTC | Merge pull request #19084 from apaszke:mosaic-versioning PiperOrigin-RevId: 592895717 | 21 December 2023, 18:23:54 UTC |
a3c8865 | jax authors | 21 December 2023, 18:15:15 UTC | Merge pull request #19072 from jakevdp:vectorize PiperOrigin-RevId: 592893601 | 21 December 2023, 18:15:15 UTC |
dcc5020 | Adam Paszke | 21 December 2023, 17:53:36 UTC | Make Pallas work with older jaxlib Recent MLIR changes have changed the names of vector.CombiningKind enum values, so we have to support both of them in the Pallas lowering (since many people might have an older jaxlib). | 21 December 2023, 17:53:36 UTC |
42ae843 | jax authors | 21 December 2023, 16:53:10 UTC | Merge pull request #19077 from mattjj:no-cast-in-ad-util PiperOrigin-RevId: 592873263 | 21 December 2023, 16:53:10 UTC |
699565a | Sergei Lebedev | 21 December 2023, 16:04:17 UTC | Removed unused `_dtype_to_xla_type_string` PiperOrigin-RevId: 592862584 | 21 December 2023, 16:04:57 UTC |
6863569 | Matthew Johnson | 21 December 2023, 05:00:42 UTC | remove use of cast in ad_util i hate cast | 21 December 2023, 05:00:42 UTC |
f64b80f | jax authors | 21 December 2023, 03:19:50 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/70ba83748baf84463a1087d725af87ea13d57805. PiperOrigin-RevId: 592716358 | 21 December 2023, 03:20:28 UTC |
9792e00 | Yash Katariya | 21 December 2023, 01:23:49 UTC | Cleanup _find_arg_mismatch logic PiperOrigin-RevId: 592697969 | 21 December 2023, 01:24:26 UTC |
57d74d6 | Yash Katariya | 21 December 2023, 00:42:17 UTC | Always return a NamedSharding from eager shard_map PiperOrigin-RevId: 592689808 | 21 December 2023, 00:42:57 UTC |
454ffcc | jax authors | 21 December 2023, 00:13:48 UTC | Merge pull request #19008 from 8bitmp3:jax-docs-advanced-autodiff PiperOrigin-RevId: 592683399 | 21 December 2023, 00:13:48 UTC |
6afb83a | 8bitmp3 | 15 December 2023, 23:32:53 UTC | Upgrade JAX Advanced Autodiff 201 | 20 December 2023, 23:58:49 UTC |
e089c9b | David Majnemer | 20 December 2023, 23:55:14 UTC | Internal only changes. Reverts 4347950d9d018a254fd00bded54ae79df2e71556 PiperOrigin-RevId: 592679160 | 20 December 2023, 23:55:53 UTC |
ad3d743 | Jake VanderPlas | 20 December 2023, 23:38:19 UTC | jnp.vectorize: support excluding arguments by keyword | 20 December 2023, 23:38:19 UTC |
965fefd | jax authors | 20 December 2023, 23:02:02 UTC | Merge pull request #19071 from mattjj:tangent-dtypes PiperOrigin-RevId: 592664820 | 20 December 2023, 23:02:02 UTC |
7ecd22c | jax authors | 20 December 2023, 22:52:52 UTC | Exclude test_gpu_memory_allocation from pytest execution. PiperOrigin-RevId: 592664477 | 20 December 2023, 22:53:33 UTC |
ec7d28c | Matthew Johnson | 20 December 2023, 20:47:43 UTC | revise logic for tangent types of extended dtypes * remove the dead code KeyTangentTy * replace TyRules.make_tangent with TyRules.zero * removed ad.instantiate_zeros_aval, which was redundant with ad.instantiate_zeros ever since (1) we removed units and (2) we made Zero carry an aval on it * fix a bug in backward_pass where we instantiated a Zero at the primal type rather than the corresponding tangent type * fix _f_bwd in test_keyarray_custom_vjp, which had the wrong type (need to return cotangents for all inputs, we were returning a (float_tangent, key_tangent) pair instead of a (float_tangent, (float_tangent, key_tangent)) nested tuple, see #19009 for a check which catches this and hence includes the same test change We probably also need a TyRules.add for any extended dtypes that can occur as tangent dtypes, but we currently don't have any tests that exercise that (because all extended dtype tangent types are currently float0). I have some follow-up work to add such a case though! | 20 December 2023, 22:24:52 UTC |
6fcaf79 | jax authors | 20 December 2023, 21:40:32 UTC | Merge pull request #19070 from mattjj:make-hypothesis-optional PiperOrigin-RevId: 592646659 | 20 December 2023, 21:40:32 UTC |
d7cceb1 | Matthew Johnson | 20 December 2023, 21:29:55 UTC | make hypothesis use in all_gather_test.py optional I think a31129a aka cl/587963496 accidentally made hypothesis a test dependency in tests/all_gather_test.py, rather than following our existing convention as in tests/state_test.py of making it optional. I think it was an accident because there's no discussion of adding hypothesis as a test dependence on the review for that PR/CL. This PR changes tests/all_gather_test.py to follow the convention for making hypothesis optional. | 20 December 2023, 21:29:55 UTC |