Working factories

This commit is contained in:
Eden Kirin
2023-09-20 21:48:24 +02:00
parent 6ccb660ccc
commit 7c55e39f32
2 changed files with 33 additions and 54 deletions

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING, cast
from uuid import UUID from uuid import UUID
@ -73,15 +74,26 @@ def create_db_engine(connection_settings: DBConnectionSettings) -> AsyncEngine:
) )
engine = create_db_engine( if "pytest" in sys.modules:
connection_settings=DBConnectionSettings( engine = create_db_engine(
username=settings.db.USER, connection_settings=DBConnectionSettings(
password=settings.db.PASSWORD, username=settings.testing.DB_USER,
host=settings.db.HOST, password=settings.testing.DB_PASSWORD,
port=settings.db.PORT, host=settings.testing.DB_HOST,
database=settings.db.NAME, port=settings.testing.DB_PORT,
database=settings.testing.DB_NAME,
)
)
else:
engine = create_db_engine(
connection_settings=DBConnectionSettings(
username=settings.db.USER,
password=settings.db.PASSWORD,
host=settings.db.HOST,
port=settings.db.PORT,
database=settings.db.NAME,
)
) )
)
"""Configure via DatabaseSettings. """Configure via DatabaseSettings.

View File

@ -3,55 +3,35 @@ import logging
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator from typing import AsyncGenerator
import msgspec
import pytest_asyncio import pytest_asyncio
import sqlalchemy
from _pytest.config import Config from _pytest.config import Config
from sqlalchemy import NullPool, event from sqlalchemy import event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.lib import settings from app.lib import settings
from app.lib.sqlalchemy_plugin import _default from app.lib.sqlalchemy_plugin import DBConnectionSettings, create_db_engine
from app.lib.test_extras.db_setup import TestingDatabaseSetup from app.lib.test_extras.db_setup import TestingDatabaseSetup
# A Guide To Database Unit Testing with Pytest and SQLAlchemy # A Guide To Database Unit Testing with Pytest and SQLAlchemy
# https://coderpad.io/blog/development/a-guide-to-database-unit-testing-with-pytest-and-sqlalchemy/ # https://coderpad.io/blog/development/a-guide-to-database-unit-testing-with-pytest-and-sqlalchemy/
db_connection_url = sqlalchemy.engine.URL.create( engine = create_db_engine(
drivername="postgresql+asyncpg", connection_settings=DBConnectionSettings(
username=settings.testing.DB_USER, username=settings.testing.DB_USER,
password=settings.testing.DB_PASSWORD, password=settings.testing.DB_PASSWORD,
host=settings.testing.DB_HOST, host=settings.testing.DB_HOST,
port=settings.testing.DB_PORT, port=settings.testing.DB_PORT,
database=settings.testing.DB_NAME, database=settings.testing.DB_NAME,
)
) )
engine = create_async_engine(
db_connection_url,
echo=settings.db.ECHO,
echo_pool=settings.db.ECHO_POOL,
json_serializer=msgspec.json.Encoder(enc_hook=_default),
max_overflow=settings.db.POOL_MAX_OVERFLOW,
pool_size=settings.db.POOL_SIZE,
pool_timeout=settings.db.POOL_TIMEOUT,
poolclass=NullPool if settings.db.POOL_DISABLE else None,
)
async_session_factory = async_sessionmaker( async_session_factory = async_sessionmaker(
engine, expire_on_commit=False, class_=AsyncSession engine, expire_on_commit=False, class_=AsyncSession
) )
TestingAsyncSessionLocal = async_sessionmaker(
engine,
expire_on_commit=False,
autoflush=False,
autocommit=False,
class_=AsyncSession,
)
@pytest_asyncio.fixture(scope="function") @pytest_asyncio.fixture(scope="function")
async def db_session() -> AsyncGenerator[AsyncSession, None]: async def db_session() -> AsyncGenerator[AsyncSession, None]:
"""The expectation with async_sessions is that the """The expectation with async_sessions is that the
@ -63,7 +43,7 @@ async def db_session() -> AsyncGenerator[AsyncSession, None]:
async with engine.connect() as connection: async with engine.connect() as connection:
trans = await connection.begin() trans = await connection.begin()
async with TestingAsyncSessionLocal(bind=connection) as async_session: async with async_session_factory(bind=connection) as async_session:
nested = await connection.begin_nested() nested = await connection.begin_nested()
@event.listens_for(async_session.sync_session, "after_transaction_end") @event.listens_for(async_session.sync_session, "after_transaction_end")
@ -80,20 +60,7 @@ async def db_session() -> AsyncGenerator[AsyncSession, None]:
await engine.dispose(close=True) await engine.dispose(close=True)
# @pytest.fixture(scope="session") pytest_plugins = ()
# def event_loop():
# """
# Creates an instance of the default event loop for the test session.
# """
# policy = asyncio.get_event_loop_policy()
# loop = policy.new_event_loop()
# yield loop
# loop.close()
pytest_plugins = (
# "app.lib.test_extras.db_plugins",
)
def pytest_configure(config: Config) -> None: def pytest_configure(config: Config) -> None: