90 lines
2.7 KiB
Python
90 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from litestar.contrib.repository import FilterTypes
|
|
from litestar.contrib.sqlalchemy.repository import ModelT
|
|
from litestar.contrib.sqlalchemy.repository.types import SelectT
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import InstrumentedAttribute
|
|
|
|
from app.lib.filter_repository import FilterRepository
|
|
|
|
if TYPE_CHECKING:
|
|
from app.domain.company import Company
|
|
|
|
|
|
class CompanyOwnedRepository(FilterRepository[ModelT]):
|
|
company_id_field: str | None = None
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
company: Company,
|
|
statement: SelectT | None = None,
|
|
session: AsyncSession,
|
|
auto_expunge: bool = False,
|
|
auto_refresh: bool = True,
|
|
auto_commit: bool = False,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self.company = company
|
|
|
|
super().__init__(
|
|
statement=statement,
|
|
session=session,
|
|
auto_expunge=auto_expunge,
|
|
auto_refresh=auto_refresh,
|
|
auto_commit=auto_commit,
|
|
**kwargs,
|
|
)
|
|
|
|
def _get_company_filter_statement(self, statement: SelectT | None) -> SelectT:
|
|
if not self.company_id_field:
|
|
raise AttributeError(
|
|
f"company_id_field must be set for {self.__class__.__name__}"
|
|
)
|
|
|
|
column = self._get_column_by_name(self.company_id_field)
|
|
if column is None:
|
|
raise AttributeError(
|
|
f"column {self.company_id_field} not found in {self.__class__.__name__}"
|
|
)
|
|
|
|
stmt = statement if statement is not None else self.statement
|
|
return stmt.where(column == self.company.id)
|
|
|
|
def _apply_filters(
|
|
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
|
|
) -> SelectT:
|
|
if not self.company_id_field:
|
|
raise AttributeError(
|
|
f"company_id_field must be set for {self.__class__.__name__}"
|
|
)
|
|
|
|
statement = super()._apply_filters(
|
|
*filters, apply_pagination=apply_pagination, statement=statement
|
|
)
|
|
|
|
statement = statement.where(
|
|
self._get_column_by_name(self.company_id_field) == self.company.id
|
|
)
|
|
|
|
return statement
|
|
|
|
async def get( # type: ignore[override]
|
|
self,
|
|
item_id: Any,
|
|
auto_expunge: bool | None = None,
|
|
statement: SelectT | None = None,
|
|
id_attribute: str | InstrumentedAttribute | None = None,
|
|
) -> ModelT:
|
|
statement = self._get_company_filter_statement(statement)
|
|
|
|
return await super().get(
|
|
item_id=item_id,
|
|
auto_expunge=auto_expunge,
|
|
statement=statement,
|
|
id_attribute=id_attribute,
|
|
)
|