Files
litestar-machines-test/app/lib/sqlalchemy_plugin.py
Eden Kirin c7060c7ed3 Asset items
2023-08-27 23:13:24 +02:00

144 lines
4.5 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, cast, Literal
from uuid import UUID
import msgspec
from litestar.contrib.sqlalchemy.plugins.init import SQLAlchemyInitPlugin
from litestar.contrib.sqlalchemy.plugins.init.config import SQLAlchemyAsyncConfig
from litestar.contrib.sqlalchemy.plugins.init.config.common import (
SESSION_SCOPE_KEY,
SESSION_TERMINUS_ASGI_EVENTS,
)
from litestar.utils import delete_litestar_scope_state, get_litestar_scope_state
from sqlalchemy import event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
@dataclass
class DatabaseSettings:
URL: str = "postgresql+asyncpg://televend:televend@localhost:5433/televend"
ECHO: bool = True
ECHO_POOL: bool | Literal["debug"] = False
POOL_DISABLE: bool = False
POOL_MAX_OVERFLOW: int = 10
POOL_SIZE: int = 5
POOL_TIMEOUT: int = 30
DB_SESSION_DEPENDENCY_KEY: str = "db_session"
settings = DatabaseSettings()
if TYPE_CHECKING:
from typing import Any
from litestar.types.asgi_types import Message, Scope
__all__ = [
"async_session_factory",
"config",
"engine",
"plugin",
]
def _default(val: Any) -> str:
if isinstance(val, UUID):
return str(val)
raise TypeError()
engine = create_async_engine(
settings.URL,
echo=settings.ECHO,
echo_pool=settings.ECHO_POOL,
json_serializer=msgspec.json.Encoder(enc_hook=_default),
max_overflow=settings.POOL_MAX_OVERFLOW,
pool_size=settings.POOL_SIZE,
pool_timeout=settings.POOL_TIMEOUT,
poolclass=NullPool if settings.POOL_DISABLE else None,
)
"""Configure via DatabaseSettings.
Overrides default JSON
serializer to use `msgspec`. See [`create_async_engine()`][sqlalchemy.ext.asyncio.create_async_engine]
for detailed instructions.
"""
async_session_factory = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
"""Database session factory.
See [`async_sessionmaker()`][sqlalchemy.ext.asyncio.async_sessionmaker].
"""
@event.listens_for(engine.sync_engine, "connect")
def _sqla_on_connect(dbapi_connection: Any, _: Any) -> Any:
"""Using orjson for serialization of the json column values means that the
output is binary, not `str` like `json.dumps` would output.
SQLAlchemy expects that the json serializer returns `str` and calls
`.encode()` on the value to turn it to bytes before writing to the
JSONB column. I'd need to either wrap `orjson.dumps` to return a
`str` so that SQLAlchemy could then convert it to binary, or do the
following, which changes the behaviour of the dialect to expect a
binary value from the serializer.
See Also:
https://github.com/sqlalchemy/sqlalchemy/blob/14bfbadfdf9260a1c40f63b31641b27fe9de12a0/lib/sqlalchemy/dialects/postgresql/asyncpg.py#L934
"""
def encoder(bin_value: bytes) -> bytes:
# \x01 is the prefix for jsonb used by PostgreSQL.
# asyncpg requires it when format='binary'
return b"\x01" + bin_value
def decoder(bin_value: bytes) -> Any:
# the byte is the \x01 prefix for jsonb used by PostgreSQL.
# asyncpg returns it when format='binary'
return msgspec.json.decode(bin_value[1:])
dbapi_connection.await_(
dbapi_connection.driver_connection.set_type_codec(
"jsonb",
encoder=encoder,
decoder=decoder,
schema="pg_catalog",
format="binary",
)
)
async def before_send_handler(message: Message, scope: Scope) -> None:
"""Custom `before_send_handler` for SQLAlchemy plugin that inspects the
status of response and commits, or rolls back the database.
Args:
message: ASGI message
_:
scope: ASGI scope
"""
session = cast("AsyncSession | None", get_litestar_scope_state(scope, SESSION_SCOPE_KEY))
try:
if session is not None and message["type"] == "http.response.start":
if 200 <= message["status"] < 300:
await session.commit()
else:
await session.rollback()
finally:
if session is not None and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
await session.close()
delete_litestar_scope_state(scope, SESSION_SCOPE_KEY)
config = SQLAlchemyAsyncConfig(
session_dependency_key=settings.DB_SESSION_DEPENDENCY_KEY,
engine_instance=engine,
session_maker=async_session_factory,
before_send_handler=before_send_handler,
)
plugin = SQLAlchemyInitPlugin(config=config)