Revision 473d1c3685570e505bed58512afa05fb7d7a8935 authored by Skye Wanderman-Milne on 28 March 2023, 20:42:51 UTC, committed by Skye Wanderman-Milne on 28 March 2023, 22:28:13 UTC
I forgot that the default setting is actually in jaxlib:
https://github.com/openxla/xla/blob/fbe9a80fdb8c429e8a175962459da348cd560a50/xla/python/xla_client.py#L135

To be able to make this change as a jax-only release, I manually set
the env var on Cloud TPU if it isn't already set.
1 parent 4061bbb
History
File Mode Size
_src
example_libraries
experimental
image
interpreters
lax
lib
nn
numpy
ops
scipy
tools
BUILD -rw-r--r-- 11.4 KB
__init__.py -rw-r--r-- 9.1 KB
abstract_arrays.py -rw-r--r-- 763 bytes
ad_checkpoint.py -rw-r--r-- 711 bytes
api_util.py -rw-r--r-- 769 bytes
cloud_tpu_init.py -rw-r--r-- 633 bytes
collect_profile.py -rw-r--r-- 4.6 KB
config.py -rw-r--r-- 696 bytes
core.py -rw-r--r-- 7.0 KB
custom_batching.py -rw-r--r-- 657 bytes
custom_derivatives.py -rw-r--r-- 1.2 KB
custom_transpose.py -rw-r--r-- 644 bytes
debug.py -rw-r--r-- 935 bytes
distributed.py -rw-r--r-- 638 bytes
dlpack.py -rw-r--r-- 653 bytes
dtypes.py -rw-r--r-- 1.1 KB
errors.py -rw-r--r-- 1.1 KB
flatten_util.py -rw-r--r-- 645 bytes
jaxpr_util.py -rw-r--r-- 6.4 KB
linear_util.py -rw-r--r-- 1.0 KB
monitoring.py -rw-r--r-- 1.1 KB
prng.py -rw-r--r-- 997 bytes
profiler.py -rw-r--r-- 1.1 KB
py.typed -rw-r--r-- 0 bytes
random.py -rw-r--r-- 6.9 KB
sharding.py -rw-r--r-- 2.0 KB
stages.py -rw-r--r-- 1.2 KB
test_util.py -rw-r--r-- 835 bytes
tree_util.py -rw-r--r-- 3.0 KB
typing.py -rw-r--r-- 2.9 KB
util.py -rw-r--r-- 1.1 KB
version.py -rw-r--r-- 1003 bytes

back to top