This commit is contained in:
Eden Kirin
2023-08-27 13:35:29 +02:00
parent 03c8aaa312
commit 7b16c2f606
6 changed files with 156 additions and 56 deletions

View File

@ -0,0 +1,89 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from litestar.contrib.repository import FilterTypes
from litestar.contrib.sqlalchemy.repository import ModelT
from litestar.contrib.sqlalchemy.repository.types import SelectT
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import InstrumentedAttribute
from app.lib.filter_repository import FilterRepository
if TYPE_CHECKING:
from app.domain.company import Company
class CompanyOwnedRepository(FilterRepository[ModelT]):
company_id_field: str | None = None
def __init__(
self,
*,
company: Company,
statement: SelectT | None = None,
session: AsyncSession,
auto_expunge: bool = False,
auto_refresh: bool = True,
auto_commit: bool = False,
**kwargs: Any,
) -> None:
self.company = company
super().__init__(
statement=statement,
session=session,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
auto_commit=auto_commit,
**kwargs,
)
def _get_company_filter_statement(self, statement: SelectT | None) -> SelectT:
if not self.company_id_field:
raise AttributeError(
f"company_id_field must be set for {self.__class__.__name__}"
)
column = self._get_column_by_name(self.company_id_field)
if column is None:
raise AttributeError(
f"column {self.company_id_field} not found in {self.__class__.__name__}"
)
stmt = statement if statement is not None else self.statement
return stmt.where(column == self.company.id)
def _apply_filters(
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
) -> SelectT:
if not self.company_id_field:
raise AttributeError(
f"company_id_field must be set for {self.__class__.__name__}"
)
statement = super()._apply_filters(
*filters, apply_pagination=apply_pagination, statement=statement
)
statement = statement.where(
self._get_column_by_name(self.company_id_field) == self.company.id
)
return statement
async def get( # type: ignore[override]
self,
item_id: Any,
auto_expunge: bool | None = None,
statement: SelectT | None = None,
id_attribute: str | InstrumentedAttribute | None = None,
) -> ModelT:
statement = self._get_company_filter_statement(statement)
return await super().get(
item_id=item_id,
auto_expunge=auto_expunge,
statement=statement,
id_attribute=id_attribute,
)

View File

@ -0,0 +1,38 @@
from typing import Optional, cast
from litestar.contrib.repository import FilterTypes
from litestar.contrib.sqlalchemy.repository import ModelT, SQLAlchemyAsyncRepository
from litestar.contrib.sqlalchemy.repository.types import SelectT
from sqlalchemy import true, Column
from app.lib.filters import ExactFilter
class FilterRepository(SQLAlchemyAsyncRepository[ModelT]):
alive_flag: Optional[str] = None
def _get_column_by_name(self, name: str) -> Column | None:
return cast(Column, getattr(self.model_type, name, None))
def _apply_filters(
self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT
) -> SelectT:
standard_filters = []
for filter_ in filters:
if isinstance(filter_, ExactFilter):
statement = statement.where(
self._get_column_by_name(filter_.field_name) == filter_.value
)
else:
standard_filters.append(filter_)
statement = super()._apply_filters(
*standard_filters, apply_pagination=apply_pagination, statement=statement
)
if self.alive_flag:
statement = statement.where(
self._get_column_by_name(self.alive_flag) == true()
)
return statement