diff --git a/.gitignore b/.gitignore index 423be66..e460b96 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ __pycache__ /.env +/.env.testing + diff --git a/app/lib/settings.py b/app/lib/settings.py index e3ceaa0..2f504c9 100644 --- a/app/lib/settings.py +++ b/app/lib/settings.py @@ -3,6 +3,7 @@ Take note of the environment variable prefixes required for each settings class, except `AppSettings`. """ +import sys from typing import Literal, Optional, Union __all__ = [ @@ -17,10 +18,11 @@ __all__ = [ from pydantic import Extra from pydantic_settings import BaseSettings +from const import ROOT_DIR + class BaseEnvSettings(BaseSettings): class Config: - env_file = ".env" env_file_encoding = "utf-8" extra = Extra.ignore @@ -89,12 +91,7 @@ class TestingSettings(BaseEnvSettings): env_prefix = "TESTS_" case_sensitive = True - DB_HOST: str = "localhost" - DB_PORT: int = 5432 DB_TEMPLATE_NAME: str = "db-template-name" - DB_NAME: str = "test_db-name" - DB_USER: str = "db-user" - DB_PASSWORD: str = "db-password" DROP_DATABASE_BEFORE_TESTS: bool = True DROP_DATABASE_AFTER_TESTS: bool = True @@ -117,11 +114,18 @@ class EmailSettings(BaseEnvSettings): case_sensitive = True -# `.parse_obj()` thing is a workaround for pyright and pydantic interplay, see: -# https://github.com/pydantic/pydantic/issues/3753#issuecomment-1087417884 -api = APISettings.parse_obj({}) -app = AppSettings.parse_obj({}) -db = DatabaseSettings.parse_obj({}) -openapi = OpenAPISettings.parse_obj({}) -server = ServerSettings.parse_obj({}) -testing = TestingSettings.parse_obj({}) +if "pytest" in sys.modules: + env_file = ROOT_DIR / ".env.testing" +else: + env_file = ROOT_DIR / ".env" + +params = { + "_env_file": env_file, +} + +api = APISettings(**params) +app = AppSettings(**params) +db = DatabaseSettings(**params) +openapi = OpenAPISettings(**params) +server = ServerSettings(**params) +testing = TestingSettings(**params) diff --git a/app/lib/sqlalchemy_plugin.py b/app/lib/sqlalchemy_plugin.py index 3d0c76b..87977f7 100644 --- a/app/lib/sqlalchemy_plugin.py +++ b/app/lib/sqlalchemy_plugin.py @@ -74,26 +74,15 @@ def create_db_engine(connection_settings: DBConnectionSettings) -> AsyncEngine: ) -if "pytest" in sys.modules: - engine = create_db_engine( - connection_settings=DBConnectionSettings( - username=settings.testing.DB_USER, - password=settings.testing.DB_PASSWORD, - host=settings.testing.DB_HOST, - port=settings.testing.DB_PORT, - database=settings.testing.DB_NAME, - ) - ) -else: - engine = create_db_engine( - connection_settings=DBConnectionSettings( - username=settings.db.USER, - password=settings.db.PASSWORD, - host=settings.db.HOST, - port=settings.db.PORT, - database=settings.db.NAME, - ) +engine = create_db_engine( + connection_settings=DBConnectionSettings( + username=settings.db.USER, + password=settings.db.PASSWORD, + host=settings.db.HOST, + port=settings.db.PORT, + database=settings.db.NAME, ) +) """Configure via DatabaseSettings. diff --git a/app/lib/test_extras/db_setup.py b/app/lib/test_extras/db_setup.py index 0b8fe63..af4622e 100644 --- a/app/lib/test_extras/db_setup.py +++ b/app/lib/test_extras/db_setup.py @@ -1,4 +1,4 @@ -from typing import Protocol +from dataclasses import dataclass import asyncpg import sqlalchemy @@ -7,13 +7,14 @@ from asyncpg import Connection, DuplicateDatabaseError, InvalidCatalogNameError from migrate import DatabaseConfig, migrate -class TestingSettingsInitOptions(Protocol): +@dataclass +class TestingSettingsInitOptions: DB_HOST: str DB_PORT: int - DB_TEMPLATE_NAME: str DB_NAME: str DB_USER: str DB_PASSWORD: str + DB_TEMPLATE_NAME: str DROP_DATABASE_BEFORE_TESTS: bool DROP_DATABASE_AFTER_TESTS: bool diff --git a/const.py b/const.py new file mode 100644 index 0000000..1002875 --- /dev/null +++ b/const.py @@ -0,0 +1,3 @@ +from pathlib import Path + +ROOT_DIR = Path(__file__).absolute().parent diff --git a/tests/conftest.py b/tests/conftest.py index 6cde9af..1dc5a7d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,10 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from app.lib import settings from app.lib.sqlalchemy_plugin import DBConnectionSettings, create_db_engine -from app.lib.test_extras.db_setup import TestingDatabaseSetup +from app.lib.test_extras.db_setup import ( + TestingDatabaseSetup, + TestingSettingsInitOptions, +) from main import app # A Guide To Database Unit Testing with Pytest and SQLAlchemy @@ -21,11 +24,11 @@ from main import app engine = create_db_engine( connection_settings=DBConnectionSettings( - username=settings.testing.DB_USER, - password=settings.testing.DB_PASSWORD, - host=settings.testing.DB_HOST, - port=settings.testing.DB_PORT, - database=settings.testing.DB_NAME, + username=settings.db.USER, + password=settings.db.PASSWORD, + host=settings.db.HOST, + port=settings.db.PORT, + database=settings.db.NAME, ) ) @@ -66,6 +69,18 @@ async def db_session() -> AsyncGenerator[AsyncSession, None]: pytest_plugins = () +db_options = TestingSettingsInitOptions( + DB_HOST=settings.db.HOST, + DB_PORT=settings.db.PORT, + DB_NAME=settings.db.NAME, + DB_USER=settings.db.USER, + DB_PASSWORD=settings.db.PASSWORD, + DB_TEMPLATE_NAME=settings.testing.DB_TEMPLATE_NAME, + DROP_DATABASE_BEFORE_TESTS=settings.testing.DROP_DATABASE_BEFORE_TESTS, + DROP_DATABASE_AFTER_TESTS=settings.testing.DROP_DATABASE_AFTER_TESTS, +) + + @pytest.fixture(scope="function") def async_client() -> AsyncTestClient: return AsyncTestClient(app=app) @@ -73,12 +88,12 @@ def async_client() -> AsyncTestClient: def pytest_configure(config: Config) -> None: logging.info(f"Starting tests: {datetime.utcnow()}") - db_setup = TestingDatabaseSetup(options=settings.testing) + db_setup = TestingDatabaseSetup(db_options) asyncio.run(db_setup.init_db()) print() def pytest_unconfigure(config: Config) -> None: logging.info(f"Ending tests: {datetime.utcnow()}") - db_setup = TestingDatabaseSetup(options=settings.testing) + db_setup = TestingDatabaseSetup(db_options) asyncio.run(db_setup.tear_down_db())