Basic testing infrastructure
This commit is contained in:
93
app/lib/test_extras/db_setup.py
Normal file
93
app/lib/test_extras/db_setup.py
Normal file
@ -0,0 +1,93 @@
|
||||
from typing import Protocol
|
||||
|
||||
import asyncpg
|
||||
import sqlalchemy
|
||||
from asyncpg import Connection, DuplicateDatabaseError, InvalidCatalogNameError
|
||||
|
||||
from migrate import DatabaseConfig, migrate
|
||||
|
||||
|
||||
class TestingSettingsInitOptions(Protocol):
|
||||
DB_HOST: str
|
||||
DB_PORT: int
|
||||
DB_TEMPLATE_NAME: str
|
||||
DB_NAME: str
|
||||
DB_USER: str
|
||||
DB_PASSWORD: str
|
||||
DROP_DATABASE_BEFORE_TESTS: bool
|
||||
DROP_DATABASE_AFTER_TESTS: bool
|
||||
|
||||
|
||||
class TestingDatabaseSetup:
|
||||
def __init__(self, options: TestingSettingsInitOptions):
|
||||
self.options = options
|
||||
|
||||
db_connection_url = sqlalchemy.engine.URL.create(
|
||||
drivername="postgresql",
|
||||
username=self.options.DB_USER,
|
||||
password=self.options.DB_PASSWORD,
|
||||
host=self.options.DB_HOST,
|
||||
port=self.options.DB_PORT,
|
||||
)
|
||||
self.connection_str = db_connection_url.render_as_string(hide_password=False)
|
||||
|
||||
async def _create_template_db(self, conn: Connection):
|
||||
query = f"CREATE DATABASE {self.options.DB_TEMPLATE_NAME}"
|
||||
try:
|
||||
await conn.execute(query)
|
||||
except DuplicateDatabaseError:
|
||||
...
|
||||
|
||||
async def _drop_template_db(self, conn: Connection):
|
||||
query = f"DROP DATABASE {self.options.DB_TEMPLATE_NAME}"
|
||||
try:
|
||||
await conn.execute(query)
|
||||
except InvalidCatalogNameError:
|
||||
...
|
||||
|
||||
async def _create_test_db(self, conn: Connection):
|
||||
query = f"CREATE DATABASE {self.options.DB_NAME} TEMPLATE {self.options.DB_TEMPLATE_NAME}"
|
||||
try:
|
||||
await conn.execute(query)
|
||||
except DuplicateDatabaseError:
|
||||
...
|
||||
|
||||
async def _drop_test_db(self, conn: Connection):
|
||||
query = f"DROP DATABASE {self.options.DB_NAME}"
|
||||
try:
|
||||
await conn.execute(query)
|
||||
except InvalidCatalogNameError:
|
||||
...
|
||||
|
||||
def _migrate_template_database(self):
|
||||
conf = DatabaseConfig(
|
||||
HOST=self.options.DB_HOST,
|
||||
PORT=self.options.DB_PORT,
|
||||
NAME=self.options.DB_TEMPLATE_NAME,
|
||||
USER=self.options.DB_USER,
|
||||
PASSWORD=self.options.DB_PASSWORD,
|
||||
)
|
||||
migrate(conf)
|
||||
|
||||
async def init_db(self):
|
||||
conn = await asyncpg.connect(self.connection_str, database="postgres")
|
||||
|
||||
if self.options.DROP_DATABASE_BEFORE_TESTS:
|
||||
await self._drop_template_db(conn)
|
||||
|
||||
await self._create_template_db(conn)
|
||||
await self._create_test_db(conn)
|
||||
await conn.close()
|
||||
|
||||
if self.options.DROP_DATABASE_BEFORE_TESTS:
|
||||
self._migrate_template_database()
|
||||
|
||||
async def tear_down_db(self):
|
||||
conn = await asyncpg.connect(self.connection_str, database="postgres")
|
||||
|
||||
await self._drop_test_db(conn)
|
||||
|
||||
if self.options.DROP_DATABASE_AFTER_TESTS:
|
||||
await self._drop_template_db(conn)
|
||||
|
||||
await conn.close()
|
||||
Reference in New Issue
Block a user