import asyncio import logging from datetime import datetime from typing import AsyncGenerator import pytest import pytest_asyncio from _pytest.config import Config from litestar.testing import AsyncTestClient from sqlalchemy import event from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from app.lib import settings from app.lib.sqlalchemy_plugin import DBConnectionSettings, create_db_engine from app.lib.test_extras.db_setup import TestingDatabaseSetup from main import app # A Guide To Database Unit Testing with Pytest and SQLAlchemy # https://coderpad.io/blog/development/a-guide-to-database-unit-testing-with-pytest-and-sqlalchemy/ engine = create_db_engine( connection_settings=DBConnectionSettings( username=settings.testing.DB_USER, password=settings.testing.DB_PASSWORD, host=settings.testing.DB_HOST, port=settings.testing.DB_PORT, database=settings.testing.DB_NAME, ) ) async_session_factory = async_sessionmaker( engine, expire_on_commit=False, class_=AsyncSession ) @pytest_asyncio.fixture(scope="function") async def db_session() -> AsyncGenerator[AsyncSession, None]: """The expectation with async_sessions is that the transactions be called on the connection object instead of the session object. Detailed explanation of async transactional tests """ async with engine.connect() as connection: trans = await connection.begin() async with async_session_factory(bind=connection) as async_session: nested = await connection.begin_nested() @event.listens_for(async_session.sync_session, "after_transaction_end") def end_savepoint(session, transaction): nonlocal nested if not nested.is_active: nested = connection.sync_connection.begin_nested() yield async_session await trans.rollback() await engine.dispose(close=True) pytest_plugins = () @pytest.fixture(scope="function") def async_client() -> AsyncTestClient: return AsyncTestClient(app=app) def pytest_configure(config: Config) -> None: logging.info(f"Starting tests: {datetime.utcnow()}") db_setup = TestingDatabaseSetup(options=settings.testing) asyncio.run(db_setup.init_db()) print() def pytest_unconfigure(config: Config) -> None: logging.info(f"Ending tests: {datetime.utcnow()}") db_setup = TestingDatabaseSetup(options=settings.testing) asyncio.run(db_setup.tear_down_db())