from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING, cast, Literal from uuid import UUID import msgspec from litestar.contrib.sqlalchemy.plugins.init import SQLAlchemyInitPlugin from litestar.contrib.sqlalchemy.plugins.init.config import SQLAlchemyAsyncConfig from litestar.contrib.sqlalchemy.plugins.init.config.common import ( SESSION_SCOPE_KEY, SESSION_TERMINUS_ASGI_EVENTS, ) from litestar.utils import delete_litestar_scope_state, get_litestar_scope_state from sqlalchemy import event from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import NullPool @dataclass class DatabaseSettings: URL: str = "postgresql+asyncpg://televend:televend@localhost:5433/televend" ECHO: bool = True ECHO_POOL: bool | Literal["debug"] = False POOL_DISABLE: bool = False POOL_MAX_OVERFLOW: int = 10 POOL_SIZE: int = 5 POOL_TIMEOUT: int = 30 DB_SESSION_DEPENDENCY_KEY: str = "db_session" settings = DatabaseSettings() if TYPE_CHECKING: from typing import Any from litestar.types.asgi_types import Message, Scope __all__ = [ "async_session_factory", "config", "engine", "plugin", ] def _default(val: Any) -> str: if isinstance(val, UUID): return str(val) raise TypeError() engine = create_async_engine( settings.URL, echo=settings.ECHO, echo_pool=settings.ECHO_POOL, json_serializer=msgspec.json.Encoder(enc_hook=_default), max_overflow=settings.POOL_MAX_OVERFLOW, pool_size=settings.POOL_SIZE, pool_timeout=settings.POOL_TIMEOUT, poolclass=NullPool if settings.POOL_DISABLE else None, ) """Configure via DatabaseSettings. Overrides default JSON serializer to use `msgspec`. See [`create_async_engine()`][sqlalchemy.ext.asyncio.create_async_engine] for detailed instructions. """ async_session_factory = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) """Database session factory. See [`async_sessionmaker()`][sqlalchemy.ext.asyncio.async_sessionmaker]. """ @event.listens_for(engine.sync_engine, "connect") def _sqla_on_connect(dbapi_connection: Any, _: Any) -> Any: """Using orjson for serialization of the json column values means that the output is binary, not `str` like `json.dumps` would output. SQLAlchemy expects that the json serializer returns `str` and calls `.encode()` on the value to turn it to bytes before writing to the JSONB column. I'd need to either wrap `orjson.dumps` to return a `str` so that SQLAlchemy could then convert it to binary, or do the following, which changes the behaviour of the dialect to expect a binary value from the serializer. See Also: https://github.com/sqlalchemy/sqlalchemy/blob/14bfbadfdf9260a1c40f63b31641b27fe9de12a0/lib/sqlalchemy/dialects/postgresql/asyncpg.py#L934 """ def encoder(bin_value: bytes) -> bytes: # \x01 is the prefix for jsonb used by PostgreSQL. # asyncpg requires it when format='binary' return b"\x01" + bin_value def decoder(bin_value: bytes) -> Any: # the byte is the \x01 prefix for jsonb used by PostgreSQL. # asyncpg returns it when format='binary' return msgspec.json.decode(bin_value[1:]) dbapi_connection.await_( dbapi_connection.driver_connection.set_type_codec( "jsonb", encoder=encoder, decoder=decoder, schema="pg_catalog", format="binary", ) ) async def before_send_handler(message: Message, scope: Scope) -> None: """Custom `before_send_handler` for SQLAlchemy plugin that inspects the status of response and commits, or rolls back the database. Args: message: ASGI message _: scope: ASGI scope """ session = cast("AsyncSession | None", get_litestar_scope_state(scope, SESSION_SCOPE_KEY)) try: if session is not None and message["type"] == "http.response.start": if 200 <= message["status"] < 300: await session.commit() else: await session.rollback() finally: if session is not None and message["type"] in SESSION_TERMINUS_ASGI_EVENTS: await session.close() delete_litestar_scope_state(scope, SESSION_SCOPE_KEY) config = SQLAlchemyAsyncConfig( session_dependency_key=settings.DB_SESSION_DEPENDENCY_KEY, engine_instance=engine, session_maker=async_session_factory, before_send_handler=before_send_handler, ) plugin = SQLAlchemyInitPlugin(config=config)