Basic testing infrastructure

This commit is contained in:
Eden Kirin
2023-09-20 09:56:52 +02:00
parent f57c4d4491
commit 6109630ed1
12 changed files with 451 additions and 22 deletions

View File

View File

@ -0,0 +1,107 @@
from asyncio import current_task
from typing import AsyncGenerator, Callable, Generator
import pytest
from _pytest.fixtures import FixtureRequest
from _pytest.tmpdir import TempPathFactory
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_scoped_session
from sqlalchemy.orm import sessionmaker
import settings_test
from televend_core.databases.televend_repositories.mapper import start_televend_mappers
from televend_core.databases.televend_repositories.repository import (
TelevendRepositoryManager,
)
from televend_core.test_extras import factory_boy_utils
from televend_core.test_extras.database_utils import (
DatabaseConfig,
configure_database,
generate_async_engine,
generate_test_database_name,
xdist_lock_pytest,
)
TELEVEND_TEMPLATE_DATABASE_CONFIG = DatabaseConfig(
engine="postgresql",
user=settings_test.POSTGRES_TEST_USER,
password=settings_test.POSTGRES_TEST_PASSWORD,
name="test_televend",
host=settings_test.POSTGRES_HOST,
port=settings_test.POSTGRES_PORT,
)
@pytest.fixture(scope="session")
def televend_fixture_setup(worker_id: int, tmp_path_factory: TempPathFactory) -> None:
start_televend_mappers("selectin")
# get the temp directory shared by all workers
# When xdist is disabled (running with -n0 for example), then worker_id will return "master"
if worker_id == "master":
root_tmp_dir = tmp_path_factory.getbasetemp()
else:
root_tmp_dir = tmp_path_factory.getbasetemp().parent
file_path = root_tmp_dir / "televend_fixture.txt"
lock_path = file_path.with_suffix(".lock")
xdist_lock_pytest(
file_path=file_path,
lock_path=lock_path,
fn=lambda: configure_database(
connection_string=TELEVEND_TEMPLATE_DATABASE_CONFIG.to_connection_string(),
database_type="cloud",
),
)
@pytest.fixture(scope="function")
def televend_test_database_name(request: FixtureRequest) -> Generator[str, None, None]:
test_database_name = generate_test_database_name(test_name=request.node.originalname)
yield test_database_name
@pytest.fixture(scope="function")
async def async_televend_engine(
televend_fixture_setup: Callable, televend_test_database_name: str
) -> AsyncGenerator[AsyncEngine, None]:
async for async_engine in generate_async_engine(
template_database_config=TELEVEND_TEMPLATE_DATABASE_CONFIG,
test_database_name=televend_test_database_name,
):
yield async_engine
@pytest.fixture(scope="function")
async def async_televend_session(
async_televend_engine: AsyncEngine,
) -> AsyncGenerator[AsyncSession, None]:
# Prepare a new, clean async_session, one sync async_session for factory boy and one async for normal usage
sync_engine = create_engine(
url=async_televend_engine.url.set(drivername="postgresql"), pool_size=1, max_overflow=1
)
factory_boy_utils.TelevendSession.configure(bind=sync_engine)
televend_async_session_factory = sessionmaker(
bind=async_televend_engine, expire_on_commit=False, class_=AsyncSession
)
televend_async_session = async_scoped_session(
televend_async_session_factory, scopefunc=current_task
)
session = televend_async_session()
yield session
factory_boy_utils.TelevendSession.remove()
await session.rollback() # to avoid coroutine 'Transaction.rollback' was never awaited warning
await televend_async_session.remove()
@pytest.fixture(scope="function")
def televend_repository_manager(
async_televend_session: AsyncSession, async_televend_engine: AsyncEngine
) -> TelevendRepositoryManager:
return TelevendRepositoryManager(
async_session=async_televend_session,
async_engine=async_televend_engine,
)

View File

@ -0,0 +1,93 @@
from typing import Protocol
import asyncpg
import sqlalchemy
from asyncpg import Connection, DuplicateDatabaseError, InvalidCatalogNameError
from migrate import DatabaseConfig, migrate
class TestingSettingsInitOptions(Protocol):
DB_HOST: str
DB_PORT: int
DB_TEMPLATE_NAME: str
DB_NAME: str
DB_USER: str
DB_PASSWORD: str
DROP_DATABASE_BEFORE_TESTS: bool
DROP_DATABASE_AFTER_TESTS: bool
class TestingDatabaseSetup:
def __init__(self, options: TestingSettingsInitOptions):
self.options = options
db_connection_url = sqlalchemy.engine.URL.create(
drivername="postgresql",
username=self.options.DB_USER,
password=self.options.DB_PASSWORD,
host=self.options.DB_HOST,
port=self.options.DB_PORT,
)
self.connection_str = db_connection_url.render_as_string(hide_password=False)
async def _create_template_db(self, conn: Connection):
query = f"CREATE DATABASE {self.options.DB_TEMPLATE_NAME}"
try:
await conn.execute(query)
except DuplicateDatabaseError:
...
async def _drop_template_db(self, conn: Connection):
query = f"DROP DATABASE {self.options.DB_TEMPLATE_NAME}"
try:
await conn.execute(query)
except InvalidCatalogNameError:
...
async def _create_test_db(self, conn: Connection):
query = f"CREATE DATABASE {self.options.DB_NAME} TEMPLATE {self.options.DB_TEMPLATE_NAME}"
try:
await conn.execute(query)
except DuplicateDatabaseError:
...
async def _drop_test_db(self, conn: Connection):
query = f"DROP DATABASE {self.options.DB_NAME}"
try:
await conn.execute(query)
except InvalidCatalogNameError:
...
def _migrate_template_database(self):
conf = DatabaseConfig(
HOST=self.options.DB_HOST,
PORT=self.options.DB_PORT,
NAME=self.options.DB_TEMPLATE_NAME,
USER=self.options.DB_USER,
PASSWORD=self.options.DB_PASSWORD,
)
migrate(conf)
async def init_db(self):
conn = await asyncpg.connect(self.connection_str, database="postgres")
if self.options.DROP_DATABASE_BEFORE_TESTS:
await self._drop_template_db(conn)
await self._create_template_db(conn)
await self._create_test_db(conn)
await conn.close()
if self.options.DROP_DATABASE_BEFORE_TESTS:
self._migrate_template_database()
async def tear_down_db(self):
conn = await asyncpg.connect(self.connection_str, database="postgres")
await self._drop_test_db(conn)
if self.options.DROP_DATABASE_AFTER_TESTS:
await self._drop_template_db(conn)
await conn.close()