From 8f4ea5b6486a208b2851288bf1efca70d080881f Mon Sep 17 00:00:00 2001 From: Eden Kirin Date: Wed, 19 Jun 2024 19:02:40 +0200 Subject: [PATCH] Good operators --- app/main.go | 13 +- app/repository/smartfilter/filterfield.go | 99 ++++++++-- app/repository/smartfilter/operators.go | 22 +++ app/repository/smartfilter/smartfilter.go | 215 +++++++++++++++++++--- 4 files changed, 301 insertions(+), 48 deletions(-) create mode 100644 app/repository/smartfilter/operators.go diff --git a/app/main.go b/app/main.go index ad9e728..b9b4a31 100644 --- a/app/main.go +++ b/app/main.go @@ -9,7 +9,6 @@ import ( "repo-pattern/app/repository" "repo-pattern/app/repository/smartfilter" - "github.com/google/uuid" "gorm.io/gorm" ) @@ -22,11 +21,17 @@ func doMagic(db *gorm.DB) { var err error query := db - id, _ := uuid.FromBytes([]byte("6dc096ab-5c03-427e-b808-c669f7446131")) + // id, _ := uuid.FromBytes([]byte("6dc096ab-5c03-427e-b808-c669f7446131")) + // serialNumber := "222" + // serialNumberContains := "323" + issuer := "FINA" f := smartfilter.SmartCertFilter[models.Cert]{ - Alive: &TRUE, - Id: &id, + Alive: &FALSE, + // Id: &id, + // SerialNumber: &serialNumber, + // SerialNumberContains: &serialNumberContains, + IssuerContains: &issuer, } query, err = f.ToQuery(query) diff --git a/app/repository/smartfilter/filterfield.go b/app/repository/smartfilter/filterfield.go index 73b4878..bc58810 100644 --- a/app/repository/smartfilter/filterfield.go +++ b/app/repository/smartfilter/filterfield.go @@ -3,45 +3,94 @@ package smartfilter import ( "fmt" "reflect" + "strconv" "github.com/google/uuid" ) -type valueGetterFunc func(ff *FilterField, v reflect.Value) error - type FilterField struct { Name string Operator Operator - boolValue bool - intValue int64 - uintValue uint64 - floatValue float64 - strValue string + valueKind reflect.Kind + boolValue *bool + intValue *int64 + uintValue *uint64 + floatValue *float64 + strValue *string } +func (ff *FilterField) getValue(v reflect.Value) string { + fn := typeGetter(v.Type()) + fn(ff, v) + + switch ff.valueKind { + case reflect.Bool: + if *ff.boolValue { + return "TRUE" + } else { + return "FALSE" + } + case reflect.Int: + return strconv.FormatInt(*ff.intValue, 10) + case reflect.Uint: + return strconv.FormatUint(*ff.uintValue, 10) + case reflect.Float32: + return "some float 32" + case reflect.Float64: + return "some float 64" + case reflect.String: + return *ff.strValue + } + return "???" +} + +func (ff *FilterField) getValueWithOperator(v reflect.Value) string { + value := ff.getValue(v) + + switch ff.valueKind { + case reflect.Bool: + return fmt.Sprintf("IS %s", value) + case reflect.Int, reflect.Uint, reflect.Float32, reflect.Float64, reflect.String: + return fmt.Sprintf("= %s", value) + } + return "???" +} + +type valueGetterFunc func(ff *FilterField, v reflect.Value) error + func boolValueGetter(ff *FilterField, v reflect.Value) error { - ff.boolValue = v.Bool() + value := v.Bool() + ff.boolValue = &value + ff.valueKind = reflect.Bool return nil } func intValueGetter(ff *FilterField, v reflect.Value) error { - ff.intValue = v.Int() + value := v.Int() + ff.intValue = &value + ff.valueKind = reflect.Int return nil } func uintValueGetter(ff *FilterField, v reflect.Value) error { - ff.uintValue = v.Uint() + value := v.Uint() + ff.uintValue = &value + ff.valueKind = reflect.Uint return nil } func floatValueGetter(ff *FilterField, v reflect.Value) error { - ff.floatValue = v.Float() + value := v.Float() + ff.floatValue = &value + ff.valueKind = reflect.Float64 return nil } func strValueGetter(ff *FilterField, v reflect.Value) error { - ff.strValue = v.String() + value := v.String() + ff.strValue = &value + ff.valueKind = reflect.String return nil } @@ -50,7 +99,9 @@ func uuidValueGetter(ff *FilterField, v reflect.Value) error { if err != nil { return err } - ff.strValue = uid.String() + value := uid.String() + ff.strValue = &value + ff.valueKind = reflect.String return nil } @@ -58,6 +109,10 @@ func unsupportedValueGetter(ff *FilterField, v reflect.Value) error { return fmt.Errorf("unsupported type: %v", v.Type()) } +func typeGetter(t reflect.Type) valueGetterFunc { + return newTypeGetter(t, true) +} + func newTypeGetter(t reflect.Type, allowAddr bool) valueGetterFunc { // If we have a non-pointer value whose type implements // Marshaler with a value receiver, then we're better off taking @@ -97,9 +152,23 @@ func newTypeGetter(t reflect.Type, allowAddr bool) valueGetterFunc { // return newSliceEncoder(t) // case reflect.Array: // return newArrayEncoder(t) - // case reflect.Pointer: - // return newPtrEncoder(t) + case reflect.Pointer: + return newPtrValueGetter(t) default: return unsupportedValueGetter } } + +type ptrValueGetter struct { + elemEnc valueGetterFunc +} + +func (pvg ptrValueGetter) getValue(ff *FilterField, v reflect.Value) error { + pvg.elemEnc(ff, v.Elem()) + return nil +} + +func newPtrValueGetter(t reflect.Type) valueGetterFunc { + enc := ptrValueGetter{elemEnc: typeGetter(t.Elem())} + return enc.getValue +} diff --git a/app/repository/smartfilter/operators.go b/app/repository/smartfilter/operators.go new file mode 100644 index 0000000..24e0c61 --- /dev/null +++ b/app/repository/smartfilter/operators.go @@ -0,0 +1,22 @@ +package smartfilter + +type Operator string + +const ( + OperatorEQ Operator = "EQ" + OperatorNE Operator = "NE" + OperatorGT Operator = "GT" + OperatorGE Operator = "GE" + OperatorLT Operator = "LT" + OperatorLE Operator = "LE" + OperatorLIKE Operator = "LIKE" + OperatorILIKE Operator = "ILIKE" + OperatorIN Operator = "IN" +) + +var OPERATORS = []Operator{ + OperatorEQ, OperatorNE, + OperatorGT, OperatorGE, OperatorLT, OperatorLE, + OperatorLIKE, OperatorILIKE, + OperatorIN, +} diff --git a/app/repository/smartfilter/smartfilter.go b/app/repository/smartfilter/smartfilter.go index c076bfb..94049be 100644 --- a/app/repository/smartfilter/smartfilter.go +++ b/app/repository/smartfilter/smartfilter.go @@ -14,21 +14,28 @@ import ( const TAG_NAME = "filterfield" const TAG_VALUE_SEPARATOR = "," -type Operator string +type handlerFunc func(query *gorm.DB, tableName string, filterField *FilterField) (*gorm.DB, error) -const ( - OperatorEQ Operator = "EQ" - OperatorIN Operator = "IN" -) - -var OPERATORS = []Operator{OperatorEQ, OperatorIN} +var operatorHandlers = map[Operator]handlerFunc{ + OperatorEQ: handleOperatorEQ, + OperatorNE: handleOperatorNE, + OperatorGT: handleOperatorGT, + OperatorGE: handleOperatorGE, + OperatorLT: handleOperatorLT, + OperatorLE: handleOperatorLE, + OperatorLIKE: handleOperatorLIKE, + OperatorILIKE: handleOperatorILIKE, +} type SmartCertFilter[T schema.Tabler] struct { - Model T - Alive *bool `filterfield:"alive,EQ"` - Id *uuid.UUID `filterfield:"id,EQ"` - Ids *[]uuid.UUID `filterfield:"id,IN"` - CompanyId *uuid.UUID `filterfield:"company_id,EQ"` + Model T + Alive *bool `filterfield:"alive,EQ"` + SerialNumber *string `filterfield:"serial_number,NE"` + SerialNumberContains *string `filterfield:"serial_number,LIKE"` + IssuerContains *string `filterfield:"issuer,ILIKE"` + Id *uuid.UUID `filterfield:"id,EQ"` + Ids *[]uuid.UUID `filterfield:"id,IN"` + CompanyId *uuid.UUID `filterfield:"company_id,EQ"` } func (f SmartCertFilter[T]) ToQuery(query *gorm.DB) (*gorm.DB, error) { @@ -60,36 +67,185 @@ func (f SmartCertFilter[T]) ToQuery(query *gorm.DB) (*gorm.DB, error) { t := fieldReflect.Type() fmt.Printf(">>> %+v --- %+v\n", field, t) - filterField, err := getFilterField(tagValue) + filterField, err := newFilterField(tagValue) if err != nil { return nil, fmt.Errorf("%s.%s: %s", modelName, field.Name, err) } - // fmt.Printf( - // "tagValue: %s, Name: %s, Operator: %s\n", - // tagValue, - // filterField.Name, - // filterField.Operator, - // ) + strValue := filterField.getValue(fieldReflect) + fmt.Printf(">>> filterField: %+v ==== %s\n", filterField, strValue) - switch filterField.Operator { - case OperatorEQ: - query = applyFilterEQ[string](query, tableName, filterField) + operatorHandler, ok := operatorHandlers[filterField.Operator] + if !ok { + return nil, fmt.Errorf("no handler for operator %s", filterField.Operator) + } + + query, err = operatorHandler(query, tableName, filterField) + if err != nil { + return nil, err } } - // query = query.Where("certificates.alive=?", true) - return query, nil } -func applyFilterEQ[T int | bool | string](query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { - // query = query.Where(fmt.Sprint("%s.%s = ?", tableName, filterField.Name), ) - - return query +func handleOperatorEQ(query *gorm.DB, tableName string, filterField *FilterField) (*gorm.DB, error) { + switch filterField.valueKind { + case reflect.Bool: + query = applyFilterEQ(query, tableName, filterField, *filterField.boolValue) + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + query = applyFilterEQ(query, tableName, filterField, *filterField.intValue) + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + query = applyFilterEQ(query, tableName, filterField, *filterField.uintValue) + case reflect.Float32, reflect.Float64: + query = applyFilterEQ(query, tableName, filterField, *filterField.floatValue) + case reflect.String: + query = applyFilterEQ(query, tableName, filterField, *filterField.strValue) + default: + return nil, fmt.Errorf("invalid field type for operator %s", filterField.Operator) + } + return query, nil } -func getFilterField(tagValue string) (*FilterField, error) { +func applyFilterEQ[T bool | int64 | uint64 | float64 | string]( + query *gorm.DB, tableName string, filterField *FilterField, value T, +) *gorm.DB { + return query.Where(fmt.Sprintf("%s.%s = ?", tableName, filterField.Name), value) +} + +func handleOperatorNE(query *gorm.DB, tableName string, filterField *FilterField) (*gorm.DB, error) { + switch filterField.valueKind { + case reflect.Bool: + query = applyFilterNE(query, tableName, filterField, *filterField.boolValue) + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + query = applyFilterNE(query, tableName, filterField, *filterField.intValue) + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + query = applyFilterNE(query, tableName, filterField, *filterField.uintValue) + case reflect.Float32, reflect.Float64: + query = applyFilterNE(query, tableName, filterField, *filterField.floatValue) + case reflect.String: + query = applyFilterNE(query, tableName, filterField, *filterField.strValue) + default: + return nil, fmt.Errorf("invalid field type for operator %s", filterField.Operator) + } + return query, nil +} + +func applyFilterNE[T bool | int64 | uint64 | float64 | string]( + query *gorm.DB, tableName string, filterField *FilterField, value T, +) *gorm.DB { + return query.Where(fmt.Sprintf("%s.%s != ?", tableName, filterField.Name), value) +} + +func handleOperatorLIKE(query *gorm.DB, tableName string, filterField *FilterField) (*gorm.DB, error) { + switch filterField.valueKind { + case reflect.String: + query = applyFilterLIKE(query, tableName, filterField, *filterField.strValue) + default: + return nil, fmt.Errorf("invalid field type for operator %s", filterField.Operator) + } + return query, nil +} + +func applyFilterLIKE(query *gorm.DB, tableName string, filterField *FilterField, value string) *gorm.DB { + return query.Where(fmt.Sprintf("%s.%s LIKE ?", tableName, filterField.Name), fmt.Sprintf("%%%s%%", value)) +} + +func handleOperatorILIKE(query *gorm.DB, tableName string, filterField *FilterField) (*gorm.DB, error) { + switch filterField.valueKind { + case reflect.String: + query = applyFilterILIKE(query, tableName, filterField, *filterField.strValue) + default: + return nil, fmt.Errorf("invalid field type for operator %s", filterField.Operator) + } + return query, nil +} + +func applyFilterILIKE(query *gorm.DB, tableName string, filterField *FilterField, value string) *gorm.DB { + return query.Where(fmt.Sprintf("%s.%s ILIKE ?", tableName, filterField.Name), fmt.Sprintf("%%%s%%", value)) +} + +func handleOperatorGT(query *gorm.DB, tableName string, filterField *FilterField) (*gorm.DB, error) { + switch filterField.valueKind { + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + query = applyFilterGT(query, tableName, filterField, *filterField.intValue) + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + query = applyFilterGT(query, tableName, filterField, *filterField.uintValue) + case reflect.Float32, reflect.Float64: + query = applyFilterGT(query, tableName, filterField, *filterField.floatValue) + default: + return nil, fmt.Errorf("invalid field type for operator %s", filterField.Operator) + } + return query, nil +} + +func applyFilterGT[T bool | int64 | uint64 | float64 | string]( + query *gorm.DB, tableName string, filterField *FilterField, value T, +) *gorm.DB { + return query.Where(fmt.Sprintf("%s.%s > ?", tableName, filterField.Name), value) +} + +func handleOperatorGE(query *gorm.DB, tableName string, filterField *FilterField) (*gorm.DB, error) { + switch filterField.valueKind { + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + query = applyFilterGE(query, tableName, filterField, *filterField.intValue) + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + query = applyFilterGE(query, tableName, filterField, *filterField.uintValue) + case reflect.Float32, reflect.Float64: + query = applyFilterGE(query, tableName, filterField, *filterField.floatValue) + default: + return nil, fmt.Errorf("invalid field type for operator %s", filterField.Operator) + } + return query, nil +} + +func applyFilterGE[T bool | int64 | uint64 | float64 | string]( + query *gorm.DB, tableName string, filterField *FilterField, value T, +) *gorm.DB { + return query.Where(fmt.Sprintf("%s.%s >= ?", tableName, filterField.Name), value) +} + +func handleOperatorLT(query *gorm.DB, tableName string, filterField *FilterField) (*gorm.DB, error) { + switch filterField.valueKind { + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + query = applyFilterLT(query, tableName, filterField, *filterField.intValue) + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + query = applyFilterLT(query, tableName, filterField, *filterField.uintValue) + case reflect.Float32, reflect.Float64: + query = applyFilterLT(query, tableName, filterField, *filterField.floatValue) + default: + return nil, fmt.Errorf("invalid field type for operator %s", filterField.Operator) + } + return query, nil +} + +func applyFilterLT[T bool | int64 | uint64 | float64 | string]( + query *gorm.DB, tableName string, filterField *FilterField, value T, +) *gorm.DB { + return query.Where(fmt.Sprintf("%s.%s < ?", tableName, filterField.Name), value) +} + +func handleOperatorLE(query *gorm.DB, tableName string, filterField *FilterField) (*gorm.DB, error) { + switch filterField.valueKind { + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + query = applyFilterLE(query, tableName, filterField, *filterField.intValue) + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + query = applyFilterLE(query, tableName, filterField, *filterField.uintValue) + case reflect.Float32, reflect.Float64: + query = applyFilterLE(query, tableName, filterField, *filterField.floatValue) + default: + return nil, fmt.Errorf("invalid field type for operator %s", filterField.Operator) + } + return query, nil +} + +func applyFilterLE[T bool | int64 | uint64 | float64 | string]( + query *gorm.DB, tableName string, filterField *FilterField, value T, +) *gorm.DB { + return query.Where(fmt.Sprintf("%s.%s <= ?", tableName, filterField.Name), value) +} + +func newFilterField(tagValue string) (*FilterField, error) { values := strings.Split(tagValue, TAG_VALUE_SEPARATOR) if len(values) != 2 { return nil, fmt.Errorf("incorrect number of tag values: %s", tagValue) @@ -101,6 +257,7 @@ func getFilterField(tagValue string) (*FilterField, error) { } f := FilterField{ + Name: values[0], Operator: operator, } return &f, nil