Basic users and cities api
This commit is contained in:
14
Makefile
Normal file
14
Makefile
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
ifeq ($(VIRTUAL_ENV),)
|
||||||
|
RUN_IN_ENV=poetry run
|
||||||
|
else
|
||||||
|
RUN_IN_ENV=
|
||||||
|
endif
|
||||||
|
|
||||||
|
run:
|
||||||
|
@ $(RUN_IN_ENV) uvicorn \
|
||||||
|
main:app \
|
||||||
|
--reload \
|
||||||
|
--reload-dir=app
|
||||||
|
|
||||||
|
shell:
|
||||||
|
@ $(RUN_IN_ENV) python manage.py shell
|
||||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
20
app/controllers/__init__.py
Normal file
20
app/controllers/__init__.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from litestar import Router
|
||||||
|
|
||||||
|
__all__ = ["create_router"]
|
||||||
|
|
||||||
|
from app.controllers.city import CityController
|
||||||
|
from app.controllers.user import UserController
|
||||||
|
from app.domain.city import City
|
||||||
|
|
||||||
|
|
||||||
|
def create_router() -> Router:
|
||||||
|
return Router(
|
||||||
|
path="/v1",
|
||||||
|
route_handlers=[
|
||||||
|
CityController,
|
||||||
|
UserController,
|
||||||
|
],
|
||||||
|
signature_namespace={
|
||||||
|
"City": City,
|
||||||
|
},
|
||||||
|
)
|
||||||
63
app/controllers/city.py
Normal file
63
app/controllers/city.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from litestar import Controller, get
|
||||||
|
from litestar.contrib.repository.filters import LimitOffset, SearchFilter
|
||||||
|
from litestar.di import Provide
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.domain.city import (
|
||||||
|
City,
|
||||||
|
CityReadDTO,
|
||||||
|
CityWriteDTO,
|
||||||
|
Repository,
|
||||||
|
Service,
|
||||||
|
)
|
||||||
|
from app.lib.responses import ObjectListResponse, ObjectResponse
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
|
DETAIL_ROUTE = "/{city_id:uuid}"
|
||||||
|
|
||||||
|
|
||||||
|
def provides_service(db_session: AsyncSession) -> Service:
|
||||||
|
"""Constructs repository and service objects for the request."""
|
||||||
|
return Service(Repository(session=db_session))
|
||||||
|
|
||||||
|
|
||||||
|
class CityController(Controller):
|
||||||
|
dto = CityWriteDTO
|
||||||
|
return_dto = CityReadDTO
|
||||||
|
path = "/cities"
|
||||||
|
dependencies = {
|
||||||
|
"service": Provide(provides_service, sync_to_thread=False),
|
||||||
|
}
|
||||||
|
tags = ["Cities"]
|
||||||
|
|
||||||
|
@get()
|
||||||
|
async def get_cities(
|
||||||
|
self, service: Service, search: Optional[str] = None
|
||||||
|
) -> ObjectListResponse[City]:
|
||||||
|
filters = [
|
||||||
|
LimitOffset(limit=20, offset=0),
|
||||||
|
]
|
||||||
|
|
||||||
|
if search is not None:
|
||||||
|
filters.append(
|
||||||
|
SearchFilter(
|
||||||
|
field_name="caption",
|
||||||
|
value=search,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
content = await service.list(*filters)
|
||||||
|
return ObjectListResponse(content=content)
|
||||||
|
|
||||||
|
@get(DETAIL_ROUTE)
|
||||||
|
async def get_city(
|
||||||
|
self, service: Service, city_id: UUID
|
||||||
|
) -> ObjectResponse[City]:
|
||||||
|
content = await service.get(city_id)
|
||||||
|
return ObjectResponse(content=content)
|
||||||
63
app/controllers/user.py
Normal file
63
app/controllers/user.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from litestar import Controller, get
|
||||||
|
from litestar.contrib.repository.filters import LimitOffset, SearchFilter
|
||||||
|
from litestar.di import Provide
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.domain.user import (
|
||||||
|
User,
|
||||||
|
UserReadDTO,
|
||||||
|
UserWriteDTO,
|
||||||
|
Repository,
|
||||||
|
Service,
|
||||||
|
)
|
||||||
|
from app.lib.responses import ObjectListResponse, ObjectResponse
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
|
DETAIL_ROUTE = "/{user_id:uuid}"
|
||||||
|
|
||||||
|
|
||||||
|
def provides_service(db_session: AsyncSession) -> Service:
|
||||||
|
"""Constructs repository and service objects for the request."""
|
||||||
|
return Service(Repository(session=db_session))
|
||||||
|
|
||||||
|
|
||||||
|
class UserController(Controller):
|
||||||
|
dto = UserWriteDTO
|
||||||
|
return_dto = UserReadDTO
|
||||||
|
path = "/users"
|
||||||
|
dependencies = {
|
||||||
|
"service": Provide(provides_service, sync_to_thread=False),
|
||||||
|
}
|
||||||
|
tags = ["Users"]
|
||||||
|
|
||||||
|
@get()
|
||||||
|
async def get_users(
|
||||||
|
self, service: Service, search: Optional[str] = None
|
||||||
|
) -> ObjectListResponse[User]:
|
||||||
|
filters = [
|
||||||
|
LimitOffset(limit=20, offset=0),
|
||||||
|
]
|
||||||
|
|
||||||
|
if search is not None:
|
||||||
|
filters.append(
|
||||||
|
SearchFilter(
|
||||||
|
field_name="caption",
|
||||||
|
value=search,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
content = await service.list(*filters)
|
||||||
|
return ObjectListResponse(content=content)
|
||||||
|
|
||||||
|
@get(DETAIL_ROUTE)
|
||||||
|
async def get_user(
|
||||||
|
self, service: Service, user_id: UUID
|
||||||
|
) -> ObjectResponse[User]:
|
||||||
|
content = await service.get(user_id)
|
||||||
|
return ObjectResponse(content=content)
|
||||||
0
app/domain/__init__.py
Normal file
0
app/domain/__init__.py
Normal file
32
app/domain/city.py
Normal file
32
app/domain/city.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from litestar.contrib.sqlalchemy.base import UUIDBase
|
||||||
|
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
|
||||||
|
from litestar.dto import DTOConfig
|
||||||
|
from sqlalchemy.orm import Mapped
|
||||||
|
|
||||||
|
from app.lib import service
|
||||||
|
from app.lib.filter_repository import FilterRepository
|
||||||
|
|
||||||
|
|
||||||
|
class City(UUIDBase):
|
||||||
|
__tablename__ = "cities" # type: ignore[assignment]
|
||||||
|
|
||||||
|
name: Mapped[str]
|
||||||
|
postal_code: Mapped[str]
|
||||||
|
created_at: Mapped[datetime]
|
||||||
|
modified_at: Mapped[datetime]
|
||||||
|
|
||||||
|
|
||||||
|
class Repository(FilterRepository[City]):
|
||||||
|
model_type = City
|
||||||
|
|
||||||
|
|
||||||
|
class Service(service.Service[City]):
|
||||||
|
repository_type = Repository
|
||||||
|
|
||||||
|
|
||||||
|
write_config = DTOConfig(exclude={"id"})
|
||||||
|
CityWriteDTO = SQLAlchemyDTO[Annotated[City, write_config]]
|
||||||
|
CityReadDTO = SQLAlchemyDTO[City]
|
||||||
0
app/domain/enums.py
Normal file
0
app/domain/enums.py
Normal file
32
app/domain/user.py
Normal file
32
app/domain/user.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from litestar.contrib.sqlalchemy.base import UUIDBase
|
||||||
|
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
|
||||||
|
from litestar.dto import DTOConfig
|
||||||
|
from sqlalchemy.orm import Mapped
|
||||||
|
|
||||||
|
from app.lib import service
|
||||||
|
from app.lib.filter_repository import FilterRepository
|
||||||
|
|
||||||
|
|
||||||
|
class User(UUIDBase):
|
||||||
|
__tablename__ = "users" # type: ignore[assignment]
|
||||||
|
|
||||||
|
first_name: Mapped[str]
|
||||||
|
last_name: Mapped[str]
|
||||||
|
created_at: Mapped[datetime]
|
||||||
|
modified_at: Mapped[datetime]
|
||||||
|
|
||||||
|
|
||||||
|
class Repository(FilterRepository[User]):
|
||||||
|
model_type = User
|
||||||
|
|
||||||
|
|
||||||
|
class Service(service.Service[User]):
|
||||||
|
repository_type = Repository
|
||||||
|
|
||||||
|
|
||||||
|
write_config = DTOConfig(exclude={"id"})
|
||||||
|
UserWriteDTO = SQLAlchemyDTO[Annotated[User, write_config]]
|
||||||
|
UserReadDTO = SQLAlchemyDTO[User]
|
||||||
0
app/dto/__init__.py
Normal file
0
app/dto/__init__.py
Normal file
0
app/lib/__init__.py
Normal file
0
app/lib/__init__.py
Normal file
150
app/lib/dependencies.py
Normal file
150
app/lib/dependencies.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from litestar.contrib.repository.filters import BeforeAfter, CollectionFilter, FilterTypes, LimitOffset
|
||||||
|
from litestar.di import Provide
|
||||||
|
from litestar.params import Dependency, Parameter
|
||||||
|
|
||||||
|
DEFAULT_PAGINATION_LIMIT = 20
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_collection_dependencies",
|
||||||
|
"provide_created_filter",
|
||||||
|
"provide_filter_dependencies",
|
||||||
|
"provide_id_filter",
|
||||||
|
"provide_limit_offset_pagination",
|
||||||
|
"provide_updated_filter",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
DTorNone = datetime | None
|
||||||
|
|
||||||
|
CREATED_FILTER_DEPENDENCY_KEY = "created_filter"
|
||||||
|
FILTERS_DEPENDENCY_KEY = "filters"
|
||||||
|
ID_FILTER_DEPENDENCY_KEY = "id_filter"
|
||||||
|
LIMIT_OFFSET_DEPENDENCY_KEY = "limit_offset"
|
||||||
|
UPDATED_FILTER_DEPENDENCY_KEY = "updated_filter"
|
||||||
|
|
||||||
|
|
||||||
|
def provide_id_filter(
|
||||||
|
ids: list[UUID] | None = Parameter(query="ids", default=None, required=False)
|
||||||
|
) -> CollectionFilter[UUID]:
|
||||||
|
"""Return type consumed by ``Repository.filter_in_collection()``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ids : list[UUID] | None
|
||||||
|
Parsed out of comma separated list of values in query params.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
-------
|
||||||
|
CollectionFilter[UUID]
|
||||||
|
"""
|
||||||
|
return CollectionFilter(field_name="id", values=ids or [])
|
||||||
|
|
||||||
|
|
||||||
|
def provide_created_filter(
|
||||||
|
before: DTorNone = Parameter(query="created-before", default=None, required=False),
|
||||||
|
after: DTorNone = Parameter(query="created-after", default=None, required=False),
|
||||||
|
) -> BeforeAfter:
|
||||||
|
"""Return type consumed by `Repository.filter_on_datetime_field()`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
before : datetime | None
|
||||||
|
Filter for records updated before this date/time.
|
||||||
|
after : datetime | None
|
||||||
|
Filter for records updated after this date/time.
|
||||||
|
"""
|
||||||
|
return BeforeAfter("created_at", before, after)
|
||||||
|
|
||||||
|
|
||||||
|
def provide_updated_filter(
|
||||||
|
before: DTorNone = Parameter(query="updated-before", default=None, required=False),
|
||||||
|
after: DTorNone = Parameter(query="updated-after", default=None, required=False),
|
||||||
|
) -> BeforeAfter:
|
||||||
|
"""Return type consumed by `Repository.filter_on_datetime_field()`.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
before : datetime | None
|
||||||
|
Filter for records updated before this date/time.
|
||||||
|
after : datetime | None
|
||||||
|
Filter for records updated after this date/time.
|
||||||
|
"""
|
||||||
|
return BeforeAfter("updated_at", before, after)
|
||||||
|
|
||||||
|
|
||||||
|
def provide_limit_offset_pagination(
|
||||||
|
page: int = Parameter(ge=1, default=1, required=False),
|
||||||
|
page_size: int = Parameter(
|
||||||
|
query="page-size",
|
||||||
|
ge=1,
|
||||||
|
default=DEFAULT_PAGINATION_LIMIT,
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
) -> LimitOffset:
|
||||||
|
"""Return type consumed by `Repository.apply_limit_offset_pagination()`.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
page : int
|
||||||
|
LIMIT to apply to select.
|
||||||
|
page_size : int
|
||||||
|
OFFSET to apply to select.
|
||||||
|
"""
|
||||||
|
return LimitOffset(page_size, page_size * (page - 1))
|
||||||
|
|
||||||
|
|
||||||
|
def provide_filter_dependencies(
|
||||||
|
created_filter: BeforeAfter = Dependency(skip_validation=True),
|
||||||
|
updated_filter: BeforeAfter = Dependency(skip_validation=True),
|
||||||
|
id_filter: CollectionFilter = Dependency(skip_validation=True),
|
||||||
|
limit_offset: LimitOffset = Dependency(skip_validation=True),
|
||||||
|
) -> list[FilterTypes]:
|
||||||
|
"""Common collection route filtering dependencies. Add all filters to any
|
||||||
|
route by including this function as a dependency, e.g:
|
||||||
|
|
||||||
|
@get
|
||||||
|
def get_collection_handler(filters: Filters) -> ...:
|
||||||
|
...
|
||||||
|
The dependency is provided at the application layer, so only need to inject the dependency where
|
||||||
|
necessary.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
id_filter : repository.CollectionFilter
|
||||||
|
Filter for scoping query to limited set of identities.
|
||||||
|
created_filter : repository.BeforeAfter
|
||||||
|
Filter for scoping query to instance creation date/time.
|
||||||
|
updated_filter : repository.BeforeAfter
|
||||||
|
Filter for scoping query to instance update date/time.
|
||||||
|
limit_offset : repository.LimitOffset
|
||||||
|
Filter for query pagination.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
-------
|
||||||
|
list[FilterTypes]
|
||||||
|
List of filters parsed from connection.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
created_filter,
|
||||||
|
id_filter,
|
||||||
|
limit_offset,
|
||||||
|
updated_filter,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_collection_dependencies() -> dict[str, Provide]:
|
||||||
|
"""Creates a dictionary of provides for pagination endpoints.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
-------
|
||||||
|
dict[str, Provide]
|
||||||
|
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
LIMIT_OFFSET_DEPENDENCY_KEY: Provide(provide_limit_offset_pagination, sync_to_thread=False),
|
||||||
|
UPDATED_FILTER_DEPENDENCY_KEY: Provide(provide_updated_filter, sync_to_thread=False),
|
||||||
|
CREATED_FILTER_DEPENDENCY_KEY: Provide(provide_created_filter, sync_to_thread=False),
|
||||||
|
ID_FILTER_DEPENDENCY_KEY: Provide(provide_id_filter, sync_to_thread=False),
|
||||||
|
FILTERS_DEPENDENCY_KEY: Provide(provide_filter_dependencies, sync_to_thread=False),
|
||||||
|
}
|
||||||
68
app/lib/exceptions.py
Normal file
68
app/lib/exceptions.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from litestar.contrib.repository.exceptions import (
|
||||||
|
ConflictError as RepositoryConflictException,
|
||||||
|
)
|
||||||
|
from litestar.contrib.repository.exceptions import (
|
||||||
|
NotFoundError as RepositoryNotFoundException,
|
||||||
|
)
|
||||||
|
from litestar.contrib.repository.exceptions import (
|
||||||
|
RepositoryError as RepositoryException,
|
||||||
|
)
|
||||||
|
from litestar.exceptions import (
|
||||||
|
HTTPException,
|
||||||
|
InternalServerException,
|
||||||
|
NotFoundException,
|
||||||
|
)
|
||||||
|
from litestar.middleware.exceptions.middleware import create_exception_response
|
||||||
|
|
||||||
|
from .service import ServiceError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litestar.connection import Request
|
||||||
|
from litestar.response import Response
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"repository_exception_to_http_response",
|
||||||
|
"service_exception_to_http_response",
|
||||||
|
]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConflictException(HTTPException):
|
||||||
|
status_code = 409
|
||||||
|
|
||||||
|
|
||||||
|
def repository_exception_to_http_response(request: "Request", exc: RepositoryException) -> "Response":
|
||||||
|
"""Transform repository exceptions to HTTP exceptions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_: The request that experienced the exception.
|
||||||
|
exc: Exception raised during handling of the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exception response appropriate to the type of original exception.
|
||||||
|
"""
|
||||||
|
http_exc: type[HTTPException]
|
||||||
|
if isinstance(exc, RepositoryNotFoundException):
|
||||||
|
http_exc = NotFoundException
|
||||||
|
elif isinstance(exc, RepositoryConflictException):
|
||||||
|
http_exc = ConflictException
|
||||||
|
else:
|
||||||
|
http_exc = InternalServerException
|
||||||
|
return create_exception_response(request, exc=http_exc())
|
||||||
|
|
||||||
|
|
||||||
|
def service_exception_to_http_response(request: "Request", exc: ServiceError) -> "Response":
|
||||||
|
"""Transform service exceptions to HTTP exceptions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_: The request that experienced the exception.
|
||||||
|
exc: Exception raised during handling of the request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Exception response appropriate to the type of original exception.
|
||||||
|
"""
|
||||||
|
return create_exception_response(request, InternalServerException())
|
||||||
38
app/lib/filter_repository.py
Normal file
38
app/lib/filter_repository.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
from litestar.contrib.repository import FilterTypes
|
||||||
|
from litestar.contrib.sqlalchemy.repository import ModelT, SQLAlchemyAsyncRepository
|
||||||
|
from litestar.contrib.sqlalchemy.repository.types import SelectT
|
||||||
|
from sqlalchemy import true, Column
|
||||||
|
|
||||||
|
from app.lib.filters import ExactFilter
|
||||||
|
|
||||||
|
|
||||||
|
class FilterRepository(SQLAlchemyAsyncRepository[ModelT]):
|
||||||
|
alive_flag: Optional[str] = None
|
||||||
|
|
||||||
|
def _get_column_by_name(self, name: str) -> Column | None:
|
||||||
|
return cast(Column, getattr(self.model_type, name, None))
|
||||||
|
|
||||||
|
def _apply_filters(
|
||||||
|
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
|
||||||
|
) -> SelectT:
|
||||||
|
standard_filters = []
|
||||||
|
for filter_ in filters:
|
||||||
|
if isinstance(filter_, ExactFilter):
|
||||||
|
statement = statement.where(
|
||||||
|
self._get_column_by_name(filter_.field_name) == filter_.value
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
standard_filters.append(filter_)
|
||||||
|
|
||||||
|
statement = super()._apply_filters(
|
||||||
|
*standard_filters, apply_pagination=apply_pagination, statement=statement
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.alive_flag:
|
||||||
|
statement = statement.where(
|
||||||
|
self._get_column_by_name(self.alive_flag) == true()
|
||||||
|
)
|
||||||
|
|
||||||
|
return statement
|
||||||
10
app/lib/filters.py
Normal file
10
app/lib/filters.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExactFilter(Generic[T]):
|
||||||
|
field_name: str
|
||||||
|
value: T
|
||||||
26
app/lib/responses.py
Normal file
26
app/lib/responses.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ObjectResponse(Generic[T]):
|
||||||
|
content: T
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ObjectListResponse(Generic[T]):
|
||||||
|
content: list[T]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PaginationMeta:
|
||||||
|
page: int
|
||||||
|
page_count: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PaginatedObjectListResponse(Generic[T]):
|
||||||
|
content: list[T]
|
||||||
|
meta: PaginationMeta
|
||||||
95
app/lib/service.py
Normal file
95
app/lib/service.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Generic
|
||||||
|
|
||||||
|
from litestar.contrib.sqlalchemy.repository import ModelT
|
||||||
|
|
||||||
|
__all__ = ["Service", "ServiceError"]
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litestar.contrib.repository import AbstractAsyncRepository, FilterTypes
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceError(Exception):
|
||||||
|
"""Base class for `Service` related exceptions."""
|
||||||
|
|
||||||
|
|
||||||
|
class Service(Generic[ModelT]):
|
||||||
|
def __init__(self, repository: AbstractAsyncRepository[ModelT]) -> None:
|
||||||
|
"""Generic Service object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository: Instance conforming to `AbstractRepository` interface.
|
||||||
|
"""
|
||||||
|
self.repository = repository
|
||||||
|
|
||||||
|
async def create(self, data: ModelT) -> ModelT:
|
||||||
|
"""Wraps repository instance creation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Representation to be created.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Representation of created instance.
|
||||||
|
"""
|
||||||
|
return await self.repository.add(data)
|
||||||
|
|
||||||
|
async def list(self, *filters: FilterTypes, **kwargs: Any) -> list[ModelT]:
|
||||||
|
"""Wraps repository scalars operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*filters: Collection route filters.
|
||||||
|
**kwargs: Keyword arguments for attribute based filtering.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of instances retrieved from the repository.
|
||||||
|
"""
|
||||||
|
return await self.repository.list(*filters, **kwargs)
|
||||||
|
|
||||||
|
async def update(self, id_: Any, data: ModelT) -> ModelT:
|
||||||
|
"""Wraps repository update operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id_: Identifier of item to be updated.
|
||||||
|
data: Representation to be updated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated representation.
|
||||||
|
"""
|
||||||
|
return await self.repository.update(data)
|
||||||
|
|
||||||
|
async def upsert(self, id_: Any, data: ModelT) -> ModelT:
|
||||||
|
"""Wraps repository upsert operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id_: Identifier of the object for upsert.
|
||||||
|
data: Representation for upsert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
-------
|
||||||
|
Updated or created representation.
|
||||||
|
"""
|
||||||
|
return await self.repository.upsert(data)
|
||||||
|
|
||||||
|
async def get(self, id_: Any) -> ModelT:
|
||||||
|
"""Wraps repository scalar operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id_: Identifier of instance to be retrieved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Representation of instance with identifier `id_`.
|
||||||
|
"""
|
||||||
|
return await self.repository.get(id_)
|
||||||
|
|
||||||
|
async def delete(self, id_: Any) -> ModelT:
|
||||||
|
"""Wraps repository delete operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
id_: Identifier of instance to be deleted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Representation of the deleted instance.
|
||||||
|
"""
|
||||||
|
return await self.repository.delete(id_)
|
||||||
150
app/lib/sqlalchemy_plugin.py
Normal file
150
app/lib/sqlalchemy_plugin.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, cast, Literal
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
from litestar.contrib.sqlalchemy.plugins.init import SQLAlchemyInitPlugin
|
||||||
|
from litestar.contrib.sqlalchemy.plugins.init.config import SQLAlchemyAsyncConfig
|
||||||
|
from litestar.contrib.sqlalchemy.plugins.init.config.common import (
|
||||||
|
SESSION_SCOPE_KEY,
|
||||||
|
SESSION_TERMINUS_ASGI_EVENTS,
|
||||||
|
)
|
||||||
|
from litestar.utils import delete_litestar_scope_state, get_litestar_scope_state
|
||||||
|
from sqlalchemy import event
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.pool import NullPool
|
||||||
|
|
||||||
|
|
||||||
|
DB_HOST = "localhost"
|
||||||
|
DB_PORT = 5432
|
||||||
|
DB_NAME = "addressbook"
|
||||||
|
DB_USER = "addressbook"
|
||||||
|
DB_PASSWORD = "addressbook"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatabaseSettings:
|
||||||
|
URL: str = f"postgresql+asyncpg://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
|
||||||
|
ECHO: bool = True
|
||||||
|
ECHO_POOL: bool | Literal["debug"] = False
|
||||||
|
POOL_DISABLE: bool = False
|
||||||
|
POOL_MAX_OVERFLOW: int = 10
|
||||||
|
POOL_SIZE: int = 5
|
||||||
|
POOL_TIMEOUT: int = 30
|
||||||
|
DB_SESSION_DEPENDENCY_KEY: str = "db_session"
|
||||||
|
|
||||||
|
|
||||||
|
settings = DatabaseSettings()
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from litestar.types.asgi_types import Message, Scope
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"async_session_factory",
|
||||||
|
"config",
|
||||||
|
"engine",
|
||||||
|
"plugin",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _default(val: Any) -> str:
|
||||||
|
if isinstance(val, UUID):
|
||||||
|
return str(val)
|
||||||
|
raise TypeError()
|
||||||
|
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
settings.URL,
|
||||||
|
echo=settings.ECHO,
|
||||||
|
echo_pool=settings.ECHO_POOL,
|
||||||
|
json_serializer=msgspec.json.Encoder(enc_hook=_default),
|
||||||
|
max_overflow=settings.POOL_MAX_OVERFLOW,
|
||||||
|
pool_size=settings.POOL_SIZE,
|
||||||
|
pool_timeout=settings.POOL_TIMEOUT,
|
||||||
|
poolclass=NullPool if settings.POOL_DISABLE else None,
|
||||||
|
)
|
||||||
|
"""Configure via DatabaseSettings.
|
||||||
|
|
||||||
|
Overrides default JSON
|
||||||
|
serializer to use `msgspec`. See [`create_async_engine()`][sqlalchemy.ext.asyncio.create_async_engine]
|
||||||
|
for detailed instructions.
|
||||||
|
"""
|
||||||
|
async_session_factory = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
||||||
|
"""Database session factory.
|
||||||
|
|
||||||
|
See [`async_sessionmaker()`][sqlalchemy.ext.asyncio.async_sessionmaker].
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@event.listens_for(engine.sync_engine, "connect")
|
||||||
|
def _sqla_on_connect(dbapi_connection: Any, _: Any) -> Any:
|
||||||
|
"""Using orjson for serialization of the json column values means that the
|
||||||
|
output is binary, not `str` like `json.dumps` would output.
|
||||||
|
|
||||||
|
SQLAlchemy expects that the json serializer returns `str` and calls
|
||||||
|
`.encode()` on the value to turn it to bytes before writing to the
|
||||||
|
JSONB column. I'd need to either wrap `orjson.dumps` to return a
|
||||||
|
`str` so that SQLAlchemy could then convert it to binary, or do the
|
||||||
|
following, which changes the behaviour of the dialect to expect a
|
||||||
|
binary value from the serializer.
|
||||||
|
|
||||||
|
See Also:
|
||||||
|
https://github.com/sqlalchemy/sqlalchemy/blob/14bfbadfdf9260a1c40f63b31641b27fe9de12a0/lib/sqlalchemy/dialects/postgresql/asyncpg.py#L934
|
||||||
|
"""
|
||||||
|
|
||||||
|
def encoder(bin_value: bytes) -> bytes:
|
||||||
|
# \x01 is the prefix for jsonb used by PostgreSQL.
|
||||||
|
# asyncpg requires it when format='binary'
|
||||||
|
return b"\x01" + bin_value
|
||||||
|
|
||||||
|
def decoder(bin_value: bytes) -> Any:
|
||||||
|
# the byte is the \x01 prefix for jsonb used by PostgreSQL.
|
||||||
|
# asyncpg returns it when format='binary'
|
||||||
|
return msgspec.json.decode(bin_value[1:])
|
||||||
|
|
||||||
|
dbapi_connection.await_(
|
||||||
|
dbapi_connection.driver_connection.set_type_codec(
|
||||||
|
"jsonb",
|
||||||
|
encoder=encoder,
|
||||||
|
decoder=decoder,
|
||||||
|
schema="pg_catalog",
|
||||||
|
format="binary",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def before_send_handler(message: Message, scope: Scope) -> None:
|
||||||
|
"""Custom `before_send_handler` for SQLAlchemy plugin that inspects the
|
||||||
|
status of response and commits, or rolls back the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: ASGI message
|
||||||
|
_:
|
||||||
|
scope: ASGI scope
|
||||||
|
"""
|
||||||
|
session = cast("AsyncSession | None", get_litestar_scope_state(scope, SESSION_SCOPE_KEY))
|
||||||
|
try:
|
||||||
|
if session is not None and message["type"] == "http.response.start":
|
||||||
|
if 200 <= message["status"] < 300:
|
||||||
|
await session.commit()
|
||||||
|
else:
|
||||||
|
await session.rollback()
|
||||||
|
finally:
|
||||||
|
if session is not None and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
|
||||||
|
await session.close()
|
||||||
|
delete_litestar_scope_state(scope, SESSION_SCOPE_KEY)
|
||||||
|
|
||||||
|
|
||||||
|
config = SQLAlchemyAsyncConfig(
|
||||||
|
session_dependency_key=settings.DB_SESSION_DEPENDENCY_KEY,
|
||||||
|
engine_instance=engine,
|
||||||
|
session_maker=async_session_factory,
|
||||||
|
before_send_handler=before_send_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin = SQLAlchemyInitPlugin(config=config)
|
||||||
28
main.py
Normal file
28
main.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from litestar import Litestar
|
||||||
|
from litestar.contrib.repository.exceptions import (
|
||||||
|
RepositoryError as RepositoryException,
|
||||||
|
)
|
||||||
|
from litestar.openapi import OpenAPIConfig
|
||||||
|
|
||||||
|
from app.controllers import create_router
|
||||||
|
from app.lib import exceptions, sqlalchemy_plugin
|
||||||
|
from app.lib.service import ServiceError
|
||||||
|
|
||||||
|
def create_app(**kwargs: Any) -> Litestar:
|
||||||
|
return Litestar(
|
||||||
|
route_handlers=[create_router()],
|
||||||
|
openapi_config=OpenAPIConfig(title="My API", version="1.0.0"),
|
||||||
|
# dependencies={"session": provide_transaction},
|
||||||
|
plugins=[sqlalchemy_plugin.plugin],
|
||||||
|
exception_handlers={
|
||||||
|
RepositoryException: exceptions.repository_exception_to_http_response, # type: ignore[dict-item]
|
||||||
|
ServiceError: exceptions.service_exception_to_http_response, # type: ignore[dict-item]
|
||||||
|
},
|
||||||
|
debug=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
26
migrations/0001-initial.sql
Normal file
26
migrations/0001-initial.sql
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
CREATE TABLE cities
|
||||||
|
(
|
||||||
|
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||||
|
name varchar(50) NOT NULL,
|
||||||
|
postal_code varchar(10),
|
||||||
|
created_at timestamp WITH TIME ZONE DEFAULT NOW() NOT NULL,
|
||||||
|
modified_at timestamp WITH TIME ZONE DEFAULT NOW() NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX cities_id_uindex
|
||||||
|
ON cities (id);
|
||||||
|
|
||||||
|
CREATE TABLE users
|
||||||
|
(
|
||||||
|
id uuid DEFAULT gen_random_uuid() NOT NULL,
|
||||||
|
first_name varchar(50),
|
||||||
|
last_name varchar(50),
|
||||||
|
city_id uuid
|
||||||
|
CONSTRAINT users_cities_id_fk
|
||||||
|
REFERENCES cities (),
|
||||||
|
created_at timestamp WITH TIME ZONE DEFAULT NOW() NOT NULL,
|
||||||
|
modified_at timestamp WITH TIME ZONE DEFAULT NOW() NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX users_id_uindex
|
||||||
|
ON users (id);
|
||||||
Reference in New Issue
Block a user