swh:1:snp:71afc25eb6e6e055a37a962e6b91010ec35e397f
Raw File
Tip revision: 7fd93027854e902263dda0e06cab5177af8ec484 authored by Yash Katariya on 29 July 2024, 19:49:08 UTC
Start JAX and jaxlib 0.4.31 release
Tip revision: 7fd9302
pyproject.toml
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"

[tool.mypy]
show_error_codes = true
disable_error_code = "attr-defined, name-defined, annotation-unchecked"
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true

[[tool.mypy.overrides]]
module = [
    "absl.*",
    "colorama.*",
    "filelock.*",
    "IPython.*",
    "numpy.*",
    "opt_einsum.*",
    "scipy.*",
    "libtpu.*",
    "jaxlib.mlir.*",
    "rich.*",
    "optax.*",
    "flatbuffers.*",
    "flax.*",
    "tensorflow.*",
    "tensorflowjs.*",
    "tensorflow.io.*",
    "tensorstore.*",
    "web_pdb.*",
    "etils.*",
    "google.colab.*",
    "pygments.*",
    "jraph.*",
    "matplotlib.*",
    "tensorboard_plugin_profile.convert.*",
    "jaxlib.*",
    "pytest.*",
    "zstandard.*",
    "jax.experimental.jax2tf.tests.flax_models",
    "jax.experimental.jax2tf.tests.back_compat_testdata",
    "setuptools.*",
    "jax_cuda12_plugin.*",
]
ignore_missing_imports = true

[tool.pytest.ini_options]
markers = [
    "multiaccelerator: indicates that a test can make use of and possibly requires multiple accelerators",
    "SlurmMultiNodeGpuTest: mark a test for Slurm multinode GPU nightly CI"
]
filterwarnings = [
    "error",
    "default:Error (reading|writing) persistent compilation cache entry for 'jit_equal'",
    "default:Error (reading|writing) persistent compilation cache entry for 'jit__lambda_'",
    "default:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",

    # TODO(jakevdp): remove when array_api_tests stabilize
    "default:.*not machine-readable.*:UserWarning",
    "default:Special cases found for .* but none were parsed.*:UserWarning",
    "default:.*is not JSON-serializable. Using the repr instead.",

    # These are transitive warnings coming from TensorFlow dependencies.
    # TODO(slebedev): Remove once we bump the minimum TensorFlow version.
    "default:The key path API is deprecated .*",
    "default:jax.xla_computation is deprecated.*:DeprecationWarning",
]
doctest_optionflags = [
    "NUMBER",
    "NORMALIZE_WHITESPACE"
]
addopts = "--doctest-glob='*.rst'"

[tool.pylint.master]
extension-pkg-whitelist = "numpy"

[tool.pylint."messages control"]
disable = [
    "missing-docstring",
    "too-many-locals",
    "invalid-name",
    "redefined-outer-name",
    "redefined-builtin",
    "protected-name",
    "no-else-return",
    "fixme",
    "protected-access",
    "too-many-arguments",
    "blacklisted-name",
    "too-few-public-methods",
    "unnecessary-lambda"
]
enable = "c-extension-no-member"

[tool.pylint.format]
indent-string=" "

[tool.ruff]
preview = true
exclude = [
    ".git",
    "build",
    "__pycache__",
]
line-length = 88
indent-width = 2
target-version = "py310"

[tool.ruff.lint]
ignore = [
    # Unnecessary collection call
    "C408",
    # Unnecessary map usage
    "C417",
    # Object names too complex
    "C901",
    # Local variable is assigned to but never used
    "F841",
    # Raise with from clause inside except block
    "B904",
    # Zip without explicit strict parameter
    "B905",
]
select = [
    "B9",
    "C",
    "F",
    "W",
    "YTT",
    "ASYNC",
    "E225",
    "E227",
    "E228",
]

[tool.ruff.lint.mccabe]
max-complexity = 18

[tool.ruff.lint.per-file-ignores]
# F811: Redefinition of unused name.
"docs/autodidax.py" = ["F811"]
# Note: we don't use jax/*.py because this matches contents of jax/_src
"__init__.py" = ["F401"]
"jax/abstract_arrays.py" = ["F401"]
"jax/ad_checkpoint.py" = ["F401"]
"jax/api_util.py" = ["F401"]
"jax/cloud_tpu_init.py" = ["F401"]
"jax/core.py" = ["F401"]
"jax/custom_batching.py" = ["F401"]
"jax/custom_derivatives.py" = ["F401"]
"jax/custom_transpose.py" = ["F401"]
"jax/debug.py" = ["F401"]
"jax/distributed.py" = ["F401"]
"jax/dlpack.py" = ["F401"]
"jax/dtypes.py" = ["F401"]
"jax/errors.py" = ["F401"]
"jax/experimental/*.py" = ["F401"]
"jax/extend/*.py" = ["F401"]
"jax/flatten_util.py" = ["F401"]
"jax/interpreters/ad.py" = ["F401"]
"jax/interpreters/batching.py" = ["F401"]
"jax/interpreters/mlir.py" = ["F401"]
"jax/interpreters/partial_eval.py" = ["F401"]
"jax/interpreters/pxla.py" = ["F401"]
"jax/interpreters/xla.py" = ["F401"]
"jax/lax/*.py" = ["F401"]
"jax/linear_util.py" = ["F401"]
"jax/monitoring.py" = ["F401"]
"jax/nn/*.py" = ["F401"]
"jax/numpy/*.py" = ["F401"]
"jax/prng.py" = ["F401"]
"jax/profiler.py" = ["F401"]
"jax/random.py" = ["F401"]
"jax/scipy/*.py" = ["F401"]
"jax/sharding.py" = ["F401"]
"jax/stages.py" = ["F401"]
"jax/test_util.py" = ["F401"]
"jax/tree_util.py" = ["F401"]
"jax/typing.py" = ["F401"]
"jax/util.py" = ["F401"]
# F821: Undefined name.
"jax/numpy/__init__.pyi" = ["F821"]
back to top