94 lines
2.8 KiB
Python
94 lines
2.8 KiB
Python
from typing import Protocol
|
|
|
|
import asyncpg
|
|
import sqlalchemy
|
|
from asyncpg import Connection, DuplicateDatabaseError, InvalidCatalogNameError
|
|
|
|
from migrate import DatabaseConfig, migrate
|
|
|
|
|
|
class TestingSettingsInitOptions(Protocol):
|
|
DB_HOST: str
|
|
DB_PORT: int
|
|
DB_TEMPLATE_NAME: str
|
|
DB_NAME: str
|
|
DB_USER: str
|
|
DB_PASSWORD: str
|
|
DROP_DATABASE_BEFORE_TESTS: bool
|
|
DROP_DATABASE_AFTER_TESTS: bool
|
|
|
|
|
|
class TestingDatabaseSetup:
|
|
def __init__(self, options: TestingSettingsInitOptions):
|
|
self.options = options
|
|
|
|
db_connection_url = sqlalchemy.engine.URL.create(
|
|
drivername="postgresql",
|
|
username=self.options.DB_USER,
|
|
password=self.options.DB_PASSWORD,
|
|
host=self.options.DB_HOST,
|
|
port=self.options.DB_PORT,
|
|
)
|
|
self.connection_str = db_connection_url.render_as_string(hide_password=False)
|
|
|
|
async def _create_template_db(self, conn: Connection):
|
|
query = f"CREATE DATABASE {self.options.DB_TEMPLATE_NAME}"
|
|
try:
|
|
await conn.execute(query)
|
|
except DuplicateDatabaseError:
|
|
...
|
|
|
|
async def _drop_template_db(self, conn: Connection):
|
|
query = f"DROP DATABASE {self.options.DB_TEMPLATE_NAME}"
|
|
try:
|
|
await conn.execute(query)
|
|
except InvalidCatalogNameError:
|
|
...
|
|
|
|
async def _create_test_db(self, conn: Connection):
|
|
query = f"CREATE DATABASE {self.options.DB_NAME} TEMPLATE {self.options.DB_TEMPLATE_NAME}"
|
|
try:
|
|
await conn.execute(query)
|
|
except DuplicateDatabaseError:
|
|
...
|
|
|
|
async def _drop_test_db(self, conn: Connection):
|
|
query = f"DROP DATABASE {self.options.DB_NAME}"
|
|
try:
|
|
await conn.execute(query)
|
|
except InvalidCatalogNameError:
|
|
...
|
|
|
|
def _migrate_template_database(self):
|
|
conf = DatabaseConfig(
|
|
HOST=self.options.DB_HOST,
|
|
PORT=self.options.DB_PORT,
|
|
NAME=self.options.DB_TEMPLATE_NAME,
|
|
USER=self.options.DB_USER,
|
|
PASSWORD=self.options.DB_PASSWORD,
|
|
)
|
|
migrate(conf)
|
|
|
|
async def init_db(self):
|
|
conn = await asyncpg.connect(self.connection_str, database="postgres")
|
|
|
|
if self.options.DROP_DATABASE_BEFORE_TESTS:
|
|
await self._drop_template_db(conn)
|
|
|
|
await self._create_template_db(conn)
|
|
await self._create_test_db(conn)
|
|
await conn.close()
|
|
|
|
if self.options.DROP_DATABASE_BEFORE_TESTS:
|
|
self._migrate_template_database()
|
|
|
|
async def tear_down_db(self):
|
|
conn = await asyncpg.connect(self.connection_str, database="postgres")
|
|
|
|
await self._drop_test_db(conn)
|
|
|
|
if self.options.DROP_DATABASE_AFTER_TESTS:
|
|
await self._drop_template_db(conn)
|
|
|
|
await conn.close()
|