Revision b81e94e16826a1c8d996b0f590b8c3407a0e6a07 authored by Jesper Nielsen on 23 September 2022, 09:50:10 UTC, committed by Jesper Nielsen on 26 October 2022, 13:02:08 UTC
1 parent f66a708
Raw File
metadata.py
# Copyright 2022 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code for determining metadata for a run.
"""
import getpass
import sys
from dataclasses import asdict, dataclass, replace
from datetime import datetime
from pathlib import Path
from platform import python_version
from socket import gethostname
from typing import Any, Mapping, Sequence, Tuple

import numpy as np
import psutil
import tensorflow as tf
from cpuinfo import get_cpu_info
from git.exc import InvalidGitRepositoryError
from git.repo import Repo
from tensorflow.python.client import device_lib  # pylint: disable=no-name-in-module

TIMESTAMP_FORMAT = "%Y%m%d_%H%M%S.%f"


@dataclass(frozen=True)
class BenchmarkMetadata:
    """
    Struct with metadata about a script execution.
    """

    suite_name: str
    argv: Sequence[str]
    user: str
    hostname: str
    timestamp: str
    py_ver: str
    tf_ver: str
    np_ver: str
    ram: int
    cpu_name: str
    cpu_count: int
    cpu_frequency: int
    gpu_names: Sequence[str]
    git_branch_name: str
    git_commit: str
    run_id: str

    @staticmethod
    def create(suite_name: str) -> "BenchmarkMetadata":
        """
        Collects information about the current environment and creates a new `BenchmarkMetadata`
        with it.
        """
        suite_name = suite_name
        argv = sys.argv
        script_name = Path(argv[0]).stem
        user = getpass.getuser()
        hostname = gethostname()
        timestamp = _get_timestamp()
        py_ver = python_version()
        tf_ver = tf.__version__
        np_ver = np.__version__
        ram = psutil.virtual_memory().total
        cpu_name, cpu_count, cpu_frequency = _get_cpu_name_count_frequency()
        gpu_names = _get_gpu_names()
        git_branch_name, git_commit = _get_git_branch_and_commit()
        run_id = f"{script_name}_{suite_name}_{git_branch_name}_{timestamp}".replace("/", "_")
        return BenchmarkMetadata(
            suite_name=suite_name,
            argv=argv,
            user=user,
            hostname=hostname,
            timestamp=timestamp,
            py_ver=py_ver,
            tf_ver=tf_ver,
            np_ver=np_ver,
            ram=ram,
            cpu_name=cpu_name,
            cpu_count=cpu_count,
            cpu_frequency=cpu_frequency,
            gpu_names=gpu_names,
            git_branch_name=git_branch_name,
            git_commit=git_commit,
            run_id=run_id,
        )

    def for_shard(self) -> "BenchmarkMetadata":
        """
        Update this metadata with current machine information.

        Used in sharding, to create a machine-specific metadata, while retaining the process
        timestamp and id.
        """
        new_metadata = BenchmarkMetadata.create(self.suite_name)
        kwargs = asdict(new_metadata)
        keep_new_fields = [
            "argv",
            "user",
            "hostname",
            "ram",
            "cpu_name",
            "cpu_count",
            "cpu_frequency",
            "gpu_names",
        ]
        keep_old_fields = ["timestamp", "run_id"]
        must_match_fields = [
            "suite_name",
            "py_ver",
            "tf_ver",
            "np_ver",
            "git_branch_name",
            "git_commit",
        ]
        for field in keep_old_fields:
            del kwargs[field]
        for field in must_match_fields:
            assert getattr(self, field) == kwargs[field], (
                f"Field {field} must match between new and old metadata."
                f" Found {getattr(self, field)} and {kwargs[field]}"
            )
            del kwargs[field]
        assert keep_new_fields == list(kwargs)

        return replace(self, **kwargs)

    @property
    def gpu_name(self) -> str:
        if not self.gpu_names:
            return "No GPU"
        return " + ".join(self.gpu_names)

    def to_json(self) -> Mapping[str, Any]:
        return asdict(self)

    @staticmethod
    def from_json(data: Mapping[str, Any]) -> "BenchmarkMetadata":
        # This only works as long as our field types are sufficiently simple. We may need to update
        # this in the future:
        return BenchmarkMetadata(**data)


def _get_timestamp() -> str:
    """ Get the current date/time. """
    return datetime.utcnow().strftime(TIMESTAMP_FORMAT)


def parse_timestamp(timestamp: str) -> datetime:
    """ Parse a timestamp, as formatted in :class:`BenchmarkMetadata` into a ``datetime``. """
    return datetime.strptime(timestamp, TIMESTAMP_FORMAT)


def _get_cpu_name_count_frequency() -> Tuple[str, int, int]:
    """ Get the name, number and clock frequency of the available CPU. """
    cpu_info = get_cpu_info()
    frequency, what = cpu_info["hz_actual"]
    assert what == 0, "I have no idea what this is, but it seem to always be 0..."
    return cpu_info["brand_raw"], cpu_info["count"], frequency


def _get_gpu_names() -> Sequence[str]:
    """ Get a list of GPUs that are available to TensorFlow. """
    result = []
    for device in device_lib.list_local_devices():
        if device.device_type != "GPU":
            continue
        desc = device.physical_device_desc

        fields = desc.split(",")
        for field in fields:
            name, value = field.split(":", maxsplit=1)
            name = name.strip()
            value = value.strip()
            if name == "name":
                result.append(value)
    return result


def _get_git_branch_and_commit() -> Tuple[str, str]:
    """ Get a `(branch_name, commit_id)` tuple from the current direcory. """
    branch_name = "NO_BRANCH"
    commit = "NO_COMMIT"
    try:
        repo = Repo(__file__, search_parent_directories=True)
        try:
            branch_name = str(repo.active_branch)
        except TypeError:
            pass  # Keep current/default branch_name
        commit = str(repo.commit())
        if repo.is_dirty():
            commit += " + uncomitted changes"
    except InvalidGitRepositoryError:
        pass  # Keep current/default branch_name and commit
    return branch_name, commit
back to top