pytest_plugin.py
# Copyright (C) 2019-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import glob
from os import path, environ
from typing import Union
import pytest
import swh.storage
from pytest_postgresql import factories
from pytest_postgresql.janitor import DatabaseJanitor, psycopg2, Version
from swh.core.utils import numfile_sortkey as sortkey
from swh.storage import get_storage
from swh.storage.tests.storage_data import StorageData
SQL_DIR = path.join(path.dirname(swh.storage.__file__), "sql")
environ["LC_ALL"] = "C.UTF-8"
DUMP_FILES = path.join(SQL_DIR, "*.sql")
@pytest.fixture
def swh_storage_backend_config(postgresql_proc, swh_storage_postgresql):
"""Basic pg storage configuration with no journal collaborator
(to avoid pulling optional dependency on clients of this fixture)
"""
yield {
"cls": "local",
"db": "postgresql://{user}@{host}:{port}/{dbname}".format(
host=postgresql_proc.host,
port=postgresql_proc.port,
user="postgres",
dbname="tests",
),
"objstorage": {"cls": "memory", "args": {}},
"check_config": {"check_write": True},
}
@pytest.fixture
def swh_storage(swh_storage_backend_config):
return get_storage(**swh_storage_backend_config)
# the postgres_fact factory fixture below is mostly a copy of the code
# from pytest-postgresql. We need a custom version here to be able to
# specify our version of the DBJanitor we use.
def postgresql_fact(process_fixture_name, db_name=None, dump_files=DUMP_FILES):
@pytest.fixture
def postgresql_factory(request):
"""
Fixture factory for PostgreSQL.
:param FixtureRequest request: fixture request object
:rtype: psycopg2.connection
:returns: postgresql client
"""
config = factories.get_config(request)
if not psycopg2:
raise ImportError("No module named psycopg2. Please install it.")
proc_fixture = request.getfixturevalue(process_fixture_name)
# _, config = try_import('psycopg2', request)
pg_host = proc_fixture.host
pg_port = proc_fixture.port
pg_user = proc_fixture.user
pg_options = proc_fixture.options
pg_db = db_name or config["dbname"]
with SwhDatabaseJanitor(
pg_user,
pg_host,
pg_port,
pg_db,
proc_fixture.version,
dump_files=dump_files,
):
connection = psycopg2.connect(
dbname=pg_db,
user=pg_user,
host=pg_host,
port=pg_port,
options=pg_options,
)
yield connection
connection.close()
return postgresql_factory
swh_storage_postgresql = postgresql_fact("postgresql_proc")
# This version of the DatabaseJanitor implement a different setup/teardown
# behavior than than the stock one: instead of dropping, creating and
# initializing the database for each test, it create and initialize the db only
# once, then it truncate the tables. This is needed to have acceptable test
# performances.
class SwhDatabaseJanitor(DatabaseJanitor):
def __init__(
self,
user: str,
host: str,
port: str,
db_name: str,
version: Union[str, float, Version],
dump_files: str = DUMP_FILES,
) -> None:
super().__init__(user, host, port, db_name, version)
self.dump_files = sorted(glob.glob(dump_files), key=sortkey)
def db_setup(self):
with psycopg2.connect(
dbname=self.db_name, user=self.user, host=self.host, port=self.port,
) as cnx:
with cnx.cursor() as cur:
for fname in self.dump_files:
with open(fname) as fobj:
sql = fobj.read().replace("concurrently", "").strip()
if sql:
cur.execute(sql)
cnx.commit()
def db_reset(self):
with psycopg2.connect(
dbname=self.db_name, user=self.user, host=self.host, port=self.port,
) as cnx:
with cnx.cursor() as cur:
cur.execute(
"SELECT table_name FROM information_schema.tables "
"WHERE table_schema = %s",
("public",),
)
tables = set(table for (table,) in cur.fetchall()) - {"dbversion"}
for table in tables:
cur.execute("truncate table %s cascade" % table)
cur.execute(
"SELECT sequence_name FROM information_schema.sequences "
"WHERE sequence_schema = %s",
("public",),
)
seqs = set(seq for (seq,) in cur.fetchall())
for seq in seqs:
cur.execute("ALTER SEQUENCE %s RESTART;" % seq)
cnx.commit()
def init(self):
with self.cursor() as cur:
cur.execute(
"SELECT COUNT(1) FROM pg_database WHERE datname=%s;", (self.db_name,)
)
db_exists = cur.fetchone()[0] == 1
if db_exists:
cur.execute(
"UPDATE pg_database SET datallowconn=true " "WHERE datname = %s;",
(self.db_name,),
)
if db_exists:
self.db_reset()
else:
with self.cursor() as cur:
cur.execute('CREATE DATABASE "{}";'.format(self.db_name))
self.db_setup()
def drop(self):
pid_column = "pid"
with self.cursor() as cur:
cur.execute(
"UPDATE pg_database SET datallowconn=false " "WHERE datname = %s;",
(self.db_name,),
)
cur.execute(
"SELECT pg_terminate_backend(pg_stat_activity.{})"
"FROM pg_stat_activity "
"WHERE pg_stat_activity.datname = %s;".format(pid_column),
(self.db_name,),
)
@pytest.fixture
def sample_data() -> StorageData:
"""Pre-defined sample storage object data to manipulate
Returns:
StorageData whose attribute keys are data model objects. Either multiple
objects: contents, directories, revisions, releases, ... or simple ones:
content, directory, revision, release, ...
"""
return StorageData()