from __future__ import annotations from typing import TYPE_CHECKING, Any from litestar.contrib.repository import FilterTypes from litestar.contrib.sqlalchemy.repository import ModelT from litestar.contrib.sqlalchemy.repository.types import SelectT from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import InstrumentedAttribute from app.lib.filter_repository import FilterRepository if TYPE_CHECKING: from app.domain.company import Company class CompanyOwnedRepository(FilterRepository[ModelT]): company_id_field: str | None = None def __init__( self, *, company: Company, statement: SelectT | None = None, session: AsyncSession, auto_expunge: bool = False, auto_refresh: bool = True, auto_commit: bool = False, **kwargs: Any, ) -> None: self.company = company super().__init__( statement=statement, session=session, auto_expunge=auto_expunge, auto_refresh=auto_refresh, auto_commit=auto_commit, **kwargs, ) def _get_company_filter_statement(self, statement: SelectT | None) -> SelectT: if not self.company_id_field: raise AttributeError( f"company_id_field must be set for {self.__class__.__name__}" ) column = self._get_column_by_name(self.company_id_field) if column is None: raise AttributeError( f"column {self.company_id_field} not found in {self.__class__.__name__}" ) stmt = statement if statement is not None else self.statement return stmt.where(column == self.company.id) def _apply_filters( self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT ) -> SelectT: if not self.company_id_field: raise AttributeError( f"company_id_field must be set for {self.__class__.__name__}" ) statement = super()._apply_filters( *filters, apply_pagination=apply_pagination, statement=statement ) statement = statement.where( self._get_column_by_name(self.company_id_field) == self.company.id ) return statement async def get( # type: ignore[override] self, item_id: Any, auto_expunge: bool | None = None, statement: SelectT | None = None, id_attribute: str | InstrumentedAttribute | None = None, ) -> ModelT: statement = self._get_company_filter_statement(statement) return await super().get( item_id=item_id, auto_expunge=auto_expunge, statement=statement, id_attribute=id_attribute, )