Done
This commit is contained in:
89
app/lib/company_owned_repository.py
Normal file
89
app/lib/company_owned_repository.py
Normal file
@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from litestar.contrib.repository import FilterTypes
|
||||
from litestar.contrib.sqlalchemy.repository import ModelT
|
||||
from litestar.contrib.sqlalchemy.repository.types import SelectT
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import InstrumentedAttribute
|
||||
|
||||
from app.lib.filter_repository import FilterRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.domain.company import Company
|
||||
|
||||
|
||||
class CompanyOwnedRepository(FilterRepository[ModelT]):
|
||||
company_id_field: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
company: Company,
|
||||
statement: SelectT | None = None,
|
||||
session: AsyncSession,
|
||||
auto_expunge: bool = False,
|
||||
auto_refresh: bool = True,
|
||||
auto_commit: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.company = company
|
||||
|
||||
super().__init__(
|
||||
statement=statement,
|
||||
session=session,
|
||||
auto_expunge=auto_expunge,
|
||||
auto_refresh=auto_refresh,
|
||||
auto_commit=auto_commit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_company_filter_statement(self, statement: SelectT | None) -> SelectT:
|
||||
if not self.company_id_field:
|
||||
raise AttributeError(
|
||||
f"company_id_field must be set for {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
column = self._get_column_by_name(self.company_id_field)
|
||||
if column is None:
|
||||
raise AttributeError(
|
||||
f"column {self.company_id_field} not found in {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
stmt = statement if statement is not None else self.statement
|
||||
return stmt.where(column == self.company.id)
|
||||
|
||||
def _apply_filters(
|
||||
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
|
||||
) -> SelectT:
|
||||
if not self.company_id_field:
|
||||
raise AttributeError(
|
||||
f"company_id_field must be set for {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
statement = super()._apply_filters(
|
||||
*filters, apply_pagination=apply_pagination, statement=statement
|
||||
)
|
||||
|
||||
statement = statement.where(
|
||||
self._get_column_by_name(self.company_id_field) == self.company.id
|
||||
)
|
||||
|
||||
return statement
|
||||
|
||||
async def get( # type: ignore[override]
|
||||
self,
|
||||
item_id: Any,
|
||||
auto_expunge: bool | None = None,
|
||||
statement: SelectT | None = None,
|
||||
id_attribute: str | InstrumentedAttribute | None = None,
|
||||
) -> ModelT:
|
||||
statement = self._get_company_filter_statement(statement)
|
||||
|
||||
return await super().get(
|
||||
item_id=item_id,
|
||||
auto_expunge=auto_expunge,
|
||||
statement=statement,
|
||||
id_attribute=id_attribute,
|
||||
)
|
||||
Reference in New Issue
Block a user