from typing import Protocol import asyncpg import sqlalchemy from asyncpg import Connection, DuplicateDatabaseError, InvalidCatalogNameError from migrate import DatabaseConfig, migrate class TestingSettingsInitOptions(Protocol): DB_HOST: str DB_PORT: int DB_TEMPLATE_NAME: str DB_NAME: str DB_USER: str DB_PASSWORD: str DROP_DATABASE_BEFORE_TESTS: bool DROP_DATABASE_AFTER_TESTS: bool class TestingDatabaseSetup: def __init__(self, options: TestingSettingsInitOptions): self.options = options db_connection_url = sqlalchemy.engine.URL.create( drivername="postgresql", username=self.options.DB_USER, password=self.options.DB_PASSWORD, host=self.options.DB_HOST, port=self.options.DB_PORT, ) self.connection_str = db_connection_url.render_as_string(hide_password=False) async def _create_template_db(self, conn: Connection): query = f"CREATE DATABASE {self.options.DB_TEMPLATE_NAME}" try: await conn.execute(query) except DuplicateDatabaseError: ... async def _drop_template_db(self, conn: Connection): query = f"DROP DATABASE {self.options.DB_TEMPLATE_NAME}" try: await conn.execute(query) except InvalidCatalogNameError: ... async def _create_test_db(self, conn: Connection): query = f"CREATE DATABASE {self.options.DB_NAME} TEMPLATE {self.options.DB_TEMPLATE_NAME}" try: await conn.execute(query) except DuplicateDatabaseError: ... async def _drop_test_db(self, conn: Connection): query = f"DROP DATABASE {self.options.DB_NAME}" try: await conn.execute(query) except InvalidCatalogNameError: ... def _migrate_template_database(self): conf = DatabaseConfig( HOST=self.options.DB_HOST, PORT=self.options.DB_PORT, NAME=self.options.DB_TEMPLATE_NAME, USER=self.options.DB_USER, PASSWORD=self.options.DB_PASSWORD, ) migrate(conf) async def init_db(self): conn = await asyncpg.connect(self.connection_str, database="postgres") if self.options.DROP_DATABASE_BEFORE_TESTS: await self._drop_template_db(conn) await self._create_template_db(conn) await self._create_test_db(conn) await conn.close() if self.options.DROP_DATABASE_BEFORE_TESTS: self._migrate_template_database() async def tear_down_db(self): conn = await asyncpg.connect(self.connection_str, database="postgres") await self._drop_test_db(conn) if self.options.DROP_DATABASE_AFTER_TESTS: await self._drop_template_db(conn) await conn.close()