diff --git a/app/controllers/__init__.py b/app/controllers/__init__.py index 799de57..17c7ef8 100644 --- a/app/controllers/__init__.py +++ b/app/controllers/__init__.py @@ -1,8 +1,10 @@ from litestar import Router +from app.controllers.asset_item import AssetItemController from app.controllers.company import CompanyController from app.controllers.fiscal_payment_mapping import FiscalPaymentMappingController from app.controllers.machine import MachineController +from app.domain.asset_item import AssetItem from app.domain.company import Company from app.domain.fiscal_payment_mapping import FiscalPaymentMapping from app.domain.machine import Machine @@ -17,10 +19,12 @@ def create_router() -> Router: CompanyController, MachineController, FiscalPaymentMappingController, + AssetItemController, ], signature_namespace={ "Company": Company, "Machine": Machine, "FiscalPaymentMapping": FiscalPaymentMapping, + "AssetItem": AssetItem, }, ) diff --git a/app/controllers/asset_item.py b/app/controllers/asset_item.py new file mode 100644 index 0000000..fee6248 --- /dev/null +++ b/app/controllers/asset_item.py @@ -0,0 +1,83 @@ +from typing import TYPE_CHECKING, Optional + +from litestar import Controller, get, post +from litestar.contrib.repository.filters import LimitOffset, SearchFilter +from litestar.di import Provide +from sqlalchemy.ext.asyncio import AsyncSession + +from app.domain.company import Company +from app.domain.asset_item import ( + AssetItem, + AssetItemReadDTO, + AssetItemWriteDTO, + Repository, + Service, +) +from app.lib.responses import ObjectListResponse, ObjectResponse + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + + +DETAIL_ROUTE = "/{asset_item_id:int}" + + +async def provides_service(db_session: AsyncSession, company_id: int) -> Service: + """Constructs repository and service objects for the request.""" + from app.controllers.company import provides_service + + company_service = provides_service(db_session) + company = await company_service.get(company_id) + return Service(Repository(session=db_session, company=company)) + + +async def get_company(db_session: AsyncSession, company_id: int) -> Company: + from app.controllers.company import provides_service + + company_service = provides_service(db_session) + return await company_service.get(company_id) + + +class AssetItemController(Controller): + dto = AssetItemWriteDTO + return_dto = AssetItemReadDTO + path = "/company/{company_id:int}/asset-items" + dependencies = { + "service": Provide(provides_service, sync_to_thread=False), + } + tags = ["AssetItems"] + + @post() + async def create_asset_item( + self, data: AssetItem, service: Service + ) -> AssetItem: + return await service.create(data) + + @get() + async def get_asset_items( + self, + service: Service, + search: Optional[str] = None, + ) -> ObjectListResponse[AssetItem]: + filters = [ + LimitOffset(limit=20, offset=0), + ] + + if search is not None: + filters.append( + SearchFilter( + field_name="caption", + value=search, + ), + ) + + content = await service.list(*filters) + return ObjectListResponse(content=content) + + @get(DETAIL_ROUTE) + async def get_asset_item( + self, service: Service, asset_item_id: int + ) -> ObjectResponse[AssetItem]: + content = await service.get(asset_item_id) + return ObjectResponse(content=content) + diff --git a/app/database.py b/app/database.py deleted file mode 100644 index b885fa2..0000000 --- a/app/database.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import AsyncGenerator - -from litestar.contrib.sqlalchemy.plugins import EngineConfig, SQLAlchemyAsyncConfig -from litestar.exceptions import ClientException -from litestar.status_codes import HTTP_409_CONFLICT -from sqlalchemy import URL -from sqlalchemy.exc import IntegrityError -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker - - -async def provide_transaction( - db_session: AsyncSession, -) -> AsyncGenerator[AsyncSession, None]: - try: - async with db_session.begin(): - yield db_session - except IntegrityError as exc: - raise ClientException( - status_code=HTTP_409_CONFLICT, - detail=str(exc), - ) from exc - - -sessionmaker = async_sessionmaker(expire_on_commit=False) - -db_connection_string = URL.create( - drivername="postgresql+asyncpg", - username="televend", - password="televend", - host="localhost", - port=5433, - database="televend", -) -db_config = SQLAlchemyAsyncConfig( - connection_string=db_connection_string.render_as_string(hide_password=False), - engine_config=EngineConfig( - echo=True, - ), -) diff --git a/app/domain/asset_item.py b/app/domain/asset_item.py new file mode 100644 index 0000000..a525815 --- /dev/null +++ b/app/domain/asset_item.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Optional + +import sqlalchemy +from litestar.contrib.sqlalchemy.base import BigIntBase +from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO +from litestar.dto import DTOConfig, MsgspecDTO +from msgspec import Struct, Meta +from sqlalchemy.orm import Mapped, mapped_column + +from app.domain.enums import AssetItemProductLineEnum, AssetItemStatusEnum +from app.lib import service +from app.lib.company_owned_repository import CompanyOwnedRepository + + +class AssetItem(BigIntBase): + __tablename__ = "asset_items" # type: ignore[assignment] + + company_id: Mapped[int] + product_line: Mapped[AssetItemProductLineEnum] = mapped_column( + sqlalchemy.Enum(AssetItemProductLineEnum, name="asset_product_line_enum") + ) + brand_id: Mapped[int] + model_id: Mapped[int] + serial_number: Mapped[str] + external_id: Mapped[Optional[str]] + alive: Mapped[bool] + status: Mapped[AssetItemStatusEnum] = mapped_column( + sqlalchemy.Enum(AssetItemStatusEnum, name="asset_status_enum") + ) + created_by_id: Mapped[Optional[int]] + created_at: Mapped[datetime] + last_modified_by_id: Mapped[Optional[int]] + last_modified_at: Mapped[datetime] + is_fiscal_device: Mapped[bool] + warehouse_id: Mapped[Optional[int]] + + +PositiveInt = Annotated[int, Meta(gt=0)] + + +class AssetItemWriteStruct(Struct): + company_id: PositiveInt + product_line: AssetItemProductLineEnum + brand_id: PositiveInt + model_id: PositiveInt + serial_number: Annotated[str, Meta(max_length=10)] + external_id: Annotated[str, Meta(max_length=10)] | None + alive: bool + status: AssetItemStatusEnum + created_by_id: PositiveInt | None + created_at: datetime + last_modified_by_id: PositiveInt | None + last_modified_at: datetime | None + is_fiscal_device: bool + warehouse_id: PositiveInt | None + + +class XXAssetItemWriteDTO(MsgspecDTO[AssetItemWriteStruct]): + ... + + +class Repository(CompanyOwnedRepository[AssetItem]): + model_type = AssetItem + alive_flag = "alive" + company_id_field = "company_id" + + +class Service(service.Service[AssetItem]): + repository_type = Repository + + +write_config = DTOConfig(exclude={"id"}) +AssetItemWriteDTO = SQLAlchemyDTO[Annotated[AssetItem, write_config]] +# AssetItemWriteDTO = MsgspecDTO[AssetItemWriteStruct] +AssetItemReadDTO = SQLAlchemyDTO[AssetItem] + diff --git a/app/domain/enums.py b/app/domain/enums.py new file mode 100644 index 0000000..e441449 --- /dev/null +++ b/app/domain/enums.py @@ -0,0 +1,43 @@ +from enum import Enum + + +class FiscalModuleEnum(str, Enum): + CROATIA = "CROATIA" + HUNGARY = "HUNGARY" + ITALY = "ITALY" + MONTENEGRO = "MONTENEGRO" + ROMANIA = "ROMANIA" + RUSSIA = "RUSSIA" + SERBIA = "SERBIA" + + +class PaymentTypeEnum(str, Enum): + CA = "CA" + DA = "DA" + DB = "DB" + DC = "DC" + DD = "DD" + PA4 = "PA4" + NEG = "NEG" + PA3 = "PA3" + TA = "TA" + WLT = "WLT" + + +class AssetItemProductLineEnum(str, Enum): + VENDING_MACHINE = "VENDING_MACHINE" + HORECA_MACHINE = "HORECA_MACHINE" + PROFESSIONAL_COFFEE_MACHINE = "PROFESSIONAL_COFFEE_MACHINE" + TELEMETRY_DEVICE = "TELEMETRY_DEVICE" + COIN_CHANGER = "COIN_CHANGER" + CASHLESS_PAYMENT_DEVICE = "CASHLESS_PAYMENT_DEVICE" + BANKNOTE_ACCEPTOR = "BANKNOTE_ACCEPTOR" + BOILER = "BOILER" + STEAMER = "STEAMER" + + +class AssetItemStatusEnum(str, Enum): + AVAILABLE = "AVAILABLE" + IN_USE = "IN_USE" + REPARATION = "REPARATION" + DISPOSED = "DISPOSED" diff --git a/app/domain/fiscal_payment_mapping.py b/app/domain/fiscal_payment_mapping.py index 871182a..6eb8436 100644 --- a/app/domain/fiscal_payment_mapping.py +++ b/app/domain/fiscal_payment_mapping.py @@ -1,4 +1,3 @@ -from enum import Enum from typing import Annotated, Optional import sqlalchemy @@ -7,33 +6,11 @@ from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO from litestar.dto import DTOConfig from sqlalchemy.orm import Mapped, mapped_column +from app.domain.enums import FiscalModuleEnum, PaymentTypeEnum from app.lib import service from app.lib.filter_repository import FilterRepository -class FiscalModuleEnum(str, Enum): - CROATIA = "CROATIA" - HUNGARY = "HUNGARY" - ITALY = "ITALY" - MONTENEGRO = "MONTENEGRO" - ROMANIA = "ROMANIA" - RUSSIA = "RUSSIA" - SERBIA = "SERBIA" - - -class PaymentTypeEnum(str, Enum): - CA = "CA" - DA = "DA" - DB = "DB" - DC = "DC" - DD = "DD" - PA4 = "PA4" - NEG = "NEG" - PA3 = "PA3" - TA = "TA" - WLT = "WLT" - - class FiscalPaymentMapping(BigIntBase): __tablename__ = "fiscal_payment_mapping" # type: ignore[assignment] diff --git a/app/lib/sqlalchemy_plugin.py b/app/lib/sqlalchemy_plugin.py new file mode 100644 index 0000000..3a5a31e --- /dev/null +++ b/app/lib/sqlalchemy_plugin.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, cast, Literal +from uuid import UUID + +import msgspec +from litestar.contrib.sqlalchemy.plugins.init import SQLAlchemyInitPlugin +from litestar.contrib.sqlalchemy.plugins.init.config import SQLAlchemyAsyncConfig +from litestar.contrib.sqlalchemy.plugins.init.config.common import ( + SESSION_SCOPE_KEY, + SESSION_TERMINUS_ASGI_EVENTS, +) +from litestar.utils import delete_litestar_scope_state, get_litestar_scope_state +from sqlalchemy import event +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import NullPool + + +@dataclass +class DatabaseSettings: + URL: str = "postgresql+asyncpg://televend:televend@localhost:5433/televend" + ECHO: bool = True + ECHO_POOL: bool | Literal["debug"] = False + POOL_DISABLE: bool = False + POOL_MAX_OVERFLOW: int = 10 + POOL_SIZE: int = 5 + POOL_TIMEOUT: int = 30 + DB_SESSION_DEPENDENCY_KEY: str = "db_session" + + +settings = DatabaseSettings() + + +if TYPE_CHECKING: + from typing import Any + + from litestar.types.asgi_types import Message, Scope + +__all__ = [ + "async_session_factory", + "config", + "engine", + "plugin", +] + + +def _default(val: Any) -> str: + if isinstance(val, UUID): + return str(val) + raise TypeError() + + +engine = create_async_engine( + settings.URL, + echo=settings.ECHO, + echo_pool=settings.ECHO_POOL, + json_serializer=msgspec.json.Encoder(enc_hook=_default), + max_overflow=settings.POOL_MAX_OVERFLOW, + pool_size=settings.POOL_SIZE, + pool_timeout=settings.POOL_TIMEOUT, + poolclass=NullPool if settings.POOL_DISABLE else None, +) +"""Configure via DatabaseSettings. + +Overrides default JSON +serializer to use `msgspec`. See [`create_async_engine()`][sqlalchemy.ext.asyncio.create_async_engine] +for detailed instructions. +""" +async_session_factory = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) +"""Database session factory. + +See [`async_sessionmaker()`][sqlalchemy.ext.asyncio.async_sessionmaker]. +""" + + +@event.listens_for(engine.sync_engine, "connect") +def _sqla_on_connect(dbapi_connection: Any, _: Any) -> Any: + """Using orjson for serialization of the json column values means that the + output is binary, not `str` like `json.dumps` would output. + + SQLAlchemy expects that the json serializer returns `str` and calls + `.encode()` on the value to turn it to bytes before writing to the + JSONB column. I'd need to either wrap `orjson.dumps` to return a + `str` so that SQLAlchemy could then convert it to binary, or do the + following, which changes the behaviour of the dialect to expect a + binary value from the serializer. + + See Also: + https://github.com/sqlalchemy/sqlalchemy/blob/14bfbadfdf9260a1c40f63b31641b27fe9de12a0/lib/sqlalchemy/dialects/postgresql/asyncpg.py#L934 + """ + + def encoder(bin_value: bytes) -> bytes: + # \x01 is the prefix for jsonb used by PostgreSQL. + # asyncpg requires it when format='binary' + return b"\x01" + bin_value + + def decoder(bin_value: bytes) -> Any: + # the byte is the \x01 prefix for jsonb used by PostgreSQL. + # asyncpg returns it when format='binary' + return msgspec.json.decode(bin_value[1:]) + + dbapi_connection.await_( + dbapi_connection.driver_connection.set_type_codec( + "jsonb", + encoder=encoder, + decoder=decoder, + schema="pg_catalog", + format="binary", + ) + ) + + +async def before_send_handler(message: Message, scope: Scope) -> None: + """Custom `before_send_handler` for SQLAlchemy plugin that inspects the + status of response and commits, or rolls back the database. + + Args: + message: ASGI message + _: + scope: ASGI scope + """ + session = cast("AsyncSession | None", get_litestar_scope_state(scope, SESSION_SCOPE_KEY)) + try: + if session is not None and message["type"] == "http.response.start": + if 200 <= message["status"] < 300: + await session.commit() + else: + await session.rollback() + finally: + if session is not None and message["type"] in SESSION_TERMINUS_ASGI_EVENTS: + await session.close() + delete_litestar_scope_state(scope, SESSION_SCOPE_KEY) + + +config = SQLAlchemyAsyncConfig( + session_dependency_key=settings.DB_SESSION_DEPENDENCY_KEY, + engine_instance=engine, + session_maker=async_session_factory, + before_send_handler=before_send_handler, +) + +plugin = SQLAlchemyInitPlugin(config=config) diff --git a/main.py b/main.py index c6e1718..9f4b31a 100644 --- a/main.py +++ b/main.py @@ -4,20 +4,18 @@ from litestar import Litestar, get from litestar.contrib.repository.exceptions import ( RepositoryError as RepositoryException, ) -from litestar.contrib.sqlalchemy.plugins import SQLAlchemyPlugin from litestar.openapi import OpenAPIConfig from app.controllers import create_router -from app.database import db_config, provide_transaction -from app.lib import exceptions +from app.lib import exceptions, sqlalchemy_plugin from app.lib.service import ServiceError def create_app(**kwargs: Any) -> Litestar: return Litestar( route_handlers=[create_router()], openapi_config=OpenAPIConfig(title="My API", version="1.0.0"), - dependencies={"session": provide_transaction}, - plugins=[SQLAlchemyPlugin(db_config)], + # dependencies={"session": provide_transaction}, + plugins=[sqlalchemy_plugin.plugin], exception_handlers={ RepositoryException: exceptions.repository_exception_to_http_response, # type: ignore[dict-item] ServiceError: exceptions.service_exception_to_http_response, # type: ignore[dict-item]