swh:1:snp:eb70f1f85391e4b077c211bec36af0061c4bf937
Raw File
Tip revision: ba47d8da0a9a05f260af8454c95598991e14f116 authored by Nicolas Dandrimont on 18 October 2018, 16:32:10 UTC
New upstream version 0.0.108
Tip revision: ba47d8d
db.py
# Copyright (C) 2015-2017  The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information

import binascii
import datetime
import enum
import functools
import json
import os
import select
import threading

from contextlib import contextmanager

import psycopg2
import psycopg2.extras

from .db_utils import execute_values_generator

TMP_CONTENT_TABLE = 'tmp_content'


psycopg2.extras.register_uuid()


def stored_procedure(stored_proc):
    """decorator to execute remote stored procedure, specified as argument

    Generally, the body of the decorated function should be empty. If it is
    not, the stored procedure will be executed first; the function body then.

    """
    def wrap(meth):
        @functools.wraps(meth)
        def _meth(self, *args, **kwargs):
            cur = kwargs.get('cur', None)
            self._cursor(cur).execute('SELECT %s()' % stored_proc)
            meth(self, *args, **kwargs)
        return _meth
    return wrap


def jsonize(value):
    """Convert a value to a psycopg2 JSON object if necessary"""
    if isinstance(value, dict):
        return psycopg2.extras.Json(value)

    return value


def entry_to_bytes(entry):
    """Convert an entry coming from the database to bytes"""
    if isinstance(entry, memoryview):
        return entry.tobytes()
    if isinstance(entry, list):
        return [entry_to_bytes(value) for value in entry]
    return entry


def line_to_bytes(line):
    """Convert a line coming from the database to bytes"""
    if not line:
        return line
    if isinstance(line, dict):
        return {k: entry_to_bytes(v) for k, v in line.items()}
    return line.__class__(entry_to_bytes(entry) for entry in line)


def cursor_to_bytes(cursor):
    """Yield all the data from a cursor as bytes"""
    yield from (line_to_bytes(line) for line in cursor)


def execute_values_to_bytes(*args, **kwargs):
    for line in execute_values_generator(*args, **kwargs):
        yield line_to_bytes(line)


class BaseDb:
    """Base class for swh.storage.*Db.

    cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb

    """

    @classmethod
    def connect(cls, *args, **kwargs):
        """factory method to create a DB proxy

        Accepts all arguments of psycopg2.connect; only some specific
        possibilities are reported below.

        Args:
            connstring: libpq2 connection string

        """
        conn = psycopg2.connect(*args, **kwargs)
        return cls(conn)

    @classmethod
    def from_pool(cls, pool):
        return cls(pool.getconn(), pool=pool)

    def _cursor(self, cur_arg):
        """get a cursor: from cur_arg if given, or a fresh one otherwise

        meant to avoid boilerplate if/then/else in methods that proxy stored
        procedures

        """
        if cur_arg is not None:
            return cur_arg
        # elif self.cur is not None:
        #     return self.cur
        else:
            return self.conn.cursor()

    def __init__(self, conn, pool=None):
        """create a DB proxy

        Args:
            conn: psycopg2 connection to the SWH DB
            pool: psycopg2 pool of connections

        """
        self.conn = conn
        self.pool = pool

    def __del__(self):
        if self.pool:
            self.pool.putconn(self.conn)

    @contextmanager
    def transaction(self):
        """context manager to execute within a DB transaction

        Yields:
            a psycopg2 cursor

        """
        with self.conn.cursor() as cur:
            try:
                yield cur
                self.conn.commit()
            except Exception:
                if not self.conn.closed:
                    self.conn.rollback()
                raise

    def copy_to(self, items, tblname, columns, cur=None, item_cb=None):
        """Copy items' entries to table tblname with columns information.

        Args:
            items (dict): dictionary of data to copy over tblname
            tblname (str): Destination table's name
            columns ([str]): keys to access data in items and also the
              column names in the destination table.
            item_cb (fn): optional function to apply to items's entry

        """
        def escape(data):
            if data is None:
                return ''
            if isinstance(data, bytes):
                return '\\x%s' % binascii.hexlify(data).decode('ascii')
            elif isinstance(data, str):
                return '"%s"' % data.replace('"', '""')
            elif isinstance(data, datetime.datetime):
                # We escape twice to make sure the string generated by
                # isoformat gets escaped
                return escape(data.isoformat())
            elif isinstance(data, dict):
                return escape(json.dumps(data))
            elif isinstance(data, list):
                return escape("{%s}" % ','.join(escape(d) for d in data))
            elif isinstance(data, psycopg2.extras.Range):
                # We escape twice here too, so that we make sure
                # everything gets passed to copy properly
                return escape(
                    '%s%s,%s%s' % (
                        '[' if data.lower_inc else '(',
                        '-infinity' if data.lower_inf else escape(data.lower),
                        'infinity' if data.upper_inf else escape(data.upper),
                        ']' if data.upper_inc else ')',
                    )
                )
            elif isinstance(data, enum.IntEnum):
                return escape(int(data))
            else:
                # We don't escape here to make sure we pass literals properly
                return str(data)

        read_file, write_file = os.pipe()

        def writer():
            cursor = self._cursor(cur)
            with open(read_file, 'r') as f:
                cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % (
                    tblname, ', '.join(columns)), f)

        write_thread = threading.Thread(target=writer)
        write_thread.start()

        try:
            with open(write_file, 'w') as f:
                for d in items:
                    if item_cb is not None:
                        item_cb(d)
                    line = [escape(d.get(k)) for k in columns]
                    f.write(','.join(line))
                    f.write('\n')
        finally:
            # No problem bubbling up exceptions, but we still need to make sure
            # we finish copying, even though we're probably going to cancel the
            # transaction.
            write_thread.join()

    def mktemp(self, tblname, cur=None):
        self._cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,))


class Db(BaseDb):
    """Proxy to the SWH DB, with wrappers around stored procedures

    """
    def mktemp_dir_entry(self, entry_type, cur=None):
        self._cursor(cur).execute('SELECT swh_mktemp_dir_entry(%s)',
                                  (('directory_entry_%s' % entry_type),))

    @stored_procedure('swh_mktemp_revision')
    def mktemp_revision(self, cur=None): pass

    @stored_procedure('swh_mktemp_release')
    def mktemp_release(self, cur=None): pass

    @stored_procedure('swh_mktemp_snapshot_branch')
    def mktemp_snapshot_branch(self, cur=None): pass

    def register_listener(self, notify_queue, cur=None):
        """Register a listener for NOTIFY queue `notify_queue`"""
        self._cursor(cur).execute("LISTEN %s" % notify_queue)

    def listen_notifies(self, timeout):
        """Listen to notifications for `timeout` seconds"""
        if select.select([self.conn], [], [], timeout) == ([], [], []):
            return
        else:
            self.conn.poll()
            while self.conn.notifies:
                yield self.conn.notifies.pop(0)

    @stored_procedure('swh_content_add')
    def content_add_from_temp(self, cur=None): pass

    @stored_procedure('swh_directory_add')
    def directory_add_from_temp(self, cur=None): pass

    @stored_procedure('swh_skipped_content_add')
    def skipped_content_add_from_temp(self, cur=None): pass

    @stored_procedure('swh_revision_add')
    def revision_add_from_temp(self, cur=None): pass

    @stored_procedure('swh_release_add')
    def release_add_from_temp(self, cur=None): pass

    def content_update_from_temp(self, keys_to_update, cur=None):
        cur = self._cursor(cur)
        cur.execute("""select swh_content_update(ARRAY[%s] :: text[])""" %
                    keys_to_update)

    content_get_metadata_keys = [
        'sha1', 'sha1_git', 'sha256', 'blake2s256', 'length', 'status']

    skipped_content_keys = [
        'sha1', 'sha1_git', 'sha256', 'blake2s256',
        'length', 'reason', 'status', 'origin']

    def content_get_metadata_from_sha1s(self, sha1s, cur=None):
        cur = self._cursor(cur)

        yield from execute_values_to_bytes(
            cur, """
            select t.sha1, %s from (values %%s) as t (sha1)
            left join content using (sha1)
            """ % ', '.join(self.content_get_metadata_keys[1:]),
            ((sha1,) for sha1 in sha1s),
        )

    content_hash_keys = ['sha1', 'sha1_git', 'sha256', 'blake2s256']

    def content_missing_from_list(self, contents, cur=None):
        cur = self._cursor(cur)

        keys = ', '.join(self.content_hash_keys)
        equality = ' AND '.join(
            ('t.%s = c.%s' % (key, key))
            for key in self.content_hash_keys
        )

        yield from execute_values_to_bytes(
            cur, """
            SELECT %s
            FROM (VALUES %%s) as t(%s)
            WHERE NOT EXISTS (
                SELECT 1 FROM content c
                WHERE %s
            )
            """ % (keys, keys, equality),
            (tuple(c[key] for key in self.content_hash_keys) for c in contents)
        )

    def content_missing_per_sha1(self, sha1s, cur=None):
        cur = self._cursor(cur)

        yield from execute_values_to_bytes(cur, """
        SELECT t.sha1 FROM (VALUES %s) AS t(sha1)
        WHERE NOT EXISTS (
            SELECT 1 FROM content c WHERE c.sha1 = t.sha1
        )""", ((sha1,) for sha1 in sha1s))

    def skipped_content_missing_from_temp(self, cur=None):
        cur = self._cursor(cur)

        cur.execute("""SELECT sha1, sha1_git, sha256, blake2s256
                       FROM swh_skipped_content_missing()""")

        yield from cursor_to_bytes(cur)

    def snapshot_exists(self, snapshot_id, cur=None):
        """Check whether a snapshot with the given id exists"""
        cur = self._cursor(cur)

        cur.execute("""SELECT 1 FROM snapshot where id=%s""", (snapshot_id,))

        return bool(cur.fetchone())

    def snapshot_add(self, origin, visit, snapshot_id, cur=None):
        """Add a snapshot for origin/visit from the temporary table"""
        cur = self._cursor(cur)

        cur.execute("""SELECT swh_snapshot_add(%s, %s, %s)""",
                    (origin, visit, snapshot_id))

    snapshot_count_cols = ['target_type', 'count']

    def snapshot_count_branches(self, snapshot_id, cur=None):
        cur = self._cursor(cur)
        query = """\
           SELECT %s FROM swh_snapshot_count_branches(%%s)
        """ % ', '.join(self.snapshot_count_cols)

        cur.execute(query, (snapshot_id,))

        yield from cursor_to_bytes(cur)

    snapshot_get_cols = ['snapshot_id', 'name', 'target', 'target_type']

    def snapshot_get_by_id(self, snapshot_id, branches_from=b'',
                           branches_count=None, target_types=None,
                           cur=None):
        cur = self._cursor(cur)
        query = """\
           SELECT %s
           FROM swh_snapshot_get_by_id(%%s, %%s, %%s, %%s :: snapshot_target[])
        """ % ', '.join(self.snapshot_get_cols)

        cur.execute(query, (snapshot_id, branches_from, branches_count,
                            target_types))

        yield from cursor_to_bytes(cur)

    def snapshot_get_by_origin_visit(self, origin_id, visit_id, cur=None):
        cur = self._cursor(cur)
        query = """\
           SELECT swh_snapshot_get_by_origin_visit(%s, %s)
        """

        cur.execute(query, (origin_id, visit_id))
        ret = cur.fetchone()
        if ret:
            return line_to_bytes(ret)[0]

    content_find_cols = ['sha1', 'sha1_git', 'sha256', 'blake2s256', 'length',
                         'ctime', 'status']

    def content_find(self, sha1=None, sha1_git=None, sha256=None,
                     blake2s256=None, cur=None):
        """Find the content optionally on a combination of the following
        checksums sha1, sha1_git, sha256 or blake2s256.

        Args:
            sha1: sha1 content
            git_sha1: the sha1 computed `a la git` sha1 of the content
            sha256: sha256 content
            blake2s256: blake2s256 content

        Returns:
            The tuple (sha1, sha1_git, sha256, blake2s256) if found or None.

        """
        cur = self._cursor(cur)

        cur.execute("""SELECT %s
                       FROM swh_content_find(%%s, %%s, %%s, %%s)
                       LIMIT 1""" % ','.join(self.content_find_cols),
                    (sha1, sha1_git, sha256, blake2s256))

        content = line_to_bytes(cur.fetchone())
        if set(content) == {None}:
            return None
        else:
            return content

    def directory_missing_from_list(self, directories, cur=None):
        cur = self._cursor(cur)
        yield from execute_values_to_bytes(
            cur, """
            SELECT id FROM (VALUES %s) as t(id)
            WHERE NOT EXISTS (
                SELECT 1 FROM directory d WHERE d.id = t.id
            )
            """, ((id,) for id in directories))

    directory_ls_cols = ['dir_id', 'type', 'target', 'name', 'perms',
                         'status', 'sha1', 'sha1_git', 'sha256', 'length']

    def directory_walk_one(self, directory, cur=None):
        cur = self._cursor(cur)
        cols = ', '.join(self.directory_ls_cols)
        query = 'SELECT %s FROM swh_directory_walk_one(%%s)' % cols
        cur.execute(query, (directory,))
        yield from cursor_to_bytes(cur)

    def directory_walk(self, directory, cur=None):
        cur = self._cursor(cur)
        cols = ', '.join(self.directory_ls_cols)
        query = 'SELECT %s FROM swh_directory_walk(%%s)' % cols
        cur.execute(query, (directory,))
        yield from cursor_to_bytes(cur)

    def directory_entry_get_by_path(self, directory, paths, cur=None):
        """Retrieve a directory entry by path.

        """
        cur = self._cursor(cur)

        cols = ', '.join(self.directory_ls_cols)
        query = (
            'SELECT %s FROM swh_find_directory_entry_by_path(%%s, %%s)' % cols)
        cur.execute(query, (directory, paths))

        data = cur.fetchone()
        if set(data) == {None}:
            return None
        return line_to_bytes(data)

    def revision_missing_from_list(self, revisions, cur=None):
        cur = self._cursor(cur)

        yield from execute_values_to_bytes(
            cur, """
            SELECT id FROM (VALUES %s) as t(id)
            WHERE NOT EXISTS (
                SELECT 1 FROM revision r WHERE r.id = t.id
            )
            """, ((id,) for id in revisions))

    revision_add_cols = [
        'id', 'date', 'date_offset', 'date_neg_utc_offset', 'committer_date',
        'committer_date_offset', 'committer_date_neg_utc_offset', 'type',
        'directory', 'message', 'author_fullname', 'author_name',
        'author_email', 'committer_fullname', 'committer_name',
        'committer_email', 'metadata', 'synthetic',
    ]

    revision_get_cols = revision_add_cols + [
        'author_id', 'committer_id', 'parents']

    def origin_visit_add(self, origin, ts, cur=None):
        """Add a new origin_visit for origin origin at timestamp ts with
        status 'ongoing'.

        Args:
            origin: origin concerned by the visit
            ts: the date of the visit

        Returns:
            The new visit index step for that origin

        """
        cur = self._cursor(cur)
        self._cursor(cur).execute('SELECT swh_origin_visit_add(%s, %s)',
                                  (origin, ts))
        return cur.fetchone()[0]

    def origin_visit_update(self, origin, visit_id, status,
                            metadata, cur=None):
        """Update origin_visit's status."""
        cur = self._cursor(cur)
        update = """UPDATE origin_visit
                    SET status=%s, metadata=%s
                    WHERE origin=%s AND visit=%s"""
        cur.execute(update, (status, jsonize(metadata), origin, visit_id))

    origin_visit_get_cols = ['origin', 'visit', 'date', 'status', 'metadata',
                             'snapshot']

    def origin_visit_get_all(self, origin_id,
                             last_visit=None, limit=None, cur=None):
        """Retrieve all visits for origin with id origin_id.

        Args:
            origin_id: The occurrence's origin

        Yields:
            The occurrence's history visits

        """
        cur = self._cursor(cur)

        if last_visit:
            extra_condition = 'and visit > %s'
            args = (origin_id, last_visit, limit)
        else:
            extra_condition = ''
            args = (origin_id, limit)

        query = """\
        SELECT %s,
            (select id from snapshot where object_id = snapshot_id) as snapshot
        FROM origin_visit
        WHERE origin=%%s %s
        order by visit asc
        limit %%s""" % (
            ', '.join(self.origin_visit_get_cols[:-1]), extra_condition
        )

        cur.execute(query, args)

        yield from cursor_to_bytes(cur)

    def origin_visit_get(self, origin_id, visit_id, cur=None):
        """Retrieve information on visit visit_id of origin origin_id.

        Args:
            origin_id: the origin concerned
            visit_id: The visit step for that origin

        Returns:
            The origin_visit information

        """
        cur = self._cursor(cur)

        query = """\
            SELECT %s,
                (select id from snapshot where object_id = snapshot_id)
                as snapshot
            FROM origin_visit
            WHERE origin = %%s AND visit = %%s
            """ % (', '.join(self.origin_visit_get_cols[:-1]))

        cur.execute(query, (origin_id, visit_id))
        r = cur.fetchall()
        if not r:
            return None
        return line_to_bytes(r[0])

    def origin_visit_get_latest_snapshot(self, origin_id,
                                         allowed_statuses=None,
                                         cur=None):
        """Retrieve the most recent origin_visit which references a snapshot

        Args:
            origin_id: the origin concerned
            allowed_statuses: the visit statuses allowed for the returned visit

        Returns:
            The origin_visit information, or None if no visit matches.
        """
        cur = self._cursor(cur)

        extra_clause = ""
        if allowed_statuses:
            extra_clause = cur.mogrify("AND status IN %s",
                                       (tuple(allowed_statuses),)).decode()

        query = """\
            SELECT %s,
                (select id from snapshot where object_id = snapshot_id)
                as snapshot
            FROM origin_visit
            WHERE
                origin = %%s AND snapshot_id is not null %s
            ORDER BY date DESC, visit DESC
            LIMIT 1
            """ % (', '.join(self.origin_visit_get_cols[:-1]), extra_clause)

        cur.execute(query, (origin_id,))
        r = cur.fetchone()
        if not r:
            return None
        return line_to_bytes(r)

    @staticmethod
    def mangle_query_key(key, main_table):
        if key == 'id':
            return 't.id'
        if key == 'parents':
            return '''
            ARRAY(
            SELECT rh.parent_id::bytea
            FROM revision_history rh
            WHERE rh.id = t.id
            ORDER BY rh.parent_rank
            )'''
        if '_' not in key:
            return '%s.%s' % (main_table, key)

        head, tail = key.split('_', 1)
        if (head in ('author', 'committer')
                and tail in ('name', 'email', 'id', 'fullname')):
            return '%s.%s' % (head, tail)

        return '%s.%s' % (main_table, key)

    def revision_get_from_list(self, revisions, cur=None):
        cur = self._cursor(cur)

        query_keys = ', '.join(
            self.mangle_query_key(k, 'revision')
            for k in self.revision_get_cols
        )

        yield from execute_values_to_bytes(
            cur, """
            SELECT %s FROM (VALUES %%s) as t(id)
            LEFT JOIN revision ON t.id = revision.id
            LEFT JOIN person author ON revision.author = author.id
            LEFT JOIN person committer ON revision.committer = committer.id
            """ % query_keys,
            ((id,) for id in revisions))

    def revision_log(self, root_revisions, limit=None, cur=None):
        cur = self._cursor(cur)

        query = """SELECT %s
                   FROM swh_revision_log(%%s, %%s)
                """ % ', '.join(self.revision_get_cols)

        cur.execute(query, (root_revisions, limit))
        yield from cursor_to_bytes(cur)

    revision_shortlog_cols = ['id', 'parents']

    def revision_shortlog(self, root_revisions, limit=None, cur=None):
        cur = self._cursor(cur)

        query = """SELECT %s
                   FROM swh_revision_list(%%s, %%s)
                """ % ', '.join(self.revision_shortlog_cols)

        cur.execute(query, (root_revisions, limit))
        yield from cursor_to_bytes(cur)

    def release_missing_from_list(self, releases, cur=None):
        cur = self._cursor(cur)
        yield from execute_values_to_bytes(
            cur, """
            SELECT id FROM (VALUES %s) as t(id)
            WHERE NOT EXISTS (
                SELECT 1 FROM release r WHERE r.id = t.id
            )
            """, ((id,) for id in releases))

    object_find_by_sha1_git_cols = ['sha1_git', 'type', 'id', 'object_id']

    def object_find_by_sha1_git(self, ids, cur=None):
        cur = self._cursor(cur)

        yield from execute_values_to_bytes(
            cur, """
            WITH t (id) AS (VALUES %s),
            known_objects as ((
                select
                  id as sha1_git,
                  'release'::object_type as type,
                  id,
                  object_id
                from release r
                where exists (select 1 from t where t.id = r.id)
            ) union all (
                select
                  id as sha1_git,
                  'revision'::object_type as type,
                  id,
                  object_id
                from revision r
                where exists (select 1 from t where t.id = r.id)
            ) union all (
                select
                  id as sha1_git,
                  'directory'::object_type as type,
                  id,
                  object_id
                from directory d
                where exists (select 1 from t where t.id = d.id)
            ) union all (
                select
                  sha1_git as sha1_git,
                  'content'::object_type as type,
                  sha1 as id,
                  object_id
                from content c
                where exists (select 1 from t where t.id = c.sha1_git)
            ))
            select t.id as sha1_git, k.type, k.id, k.object_id
            from t
            left join known_objects k on t.id = k.sha1_git
            """,
            ((id,) for id in ids)
        )

    def stat_counters(self, cur=None):
        cur = self._cursor(cur)
        cur.execute('SELECT * FROM swh_stat_counters()')
        yield from cur

    fetch_history_cols = ['origin', 'date', 'status', 'result', 'stdout',
                          'stderr', 'duration']

    def create_fetch_history(self, fetch_history, cur=None):
        """Create a fetch_history entry with the data in fetch_history"""
        cur = self._cursor(cur)
        query = '''INSERT INTO fetch_history (%s)
                   VALUES (%s) RETURNING id''' % (
            ','.join(self.fetch_history_cols),
            ','.join(['%s'] * len(self.fetch_history_cols))
        )
        cur.execute(query, [fetch_history.get(col) for col in
                            self.fetch_history_cols])

        return cur.fetchone()[0]

    def get_fetch_history(self, fetch_history_id, cur=None):
        """Get a fetch_history entry with the given id"""
        cur = self._cursor(cur)
        query = '''SELECT %s FROM fetch_history WHERE id=%%s''' % (
            ', '.join(self.fetch_history_cols),
        )
        cur.execute(query, (fetch_history_id,))

        data = cur.fetchone()

        if not data:
            return None

        ret = {'id': fetch_history_id}
        for i, col in enumerate(self.fetch_history_cols):
            ret[col] = data[i]

        return ret

    def update_fetch_history(self, fetch_history, cur=None):
        """Update the fetch_history entry from the data in fetch_history"""
        cur = self._cursor(cur)
        query = '''UPDATE fetch_history
                   SET %s
                   WHERE id=%%s''' % (
            ','.join('%s=%%s' % col for col in self.fetch_history_cols)
        )
        cur.execute(query, [jsonize(fetch_history.get(col)) for col in
                            self.fetch_history_cols + ['id']])

    def origin_add(self, type, url, cur=None):
        """Insert a new origin and return the new identifier."""
        insert = """INSERT INTO origin (type, url) values (%s, %s)
                    RETURNING id"""

        cur.execute(insert, (type, url))
        return cur.fetchone()[0]

    origin_cols = ['id', 'type', 'url']

    def origin_get_with(self, type, url, cur=None):
        """Retrieve the origin id from its type and url if found."""
        cur = self._cursor(cur)

        query = """SELECT %s
                   FROM origin
                   WHERE type=%%s AND url=%%s
                """ % ','.join(self.origin_cols)

        cur.execute(query, (type, url))
        data = cur.fetchone()
        if data:
            return line_to_bytes(data)
        return None

    def origin_get(self, id, cur=None):
        """Retrieve the origin per its identifier.

        """
        cur = self._cursor(cur)

        query = """SELECT %s
                   FROM origin WHERE id=%%s
                """ % ','.join(self.origin_cols)

        cur.execute(query, (id,))
        data = cur.fetchone()
        if data:
            return line_to_bytes(data)
        return None

    def origin_search(self, url_pattern, offset=0, limit=50,
                      regexp=False, with_visit=False, cur=None):
        """Search for origins whose urls contain a provided string pattern
        or match a provided regular expression.
        The search is performed in a case insensitive way.

        Args:
            url_pattern (str): the string pattern to search for in origin urls
            offset (int): number of found origins to skip before returning
                results
            limit (int): the maximum number of found origins to return
            regexp (bool): if True, consider the provided pattern as a regular
                expression and returns origins whose urls match it
            with_visit (bool): if True, filter out origins with no visit

        """
        cur = self._cursor(cur)
        origin_cols = ','.join(self.origin_cols)
        query = """SELECT %s
                   FROM origin
                   WHERE """
        if with_visit:
            query += """
                   EXISTS (SELECT 1 from origin_visit WHERE origin=origin.id)
                   AND """
        query += """
                   url %s %%s
                   ORDER BY id
                   OFFSET %%s LIMIT %%s"""

        if not regexp:
            query = query % (origin_cols, 'ILIKE')
            query_params = ('%'+url_pattern+'%', offset, limit)
        else:
            query = query % (origin_cols, '~*')
            query_params = (url_pattern, offset, limit)

        cur.execute(query, query_params)
        yield from cursor_to_bytes(cur)

    person_cols = ['fullname', 'name', 'email']
    person_get_cols = person_cols + ['id']

    def person_add(self, person, cur=None):
        """Add a person identified by its name and email.

        Returns:
            The new person's id

        """
        cur = self._cursor(cur)

        query_new_person = '''\
        INSERT INTO person(%s)
        VALUES (%s)
        RETURNING id''' % (
            ', '.join(self.person_cols),
            ', '.join('%s' for i in range(len(self.person_cols)))
        )
        cur.execute(query_new_person,
                    [person[col] for col in self.person_cols])
        return cur.fetchone()[0]

    def person_get(self, ids, cur=None):
        """Retrieve the persons identified by the list of ids.

        """
        cur = self._cursor(cur)

        query = """SELECT %s
                   FROM person
                   WHERE id IN %%s""" % ', '.join(self.person_get_cols)

        cur.execute(query, (tuple(ids),))
        yield from cursor_to_bytes(cur)

    release_add_cols = [
        'id', 'target', 'target_type', 'date', 'date_offset',
        'date_neg_utc_offset', 'name', 'comment', 'synthetic',
        'author_fullname', 'author_name', 'author_email',
    ]
    release_get_cols = release_add_cols + ['author_id']

    def release_get_from_list(self, releases, cur=None):
        cur = self._cursor(cur)
        query_keys = ', '.join(
            self.mangle_query_key(k, 'release')
            for k in self.release_get_cols
        )

        yield from execute_values_to_bytes(
            cur, """
            SELECT %s FROM (VALUES %%s) as t(id)
            LEFT JOIN release ON t.id = release.id
            LEFT JOIN person author ON release.author = author.id
            """ % query_keys,
            ((id,) for id in releases))

    def origin_metadata_add(self, origin, ts, provider, tool,
                            metadata, cur=None):
        """ Add an origin_metadata for the origin at ts with provider, tool and
        metadata.

        Args:
            origin (int): the origin's id for which the metadata is added
            ts (datetime): time when the metadata was found
            provider (int): the metadata provider identifier
            tool (int): the tool's identifier used to extract metadata
            metadata (jsonb): the metadata retrieved at the time and location

        Returns:
            id (int): the origin_metadata unique id

        """
        cur = self._cursor(cur)
        insert = """INSERT INTO origin_metadata (origin_id, discovery_date,
                    provider_id, tool_id, metadata) values (%s, %s, %s, %s, %s)
                    RETURNING id"""
        cur.execute(insert, (origin, ts, provider, tool, jsonize(metadata)))

        return cur.fetchone()[0]

    origin_metadata_get_cols = ['id', 'origin_id', 'discovery_date',
                                'tool_id', 'metadata', 'provider_id',
                                'provider_name', 'provider_type',
                                'provider_url']

    def origin_metadata_get_by(self, origin_id, provider_type=None, cur=None):
        """Retrieve all origin_metadata entries for one origin_id

        """
        cur = self._cursor(cur)
        if not provider_type:
            query = '''SELECT %s
                       FROM swh_origin_metadata_get_by_origin(
                            %%s)''' % (','.join(
                                          self.origin_metadata_get_cols))

            cur.execute(query, (origin_id, ))

        else:
            query = '''SELECT %s
                       FROM swh_origin_metadata_get_by_provider_type(
                            %%s, %%s)''' % (','.join(
                                          self.origin_metadata_get_cols))

            cur.execute(query, (origin_id, provider_type))

        yield from cursor_to_bytes(cur)

    tool_cols = ['id', 'name', 'version', 'configuration']

    @stored_procedure('swh_mktemp_tool')
    def mktemp_tool(self, cur=None):
        pass

    def tool_add_from_temp(self, cur=None):
        cur = self._cursor(cur)
        cur.execute("SELECT %s from swh_tool_add()" % (
            ','.join(self.tool_cols), ))
        yield from cursor_to_bytes(cur)

    def tool_get(self, name, version, configuration, cur=None):
        cur = self._cursor(cur)
        cur.execute('''select %s
                       from tool
                       where name=%%s and
                             version=%%s and
                             configuration=%%s''' % (
                                 ','.join(self.tool_cols)),
                    (name, version, configuration))

        data = cur.fetchone()
        if not data:
            return None
        return line_to_bytes(data)

    metadata_provider_cols = ['id', 'provider_name', 'provider_type',
                              'provider_url', 'metadata']

    def metadata_provider_add(self, provider_name, provider_type,
                              provider_url, metadata, cur=None):
        """Insert a new provider and return the new identifier."""
        cur = self._cursor(cur)
        insert = """INSERT INTO metadata_provider (provider_name, provider_type,
                    provider_url, metadata) values (%s, %s, %s, %s)
                    RETURNING id"""

        cur.execute(insert, (provider_name, provider_type, provider_url,
                    jsonize(metadata)))
        return cur.fetchone()[0]

    def metadata_provider_get(self, provider_id, cur=None):
        cur = self._cursor(cur)
        cur.execute('''select %s
                       from metadata_provider
                       where provider_id=%%s ''' % (
                                 ','.join(self.metadata_provider_cols)),
                    (provider_id, ))

        data = cur.fetchone()
        if not data:
            return None
        return line_to_bytes(data)

    def metadata_provider_get_by(self, provider_name, provider_url,
                                 cur=None):
        cur = self._cursor(cur)
        cur.execute('''select %s
                       from metadata_provider
                       where provider_name=%%s and
                             provider_url=%%s''' % (
                                 ','.join(self.metadata_provider_cols)),
                    (provider_name, provider_url))

        data = cur.fetchone()
        if not data:
            return None
        return line_to_bytes(data)
back to top