https://github.com/google/jax
Revision 6f79093cffdc5a14ab1cedaeef269482908b1897 authored by Jackson Stokes on 10 May 2024, 00:33:20 UTC, committed by jax authors on 10 May 2024, 00:34:21 UTC
Previously, output streaming took a top-down approach which indiscriminately checks if a MoveToHost custom call would trace down to an output marked with host memory space. This did not work when a dynamic-update-slice existed between the MTH call and the output. This CL fixes this problem by handling output streaming before other MTH calls, while also improving efficiency with the bottoms-up approach so we only trace a single path in the graph.

PiperOrigin-RevId: 632318740
1 parent a9460f2
History
Tip revision: 6f79093cffdc5a14ab1cedaeef269482908b1897 authored by Jackson Stokes on 10 May 2024, 00:33:20 UTC
[XLA:TPU] Support output streaming and refactor TryOutputStreaming into a bottoms-up approach.
Tip revision: 6f79093
File Mode Size
.github
benchmarks
build
cloud_tpu_colabs
docs
examples
images
jax
jax_plugins
jaxlib
tests
third_party
.bazelrc -rw-r--r-- 18.4 KB
.bazelversion -rw-r--r-- 6 bytes
.editorconfig -rw-r--r-- 292 bytes
.gitignore -rw-r--r-- 355 bytes
.pre-commit-config.yaml -rw-r--r-- 1.1 KB
.readthedocs.yml -rw-r--r-- 570 bytes
AUTHORS -rw-r--r-- 313 bytes
CHANGELOG.md -rw-r--r-- 124.6 KB
CITATION.bib -rw-r--r-- 408 bytes
CONTRIBUTING.md -rw-r--r-- 150 bytes
LICENSE -rw-r--r-- 11.1 KB
README.md -rw-r--r-- 20.0 KB
WORKSPACE -rw-r--r-- 568 bytes
conftest.py -rw-r--r-- 2.5 KB
platform_mappings -rw-r--r-- 371 bytes
pyproject.toml -rw-r--r-- 5.9 KB
setup.py -rw-r--r-- 7.2 KB

README.md

back to top