182 lines
5.5 KiB
Python
182 lines
5.5 KiB
Python
from __future__ import annotations
|
|
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, cast
|
|
from uuid import UUID
|
|
|
|
import msgspec
|
|
import sqlalchemy
|
|
from litestar.contrib.sqlalchemy.plugins.init import SQLAlchemyInitPlugin
|
|
from litestar.contrib.sqlalchemy.plugins.init.config import SQLAlchemyAsyncConfig
|
|
from litestar.contrib.sqlalchemy.plugins.init.config.common import (
|
|
SESSION_SCOPE_KEY,
|
|
SESSION_TERMINUS_ASGI_EVENTS,
|
|
)
|
|
from litestar.utils import delete_litestar_scope_state, get_litestar_scope_state
|
|
from sqlalchemy import event
|
|
from sqlalchemy.ext.asyncio import (
|
|
AsyncEngine,
|
|
AsyncSession,
|
|
async_sessionmaker,
|
|
create_async_engine,
|
|
)
|
|
from sqlalchemy.pool import NullPool
|
|
|
|
from app.lib import settings
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Any
|
|
|
|
from litestar.types.asgi_types import Message, Scope
|
|
|
|
__all__ = [
|
|
"async_session_factory",
|
|
"config",
|
|
"engine",
|
|
"plugin",
|
|
]
|
|
|
|
|
|
def _default(val: Any) -> str:
|
|
if isinstance(val, UUID):
|
|
return str(val)
|
|
raise TypeError()
|
|
|
|
|
|
@dataclass
|
|
class DBConnectionSettings:
|
|
username: str
|
|
password: str
|
|
host: str
|
|
port: int
|
|
database: str
|
|
|
|
|
|
def create_db_engine(connection_settings: DBConnectionSettings) -> AsyncEngine:
|
|
db_connection_url = sqlalchemy.engine.URL.create(
|
|
drivername="postgresql+asyncpg",
|
|
username=connection_settings.username,
|
|
password=connection_settings.password,
|
|
host=connection_settings.host,
|
|
port=connection_settings.port,
|
|
database=connection_settings.database,
|
|
)
|
|
return create_async_engine(
|
|
db_connection_url,
|
|
echo=settings.db.ECHO,
|
|
echo_pool=settings.db.ECHO_POOL,
|
|
json_serializer=msgspec.json.Encoder(enc_hook=_default),
|
|
max_overflow=settings.db.POOL_MAX_OVERFLOW,
|
|
pool_size=settings.db.POOL_SIZE,
|
|
pool_timeout=settings.db.POOL_TIMEOUT,
|
|
poolclass=NullPool if settings.db.POOL_DISABLE else None,
|
|
)
|
|
|
|
|
|
if "pytest" in sys.modules:
|
|
engine = create_db_engine(
|
|
connection_settings=DBConnectionSettings(
|
|
username=settings.testing.DB_USER,
|
|
password=settings.testing.DB_PASSWORD,
|
|
host=settings.testing.DB_HOST,
|
|
port=settings.testing.DB_PORT,
|
|
database=settings.testing.DB_NAME,
|
|
)
|
|
)
|
|
else:
|
|
engine = create_db_engine(
|
|
connection_settings=DBConnectionSettings(
|
|
username=settings.db.USER,
|
|
password=settings.db.PASSWORD,
|
|
host=settings.db.HOST,
|
|
port=settings.db.PORT,
|
|
database=settings.db.NAME,
|
|
)
|
|
)
|
|
|
|
|
|
"""Configure via DatabaseSettings.
|
|
|
|
Overrides default JSON
|
|
serializer to use `msgspec`. See [`create_async_engine()`][sqlalchemy.ext.asyncio.create_async_engine]
|
|
for detailed instructions.
|
|
"""
|
|
async_session_factory = async_sessionmaker(
|
|
engine, expire_on_commit=False, class_=AsyncSession
|
|
)
|
|
"""Database session factory.
|
|
|
|
See [`async_sessionmaker()`][sqlalchemy.ext.asyncio.async_sessionmaker].
|
|
"""
|
|
|
|
|
|
@event.listens_for(engine.sync_engine, "connect")
|
|
def _sqla_on_connect(dbapi_connection: Any, _: Any) -> Any:
|
|
"""Using orjson for serialization of the json column values means that the
|
|
output is binary, not `str` like `json.dumps` would output.
|
|
|
|
SQLAlchemy expects that the json serializer returns `str` and calls
|
|
`.encode()` on the value to turn it to bytes before writing to the
|
|
JSONB column. I'd need to either wrap `orjson.dumps` to return a
|
|
`str` so that SQLAlchemy could then convert it to binary, or do the
|
|
following, which changes the behaviour of the dialect to expect a
|
|
binary value from the serializer.
|
|
|
|
See Also:
|
|
https://github.com/sqlalchemy/sqlalchemy/blob/14bfbadfdf9260a1c40f63b31641b27fe9de12a0/lib/sqlalchemy/dialects/postgresql/asyncpg.py#L934
|
|
"""
|
|
|
|
def encoder(bin_value: bytes) -> bytes:
|
|
# \x01 is the prefix for jsonb used by PostgreSQL.
|
|
# asyncpg requires it when format='binary'
|
|
return b"\x01" + bin_value
|
|
|
|
def decoder(bin_value: bytes) -> Any:
|
|
# the byte is the \x01 prefix for jsonb used by PostgreSQL.
|
|
# asyncpg returns it when format='binary'
|
|
return msgspec.json.decode(bin_value[1:])
|
|
|
|
dbapi_connection.await_(
|
|
dbapi_connection.driver_connection.set_type_codec(
|
|
"jsonb",
|
|
encoder=encoder,
|
|
decoder=decoder,
|
|
schema="pg_catalog",
|
|
format="binary",
|
|
)
|
|
)
|
|
|
|
|
|
async def before_send_handler(message: Message, scope: Scope) -> None:
|
|
"""Custom `before_send_handler` for SQLAlchemy plugin that inspects the
|
|
status of response and commits, or rolls back the database.
|
|
|
|
Args:
|
|
message: ASGI message
|
|
scope: ASGI scope
|
|
"""
|
|
session = cast(
|
|
"AsyncSession | None", get_litestar_scope_state(scope, SESSION_SCOPE_KEY)
|
|
)
|
|
try:
|
|
if session is not None and message["type"] == "http.response.start":
|
|
if 200 <= message["status"] < 300:
|
|
await session.commit()
|
|
else:
|
|
await session.rollback()
|
|
finally:
|
|
if session is not None and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
|
|
await session.close()
|
|
delete_litestar_scope_state(scope, SESSION_SCOPE_KEY)
|
|
|
|
|
|
config = SQLAlchemyAsyncConfig(
|
|
session_dependency_key=settings.api.DB_SESSION_DEPENDENCY_KEY,
|
|
engine_instance=engine,
|
|
session_maker=async_session_factory,
|
|
before_send_handler=before_send_handler,
|
|
)
|
|
|
|
plugin = SQLAlchemyInitPlugin(config=config)
|