115 lines
3.4 KiB
Python
115 lines
3.4 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional, Type, cast
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from .errors import HeaderNotSetError, IndexOutOfHeaderBounds
|
|
|
|
|
|
class MappingStrategyBase(ABC):
|
|
"""
|
|
Mapping strategy implements mechanism of creating params (kwargs) dict from
|
|
row values which is later used in model creation.
|
|
"""
|
|
|
|
def __init__(self, model_cls: Type[BaseModel]) -> None:
|
|
self.model_cls = model_cls
|
|
self.header: Optional[List[str]] = None
|
|
|
|
def set_header(self, header: List[str]) -> None:
|
|
self.header = header
|
|
|
|
@abstractmethod
|
|
def create_model_param_dict(self, row_values: List[Any]) -> Dict[str, Any]:
|
|
"""Create initial model params dict."""
|
|
|
|
@classmethod
|
|
def validate_csv_loader_configuration(
|
|
cls: Type[MappingStrategyBase], csv_loader: object
|
|
) -> bool:
|
|
return True
|
|
|
|
|
|
class MappingStrategyByModelFieldOrder(MappingStrategyBase):
|
|
"""
|
|
Implements 1:1 field assignment. Each row value is assigned to model attribute
|
|
in order in which is defined in model.
|
|
"""
|
|
|
|
def __init__(self, model_cls: Type[BaseModel]) -> None:
|
|
super().__init__(model_cls)
|
|
self.field_names = self.model_cls.__fields__.keys()
|
|
|
|
def create_model_param_dict(self, row_values: List[Any]) -> Dict[str, Any]:
|
|
# map model field names as dict keys
|
|
return dict(zip(self.field_names, row_values))
|
|
|
|
|
|
@dataclass
|
|
class HeaderRemapField:
|
|
header_field: str
|
|
model_attr: str
|
|
|
|
|
|
class MappingStrategyByHeader(MappingStrategyBase):
|
|
"""Implements by-header assignment. Header must be present."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_cls: Type[BaseModel],
|
|
header_remap_fields: Optional[List[HeaderRemapField]] = None,
|
|
) -> None:
|
|
super().__init__(model_cls)
|
|
self.header: List[str] = []
|
|
self.header_remap = header_remap_fields
|
|
|
|
@classmethod
|
|
def validate_csv_loader_configuration(
|
|
cls: Type[MappingStrategyByHeader], csv_loader: object
|
|
) -> bool:
|
|
# avoid circular imports and keep mypy happy
|
|
from .csv_loader import CSVLoader
|
|
|
|
csv_loader = cast(CSVLoader, csv_loader)
|
|
|
|
if not csv_loader.has_header:
|
|
raise HeaderNotSetError()
|
|
|
|
return True
|
|
|
|
@staticmethod
|
|
def _remap_header_mapping(
|
|
header_mapping: Dict[str, Any], header_remap: Optional[List[HeaderRemapField]]
|
|
) -> Dict[str, Any]:
|
|
if not header_remap:
|
|
return header_mapping
|
|
|
|
header_mapping = header_mapping.copy()
|
|
|
|
for remap_field in header_remap:
|
|
if remap_field.header_field in header_mapping:
|
|
header_mapping[remap_field.model_attr] = header_mapping.pop(
|
|
remap_field.header_field
|
|
)
|
|
return header_mapping
|
|
|
|
def create_model_param_dict(self, row_values: List[Any]) -> Dict[str, Any]:
|
|
# header not set? stop! hammer time!
|
|
if not self.header:
|
|
raise HeaderNotSetError()
|
|
|
|
# header too short, can't do
|
|
if len(row_values) > len(self.header):
|
|
raise IndexOutOfHeaderBounds()
|
|
|
|
# map header values as dict keys
|
|
header_mapping = dict(zip(self.header, row_values))
|
|
|
|
header_mapping = self._remap_header_mapping(
|
|
header_mapping=header_mapping, header_remap=self.header_remap
|
|
)
|
|
return header_mapping
|