swh:1:snp:71afc25eb6e6e055a37a962e6b91010ec35e397f
Raw File
Tip revision: e27af48510b57e5d9f1ae17f256bd67fa9db7941 authored by jax authors on 06 October 2023, 20:21:49 UTC
Merge pull request #17989 from skye:version
Tip revision: e27af48
version.py
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is included as part of both jax and jaxlib. It is also
# eval()-ed by setup.py, so it should not have any dependencies.
from __future__ import annotations

import datetime
import os
import pathlib
import subprocess

_version = "0.4.18"
# The following line is overwritten by build scripts in distributions &
# releases. Do not modify this manually, or jax/jaxlib build will fail.
_release_version: str | None = None


def _get_version_string() -> str:
  # The build/source distribution for jax & jaxlib overwrites _release_version.
  # In this case we return it directly.
  if _release_version is not None:
    return _release_version
  return _version_from_git_tree(_version) or _version_from_todays_date(_version)


def _version_from_todays_date(base_version: str) -> str:
  datestring = datetime.date.today().strftime("%Y%m%d")
  return f"{base_version}.dev{datestring}"


def _version_from_git_tree(base_version: str) -> str | None:
  try:
    root_directory = os.path.dirname(os.path.realpath(__file__))

    # Get date string from date of most recent git commit.
    p = subprocess.Popen(["git", "show", "-s", "--format=%at", "HEAD"],
                         cwd=root_directory,
                         stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, _ = p.communicate()
    timestamp = int(stdout.decode().strip())
    datestring = datetime.date.fromtimestamp(timestamp).strftime("%Y%m%d")
    assert datestring.isnumeric()

    # Get commit hash from most recent git commit.
    p = subprocess.Popen(["git", "describe", "--long", "--always"],
                         cwd=root_directory,
                         stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, _ = p.communicate()
    commit_hash = stdout.decode().strip().rsplit('-', 1)[-1]
    assert commit_hash.isalnum()
  except:
    return None
  else:
    return f"{base_version}.dev{datestring}+{commit_hash}"


def _get_version_for_build() -> str:
  """Determine the version at build time.

  The returned version string depends on which environment variables are set:
  - if JAX_RELEASE or JAXLIB_RELEASE are set: version looks like "0.4.16"
  - if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906"
  - if none are set: version looks like "0.4.16.dev20230906+ge58560fdc
  """
  if _release_version is not None:
    return _release_version
  if os.environ.get('JAX_NIGHTLY') or os.environ.get('JAXLIB_NIGHTLY'):
    return _version_from_todays_date(_version)
  if os.environ.get('JAX_RELEASE') or os.environ.get('JAXLIB_RELEASE'):
    return _version
  return _version_from_git_tree(_version) or _version_from_todays_date(_version)


def _write_version(fname: str) -> None:
  """Used by setup.py to write the specified version info into the source tree."""
  release_version = _get_version_for_build()
  old_version_string = "_release_version: str | None = None"
  new_version_string = f"_release_version: str = {release_version!r}"
  fhandle = pathlib.Path(fname)
  contents = fhandle.read_text()
  # Expect two occurrences: one above, and one here.
  if contents.count(old_version_string) != 2:
    raise RuntimeError(f"Build: could not find {old_version_string!r} in {fname}")
  contents = contents.replace(old_version_string, new_version_string)
  fhandle.write_text(contents)


def _get_cmdclass(pkg_source_path):
  from setuptools.command.build_py import build_py as build_py_orig  # pytype: disable=import-error
  from setuptools.command.sdist import sdist as sdist_orig  # pytype: disable=import-error

  class _build_py(build_py_orig):
    def run(self):
      super().run()
      if _release_version is None:
        _write_version(os.path.join(self.build_lib, pkg_source_path,
                                    os.path.basename(__file__)))

  class _sdist(sdist_orig):
    def make_release_tree(self, base_dir, files):
      super().make_release_tree(base_dir, files)
      if _release_version is None:
        _write_version(os.path.join(base_dir, pkg_source_path,
                                    os.path.basename(__file__)))

  return dict(sdist=_sdist, build_py=_build_py)


__version__ = _get_version_string()
_minimum_jaxlib_version = "0.4.14"

def _version_as_tuple(version_str):
  return tuple(int(i) for i in version_str.split(".") if i.isdigit())

__version_info__ = _version_as_tuple(__version__)
_minimum_jaxlib_version_info = _version_as_tuple(_minimum_jaxlib_version)
back to top