79021a4 | Enrique Piqueras | 14 September 2023, 16:38:24 UTC | Type compiler options correctly. PiperOrigin-RevId: 565394366 | 14 September 2023, 16:47:15 UTC |
a2720ee | Yash Katariya | 14 September 2023, 16:22:21 UTC | Deprecate `jax.experimental.pjit.with_sharding_constraint`. Replacement is jax.lax.with_sharding_constraint which has been available since 1 year. PiperOrigin-RevId: 565389746 | 14 September 2023, 16:23:03 UTC |
ed955ea | Brian Patton | 14 September 2023, 15:10:00 UTC | Fully unroll the `scan` in `jnp.searchsorted`, when method 'scan_unrolled' is specified. On GPU, XLA's 'scan' (fori_loop) implementation launches multiple calls to the body_fun GPU kernel, whereas a fully unrolled scan can be fused into a single kernel launch. Since we only require log-many steps, this is often quite practical, and can be a nice speedup. (from 4.5ms down to 1.5ms in my scenario.) PiperOrigin-RevId: 565371859 | 14 September 2023, 15:10:49 UTC |
bbfba9a | Peter Hawkins | 14 September 2023, 14:52:07 UTC | Remove code that disabled tests on "stream_executor" backends. These tests work on both GPU and the current (non-stream_executor) TPU runtime, so the conditions aren't needed any more. Tag a couple of tests as "multiaccelerator" since they appear to benefit from multiple devices. PiperOrigin-RevId: 565367453 | 14 September 2023, 14:52:43 UTC |
8acf597 | Adam Paszke | 14 September 2023, 09:50:08 UTC | [Mosaic] Add support for specifying estimated costs for Mosaic kernels PiperOrigin-RevId: 565310871 | 14 September 2023, 09:50:44 UTC |
cab68db | jax authors | 14 September 2023, 08:08:24 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/f33da8b05c619f9134b85dba8c0e19aa4e5422d7. PiperOrigin-RevId: 565291862 | 14 September 2023, 08:08:59 UTC |
838f59e | Tomás Longeri | 14 September 2023, 05:18:30 UTC | [Mosaic] apply_vector_layout C++ rewrite (4) Elementwise ops PiperOrigin-RevId: 565255860 | 14 September 2023, 05:19:10 UTC |
6869000 | jax authors | 13 September 2023, 23:40:02 UTC | Modify the name for JAX compilation cache adoption metric. Rename the adoption metric from '/jax/compilation_cache/tasks_using_original_cache' to '/jax/compilation_cache/tasks_using_cache' to record the number of tasks using compilation cache for both original and new implementations. We realized that one adoption metric is enough to monitor and compare the original and new cache adoption rates. To avoid confusion of the word 'original' in the metric name, we decide to change the metric name to better describe the purpose of this adoption metric. Testing: revised unit test. PiperOrigin-RevId: 565197027 | 13 September 2023, 23:44:10 UTC |
1f8cc44 | Roy Frostig | 13 September 2023, 23:33:21 UTC | deprecate `PRNGKeyArray.unsafe_raw_array` in favor of `jax.random.key_data` The latter function is also better in that its behavior is invariant to `jit`, whereas the `unsafe_raw_array` method only works in eager mode. PiperOrigin-RevId: 565195381 | 13 September 2023, 23:33:56 UTC |
cd2d419 | Sharad Vikram | 13 September 2023, 23:13:33 UTC | [Pallas] Add support for remote DMAs on TPU PiperOrigin-RevId: 565190637 | 13 September 2023, 23:14:10 UTC |
91fbf9d | Jieying Luo | 13 September 2023, 23:03:11 UTC | [PJRT C API] Set up jax xla cuda package. Add a build wheel, pyproject.toml and setup.py. The directory structure in jax repo is: jax/ └── plugins/ └── cuda/ ├── __init__.py ├── pyproject.toml └── setup.py Installed package structure is: jax_plugins/ └── xla_cuda_cu12/ ├── __init__.py └── xla_cuda_plugin.so The major cuda version will be part of the package name. The plugin wheel can be built with command: python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --bazel_options="--override_repository=xla=$HOME/xla" PiperOrigin-RevId: 565187954 | 13 September 2023, 23:03:53 UTC |
a38a152 | Peter Hawkins | 13 September 2023, 22:34:39 UTC | Disable lobpcg_test on GPU. This test is failing in CI, disable it while we debug. PiperOrigin-RevId: 565180431 | 13 September 2023, 22:44:02 UTC |
ebc24c7 | Yash Katariya | 13 September 2023, 22:33:28 UTC | Pass sharded inputs to remat offloading tests. When we execute, these inputs will be interesting to validate against the correctness of the compiler passes. PiperOrigin-RevId: 565180089 | 13 September 2023, 22:43:40 UTC |
44bd916 | Sharad Vikram | 13 September 2023, 22:32:01 UTC | [Pallas] Add support for local DMAs on TPU backend PiperOrigin-RevId: 565179693 | 13 September 2023, 22:32:44 UTC |
11c2f16 | jax authors | 13 September 2023, 21:33:56 UTC | Merge pull request #17594 from jakevdp:dep-prngkey PiperOrigin-RevId: 565163390 | 13 September 2023, 21:33:56 UTC |
306c60d | Peter Hawkins | 13 September 2023, 21:09:28 UTC | Remove references to deprecated "tpu_se" build configuration. PiperOrigin-RevId: 565156675 | 13 September 2023, 21:10:30 UTC |
4e6c1b6 | Jake VanderPlas | 13 September 2023, 18:37:49 UTC | Deprecate random.KeyArray and random.PRNGKeyArray | 13 September 2023, 21:05:42 UTC |
270cc60 | Jake VanderPlas | 13 September 2023, 18:37:43 UTC | Update internal callers to avoid PRNGKeyArray | 13 September 2023, 21:05:42 UTC |
729752b | Peter Hawkins | 13 September 2023, 20:44:21 UTC | Disable XLA detailed logging and dumping for small computations. This significantly reduces the amount of logging from XLA on TPU. PiperOrigin-RevId: 565148809 | 13 September 2023, 20:45:00 UTC |
eeb32a7 | Jake VanderPlas | 13 September 2023, 20:21:02 UTC | Finish deprecation cycle for abstract_arrays.ShapedArray & abstract_arrays.raise_to_shaped PiperOrigin-RevId: 565142019 | 13 September 2023, 20:21:46 UTC |
22ff7bd | Jake VanderPlas | 13 September 2023, 19:07:01 UTC | Finish the deprecation cycle for jnp.alltrue, jnp.sometrue, jnp.product, jnp.cumproduct These have been deprecated in JAX following similar deprecations in numpy v1.25.0 PiperOrigin-RevId: 565122288 | 13 September 2023, 19:07:36 UTC |
8340149 | Yash Katariya | 13 September 2023, 17:49:33 UTC | Check if the input which is donated is actually deleted along with the AOT check. PiperOrigin-RevId: 565098239 | 13 September 2023, 17:50:16 UTC |
6abefa1 | Roy Frostig | 13 September 2023, 16:43:14 UTC | fast dispatch for functions over typed PRNG key arrays Before this change, JAX could dispatch compiled functions over new-style (typed) RNG key arrays, but it would always do so off of the fast (C++-based) dispatch path. In other words, switching from old-style `uint32` RNG keys to new-style keys would regress dispatch times. With this change, dispatch happens on the fast path again and performance regressions ought to be minimal. We currently maintain only one pytree registry, for all registered pytree node types. We want RNG key arrays to also be treated as pytree leaves everywhere *except* during dispatch. In other words: we want operations on (typed) RNG key arrays to appear in Jaxpr, but we want to unravel those arrays into their underlying `uint32` arrays only during dispatch. To do this, we add a new internal pytree registry that dispatch respects uniquely. This registry includes all items in the default registry, but also the RNG key array type. Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 565077758 | 13 September 2023, 16:43:58 UTC |
8495128 | jax authors | 13 September 2023, 16:03:13 UTC | Merge pull request #17584 from cabbagepatchman:cabbagepatchman-fix-sp PiperOrigin-RevId: 565066492 | 13 September 2023, 16:03:13 UTC |
5a15ba9 | jax authors | 13 September 2023, 09:04:49 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/551a620dafee38b6015028244a3b906d729e5a44. PiperOrigin-RevId: 564978874 | 13 September 2023, 09:10:36 UTC |
4ded017 | Adam Paszke | 13 September 2023, 09:00:03 UTC | [Mosaic] Fix an incorrectly implemented heuristic for selecting tile sizes It did not properly account for the minimum tile size requirements on TPUv2 and v3. PiperOrigin-RevId: 564977666 | 13 September 2023, 09:00:40 UTC |
d7a6eed | Peter Hawkins | 13 September 2023, 05:07:50 UTC | Update jax2tf/tests/sharding_test.py for TPU runtime changes. Removes support for an older runtime (StreamExecutor) on TPU. PiperOrigin-RevId: 564927177 | 13 September 2023, 05:08:26 UTC |
c41d271 | Yash Katariya | 13 September 2023, 03:49:25 UTC | Add memories support to remat. This PR adds basic support to remat to allow transferring intermediates (activations) to destination memory in the forward pass. Currently JAX only support host memory kind but the API allows to transfer to other memories too. Remat will automatically load the residuals back to the source memory in the backward pass. Introduce two singletons called `Recompute`, `Saveable` and a NamedTuple (`Offloadable`) that each policy can return. Currently policies return a bool which if True means saveable else recompute on backward pass. This is a backwards compatible change i.e. policies can still return a bool. A very basic offloadable policy can look like this: ``` def policy(prim, *avals, **params): return ad_checkpoint.Offloadable(src='tpu_hbm', dst='unpinned_host') ``` Co-authored-by: Matthew Johnson <mattjj@google.com> PiperOrigin-RevId: 564914301 | 13 September 2023, 03:50:05 UTC |
125343b | cabbagepatchman | 13 September 2023, 03:39:06 UTC | Update README.md | 13 September 2023, 03:39:06 UTC |
7c0abb1 | jax authors | 13 September 2023, 01:21:31 UTC | Merge pull request #17564 from andportnoy:aportnoy/scipy-spatial-test-increase-tolerance PiperOrigin-RevId: 564891104 | 13 September 2023, 01:21:31 UTC |
d4b564a | Jevin Jiang | 13 September 2023, 00:33:47 UTC | [Mosaic] Support relayout from (1,128) to (8,128) when dst.offset is (0, 0). PiperOrigin-RevId: 564882618 | 13 September 2023, 00:35:09 UTC |
d3950b9 | jax authors | 12 September 2023, 23:46:39 UTC | Merge pull request #17568 from zhenying-liu:testdotgeneral PiperOrigin-RevId: 564871912 | 12 September 2023, 23:46:39 UTC |
90da4f1 | Jane Liu | 12 September 2023, 21:30:40 UTC | Fix an A100 nightly unit test failure on testDotGeneral() by replacing TF32 with float32 | 12 September 2023, 23:31:22 UTC |
56791eb | Jake VanderPlas | 12 September 2023, 22:59:00 UTC | lax_test: adjust TPU tolerance for igamma & friends PiperOrigin-RevId: 564859109 | 12 September 2023, 22:59:41 UTC |
a26125c | Benjamin Kramer | 12 September 2023, 21:59:08 UTC | Integrate LLVM at llvm/llvm-project@c1796be93fe5 Updates LLVM usage to match [c1796be93fe5](https://github.com/llvm/llvm-project/commit/c1796be93fe5) PiperOrigin-RevId: 564842806 | 12 September 2023, 22:00:09 UTC |
801cbef | Jevin Jiang | 12 September 2023, 21:18:53 UTC | [Mosaic] Use strided load to load one entire row more efficiently. PiperOrigin-RevId: 564831610 | 12 September 2023, 21:19:35 UTC |
c617bcb | jax authors | 12 September 2023, 21:09:02 UTC | Merge pull request #17566 from jakevdp:prngarray-type PiperOrigin-RevId: 564828726 | 12 September 2023, 21:09:02 UTC |
9d86421 | Jevin Jiang | 12 September 2023, 20:46:53 UTC | [Mosaic] Use strided store to store one row. PiperOrigin-RevId: 564821813 | 12 September 2023, 20:56:58 UTC |
ea5f126 | Jake VanderPlas | 12 September 2023, 20:46:22 UTC | [custom prng] make PRNGKeyArray a subclass of jax.Array | 12 September 2023, 20:48:12 UTC |
b20b93e | jax authors | 12 September 2023, 20:46:16 UTC | Merge pull request #17565 from jakevdp:fix-array-type PiperOrigin-RevId: 564821273 | 12 September 2023, 20:46:16 UTC |
d44b038 | Jake VanderPlas | 12 September 2023, 20:21:48 UTC | [typing] fix a few array type declarations | 12 September 2023, 20:21:48 UTC |
34ea2b2 | Andrey Portnoy | 12 September 2023, 20:12:14 UTC | Increase comparison tolerance in SciPy spatial RotationMean subtest Previous value leads to failures on A100 runners in github.com/NVIDIA/JAX-Toolbox CI: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/6144692887/job/16670611913#step:8:1014 The suspected reason is the use of TF32 math for matmuls: decorating the function with @jax.default_matmul_precision("float32") allows the test to pass. We thought it's better to loosen the tolerance but preserve the original execution mode. The fully qualified test case is tests/scipy_spatial_test.py::LaxBackedScipySpatialTransformTests::testRotationMean0 | 12 September 2023, 20:12:14 UTC |
d0df18a | jax authors | 12 September 2023, 19:57:29 UTC | Merge pull request #17562 from jakevdp:version-import PiperOrigin-RevId: 564807097 | 12 September 2023, 19:57:29 UTC |
1800015 | Jake VanderPlas | 12 September 2023, 19:27:20 UTC | Import jax.version first | 12 September 2023, 19:27:20 UTC |
63e51fe | jax authors | 12 September 2023, 19:01:21 UTC | Merge pull request #17524 from jakevdp:pref-eltype-test PiperOrigin-RevId: 564792174 | 12 September 2023, 19:01:21 UTC |
dbb0e8f | Adam Paszke | 12 September 2023, 15:04:49 UTC | [Mosaic] Add a pass for instantiating memory spaces PiperOrigin-RevId: 564723473 | 12 September 2023, 15:05:26 UTC |
7dddb50 | Peter Hawkins | 12 September 2023, 14:14:26 UTC | [XLA:Python] Remove the use_tfrt flag from make_cpu_client(). use_tfrt=True has been the default for over a year, and the flag currently does nothing. PiperOrigin-RevId: 564712316 | 12 September 2023, 14:15:04 UTC |
462c0bd | jax authors | 12 September 2023, 09:58:47 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/cd7379d9af2d1a6f0b18edf47720ec9198b39fe3. PiperOrigin-RevId: 564660543 | 12 September 2023, 09:59:25 UTC |
2a7b8e6 | Yash Katariya | 12 September 2023, 01:40:17 UTC | Add `gpu_common_utils` to build_wheel to fix the gpu wheels build PiperOrigin-RevId: 564562958 | 12 September 2023, 01:40:55 UTC |
76a5dc3 | Yash Katariya | 12 September 2023, 00:41:18 UTC | Move memories_test.py to JAX PiperOrigin-RevId: 564551723 | 12 September 2023, 00:41:55 UTC |
3e06dc8 | Ruoxin Sang | 12 September 2023, 00:07:51 UTC | Update `jax_spmd_mode` flag docstring and remove unused `allow_pjit` option. PiperOrigin-RevId: 564543943 | 12 September 2023, 00:08:35 UTC |
d4adf00 | Qiao Zhang | 11 September 2023, 23:35:00 UTC | Add default jvp and transpose rule for jax.lax.reduce_precision. PiperOrigin-RevId: 564536160 | 11 September 2023, 23:35:44 UTC |
997b35e | John QiangZhang | 11 September 2023, 23:13:09 UTC | Improve the gpu lowering error message if users forget link the gpu library. PiperOrigin-RevId: 564530960 | 11 September 2023, 23:14:18 UTC |
6c3b42d | jax authors | 11 September 2023, 19:52:47 UTC | Add flags to exclude from cache-key generation. Some flags do not affect the compilation output. These should not be part of the cache key, otherwise changing them will change the key causing an unnecessary cache miss. Synchronize the exclusions between the command-line flags and DebugOptions. Add if-this-then-that lint checks to keep them in sync. PiperOrigin-RevId: 564474189 | 11 September 2023, 19:57:45 UTC |
05d2432 | jax authors | 11 September 2023, 19:48:10 UTC | Merge pull request #17527 from ROCmSoftwarePlatform:rocm_build_updates_1 PiperOrigin-RevId: 564472739 | 11 September 2023, 19:48:10 UTC |
23e4f0b | jax authors | 11 September 2023, 19:07:48 UTC | Hash serialized topology description for new cache key generation. The original cache key generation hashes devices and backend. This is not future proof: it does not work for accelerators other than TPUs. Change this to use the serialized version of PjRtTopologyDescription which is supported for all accelerators. Note: . CPU and PjRt C API not supported as yet. . Stream Executor will not be supported. Testing: revised unit test. PiperOrigin-RevId: 564461564 | 11 September 2023, 19:08:26 UTC |
a36598b | Yash Katariya | 11 September 2023, 18:54:29 UTC | Set the jax_enable_memories flag to True. If all memory_kinds in the jaxpr are the default memory kind, then annotate_device_placement custom calls are not inserted. This allows for existing code to work without any changes. If non-default memory kind is present in the jaxpr, then we allow custom calls to be inserted. PiperOrigin-RevId: 564457393 | 11 September 2023, 18:55:09 UTC |
bfc12bd | jax authors | 11 September 2023, 10:35:42 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/96beae40a196d9d990695d430ca68b3015bda9bb. PiperOrigin-RevId: 564327668 | 11 September 2023, 10:36:30 UTC |
292deef | jax authors | 09 September 2023, 11:00:28 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/c227585959ec96a4527b1b9f9023f8d6bbe976b3. PiperOrigin-RevId: 563982388 | 09 September 2023, 11:01:14 UTC |
4091ac6 | Rahul Batra | 02 August 2023, 21:54:41 UTC | [ROCm]: Fix duplicate deps include | 08 September 2023, 22:56:59 UTC |
ef79c19 | Rahul Batra | 28 July 2023, 20:33:04 UTC | [ROCm]: Dockerfile and build script updates Add hipblaslt in Dockerfile Update docker file to default to ROCm5.6 CI scripts update to handle multiple ROCm versions | 08 September 2023, 22:56:59 UTC |
9289f32 | Jake VanderPlas | 08 September 2023, 20:07:37 UTC | Add missing preferred_element_type tests Followup to https://github.com/google/jax/pull/17506 | 08 September 2023, 20:07:37 UTC |
1c99dd5 | jax authors | 08 September 2023, 20:07:36 UTC | Merge pull request #17502 from hawkinsp:fb PiperOrigin-RevId: 563832070 | 08 September 2023, 20:07:36 UTC |
dd05048 | jax authors | 08 September 2023, 18:58:50 UTC | Merge pull request #17521 from jakevdp:more-version-tests PiperOrigin-RevId: 563813319 | 08 September 2023, 18:58:50 UTC |
34ba4f5 | jax authors | 08 September 2023, 18:48:40 UTC | Merge pull request #17519 from jakevdp:git-stderr PiperOrigin-RevId: 563811588 | 08 September 2023, 18:48:40 UTC |
bc91f2d | Jake VanderPlas | 08 September 2023, 18:33:49 UTC | Add more extensive tests for version strings | 08 September 2023, 18:33:49 UTC |
bc1ca0b | jax authors | 08 September 2023, 17:44:23 UTC | Merge pull request #17515 from jakevdp:fix-can-cast PiperOrigin-RevId: 563793237 | 08 September 2023, 17:44:23 UTC |
db6104a | Jake VanderPlas | 08 September 2023, 17:15:55 UTC | Build: catch stderr when attempting git commands | 08 September 2023, 17:15:55 UTC |
4adbc1f | jax authors | 08 September 2023, 17:12:09 UTC | Merge pull request #17511 from hawkinsp:dlpack PiperOrigin-RevId: 563783931 | 08 September 2023, 17:12:09 UTC |
c351523 | Peter Hawkins | 08 September 2023, 16:57:46 UTC | Add -lm to Linux flatbuffer builds. | 08 September 2023, 16:57:46 UTC |
09a5970 | Tomás Longeri | 08 September 2023, 16:46:46 UTC | [Mosaic] Pass hardware generation to C++ apply vector layout pass PiperOrigin-RevId: 563777135 | 08 September 2023, 16:47:19 UTC |
3ce5cb6 | Peter Hawkins | 08 September 2023, 16:17:53 UTC | Remove use of get_default_device_assignment(). This is the only caller of this API in JAX, and it can be simplified. Change in preparation for removing get_default_device_assignment() from the Python bindings. PiperOrigin-RevId: 563770199 | 08 September 2023, 16:18:32 UTC |
9e9ea5a | Jake VanderPlas | 08 September 2023, 16:04:05 UTC | jax.dtypes: fix dtypes.safe_to_cast() when output dtype is weak | 08 September 2023, 16:04:05 UTC |
832820e | Adam Paszke | 08 September 2023, 15:07:30 UTC | CI changes PiperOrigin-RevId: 563754696 | 08 September 2023, 15:08:05 UTC |
592eb44 | Adam Paszke | 08 September 2023, 14:49:40 UTC | Fix Pallas tests that broke after recent changes PiperOrigin-RevId: 563750775 | 08 September 2023, 14:50:20 UTC |
3a4b60b | Peter Hawkins | 08 September 2023, 13:18:38 UTC | Fix dlpack type signatures to match Array API spec. Fixes https://github.com/google/jax/issues/17510 | 08 September 2023, 14:12:32 UTC |
601d67a | jax authors | 08 September 2023, 11:13:56 UTC | Update XLA dependency to use revision http://github.com/openxla/xla/commit/cb2f349863479be68e73052b0e03500c0470f40d. PiperOrigin-RevId: 563710408 | 08 September 2023, 11:14:45 UTC |
366a16f | Sharad Vikram | 08 September 2023, 04:33:12 UTC | [Pallas] Add support for allocating TPU semaphores and basic semaphore ops PiperOrigin-RevId: 563632709 | 08 September 2023, 04:33:53 UTC |
b8eccb1 | Yash Katariya | 08 September 2023, 02:43:30 UTC | Remove the date check from jaxlib and jax version checks since it causes problem when jaxlib runs ahead of jax in CI (depending on timezones). PiperOrigin-RevId: 563614108 | 08 September 2023, 02:52:35 UTC |
47bf243 | jax authors | 08 September 2023, 02:43:01 UTC | Merge pull request #17506 from jakevdp:pref-el-type PiperOrigin-RevId: 563613949 | 08 September 2023, 02:43:01 UTC |
cb114f2 | Sharad Vikram | 08 September 2023, 00:08:18 UTC | [Pallas] Refactor memory space handling PiperOrigin-RevId: 563586933 | 08 September 2023, 00:08:57 UTC |
d0c4c9b | Sharad Vikram | 07 September 2023, 23:40:27 UTC | [Pallas] Add support for scoped allocations to Pallas TPU PiperOrigin-RevId: 563580548 | 07 September 2023, 23:41:01 UTC |
311dc9c | jax authors | 07 September 2023, 23:23:03 UTC | Add truncated normal initializer to jax.nn PiperOrigin-RevId: 563576354 | 07 September 2023, 23:23:42 UTC |
bfd79b8 | jax authors | 07 September 2023, 23:00:01 UTC | Support relayout of tiles in register when the layout tiling changes. PiperOrigin-RevId: 563570338 | 07 September 2023, 23:00:44 UTC |
2451f34 | Jake VanderPlas | 07 September 2023, 22:16:22 UTC | jax.numpy: add preferred_element_type argument to matmul functions | 07 September 2023, 22:16:22 UTC |
a6eed40 | Tomás Longeri | 07 September 2023, 22:07:50 UTC | [MOSAIC] apply_vector_layout C++ rewrite (3) applyLayoutOp and relayout PiperOrigin-RevId: 563556815 | 07 September 2023, 22:08:30 UTC |
52ee404 | jax authors | 07 September 2023, 21:47:37 UTC | Merge pull request #17505 from hawkinsp:windows PiperOrigin-RevId: 563550978 | 07 September 2023, 21:47:37 UTC |
bda9292 | Parker Schuh | 07 September 2023, 21:22:41 UTC | Propagate ad.Zeros to the scan body function for jax.lax.scan for the output 'ys'. Example of what this fixes: ``` def grad_fn(x): def scan_body(x, params): return x, x.sum() pred, state = jax.lax.scan(scan_body, x, None, length=2) return pred.sum(), state x = np.zeros((5, 10), dtype=np.float32) loss_grad_fn = jax.value_and_grad(grad_fn, has_aux=True) print(jax.make_jaxpr(loss_grad_fn)(x)) ``` PiperOrigin-RevId: 563544684 | 07 September 2023, 21:36:53 UTC |
de5aee8 | Peter Hawkins | 07 September 2023, 21:36:04 UTC | Fix Bazel build failures from long command lines on Windows. Fixes https://github.com/google/jax/issues/14950 | 07 September 2023, 21:36:04 UTC |
dbf1325 | Peter Hawkins | 07 September 2023, 21:19:46 UTC | Copybara import of the project: -- 3905d6123bdc22f505934242363fda426c99c4cf by Peter Hawkins <phawkins@google.com>: Update flatbuffers. Use upstream flatbuffer bazel scripts, with a couple of small patches to fix: * https://github.com/google/flatbuffers/issues/8087 (remove npm references) * https://github.com/google/flatbuffers/pull/8088 (fix flatc build failure due to main() removal by linker) COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/17502 from hawkinsp:fb 3905d6123bdc22f505934242363fda426c99c4cf PiperOrigin-RevId: 563543954 | 07 September 2023, 21:27:25 UTC |
9566112 | jax authors | 07 September 2023, 21:17:03 UTC | Merge pull request #17501 from jakevdp:safe-to-cast PiperOrigin-RevId: 563543026 | 07 September 2023, 21:17:03 UTC |
8b700fa | Jevin Jiang | 07 September 2023, 20:49:05 UTC | [Mosaic] Support relayout from (1, 128) to (8, 128). PiperOrigin-RevId: 563534657 | 07 September 2023, 20:49:44 UTC |
8412781 | Jake VanderPlas | 07 September 2023, 19:18:14 UTC | Internal: add dtypes.safe_to_cast utility & use to generate indexing warning | 07 September 2023, 19:18:14 UTC |
a9410e5 | jax authors | 07 September 2023, 17:25:45 UTC | Merge pull request #17489 from patrick-kidger:patch-5 PiperOrigin-RevId: 563472425 | 07 September 2023, 17:25:45 UTC |
4e4f8b1 | jax authors | 07 September 2023, 17:25:22 UTC | Merge pull request #17491 from hawkinsp:relnotes PiperOrigin-RevId: 563472424 | 07 September 2023, 17:25:22 UTC |
f64235a | jax authors | 07 September 2023, 17:06:32 UTC | Merge pull request #17453 from jakevdp:fix-version-string PiperOrigin-RevId: 563466394 | 07 September 2023, 17:06:32 UTC |
6f3f0d5 | Jake VanderPlas | 07 September 2023, 15:45:48 UTC | build: write appropriate version strings to build artifacts | 07 September 2023, 15:45:48 UTC |
9b447aa | Peter Hawkins | 07 September 2023, 15:36:51 UTC | Relax test tolerance to fix BCSR sparse matmul test failure on P100 GPU. PiperOrigin-RevId: 563441383 | 07 September 2023, 15:37:31 UTC |
429422d | Peter Hawkins | 07 September 2023, 14:34:53 UTC | Reverts 5fcd9265b1e20c41d684659af3d52c41f25ae2f3 PiperOrigin-RevId: 563426073 | 07 September 2023, 14:35:44 UTC |
408c657 | Peter Hawkins | 07 September 2023, 13:35:25 UTC | Add a release note about a fixed Windows crash. | 07 September 2023, 13:35:25 UTC |