Files
test-litestar-addressbook/app/lib/test_extras/db_setup.py
2023-09-20 09:56:52 +02:00

94 lines
2.8 KiB
Python

from typing import Protocol
import asyncpg
import sqlalchemy
from asyncpg import Connection, DuplicateDatabaseError, InvalidCatalogNameError
from migrate import DatabaseConfig, migrate
class TestingSettingsInitOptions(Protocol):
DB_HOST: str
DB_PORT: int
DB_TEMPLATE_NAME: str
DB_NAME: str
DB_USER: str
DB_PASSWORD: str
DROP_DATABASE_BEFORE_TESTS: bool
DROP_DATABASE_AFTER_TESTS: bool
class TestingDatabaseSetup:
def __init__(self, options: TestingSettingsInitOptions):
self.options = options
db_connection_url = sqlalchemy.engine.URL.create(
drivername="postgresql",
username=self.options.DB_USER,
password=self.options.DB_PASSWORD,
host=self.options.DB_HOST,
port=self.options.DB_PORT,
)
self.connection_str = db_connection_url.render_as_string(hide_password=False)
async def _create_template_db(self, conn: Connection):
query = f"CREATE DATABASE {self.options.DB_TEMPLATE_NAME}"
try:
await conn.execute(query)
except DuplicateDatabaseError:
...
async def _drop_template_db(self, conn: Connection):
query = f"DROP DATABASE {self.options.DB_TEMPLATE_NAME}"
try:
await conn.execute(query)
except InvalidCatalogNameError:
...
async def _create_test_db(self, conn: Connection):
query = f"CREATE DATABASE {self.options.DB_NAME} TEMPLATE {self.options.DB_TEMPLATE_NAME}"
try:
await conn.execute(query)
except DuplicateDatabaseError:
...
async def _drop_test_db(self, conn: Connection):
query = f"DROP DATABASE {self.options.DB_NAME}"
try:
await conn.execute(query)
except InvalidCatalogNameError:
...
def _migrate_template_database(self):
conf = DatabaseConfig(
HOST=self.options.DB_HOST,
PORT=self.options.DB_PORT,
NAME=self.options.DB_TEMPLATE_NAME,
USER=self.options.DB_USER,
PASSWORD=self.options.DB_PASSWORD,
)
migrate(conf)
async def init_db(self):
conn = await asyncpg.connect(self.connection_str, database="postgres")
if self.options.DROP_DATABASE_BEFORE_TESTS:
await self._drop_template_db(conn)
await self._create_template_db(conn)
await self._create_test_db(conn)
await conn.close()
if self.options.DROP_DATABASE_BEFORE_TESTS:
self._migrate_template_database()
async def tear_down_db(self):
conn = await asyncpg.connect(self.connection_str, database="postgres")
await self._drop_test_db(conn)
if self.options.DROP_DATABASE_AFTER_TESTS:
await self._drop_template_db(conn)
await conn.close()