CSV Loader

This commit is contained in:
Eden Kirin
2023-02-16 15:43:21 +01:00
parent 5caef41e80
commit c5a0008236
10 changed files with 603 additions and 0 deletions

View File

@ -0,0 +1,27 @@
from .csv_loader import (
BoolValuePair,
CSVFieldDuplicate,
CSVLoader,
CSVLoaderResult,
CSVRow,
CSVRowDefaultConfig,
CSVRows,
)
from .mapping_strategies import (
HeaderRemapField,
MappingStrategyByHeader,
MappingStrategyByModelFieldOrder,
)
__all__ = [
"BoolValuePair",
"CSVLoader",
"CSVLoaderResult",
"CSVRow",
"CSVRows",
"CSVFieldDuplicate",
"CSVRowDefaultConfig",
"MappingStrategyByHeader",
"MappingStrategyByModelFieldOrder",
"HeaderRemapField",
]

View File

@ -0,0 +1,244 @@
from __future__ import annotations
import collections
import json
from dataclasses import dataclass
from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, Type, TypeVar
from pydantic import BaseModel, ValidationError, validator
from pydantic.fields import ModelField
from .errors import CSVValidationError, MappingStrategyError
from .mapping_strategies import MappingStrategyBase, MappingStrategyByModelFieldOrder
CSVReaderType = Iterable[List[str]]
@dataclass
class BoolValuePair:
true: str
false: str
@dataclass
class CSVFieldDuplicate:
value: Any
duplicate_rows: List[int]
class CSVRowDefaultConfig:
anystr_strip_whitespace: bool = True
"""Standard pydantic config flag, set default to True."""
empty_optional_str_fields_to_none: Tuple = ("__all__",)
"""List of optional string fields which will be converted to None, if empty.
Default magic value is "__all__" to convert all fields."""
bool_value_pair: BoolValuePair = BoolValuePair(true="1", false="0")
"""Possible boolean values for true and false. If the actual value
is not in defined pair, value will be parsed as None."""
class CSVRow(BaseModel):
"""
Represents a model base for a single CSV row and implements special handling for string values.
If given field value is empty, but annotated type is not string, it will be converted to None.
This is useful for basic types (int, float), to be converted to None if value is not provided.
It's assumed those fields are annotated as Optional, otherwise pydantic will raise expected
validation error.
See Config inner class for more options.
"""
class Config(CSVRowDefaultConfig):
"""
Defaults from CSVRowDefaultConfig will be used. If you're defining your own Config in
custom CSVRow class, make sure it inherits `CSVRowDefaultConfig`
"""
@validator("*", pre=True)
def prepare_str_value(
cls: CSVRow, value: Any, field: ModelField # noqa: ANN401
) -> Optional[Any]:
# not a string? just return value, pydantic validator will do the rest
if not isinstance(value, str):
return value
# strip whitespace if config say so
if cls.Config.anystr_strip_whitespace:
value = value.strip()
# special handling for bool values
if field.type_ is bool:
if value == cls.Config.bool_value_pair.true:
return True
if value == cls.Config.bool_value_pair.false:
return False
return None
# no special handling for non-empty strings
if len(value) > 0:
return value
# empty value and annotated field type is not string? return None
if field.type_ is not str:
return None
# if string field is annotated as optional with 0 length, set it to None
if (
"__all__" in cls.Config.empty_optional_str_fields_to_none
or field.name in cls.Config.empty_optional_str_fields_to_none
) and not field.required:
return None
return value
CSVLoaderModelType = TypeVar("CSVLoaderModelType", bound=BaseModel)
class CSVRows(List[CSVLoaderModelType]):
"""Generic parsed CSV rows containing pydantic models."""
def get_field_values(self, field_name: str) -> List[Any]:
"""Get list of all values from models for named field.
Field value order is preserved."""
return [getattr(row, field_name) for row in self]
def get_field_values_unique(self, field_name: str) -> List[Any]:
"""Get list of all unique values from models for named field, without duplicates.
Field value order is not preserved."""
return list(set(self.get_field_values(field_name)))
def get_field_duplicates(self, field_name: str) -> List[CSVFieldDuplicate]:
"""Get list of fields with duplicate values."""
check_dict: Dict[Any, List] = collections.defaultdict(list)
all_field_values = self.get_field_values(field_name)
for row_index, value in enumerate(all_field_values):
check_dict[value].append(row_index)
result: List[CSVFieldDuplicate] = [
CSVFieldDuplicate(
value=value,
duplicate_rows=found_in_rows,
)
for value, found_in_rows in check_dict.items()
if len(found_in_rows) > 1
]
return result
def dict_list(self) -> List[Dict]:
"""Get list of all rows converted to dict."""
return [row.dict() for row in self]
def json(self) -> str:
"""Get json representation of all rows."""
return json.dumps(self.dict_list())
class CSVLoaderResult(Generic[CSVLoaderModelType]):
"""Generic CSVLoader result. Contains parsed pydantic models, aggregated errors and header content."""
def __init__(self) -> None:
self.rows: CSVRows[CSVLoaderModelType] = CSVRows()
self.errors: List[CSVValidationError] = []
self.header: List[str] = []
def has_errors(self) -> bool:
return len(self.errors) > 0
class CSVLoader(Generic[CSVLoaderModelType]):
"""
Generic CSV file parser.
Uses standard csv reader to fetch csv rows, validate against provided
pydantic model and returns list of created models together with
aggregated error list.
Example:
with open("data.csv") as csv_file:
reader = csv.reader(csv_file, delimiter=",")
csv_loader = CSVLoader[MyRowModel](
reader=reader,
output_model_cls=MyRowModel,
has_header=True,
aggregate_errors=True,
)
result = csv_loader.read_rows()
if result.has_errors():
print("Errors:")
for error in result.errors:
print(error)
print("Created models:")
for row in result.rows:
print(row.index, row.organization_id)
See tests/adapters/tools/test_csv_loader.py for more examples.
"""
def __init__(
self,
reader: CSVReaderType,
output_model_cls: Type[CSVLoaderModelType],
has_header: Optional[bool] = True,
aggregate_errors: Optional[bool] = False,
mapping_strategy: Optional[MappingStrategyBase] = None,
) -> None:
self.reader = reader
self.output_model_cls = output_model_cls
self.has_header = has_header
self.aggregate_errors = aggregate_errors
if mapping_strategy:
self.mapping_strategy = mapping_strategy
else:
self.mapping_strategy = MappingStrategyByModelFieldOrder(
model_cls=self.output_model_cls,
)
self.mapping_strategy.validate_csv_loader_configuration(csv_loader=self)
def read_rows(self) -> CSVLoaderResult[CSVLoaderModelType]:
result = CSVLoaderResult[CSVLoaderModelType]()
for line_number, row in enumerate(self.reader):
# skip header, if configured and first line
if self.has_header and line_number == 0:
# strip header field names
header = [field.strip() for field in row]
result.header = header
self.mapping_strategy.set_header(header)
continue
# skip empty lines
if not row:
continue
row_model = None
try:
# create model kwargs params using mapping strategy
model_create_kwargs = self.mapping_strategy.create_model_param_dict(
row_values=row,
)
# create output model from row data
row_model = self.output_model_cls(**model_create_kwargs)
except (MappingStrategyError, ValidationError) as ex:
# create extended error object
error = CSVValidationError(
line_number=line_number,
original_error=ex,
)
if self.aggregate_errors:
# if we're aggregating errors, just add exception to the list
result.errors.append(error)
else:
# else just raise error and stop reading rows
raise error
# row_model will be None if creation fails and error aggregation is active
if row_model is not None:
result.rows.append(row_model)
return result
def row_index_to_line_number(self, row_index: int) -> int:
return row_index if not self.has_header else row_index + 1

21
lib/csv_loader/errors.py Normal file
View File

@ -0,0 +1,21 @@
class CSVValidationError(Exception):
"""Extended validation exception class containing additional attributes."""
def __init__(self, line_number: int, original_error: Exception) -> None:
self.line_number = line_number
self.original_error = original_error
def __str__(self) -> str:
return f"Error at line {self.line_number}: {self.original_error}"
class MappingStrategyError(Exception):
...
class HeaderNotSetError(MappingStrategyError):
detail = "Header must be set in order to use MappingStrategyByHeader"
class IndexOutOfHeaderBounds(MappingStrategyError):
detail = "Row value index out of header bounds"

View File

@ -0,0 +1,114 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Type, cast
from pydantic import BaseModel
from .errors import HeaderNotSetError, IndexOutOfHeaderBounds
class MappingStrategyBase(ABC):
"""
Mapping strategy implements mechanism of creating params (kwargs) dict from
row values which is later used in model creation.
"""
def __init__(self, model_cls: Type[BaseModel]) -> None:
self.model_cls = model_cls
self.header: Optional[List[str]] = None
def set_header(self, header: List[str]) -> None:
self.header = header
@abstractmethod
def create_model_param_dict(self, row_values: List[Any]) -> Dict[str, Any]:
"""Create initial model params dict."""
@classmethod
def validate_csv_loader_configuration(
cls: Type[MappingStrategyBase], csv_loader: object
) -> bool:
return True
class MappingStrategyByModelFieldOrder(MappingStrategyBase):
"""
Implements 1:1 field assignment. Each row value is assigned to model attribute
in order in which is defined in model.
"""
def __init__(self, model_cls: Type[BaseModel]) -> None:
super().__init__(model_cls)
self.field_names = self.model_cls.__fields__.keys()
def create_model_param_dict(self, row_values: List[Any]) -> Dict[str, Any]:
# map model field names as dict keys
return dict(zip(self.field_names, row_values))
@dataclass
class HeaderRemapField:
header_field: str
model_attr: str
class MappingStrategyByHeader(MappingStrategyBase):
"""Implements by-header assignment. Header must be present."""
def __init__(
self,
model_cls: Type[BaseModel],
header_remap_fields: Optional[List[HeaderRemapField]] = None,
) -> None:
super().__init__(model_cls)
self.header: List[str] = []
self.header_remap = header_remap_fields
@classmethod
def validate_csv_loader_configuration(
cls: Type[MappingStrategyByHeader], csv_loader: object
) -> bool:
# avoid circular imports and keep mypy happy
from .csv_loader import CSVLoader
csv_loader = cast(CSVLoader, csv_loader)
if not csv_loader.has_header:
raise HeaderNotSetError()
return True
@staticmethod
def _remap_header_mapping(
header_mapping: Dict[str, Any], header_remap: Optional[List[HeaderRemapField]]
) -> Dict[str, Any]:
if not header_remap:
return header_mapping
header_mapping = header_mapping.copy()
for remap_field in header_remap:
if remap_field.header_field in header_mapping:
header_mapping[remap_field.model_attr] = header_mapping.pop(
remap_field.header_field
)
return header_mapping
def create_model_param_dict(self, row_values: List[Any]) -> Dict[str, Any]:
# header not set? stop! hammer time!
if not self.header:
raise HeaderNotSetError()
# header too short, can't do
if len(row_values) > len(self.header):
raise IndexOutOfHeaderBounds()
# map header values as dict keys
header_mapping = dict(zip(self.header, row_values))
header_mapping = self._remap_header_mapping(
header_mapping=header_mapping, header_remap=self.header_remap
)
return header_mapping