diff --git a/app/main.go b/app/main.go index 15d2f69..7df139f 100644 --- a/app/main.go +++ b/app/main.go @@ -20,19 +20,31 @@ var ( ) type CertFilter struct { - Alive *bool `filterfield:"field=alive,operator=EQ"` - SerialNumber *string `filterfield:"field=serial_number,operator=NE"` - SerialNumberContains *string `filterfield:"field=serial_number,operator=LIKE"` - IssuerContains *string `filterfield:"field=issuer,operator=ILIKE"` - Id *uuid.UUID `filterfield:"field=id,operator=EQ"` - Ids *[]uuid.UUID `filterfield:"field=id,operator=IN"` - IdsNot *[]string `filterfield:"field=id,operator=NOT_IN"` - CreatedAt_Lt *time.Time `filterfield:"field=created_at,operator=LT"` - Timestamps *[]time.Time `filterfield:"field=created_at,operator=IN"` + Alive *bool `filterfield:"field=alive;operator=EQ"` + SerialNumber *string `filterfield:"field=serial_number;operator=NE"` + SerialNumberContains *string `filterfield:"field=serial_number;operator=LIKE"` + IssuerContains *string `filterfield:"field=issuer;operator=ILIKE"` + Id *uuid.UUID `filterfield:"field=id;operator=EQ"` + Ids *[]uuid.UUID `filterfield:"field=id;operator=IN"` + IdsNot *[]string `filterfield:"field=id;operator=NOT_IN"` + CreatedAt_Lt *time.Time `filterfield:"field=created_at;operator=LT"` + Timestamps *[]time.Time `filterfield:"field=created_at;operator=IN"` + CompanyIsActive *bool +} + +func (f CertFilter) ApplyQuery(query *gorm.DB) *gorm.DB { + if f.CompanyIsActive != nil { + query = query.Joins( + fmt.Sprintf( + "JOIN companies ON certificates.company_id = companies.id WHERE companies.is_active = %t", + *f.CompanyIsActive, + )) + } + return query } type CompanyFilter struct { - IsActive *bool `filterfield:"field=is_active,operator=EQ"` + IsActive *bool `filterfield:"field=is_active;operator=EQ"` } func doMagic(db *gorm.DB) { @@ -93,6 +105,24 @@ func doList(db *gorm.DB) { } } +func doListWithJoins(db *gorm.DB) { + repo := repository.RepoBase[models.Cert]{} + repo.Init(db, nil) + + filter := CertFilter{ + CompanyIsActive: &FALSE, + } + + certs, err := repo.List(filter, nil) + if err != nil { + panic(err) + } + + for n, cert := range *certs { + fmt.Printf(">> [%d] %+v %s (alive %t)\n", n, cert.Id, cert.CreatedAt, cert.Alive) + } +} + func doCount(db *gorm.DB) { repo := repository.RepoBase[models.Cert]{} repo.Init(db, nil) @@ -185,9 +215,10 @@ func main() { // doMagic(db) // doList(db) + doListWithJoins(db) // doCount(db) // doSave(db) - doDelete(db) + // doDelete(db) // doGet(db) // doExists(db) // inheritance.DoInheritanceTest() diff --git a/app/repository/method_list_test.go b/app/repository/method_list_test.go index 16b1195..cd2341a 100644 --- a/app/repository/method_list_test.go +++ b/app/repository/method_list_test.go @@ -43,10 +43,10 @@ func (m MyModel) TableName() string { } type MyModelFilter struct { - Id *uuid.UUID `filterfield:"field=id,operator=EQ"` - Ids *[]uuid.UUID `filterfield:"field=id,operator=IN"` - Value *string `filterfield:"field=value,operator=EQ"` - CntGT *int `filterfield:"field=cnt,operator=GT"` + Id *uuid.UUID `filterfield:"field=id;operator=EQ"` + Ids *[]uuid.UUID `filterfield:"field=id;operator=IN"` + Value *string `filterfield:"field=value;operator=EQ"` + CntGT *int `filterfield:"field=cnt;operator=GT"` } func TestListMethod(t *testing.T) { diff --git a/app/repository/smartfilter/smartfilter.go b/app/repository/smartfilter/smartfilter.go index b77ba4d..606e821 100644 --- a/app/repository/smartfilter/smartfilter.go +++ b/app/repository/smartfilter/smartfilter.go @@ -11,11 +11,16 @@ import ( ) const TAG_NAME = "filterfield" -const TAG_PAIRS_SEPARATOR = "," +const TAG_PAIRS_SEPARATOR = ";" +const TAG_LIST_SEPARATOR = "," const TAG_KEYVALUE_SEPARATOR = "=" type handlerFunc func(query *gorm.DB, tableName string, filterField *FilterField) *gorm.DB +type QueryApplier interface { + ApplyQuery(query *gorm.DB) *gorm.DB +} + var operatorHandlers = map[Operator]handlerFunc{ OperatorEQ: handleOperatorEQ, OperatorNE: handleOperatorNE, @@ -67,15 +72,20 @@ func getFilterFields(filter interface{}) []ReflectedStructField { return res } +func getQueryApplierInterface(filter interface{}) QueryApplier { + queryApplier, ok := filter.(QueryApplier) + if ok { + return queryApplier + } + return nil +} + func ToQuery(model schema.Tabler, filter interface{}, query *gorm.DB) (*gorm.DB, error) { st := reflect.TypeOf(filter) tableName := model.TableName() modelName := st.Name() - fmt.Printf("Table name: %s\n", tableName) - fmt.Printf("Model name: %s\n", modelName) - fields := getFilterFields(filter) for _, field := range fields { filterField, err := newFilterField(field.tagValue) @@ -97,22 +107,36 @@ func ToQuery(model schema.Tabler, filter interface{}, query *gorm.DB) (*gorm.DB, } } + // apply custom filters, if interface exists + queryApplier := getQueryApplierInterface(filter) + if queryApplier != nil { + query = queryApplier.ApplyQuery(query) + } + return query, nil } +func splitTrim(value string, separator string) []string { + var out []string = []string{} + for _, s := range strings.Split(value, separator) { + if len(s) == 0 { + continue + } + out = append(out, strings.TrimSpace(s)) + } + return out +} + func newFilterField(tagValue string) (*FilterField, error) { filterField := FilterField{} - tagValue = strings.TrimSpace(tagValue) - pairs := strings.Split(tagValue, TAG_PAIRS_SEPARATOR) - - for _, pair := range pairs { - kvs := strings.Split(pair, TAG_KEYVALUE_SEPARATOR) + for _, pair := range splitTrim(tagValue, TAG_PAIRS_SEPARATOR) { + kvs := splitTrim(pair, TAG_KEYVALUE_SEPARATOR) if len(kvs) != 2 { return nil, fmt.Errorf("invalid tag value: %s", strings.TrimSpace(pair)) } - key := strings.TrimSpace(kvs[0]) - value := strings.TrimSpace(kvs[1]) + key := kvs[0] + value := kvs[1] switch key { case "field": diff --git a/app/repository/smartfilter/smartfilter_test.go b/app/repository/smartfilter/smartfilter_test.go index 13ab166..0a54e7d 100644 --- a/app/repository/smartfilter/smartfilter_test.go +++ b/app/repository/smartfilter/smartfilter_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "gorm.io/gorm" ) func TestGetFilterFields(t *testing.T) { @@ -101,11 +102,11 @@ func TestGetFilterFields(t *testing.T) { t.Run("Skip nil fields", func(t *testing.T) { type TestFilter struct { - Alive *bool `filterfield:"alive,EQ"` - Id *int64 `filterfield:"id,EQ"` - Ids *[]uint `filterfield:"id,IN"` - IdsNot *[]uint `filterfield:"id,NOT_IN"` - FirstName *string `filterfield:"first_name,EQ"` + Alive *bool `filterfield:"alive;EQ"` + Id *int64 `filterfield:"id;EQ"` + Ids *[]uint `filterfield:"id;IN"` + IdsNot *[]uint `filterfield:"id;NOT_IN"` + FirstName *string `filterfield:"first_name;EQ"` } filter := TestFilter{} result := getFilterFields(filter) @@ -119,7 +120,7 @@ func TestGetFilterFields(t *testing.T) { ) type TestFilter struct { Alive *bool - Id *int64 `funnytag:"created_at,LT"` + Id *int64 `funnytag:"created_at;LT"` } filter := TestFilter{ Alive: &alive, @@ -140,7 +141,7 @@ func TestFilterField(t *testing.T) { testCases := []TagParseTestCase{ { name: "Parse without spaces", - tagValue: "field=field_1,operator=EQ", + tagValue: "field=field_1;operator=EQ", expected: FilterField{ Name: "field_1", Operator: OperatorEQ, @@ -148,7 +149,7 @@ func TestFilterField(t *testing.T) { }, { name: "Parse spaces between pairs", - tagValue: " field=field_2 , operator=LT ", + tagValue: " field=field_2 ; operator=LT ", expected: FilterField{ Name: "field_2", Operator: OperatorLT, @@ -156,7 +157,7 @@ func TestFilterField(t *testing.T) { }, { name: "Parse spaces between around keys and values", - tagValue: "operator = LIKE , field = field_3", + tagValue: "operator = LIKE ; field = field_3", expected: FilterField{ Name: "field_3", Operator: OperatorLIKE, @@ -174,19 +175,19 @@ func TestFilterField(t *testing.T) { } t.Run("Fail on invalid tag value", func(t *testing.T) { - filterField, err := newFilterField("field=field_1=fail, operator=EQ") + filterField, err := newFilterField("field=field_1=fail; operator=EQ") assert.Nil(t, filterField) assert.EqualError(t, err, "invalid tag value: field=field_1=fail") }) t.Run("Fail on invalid operator", func(t *testing.T) { - filterField, err := newFilterField("field=field_1, operator=FAIL") + filterField, err := newFilterField("field=field_1; operator=FAIL") assert.Nil(t, filterField) assert.EqualError(t, err, "unknown operator: FAIL") }) t.Run("Fail on invalid value key", func(t *testing.T) { - filterField, err := newFilterField("failkey=field_1, operator=FAIL") + filterField, err := newFilterField("failkey=field_1; operator=FAIL") assert.Nil(t, filterField) assert.EqualError(t, err, "invalid value key: failkey") }) @@ -203,3 +204,26 @@ func TestFilterField(t *testing.T) { assert.EqualError(t, err, "missing operator in tag: field=field_1") }) } + +type filterWithoutQueryApplier struct{} + +type filterWithQueryApplier struct{} + +func (f filterWithQueryApplier) ApplyQuery(query *gorm.DB) *gorm.DB { + return query +} + +func TestSmartfilterApplyQuery(t *testing.T) { + + t.Run("Get query applier interface - without interface", func(t *testing.T) { + f := filterWithoutQueryApplier{} + queryApplier := getQueryApplierInterface(f) + assert.Nil(t, queryApplier) + }) + + t.Run("Get query applier interface - with interface", func(t *testing.T) { + f := filterWithQueryApplier{} + queryApplier := getQueryApplierInterface(f) + assert.NotNil(t, queryApplier) + }) +}