From 14bc29d7e39e3c9d6806e232f12434c2a8679179 Mon Sep 17 00:00:00 2001 From: Eden Kirin Date: Sat, 22 Jun 2024 13:34:27 +0200 Subject: [PATCH 1/4] Working version --- app/inheritance/inheritance.go | 74 ++++++++++ app/inheritance/methods.go | 43 ++++++ app/main.go | 76 +++++++++-- app/repository/methods.go | 84 ++++++++++++ app/repository/repository.go | 29 ++++ app/repository/smartfilter/filterfield.go | 50 ++++--- app/repository/smartfilter/filters.go | 6 + app/repository/smartfilter/handlers.go | 125 +++++++++++++++++ app/repository/smartfilter/smartfilter.go | 156 +++++----------------- 9 files changed, 496 insertions(+), 147 deletions(-) create mode 100644 app/inheritance/inheritance.go create mode 100644 app/inheritance/methods.go create mode 100644 app/repository/methods.go create mode 100644 app/repository/repository.go create mode 100644 app/repository/smartfilter/handlers.go diff --git a/app/inheritance/inheritance.go b/app/inheritance/inheritance.go new file mode 100644 index 0000000..2f0ca61 --- /dev/null +++ b/app/inheritance/inheritance.go @@ -0,0 +1,74 @@ +package inheritance + +import "fmt" + +type Model struct{} + +type MyModel struct { + Model +} + +type MethodInitInterface interface { + Init(dbConn int) +} + +type RepoBase[T interface{}] struct { + DbConn int + GetMethod[T] + ListMethod[T] + methods []MethodInitInterface +} + +func (b *RepoBase[T]) InitMethods(dbConn int) { + for _, method := range b.methods { + method.Init(dbConn) + } +} + +type CRUDRepo[T interface{}] struct { + RepoBase[T] + SaveMethod[T] +} + +func (m *CRUDRepo[T]) Init(dbConn int) { + m.methods = []MethodInitInterface{&m.GetMethod, &m.ListMethod, &m.SaveMethod} + m.InitMethods(dbConn) +} + +func DoInheritanceTest() { + repo := RepoBase[MyModel]{ + DbConn: 111, + // GetMethod: GetMethod{ + // DbConn: 666, + // }, + // ListMethod: ListMethod{ + // DbConn: 777, + // }, + } + repo.GetMethod.Init(888) + repo.ListMethod.Init(888) + + repo.GetMethod.Get() + repo.List() + + fmt.Printf("outside Base: %d\n", repo.DbConn) + fmt.Printf("outside GetMethod: %d\n", repo.GetMethod.DbConn) + fmt.Printf("outside ListMethod: %d\n", repo.ListMethod.DbConn) + + fmt.Println("----------------") + + crudRepo := CRUDRepo[MyModel]{} + crudRepo.Init(999) + + crudRepo.Get() + crudRepo.List() + crudRepo.Save() + + fmt.Printf("outside GetMethod: %d\n", crudRepo.GetMethod.DbConn) + fmt.Printf("outside ListMethod: %d\n", crudRepo.ListMethod.DbConn) + fmt.Printf("outside SaveMethod: %d\n", crudRepo.SaveMethod.DbConn) + + // repo.DbConn = 123 + // repo.SomeGetVar = 456 + // repo.DoSomething() +} diff --git a/app/inheritance/methods.go b/app/inheritance/methods.go new file mode 100644 index 0000000..3a46116 --- /dev/null +++ b/app/inheritance/methods.go @@ -0,0 +1,43 @@ +package inheritance + +import "fmt" + +type GetMethod[T interface{}] struct { + SomeGetVar int + DbConn int +} + +func (m *GetMethod[T]) Init(dbConn int) { + m.DbConn = dbConn +} + +func (m GetMethod[T]) Get() T { + var model T + fmt.Printf("Get DbConn: %d\n", m.DbConn) + return model +} + +type ListMethod[T interface{}] struct { + SomeListVar int + DbConn int +} + +func (m *ListMethod[T]) Init(dbConn int) { + m.DbConn = dbConn +} + +func (m ListMethod[T]) List() { + fmt.Printf("List DbConn: %d\n", m.DbConn) +} + +type SaveMethod[T interface{}] struct { + DbConn int +} + +func (m *SaveMethod[T]) Init(dbConn int) { + m.DbConn = dbConn +} + +func (m SaveMethod[T]) Save() { + fmt.Printf("List DbConn: %d\n", m.DbConn) +} diff --git a/app/main.go b/app/main.go index 1693148..e0c86b5 100644 --- a/app/main.go +++ b/app/main.go @@ -25,20 +25,22 @@ func doMagic(db *gorm.DB) { // id := "6dc096ab-5c03-427e-b808-c669f7446131" // serialNumber := "222" // serialNumberContains := "323" - // issuer := "FINA" + issuer := "FINA" location, _ := time.LoadLocation("UTC") createdTime := time.Date(2024, 5, 26, 16, 8, 0, 0, location) + ids := []string{"eb2bcac6-5173-4dbb-93b7-e7c03b924a03", "db9fb837-3483-4736-819d-f427dc8cda23", "1fece5e7-8e8d-4828-8298-3b1f07fd29ff"} - f := smartfilter.SmartCertFilter[models.Cert]{ - // Alive: &FALSE, - // Id: &id, + filter := smartfilter.CertFilter{ + Alive: &FALSE, + // Id: &id, // SerialNumber: &serialNumber, // SerialNumberContains: &serialNumberContains, - // IssuerContains: &issuer, - CreatedAt_Lt: &createdTime, + Ids: &ids, + IssuerContains: &issuer, + CreatedAt_Lt: &createdTime, } - query, err = f.ToQuery(query) + query, err = smartfilter.ToQuery(models.Cert{}, filter, query) if err != nil { panic(err) } @@ -51,6 +53,58 @@ func doMagic(db *gorm.DB) { } } +func doList(db *gorm.DB) { + repo := repository.RepoBase[models.Cert]{} + repo.Init(db) + + filter := smartfilter.CertFilter{ + Alive: &TRUE, + } + + certs, err := repo.List(filter) + if err != nil { + panic(err) + } + + for n, cert := range *certs { + fmt.Printf(">> [%d] %+v %s (alive %t)\n", n, cert.Id, cert.CreatedAt, cert.Alive) + } +} + +func doGet(db *gorm.DB) { + repo := repository.RepoBase[models.Cert]{} + repo.Init(db) + + id := "db9fb837-3483-4736-819d-f427dc8cda23" + + filter := smartfilter.CertFilter{ + Id: &id, + } + + cert, err := repo.Get(filter) + if err != nil { + panic(err) + } + fmt.Printf(">> %+v %s (alive %t)\n", cert.Id, cert.CreatedAt, cert.Alive) +} + +func doExists(db *gorm.DB) { + repo := repository.RepoBase[models.Cert]{} + repo.Init(db) + + id := "db9fb837-3483-4736-819d-f427dc8cda23" + + filter := smartfilter.CertFilter{ + Id: &id, + } + + exists, err := repo.Exists(filter) + if err != nil { + panic(err) + } + fmt.Printf(">> EXISTS: %t\n", exists) +} + func main() { cfg.Init() logging.Init() @@ -59,7 +113,9 @@ func main() { db := db.InitDB() repository.Dao = repository.CreateDAO(db) - doMagic(db) - - fmt.Println("Running...") + // doMagic(db) + // doList(db) + // doGet(db) + doExists(db) + // inheritance.DoInheritanceTest() } diff --git a/app/repository/methods.go b/app/repository/methods.go new file mode 100644 index 0000000..bf1c68a --- /dev/null +++ b/app/repository/methods.go @@ -0,0 +1,84 @@ +package repository + +import ( + "errors" + "repo-pattern/app/repository/smartfilter" + + "gorm.io/gorm" + "gorm.io/gorm/schema" +) + +type ListMethod[T schema.Tabler] struct { + DbConn *gorm.DB +} + +func (m *ListMethod[T]) Init(dbConn *gorm.DB) { + m.DbConn = dbConn +} + +func (m ListMethod[T]) List(filter interface{}) (*[]T, error) { + var ( + model T + models []T + ) + + query, err := smartfilter.ToQuery(model, filter, m.DbConn) + if err != nil { + return nil, err + } + + query.Find(&models) + return &models, nil +} + +type GetMethod[T schema.Tabler] struct { + DbConn *gorm.DB +} + +func (m *GetMethod[T]) Init(dbConn *gorm.DB) { + m.DbConn = dbConn +} + +func (m GetMethod[T]) Get(filter interface{}) (*T, error) { + var ( + model T + ) + + query, err := smartfilter.ToQuery(model, filter, m.DbConn) + if err != nil { + return nil, err + } + + result := query.First(&model) + if result.Error != nil { + return nil, result.Error + } + + return &model, nil +} + +type ExistsMethod[T schema.Tabler] struct { + DbConn *gorm.DB +} + +func (m *ExistsMethod[T]) Init(dbConn *gorm.DB) { + m.DbConn = dbConn +} + +func (m ExistsMethod[T]) Exists(filter interface{}) (bool, error) { + var ( + model T + ) + + query := m.DbConn.Model(model) + + query, err := smartfilter.ToQuery(model, filter, query) + if err != nil { + return false, err + } + + result := query.Select("*").First(&model) + + exists := !errors.Is(result.Error, gorm.ErrRecordNotFound) && result.Error == nil + return exists, nil +} diff --git a/app/repository/repository.go b/app/repository/repository.go new file mode 100644 index 0000000..16a30bd --- /dev/null +++ b/app/repository/repository.go @@ -0,0 +1,29 @@ +package repository + +import ( + "gorm.io/gorm" + "gorm.io/gorm/schema" +) + +type MethodInitInterface interface { + Init(dbConn *gorm.DB) +} + +type RepoBase[T schema.Tabler] struct { + DbConn int + ListMethod[T] + GetMethod[T] + ExistsMethod[T] + methods []MethodInitInterface +} + +func (b *RepoBase[T]) InitMethods(dbConn *gorm.DB) { + for _, method := range b.methods { + method.Init(dbConn) + } +} + +func (m *RepoBase[T]) Init(dbConn *gorm.DB) { + m.methods = []MethodInitInterface{&m.ListMethod, &m.GetMethod, &m.ExistsMethod} + m.InitMethods(dbConn) +} diff --git a/app/repository/smartfilter/filterfield.go b/app/repository/smartfilter/filterfield.go index f82874d..96455c0 100644 --- a/app/repository/smartfilter/filterfield.go +++ b/app/repository/smartfilter/filterfield.go @@ -16,7 +16,6 @@ type FilterField struct { uintValue *uint64 floatValue *float64 strValue *string - timeValue *time.Time } func (ff *FilterField) setValueFromReflection(v reflect.Value) { @@ -24,18 +23,6 @@ func (ff *FilterField) setValueFromReflection(v reflect.Value) { fn(ff, v) } -// func (ff *FilterField) getValueWithOperator(v reflect.Value) string { -// value := ff.getValue(v) - -// switch ff.valueKind { -// case reflect.Bool: -// return fmt.Sprintf("IS %s", value) -// case reflect.Int, reflect.Uint, reflect.Float32, reflect.Float64, reflect.String: -// return fmt.Sprintf("= %s", value) -// } -// return "???" -// } - type valueGetterFunc func(ff *FilterField, v reflect.Value) error func boolValueGetter(ff *FilterField, v reflect.Value) error { @@ -132,10 +119,10 @@ func newTypeGetter(t reflect.Type, allowAddr bool) valueGetterFunc { } // case reflect.Map: // return newMapEncoder(t) - // case reflect.Slice: - // return newSliceEncoder(t) - // case reflect.Array: - // return newArrayEncoder(t) + case reflect.Slice: + return newSliceGetter(t) + case reflect.Array: + return newArrayGetter(t) case reflect.Pointer: return newPtrValueGetter(t) } @@ -155,3 +142,32 @@ func newPtrValueGetter(t reflect.Type) valueGetterFunc { enc := ptrValueGetter{elemEnc: typeGetter(t.Elem())} return enc.getValue } + +type arrayGetter struct { + elemEnc valueGetterFunc +} + +func (ag arrayGetter) getValue(ff *FilterField, v reflect.Value) error { + ag.elemEnc(ff, v.Elem()) + return nil +} + +func newArrayGetter(t reflect.Type) valueGetterFunc { + enc := arrayGetter{elemEnc: typeGetter(t.Elem())} + return enc.getValue +} + +type sliceGetter struct { + elemEnc valueGetterFunc +} + +func (sg sliceGetter) getValue(ff *FilterField, v reflect.Value) error { + fmt.Printf("%+v\n", v.Slice(0, 1)) + sg.elemEnc(ff, v.Elem()) + return nil +} + +func newSliceGetter(t reflect.Type) valueGetterFunc { + enc := sliceGetter{elemEnc: typeGetter(t.Elem())} + return enc.getValue +} diff --git a/app/repository/smartfilter/filters.go b/app/repository/smartfilter/filters.go index 39a6acd..7cc77e1 100644 --- a/app/repository/smartfilter/filters.go +++ b/app/repository/smartfilter/filters.go @@ -49,3 +49,9 @@ func applyFilterLE[T bool | int64 | uint64 | float64 | string]( ) *gorm.DB { return query.Where(fmt.Sprintf("%s.%s <= ?", tableName, filterField.Name), value) } + +func applyFilterIN[T bool | int64 | uint64 | float64 | string]( + query *gorm.DB, tableName string, filterField *FilterField, value T, +) *gorm.DB { + return query.Where(fmt.Sprintf("%s.%s IN ?", tableName, filterField.Name), value) +} diff --git a/app/repository/smartfilter/handlers.go b/app/repository/smartfilter/handlers.go new file mode 100644 index 0000000..68ce695 --- /dev/null +++ b/app/repository/smartfilter/handlers.go @@ -0,0 +1,125 @@ +package smartfilter + +import ( + "reflect" + + "gorm.io/gorm" +) + +func handleOperatorEQ(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { + switch filterField.valueKind { + case reflect.Bool: + return applyFilterEQ(query, tableName, filterField, *filterField.boolValue) + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return applyFilterEQ(query, tableName, filterField, *filterField.intValue) + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return applyFilterEQ(query, tableName, filterField, *filterField.uintValue) + case reflect.Float32, reflect.Float64: + return applyFilterEQ(query, tableName, filterField, *filterField.floatValue) + case reflect.String: + return applyFilterEQ(query, tableName, filterField, *filterField.strValue) + } + return nil +} + +func handleOperatorNE(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { + switch filterField.valueKind { + case reflect.Bool: + return applyFilterNE(query, tableName, filterField, *filterField.boolValue) + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return applyFilterNE(query, tableName, filterField, *filterField.intValue) + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return applyFilterNE(query, tableName, filterField, *filterField.uintValue) + case reflect.Float32, reflect.Float64: + return applyFilterNE(query, tableName, filterField, *filterField.floatValue) + case reflect.String: + return applyFilterNE(query, tableName, filterField, *filterField.strValue) + } + return nil +} + +func handleOperatorLIKE(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { + switch filterField.valueKind { + case reflect.String: + return applyFilterLIKE(query, tableName, filterField, *filterField.strValue) + } + return nil +} + +func handleOperatorILIKE(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { + switch filterField.valueKind { + case reflect.String: + return applyFilterILIKE(query, tableName, filterField, *filterField.strValue) + } + return nil +} + +func handleOperatorGT(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { + switch filterField.valueKind { + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return applyFilterGT(query, tableName, filterField, *filterField.intValue) + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return applyFilterGT(query, tableName, filterField, *filterField.uintValue) + case reflect.Float32, reflect.Float64: + return applyFilterGT(query, tableName, filterField, *filterField.floatValue) + case reflect.String: + return applyFilterGT(query, tableName, filterField, *filterField.strValue) + } + return nil +} + +func handleOperatorGE(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { + switch filterField.valueKind { + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return applyFilterGE(query, tableName, filterField, *filterField.intValue) + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return applyFilterGE(query, tableName, filterField, *filterField.uintValue) + case reflect.Float32, reflect.Float64: + return applyFilterGE(query, tableName, filterField, *filterField.floatValue) + case reflect.String: + return applyFilterGE(query, tableName, filterField, *filterField.strValue) + } + return nil +} + +func handleOperatorLT(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { + switch filterField.valueKind { + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return applyFilterLT(query, tableName, filterField, *filterField.intValue) + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return applyFilterLT(query, tableName, filterField, *filterField.uintValue) + case reflect.Float32, reflect.Float64: + return applyFilterLT(query, tableName, filterField, *filterField.floatValue) + case reflect.String: + return applyFilterLT(query, tableName, filterField, *filterField.strValue) + } + return nil +} + +func handleOperatorLE(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { + switch filterField.valueKind { + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return applyFilterLE(query, tableName, filterField, *filterField.intValue) + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return applyFilterLE(query, tableName, filterField, *filterField.uintValue) + case reflect.Float32, reflect.Float64: + return applyFilterLE(query, tableName, filterField, *filterField.floatValue) + case reflect.String: + return applyFilterLE(query, tableName, filterField, *filterField.strValue) + } + return nil +} + +func handleOperatorIN(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { + switch filterField.valueKind { + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return applyFilterIN(query, tableName, filterField, *filterField.intValue) + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return applyFilterIN(query, tableName, filterField, *filterField.uintValue) + case reflect.Float32, reflect.Float64: + return applyFilterIN(query, tableName, filterField, *filterField.floatValue) + case reflect.String: + return applyFilterIN(query, tableName, filterField, *filterField.strValue) + } + return nil +} diff --git a/app/repository/smartfilter/smartfilter.go b/app/repository/smartfilter/smartfilter.go index cf87c7d..53e68cd 100644 --- a/app/repository/smartfilter/smartfilter.go +++ b/app/repository/smartfilter/smartfilter.go @@ -25,10 +25,10 @@ var operatorHandlers = map[Operator]handlerFunc{ OperatorLE: handleOperatorLE, OperatorLIKE: handleOperatorLIKE, OperatorILIKE: handleOperatorILIKE, + OperatorIN: handleOperatorIN, } -type SmartCertFilter[T schema.Tabler] struct { - Model T +type CertFilter struct { Alive *bool `filterfield:"alive,EQ"` SerialNumber *string `filterfield:"serial_number,NE"` SerialNumberContains *string `filterfield:"serial_number,LIKE"` @@ -38,15 +38,17 @@ type SmartCertFilter[T schema.Tabler] struct { CreatedAt_Lt *time.Time `filterfield:"created_at,LT"` } -func (f SmartCertFilter[T]) ToQuery(query *gorm.DB) (*gorm.DB, error) { - tableName := f.Model.TableName() +type ReflectedStructField struct { + name string + value reflect.Value + tagValue string +} - fmt.Printf("Table name: %s\n", tableName) - // fmt.Printf("%+v\n", f) +func getFilterFields(filter interface{}) []ReflectedStructField { + res := make([]ReflectedStructField, 0) - st := reflect.TypeOf(f) - modelName := st.Name() - reflectValue := reflect.ValueOf(f) + st := reflect.TypeOf(filter) + reflectValue := reflect.ValueOf(filter) for i := 0; i < st.NumField(); i++ { field := st.Field(i) @@ -57,23 +59,41 @@ func (f SmartCertFilter[T]) ToQuery(query *gorm.DB) (*gorm.DB, error) { continue } - fieldReflect := reflectValue.FieldByName(field.Name) + // get field value + fieldValue := reflectValue.FieldByName(field.Name) // skip field if value is nil - if fieldReflect.IsNil() { + if fieldValue.IsNil() { continue } - t := fieldReflect.Type() - fmt.Printf(">>> %+v --- %+v\n", field, t) + res = append(res, ReflectedStructField{ + name: field.Name, + tagValue: tagValue, + value: fieldValue, + }) + } + return res +} - filterField, err := newFilterField(tagValue) +func ToQuery(model schema.Tabler, filter interface{}, query *gorm.DB) (*gorm.DB, error) { + st := reflect.TypeOf(filter) + + tableName := model.TableName() + modelName := st.Name() + + fmt.Printf("Table name: %s\n", tableName) + fmt.Printf("Model name: %s\n", modelName) + + fields := getFilterFields(filter) + for _, field := range fields { + filterField, err := newFilterField(field.tagValue) if err != nil { - return nil, fmt.Errorf("%s.%s: %s", modelName, field.Name, err) + return nil, fmt.Errorf("%s.%s: %s", modelName, field.name, err) } // must be called! - filterField.setValueFromReflection(fieldReflect) + filterField.setValueFromReflection(field.value) operatorHandler, ok := operatorHandlers[filterField.Operator] if !ok { @@ -106,107 +126,3 @@ func newFilterField(tagValue string) (*FilterField, error) { } return &f, nil } - -func handleOperatorEQ(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { - switch filterField.valueKind { - case reflect.Bool: - return applyFilterEQ(query, tableName, filterField, *filterField.boolValue) - case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: - return applyFilterEQ(query, tableName, filterField, *filterField.intValue) - case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return applyFilterEQ(query, tableName, filterField, *filterField.uintValue) - case reflect.Float32, reflect.Float64: - return applyFilterEQ(query, tableName, filterField, *filterField.floatValue) - case reflect.String: - return applyFilterEQ(query, tableName, filterField, *filterField.strValue) - } - return nil -} - -func handleOperatorNE(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { - switch filterField.valueKind { - case reflect.Bool: - return applyFilterNE(query, tableName, filterField, *filterField.boolValue) - case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: - return applyFilterNE(query, tableName, filterField, *filterField.intValue) - case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return applyFilterNE(query, tableName, filterField, *filterField.uintValue) - case reflect.Float32, reflect.Float64: - return applyFilterNE(query, tableName, filterField, *filterField.floatValue) - case reflect.String: - return applyFilterNE(query, tableName, filterField, *filterField.strValue) - } - return nil -} - -func handleOperatorLIKE(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { - switch filterField.valueKind { - case reflect.String: - return applyFilterLIKE(query, tableName, filterField, *filterField.strValue) - } - return nil -} - -func handleOperatorILIKE(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { - switch filterField.valueKind { - case reflect.String: - return applyFilterILIKE(query, tableName, filterField, *filterField.strValue) - } - return nil -} - -func handleOperatorGT(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { - switch filterField.valueKind { - case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: - return applyFilterGT(query, tableName, filterField, *filterField.intValue) - case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return applyFilterGT(query, tableName, filterField, *filterField.uintValue) - case reflect.Float32, reflect.Float64: - return applyFilterGT(query, tableName, filterField, *filterField.floatValue) - case reflect.String: - return applyFilterGT(query, tableName, filterField, *filterField.strValue) - } - return nil -} - -func handleOperatorGE(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { - switch filterField.valueKind { - case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: - return applyFilterGE(query, tableName, filterField, *filterField.intValue) - case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return applyFilterGE(query, tableName, filterField, *filterField.uintValue) - case reflect.Float32, reflect.Float64: - return applyFilterGE(query, tableName, filterField, *filterField.floatValue) - case reflect.String: - return applyFilterGE(query, tableName, filterField, *filterField.strValue) - } - return nil -} - -func handleOperatorLT(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { - switch filterField.valueKind { - case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: - return applyFilterLT(query, tableName, filterField, *filterField.intValue) - case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return applyFilterLT(query, tableName, filterField, *filterField.uintValue) - case reflect.Float32, reflect.Float64: - return applyFilterLT(query, tableName, filterField, *filterField.floatValue) - case reflect.String: - return applyFilterLT(query, tableName, filterField, *filterField.strValue) - } - return nil -} - -func handleOperatorLE(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { - switch filterField.valueKind { - case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: - return applyFilterLE(query, tableName, filterField, *filterField.intValue) - case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return applyFilterLE(query, tableName, filterField, *filterField.uintValue) - case reflect.Float32, reflect.Float64: - return applyFilterLE(query, tableName, filterField, *filterField.floatValue) - case reflect.String: - return applyFilterLE(query, tableName, filterField, *filterField.strValue) - } - return nil -} From b67bc14e5e437f3436114e7b7fe92f405e7c12e0 Mon Sep 17 00:00:00 2001 From: Eden Kirin Date: Sat, 22 Jun 2024 21:59:39 +0200 Subject: [PATCH 2/4] Working arrays --- app/inheritance/inheritance.go | 74 ------------- app/inheritance/methods.go | 43 -------- app/main.go | 19 ++-- app/repository/smartfilter/filterfield.go | 126 ++++++++++++++++++---- app/repository/smartfilter/filters.go | 10 +- app/repository/smartfilter/handlers.go | 26 ++++- app/repository/smartfilter/operators.go | 21 ++-- app/repository/smartfilter/smartfilter.go | 22 ++-- 8 files changed, 171 insertions(+), 170 deletions(-) delete mode 100644 app/inheritance/inheritance.go delete mode 100644 app/inheritance/methods.go diff --git a/app/inheritance/inheritance.go b/app/inheritance/inheritance.go deleted file mode 100644 index 2f0ca61..0000000 --- a/app/inheritance/inheritance.go +++ /dev/null @@ -1,74 +0,0 @@ -package inheritance - -import "fmt" - -type Model struct{} - -type MyModel struct { - Model -} - -type MethodInitInterface interface { - Init(dbConn int) -} - -type RepoBase[T interface{}] struct { - DbConn int - GetMethod[T] - ListMethod[T] - methods []MethodInitInterface -} - -func (b *RepoBase[T]) InitMethods(dbConn int) { - for _, method := range b.methods { - method.Init(dbConn) - } -} - -type CRUDRepo[T interface{}] struct { - RepoBase[T] - SaveMethod[T] -} - -func (m *CRUDRepo[T]) Init(dbConn int) { - m.methods = []MethodInitInterface{&m.GetMethod, &m.ListMethod, &m.SaveMethod} - m.InitMethods(dbConn) -} - -func DoInheritanceTest() { - repo := RepoBase[MyModel]{ - DbConn: 111, - // GetMethod: GetMethod{ - // DbConn: 666, - // }, - // ListMethod: ListMethod{ - // DbConn: 777, - // }, - } - repo.GetMethod.Init(888) - repo.ListMethod.Init(888) - - repo.GetMethod.Get() - repo.List() - - fmt.Printf("outside Base: %d\n", repo.DbConn) - fmt.Printf("outside GetMethod: %d\n", repo.GetMethod.DbConn) - fmt.Printf("outside ListMethod: %d\n", repo.ListMethod.DbConn) - - fmt.Println("----------------") - - crudRepo := CRUDRepo[MyModel]{} - crudRepo.Init(999) - - crudRepo.Get() - crudRepo.List() - crudRepo.Save() - - fmt.Printf("outside GetMethod: %d\n", crudRepo.GetMethod.DbConn) - fmt.Printf("outside ListMethod: %d\n", crudRepo.ListMethod.DbConn) - fmt.Printf("outside SaveMethod: %d\n", crudRepo.SaveMethod.DbConn) - - // repo.DbConn = 123 - // repo.SomeGetVar = 456 - // repo.DoSomething() -} diff --git a/app/inheritance/methods.go b/app/inheritance/methods.go deleted file mode 100644 index 3a46116..0000000 --- a/app/inheritance/methods.go +++ /dev/null @@ -1,43 +0,0 @@ -package inheritance - -import "fmt" - -type GetMethod[T interface{}] struct { - SomeGetVar int - DbConn int -} - -func (m *GetMethod[T]) Init(dbConn int) { - m.DbConn = dbConn -} - -func (m GetMethod[T]) Get() T { - var model T - fmt.Printf("Get DbConn: %d\n", m.DbConn) - return model -} - -type ListMethod[T interface{}] struct { - SomeListVar int - DbConn int -} - -func (m *ListMethod[T]) Init(dbConn int) { - m.DbConn = dbConn -} - -func (m ListMethod[T]) List() { - fmt.Printf("List DbConn: %d\n", m.DbConn) -} - -type SaveMethod[T interface{}] struct { - DbConn int -} - -func (m *SaveMethod[T]) Init(dbConn int) { - m.DbConn = dbConn -} - -func (m SaveMethod[T]) Save() { - fmt.Printf("List DbConn: %d\n", m.DbConn) -} diff --git a/app/main.go b/app/main.go index e0c86b5..30c07a4 100644 --- a/app/main.go +++ b/app/main.go @@ -8,7 +8,6 @@ import ( "repo-pattern/app/models" "repo-pattern/app/repository" "repo-pattern/app/repository/smartfilter" - "time" "gorm.io/gorm" ) @@ -25,19 +24,19 @@ func doMagic(db *gorm.DB) { // id := "6dc096ab-5c03-427e-b808-c669f7446131" // serialNumber := "222" // serialNumberContains := "323" - issuer := "FINA" - location, _ := time.LoadLocation("UTC") - createdTime := time.Date(2024, 5, 26, 16, 8, 0, 0, location) + // issuer := "FINA" + // location, _ := time.LoadLocation("UTC") + // createdTime := time.Date(2024, 5, 26, 16, 8, 0, 0, location) ids := []string{"eb2bcac6-5173-4dbb-93b7-e7c03b924a03", "db9fb837-3483-4736-819d-f427dc8cda23", "1fece5e7-8e8d-4828-8298-3b1f07fd29ff"} filter := smartfilter.CertFilter{ - Alive: &FALSE, + // Alive: &FALSE, // Id: &id, // SerialNumber: &serialNumber, // SerialNumberContains: &serialNumberContains, - Ids: &ids, - IssuerContains: &issuer, - CreatedAt_Lt: &createdTime, + Ids: &ids, + // IssuerContains: &issuer, + // CreatedAt_Lt: &createdTime, } query, err = smartfilter.ToQuery(models.Cert{}, filter, query) @@ -113,9 +112,9 @@ func main() { db := db.InitDB() repository.Dao = repository.CreateDAO(db) - // doMagic(db) + doMagic(db) // doList(db) // doGet(db) - doExists(db) + // doExists(db) // inheritance.DoInheritanceTest() } diff --git a/app/repository/smartfilter/filterfield.go b/app/repository/smartfilter/filterfield.go index 96455c0..714d433 100644 --- a/app/repository/smartfilter/filterfield.go +++ b/app/repository/smartfilter/filterfield.go @@ -10,12 +10,17 @@ type FilterField struct { Name string Operator Operator - valueKind reflect.Kind - boolValue *bool - intValue *int64 - uintValue *uint64 - floatValue *float64 - strValue *string + valueKind reflect.Kind + boolValue *bool + intValue *int64 + uintValue *uint64 + floatValue *float64 + strValue *string + boolValues *[]bool + intValues *[]int64 + uintValues *[]uint64 + floatValues *[]float64 + strValues *[]string } func (ff *FilterField) setValueFromReflection(v reflect.Value) { @@ -23,6 +28,71 @@ func (ff *FilterField) setValueFromReflection(v reflect.Value) { fn(ff, v) } +func (ff *FilterField) appendStr(value string) { + var valueArray []string + + if ff.strValues == nil { + valueArray = make([]string, 0) + } else { + valueArray = *ff.strValues + } + valueArray = append(valueArray, value) + ff.strValues = &valueArray + ff.valueKind = reflect.String +} + +func (ff *FilterField) appendBool(value bool) { + var valueArray []bool + + if ff.boolValues == nil { + valueArray = make([]bool, 0) + } else { + valueArray = *ff.boolValues + } + valueArray = append(valueArray, value) + ff.boolValues = &valueArray + ff.valueKind = reflect.Bool +} + +func (ff *FilterField) appendInt(value int64) { + var valueArray []int64 + + if ff.boolValues == nil { + valueArray = make([]int64, 0) + } else { + valueArray = *ff.intValues + } + valueArray = append(valueArray, value) + ff.intValues = &valueArray + ff.valueKind = reflect.Int +} + +func (ff *FilterField) appendUint(value uint64) { + var valueArray []uint64 + + if ff.boolValues == nil { + valueArray = make([]uint64, 0) + } else { + valueArray = *ff.uintValues + } + valueArray = append(valueArray, value) + ff.uintValues = &valueArray + ff.valueKind = reflect.Int +} + +func (ff *FilterField) appendFloat(value float64) { + var valueArray []float64 + + if ff.boolValues == nil { + valueArray = make([]float64, 0) + } else { + valueArray = *ff.floatValues + } + valueArray = append(valueArray, value) + ff.floatValues = &valueArray + ff.valueKind = reflect.Int +} + type valueGetterFunc func(ff *FilterField, v reflect.Value) error func boolValueGetter(ff *FilterField, v reflect.Value) error { @@ -121,8 +191,8 @@ func newTypeGetter(t reflect.Type, allowAddr bool) valueGetterFunc { // return newMapEncoder(t) case reflect.Slice: return newSliceGetter(t) - case reflect.Array: - return newArrayGetter(t) + // case reflect.Array: + // return newArrayGetter(t) case reflect.Pointer: return newPtrValueGetter(t) } @@ -130,44 +200,64 @@ func newTypeGetter(t reflect.Type, allowAddr bool) valueGetterFunc { } type ptrValueGetter struct { - elemEnc valueGetterFunc + elemGetter valueGetterFunc } func (pvg ptrValueGetter) getValue(ff *FilterField, v reflect.Value) error { - pvg.elemEnc(ff, v.Elem()) + pvg.elemGetter(ff, v.Elem()) return nil } func newPtrValueGetter(t reflect.Type) valueGetterFunc { - enc := ptrValueGetter{elemEnc: typeGetter(t.Elem())} + enc := ptrValueGetter{elemGetter: typeGetter(t.Elem())} return enc.getValue } type arrayGetter struct { - elemEnc valueGetterFunc + elemGetter valueGetterFunc } func (ag arrayGetter) getValue(ff *FilterField, v reflect.Value) error { - ag.elemEnc(ff, v.Elem()) + ag.elemGetter(ff, v.Elem()) return nil } func newArrayGetter(t reflect.Type) valueGetterFunc { - enc := arrayGetter{elemEnc: typeGetter(t.Elem())} + enc := arrayGetter{elemGetter: typeGetter(t.Elem())} return enc.getValue } type sliceGetter struct { - elemEnc valueGetterFunc + elemGetter valueGetterFunc } func (sg sliceGetter) getValue(ff *FilterField, v reflect.Value) error { - fmt.Printf("%+v\n", v.Slice(0, 1)) - sg.elemEnc(ff, v.Elem()) + for n := range v.Len() { + element := v.Index(n) + fmt.Printf("ELEMENT: %+v\n", element) + + switch element.Kind() { + case reflect.Bool: + ff.appendBool(element.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + ff.appendInt(element.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + ff.appendUint(element.Uint()) + case reflect.Float32, reflect.Float64: + ff.appendFloat(element.Float()) + case reflect.String: + ff.appendStr(element.String()) + } + } + fmt.Println(v.Len()) + fmt.Printf(">>> getValue %+v\n", v) + fmt.Printf(">>> ff %+v\n", ff.strValues) + // fmt.Printf("%+v\n", v.Slice(0, 1)) + // sg.elemGetter(ff, v.Elem()) return nil } func newSliceGetter(t reflect.Type) valueGetterFunc { - enc := sliceGetter{elemEnc: typeGetter(t.Elem())} + enc := sliceGetter{elemGetter: newArrayGetter(t)} return enc.getValue } diff --git a/app/repository/smartfilter/filters.go b/app/repository/smartfilter/filters.go index 7cc77e1..5a1cb40 100644 --- a/app/repository/smartfilter/filters.go +++ b/app/repository/smartfilter/filters.go @@ -51,7 +51,13 @@ func applyFilterLE[T bool | int64 | uint64 | float64 | string]( } func applyFilterIN[T bool | int64 | uint64 | float64 | string]( - query *gorm.DB, tableName string, filterField *FilterField, value T, + query *gorm.DB, tableName string, filterField *FilterField, value *[]T, ) *gorm.DB { - return query.Where(fmt.Sprintf("%s.%s IN ?", tableName, filterField.Name), value) + return query.Where(fmt.Sprintf("%s.%s IN (?)", tableName, filterField.Name), *value) +} + +func applyFilterNOT_IN[T bool | int64 | uint64 | float64 | string]( + query *gorm.DB, tableName string, filterField *FilterField, value *[]T, +) *gorm.DB { + return query.Where(fmt.Sprintf("%s.%s NOT IN (?)", tableName, filterField.Name), *value) } diff --git a/app/repository/smartfilter/handlers.go b/app/repository/smartfilter/handlers.go index 68ce695..332b1ff 100644 --- a/app/repository/smartfilter/handlers.go +++ b/app/repository/smartfilter/handlers.go @@ -112,14 +112,32 @@ func handleOperatorLE(query *gorm.DB, tableName string, filterField *FilterField func handleOperatorIN(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { switch filterField.valueKind { + case reflect.Bool: + return applyFilterIN(query, tableName, filterField, filterField.boolValues) case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: - return applyFilterIN(query, tableName, filterField, *filterField.intValue) + return applyFilterIN(query, tableName, filterField, filterField.intValues) case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return applyFilterIN(query, tableName, filterField, *filterField.uintValue) + return applyFilterIN(query, tableName, filterField, filterField.uintValues) case reflect.Float32, reflect.Float64: - return applyFilterIN(query, tableName, filterField, *filterField.floatValue) + return applyFilterIN(query, tableName, filterField, filterField.floatValues) case reflect.String: - return applyFilterIN(query, tableName, filterField, *filterField.strValue) + return applyFilterIN(query, tableName, filterField, filterField.strValues) + } + return nil +} + +func handleOperatorNOT_IN(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB { + switch filterField.valueKind { + case reflect.Bool: + return applyFilterNOT_IN(query, tableName, filterField, filterField.boolValues) + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return applyFilterNOT_IN(query, tableName, filterField, filterField.intValues) + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return applyFilterNOT_IN(query, tableName, filterField, filterField.uintValues) + case reflect.Float32, reflect.Float64: + return applyFilterNOT_IN(query, tableName, filterField, filterField.floatValues) + case reflect.String: + return applyFilterNOT_IN(query, tableName, filterField, filterField.strValues) } return nil } diff --git a/app/repository/smartfilter/operators.go b/app/repository/smartfilter/operators.go index 24e0c61..10be9b7 100644 --- a/app/repository/smartfilter/operators.go +++ b/app/repository/smartfilter/operators.go @@ -3,20 +3,21 @@ package smartfilter type Operator string const ( - OperatorEQ Operator = "EQ" - OperatorNE Operator = "NE" - OperatorGT Operator = "GT" - OperatorGE Operator = "GE" - OperatorLT Operator = "LT" - OperatorLE Operator = "LE" - OperatorLIKE Operator = "LIKE" - OperatorILIKE Operator = "ILIKE" - OperatorIN Operator = "IN" + OperatorEQ Operator = "EQ" + OperatorNE Operator = "NE" + OperatorGT Operator = "GT" + OperatorGE Operator = "GE" + OperatorLT Operator = "LT" + OperatorLE Operator = "LE" + OperatorLIKE Operator = "LIKE" + OperatorILIKE Operator = "ILIKE" + OperatorIN Operator = "IN" + OperatorNOT_IN Operator = "NOT_IN" ) var OPERATORS = []Operator{ OperatorEQ, OperatorNE, OperatorGT, OperatorGE, OperatorLT, OperatorLE, OperatorLIKE, OperatorILIKE, - OperatorIN, + OperatorIN, OperatorNOT_IN, } diff --git a/app/repository/smartfilter/smartfilter.go b/app/repository/smartfilter/smartfilter.go index 53e68cd..5349696 100644 --- a/app/repository/smartfilter/smartfilter.go +++ b/app/repository/smartfilter/smartfilter.go @@ -17,15 +17,16 @@ const TAG_VALUE_SEPARATOR = "," type handlerFunc func(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB var operatorHandlers = map[Operator]handlerFunc{ - OperatorEQ: handleOperatorEQ, - OperatorNE: handleOperatorNE, - OperatorGT: handleOperatorGT, - OperatorGE: handleOperatorGE, - OperatorLT: handleOperatorLT, - OperatorLE: handleOperatorLE, - OperatorLIKE: handleOperatorLIKE, - OperatorILIKE: handleOperatorILIKE, - OperatorIN: handleOperatorIN, + OperatorEQ: handleOperatorEQ, + OperatorNE: handleOperatorNE, + OperatorGT: handleOperatorGT, + OperatorGE: handleOperatorGE, + OperatorLT: handleOperatorLT, + OperatorLE: handleOperatorLE, + OperatorLIKE: handleOperatorLIKE, + OperatorILIKE: handleOperatorILIKE, + OperatorIN: handleOperatorIN, + OperatorNOT_IN: handleOperatorNOT_IN, } type CertFilter struct { @@ -35,6 +36,7 @@ type CertFilter struct { IssuerContains *string `filterfield:"issuer,ILIKE"` Id *string `filterfield:"id,EQ"` Ids *[]string `filterfield:"id,IN"` + IdsNot *[]string `filterfield:"id,NOT_IN"` CreatedAt_Lt *time.Time `filterfield:"created_at,LT"` } @@ -73,6 +75,8 @@ func getFilterFields(filter interface{}) []ReflectedStructField { value: fieldValue, }) } + fmt.Println("-------------- RES --------------") + fmt.Printf("%+v\n", res) return res } From 35082b1f6aca09ee6ac330b38e504626b8da46e2 Mon Sep 17 00:00:00 2001 From: Eden Kirin Date: Sat, 22 Jun 2024 23:17:37 +0200 Subject: [PATCH 3/4] Tests --- app/repository/smartfilter/handlers_test.go | 279 ++++++++++++++++++++ go.mod | 4 +- go.sum | 7 + 3 files changed, 289 insertions(+), 1 deletion(-) create mode 100644 app/repository/smartfilter/handlers_test.go diff --git a/app/repository/smartfilter/handlers_test.go b/app/repository/smartfilter/handlers_test.go new file mode 100644 index 0000000..9e16a12 --- /dev/null +++ b/app/repository/smartfilter/handlers_test.go @@ -0,0 +1,279 @@ +package smartfilter + +import ( + "log" + "reflect" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/go-playground/assert" + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +func NewMockDB() (*gorm.DB, sqlmock.Sqlmock) { + db, mock, err := sqlmock.New() + if err != nil { + log.Fatalf("An error '%s' was not expected when opening a stub database connection", err) + } + + gormDB, err := gorm.Open(mysql.New(mysql.Config{ + Conn: db, + SkipInitializeWithVersion: true, + }), &gorm.Config{}) + + if err != nil { + log.Fatalf("An error '%s' was not expected when opening gorm database", err) + } + + return gormDB, mock +} + +type MyModel struct { + Id int + Value string +} + +func TestHandleOperatorEQ(t *testing.T) { + db, _ := NewMockDB() + testFunc := handleOperatorEQ + + t.Run("Test handleOperatorEQ bool true", func(t *testing.T) { + var value bool = true + filterField := FilterField{ + Name: "my_field", + boolValue: &value, + valueKind: reflect.Bool, + } + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field = true ORDER BY `my_models`.`id` LIMIT 1", sql) + }) + + t.Run("Test handleOperatorEQ bool false", func(t *testing.T) { + var value bool = false + filterField := FilterField{ + Name: "my_field", + boolValue: &value, + valueKind: reflect.Bool, + } + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field = false ORDER BY `my_models`.`id` LIMIT 1", sql) + }) + + t.Run("Test handleOperatorEQ int64", func(t *testing.T) { + var value int64 = -123456 + filterField := FilterField{ + Name: "my_field", + intValue: &value, + valueKind: reflect.Int64, + } + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field = -123456 ORDER BY `my_models`.`id` LIMIT 1", sql) + }) + + t.Run("Test handleOperatorEQ uint64", func(t *testing.T) { + var value uint64 = 123456 + filterField := FilterField{ + Name: "my_field", + uintValue: &value, + valueKind: reflect.Uint64, + } + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field = 123456 ORDER BY `my_models`.`id` LIMIT 1", sql) + }) + + t.Run("Test handleOperatorEQ float", func(t *testing.T) { + var value float64 = -123456.789 + filterField := FilterField{ + Name: "my_field", + floatValue: &value, + valueKind: reflect.Float64, + } + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field = -123456.789 ORDER BY `my_models`.`id` LIMIT 1", sql) + }) + + t.Run("Test handleOperatorEQ string", func(t *testing.T) { + var value string = "Some Value" + filterField := FilterField{ + Name: "my_field", + strValue: &value, + valueKind: reflect.String, + } + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field = 'Some Value' ORDER BY `my_models`.`id` LIMIT 1", sql) + }) +} + +func TestHandleOperatorNE(t *testing.T) { + db, _ := NewMockDB() + testFunc := handleOperatorNE + + t.Run("Test handleOperatorNE bool true", func(t *testing.T) { + var value bool = true + filterField := FilterField{ + Name: "my_field", + boolValue: &value, + valueKind: reflect.Bool, + } + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field != true ORDER BY `my_models`.`id` LIMIT 1", sql) + }) + + t.Run("Test handleOperatorNE bool false", func(t *testing.T) { + var value bool = false + filterField := FilterField{ + Name: "my_field", + boolValue: &value, + valueKind: reflect.Bool, + } + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field != false ORDER BY `my_models`.`id` LIMIT 1", sql) + }) + + t.Run("Test handleOperatorNE int64", func(t *testing.T) { + var value int64 = -123456 + filterField := FilterField{ + Name: "my_field", + intValue: &value, + valueKind: reflect.Int64, + } + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field != -123456 ORDER BY `my_models`.`id` LIMIT 1", sql) + }) + + t.Run("Test handleOperatorNE uint64", func(t *testing.T) { + var value uint64 = 123456 + filterField := FilterField{ + Name: "my_field", + uintValue: &value, + valueKind: reflect.Uint64, + } + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field != 123456 ORDER BY `my_models`.`id` LIMIT 1", sql) + }) + + t.Run("Test handleOperatorNE float", func(t *testing.T) { + var value float64 = -123456.789 + filterField := FilterField{ + Name: "my_field", + floatValue: &value, + valueKind: reflect.Float64, + } + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field != -123456.789 ORDER BY `my_models`.`id` LIMIT 1", sql) + }) + + t.Run("Test handleOperatorNE string", func(t *testing.T) { + var value string = "Some Value" + filterField := FilterField{ + Name: "my_field", + strValue: &value, + valueKind: reflect.String, + } + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field != 'Some Value' ORDER BY `my_models`.`id` LIMIT 1", sql) + }) +} + +func TestHandleOperatorLIKE(t *testing.T) { + db, _ := NewMockDB() + testFunc := handleOperatorLIKE + + t.Run("Test handleOperatorLIKE", func(t *testing.T) { + var value string = "Some Value" + filterField := FilterField{ + Name: "my_field", + strValue: &value, + valueKind: reflect.String, + } + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field LIKE '%Some Value%' ORDER BY `my_models`.`id` LIMIT 1", sql) + }) +} + +func TestHandleOperatorILIKE(t *testing.T) { + db, _ := NewMockDB() + testFunc := handleOperatorILIKE + + t.Run("Test handleOperatorLIKE", func(t *testing.T) { + var value string = "Some Value" + filterField := FilterField{ + Name: "my_field", + strValue: &value, + valueKind: reflect.String, + } + + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field ILIKE '%Some Value%' ORDER BY `my_models`.`id` LIMIT 1", sql) + }) +} diff --git a/go.mod b/go.mod index 7dc71c0..57f463f 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/google/uuid v1.6.0 github.com/kelseyhightower/envconfig v1.4.0 github.com/mozillazg/go-slugify v0.2.0 - github.com/stretchr/testify v1.8.1 + github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.27.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/datatypes v1.2.1 @@ -16,7 +16,9 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect + github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-playground/assert v1.2.1 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect diff --git a/go.sum b/go.sum index 2780cf9..6b16d44 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,12 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-playground/assert v1.2.1 h1:ad06XqC+TOv0nJWnbULSlh3ehp5uLuQEojZY5Tq8RgI= +github.com/go-playground/assert v1.2.1/go.mod h1:Lgy+k19nOB/wQG/fVSQ7rra5qYugmytMQqvQ2dgjWn8= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= @@ -26,6 +30,7 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8= github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -51,6 +56,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= From 6e31c8eb980304ed001807ea27e582f28d311dae Mon Sep 17 00:00:00 2001 From: Eden Kirin Date: Sat, 22 Jun 2024 23:48:42 +0200 Subject: [PATCH 4/4] Tests --- app/repository/smartfilter/handlers_test.go | 755 ++++++++++++++------ 1 file changed, 543 insertions(+), 212 deletions(-) diff --git a/app/repository/smartfilter/handlers_test.go b/app/repository/smartfilter/handlers_test.go index 9e16a12..4efc07a 100644 --- a/app/repository/smartfilter/handlers_test.go +++ b/app/repository/smartfilter/handlers_test.go @@ -34,246 +34,577 @@ type MyModel struct { Value string } +func (m MyModel) TableName() string { + return "my_models" +} + +type TestCase struct { + name string + filterField FilterField + expected string +} + +var ( + boolTrue bool = true + boolFalse bool = false + int64Value int64 = -123456 + uint64Value uint64 = 123456 + floatValue float64 = -123456.789 + strValue string = "Some Value" + + boolValues = []bool{true, false} + int64Values = []int64{-123456, 1, 123456} + uint64Values = []uint64{123456, 1234567, 1234568} + floatValues = []float64{-123456.789, -1, 123456.789} + strValues = []string{"First Value", "Second Value", "Third Value"} +) + func TestHandleOperatorEQ(t *testing.T) { db, _ := NewMockDB() testFunc := handleOperatorEQ - t.Run("Test handleOperatorEQ bool true", func(t *testing.T) { - var value bool = true - filterField := FilterField{ - Name: "my_field", - boolValue: &value, - valueKind: reflect.Bool, - } + testCases := []TestCase{ + { + name: "handleOperatorEQ bool true", + filterField: FilterField{ + Name: "my_field", + boolValue: &boolTrue, + valueKind: reflect.Bool, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field = true ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorEQ bool false", + filterField: FilterField{ + Name: "my_field", + boolValue: &boolFalse, + valueKind: reflect.Bool, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field = false ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorEQ int64", + filterField: FilterField{ + Name: "my_field", + intValue: &int64Value, + valueKind: reflect.Int64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field = -123456 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorEQ uint64", + filterField: FilterField{ + Name: "my_field", + uintValue: &uint64Value, + valueKind: reflect.Uint64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field = 123456 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorEQ float", + filterField: FilterField{ + Name: "my_field", + floatValue: &floatValue, + valueKind: reflect.Float64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field = -123456.789 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorEQ string", + filterField: FilterField{ + Name: "my_field", + strValue: &strValue, + valueKind: reflect.String, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field = 'Some Value' ORDER BY `my_models`.`id` LIMIT 1", + }, + } - sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { - query := tx.Model(&MyModel{}) - query = testFunc(query, "my_table", &filterField) - return query.First(&MyModel{}) + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &testCase.filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, testCase.expected, sql) }) - assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field = true ORDER BY `my_models`.`id` LIMIT 1", sql) - }) - - t.Run("Test handleOperatorEQ bool false", func(t *testing.T) { - var value bool = false - filterField := FilterField{ - Name: "my_field", - boolValue: &value, - valueKind: reflect.Bool, - } - - sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { - query := tx.Model(&MyModel{}) - query = testFunc(query, "my_table", &filterField) - return query.First(&MyModel{}) - }) - assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field = false ORDER BY `my_models`.`id` LIMIT 1", sql) - }) - - t.Run("Test handleOperatorEQ int64", func(t *testing.T) { - var value int64 = -123456 - filterField := FilterField{ - Name: "my_field", - intValue: &value, - valueKind: reflect.Int64, - } - - sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { - query := tx.Model(&MyModel{}) - query = testFunc(query, "my_table", &filterField) - return query.First(&MyModel{}) - }) - assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field = -123456 ORDER BY `my_models`.`id` LIMIT 1", sql) - }) - - t.Run("Test handleOperatorEQ uint64", func(t *testing.T) { - var value uint64 = 123456 - filterField := FilterField{ - Name: "my_field", - uintValue: &value, - valueKind: reflect.Uint64, - } - - sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { - query := tx.Model(&MyModel{}) - query = testFunc(query, "my_table", &filterField) - return query.First(&MyModel{}) - }) - assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field = 123456 ORDER BY `my_models`.`id` LIMIT 1", sql) - }) - - t.Run("Test handleOperatorEQ float", func(t *testing.T) { - var value float64 = -123456.789 - filterField := FilterField{ - Name: "my_field", - floatValue: &value, - valueKind: reflect.Float64, - } - - sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { - query := tx.Model(&MyModel{}) - query = testFunc(query, "my_table", &filterField) - return query.First(&MyModel{}) - }) - assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field = -123456.789 ORDER BY `my_models`.`id` LIMIT 1", sql) - }) - - t.Run("Test handleOperatorEQ string", func(t *testing.T) { - var value string = "Some Value" - filterField := FilterField{ - Name: "my_field", - strValue: &value, - valueKind: reflect.String, - } - - sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { - query := tx.Model(&MyModel{}) - query = testFunc(query, "my_table", &filterField) - return query.First(&MyModel{}) - }) - assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field = 'Some Value' ORDER BY `my_models`.`id` LIMIT 1", sql) - }) + } } func TestHandleOperatorNE(t *testing.T) { db, _ := NewMockDB() testFunc := handleOperatorNE - t.Run("Test handleOperatorNE bool true", func(t *testing.T) { - var value bool = true - filterField := FilterField{ - Name: "my_field", - boolValue: &value, - valueKind: reflect.Bool, - } + testCases := []TestCase{ + { + name: "handleOperatorNE bool true", + filterField: FilterField{ + Name: "my_field", + boolValue: &boolTrue, + valueKind: reflect.Bool, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field != true ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorNE bool false", + filterField: FilterField{ + Name: "my_field", + boolValue: &boolFalse, + valueKind: reflect.Bool, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field != false ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorNE int64", + filterField: FilterField{ + Name: "my_field", + intValue: &int64Value, + valueKind: reflect.Int64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field != -123456 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorNE uint64", + filterField: FilterField{ + Name: "my_field", + uintValue: &uint64Value, + valueKind: reflect.Uint64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field != 123456 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorNE float", + filterField: FilterField{ + Name: "my_field", + floatValue: &floatValue, + valueKind: reflect.Float64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field != -123456.789 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorNE string", + filterField: FilterField{ + Name: "my_field", + strValue: &strValue, + valueKind: reflect.String, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field != 'Some Value' ORDER BY `my_models`.`id` LIMIT 1", + }, + } - sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { - query := tx.Model(&MyModel{}) - query = testFunc(query, "my_table", &filterField) - return query.First(&MyModel{}) + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &testCase.filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, testCase.expected, sql) }) - assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field != true ORDER BY `my_models`.`id` LIMIT 1", sql) - }) - - t.Run("Test handleOperatorNE bool false", func(t *testing.T) { - var value bool = false - filterField := FilterField{ - Name: "my_field", - boolValue: &value, - valueKind: reflect.Bool, - } - - sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { - query := tx.Model(&MyModel{}) - query = testFunc(query, "my_table", &filterField) - return query.First(&MyModel{}) - }) - assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field != false ORDER BY `my_models`.`id` LIMIT 1", sql) - }) - - t.Run("Test handleOperatorNE int64", func(t *testing.T) { - var value int64 = -123456 - filterField := FilterField{ - Name: "my_field", - intValue: &value, - valueKind: reflect.Int64, - } - - sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { - query := tx.Model(&MyModel{}) - query = testFunc(query, "my_table", &filterField) - return query.First(&MyModel{}) - }) - assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field != -123456 ORDER BY `my_models`.`id` LIMIT 1", sql) - }) - - t.Run("Test handleOperatorNE uint64", func(t *testing.T) { - var value uint64 = 123456 - filterField := FilterField{ - Name: "my_field", - uintValue: &value, - valueKind: reflect.Uint64, - } - - sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { - query := tx.Model(&MyModel{}) - query = testFunc(query, "my_table", &filterField) - return query.First(&MyModel{}) - }) - assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field != 123456 ORDER BY `my_models`.`id` LIMIT 1", sql) - }) - - t.Run("Test handleOperatorNE float", func(t *testing.T) { - var value float64 = -123456.789 - filterField := FilterField{ - Name: "my_field", - floatValue: &value, - valueKind: reflect.Float64, - } - - sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { - query := tx.Model(&MyModel{}) - query = testFunc(query, "my_table", &filterField) - return query.First(&MyModel{}) - }) - assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field != -123456.789 ORDER BY `my_models`.`id` LIMIT 1", sql) - }) - - t.Run("Test handleOperatorNE string", func(t *testing.T) { - var value string = "Some Value" - filterField := FilterField{ - Name: "my_field", - strValue: &value, - valueKind: reflect.String, - } - - sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { - query := tx.Model(&MyModel{}) - query = testFunc(query, "my_table", &filterField) - return query.First(&MyModel{}) - }) - assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field != 'Some Value' ORDER BY `my_models`.`id` LIMIT 1", sql) - }) + } } func TestHandleOperatorLIKE(t *testing.T) { db, _ := NewMockDB() testFunc := handleOperatorLIKE - t.Run("Test handleOperatorLIKE", func(t *testing.T) { - var value string = "Some Value" - filterField := FilterField{ - Name: "my_field", - strValue: &value, - valueKind: reflect.String, - } + testCases := []TestCase{ + { + name: "handleOperatorLIKE", + filterField: FilterField{ + Name: "my_field", + strValue: &strValue, + valueKind: reflect.String, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field LIKE '%Some Value%' ORDER BY `my_models`.`id` LIMIT 1", + }, + } - sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { - query := tx.Model(&MyModel{}) - query = testFunc(query, "my_table", &filterField) - return query.First(&MyModel{}) + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &testCase.filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, testCase.expected, sql) }) - assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field LIKE '%Some Value%' ORDER BY `my_models`.`id` LIMIT 1", sql) - }) + } } func TestHandleOperatorILIKE(t *testing.T) { db, _ := NewMockDB() testFunc := handleOperatorILIKE - t.Run("Test handleOperatorLIKE", func(t *testing.T) { - var value string = "Some Value" - filterField := FilterField{ - Name: "my_field", - strValue: &value, - valueKind: reflect.String, - } + testCases := []TestCase{ + { + name: "handleOperatorILIKE", + filterField: FilterField{ + Name: "my_field", + strValue: &strValue, + valueKind: reflect.String, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field ILIKE '%Some Value%' ORDER BY `my_models`.`id` LIMIT 1", + }, + } - sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { - query := tx.Model(&MyModel{}) - query = testFunc(query, "my_table", &filterField) - return query.First(&MyModel{}) + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &testCase.filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, testCase.expected, sql) }) - assert.Equal(t, "SELECT * FROM `my_models` WHERE my_table.my_field ILIKE '%Some Value%' ORDER BY `my_models`.`id` LIMIT 1", sql) - }) + } +} + +func TestHandleOperatorGT(t *testing.T) { + db, _ := NewMockDB() + testFunc := handleOperatorGT + + testCases := []TestCase{ + { + name: "handleOperatorGT int64", + filterField: FilterField{ + Name: "my_field", + intValue: &int64Value, + valueKind: reflect.Int64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field > -123456 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorGT uint64", + filterField: FilterField{ + Name: "my_field", + uintValue: &uint64Value, + valueKind: reflect.Uint64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field > 123456 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorGT float", + filterField: FilterField{ + Name: "my_field", + floatValue: &floatValue, + valueKind: reflect.Float64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field > -123456.789 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorGT string", + filterField: FilterField{ + Name: "my_field", + strValue: &strValue, + valueKind: reflect.String, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field > 'Some Value' ORDER BY `my_models`.`id` LIMIT 1", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &testCase.filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, testCase.expected, sql) + }) + } +} + +func TestHandleOperatorGE(t *testing.T) { + db, _ := NewMockDB() + testFunc := handleOperatorGE + + testCases := []TestCase{ + { + name: "handleOperatorGE int64", + filterField: FilterField{ + Name: "my_field", + intValue: &int64Value, + valueKind: reflect.Int64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field >= -123456 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorGE uint64", + filterField: FilterField{ + Name: "my_field", + uintValue: &uint64Value, + valueKind: reflect.Uint64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field >= 123456 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorGE float", + filterField: FilterField{ + Name: "my_field", + floatValue: &floatValue, + valueKind: reflect.Float64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field >= -123456.789 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorGE string", + filterField: FilterField{ + Name: "my_field", + strValue: &strValue, + valueKind: reflect.String, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field >= 'Some Value' ORDER BY `my_models`.`id` LIMIT 1", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &testCase.filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, testCase.expected, sql) + }) + } +} + +func TestHandleOperatorLT(t *testing.T) { + db, _ := NewMockDB() + testFunc := handleOperatorLT + + testCases := []TestCase{ + { + name: "handleOperatorLT int64", + filterField: FilterField{ + Name: "my_field", + intValue: &int64Value, + valueKind: reflect.Int64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field < -123456 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorLT uint64", + filterField: FilterField{ + Name: "my_field", + uintValue: &uint64Value, + valueKind: reflect.Uint64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field < 123456 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorLT float", + filterField: FilterField{ + Name: "my_field", + floatValue: &floatValue, + valueKind: reflect.Float64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field < -123456.789 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorLT string", + filterField: FilterField{ + Name: "my_field", + strValue: &strValue, + valueKind: reflect.String, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field < 'Some Value' ORDER BY `my_models`.`id` LIMIT 1", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &testCase.filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, testCase.expected, sql) + }) + } +} + +func TestHandleOperatorLE(t *testing.T) { + db, _ := NewMockDB() + testFunc := handleOperatorLE + + testCases := []TestCase{ + { + name: "handleOperatorLE int64", + filterField: FilterField{ + Name: "my_field", + intValue: &int64Value, + valueKind: reflect.Int64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field <= -123456 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorLE uint64", + filterField: FilterField{ + Name: "my_field", + uintValue: &uint64Value, + valueKind: reflect.Uint64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field <= 123456 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorLE float", + filterField: FilterField{ + Name: "my_field", + floatValue: &floatValue, + valueKind: reflect.Float64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field <= -123456.789 ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorLE string", + filterField: FilterField{ + Name: "my_field", + strValue: &strValue, + valueKind: reflect.String, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field <= 'Some Value' ORDER BY `my_models`.`id` LIMIT 1", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &testCase.filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, testCase.expected, sql) + }) + } +} + +func TestHandleOperatorIN(t *testing.T) { + db, _ := NewMockDB() + testFunc := handleOperatorIN + + testCases := []TestCase{ + { + name: "handleOperatorIN bool", + filterField: FilterField{ + Name: "my_field", + boolValues: &boolValues, + valueKind: reflect.Bool, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field IN (true,false) ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorIN int64", + filterField: FilterField{ + Name: "my_field", + intValues: &int64Values, + valueKind: reflect.Int64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field IN (-123456,1,123456) ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorIN uint64", + filterField: FilterField{ + Name: "my_field", + uintValues: &uint64Values, + valueKind: reflect.Uint64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field IN (123456,1234567,1234568) ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorIN float", + filterField: FilterField{ + Name: "my_field", + floatValues: &floatValues, + valueKind: reflect.Float64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field IN (-123456.789,-1,123456.789) ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorIN string", + filterField: FilterField{ + Name: "my_field", + strValues: &strValues, + valueKind: reflect.String, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field IN ('First Value','Second Value','Third Value') ORDER BY `my_models`.`id` LIMIT 1", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &testCase.filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, testCase.expected, sql) + }) + } +} + +func TestHandleOperatorNOT_IN(t *testing.T) { + db, _ := NewMockDB() + testFunc := handleOperatorNOT_IN + + testCases := []TestCase{ + { + name: "handleOperatorNOT_IN bool", + filterField: FilterField{ + Name: "my_field", + boolValues: &boolValues, + valueKind: reflect.Bool, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field NOT IN (true,false) ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorNOT_IN int64", + filterField: FilterField{ + Name: "my_field", + intValues: &int64Values, + valueKind: reflect.Int64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field NOT IN (-123456,1,123456) ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorNOT_IN uint64", + filterField: FilterField{ + Name: "my_field", + uintValues: &uint64Values, + valueKind: reflect.Uint64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field NOT IN (123456,1234567,1234568) ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorNOT_IN float", + filterField: FilterField{ + Name: "my_field", + floatValues: &floatValues, + valueKind: reflect.Float64, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field NOT IN (-123456.789,-1,123456.789) ORDER BY `my_models`.`id` LIMIT 1", + }, + { + name: "handleOperatorNOT_IN string", + filterField: FilterField{ + Name: "my_field", + strValues: &strValues, + valueKind: reflect.String, + }, + expected: "SELECT * FROM `my_models` WHERE my_table.my_field NOT IN ('First Value','Second Value','Third Value') ORDER BY `my_models`.`id` LIMIT 1", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { + query := tx.Model(&MyModel{}) + query = testFunc(query, "my_table", &testCase.filterField) + return query.First(&MyModel{}) + }) + assert.Equal(t, testCase.expected, sql) + }) + } }