100 lines
2.9 KiB
Python
100 lines
2.9 KiB
Python
import asyncio
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import AsyncGenerator
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from _pytest.config import Config
|
|
from litestar.testing import AsyncTestClient
|
|
from sqlalchemy import event
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
|
|
from app.lib import settings
|
|
from app.lib.sqlalchemy_plugin import DBConnectionSettings, create_db_engine
|
|
from app.lib.test_extras.db_setup import (
|
|
TestingDatabaseSetup,
|
|
TestingSettingsInitOptions,
|
|
)
|
|
from main import app
|
|
|
|
# A Guide To Database Unit Testing with Pytest and SQLAlchemy
|
|
# https://coderpad.io/blog/development/a-guide-to-database-unit-testing-with-pytest-and-sqlalchemy/
|
|
|
|
|
|
engine = create_db_engine(
|
|
connection_settings=DBConnectionSettings(
|
|
username=settings.db.USER,
|
|
password=settings.db.PASSWORD,
|
|
host=settings.db.HOST,
|
|
port=settings.db.PORT,
|
|
database=settings.db.NAME,
|
|
)
|
|
)
|
|
|
|
|
|
async_session_factory = async_sessionmaker(
|
|
engine, expire_on_commit=False, class_=AsyncSession
|
|
)
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="function")
|
|
async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
|
"""The expectation with async_sessions is that the
|
|
transactions be called on the connection object instead of the
|
|
session object.
|
|
Detailed explanation of async transactional tests
|
|
<https://github.com/sqlalchemy/sqlalchemy/issues/5811>
|
|
"""
|
|
|
|
async with engine.connect() as connection:
|
|
trans = await connection.begin()
|
|
async with async_session_factory(bind=connection) as async_session:
|
|
nested = await connection.begin_nested()
|
|
|
|
@event.listens_for(async_session.sync_session, "after_transaction_end")
|
|
def end_savepoint(session, transaction):
|
|
nonlocal nested
|
|
|
|
if not nested.is_active:
|
|
nested = connection.sync_connection.begin_nested()
|
|
|
|
yield async_session
|
|
|
|
await trans.rollback()
|
|
|
|
await engine.dispose(close=True)
|
|
|
|
|
|
pytest_plugins = ()
|
|
|
|
|
|
db_options = TestingSettingsInitOptions(
|
|
DB_HOST=settings.db.HOST,
|
|
DB_PORT=settings.db.PORT,
|
|
DB_NAME=settings.db.NAME,
|
|
DB_USER=settings.db.USER,
|
|
DB_PASSWORD=settings.db.PASSWORD,
|
|
DB_TEMPLATE_NAME=settings.testing.DB_TEMPLATE_NAME,
|
|
DROP_DATABASE_BEFORE_TESTS=settings.testing.DROP_DATABASE_BEFORE_TESTS,
|
|
DROP_DATABASE_AFTER_TESTS=settings.testing.DROP_DATABASE_AFTER_TESTS,
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def async_client() -> AsyncTestClient:
|
|
return AsyncTestClient(app=app)
|
|
|
|
|
|
def pytest_configure(config: Config) -> None:
|
|
logging.info(f"Starting tests: {datetime.utcnow()}")
|
|
db_setup = TestingDatabaseSetup(db_options)
|
|
asyncio.run(db_setup.init_db())
|
|
print()
|
|
|
|
|
|
def pytest_unconfigure(config: Config) -> None:
|
|
logging.info(f"Ending tests: {datetime.utcnow()}")
|
|
db_setup = TestingDatabaseSetup(db_options)
|
|
asyncio.run(db_setup.tear_down_db())
|