Done
This commit is contained in:
@ -5,6 +5,7 @@ from litestar.contrib.repository.filters import LimitOffset, SearchFilter
|
|||||||
from litestar.di import Provide
|
from litestar.di import Provide
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.domain.company import Company
|
||||||
from app.domain.machine import (
|
from app.domain.machine import (
|
||||||
Machine,
|
Machine,
|
||||||
MachineReadDTO,
|
MachineReadDTO,
|
||||||
@ -12,6 +13,7 @@ from app.domain.machine import (
|
|||||||
Repository,
|
Repository,
|
||||||
Service,
|
Service,
|
||||||
)
|
)
|
||||||
|
from app.lib.filters import ExactFilter
|
||||||
from app.lib.responses import ObjectListResponse, ObjectResponse
|
from app.lib.responses import ObjectListResponse, ObjectResponse
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -21,15 +23,24 @@ if TYPE_CHECKING:
|
|||||||
DETAIL_ROUTE = "/{machine_id:int}"
|
DETAIL_ROUTE = "/{machine_id:int}"
|
||||||
|
|
||||||
|
|
||||||
def provides_service(db_session: AsyncSession) -> Service:
|
async def provides_service(db_session: AsyncSession, company_id: int) -> Service:
|
||||||
"""Constructs repository and service objects for the request."""
|
"""Constructs repository and service objects for the request."""
|
||||||
return Service(Repository(session=db_session))
|
from app.controllers.company import provides_service
|
||||||
|
company_service = provides_service(db_session)
|
||||||
|
company = await company_service.get(company_id)
|
||||||
|
return Service(Repository(session=db_session, company=company))
|
||||||
|
|
||||||
|
|
||||||
|
async def get_company(db_session: AsyncSession, company_id: int) -> Company:
|
||||||
|
from app.controllers.company import provides_service
|
||||||
|
company_service = provides_service(db_session)
|
||||||
|
return await company_service.get(company_id)
|
||||||
|
|
||||||
|
|
||||||
class MachineController(Controller):
|
class MachineController(Controller):
|
||||||
dto = MachineWriteDTO
|
dto = MachineWriteDTO
|
||||||
return_dto = MachineReadDTO
|
return_dto = MachineReadDTO
|
||||||
path = "/machines"
|
path = "/company/{company_id:int}/machines"
|
||||||
dependencies = {
|
dependencies = {
|
||||||
"service": Provide(provides_service, sync_to_thread=False),
|
"service": Provide(provides_service, sync_to_thread=False),
|
||||||
}
|
}
|
||||||
@ -37,7 +48,9 @@ class MachineController(Controller):
|
|||||||
|
|
||||||
@get()
|
@get()
|
||||||
async def get_machines(
|
async def get_machines(
|
||||||
self, service: Service, search: Optional[str] = None
|
self,
|
||||||
|
service: Service,
|
||||||
|
search: Optional[str] = None,
|
||||||
) -> ObjectListResponse[Machine]:
|
) -> ObjectListResponse[Machine]:
|
||||||
filters = [
|
filters = [
|
||||||
LimitOffset(limit=20, offset=0),
|
LimitOffset(limit=20, offset=0),
|
||||||
|
|||||||
@ -1,15 +1,12 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from litestar.contrib.repository import FilterTypes
|
|
||||||
from litestar.contrib.sqlalchemy.base import BigIntBase
|
from litestar.contrib.sqlalchemy.base import BigIntBase
|
||||||
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
|
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
|
||||||
from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository
|
|
||||||
from litestar.contrib.sqlalchemy.repository.types import SelectT
|
|
||||||
from litestar.dto import DTOConfig
|
from litestar.dto import DTOConfig
|
||||||
from sqlalchemy import true
|
|
||||||
from sqlalchemy.orm import Mapped
|
from sqlalchemy.orm import Mapped
|
||||||
|
|
||||||
from app.lib import service
|
from app.lib import service
|
||||||
|
from app.lib.filter_repository import FilterRepository
|
||||||
|
|
||||||
|
|
||||||
class Company(BigIntBase):
|
class Company(BigIntBase):
|
||||||
@ -25,23 +22,15 @@ class Company(BigIntBase):
|
|||||||
alive: Mapped[bool]
|
alive: Mapped[bool]
|
||||||
|
|
||||||
|
|
||||||
class Repository(SQLAlchemyAsyncRepository[Company]):
|
class Repository(FilterRepository[Company]):
|
||||||
model_type = Company
|
model_type = Company
|
||||||
|
alive_flag = "alive"
|
||||||
def _apply_filters(
|
|
||||||
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
|
|
||||||
) -> SelectT:
|
|
||||||
statement = super()._apply_filters(
|
|
||||||
*filters, apply_pagination=apply_pagination, statement=statement
|
|
||||||
)
|
|
||||||
statement = statement.where(self.model_type.alive == true())
|
|
||||||
return statement
|
|
||||||
|
|
||||||
|
|
||||||
class Service(service.Service[Company]):
|
class Service(service.Service[Company]):
|
||||||
repository_type = Repository
|
repository_type = Repository
|
||||||
|
|
||||||
|
|
||||||
write_config = DTOConfig()
|
write_config = DTOConfig(exclude={"id"})
|
||||||
CompanyWriteDTO = SQLAlchemyDTO[Annotated[Company, write_config]]
|
CompanyWriteDTO = SQLAlchemyDTO[Annotated[Company, write_config]]
|
||||||
CompanyReadDTO = SQLAlchemyDTO[Company]
|
CompanyReadDTO = SQLAlchemyDTO[Company]
|
||||||
|
|||||||
@ -2,17 +2,13 @@ from enum import Enum
|
|||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from litestar.contrib.repository import FilterTypes
|
|
||||||
from litestar.contrib.sqlalchemy.base import BigIntBase
|
from litestar.contrib.sqlalchemy.base import BigIntBase
|
||||||
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
|
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
|
||||||
from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository
|
|
||||||
from litestar.contrib.sqlalchemy.repository.types import SelectT
|
|
||||||
from litestar.dto import DTOConfig
|
from litestar.dto import DTOConfig
|
||||||
from sqlalchemy import ColumnElement
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from app.lib import service
|
from app.lib import service
|
||||||
from app.lib.filters import ExactFilter
|
from app.lib.filter_repository import FilterRepository
|
||||||
|
|
||||||
|
|
||||||
class FiscalModuleEnum(str, Enum):
|
class FiscalModuleEnum(str, Enum):
|
||||||
@ -52,24 +48,9 @@ class FiscalPaymentMapping(BigIntBase):
|
|||||||
payment_device_code: Mapped[Optional[int]]
|
payment_device_code: Mapped[Optional[int]]
|
||||||
|
|
||||||
|
|
||||||
class Repository(SQLAlchemyAsyncRepository[FiscalPaymentMapping]):
|
class Repository(FilterRepository[FiscalPaymentMapping]):
|
||||||
model_type = FiscalPaymentMapping
|
model_type = FiscalPaymentMapping
|
||||||
|
|
||||||
def _apply_filters(
|
|
||||||
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
|
|
||||||
) -> SelectT:
|
|
||||||
standard_filters = []
|
|
||||||
for filter_ in filters:
|
|
||||||
if isinstance(filter_, ExactFilter):
|
|
||||||
field: ColumnElement = getattr(self.model_type, filter_.field_name)
|
|
||||||
statement = statement.where(field == filter_.value)
|
|
||||||
else:
|
|
||||||
standard_filters.append(filter_)
|
|
||||||
|
|
||||||
return super()._apply_filters(
|
|
||||||
*standard_filters, apply_pagination=apply_pagination, statement=statement
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Service(service.Service[FiscalPaymentMapping]):
|
class Service(service.Service[FiscalPaymentMapping]):
|
||||||
repository_type = Repository
|
repository_type = Repository
|
||||||
|
|||||||
@ -1,15 +1,12 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from litestar.contrib.repository import FilterTypes
|
|
||||||
from litestar.contrib.sqlalchemy.base import BigIntBase
|
from litestar.contrib.sqlalchemy.base import BigIntBase
|
||||||
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
|
from litestar.contrib.sqlalchemy.dto import SQLAlchemyDTO
|
||||||
from litestar.contrib.sqlalchemy.repository import SQLAlchemyAsyncRepository
|
|
||||||
from litestar.contrib.sqlalchemy.repository.types import SelectT
|
|
||||||
from litestar.dto import DTOConfig
|
from litestar.dto import DTOConfig
|
||||||
from sqlalchemy import true
|
|
||||||
from sqlalchemy.orm import Mapped
|
from sqlalchemy.orm import Mapped
|
||||||
|
|
||||||
from app.lib import service
|
from app.lib import service
|
||||||
|
from app.lib.company_owned_repository import CompanyOwnedRepository
|
||||||
|
|
||||||
|
|
||||||
class Machine(BigIntBase):
|
class Machine(BigIntBase):
|
||||||
@ -20,26 +17,19 @@ class Machine(BigIntBase):
|
|||||||
alive: Mapped[bool]
|
alive: Mapped[bool]
|
||||||
deleted: Mapped[bool]
|
deleted: Mapped[bool]
|
||||||
external_id: Mapped[str]
|
external_id: Mapped[str]
|
||||||
|
owner_id: Mapped[int]
|
||||||
|
|
||||||
|
|
||||||
class Repository(SQLAlchemyAsyncRepository[Machine]):
|
class Repository(CompanyOwnedRepository[Machine]):
|
||||||
model_type = Machine
|
model_type = Machine
|
||||||
|
alive_flag = "alive"
|
||||||
def _apply_filters(
|
company_id_field = "owner_id"
|
||||||
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
|
|
||||||
) -> SelectT:
|
|
||||||
statement = super()._apply_filters(
|
|
||||||
*filters, apply_pagination=apply_pagination, statement=statement
|
|
||||||
)
|
|
||||||
statement = statement.where(self.model_type.alive == true())
|
|
||||||
return statement
|
|
||||||
|
|
||||||
|
|
||||||
class Service(service.Service[Machine]):
|
class Service(service.Service[Machine]):
|
||||||
repository_type = Repository
|
repository_type = Repository
|
||||||
|
|
||||||
|
|
||||||
# write_config = DTOConfig(exclude={"created_at", "updated_at", "nationality"})
|
write_config = DTOConfig(exclude={"id"})
|
||||||
write_config = DTOConfig()
|
|
||||||
MachineWriteDTO = SQLAlchemyDTO[Annotated[Machine, write_config]]
|
MachineWriteDTO = SQLAlchemyDTO[Annotated[Machine, write_config]]
|
||||||
MachineReadDTO = SQLAlchemyDTO[Machine]
|
MachineReadDTO = SQLAlchemyDTO[Machine]
|
||||||
|
|||||||
89
app/lib/company_owned_repository.py
Normal file
89
app/lib/company_owned_repository.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from litestar.contrib.repository import FilterTypes
|
||||||
|
from litestar.contrib.sqlalchemy.repository import ModelT
|
||||||
|
from litestar.contrib.sqlalchemy.repository.types import SelectT
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import InstrumentedAttribute
|
||||||
|
|
||||||
|
from app.lib.filter_repository import FilterRepository
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.domain.company import Company
|
||||||
|
|
||||||
|
|
||||||
|
class CompanyOwnedRepository(FilterRepository[ModelT]):
|
||||||
|
company_id_field: str | None = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
company: Company,
|
||||||
|
statement: SelectT | None = None,
|
||||||
|
session: AsyncSession,
|
||||||
|
auto_expunge: bool = False,
|
||||||
|
auto_refresh: bool = True,
|
||||||
|
auto_commit: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self.company = company
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
statement=statement,
|
||||||
|
session=session,
|
||||||
|
auto_expunge=auto_expunge,
|
||||||
|
auto_refresh=auto_refresh,
|
||||||
|
auto_commit=auto_commit,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_company_filter_statement(self, statement: SelectT | None) -> SelectT:
|
||||||
|
if not self.company_id_field:
|
||||||
|
raise AttributeError(
|
||||||
|
f"company_id_field must be set for {self.__class__.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
column = self._get_column_by_name(self.company_id_field)
|
||||||
|
if column is None:
|
||||||
|
raise AttributeError(
|
||||||
|
f"column {self.company_id_field} not found in {self.__class__.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
stmt = statement if statement is not None else self.statement
|
||||||
|
return stmt.where(column == self.company.id)
|
||||||
|
|
||||||
|
def _apply_filters(
|
||||||
|
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
|
||||||
|
) -> SelectT:
|
||||||
|
if not self.company_id_field:
|
||||||
|
raise AttributeError(
|
||||||
|
f"company_id_field must be set for {self.__class__.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
statement = super()._apply_filters(
|
||||||
|
*filters, apply_pagination=apply_pagination, statement=statement
|
||||||
|
)
|
||||||
|
|
||||||
|
statement = statement.where(
|
||||||
|
self._get_column_by_name(self.company_id_field) == self.company.id
|
||||||
|
)
|
||||||
|
|
||||||
|
return statement
|
||||||
|
|
||||||
|
async def get( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
item_id: Any,
|
||||||
|
auto_expunge: bool | None = None,
|
||||||
|
statement: SelectT | None = None,
|
||||||
|
id_attribute: str | InstrumentedAttribute | None = None,
|
||||||
|
) -> ModelT:
|
||||||
|
statement = self._get_company_filter_statement(statement)
|
||||||
|
|
||||||
|
return await super().get(
|
||||||
|
item_id=item_id,
|
||||||
|
auto_expunge=auto_expunge,
|
||||||
|
statement=statement,
|
||||||
|
id_attribute=id_attribute,
|
||||||
|
)
|
||||||
38
app/lib/filter_repository.py
Normal file
38
app/lib/filter_repository.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
from litestar.contrib.repository import FilterTypes
|
||||||
|
from litestar.contrib.sqlalchemy.repository import ModelT, SQLAlchemyAsyncRepository
|
||||||
|
from litestar.contrib.sqlalchemy.repository.types import SelectT
|
||||||
|
from sqlalchemy import true, Column
|
||||||
|
|
||||||
|
from app.lib.filters import ExactFilter
|
||||||
|
|
||||||
|
|
||||||
|
class FilterRepository(SQLAlchemyAsyncRepository[ModelT]):
|
||||||
|
alive_flag: Optional[str] = None
|
||||||
|
|
||||||
|
def _get_column_by_name(self, name: str) -> Column | None:
|
||||||
|
return cast(Column, getattr(self.model_type, name, None))
|
||||||
|
|
||||||
|
def _apply_filters(
|
||||||
|
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
|
||||||
|
) -> SelectT:
|
||||||
|
standard_filters = []
|
||||||
|
for filter_ in filters:
|
||||||
|
if isinstance(filter_, ExactFilter):
|
||||||
|
statement = statement.where(
|
||||||
|
self._get_column_by_name(filter_.field_name) == filter_.value
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
standard_filters.append(filter_)
|
||||||
|
|
||||||
|
statement = super()._apply_filters(
|
||||||
|
*standard_filters, apply_pagination=apply_pagination, statement=statement
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.alive_flag:
|
||||||
|
statement = statement.where(
|
||||||
|
self._get_column_by_name(self.alive_flag) == true()
|
||||||
|
)
|
||||||
|
|
||||||
|
return statement
|
||||||
Reference in New Issue
Block a user