from asyncio import current_task from typing import AsyncGenerator, Callable, Generator import pytest from _pytest.fixtures import FixtureRequest from _pytest.tmpdir import TempPathFactory from sqlalchemy import create_engine from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_scoped_session from sqlalchemy.orm import sessionmaker import settings_test from televend_core.databases.televend_repositories.mapper import start_televend_mappers from televend_core.databases.televend_repositories.repository import ( TelevendRepositoryManager, ) from televend_core.test_extras import factory_boy_utils from televend_core.test_extras.database_utils import ( DatabaseConfig, configure_database, generate_async_engine, generate_test_database_name, xdist_lock_pytest, ) TELEVEND_TEMPLATE_DATABASE_CONFIG = DatabaseConfig( engine="postgresql", user=settings_test.POSTGRES_TEST_USER, password=settings_test.POSTGRES_TEST_PASSWORD, name="test_televend", host=settings_test.POSTGRES_HOST, port=settings_test.POSTGRES_PORT, ) @pytest.fixture(scope="session") def televend_fixture_setup(worker_id: int, tmp_path_factory: TempPathFactory) -> None: start_televend_mappers("selectin") # get the temp directory shared by all workers # When xdist is disabled (running with -n0 for example), then worker_id will return "master" if worker_id == "master": root_tmp_dir = tmp_path_factory.getbasetemp() else: root_tmp_dir = tmp_path_factory.getbasetemp().parent file_path = root_tmp_dir / "televend_fixture.txt" lock_path = file_path.with_suffix(".lock") xdist_lock_pytest( file_path=file_path, lock_path=lock_path, fn=lambda: configure_database( connection_string=TELEVEND_TEMPLATE_DATABASE_CONFIG.to_connection_string(), database_type="cloud", ), ) @pytest.fixture(scope="function") def televend_test_database_name(request: FixtureRequest) -> Generator[str, None, None]: test_database_name = generate_test_database_name(test_name=request.node.originalname) yield test_database_name @pytest.fixture(scope="function") async def async_televend_engine( televend_fixture_setup: Callable, televend_test_database_name: str ) -> AsyncGenerator[AsyncEngine, None]: async for async_engine in generate_async_engine( template_database_config=TELEVEND_TEMPLATE_DATABASE_CONFIG, test_database_name=televend_test_database_name, ): yield async_engine @pytest.fixture(scope="function") async def async_televend_session( async_televend_engine: AsyncEngine, ) -> AsyncGenerator[AsyncSession, None]: # Prepare a new, clean async_session, one sync async_session for factory boy and one async for normal usage sync_engine = create_engine( url=async_televend_engine.url.set(drivername="postgresql"), pool_size=1, max_overflow=1 ) factory_boy_utils.TelevendSession.configure(bind=sync_engine) televend_async_session_factory = sessionmaker( bind=async_televend_engine, expire_on_commit=False, class_=AsyncSession ) televend_async_session = async_scoped_session( televend_async_session_factory, scopefunc=current_task ) session = televend_async_session() yield session factory_boy_utils.TelevendSession.remove() await session.rollback() # to avoid coroutine 'Transaction.rollback' was never awaited warning await televend_async_session.remove() @pytest.fixture(scope="function") def televend_repository_manager( async_televend_session: AsyncSession, async_televend_engine: AsyncEngine ) -> TelevendRepositoryManager: return TelevendRepositoryManager( async_session=async_televend_session, async_engine=async_televend_engine, )