https://github.com/google/jax
Revision 6f0737b46f24828fbede93477e2e83d7ed7f39d3 authored by Sharad Vikram on 22 March 2024, 23:58:45 UTC, committed by jax authors on 22 March 2024, 23:59:32 UTC
This change adds limited support for dynamic sized DMAs. Specifically, you can use `pl.ds(start, size)` where `size` is a variable in your kernel. This slice can be used in a view that is used in a DMA. For example:

```python
def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem):
  size = size_smem_ref[0]
  pltpu.async_copy(
    x_hbm_ref.at[pl.ds(0, size)],
    o_hbm_ref.at[pl.ds(0, size)], sem).wait()
```

We're doing this because we want to add limited support for dynamic shapes to Pallas, something that is usually hard to do in XLA.
We're augmenting the Slice class to support dynamic sizes, which are then plumbed down into indexing primitives like pl.load, and pltpu.async_copy.
However, values and Refs in Pallas kernels still have a static shape requirement, so we can't do dynamic loads/stores. As a result, we won't touch any abstract evaluation rules (since they will still all consume and return ShapedArrays). We can, however, do dynamically
sized DMAs between statically shaped Refs. While this isn't arbitrary dynamic shapes, we hope this enables some new interesting kernels.

PiperOrigin-RevId: 618322737
1 parent 6ffd55c
History
Tip revision: 6f0737b46f24828fbede93477e2e83d7ed7f39d3 authored by Sharad Vikram on 22 March 2024, 23:58:45 UTC
[Pallas TPU] Add support for dynamic sized (tile aligned) DMAs
Tip revision: 6f0737b
File Mode Size
.github
benchmarks
build
cloud_tpu_colabs
docs
examples
images
jax
jax_plugins
jaxlib
tests
third_party
.bazelrc -rw-r--r-- 20.8 KB
.bazelversion -rw-r--r-- 6 bytes
.editorconfig -rw-r--r-- 292 bytes
.gitignore -rw-r--r-- 355 bytes
.pre-commit-config.yaml -rw-r--r-- 1.1 KB
.readthedocs.yml -rw-r--r-- 569 bytes
AUTHORS -rw-r--r-- 313 bytes
CHANGELOG.md -rw-r--r-- 118.0 KB
CITATION.bib -rw-r--r-- 408 bytes
CONTRIBUTING.md -rw-r--r-- 150 bytes
LICENSE -rw-r--r-- 11.1 KB
README.md -rw-r--r-- 19.8 KB
WORKSPACE -rw-r--r-- 568 bytes
conftest.py -rw-r--r-- 2.5 KB
pyproject.toml -rw-r--r-- 5.7 KB
setup.py -rw-r--r-- 7.8 KB

README.md

back to top