From 7b16c2f6067bc1d862903fd9d1957a0c6dcf789a Mon Sep 17 00:00:00 2001 From: Eden Kirin Date: Sun, 27 Aug 2023 13:35:29 +0200 Subject: [PATCH] Done --- app/controllers/machine.py | 21 +++++-- app/domain/company.py | 19 ++---- app/domain/fiscal_payment_mapping.py | 23 +------ app/domain/machine.py | 22 ++----- app/lib/company_owned_repository.py | 89 ++++++++++++++++++++++++++++ app/lib/filter_repository.py | 38 ++++++++++++ 6 files changed, 156 insertions(+), 56 deletions(-) create mode 100644 app/lib/company_owned_repository.py create mode 100644 app/lib/filter_repository.py diff --git a/app/controllers/machine.py b/app/controllers/machine.py index 1dfe34a..0092213 100644 --- a/app/controllers/machine.py +++ b/app/controllers/machine.py @@ -5,6 +5,7 @@ from litestar.contrib.repository.filters import LimitOffset, SearchFilter from litestar.di import Provide from sqlalchemy.ext.asyncio import AsyncSession +from app.domain.company import Company from app.domain.machine import ( Machine, MachineReadDTO, @@ -12,6 +13,7 @@ from app.domain.machine import ( Repository, Service, ) +from app.lib.filters import ExactFilter from app.lib.responses import ObjectListResponse, ObjectResponse if TYPE_CHECKING: @@ -21,15 +23,24 @@ if TYPE_CHECKING: DETAIL_ROUTE = "/{machine_id:int}" -def provides_service(db_session: AsyncSession) -> Service: +async def provides_service(db_session: AsyncSession, company_id: int) -> Service: """Constructs repository and service objects for the request.""" - return Service(Repository(session=db_session)) + from app.controllers.company import provides_service + company_service = provides_service(db_session) + company = await company_service.get(company_id) + return Service(Repository(session=db_session, company=company)) + + +async def get_company(db_session: AsyncSession, company_id: int) -> Company: + from app.controllers.company import provides_service + company_service = provides_service(db_session) + return await company_service.get(company_id) class MachineController(Controller): dto = MachineWriteDTO return_dto = MachineReadDTO - path = "/machines" + path = "/company/{company_id:int}/machines" dependencies = { "service": Provide(provides_service, sync_to_thread=False), } @@ -37,7 +48,9 @@ class MachineController(Controller): @get() async def get_machines( - self, service: Service, search: Optional[str] = None + self, + service: Service, + search: Optional[str] = None, ) -> ObjectListResponse[Machine]: filters = [ LimitOffset(limit=20, offset=0), diff --git a/app/domain/company.py b/app/domain/company.py index 3c862bf..6b3fa26 100644 --- a/app/domain/company.py +++ b/app/domain/company.py @@ -1,15 +1,12 @@ from typing import Annotated -from litestar.contrib.repository import FilterTypes from litestar.contrib.sqlalchemy.base import BigIntBase from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO -from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository -from litestar.contrib.sqlalchemy.repository.types import SelectT from litestar.dto import DTOConfig -from sqlalchemy import true from sqlalchemy.orm import Mapped from app.lib import service +from app.lib.filter_repository import FilterRepository class Company(BigIntBase): @@ -25,23 +22,15 @@ class Company(BigIntBase): alive: Mapped[bool] -class Repository(SQLAlchemyAsyncRepository[Company]): +class Repository(FilterRepository[Company]): model_type = Company - - def _apply_filters( - self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT - ) -> SelectT: - statement = super()._apply_filters( - *filters, apply_pagination=apply_pagination, statement=statement - ) - statement = statement.where(self.model_type.alive == true()) - return statement + alive_flag = "alive" class Service(service.Service[Company]): repository_type = Repository -write_config = DTOConfig() +write_config = DTOConfig(exclude={"id"}) CompanyWriteDTO = SQLAlchemyDTO[Annotated[Company, write_config]] CompanyReadDTO = SQLAlchemyDTO[Company] diff --git a/app/domain/fiscal_payment_mapping.py b/app/domain/fiscal_payment_mapping.py index f192ffa..871182a 100644 --- a/app/domain/fiscal_payment_mapping.py +++ b/app/domain/fiscal_payment_mapping.py @@ -2,17 +2,13 @@ from enum import Enum from typing import Annotated, Optional import sqlalchemy -from litestar.contrib.repository import FilterTypes from litestar.contrib.sqlalchemy.base import BigIntBase from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO -from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository -from litestar.contrib.sqlalchemy.repository.types import SelectT from litestar.dto import DTOConfig -from sqlalchemy import ColumnElement from sqlalchemy.orm import Mapped, mapped_column from app.lib import service -from app.lib.filters import ExactFilter +from app.lib.filter_repository import FilterRepository class FiscalModuleEnum(str, Enum): @@ -52,24 +48,9 @@ class FiscalPaymentMapping(BigIntBase): payment_device_code: Mapped[Optional[int]] -class Repository(SQLAlchemyAsyncRepository[FiscalPaymentMapping]): +class Repository(FilterRepository[FiscalPaymentMapping]): model_type = FiscalPaymentMapping - def _apply_filters( - self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT - ) -> SelectT: - standard_filters = [] - for filter_ in filters: - if isinstance(filter_, ExactFilter): - field: ColumnElement = getattr(self.model_type, filter_.field_name) - statement = statement.where(field == filter_.value) - else: - standard_filters.append(filter_) - - return super()._apply_filters( - *standard_filters, apply_pagination=apply_pagination, statement=statement - ) - class Service(service.Service[FiscalPaymentMapping]): repository_type = Repository diff --git a/app/domain/machine.py b/app/domain/machine.py index 9c1d76d..98025f7 100644 --- a/app/domain/machine.py +++ b/app/domain/machine.py @@ -1,15 +1,12 @@ from typing import Annotated -from litestar.contrib.repository import FilterTypes from litestar.contrib.sqlalchemy.base import BigIntBase from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO -from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository -from litestar.contrib.sqlalchemy.repository.types import SelectT from litestar.dto import DTOConfig -from sqlalchemy import true from sqlalchemy.orm import Mapped from app.lib import service +from app.lib.company_owned_repository import CompanyOwnedRepository class Machine(BigIntBase): @@ -20,26 +17,19 @@ class Machine(BigIntBase): alive: Mapped[bool] deleted: Mapped[bool] external_id: Mapped[str] + owner_id: Mapped[int] -class Repository(SQLAlchemyAsyncRepository[Machine]): +class Repository(CompanyOwnedRepository[Machine]): model_type = Machine - - def _apply_filters( - self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT - ) -> SelectT: - statement = super()._apply_filters( - *filters, apply_pagination=apply_pagination, statement=statement - ) - statement = statement.where(self.model_type.alive == true()) - return statement + alive_flag = "alive" + company_id_field = "owner_id" class Service(service.Service[Machine]): repository_type = Repository -# write_config = DTOConfig(exclude={"created_at", "updated_at", "nationality"}) -write_config = DTOConfig() +write_config = DTOConfig(exclude={"id"}) MachineWriteDTO = SQLAlchemyDTO[Annotated[Machine, write_config]] MachineReadDTO = SQLAlchemyDTO[Machine] diff --git a/app/lib/company_owned_repository.py b/app/lib/company_owned_repository.py new file mode 100644 index 0000000..589c1e4 --- /dev/null +++ b/app/lib/company_owned_repository.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from litestar.contrib.repository import FilterTypes +from litestar.contrib.sqlalchemy.repository import ModelT +from litestar.contrib.sqlalchemy.repository.types import SelectT +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import InstrumentedAttribute + +from app.lib.filter_repository import FilterRepository + +if TYPE_CHECKING: + from app.domain.company import Company + + +class CompanyOwnedRepository(FilterRepository[ModelT]): + company_id_field: str | None = None + + def __init__( + self, + *, + company: Company, + statement: SelectT | None = None, + session: AsyncSession, + auto_expunge: bool = False, + auto_refresh: bool = True, + auto_commit: bool = False, + **kwargs: Any, + ) -> None: + self.company = company + + super().__init__( + statement=statement, + session=session, + auto_expunge=auto_expunge, + auto_refresh=auto_refresh, + auto_commit=auto_commit, + **kwargs, + ) + + def _get_company_filter_statement(self, statement: SelectT | None) -> SelectT: + if not self.company_id_field: + raise AttributeError( + f"company_id_field must be set for {self.__class__.__name__}" + ) + + column = self._get_column_by_name(self.company_id_field) + if column is None: + raise AttributeError( + f"column {self.company_id_field} not found in {self.__class__.__name__}" + ) + + stmt = statement if statement is not None else self.statement + return stmt.where(column == self.company.id) + + def _apply_filters( + self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT + ) -> SelectT: + if not self.company_id_field: + raise AttributeError( + f"company_id_field must be set for {self.__class__.__name__}" + ) + + statement = super()._apply_filters( + *filters, apply_pagination=apply_pagination, statement=statement + ) + + statement = statement.where( + self._get_column_by_name(self.company_id_field) == self.company.id + ) + + return statement + + async def get( # type: ignore[override] + self, + item_id: Any, + auto_expunge: bool | None = None, + statement: SelectT | None = None, + id_attribute: str | InstrumentedAttribute | None = None, + ) -> ModelT: + statement = self._get_company_filter_statement(statement) + + return await super().get( + item_id=item_id, + auto_expunge=auto_expunge, + statement=statement, + id_attribute=id_attribute, + ) diff --git a/app/lib/filter_repository.py b/app/lib/filter_repository.py new file mode 100644 index 0000000..52b01fe --- /dev/null +++ b/app/lib/filter_repository.py @@ -0,0 +1,38 @@ +from typing import Optional, cast + +from litestar.contrib.repository import FilterTypes +from litestar.contrib.sqlalchemy.repository import ModelT, SQLAlchemyAsyncRepository +from litestar.contrib.sqlalchemy.repository.types import SelectT +from sqlalchemy import true, Column + +from app.lib.filters import ExactFilter + + +class FilterRepository(SQLAlchemyAsyncRepository[ModelT]): + alive_flag: Optional[str] = None + + def _get_column_by_name(self, name: str) -> Column | None: + return cast(Column, getattr(self.model_type, name, None)) + + def _apply_filters( + self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT + ) -> SelectT: + standard_filters = [] + for filter_ in filters: + if isinstance(filter_, ExactFilter): + statement = statement.where( + self._get_column_by_name(filter_.field_name) == filter_.value + ) + else: + standard_filters.append(filter_) + + statement = super()._apply_filters( + *standard_filters, apply_pagination=apply_pagination, statement=statement + ) + + if self.alive_flag: + statement = statement.where( + self._get_column_by_name(self.alive_flag) == true() + ) + + return statement