diff --git a/app/main.go b/app/main.go index 02d38c3..7df139f 100644 --- a/app/main.go +++ b/app/main.go @@ -29,7 +29,18 @@ type CertFilter struct { IdsNot *[]string `filterfield:"field=id;operator=NOT_IN"` CreatedAt_Lt *time.Time `filterfield:"field=created_at;operator=LT"` Timestamps *[]time.Time `filterfield:"field=created_at;operator=IN"` - // CompanyIsActive *bool `filterfield:"joins=companies;field=is_active;operator=EQ"` + CompanyIsActive *bool +} + +func (f CertFilter) ApplyQuery(query *gorm.DB) *gorm.DB { + if f.CompanyIsActive != nil { + query = query.Joins( + fmt.Sprintf( + "JOIN companies ON certificates.company_id = companies.id WHERE companies.is_active = %t", + *f.CompanyIsActive, + )) + } + return query } type CompanyFilter struct { @@ -99,7 +110,7 @@ func doListWithJoins(db *gorm.DB) { repo.Init(db, nil) filter := CertFilter{ - Alive: &TRUE, + CompanyIsActive: &FALSE, } certs, err := repo.List(filter, nil) @@ -203,11 +214,11 @@ func main() { db := db.InitDB() // doMagic(db) - doList(db) - // doListWithJoins(db) + // doList(db) + doListWithJoins(db) // doCount(db) // doSave(db) - doDelete(db) + // doDelete(db) // doGet(db) // doExists(db) // inheritance.DoInheritanceTest() diff --git a/app/repository/smartfilter/filterfield.go b/app/repository/smartfilter/filterfield.go index 8cc2c06..a4cf4b8 100644 --- a/app/repository/smartfilter/filterfield.go +++ b/app/repository/smartfilter/filterfield.go @@ -11,7 +11,6 @@ import ( type FilterField struct { Name string Operator Operator - Joins []string valueKind reflect.Kind boolValue *bool diff --git a/app/repository/smartfilter/smartfilter.go b/app/repository/smartfilter/smartfilter.go index 84bd8d9..606e821 100644 --- a/app/repository/smartfilter/smartfilter.go +++ b/app/repository/smartfilter/smartfilter.go @@ -12,10 +12,15 @@ import ( const TAG_NAME = "filterfield" const TAG_PAIRS_SEPARATOR = ";" +const TAG_LIST_SEPARATOR = "," const TAG_KEYVALUE_SEPARATOR = "=" type handlerFunc func(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB +type QueryApplier interface { + ApplyQuery(query *gorm.DB) *gorm.DB +} + var operatorHandlers = map[Operator]handlerFunc{ OperatorEQ: handleOperatorEQ, OperatorNE: handleOperatorNE, @@ -67,15 +72,20 @@ func getFilterFields(filter interface{}) []ReflectedStructField { return res } +func getQueryApplierInterface(filter interface{}) QueryApplier { + queryApplier, ok := filter.(QueryApplier) + if ok { + return queryApplier + } + return nil +} + func ToQuery(model schema.Tabler, filter interface{}, query *gorm.DB) (*gorm.DB, error) { st := reflect.TypeOf(filter) tableName := model.TableName() modelName := st.Name() - fmt.Printf("Table name: %s\n", tableName) - fmt.Printf("Model name: %s\n", modelName) - fields := getFilterFields(filter) for _, field := range fields { filterField, err := newFilterField(field.tagValue) @@ -97,22 +107,36 @@ func ToQuery(model schema.Tabler, filter interface{}, query *gorm.DB) (*gorm.DB, } } + // apply custom filters, if interface exists + queryApplier := getQueryApplierInterface(filter) + if queryApplier != nil { + query = queryApplier.ApplyQuery(query) + } + return query, nil } +func splitTrim(value string, separator string) []string { + var out []string = []string{} + for _, s := range strings.Split(value, separator) { + if len(s) == 0 { + continue + } + out = append(out, strings.TrimSpace(s)) + } + return out +} + func newFilterField(tagValue string) (*FilterField, error) { filterField := FilterField{} - tagValue = strings.TrimSpace(tagValue) - pairs := strings.Split(tagValue, TAG_PAIRS_SEPARATOR) - - for _, pair := range pairs { - kvs := strings.Split(pair, TAG_KEYVALUE_SEPARATOR) + for _, pair := range splitTrim(tagValue, TAG_PAIRS_SEPARATOR) { + kvs := splitTrim(pair, TAG_KEYVALUE_SEPARATOR) if len(kvs) != 2 { return nil, fmt.Errorf("invalid tag value: %s", strings.TrimSpace(pair)) } - key := strings.TrimSpace(kvs[0]) - value := strings.TrimSpace(kvs[1]) + key := kvs[0] + value := kvs[1] switch key { case "field": diff --git a/app/repository/smartfilter/smartfilter_test.go b/app/repository/smartfilter/smartfilter_test.go index c9834d6..554e4a2 100644 --- a/app/repository/smartfilter/smartfilter_test.go +++ b/app/repository/smartfilter/smartfilter_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "gorm.io/gorm" ) func TestGetFilterFields(t *testing.T) { @@ -203,3 +204,32 @@ func TestFilterField(t *testing.T) { assert.EqualError(t, err, "missing operator in tag: field=field_1") }) } + +type filterWithoutQueryApplier struct{} + +type filterWithQueryApplier struct{} + +func (f filterWithQueryApplier) ApplyQuery(query *gorm.DB) *gorm.DB { + return query +} + +func TestSmartfilterApplyQuery(t *testing.T) { + + t.Run("Get query applier interface - without interface", func(t *testing.T) { + f := filterWithoutQueryApplier{} + queryApplier := getQueryApplierInterface(f) + assert.Nil(t, queryApplier) + }) + + t.Run("Get query applier interface - with interface", func(t *testing.T) { + f := filterWithQueryApplier{} + queryApplier := getQueryApplierInterface(f) + assert.NotNil(t, queryApplier) + }) + + // t.Run("Get query applier interface - call interface function", func(t *testing.T) { + // f := filterWithQueryApplier{} + // queryApplier := getQueryApplierInterface(f) + // assert.NotNil(t, queryApplier) + // }) +}