Files
csv-loader-generics-present…/lib/csv_loader/mapping_strategies.py
Eden Kirin c5a0008236 CSV Loader
2023-02-16 15:43:21 +01:00

115 lines
3.4 KiB
Python

from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type, cast
from pydantic import BaseModel
from .errors import HeaderNotSetError, IndexOutOfHeaderBounds
class MappingStrategyBase(ABC):
"""
Mapping strategy implements mechanism of creating params (kwargs) dict from
row values which is later used in model creation.
"""
def __init__(self, model_cls: Type[BaseModel]) -> None:
self.model_cls = model_cls
self.header: Optional[List[str]] = None
def set_header(self, header: List[str]) -> None:
self.header = header
@abstractmethod
def create_model_param_dict(self, row_values: List[Any]) -> Dict[str, Any]:
"""Create initial model params dict."""
@classmethod
def validate_csv_loader_configuration(
cls: Type[MappingStrategyBase], csv_loader: object
) -> bool:
return True
class MappingStrategyByModelFieldOrder(MappingStrategyBase):
"""
Implements 1:1 field assignment. Each row value is assigned to model attribute
in order in which is defined in model.
"""
def __init__(self, model_cls: Type[BaseModel]) -> None:
super().__init__(model_cls)
self.field_names = self.model_cls.__fields__.keys()
def create_model_param_dict(self, row_values: List[Any]) -> Dict[str, Any]:
# map model field names as dict keys
return dict(zip(self.field_names, row_values))
@dataclass
class HeaderRemapField:
header_field: str
model_attr: str
class MappingStrategyByHeader(MappingStrategyBase):
"""Implements by-header assignment. Header must be present."""
def __init__(
self,
model_cls: Type[BaseModel],
header_remap_fields: Optional[List[HeaderRemapField]] = None,
) -> None:
super().__init__(model_cls)
self.header: List[str] = []
self.header_remap = header_remap_fields
@classmethod
def validate_csv_loader_configuration(
cls: Type[MappingStrategyByHeader], csv_loader: object
) -> bool:
# avoid circular imports and keep mypy happy
from .csv_loader import CSVLoader
csv_loader = cast(CSVLoader, csv_loader)
if not csv_loader.has_header:
raise HeaderNotSetError()
return True
@staticmethod
def _remap_header_mapping(
header_mapping: Dict[str, Any], header_remap: Optional[List[HeaderRemapField]]
) -> Dict[str, Any]:
if not header_remap:
return header_mapping
header_mapping = header_mapping.copy()
for remap_field in header_remap:
if remap_field.header_field in header_mapping:
header_mapping[remap_field.model_attr] = header_mapping.pop(
remap_field.header_field
)
return header_mapping
def create_model_param_dict(self, row_values: List[Any]) -> Dict[str, Any]:
# header not set? stop! hammer time!
if not self.header:
raise HeaderNotSetError()
# header too short, can't do
if len(row_values) > len(self.header):
raise IndexOutOfHeaderBounds()
# map header values as dict keys
header_mapping = dict(zip(self.header, row_values))
header_mapping = self._remap_header_mapping(
header_mapping=header_mapping, header_remap=self.header_remap
)
return header_mapping