This commit is contained in:
Eden Kirin
2023-08-27 13:35:29 +02:00
parent 03c8aaa312
commit 7b16c2f606
6 changed files with 156 additions and 56 deletions

View File

@ -0,0 +1,89 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from litestar.contrib.repository import FilterTypes
from litestar.contrib.sqlalchemy.repository import ModelT
from litestar.contrib.sqlalchemy.repository.types import SelectT
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import InstrumentedAttribute
from app.lib.filter_repository import FilterRepository
if TYPE_CHECKING:
from app.domain.company import Company
class CompanyOwnedRepository(FilterRepository[ModelT]):
company_id_field: str | None = None
def __init__(
self,
*,
company: Company,
statement: SelectT | None = None,
session: AsyncSession,
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
**kwargs: Any,
) -> None:
self.company = company
super().__init__(
statement=statement,
session=session,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
auto_commit=auto_commit,
**kwargs,
)
def _get_company_filter_statement(self, statement: SelectT | None) -> SelectT:
if not self.company_id_field:
raise AttributeError(
f"company_id_field must be set for {self.__class__.__name__}"
)
column = self._get_column_by_name(self.company_id_field)
if column is None:
raise AttributeError(
f"column {self.company_id_field} not found in {self.__class__.__name__}"
)
stmt = statement if statement is not None else self.statement
return stmt.where(column == self.company.id)
def _apply_filters(
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
) -> SelectT:
if not self.company_id_field:
raise AttributeError(
f"company_id_field must be set for {self.__class__.__name__}"
)
statement = super()._apply_filters(
*filters, apply_pagination=apply_pagination, statement=statement
)
statement = statement.where(
self._get_column_by_name(self.company_id_field) == self.company.id
)
return statement
async def get( # type: ignore[override]
self,
item_id: Any,
auto_expunge: bool | None = None,
statement: SelectT | None = None,
id_attribute: str | InstrumentedAttribute | None = None,
) -> ModelT:
statement = self._get_company_filter_statement(statement)
return await super().get(
item_id=item_id,
auto_expunge=auto_expunge,
statement=statement,
id_attribute=id_attribute,
)