https://github.com/google/jax
Revision 231495166929be4a6ee3a0fd843858abeeca3694 authored by Yash Katariya on 07 July 2022, 17:41:27 UTC, committed by jax authors on 07 July 2022, 17:41:52 UTC
* All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs.

* `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances.
  * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled.
  * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used.

* Checking of sharding with `aval` has a handler system to deal with sharding instances.
  * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding.

* `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us.

* _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL.
  * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too.
  * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface.

* `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`.
  * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998

* `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach.
  * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done.

PiperOrigin-RevId: 459548974
1 parent 88c1e7d
History
Tip revision: 231495166929be4a6ee3a0fd843858abeeca3694 authored by Yash Katariya on 07 July 2022, 17:41:27 UTC
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL:
Tip revision: 2314951
File Mode Size
.github
benchmarks
build
cloud_tpu_colabs
docs
examples
images
jax
jaxlib
tests
third_party
.bazelrc -rw-r--r-- 14.0 KB
.bazelversion -rw-r--r-- 6 bytes
.gitignore -rw-r--r-- 291 bytes
.pre-commit-config.yaml -rw-r--r-- 563 bytes
.readthedocs.yml -rw-r--r-- 569 bytes
CHANGELOG.md -rw-r--r-- 62.9 KB
CITATION.bib -rw-r--r-- 408 bytes
CONTRIBUTING.md -rw-r--r-- 150 bytes
LICENSE -rw-r--r-- 11.1 KB
README.md -rw-r--r-- 23.3 KB
WORKSPACE -rw-r--r-- 1.7 KB
conftest.py -rw-r--r-- 856 bytes
mypy.ini -rw-r--r-- 618 bytes
pylintrc -rw-r--r-- 1.6 KB
pytest.ini -rw-r--r-- 1.4 KB
setup.cfg -rw-r--r-- 486 bytes
setup.py -rw-r--r-- 3.4 KB

README.md

back to top