Files
test-litestar-addressbook/app/lib/sqlalchemy_plugin.py
2023-09-20 09:56:52 +02:00

170 lines
5.1 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, cast
from uuid import UUID
import msgspec
import sqlalchemy
from litestar.contrib.sqlalchemy.plugins.init import SQLAlchemyInitPlugin
from litestar.contrib.sqlalchemy.plugins.init.config import SQLAlchemyAsyncConfig
from litestar.contrib.sqlalchemy.plugins.init.config.common import (
SESSION_SCOPE_KEY,
SESSION_TERMINUS_ASGI_EVENTS,
)
from litestar.utils import delete_litestar_scope_state, get_litestar_scope_state
from sqlalchemy import event
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.pool import NullPool
from app.lib import settings
if TYPE_CHECKING:
from typing import Any
from litestar.types.asgi_types import Message, Scope
__all__ = [
"async_session_factory",
"config",
"engine",
"plugin",
]
def _default(val: Any) -> str:
if isinstance(val, UUID):
return str(val)
raise TypeError()
@dataclass
class DBConnectionSettings:
username: str
password: str
host: str
port: int
database: str
def create_db_engine(connection_settings: DBConnectionSettings) -> AsyncEngine:
db_connection_url = sqlalchemy.engine.URL.create(
drivername="postgresql+asyncpg",
username=connection_settings.username,
password=connection_settings.password,
host=connection_settings.host,
port=connection_settings.port,
database=connection_settings.database,
)
return create_async_engine(
db_connection_url,
echo=settings.db.ECHO,
echo_pool=settings.db.ECHO_POOL,
json_serializer=msgspec.json.Encoder(enc_hook=_default),
max_overflow=settings.db.POOL_MAX_OVERFLOW,
pool_size=settings.db.POOL_SIZE,
pool_timeout=settings.db.POOL_TIMEOUT,
poolclass=NullPool if settings.db.POOL_DISABLE else None,
)
engine = create_db_engine(
connection_settings=DBConnectionSettings(
username=settings.db.USER,
password=settings.db.PASSWORD,
host=settings.db.HOST,
port=settings.db.PORT,
database=settings.db.NAME,
)
)
"""Configure via DatabaseSettings.
Overrides default JSON
serializer to use `msgspec`. See [`create_async_engine()`][sqlalchemy.ext.asyncio.create_async_engine]
for detailed instructions.
"""
async_session_factory = async_sessionmaker(
engine, expire_on_commit=False, class_=AsyncSession
)
"""Database session factory.
See [`async_sessionmaker()`][sqlalchemy.ext.asyncio.async_sessionmaker].
"""
@event.listens_for(engine.sync_engine, "connect")
def _sqla_on_connect(dbapi_connection: Any, _: Any) -> Any:
"""Using orjson for serialization of the json column values means that the
output is binary, not `str` like `json.dumps` would output.
SQLAlchemy expects that the json serializer returns `str` and calls
`.encode()` on the value to turn it to bytes before writing to the
JSONB column. I'd need to either wrap `orjson.dumps` to return a
`str` so that SQLAlchemy could then convert it to binary, or do the
following, which changes the behaviour of the dialect to expect a
binary value from the serializer.
See Also:
https://github.com/sqlalchemy/sqlalchemy/blob/14bfbadfdf9260a1c40f63b31641b27fe9de12a0/lib/sqlalchemy/dialects/postgresql/asyncpg.py#L934
"""
def encoder(bin_value: bytes) -> bytes:
# \x01 is the prefix for jsonb used by PostgreSQL.
# asyncpg requires it when format='binary'
return b"\x01" + bin_value
def decoder(bin_value: bytes) -> Any:
# the byte is the \x01 prefix for jsonb used by PostgreSQL.
# asyncpg returns it when format='binary'
return msgspec.json.decode(bin_value[1:])
dbapi_connection.await_(
dbapi_connection.driver_connection.set_type_codec(
"jsonb",
encoder=encoder,
decoder=decoder,
schema="pg_catalog",
format="binary",
)
)
async def before_send_handler(message: Message, scope: Scope) -> None:
"""Custom `before_send_handler` for SQLAlchemy plugin that inspects the
status of response and commits, or rolls back the database.
Args:
message: ASGI message
scope: ASGI scope
"""
session = cast(
"AsyncSession | None", get_litestar_scope_state(scope, SESSION_SCOPE_KEY)
)
try:
if session is not None and message["type"] == "http.response.start":
if 200 <= message["status"] < 300:
await session.commit()
else:
await session.rollback()
finally:
if session is not None and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
await session.close()
delete_litestar_scope_state(scope, SESSION_SCOPE_KEY)
config = SQLAlchemyAsyncConfig(
session_dependency_key=settings.api.DB_SESSION_DEPENDENCY_KEY,
engine_instance=engine,
session_maker=async_session_factory,
before_send_handler=before_send_handler,
)
plugin = SQLAlchemyInitPlugin(config=config)