This commit is contained in:
Eden Kirin
2023-08-27 13:35:29 +02:00
parent 03c8aaa312
commit 7b16c2f606
6 changed files with 156 additions and 56 deletions

View File

@ -5,6 +5,7 @@ from litestar.contrib.repository.filters import LimitOffset, SearchFilter
from litestar.di import Provide from litestar.di import Provide
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.domain.company import Company
from app.domain.machine import ( from app.domain.machine import (
Machine, Machine,
MachineReadDTO, MachineReadDTO,
@ -12,6 +13,7 @@ from app.domain.machine import (
Repository, Repository,
Service, Service,
) )
from app.lib.filters import ExactFilter
from app.lib.responses import ObjectListResponse, ObjectResponse from app.lib.responses import ObjectListResponse, ObjectResponse
if TYPE_CHECKING: if TYPE_CHECKING:
@ -21,15 +23,24 @@ if TYPE_CHECKING:
DETAIL_ROUTE = "/{machine_id:int}" DETAIL_ROUTE = "/{machine_id:int}"
def provides_service(db_session: AsyncSession) -> Service: async def provides_service(db_session: AsyncSession, company_id: int) -> Service:
"""Constructs repository and service objects for the request.""" """Constructs repository and service objects for the request."""
return Service(Repository(session=db_session)) from app.controllers.company import provides_service
company_service = provides_service(db_session)
company = await company_service.get(company_id)
return Service(Repository(session=db_session, company=company))
async def get_company(db_session: AsyncSession, company_id: int) -> Company:
from app.controllers.company import provides_service
company_service = provides_service(db_session)
return await company_service.get(company_id)
class MachineController(Controller): class MachineController(Controller):
dto = MachineWriteDTO dto = MachineWriteDTO
return_dto = MachineReadDTO return_dto = MachineReadDTO
path = "/machines" path = "/company/{company_id:int}/machines"
dependencies = { dependencies = {
"service": Provide(provides_service, sync_to_thread=False), "service": Provide(provides_service, sync_to_thread=False),
} }
@ -37,7 +48,9 @@ class MachineController(Controller):
@get() @get()
async def get_machines( async def get_machines(
self, service: Service, search: Optional[str] = None self,
service: Service,
search: Optional[str] = None,
) -> ObjectListResponse[Machine]: ) -> ObjectListResponse[Machine]:
filters = [ filters = [
LimitOffset(limit=20, offset=0), LimitOffset(limit=20, offset=0),

View File

@ -1,15 +1,12 @@
from typing import Annotated from typing import Annotated
from litestar.contrib.repository import FilterTypes
from litestar.contrib.sqlalchemy.base import BigIntBase from litestar.contrib.sqlalchemy.base import BigIntBase
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository
from litestar.contrib.sqlalchemy.repository.types import SelectT
from litestar.dto import DTOConfig from litestar.dto import DTOConfig
from sqlalchemy import true
from sqlalchemy.orm import Mapped from sqlalchemy.orm import Mapped
from app.lib import service from app.lib import service
from app.lib.filter_repository import FilterRepository
class Company(BigIntBase): class Company(BigIntBase):
@ -25,23 +22,15 @@ class Company(BigIntBase):
alive: Mapped[bool] alive: Mapped[bool]
class Repository(SQLAlchemyAsyncRepository[Company]): class Repository(FilterRepository[Company]):
model_type = Company model_type = Company
alive_flag = "alive"
def _apply_filters(
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
) -> SelectT:
statement = super()._apply_filters(
*filters, apply_pagination=apply_pagination, statement=statement
)
statement = statement.where(self.model_type.alive == true())
return statement
class Service(service.Service[Company]): class Service(service.Service[Company]):
repository_type = Repository repository_type = Repository
write_config = DTOConfig() write_config = DTOConfig(exclude={"id"})
CompanyWriteDTO = SQLAlchemyDTO[Annotated[Company, write_config]] CompanyWriteDTO = SQLAlchemyDTO[Annotated[Company, write_config]]
CompanyReadDTO = SQLAlchemyDTO[Company] CompanyReadDTO = SQLAlchemyDTO[Company]

View File

@ -2,17 +2,13 @@ from enum import Enum
from typing import Annotated, Optional from typing import Annotated, Optional
import sqlalchemy import sqlalchemy
from litestar.contrib.repository import FilterTypes
from litestar.contrib.sqlalchemy.base import BigIntBase from litestar.contrib.sqlalchemy.base import BigIntBase
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository
from litestar.contrib.sqlalchemy.repository.types import SelectT
from litestar.dto import DTOConfig from litestar.dto import DTOConfig
from sqlalchemy import ColumnElement
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from app.lib import service from app.lib import service
from app.lib.filters import ExactFilter from app.lib.filter_repository import FilterRepository
class FiscalModuleEnum(str, Enum): class FiscalModuleEnum(str, Enum):
@ -52,24 +48,9 @@ class FiscalPaymentMapping(BigIntBase):
payment_device_code: Mapped[Optional[int]] payment_device_code: Mapped[Optional[int]]
class Repository(SQLAlchemyAsyncRepository[FiscalPaymentMapping]): class Repository(FilterRepository[FiscalPaymentMapping]):
model_type = FiscalPaymentMapping model_type = FiscalPaymentMapping
def _apply_filters(
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
) -> SelectT:
standard_filters = []
for filter_ in filters:
if isinstance(filter_, ExactFilter):
field: ColumnElement = getattr(self.model_type, filter_.field_name)
statement = statement.where(field == filter_.value)
else:
standard_filters.append(filter_)
return super()._apply_filters(
*standard_filters, apply_pagination=apply_pagination, statement=statement
)
class Service(service.Service[FiscalPaymentMapping]): class Service(service.Service[FiscalPaymentMapping]):
repository_type = Repository repository_type = Repository

View File

@ -1,15 +1,12 @@
from typing import Annotated from typing import Annotated
from litestar.contrib.repository import FilterTypes
from litestar.contrib.sqlalchemy.base import BigIntBase from litestar.contrib.sqlalchemy.base import BigIntBase
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository
from litestar.contrib.sqlalchemy.repository.types import SelectT
from litestar.dto import DTOConfig from litestar.dto import DTOConfig
from sqlalchemy import true
from sqlalchemy.orm import Mapped from sqlalchemy.orm import Mapped
from app.lib import service from app.lib import service
from app.lib.company_owned_repository import CompanyOwnedRepository
class Machine(BigIntBase): class Machine(BigIntBase):
@ -20,26 +17,19 @@ class Machine(BigIntBase):
alive: Mapped[bool] alive: Mapped[bool]
deleted: Mapped[bool] deleted: Mapped[bool]
external_id: Mapped[str] external_id: Mapped[str]
owner_id: Mapped[int]
class Repository(SQLAlchemyAsyncRepository[Machine]): class Repository(CompanyOwnedRepository[Machine]):
model_type = Machine model_type = Machine
alive_flag = "alive"
def _apply_filters( company_id_field = "owner_id"
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
) -> SelectT:
statement = super()._apply_filters(
*filters, apply_pagination=apply_pagination, statement=statement
)
statement = statement.where(self.model_type.alive == true())
return statement
class Service(service.Service[Machine]): class Service(service.Service[Machine]):
repository_type = Repository repository_type = Repository
# write_config = DTOConfig(exclude={"created_at", "updated_at", "nationality"}) write_config = DTOConfig(exclude={"id"})
write_config = DTOConfig()
MachineWriteDTO = SQLAlchemyDTO[Annotated[Machine, write_config]] MachineWriteDTO = SQLAlchemyDTO[Annotated[Machine, write_config]]
MachineReadDTO = SQLAlchemyDTO[Machine] MachineReadDTO = SQLAlchemyDTO[Machine]

View File

@ -0,0 +1,89 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from litestar.contrib.repository import FilterTypes
from litestar.contrib.sqlalchemy.repository import ModelT
from litestar.contrib.sqlalchemy.repository.types import SelectT
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import InstrumentedAttribute
from app.lib.filter_repository import FilterRepository
if TYPE_CHECKING:
from app.domain.company import Company
class CompanyOwnedRepository(FilterRepository[ModelT]):
company_id_field: str | None = None
def __init__(
self,
*,
company: Company,
statement: SelectT | None = None,
session: AsyncSession,
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
**kwargs: Any,
) -> None:
self.company = company
super().__init__(
statement=statement,
session=session,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
auto_commit=auto_commit,
**kwargs,
)
def _get_company_filter_statement(self, statement: SelectT | None) -> SelectT:
if not self.company_id_field:
raise AttributeError(
f"company_id_field must be set for {self.__class__.__name__}"
)
column = self._get_column_by_name(self.company_id_field)
if column is None:
raise AttributeError(
f"column {self.company_id_field} not found in {self.__class__.__name__}"
)
stmt = statement if statement is not None else self.statement
return stmt.where(column == self.company.id)
def _apply_filters(
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
) -> SelectT:
if not self.company_id_field:
raise AttributeError(
f"company_id_field must be set for {self.__class__.__name__}"
)
statement = super()._apply_filters(
*filters, apply_pagination=apply_pagination, statement=statement
)
statement = statement.where(
self._get_column_by_name(self.company_id_field) == self.company.id
)
return statement
async def get( # type: ignore[override]
self,
item_id: Any,
auto_expunge: bool | None = None,
statement: SelectT | None = None,
id_attribute: str | InstrumentedAttribute | None = None,
) -> ModelT:
statement = self._get_company_filter_statement(statement)
return await super().get(
item_id=item_id,
auto_expunge=auto_expunge,
statement=statement,
id_attribute=id_attribute,
)

View File

@ -0,0 +1,38 @@
from typing import Optional, cast
from litestar.contrib.repository import FilterTypes
from litestar.contrib.sqlalchemy.repository import ModelT, SQLAlchemyAsyncRepository
from litestar.contrib.sqlalchemy.repository.types import SelectT
from sqlalchemy import true, Column
from app.lib.filters import ExactFilter
class FilterRepository(SQLAlchemyAsyncRepository[ModelT]):
alive_flag: Optional[str] = None
def _get_column_by_name(self, name: str) -> Column | None:
return cast(Column, getattr(self.model_type, name, None))
def _apply_filters(
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
) -> SelectT:
standard_filters = []
for filter_ in filters:
if isinstance(filter_, ExactFilter):
statement = statement.where(
self._get_column_by_name(filter_.field_name) == filter_.value
)
else:
standard_filters.append(filter_)
statement = super()._apply_filters(
*standard_filters, apply_pagination=apply_pagination, statement=statement
)
if self.alive_flag:
statement = statement.where(
self._get_column_by_name(self.alive_flag) == true()
)
return statement