From 4875837d4edcf75bbf7624af1a95646079d0e939 Mon Sep 17 00:00:00 2001 From: Eden Kirin Date: Thu, 14 Sep 2023 17:44:02 +0200 Subject: [PATCH] Basic users and cities api --- Makefile | 14 ++++ app/__init__.py | 0 app/controllers/__init__.py | 20 +++++ app/controllers/city.py | 63 +++++++++++++++ app/controllers/user.py | 63 +++++++++++++++ app/domain/__init__.py | 0 app/domain/city.py | 32 ++++++++ app/domain/enums.py | 0 app/domain/user.py | 32 ++++++++ app/dto/__init__.py | 0 app/lib/__init__.py | 0 app/lib/dependencies.py | 150 +++++++++++++++++++++++++++++++++++ app/lib/exceptions.py | 68 ++++++++++++++++ app/lib/filter_repository.py | 38 +++++++++ app/lib/filters.py | 10 +++ app/lib/responses.py | 26 ++++++ app/lib/service.py | 95 ++++++++++++++++++++++ app/lib/sqlalchemy_plugin.py | 150 +++++++++++++++++++++++++++++++++++ main.py | 28 +++++++ migrations/0001-initial.sql | 26 ++++++ 20 files changed, 815 insertions(+) create mode 100644 Makefile create mode 100644 app/__init__.py create mode 100644 app/controllers/__init__.py create mode 100644 app/controllers/city.py create mode 100644 app/controllers/user.py create mode 100644 app/domain/__init__.py create mode 100644 app/domain/city.py create mode 100644 app/domain/enums.py create mode 100644 app/domain/user.py create mode 100644 app/dto/__init__.py create mode 100644 app/lib/__init__.py create mode 100644 app/lib/dependencies.py create mode 100644 app/lib/exceptions.py create mode 100644 app/lib/filter_repository.py create mode 100644 app/lib/filters.py create mode 100644 app/lib/responses.py create mode 100644 app/lib/service.py create mode 100644 app/lib/sqlalchemy_plugin.py create mode 100644 main.py create mode 100644 migrations/0001-initial.sql diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..0694b0a --- /dev/null +++ b/Makefile @@ -0,0 +1,14 @@ +ifeq ($(VIRTUAL_ENV),) + RUN_IN_ENV=poetry run +else + RUN_IN_ENV= +endif + +run: + @ $(RUN_IN_ENV) uvicorn \ + main:app \ + --reload \ + --reload-dir=app + +shell: + @ $(RUN_IN_ENV) python manage.py shell diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/controllers/__init__.py b/app/controllers/__init__.py new file mode 100644 index 0000000..99ca019 --- /dev/null +++ b/app/controllers/__init__.py @@ -0,0 +1,20 @@ +from litestar import Router + +__all__ = ["create_router"] + +from app.controllers.city import CityController +from app.controllers.user import UserController +from app.domain.city import City + + +def create_router() -> Router: + return Router( + path="/v1", + route_handlers=[ + CityController, + UserController, + ], + signature_namespace={ + "City": City, + }, + ) diff --git a/app/controllers/city.py b/app/controllers/city.py new file mode 100644 index 0000000..704f3f9 --- /dev/null +++ b/app/controllers/city.py @@ -0,0 +1,63 @@ +from typing import TYPE_CHECKING, Optional +from uuid import UUID + +from litestar import Controller, get +from litestar.contrib.repository.filters import LimitOffset, SearchFilter +from litestar.di import Provide +from sqlalchemy.ext.asyncio import AsyncSession + +from app.domain.city import ( + City, + CityReadDTO, + CityWriteDTO, + Repository, + Service, +) +from app.lib.responses import ObjectListResponse, ObjectResponse + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + + +DETAIL_ROUTE = "/{city_id:uuid}" + + +def provides_service(db_session: AsyncSession) -> Service: + """Constructs repository and service objects for the request.""" + return Service(Repository(session=db_session)) + + +class CityController(Controller): + dto = CityWriteDTO + return_dto = CityReadDTO + path = "/cities" + dependencies = { + "service": Provide(provides_service, sync_to_thread=False), + } + tags = ["Cities"] + + @get() + async def get_cities( + self, service: Service, search: Optional[str] = None + ) -> ObjectListResponse[City]: + filters = [ + LimitOffset(limit=20, offset=0), + ] + + if search is not None: + filters.append( + SearchFilter( + field_name="caption", + value=search, + ), + ) + + content = await service.list(*filters) + return ObjectListResponse(content=content) + + @get(DETAIL_ROUTE) + async def get_city( + self, service: Service, city_id: UUID + ) -> ObjectResponse[City]: + content = await service.get(city_id) + return ObjectResponse(content=content) diff --git a/app/controllers/user.py b/app/controllers/user.py new file mode 100644 index 0000000..e22e5d9 --- /dev/null +++ b/app/controllers/user.py @@ -0,0 +1,63 @@ +from typing import TYPE_CHECKING, Optional +from uuid import UUID + +from litestar import Controller, get +from litestar.contrib.repository.filters import LimitOffset, SearchFilter +from litestar.di import Provide +from sqlalchemy.ext.asyncio import AsyncSession + +from app.domain.user import ( + User, + UserReadDTO, + UserWriteDTO, + Repository, + Service, +) +from app.lib.responses import ObjectListResponse, ObjectResponse + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + + +DETAIL_ROUTE = "/{user_id:uuid}" + + +def provides_service(db_session: AsyncSession) -> Service: + """Constructs repository and service objects for the request.""" + return Service(Repository(session=db_session)) + + +class UserController(Controller): + dto = UserWriteDTO + return_dto = UserReadDTO + path = "/users" + dependencies = { + "service": Provide(provides_service, sync_to_thread=False), + } + tags = ["Users"] + + @get() + async def get_users( + self, service: Service, search: Optional[str] = None + ) -> ObjectListResponse[User]: + filters = [ + LimitOffset(limit=20, offset=0), + ] + + if search is not None: + filters.append( + SearchFilter( + field_name="caption", + value=search, + ), + ) + + content = await service.list(*filters) + return ObjectListResponse(content=content) + + @get(DETAIL_ROUTE) + async def get_user( + self, service: Service, user_id: UUID + ) -> ObjectResponse[User]: + content = await service.get(user_id) + return ObjectResponse(content=content) diff --git a/app/domain/__init__.py b/app/domain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/domain/city.py b/app/domain/city.py new file mode 100644 index 0000000..10c99d1 --- /dev/null +++ b/app/domain/city.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Annotated + +from litestar.contrib.sqlalchemy.base import UUIDBase +from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO +from litestar.dto import DTOConfig +from sqlalchemy.orm import Mapped + +from app.lib import service +from app.lib.filter_repository import FilterRepository + + +class City(UUIDBase): + __tablename__ = "cities" # type: ignore[assignment] + + name: Mapped[str] + postal_code: Mapped[str] + created_at: Mapped[datetime] + modified_at: Mapped[datetime] + + +class Repository(FilterRepository[City]): + model_type = City + + +class Service(service.Service[City]): + repository_type = Repository + + +write_config = DTOConfig(exclude={"id"}) +CityWriteDTO = SQLAlchemyDTO[Annotated[City, write_config]] +CityReadDTO = SQLAlchemyDTO[City] diff --git a/app/domain/enums.py b/app/domain/enums.py new file mode 100644 index 0000000..e69de29 diff --git a/app/domain/user.py b/app/domain/user.py new file mode 100644 index 0000000..e9b32d3 --- /dev/null +++ b/app/domain/user.py @@ -0,0 +1,32 @@ +from datetime import datetime +from typing import Annotated + +from litestar.contrib.sqlalchemy.base import UUIDBase +from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO +from litestar.dto import DTOConfig +from sqlalchemy.orm import Mapped + +from app.lib import service +from app.lib.filter_repository import FilterRepository + + +class User(UUIDBase): + __tablename__ = "users" # type: ignore[assignment] + + first_name: Mapped[str] + last_name: Mapped[str] + created_at: Mapped[datetime] + modified_at: Mapped[datetime] + + +class Repository(FilterRepository[User]): + model_type = User + + +class Service(service.Service[User]): + repository_type = Repository + + +write_config = DTOConfig(exclude={"id"}) +UserWriteDTO = SQLAlchemyDTO[Annotated[User, write_config]] +UserReadDTO = SQLAlchemyDTO[User] diff --git a/app/dto/__init__.py b/app/dto/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/lib/__init__.py b/app/lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/lib/dependencies.py b/app/lib/dependencies.py new file mode 100644 index 0000000..36f4df2 --- /dev/null +++ b/app/lib/dependencies.py @@ -0,0 +1,150 @@ +from datetime import datetime +from uuid import UUID + +from litestar.contrib.repository.filters import BeforeAfter, CollectionFilter, FilterTypes, LimitOffset +from litestar.di import Provide +from litestar.params import Dependency, Parameter + +DEFAULT_PAGINATION_LIMIT = 20 + +__all__ = [ + "create_collection_dependencies", + "provide_created_filter", + "provide_filter_dependencies", + "provide_id_filter", + "provide_limit_offset_pagination", + "provide_updated_filter", +] + + +DTorNone = datetime | None + +CREATED_FILTER_DEPENDENCY_KEY = "created_filter" +FILTERS_DEPENDENCY_KEY = "filters" +ID_FILTER_DEPENDENCY_KEY = "id_filter" +LIMIT_OFFSET_DEPENDENCY_KEY = "limit_offset" +UPDATED_FILTER_DEPENDENCY_KEY = "updated_filter" + + +def provide_id_filter( + ids: list[UUID] | None = Parameter(query="ids", default=None, required=False) +) -> CollectionFilter[UUID]: + """Return type consumed by ``Repository.filter_in_collection()``. + + Parameters + ---------- + ids : list[UUID] | None + Parsed out of comma separated list of values in query params. + + Returns: + ------- + CollectionFilter[UUID] + """ + return CollectionFilter(field_name="id", values=ids or []) + + +def provide_created_filter( + before: DTorNone = Parameter(query="created-before", default=None, required=False), + after: DTorNone = Parameter(query="created-after", default=None, required=False), +) -> BeforeAfter: + """Return type consumed by `Repository.filter_on_datetime_field()`. + + Parameters + ---------- + before : datetime | None + Filter for records updated before this date/time. + after : datetime | None + Filter for records updated after this date/time. + """ + return BeforeAfter("created_at", before, after) + + +def provide_updated_filter( + before: DTorNone = Parameter(query="updated-before", default=None, required=False), + after: DTorNone = Parameter(query="updated-after", default=None, required=False), +) -> BeforeAfter: + """Return type consumed by `Repository.filter_on_datetime_field()`. + Parameters + ---------- + before : datetime | None + Filter for records updated before this date/time. + after : datetime | None + Filter for records updated after this date/time. + """ + return BeforeAfter("updated_at", before, after) + + +def provide_limit_offset_pagination( + page: int = Parameter(ge=1, default=1, required=False), + page_size: int = Parameter( + query="page-size", + ge=1, + default=DEFAULT_PAGINATION_LIMIT, + required=False, + ), +) -> LimitOffset: + """Return type consumed by `Repository.apply_limit_offset_pagination()`. + Parameters + ---------- + page : int + LIMIT to apply to select. + page_size : int + OFFSET to apply to select. + """ + return LimitOffset(page_size, page_size * (page - 1)) + + +def provide_filter_dependencies( + created_filter: BeforeAfter = Dependency(skip_validation=True), + updated_filter: BeforeAfter = Dependency(skip_validation=True), + id_filter: CollectionFilter = Dependency(skip_validation=True), + limit_offset: LimitOffset = Dependency(skip_validation=True), +) -> list[FilterTypes]: + """Common collection route filtering dependencies. Add all filters to any + route by including this function as a dependency, e.g: + + @get + def get_collection_handler(filters: Filters) -> ...: + ... + The dependency is provided at the application layer, so only need to inject the dependency where + necessary. + + Parameters + ---------- + id_filter : repository.CollectionFilter + Filter for scoping query to limited set of identities. + created_filter : repository.BeforeAfter + Filter for scoping query to instance creation date/time. + updated_filter : repository.BeforeAfter + Filter for scoping query to instance update date/time. + limit_offset : repository.LimitOffset + Filter for query pagination. + + Returns: + ------- + list[FilterTypes] + List of filters parsed from connection. + """ + return [ + created_filter, + id_filter, + limit_offset, + updated_filter, + ] + + +def create_collection_dependencies() -> dict[str, Provide]: + """Creates a dictionary of provides for pagination endpoints. + + Returns: + ------- + dict[str, Provide] + + """ + return { + LIMIT_OFFSET_DEPENDENCY_KEY: Provide(provide_limit_offset_pagination, sync_to_thread=False), + UPDATED_FILTER_DEPENDENCY_KEY: Provide(provide_updated_filter, sync_to_thread=False), + CREATED_FILTER_DEPENDENCY_KEY: Provide(provide_created_filter, sync_to_thread=False), + ID_FILTER_DEPENDENCY_KEY: Provide(provide_id_filter, sync_to_thread=False), + FILTERS_DEPENDENCY_KEY: Provide(provide_filter_dependencies, sync_to_thread=False), + } diff --git a/app/lib/exceptions.py b/app/lib/exceptions.py new file mode 100644 index 0000000..6b4a670 --- /dev/null +++ b/app/lib/exceptions.py @@ -0,0 +1,68 @@ +import logging +from typing import TYPE_CHECKING + +from litestar.contrib.repository.exceptions import ( + ConflictError as RepositoryConflictException, +) +from litestar.contrib.repository.exceptions import ( + NotFoundError as RepositoryNotFoundException, +) +from litestar.contrib.repository.exceptions import ( + RepositoryError as RepositoryException, +) +from litestar.exceptions import ( + HTTPException, + InternalServerException, + NotFoundException, +) +from litestar.middleware.exceptions.middleware import create_exception_response + +from .service import ServiceError + +if TYPE_CHECKING: + from litestar.connection import Request + from litestar.response import Response + +__all__ = [ + "repository_exception_to_http_response", + "service_exception_to_http_response", +] + +logger = logging.getLogger(__name__) + + +class ConflictException(HTTPException): + status_code = 409 + + +def repository_exception_to_http_response(request: "Request", exc: RepositoryException) -> "Response": + """Transform repository exceptions to HTTP exceptions. + + Args: + _: The request that experienced the exception. + exc: Exception raised during handling of the request. + + Returns: + Exception response appropriate to the type of original exception. + """ + http_exc: type[HTTPException] + if isinstance(exc, RepositoryNotFoundException): + http_exc = NotFoundException + elif isinstance(exc, RepositoryConflictException): + http_exc = ConflictException + else: + http_exc = InternalServerException + return create_exception_response(request, exc=http_exc()) + + +def service_exception_to_http_response(request: "Request", exc: ServiceError) -> "Response": + """Transform service exceptions to HTTP exceptions. + + Args: + _: The request that experienced the exception. + exc: Exception raised during handling of the request. + + Returns: + Exception response appropriate to the type of original exception. + """ + return create_exception_response(request, InternalServerException()) diff --git a/app/lib/filter_repository.py b/app/lib/filter_repository.py new file mode 100644 index 0000000..52b01fe --- /dev/null +++ b/app/lib/filter_repository.py @@ -0,0 +1,38 @@ +from typing import Optional, cast + +from litestar.contrib.repository import FilterTypes +from litestar.contrib.sqlalchemy.repository import ModelT, SQLAlchemyAsyncRepository +from litestar.contrib.sqlalchemy.repository.types import SelectT +from sqlalchemy import true, Column + +from app.lib.filters import ExactFilter + + +class FilterRepository(SQLAlchemyAsyncRepository[ModelT]): + alive_flag: Optional[str] = None + + def _get_column_by_name(self, name: str) -> Column | None: + return cast(Column, getattr(self.model_type, name, None)) + + def _apply_filters( + self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT + ) -> SelectT: + standard_filters = [] + for filter_ in filters: + if isinstance(filter_, ExactFilter): + statement = statement.where( + self._get_column_by_name(filter_.field_name) == filter_.value + ) + else: + standard_filters.append(filter_) + + statement = super()._apply_filters( + *standard_filters, apply_pagination=apply_pagination, statement=statement + ) + + if self.alive_flag: + statement = statement.where( + self._get_column_by_name(self.alive_flag) == true() + ) + + return statement diff --git a/app/lib/filters.py b/app/lib/filters.py new file mode 100644 index 0000000..ac6ff8d --- /dev/null +++ b/app/lib/filters.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass +from typing import Generic, TypeVar + +T = TypeVar("T") + + +@dataclass +class ExactFilter(Generic[T]): + field_name: str + value: T diff --git a/app/lib/responses.py b/app/lib/responses.py new file mode 100644 index 0000000..b6510d0 --- /dev/null +++ b/app/lib/responses.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from typing import Generic, TypeVar + +T = TypeVar("T") + + +@dataclass +class ObjectResponse(Generic[T]): + content: T + + +@dataclass +class ObjectListResponse(Generic[T]): + content: list[T] + + +@dataclass +class PaginationMeta: + page: int + page_count: int + + +@dataclass +class PaginatedObjectListResponse(Generic[T]): + content: list[T] + meta: PaginationMeta diff --git a/app/lib/service.py b/app/lib/service.py new file mode 100644 index 0000000..f47e5a2 --- /dev/null +++ b/app/lib/service.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic + +from litestar.contrib.sqlalchemy.repository import ModelT + +__all__ = ["Service", "ServiceError"] + + +if TYPE_CHECKING: + from litestar.contrib.repository import AbstractAsyncRepository, FilterTypes + + +class ServiceError(Exception): + """Base class for `Service` related exceptions.""" + + +class Service(Generic[ModelT]): + def __init__(self, repository: AbstractAsyncRepository[ModelT]) -> None: + """Generic Service object. + + Args: + repository: Instance conforming to `AbstractRepository` interface. + """ + self.repository = repository + + async def create(self, data: ModelT) -> ModelT: + """Wraps repository instance creation. + + Args: + data: Representation to be created. + + Returns: + Representation of created instance. + """ + return await self.repository.add(data) + + async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[ModelT]: + """Wraps repository scalars operation. + + Args: + *filters: Collection route filters. + **kwargs: Keyword arguments for attribute based filtering. + + Returns: + The list of instances retrieved from the repository. + """ + return await self.repository.list(*filters, **kwargs) + + async def update(self, id_: Any, data: ModelT) -> ModelT: + """Wraps repository update operation. + + Args: + id_: Identifier of item to be updated. + data: Representation to be updated. + + Returns: + Updated representation. + """ + return await self.repository.update(data) + + async def upsert(self, id_: Any, data: ModelT) -> ModelT: + """Wraps repository upsert operation. + + Args: + id_: Identifier of the object for upsert. + data: Representation for upsert. + + Returns: + ------- + Updated or created representation. + """ + return await self.repository.upsert(data) + + async def get(self, id_: Any) -> ModelT: + """Wraps repository scalar operation. + + Args: + id_: Identifier of instance to be retrieved. + + Returns: + Representation of instance with identifier `id_`. + """ + return await self.repository.get(id_) + + async def delete(self, id_: Any) -> ModelT: + """Wraps repository delete operation. + + Args: + id_: Identifier of instance to be deleted. + + Returns: + Representation of the deleted instance. + """ + return await self.repository.delete(id_) diff --git a/app/lib/sqlalchemy_plugin.py b/app/lib/sqlalchemy_plugin.py new file mode 100644 index 0000000..3b30614 --- /dev/null +++ b/app/lib/sqlalchemy_plugin.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, cast, Literal +from uuid import UUID + +import msgspec +from litestar.contrib.sqlalchemy.plugins.init import SQLAlchemyInitPlugin +from litestar.contrib.sqlalchemy.plugins.init.config import SQLAlchemyAsyncConfig +from litestar.contrib.sqlalchemy.plugins.init.config.common import ( + SESSION_SCOPE_KEY, + SESSION_TERMINUS_ASGI_EVENTS, +) +from litestar.utils import delete_litestar_scope_state, get_litestar_scope_state +from sqlalchemy import event +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import NullPool + + +DB_HOST = "localhost" +DB_PORT = 5432 +DB_NAME = "addressbook" +DB_USER = "addressbook" +DB_PASSWORD = "addressbook" + + +@dataclass +class DatabaseSettings: + URL: str = f"postgresql+asyncpg://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" + ECHO: bool = True + ECHO_POOL: bool | Literal["debug"] = False + POOL_DISABLE: bool = False + POOL_MAX_OVERFLOW: int = 10 + POOL_SIZE: int = 5 + POOL_TIMEOUT: int = 30 + DB_SESSION_DEPENDENCY_KEY: str = "db_session" + + +settings = DatabaseSettings() + + +if TYPE_CHECKING: + from typing import Any + + from litestar.types.asgi_types import Message, Scope + +__all__ = [ + "async_session_factory", + "config", + "engine", + "plugin", +] + + +def _default(val: Any) -> str: + if isinstance(val, UUID): + return str(val) + raise TypeError() + + +engine = create_async_engine( + settings.URL, + echo=settings.ECHO, + echo_pool=settings.ECHO_POOL, + json_serializer=msgspec.json.Encoder(enc_hook=_default), + max_overflow=settings.POOL_MAX_OVERFLOW, + pool_size=settings.POOL_SIZE, + pool_timeout=settings.POOL_TIMEOUT, + poolclass=NullPool if settings.POOL_DISABLE else None, +) +"""Configure via DatabaseSettings. + +Overrides default JSON +serializer to use `msgspec`. See [`create_async_engine()`][sqlalchemy.ext.asyncio.create_async_engine] +for detailed instructions. +""" +async_session_factory = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) +"""Database session factory. + +See [`async_sessionmaker()`][sqlalchemy.ext.asyncio.async_sessionmaker]. +""" + + +@event.listens_for(engine.sync_engine, "connect") +def _sqla_on_connect(dbapi_connection: Any, _: Any) -> Any: + """Using orjson for serialization of the json column values means that the + output is binary, not `str` like `json.dumps` would output. + + SQLAlchemy expects that the json serializer returns `str` and calls + `.encode()` on the value to turn it to bytes before writing to the + JSONB column. I'd need to either wrap `orjson.dumps` to return a + `str` so that SQLAlchemy could then convert it to binary, or do the + following, which changes the behaviour of the dialect to expect a + binary value from the serializer. + + See Also: + https://github.com/sqlalchemy/sqlalchemy/blob/14bfbadfdf9260a1c40f63b31641b27fe9de12a0/lib/sqlalchemy/dialects/postgresql/asyncpg.py#L934 + """ + + def encoder(bin_value: bytes) -> bytes: + # \x01 is the prefix for jsonb used by PostgreSQL. + # asyncpg requires it when format='binary' + return b"\x01" + bin_value + + def decoder(bin_value: bytes) -> Any: + # the byte is the \x01 prefix for jsonb used by PostgreSQL. + # asyncpg returns it when format='binary' + return msgspec.json.decode(bin_value[1:]) + + dbapi_connection.await_( + dbapi_connection.driver_connection.set_type_codec( + "jsonb", + encoder=encoder, + decoder=decoder, + schema="pg_catalog", + format="binary", + ) + ) + + +async def before_send_handler(message: Message, scope: Scope) -> None: + """Custom `before_send_handler` for SQLAlchemy plugin that inspects the + status of response and commits, or rolls back the database. + + Args: + message: ASGI message + _: + scope: ASGI scope + """ + session = cast("AsyncSession | None", get_litestar_scope_state(scope, SESSION_SCOPE_KEY)) + try: + if session is not None and message["type"] == "http.response.start": + if 200 <= message["status"] < 300: + await session.commit() + else: + await session.rollback() + finally: + if session is not None and message["type"] in SESSION_TERMINUS_ASGI_EVENTS: + await session.close() + delete_litestar_scope_state(scope, SESSION_SCOPE_KEY) + + +config = SQLAlchemyAsyncConfig( + session_dependency_key=settings.DB_SESSION_DEPENDENCY_KEY, + engine_instance=engine, + session_maker=async_session_factory, + before_send_handler=before_send_handler, +) + +plugin = SQLAlchemyInitPlugin(config=config) diff --git a/main.py b/main.py new file mode 100644 index 0000000..6070b5f --- /dev/null +++ b/main.py @@ -0,0 +1,28 @@ +from typing import Any + +from litestar import Litestar +from litestar.contrib.repository.exceptions import ( + RepositoryError as RepositoryException, +) +from litestar.openapi import OpenAPIConfig + +from app.controllers import create_router +from app.lib import exceptions, sqlalchemy_plugin +from app.lib.service import ServiceError + +def create_app(**kwargs: Any) -> Litestar: + return Litestar( + route_handlers=[create_router()], + openapi_config=OpenAPIConfig(title="My API", version="1.0.0"), + # dependencies={"session": provide_transaction}, + plugins=[sqlalchemy_plugin.plugin], + exception_handlers={ + RepositoryException: exceptions.repository_exception_to_http_response, # type: ignore[dict-item] + ServiceError: exceptions.service_exception_to_http_response, # type: ignore[dict-item] + }, + debug=True, + **kwargs, + ) + + +app = create_app() diff --git a/migrations/0001-initial.sql b/migrations/0001-initial.sql new file mode 100644 index 0000000..afb47f4 --- /dev/null +++ b/migrations/0001-initial.sql @@ -0,0 +1,26 @@ +CREATE TABLE cities +( + id uuid DEFAULT gen_random_uuid() NOT NULL, + name varchar(50) NOT NULL, + postal_code varchar(10), + created_at timestamp WITH TIME ZONE DEFAULT NOW() NOT NULL, + modified_at timestamp WITH TIME ZONE DEFAULT NOW() NOT NULL +); + +CREATE UNIQUE INDEX cities_id_uindex + ON cities (id); + +CREATE TABLE users +( + id uuid DEFAULT gen_random_uuid() NOT NULL, + first_name varchar(50), + last_name varchar(50), + city_id uuid + CONSTRAINT users_cities_id_fk + REFERENCES cities (), + created_at timestamp WITH TIME ZONE DEFAULT NOW() NOT NULL, + modified_at timestamp WITH TIME ZONE DEFAULT NOW() NOT NULL +); + +CREATE UNIQUE INDEX users_id_uindex + ON users (id);