151 lines
4.7 KiB
Python
151 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, cast, Literal
|
|
from uuid import UUID
|
|
|
|
import msgspec
|
|
from litestar.contrib.sqlalchemy.plugins.init import SQLAlchemyInitPlugin
|
|
from litestar.contrib.sqlalchemy.plugins.init.config import SQLAlchemyAsyncConfig
|
|
from litestar.contrib.sqlalchemy.plugins.init.config.common import (
|
|
SESSION_SCOPE_KEY,
|
|
SESSION_TERMINUS_ASGI_EVENTS,
|
|
)
|
|
from litestar.utils import delete_litestar_scope_state, get_litestar_scope_state
|
|
from sqlalchemy import event
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from sqlalchemy.pool import NullPool
|
|
|
|
|
|
DB_HOST = "localhost"
|
|
DB_PORT = 5432
|
|
DB_NAME = "addressbook"
|
|
DB_USER = "addressbook"
|
|
DB_PASSWORD = "addressbook"
|
|
|
|
|
|
@dataclass
|
|
class DatabaseSettings:
|
|
URL: str = f"postgresql+asyncpg://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
|
|
ECHO: bool = True
|
|
ECHO_POOL: bool | Literal["debug"] = False
|
|
POOL_DISABLE: bool = False
|
|
POOL_MAX_OVERFLOW: int = 10
|
|
POOL_SIZE: int = 5
|
|
POOL_TIMEOUT: int = 30
|
|
DB_SESSION_DEPENDENCY_KEY: str = "db_session"
|
|
|
|
|
|
settings = DatabaseSettings()
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Any
|
|
|
|
from litestar.types.asgi_types import Message, Scope
|
|
|
|
__all__ = [
|
|
"async_session_factory",
|
|
"config",
|
|
"engine",
|
|
"plugin",
|
|
]
|
|
|
|
|
|
def _default(val: Any) -> str:
|
|
if isinstance(val, UUID):
|
|
return str(val)
|
|
raise TypeError()
|
|
|
|
|
|
engine = create_async_engine(
|
|
settings.URL,
|
|
echo=settings.ECHO,
|
|
echo_pool=settings.ECHO_POOL,
|
|
json_serializer=msgspec.json.Encoder(enc_hook=_default),
|
|
max_overflow=settings.POOL_MAX_OVERFLOW,
|
|
pool_size=settings.POOL_SIZE,
|
|
pool_timeout=settings.POOL_TIMEOUT,
|
|
poolclass=NullPool if settings.POOL_DISABLE else None,
|
|
)
|
|
"""Configure via DatabaseSettings.
|
|
|
|
Overrides default JSON
|
|
serializer to use `msgspec`. See [`create_async_engine()`][sqlalchemy.ext.asyncio.create_async_engine]
|
|
for detailed instructions.
|
|
"""
|
|
async_session_factory = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
|
"""Database session factory.
|
|
|
|
See [`async_sessionmaker()`][sqlalchemy.ext.asyncio.async_sessionmaker].
|
|
"""
|
|
|
|
|
|
@event.listens_for(engine.sync_engine, "connect")
|
|
def _sqla_on_connect(dbapi_connection: Any, _: Any) -> Any:
|
|
"""Using orjson for serialization of the json column values means that the
|
|
output is binary, not `str` like `json.dumps` would output.
|
|
|
|
SQLAlchemy expects that the json serializer returns `str` and calls
|
|
`.encode()` on the value to turn it to bytes before writing to the
|
|
JSONB column. I'd need to either wrap `orjson.dumps` to return a
|
|
`str` so that SQLAlchemy could then convert it to binary, or do the
|
|
following, which changes the behaviour of the dialect to expect a
|
|
binary value from the serializer.
|
|
|
|
See Also:
|
|
https://github.com/sqlalchemy/sqlalchemy/blob/14bfbadfdf9260a1c40f63b31641b27fe9de12a0/lib/sqlalchemy/dialects/postgresql/asyncpg.py#L934
|
|
"""
|
|
|
|
def encoder(bin_value: bytes) -> bytes:
|
|
# \x01 is the prefix for jsonb used by PostgreSQL.
|
|
# asyncpg requires it when format='binary'
|
|
return b"\x01" + bin_value
|
|
|
|
def decoder(bin_value: bytes) -> Any:
|
|
# the byte is the \x01 prefix for jsonb used by PostgreSQL.
|
|
# asyncpg returns it when format='binary'
|
|
return msgspec.json.decode(bin_value[1:])
|
|
|
|
dbapi_connection.await_(
|
|
dbapi_connection.driver_connection.set_type_codec(
|
|
"jsonb",
|
|
encoder=encoder,
|
|
decoder=decoder,
|
|
schema="pg_catalog",
|
|
format="binary",
|
|
)
|
|
)
|
|
|
|
|
|
async def before_send_handler(message: Message, scope: Scope) -> None:
|
|
"""Custom `before_send_handler` for SQLAlchemy plugin that inspects the
|
|
status of response and commits, or rolls back the database.
|
|
|
|
Args:
|
|
message: ASGI message
|
|
_:
|
|
scope: ASGI scope
|
|
"""
|
|
session = cast("AsyncSession | None", get_litestar_scope_state(scope, SESSION_SCOPE_KEY))
|
|
try:
|
|
if session is not None and message["type"] == "http.response.start":
|
|
if 200 <= message["status"] < 300:
|
|
await session.commit()
|
|
else:
|
|
await session.rollback()
|
|
finally:
|
|
if session is not None and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
|
|
await session.close()
|
|
delete_litestar_scope_state(scope, SESSION_SCOPE_KEY)
|
|
|
|
|
|
config = SQLAlchemyAsyncConfig(
|
|
session_dependency_key=settings.DB_SESSION_DEPENDENCY_KEY,
|
|
engine_instance=engine,
|
|
session_maker=async_session_factory,
|
|
before_send_handler=before_send_handler,
|
|
)
|
|
|
|
plugin = SQLAlchemyInitPlugin(config=config)
|