Done
This commit is contained in:
@ -5,6 +5,7 @@ from litestar.contrib.repository.filters import LimitOffset, SearchFilter
|
||||
from litestar.di import Provide
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.domain.company import Company
|
||||
from app.domain.machine import (
|
||||
Machine,
|
||||
MachineReadDTO,
|
||||
@ -12,6 +13,7 @@ from app.domain.machine import (
|
||||
Repository,
|
||||
Service,
|
||||
)
|
||||
from app.lib.filters import ExactFilter
|
||||
from app.lib.responses import ObjectListResponse, ObjectResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -21,15 +23,24 @@ if TYPE_CHECKING:
|
||||
DETAIL_ROUTE = "/{machine_id:int}"
|
||||
|
||||
|
||||
def provides_service(db_session: AsyncSession) -> Service:
|
||||
async def provides_service(db_session: AsyncSession, company_id: int) -> Service:
|
||||
"""Constructs repository and service objects for the request."""
|
||||
return Service(Repository(session=db_session))
|
||||
from app.controllers.company import provides_service
|
||||
company_service = provides_service(db_session)
|
||||
company = await company_service.get(company_id)
|
||||
return Service(Repository(session=db_session, company=company))
|
||||
|
||||
|
||||
async def get_company(db_session: AsyncSession, company_id: int) -> Company:
|
||||
from app.controllers.company import provides_service
|
||||
company_service = provides_service(db_session)
|
||||
return await company_service.get(company_id)
|
||||
|
||||
|
||||
class MachineController(Controller):
|
||||
dto = MachineWriteDTO
|
||||
return_dto = MachineReadDTO
|
||||
path = "/machines"
|
||||
path = "/company/{company_id:int}/machines"
|
||||
dependencies = {
|
||||
"service": Provide(provides_service, sync_to_thread=False),
|
||||
}
|
||||
@ -37,7 +48,9 @@ class MachineController(Controller):
|
||||
|
||||
@get()
|
||||
async def get_machines(
|
||||
self, service: Service, search: Optional[str] = None
|
||||
self,
|
||||
service: Service,
|
||||
search: Optional[str] = None,
|
||||
) -> ObjectListResponse[Machine]:
|
||||
filters = [
|
||||
LimitOffset(limit=20, offset=0),
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
from typing import Annotated
|
||||
|
||||
from litestar.contrib.repository import FilterTypes
|
||||
from litestar.contrib.sqlalchemy.base import BigIntBase
|
||||
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
|
||||
from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository
|
||||
from litestar.contrib.sqlalchemy.repository.types import SelectT
|
||||
from litestar.dto import DTOConfig
|
||||
from sqlalchemy import true
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
from app.lib import service
|
||||
from app.lib.filter_repository import FilterRepository
|
||||
|
||||
|
||||
class Company(BigIntBase):
|
||||
@ -25,23 +22,15 @@ class Company(BigIntBase):
|
||||
alive: Mapped[bool]
|
||||
|
||||
|
||||
class Repository(SQLAlchemyAsyncRepository[Company]):
|
||||
class Repository(FilterRepository[Company]):
|
||||
model_type = Company
|
||||
|
||||
def _apply_filters(
|
||||
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
|
||||
) -> SelectT:
|
||||
statement = super()._apply_filters(
|
||||
*filters, apply_pagination=apply_pagination, statement=statement
|
||||
)
|
||||
statement = statement.where(self.model_type.alive == true())
|
||||
return statement
|
||||
alive_flag = "alive"
|
||||
|
||||
|
||||
class Service(service.Service[Company]):
|
||||
repository_type = Repository
|
||||
|
||||
|
||||
write_config = DTOConfig()
|
||||
write_config = DTOConfig(exclude={"id"})
|
||||
CompanyWriteDTO = SQLAlchemyDTO[Annotated[Company, write_config]]
|
||||
CompanyReadDTO = SQLAlchemyDTO[Company]
|
||||
|
||||
@ -2,17 +2,13 @@ from enum import Enum
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import sqlalchemy
|
||||
from litestar.contrib.repository import FilterTypes
|
||||
from litestar.contrib.sqlalchemy.base import BigIntBase
|
||||
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
|
||||
from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository
|
||||
from litestar.contrib.sqlalchemy.repository.types import SelectT
|
||||
from litestar.dto import DTOConfig
|
||||
from sqlalchemy import ColumnElement
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.lib import service
|
||||
from app.lib.filters import ExactFilter
|
||||
from app.lib.filter_repository import FilterRepository
|
||||
|
||||
|
||||
class FiscalModuleEnum(str, Enum):
|
||||
@ -52,24 +48,9 @@ class FiscalPaymentMapping(BigIntBase):
|
||||
payment_device_code: Mapped[Optional[int]]
|
||||
|
||||
|
||||
class Repository(SQLAlchemyAsyncRepository[FiscalPaymentMapping]):
|
||||
class Repository(FilterRepository[FiscalPaymentMapping]):
|
||||
model_type = FiscalPaymentMapping
|
||||
|
||||
def _apply_filters(
|
||||
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
|
||||
) -> SelectT:
|
||||
standard_filters = []
|
||||
for filter_ in filters:
|
||||
if isinstance(filter_, ExactFilter):
|
||||
field: ColumnElement = getattr(self.model_type, filter_.field_name)
|
||||
statement = statement.where(field == filter_.value)
|
||||
else:
|
||||
standard_filters.append(filter_)
|
||||
|
||||
return super()._apply_filters(
|
||||
*standard_filters, apply_pagination=apply_pagination, statement=statement
|
||||
)
|
||||
|
||||
|
||||
class Service(service.Service[FiscalPaymentMapping]):
|
||||
repository_type = Repository
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
from typing import Annotated
|
||||
|
||||
from litestar.contrib.repository import FilterTypes
|
||||
from litestar.contrib.sqlalchemy.base import BigIntBase
|
||||
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
|
||||
from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository
|
||||
from litestar.contrib.sqlalchemy.repository.types import SelectT
|
||||
from litestar.dto import DTOConfig
|
||||
from sqlalchemy import true
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
from app.lib import service
|
||||
from app.lib.company_owned_repository import CompanyOwnedRepository
|
||||
|
||||
|
||||
class Machine(BigIntBase):
|
||||
@ -20,26 +17,19 @@ class Machine(BigIntBase):
|
||||
alive: Mapped[bool]
|
||||
deleted: Mapped[bool]
|
||||
external_id: Mapped[str]
|
||||
owner_id: Mapped[int]
|
||||
|
||||
|
||||
class Repository(SQLAlchemyAsyncRepository[Machine]):
|
||||
class Repository(CompanyOwnedRepository[Machine]):
|
||||
model_type = Machine
|
||||
|
||||
def _apply_filters(
|
||||
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
|
||||
) -> SelectT:
|
||||
statement = super()._apply_filters(
|
||||
*filters, apply_pagination=apply_pagination, statement=statement
|
||||
)
|
||||
statement = statement.where(self.model_type.alive == true())
|
||||
return statement
|
||||
alive_flag = "alive"
|
||||
company_id_field = "owner_id"
|
||||
|
||||
|
||||
class Service(service.Service[Machine]):
|
||||
repository_type = Repository
|
||||
|
||||
|
||||
# write_config = DTOConfig(exclude={"created_at", "updated_at", "nationality"})
|
||||
write_config = DTOConfig()
|
||||
write_config = DTOConfig(exclude={"id"})
|
||||
MachineWriteDTO = SQLAlchemyDTO[Annotated[Machine, write_config]]
|
||||
MachineReadDTO = SQLAlchemyDTO[Machine]
|
||||
|
||||
89
app/lib/company_owned_repository.py
Normal file
89
app/lib/company_owned_repository.py
Normal file
@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from litestar.contrib.repository import FilterTypes
|
||||
from litestar.contrib.sqlalchemy.repository import ModelT
|
||||
from litestar.contrib.sqlalchemy.repository.types import SelectT
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import InstrumentedAttribute
|
||||
|
||||
from app.lib.filter_repository import FilterRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.domain.company import Company
|
||||
|
||||
|
||||
class CompanyOwnedRepository(FilterRepository[ModelT]):
|
||||
company_id_field: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
company: Company,
|
||||
statement: SelectT | None = None,
|
||||
session: AsyncSession,
|
||||
auto_expunge: bool = False,
|
||||
auto_refresh: bool = True,
|
||||
auto_commit: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.company = company
|
||||
|
||||
super().__init__(
|
||||
statement=statement,
|
||||
session=session,
|
||||
auto_expunge=auto_expunge,
|
||||
auto_refresh=auto_refresh,
|
||||
auto_commit=auto_commit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_company_filter_statement(self, statement: SelectT | None) -> SelectT:
|
||||
if not self.company_id_field:
|
||||
raise AttributeError(
|
||||
f"company_id_field must be set for {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
column = self._get_column_by_name(self.company_id_field)
|
||||
if column is None:
|
||||
raise AttributeError(
|
||||
f"column {self.company_id_field} not found in {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
stmt = statement if statement is not None else self.statement
|
||||
return stmt.where(column == self.company.id)
|
||||
|
||||
def _apply_filters(
|
||||
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
|
||||
) -> SelectT:
|
||||
if not self.company_id_field:
|
||||
raise AttributeError(
|
||||
f"company_id_field must be set for {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
statement = super()._apply_filters(
|
||||
*filters, apply_pagination=apply_pagination, statement=statement
|
||||
)
|
||||
|
||||
statement = statement.where(
|
||||
self._get_column_by_name(self.company_id_field) == self.company.id
|
||||
)
|
||||
|
||||
return statement
|
||||
|
||||
async def get( # type: ignore[override]
|
||||
self,
|
||||
item_id: Any,
|
||||
auto_expunge: bool | None = None,
|
||||
statement: SelectT | None = None,
|
||||
id_attribute: str | InstrumentedAttribute | None = None,
|
||||
) -> ModelT:
|
||||
statement = self._get_company_filter_statement(statement)
|
||||
|
||||
return await super().get(
|
||||
item_id=item_id,
|
||||
auto_expunge=auto_expunge,
|
||||
statement=statement,
|
||||
id_attribute=id_attribute,
|
||||
)
|
||||
38
app/lib/filter_repository.py
Normal file
38
app/lib/filter_repository.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import Optional, cast
|
||||
|
||||
from litestar.contrib.repository import FilterTypes
|
||||
from litestar.contrib.sqlalchemy.repository import ModelT, SQLAlchemyAsyncRepository
|
||||
from litestar.contrib.sqlalchemy.repository.types import SelectT
|
||||
from sqlalchemy import true, Column
|
||||
|
||||
from app.lib.filters import ExactFilter
|
||||
|
||||
|
||||
class FilterRepository(SQLAlchemyAsyncRepository[ModelT]):
|
||||
alive_flag: Optional[str] = None
|
||||
|
||||
def _get_column_by_name(self, name: str) -> Column | None:
|
||||
return cast(Column, getattr(self.model_type, name, None))
|
||||
|
||||
def _apply_filters(
|
||||
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
|
||||
) -> SelectT:
|
||||
standard_filters = []
|
||||
for filter_ in filters:
|
||||
if isinstance(filter_, ExactFilter):
|
||||
statement = statement.where(
|
||||
self._get_column_by_name(filter_.field_name) == filter_.value
|
||||
)
|
||||
else:
|
||||
standard_filters.append(filter_)
|
||||
|
||||
statement = super()._apply_filters(
|
||||
*standard_filters, apply_pagination=apply_pagination, statement=statement
|
||||
)
|
||||
|
||||
if self.alive_flag:
|
||||
statement = statement.where(
|
||||
self._get_column_by_name(self.alive_flag) == true()
|
||||
)
|
||||
|
||||
return statement
|
||||
Reference in New Issue
Block a user