Files
2023-09-20 09:56:52 +02:00

108 lines
3.7 KiB
Python

from asyncio import current_task
from typing import AsyncGenerator, Callable, Generator
import pytest
from _pytest.fixtures import FixtureRequest
from _pytest.tmpdir import TempPathFactory
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_scoped_session
from sqlalchemy.orm import sessionmaker
import settings_test
from televend_core.databases.televend_repositories.mapper import start_televend_mappers
from televend_core.databases.televend_repositories.repository import (
TelevendRepositoryManager,
)
from televend_core.test_extras import factory_boy_utils
from televend_core.test_extras.database_utils import (
DatabaseConfig,
configure_database,
generate_async_engine,
generate_test_database_name,
xdist_lock_pytest,
)
TELEVEND_TEMPLATE_DATABASE_CONFIG = DatabaseConfig(
engine="postgresql",
user=settings_test.POSTGRES_TEST_USER,
password=settings_test.POSTGRES_TEST_PASSWORD,
name="test_televend",
host=settings_test.POSTGRES_HOST,
port=settings_test.POSTGRES_PORT,
)
@pytest.fixture(scope="session")
def televend_fixture_setup(worker_id: int, tmp_path_factory: TempPathFactory) -> None:
start_televend_mappers("selectin")
# get the temp directory shared by all workers
# When xdist is disabled (running with -n0 for example), then worker_id will return "master"
if worker_id == "master":
root_tmp_dir = tmp_path_factory.getbasetemp()
else:
root_tmp_dir = tmp_path_factory.getbasetemp().parent
file_path = root_tmp_dir / "televend_fixture.txt"
lock_path = file_path.with_suffix(".lock")
xdist_lock_pytest(
file_path=file_path,
lock_path=lock_path,
fn=lambda: configure_database(
connection_string=TELEVEND_TEMPLATE_DATABASE_CONFIG.to_connection_string(),
database_type="cloud",
),
)
@pytest.fixture(scope="function")
def televend_test_database_name(request: FixtureRequest) -> Generator[str, None, None]:
test_database_name = generate_test_database_name(test_name=request.node.originalname)
yield test_database_name
@pytest.fixture(scope="function")
async def async_televend_engine(
televend_fixture_setup: Callable, televend_test_database_name: str
) -> AsyncGenerator[AsyncEngine, None]:
async for async_engine in generate_async_engine(
template_database_config=TELEVEND_TEMPLATE_DATABASE_CONFIG,
test_database_name=televend_test_database_name,
):
yield async_engine
@pytest.fixture(scope="function")
async def async_televend_session(
async_televend_engine: AsyncEngine,
) -> AsyncGenerator[AsyncSession, None]:
# Prepare a new, clean async_session, one sync async_session for factory boy and one async for normal usage
sync_engine = create_engine(
url=async_televend_engine.url.set(drivername="postgresql"), pool_size=1, max_overflow=1
)
factory_boy_utils.TelevendSession.configure(bind=sync_engine)
televend_async_session_factory = sessionmaker(
bind=async_televend_engine, expire_on_commit=False, class_=AsyncSession
)
televend_async_session = async_scoped_session(
televend_async_session_factory, scopefunc=current_task
)
session = televend_async_session()
yield session
factory_boy_utils.TelevendSession.remove()
await session.rollback() # to avoid coroutine 'Transaction.rollback' was never awaited warning
await televend_async_session.remove()
@pytest.fixture(scope="function")
def televend_repository_manager(
async_televend_session: AsyncSession, async_televend_engine: AsyncEngine
) -> TelevendRepositoryManager:
return TelevendRepositoryManager(
async_session=async_televend_session,
async_engine=async_televend_engine,
)