https://github.com/google/jax
Revision 2347f61173e7b3fef2150685bac046209d8f0546 authored by Dan Foreman-Mackey on 06 August 2024, 10:53:46 UTC, committed by jax authors on 06 August 2024, 11:59:47 UTC
After switching to the FFI for this custom call, it is straightforward to support dynamic shapes in all axes (including the non batch dimensions). This does come with one potential performance caveat: when the non batch dimensions are statically known at lowering time and equal (`m == n`), we lower to the cuBLAS "batched" implementation which is significantly faster. When `m` and/or `n` are dynamic we always lower to the cuSolver version.

It would actually be straightforward to defer this decision (cuBLAS vs. cuSolver) into the kernel, and I found that that didn't have any obvious performance implications, but I haven't figured out how to do the same thing for the QR decomposition because the two versions return different outputs. I'll revisit this detail in a follow up.

With this change, the GPU and CPU lowering rules are now nearly identical. In another follow up I will consolidate those implementations.

PiperOrigin-RevId: 659894314
1 parent 23da11b
History
Tip revision: 2347f61173e7b3fef2150685bac046209d8f0546 authored by Dan Foreman-Mackey on 06 August 2024, 10:53:46 UTC
Add support for using dynamic shapes with GPU LU decomposition.
Tip revision: 2347f61
File Mode Size
.github
benchmarks
build
cloud_tpu_colabs
docs
examples
images
jax
jax_plugins
jaxlib
tests
third_party
.bazelrc -rw-r--r-- 17.0 KB
.bazelversion -rw-r--r-- 6 bytes
.editorconfig -rw-r--r-- 292 bytes
.gitignore -rw-r--r-- 355 bytes
.pre-commit-config.yaml -rw-r--r-- 1.1 KB
.readthedocs.yml -rw-r--r-- 570 bytes
AUTHORS -rw-r--r-- 313 bytes
CHANGELOG.md -rw-r--r-- 134.1 KB
CITATION.bib -rw-r--r-- 408 bytes
CONTRIBUTING.md -rw-r--r-- 150 bytes
LICENSE -rw-r--r-- 11.1 KB
README.md -rw-r--r-- 20.1 KB
WORKSPACE -rw-r--r-- 1.8 KB
conftest.py -rw-r--r-- 2.5 KB
platform_mappings -rw-r--r-- 371 bytes
pyproject.toml -rw-r--r-- 4.6 KB
setup.py -rw-r--r-- 4.1 KB

README.md

back to top