import asyncio import logging from datetime import datetime from typing import AsyncGenerator import msgspec import pytest_asyncio import sqlalchemy from _pytest.config import Config from sqlalchemy import NullPool, event from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from app.lib import settings from app.lib.sqlalchemy_plugin import _default from app.lib.test_extras.db_setup import TestingDatabaseSetup # A Guide To Database Unit Testing with Pytest and SQLAlchemy # https://coderpad.io/blog/development/a-guide-to-database-unit-testing-with-pytest-and-sqlalchemy/ db_connection_url = sqlalchemy.engine.URL.create( drivername="postgresql+asyncpg", username=settings.testing.DB_USER, password=settings.testing.DB_PASSWORD, host=settings.testing.DB_HOST, port=settings.testing.DB_PORT, database=settings.testing.DB_NAME, ) engine = create_async_engine( db_connection_url, echo=settings.db.ECHO, echo_pool=settings.db.ECHO_POOL, json_serializer=msgspec.json.Encoder(enc_hook=_default), max_overflow=settings.db.POOL_MAX_OVERFLOW, pool_size=settings.db.POOL_SIZE, pool_timeout=settings.db.POOL_TIMEOUT, poolclass=NullPool if settings.db.POOL_DISABLE else None, ) async_session_factory = async_sessionmaker( engine, expire_on_commit=False, class_=AsyncSession ) TestingAsyncSessionLocal = async_sessionmaker( engine, expire_on_commit=False, autoflush=False, autocommit=False, class_=AsyncSession, ) @pytest_asyncio.fixture(scope="function") async def db_session() -> AsyncGenerator[AsyncSession, None]: """The expectation with async_sessions is that the transactions be called on the connection object instead of the session object. Detailed explanation of async transactional tests """ async with engine.connect() as connection: trans = await connection.begin() async with TestingAsyncSessionLocal(bind=connection) as async_session: nested = await connection.begin_nested() @event.listens_for(async_session.sync_session, "after_transaction_end") def end_savepoint(session, transaction): nonlocal nested if not nested.is_active: nested = connection.sync_connection.begin_nested() yield async_session await trans.rollback() await engine.dispose(close=True) # @pytest.fixture(scope="session") # def event_loop(): # """ # Creates an instance of the default event loop for the test session. # """ # policy = asyncio.get_event_loop_policy() # loop = policy.new_event_loop() # yield loop # loop.close() pytest_plugins = ( # "app.lib.test_extras.db_plugins", ) def pytest_configure(config: Config) -> None: logging.info(f"Starting tests: {datetime.utcnow()}") db_setup = TestingDatabaseSetup(options=settings.testing) asyncio.run(db_setup.init_db()) print() def pytest_unconfigure(config: Config) -> None: logging.info(f"Ending tests: {datetime.utcnow()}") db_setup = TestingDatabaseSetup(options=settings.testing) asyncio.run(db_setup.tear_down_db())