diff --git a/app/lib/sqlalchemy_plugin.py b/app/lib/sqlalchemy_plugin.py index 75f1c2c..3d0c76b 100644 --- a/app/lib/sqlalchemy_plugin.py +++ b/app/lib/sqlalchemy_plugin.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from dataclasses import dataclass from typing import TYPE_CHECKING, cast from uuid import UUID @@ -73,15 +74,26 @@ def create_db_engine(connection_settings: DBConnectionSettings) -> AsyncEngine: ) -engine = create_db_engine( - connection_settings=DBConnectionSettings( - username=settings.db.USER, - password=settings.db.PASSWORD, - host=settings.db.HOST, - port=settings.db.PORT, - database=settings.db.NAME, +if "pytest" in sys.modules: + engine = create_db_engine( + connection_settings=DBConnectionSettings( + username=settings.testing.DB_USER, + password=settings.testing.DB_PASSWORD, + host=settings.testing.DB_HOST, + port=settings.testing.DB_PORT, + database=settings.testing.DB_NAME, + ) + ) +else: + engine = create_db_engine( + connection_settings=DBConnectionSettings( + username=settings.db.USER, + password=settings.db.PASSWORD, + host=settings.db.HOST, + port=settings.db.PORT, + database=settings.db.NAME, + ) ) -) """Configure via DatabaseSettings. diff --git a/tests/conftest.py b/tests/conftest.py index 891ef51..a03dae6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,55 +3,35 @@ import logging from datetime import datetime from typing import AsyncGenerator -import msgspec import pytest_asyncio -import sqlalchemy from _pytest.config import Config -from sqlalchemy import NullPool, event -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy import event +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from app.lib import settings -from app.lib.sqlalchemy_plugin import _default +from app.lib.sqlalchemy_plugin import DBConnectionSettings, create_db_engine from app.lib.test_extras.db_setup import TestingDatabaseSetup # A Guide To Database Unit Testing with Pytest and SQLAlchemy # https://coderpad.io/blog/development/a-guide-to-database-unit-testing-with-pytest-and-sqlalchemy/ -db_connection_url = sqlalchemy.engine.URL.create( - drivername="postgresql+asyncpg", - username=settings.testing.DB_USER, - password=settings.testing.DB_PASSWORD, - host=settings.testing.DB_HOST, - port=settings.testing.DB_PORT, - database=settings.testing.DB_NAME, +engine = create_db_engine( + connection_settings=DBConnectionSettings( + username=settings.testing.DB_USER, + password=settings.testing.DB_PASSWORD, + host=settings.testing.DB_HOST, + port=settings.testing.DB_PORT, + database=settings.testing.DB_NAME, + ) ) -engine = create_async_engine( - db_connection_url, - echo=settings.db.ECHO, - echo_pool=settings.db.ECHO_POOL, - json_serializer=msgspec.json.Encoder(enc_hook=_default), - max_overflow=settings.db.POOL_MAX_OVERFLOW, - pool_size=settings.db.POOL_SIZE, - pool_timeout=settings.db.POOL_TIMEOUT, - poolclass=NullPool if settings.db.POOL_DISABLE else None, -) async_session_factory = async_sessionmaker( engine, expire_on_commit=False, class_=AsyncSession ) -TestingAsyncSessionLocal = async_sessionmaker( - engine, - expire_on_commit=False, - autoflush=False, - autocommit=False, - class_=AsyncSession, -) - - @pytest_asyncio.fixture(scope="function") async def db_session() -> AsyncGenerator[AsyncSession, None]: """The expectation with async_sessions is that the @@ -63,7 +43,7 @@ async def db_session() -> AsyncGenerator[AsyncSession, None]: async with engine.connect() as connection: trans = await connection.begin() - async with TestingAsyncSessionLocal(bind=connection) as async_session: + async with async_session_factory(bind=connection) as async_session: nested = await connection.begin_nested() @event.listens_for(async_session.sync_session, "after_transaction_end") @@ -80,20 +60,7 @@ async def db_session() -> AsyncGenerator[AsyncSession, None]: await engine.dispose(close=True) -# @pytest.fixture(scope="session") -# def event_loop(): -# """ -# Creates an instance of the default event loop for the test session. -# """ -# policy = asyncio.get_event_loop_policy() -# loop = policy.new_event_loop() -# yield loop -# loop.close() - - -pytest_plugins = ( - # "app.lib.test_extras.db_plugins", -) +pytest_plugins = () def pytest_configure(config: Config) -> None: