From c5a000823630a6e079690f753a3d154fd680dde6 Mon Sep 17 00:00:00 2001 From: Eden Kirin Date: Thu, 16 Feb 2023 15:43:21 +0100 Subject: [PATCH] CSV Loader --- csv/assets-header-remap.csv | 11 ++ csv/assets.csv | 11 ++ lib/csv_loader/__init__.py | 27 +++ lib/csv_loader/csv_loader.py | 244 +++++++++++++++++++++++++++ lib/csv_loader/errors.py | 21 +++ lib/csv_loader/mapping_strategies.py | 114 +++++++++++++ loader_1_simple.py | 48 ++++++ loader_2_header.py | 53 ++++++ loader_3_header_mapper.py | 65 +++++++ loader_3_remote_reader.py | 9 + 10 files changed, 603 insertions(+) create mode 100644 csv/assets-header-remap.csv create mode 100644 csv/assets.csv create mode 100644 lib/csv_loader/__init__.py create mode 100644 lib/csv_loader/csv_loader.py create mode 100644 lib/csv_loader/errors.py create mode 100644 lib/csv_loader/mapping_strategies.py create mode 100644 loader_1_simple.py create mode 100644 loader_2_header.py create mode 100644 loader_3_header_mapper.py create mode 100644 loader_3_remote_reader.py diff --git a/csv/assets-header-remap.csv b/csv/assets-header-remap.csv new file mode 100644 index 0000000..210f42f --- /dev/null +++ b/csv/assets-header-remap.csv @@ -0,0 +1,11 @@ +Asset Id;Action;Serial Number;Brand Id;Model Id;Is Active +"sXkRAbsHYsEnxGJduyjs";0;509409163778441649684954421415;28959;44593;YES +"WJDPKajxANBwbBhEfwAi";1;365228020946043885618697526758;556484;6192;NO +"CjndFIBkAvZgaVEClxzy";2;800106731693683173886690767546;486;70009;YES +"dKLMQOSlrxUoTFJNNIgL";0;484621303136600603740664328753;67869480;59839;YES +"YcDuvcJosGeMbHgTdGRw";2;832537342998006368585679647288;4;215774;YES +"vujkSbLRTzBIfkmUZXjy";0;112483609942168822151288639122;83;4;YES +"DWJowiNmTNUNxTzGTFFr";0;466033505889282671434228388950;249778;480;NO +"HxhsZuGTHgqllERCSWau";0;178809011962241227784938272343;66219;2821473;YES +"WtdXUpLHDjnUuGTQSmqu";0;903201497982085302070856779353;220327928;4;NO +"pchtaEJmdmsrwxBOviBc";0;641199031801878398775345747952;268676370;69;YES diff --git a/csv/assets.csv b/csv/assets.csv new file mode 100644 index 0000000..93c44d1 --- /dev/null +++ b/csv/assets.csv @@ -0,0 +1,11 @@ +"asset_id";"asset_action";"serial_number";"brand_id";"model_id";is_active +"sXkRAbsHYsEnxGJduyjs";0;509409163778441649684954421415;28959;44593;YES +"WJDPKajxANBwbBhEfwAi";1;365228020946043885618697526758;556484;6192;NO +"CjndFIBkAvZgaVEClxzy";2;800106731693683173886690767546;486;70009;YES +"dKLMQOSlrxUoTFJNNIgL";0;484621303136600603740664328753;67869480;59839;YES +"YcDuvcJosGeMbHgTdGRw";2;832537342998006368585679647288;4;215774;YES +"vujkSbLRTzBIfkmUZXjy";0;112483609942168822151288639122;83;4;YES +"DWJowiNmTNUNxTzGTFFr";0;466033505889282671434228388950;249778;480;NO +"HxhsZuGTHgqllERCSWau";0;178809011962241227784938272343;66219;2821473;YES +"WtdXUpLHDjnUuGTQSmqu";0;903201497982085302070856779353;220327928;4;NO +"pchtaEJmdmsrwxBOviBc";0;641199031801878398775345747952;268676370;69;YES diff --git a/lib/csv_loader/__init__.py b/lib/csv_loader/__init__.py new file mode 100644 index 0000000..5388908 --- /dev/null +++ b/lib/csv_loader/__init__.py @@ -0,0 +1,27 @@ +from .csv_loader import ( + BoolValuePair, + CSVFieldDuplicate, + CSVLoader, + CSVLoaderResult, + CSVRow, + CSVRowDefaultConfig, + CSVRows, +) +from .mapping_strategies import ( + HeaderRemapField, + MappingStrategyByHeader, + MappingStrategyByModelFieldOrder, +) + +__all__ = [ + "BoolValuePair", + "CSVLoader", + "CSVLoaderResult", + "CSVRow", + "CSVRows", + "CSVFieldDuplicate", + "CSVRowDefaultConfig", + "MappingStrategyByHeader", + "MappingStrategyByModelFieldOrder", + "HeaderRemapField", +] diff --git a/lib/csv_loader/csv_loader.py b/lib/csv_loader/csv_loader.py new file mode 100644 index 0000000..1e3e756 --- /dev/null +++ b/lib/csv_loader/csv_loader.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +import collections +import json +from dataclasses import dataclass +from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, Type, TypeVar + +from pydantic import BaseModel, ValidationError, validator +from pydantic.fields import ModelField + +from .errors import CSVValidationError, MappingStrategyError +from .mapping_strategies import MappingStrategyBase, MappingStrategyByModelFieldOrder + +CSVReaderType = Iterable[List[str]] + + +@dataclass +class BoolValuePair: + true: str + false: str + + +@dataclass +class CSVFieldDuplicate: + value: Any + duplicate_rows: List[int] + + +class CSVRowDefaultConfig: + anystr_strip_whitespace: bool = True + """Standard pydantic config flag, set default to True.""" + empty_optional_str_fields_to_none: Tuple = ("__all__",) + """List of optional string fields which will be converted to None, if empty. + Default magic value is "__all__" to convert all fields.""" + bool_value_pair: BoolValuePair = BoolValuePair(true="1", false="0") + """Possible boolean values for true and false. If the actual value + is not in defined pair, value will be parsed as None.""" + + +class CSVRow(BaseModel): + """ + Represents a model base for a single CSV row and implements special handling for string values. + If given field value is empty, but annotated type is not string, it will be converted to None. + This is useful for basic types (int, float), to be converted to None if value is not provided. + It's assumed those fields are annotated as Optional, otherwise pydantic will raise expected + validation error. + See Config inner class for more options. + """ + + class Config(CSVRowDefaultConfig): + """ + Defaults from CSVRowDefaultConfig will be used. If you're defining your own Config in + custom CSVRow class, make sure it inherits `CSVRowDefaultConfig` + """ + + @validator("*", pre=True) + def prepare_str_value( + cls: CSVRow, value: Any, field: ModelField # noqa: ANN401 + ) -> Optional[Any]: + # not a string? just return value, pydantic validator will do the rest + if not isinstance(value, str): + return value + # strip whitespace if config say so + if cls.Config.anystr_strip_whitespace: + value = value.strip() + + # special handling for bool values + if field.type_ is bool: + if value == cls.Config.bool_value_pair.true: + return True + if value == cls.Config.bool_value_pair.false: + return False + return None + + # no special handling for non-empty strings + if len(value) > 0: + return value + # empty value and annotated field type is not string? return None + if field.type_ is not str: + return None + # if string field is annotated as optional with 0 length, set it to None + if ( + "__all__" in cls.Config.empty_optional_str_fields_to_none + or field.name in cls.Config.empty_optional_str_fields_to_none + ) and not field.required: + return None + return value + + +CSVLoaderModelType = TypeVar("CSVLoaderModelType", bound=BaseModel) + + +class CSVRows(List[CSVLoaderModelType]): + """Generic parsed CSV rows containing pydantic models.""" + + def get_field_values(self, field_name: str) -> List[Any]: + """Get list of all values from models for named field. + Field value order is preserved.""" + return [getattr(row, field_name) for row in self] + + def get_field_values_unique(self, field_name: str) -> List[Any]: + """Get list of all unique values from models for named field, without duplicates. + Field value order is not preserved.""" + return list(set(self.get_field_values(field_name))) + + def get_field_duplicates(self, field_name: str) -> List[CSVFieldDuplicate]: + """Get list of fields with duplicate values.""" + check_dict: Dict[Any, List] = collections.defaultdict(list) + all_field_values = self.get_field_values(field_name) + + for row_index, value in enumerate(all_field_values): + check_dict[value].append(row_index) + + result: List[CSVFieldDuplicate] = [ + CSVFieldDuplicate( + value=value, + duplicate_rows=found_in_rows, + ) + for value, found_in_rows in check_dict.items() + if len(found_in_rows) > 1 + ] + + return result + + def dict_list(self) -> List[Dict]: + """Get list of all rows converted to dict.""" + return [row.dict() for row in self] + + def json(self) -> str: + """Get json representation of all rows.""" + return json.dumps(self.dict_list()) + + +class CSVLoaderResult(Generic[CSVLoaderModelType]): + """Generic CSVLoader result. Contains parsed pydantic models, aggregated errors and header content.""" + + def __init__(self) -> None: + self.rows: CSVRows[CSVLoaderModelType] = CSVRows() + self.errors: List[CSVValidationError] = [] + self.header: List[str] = [] + + def has_errors(self) -> bool: + return len(self.errors) > 0 + + +class CSVLoader(Generic[CSVLoaderModelType]): + """ + Generic CSV file parser. + Uses standard csv reader to fetch csv rows, validate against provided + pydantic model and returns list of created models together with + aggregated error list. + + Example: + + with open("data.csv") as csv_file: + reader = csv.reader(csv_file, delimiter=",") + + csv_loader = CSVLoader[MyRowModel]( + reader=reader, + output_model_cls=MyRowModel, + has_header=True, + aggregate_errors=True, + ) + + result = csv_loader.read_rows() + if result.has_errors(): + print("Errors:") + for error in result.errors: + print(error) + + print("Created models:") + for row in result.rows: + print(row.index, row.organization_id) + + See tests/adapters/tools/test_csv_loader.py for more examples. + """ + + def __init__( + self, + reader: CSVReaderType, + output_model_cls: Type[CSVLoaderModelType], + has_header: Optional[bool] = True, + aggregate_errors: Optional[bool] = False, + mapping_strategy: Optional[MappingStrategyBase] = None, + ) -> None: + self.reader = reader + self.output_model_cls = output_model_cls + self.has_header = has_header + self.aggregate_errors = aggregate_errors + + if mapping_strategy: + self.mapping_strategy = mapping_strategy + else: + self.mapping_strategy = MappingStrategyByModelFieldOrder( + model_cls=self.output_model_cls, + ) + + self.mapping_strategy.validate_csv_loader_configuration(csv_loader=self) + + def read_rows(self) -> CSVLoaderResult[CSVLoaderModelType]: + result = CSVLoaderResult[CSVLoaderModelType]() + + for line_number, row in enumerate(self.reader): + # skip header, if configured and first line + if self.has_header and line_number == 0: + # strip header field names + header = [field.strip() for field in row] + result.header = header + self.mapping_strategy.set_header(header) + continue + + # skip empty lines + if not row: + continue + + row_model = None + try: + # create model kwargs params using mapping strategy + model_create_kwargs = self.mapping_strategy.create_model_param_dict( + row_values=row, + ) + # create output model from row data + row_model = self.output_model_cls(**model_create_kwargs) + except (MappingStrategyError, ValidationError) as ex: + # create extended error object + error = CSVValidationError( + line_number=line_number, + original_error=ex, + ) + if self.aggregate_errors: + # if we're aggregating errors, just add exception to the list + result.errors.append(error) + else: + # else just raise error and stop reading rows + raise error + + # row_model will be None if creation fails and error aggregation is active + if row_model is not None: + result.rows.append(row_model) + + return result + + def row_index_to_line_number(self, row_index: int) -> int: + return row_index if not self.has_header else row_index + 1 diff --git a/lib/csv_loader/errors.py b/lib/csv_loader/errors.py new file mode 100644 index 0000000..ccb28ba --- /dev/null +++ b/lib/csv_loader/errors.py @@ -0,0 +1,21 @@ +class CSVValidationError(Exception): + """Extended validation exception class containing additional attributes.""" + + def __init__(self, line_number: int, original_error: Exception) -> None: + self.line_number = line_number + self.original_error = original_error + + def __str__(self) -> str: + return f"Error at line {self.line_number}: {self.original_error}" + + +class MappingStrategyError(Exception): + ... + + +class HeaderNotSetError(MappingStrategyError): + detail = "Header must be set in order to use MappingStrategyByHeader" + + +class IndexOutOfHeaderBounds(MappingStrategyError): + detail = "Row value index out of header bounds" diff --git a/lib/csv_loader/mapping_strategies.py b/lib/csv_loader/mapping_strategies.py new file mode 100644 index 0000000..d477cd4 --- /dev/null +++ b/lib/csv_loader/mapping_strategies.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type, cast + +from pydantic import BaseModel + +from .errors import HeaderNotSetError, IndexOutOfHeaderBounds + + +class MappingStrategyBase(ABC): + """ + Mapping strategy implements mechanism of creating params (kwargs) dict from + row values which is later used in model creation. + """ + + def __init__(self, model_cls: Type[BaseModel]) -> None: + self.model_cls = model_cls + self.header: Optional[List[str]] = None + + def set_header(self, header: List[str]) -> None: + self.header = header + + @abstractmethod + def create_model_param_dict(self, row_values: List[Any]) -> Dict[str, Any]: + """Create initial model params dict.""" + + @classmethod + def validate_csv_loader_configuration( + cls: Type[MappingStrategyBase], csv_loader: object + ) -> bool: + return True + + +class MappingStrategyByModelFieldOrder(MappingStrategyBase): + """ + Implements 1:1 field assignment. Each row value is assigned to model attribute + in order in which is defined in model. + """ + + def __init__(self, model_cls: Type[BaseModel]) -> None: + super().__init__(model_cls) + self.field_names = self.model_cls.__fields__.keys() + + def create_model_param_dict(self, row_values: List[Any]) -> Dict[str, Any]: + # map model field names as dict keys + return dict(zip(self.field_names, row_values)) + + +@dataclass +class HeaderRemapField: + header_field: str + model_attr: str + + +class MappingStrategyByHeader(MappingStrategyBase): + """Implements by-header assignment. Header must be present.""" + + def __init__( + self, + model_cls: Type[BaseModel], + header_remap_fields: Optional[List[HeaderRemapField]] = None, + ) -> None: + super().__init__(model_cls) + self.header: List[str] = [] + self.header_remap = header_remap_fields + + @classmethod + def validate_csv_loader_configuration( + cls: Type[MappingStrategyByHeader], csv_loader: object + ) -> bool: + # avoid circular imports and keep mypy happy + from .csv_loader import CSVLoader + + csv_loader = cast(CSVLoader, csv_loader) + + if not csv_loader.has_header: + raise HeaderNotSetError() + + return True + + @staticmethod + def _remap_header_mapping( + header_mapping: Dict[str, Any], header_remap: Optional[List[HeaderRemapField]] + ) -> Dict[str, Any]: + if not header_remap: + return header_mapping + + header_mapping = header_mapping.copy() + + for remap_field in header_remap: + if remap_field.header_field in header_mapping: + header_mapping[remap_field.model_attr] = header_mapping.pop( + remap_field.header_field + ) + return header_mapping + + def create_model_param_dict(self, row_values: List[Any]) -> Dict[str, Any]: + # header not set? stop! hammer time! + if not self.header: + raise HeaderNotSetError() + + # header too short, can't do + if len(row_values) > len(self.header): + raise IndexOutOfHeaderBounds() + + # map header values as dict keys + header_mapping = dict(zip(self.header, row_values)) + + header_mapping = self._remap_header_mapping( + header_mapping=header_mapping, header_remap=self.header_remap + ) + return header_mapping diff --git a/loader_1_simple.py b/loader_1_simple.py new file mode 100644 index 0000000..64de371 --- /dev/null +++ b/loader_1_simple.py @@ -0,0 +1,48 @@ +import csv +from enum import IntEnum +from lib.csv_loader import CSVLoader, CSVRow +from lib.csv_loader.csv_loader import BoolValuePair, CSVRowDefaultConfig + + +class ActionEnum(IntEnum): + INSERT = 0 + UPDATE = 1 + DELETE = 2 + + +class AssetRow(CSVRow): + asset_id: str + asset_action: ActionEnum + serial_number: str + brand_id: int + model_id: int + is_active: bool + + class Config(CSVRowDefaultConfig): + bool_value_pair = BoolValuePair(true="YES", false="NO") + + +def main(): + with open("csv/assets.csv", "r") as f: + reader = csv.reader(f, delimiter=";") + csv_loader = CSVLoader[AssetRow]( + reader=reader, + output_model_cls=AssetRow, + has_header=True, + aggregate_errors=True, + ) + + result = csv_loader.read_rows() + + "Results:" + for row in result.rows: + print(row) + + if result.has_errors(): + print("Errors:") + for error in result.errors: + print(f"Line: {error.line_number}: {error.original_error}") + + +if __name__ == "__main__": + main() diff --git a/loader_2_header.py b/loader_2_header.py new file mode 100644 index 0000000..46e61f9 --- /dev/null +++ b/loader_2_header.py @@ -0,0 +1,53 @@ +import csv +from enum import IntEnum +from lib.csv_loader import CSVLoader, CSVRow +from lib.csv_loader.csv_loader import BoolValuePair, CSVRowDefaultConfig +from lib.csv_loader.mapping_strategies import MappingStrategyByHeader + + +class ActionEnum(IntEnum): + INSERT = 0 + UPDATE = 1 + DELETE = 2 + + +class AssetRow(CSVRow): + asset_action: ActionEnum + asset_id: str + brand_id: int + model_id: int + serial_number: str + is_active: bool + + class Config(CSVRowDefaultConfig): + bool_value_pair = BoolValuePair(true="YES", false="NO") + + +def main(): + with open("csv/assets.csv", "r") as f: + reader = csv.reader(f, delimiter=";") + + mapping_strategy = MappingStrategyByHeader(model_cls=AssetRow) + + csv_loader = CSVLoader[AssetRow]( + reader=reader, + output_model_cls=AssetRow, + has_header=True, + aggregate_errors=True, + mapping_strategy=mapping_strategy, + ) + + result = csv_loader.read_rows() + + "Results:" + for row in result.rows: + print(row) + + if result.has_errors(): + print("Errors:") + for error in result.errors: + print(f"Line: {error.line_number}: {error.original_error}") + + +if __name__ == "__main__": + main() diff --git a/loader_3_header_mapper.py b/loader_3_header_mapper.py new file mode 100644 index 0000000..2fbcc28 --- /dev/null +++ b/loader_3_header_mapper.py @@ -0,0 +1,65 @@ +import csv +from enum import IntEnum +from lib.csv_loader import CSVLoader, CSVRow +from lib.csv_loader.csv_loader import BoolValuePair, CSVRowDefaultConfig +from lib.csv_loader.mapping_strategies import HeaderRemapField, MappingStrategyByHeader + + +class ActionEnum(IntEnum): + INSERT = 0 + UPDATE = 1 + DELETE = 2 + + +class AssetRow(CSVRow): + asset_action: ActionEnum + asset_id: str + brand_id: int + model_id: int + serial_number: str + is_active: bool + + class Config(CSVRowDefaultConfig): + bool_value_pair = BoolValuePair(true="YES", false="NO") + + +def main(): + with open("csv/assets-header-remap.csv", "r") as f: + reader = csv.reader(f, delimiter=";") + + mapping_strategy = MappingStrategyByHeader( + model_cls=AssetRow, + header_remap_fields=[ + HeaderRemapField(header_field="Asset Id", model_attr="asset_id"), + HeaderRemapField(header_field="Action", model_attr="asset_action"), + HeaderRemapField( + header_field="Serial Number", model_attr="serial_number" + ), + HeaderRemapField(header_field="Brand Id", model_attr="brand_id"), + HeaderRemapField(header_field="Model Id", model_attr="model_id"), + HeaderRemapField(header_field="Is Active", model_attr="is_active"), + ], + ) + + csv_loader = CSVLoader[AssetRow]( + reader=reader, + output_model_cls=AssetRow, + has_header=True, + aggregate_errors=True, + mapping_strategy=mapping_strategy, + ) + + result = csv_loader.read_rows() + + "Results:" + for row in result.rows: + print(row) + + if result.has_errors(): + print("Errors:") + for error in result.errors: + print(f"Line: {error.line_number}: {error.original_error}") + + +if __name__ == "__main__": + main() diff --git a/loader_3_remote_reader.py b/loader_3_remote_reader.py new file mode 100644 index 0000000..9299ce7 --- /dev/null +++ b/loader_3_remote_reader.py @@ -0,0 +1,9 @@ +from lib.csv_loader import CSVLoader + + +def main(): + ... + + +if __name__ == "__main__": + main()