Raw File
buffer.py
# Copyright (C) 2019-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

from functools import partial
from typing import Dict, Iterable, List, Optional

from swh.core.utils import grouper
from swh.model.model import Content, BaseModel
from swh.storage import get_storage


class BufferingProxyStorage:
    """Storage implementation in charge of accumulating objects prior to
       discussing with the "main" storage.

    Sample configuration use case for buffering storage:

    .. code-block:: yaml

        storage:
          cls: buffer
          args:
            storage:
              cls: remote
              args: http://storage.internal.staging.swh.network:5002/
            min_batch_size:
              content: 10000
              content_bytes: 100000000
              skipped_content: 10000
              directory: 5000
              revision: 1000
              release: 10000

    """

    def __init__(self, storage, min_batch_size=None):
        self.storage = get_storage(**storage)

        if min_batch_size is None:
            min_batch_size = {}

        self.min_batch_size = {
            "content": min_batch_size.get("content", 10000),
            "content_bytes": min_batch_size.get("content_bytes", 100 * 1024 * 1024),
            "skipped_content": min_batch_size.get("skipped_content", 10000),
            "directory": min_batch_size.get("directory", 25000),
            "revision": min_batch_size.get("revision", 100000),
            "release": min_batch_size.get("release", 100000),
        }
        self.object_types = [
            "content",
            "skipped_content",
            "directory",
            "revision",
            "release",
        ]
        self._objects = {k: {} for k in self.object_types}

    def __getattr__(self, key):
        if key.endswith("_add"):
            object_type = key.rsplit("_", 1)[0]
            if object_type in self.object_types:
                return partial(self.object_add, object_type=object_type, keys=["id"],)
        if key == "storage":
            raise AttributeError(key)
        return getattr(self.storage, key)

    def content_add(self, content: Iterable[Content]) -> Dict:
        """Enqueue contents to write to the storage.

        Following policies apply:

            - First, check if the queue's threshold is hit.
              If it is flush content to the storage.

            - If not, check if the total size of enqueued contents's
              threshold is hit. If it is flush content to the storage.

        """
        content = list(content)
        s = self.object_add(
            content,
            object_type="content",
            keys=["sha1", "sha1_git", "sha256", "blake2s256"],
        )
        if not s:
            buffer_ = self._objects["content"].values()
            total_size = sum(c.length for c in buffer_)
            if total_size >= self.min_batch_size["content_bytes"]:
                return self.flush(["content"])

        return s

    def skipped_content_add(self, content: Iterable[Content]) -> Dict:
        return self.object_add(
            content,
            object_type="skipped_content",
            keys=["sha1", "sha1_git", "sha256", "blake2s256"],
        )

    def flush(self, object_types: Optional[Iterable[str]] = None) -> Dict:
        summary: Dict[str, int] = self.storage.flush(object_types)
        if object_types is None:
            object_types = self.object_types
        for object_type in object_types:
            buffer_ = self._objects[object_type]
            batches = grouper(buffer_.values(), n=self.min_batch_size[object_type])
            for batch in batches:
                add_fn = getattr(self.storage, "%s_add" % object_type)
                s = add_fn(batch)
                summary = {k: v + summary.get(k, 0) for k, v in s.items()}
            buffer_.clear()

        return summary

    def object_add(
        self, objects: Iterable[BaseModel], *, object_type: str, keys: List[str]
    ) -> Dict:
        """Enqueue objects to write to the storage. This checks if the queue's
           threshold is hit. If it is actually write those to the storage.

        """
        buffer_ = self._objects[object_type]
        threshold = self.min_batch_size[object_type]
        for obj in objects:
            obj_key = tuple(getattr(obj, key) for key in keys)
            buffer_[obj_key] = obj
        if len(buffer_) >= threshold:
            return self.flush()

        return {}

    def clear_buffers(self, object_types: Optional[Iterable[str]] = None) -> None:
        """Clear objects from current buffer.

        WARNING:

            data that has not been flushed to storage will be lost when this
            method is called. This should only be called when `flush` fails and
            you want to continue your processing.

        """
        if object_types is None:
            object_types = self.object_types

        for object_type in object_types:
            q = self._objects[object_type]
            q.clear()

        return self.storage.clear_buffers(object_types)
back to top