https://github.com/google/jax
Revision 231495166929be4a6ee3a0fd843858abeeca3694 authored by Yash Katariya on 07 July 2022, 17:41:27 UTC, committed by jax authors on 07 July 2022, 17:41:52 UTC
* All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
1 parent 88c1e7d
Tip revision: 231495166929be4a6ee3a0fd843858abeeca3694 authored by Yash Katariya on 07 July 2022, 17:41:27 UTC
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL:
Convert everything in pjit to the `Sharding` interface. The following contains the things that have changed in this CL:
Tip revision: 2314951
File | Mode | Size |
---|---|---|
.github | ||
benchmarks | ||
build | ||
cloud_tpu_colabs | ||
docs | ||
examples | ||
images | ||
jax | ||
jaxlib | ||
tests | ||
third_party | ||
.bazelrc | -rw-r--r-- | 14.0 KB |
.bazelversion | -rw-r--r-- | 6 bytes |
.gitignore | -rw-r--r-- | 291 bytes |
.pre-commit-config.yaml | -rw-r--r-- | 563 bytes |
.readthedocs.yml | -rw-r--r-- | 569 bytes |
CHANGELOG.md | -rw-r--r-- | 62.9 KB |
CITATION.bib | -rw-r--r-- | 408 bytes |
CONTRIBUTING.md | -rw-r--r-- | 150 bytes |
LICENSE | -rw-r--r-- | 11.1 KB |
README.md | -rw-r--r-- | 23.3 KB |
WORKSPACE | -rw-r--r-- | 1.7 KB |
conftest.py | -rw-r--r-- | 856 bytes |
mypy.ini | -rw-r--r-- | 618 bytes |
pylintrc | -rw-r--r-- | 1.6 KB |
pytest.ini | -rw-r--r-- | 1.4 KB |
setup.cfg | -rw-r--r-- | 486 bytes |
setup.py | -rw-r--r-- | 3.4 KB |
![swh spinner](/static/img/swh-spinner.gif)
Computing file changes ...