From bd510d958bffb22dc8ce5464675c7fce88ace888 Mon Sep 17 00:00:00 2001 From: Eden Kirin Date: Mon, 24 Jun 2024 01:08:39 +0200 Subject: [PATCH] Only option --- app/main.go | 4 +- app/repository/method_get.go | 10 +++- app/repository/method_list.go | 58 ++++--------------- app/repository/options.go | 61 ++++++++++++++++++++ app/repository/repository_test.go | 96 +++++++++++++++++++++---------- 5 files changed, 147 insertions(+), 82 deletions(-) create mode 100644 app/repository/options.go diff --git a/app/main.go b/app/main.go index 043e78c..8d2834e 100644 --- a/app/main.go +++ b/app/main.go @@ -72,7 +72,7 @@ func doList(db *gorm.DB) { Alive: &TRUE, } - certs, err := repo.List(filter, nil, nil) + certs, err := repo.List(filter, nil) if err != nil { panic(err) } @@ -92,7 +92,7 @@ func doGet(db *gorm.DB) { Id: &id, } - cert, err := repo.Get(filter) + cert, err := repo.Get(filter, nil) if err != nil { panic(err) } diff --git a/app/repository/method_get.go b/app/repository/method_get.go index 700f4b1..95635e3 100644 --- a/app/repository/method_get.go +++ b/app/repository/method_get.go @@ -7,6 +7,10 @@ import ( "gorm.io/gorm/schema" ) +type GetOptions struct { + Only *[]string +} + type GetMethod[T schema.Tabler] struct { DbConn *gorm.DB } @@ -15,7 +19,7 @@ func (m *GetMethod[T]) Init(dbConn *gorm.DB) { m.DbConn = dbConn } -func (m GetMethod[T]) Get(filter interface{}) (*T, error) { +func (m GetMethod[T]) Get(filter interface{}, options *GetOptions) (*T, error) { var ( model T ) @@ -25,6 +29,10 @@ func (m GetMethod[T]) Get(filter interface{}) (*T, error) { return nil, err } + if options != nil { + query = applyOptionOnly(query, options.Only) + } + result := query.First(&model) if result.Error != nil { return nil, result.Error diff --git a/app/repository/method_list.go b/app/repository/method_list.go index b39d948..821d054 100644 --- a/app/repository/method_list.go +++ b/app/repository/method_list.go @@ -1,28 +1,16 @@ package repository import ( - "fmt" "repo-pattern/app/repository/smartfilter" "gorm.io/gorm" "gorm.io/gorm/schema" ) -type Pagination struct { - Offset int - Limit int -} - -type OrderDirection string - -const ( - OrderASC OrderDirection = "ASC" - OrderDESC OrderDirection = "DESC" -) - -type Order struct { - Field string - Direction OrderDirection +type ListOptions struct { + Only *[]string + Ordering *[]Order + Pagination *Pagination } type ListMethod[T schema.Tabler] struct { @@ -33,36 +21,7 @@ func (m *ListMethod[T]) Init(dbConn *gorm.DB) { m.DbConn = dbConn } -func applyOrdering(query *gorm.DB, ordering *[]Order) *gorm.DB { - if ordering == nil || len(*ordering) == 0 { - return query - } - - for _, order := range *ordering { - if order.Direction == OrderASC { - query = query.Order(order.Field) - } else { - query = query.Order(fmt.Sprintf("%s %s", order.Field, order.Direction)) - } - } - return query -} - -func applyPagination(query *gorm.DB, pagination *Pagination) *gorm.DB { - if pagination == nil { - return query - } - - if pagination.Limit != 0 { - query = query.Limit(pagination.Limit) - } - if pagination.Offset != 0 { - query = query.Offset(pagination.Offset) - } - return query -} - -func (m ListMethod[T]) List(filter interface{}, ordering *[]Order, pagination *Pagination) (*[]T, error) { +func (m ListMethod[T]) List(filter interface{}, options *ListOptions) (*[]T, error) { var ( model T models []T @@ -73,8 +32,11 @@ func (m ListMethod[T]) List(filter interface{}, ordering *[]Order, pagination *P return nil, err } - query = applyOrdering(query, ordering) - query = applyPagination(query, pagination) + if options != nil { + query = applyOptionOnly(query, options.Only) + query = applyOptionOrdering(query, options.Ordering) + query = applyOptionPagination(query, options.Pagination) + } query.Find(&models) return &models, nil diff --git a/app/repository/options.go b/app/repository/options.go new file mode 100644 index 0000000..d824d7f --- /dev/null +++ b/app/repository/options.go @@ -0,0 +1,61 @@ +package repository + +import ( + "fmt" + + "gorm.io/gorm" +) + +type Pagination struct { + Offset int + Limit int +} + +type OrderDirection string + +const ( + OrderASC OrderDirection = "ASC" + OrderDESC OrderDirection = "DESC" +) + +type Order struct { + Field string + Direction OrderDirection +} + +func applyOptionOnly(query *gorm.DB, only *[]string) *gorm.DB { + if only == nil || len(*only) == 0 { + return query + } + query = query.Select(*only) + return query +} + +func applyOptionOrdering(query *gorm.DB, ordering *[]Order) *gorm.DB { + if ordering == nil || len(*ordering) == 0 { + return query + } + + for _, order := range *ordering { + if len(order.Direction) == 0 || order.Direction == OrderASC { + query = query.Order(order.Field) + } else { + query = query.Order(fmt.Sprintf("%s %s", order.Field, order.Direction)) + } + } + return query +} + +func applyOptionPagination(query *gorm.DB, pagination *Pagination) *gorm.DB { + if pagination == nil { + return query + } + + if pagination.Limit != 0 { + query = query.Limit(pagination.Limit) + } + if pagination.Offset != 0 { + query = query.Offset(pagination.Offset) + } + return query +} diff --git a/app/repository/repository_test.go b/app/repository/repository_test.go index aadc283..690640f 100644 --- a/app/repository/repository_test.go +++ b/app/repository/repository_test.go @@ -56,21 +56,23 @@ func TestListMethod(t *testing.T) { repo.Init(db) filter := MyModelFilter{} - ordering := []Order{ - { - Field: "id", - Direction: OrderASC, - }, - { - Field: "count", - Direction: OrderDESC, + options := ListOptions{ + Ordering: &[]Order{ + { + Field: "id", + Direction: OrderASC, + }, + { + Field: "count", + Direction: OrderDESC, + }, }, } sql := "SELECT * FROM my_models ORDER BY id,count DESC" mock.ExpectQuery(fmt.Sprintf("^%s$", regexp.QuoteMeta(sql))) - _, err := repo.List(filter, &ordering, nil) + _, err := repo.List(filter, &options) assert.Nil(t, err) if err := mock.ExpectationsWereMet(); err != nil { @@ -86,16 +88,17 @@ func TestListMethod(t *testing.T) { repo.Init(db) filter := MyModelFilter{} - pagination := Pagination{ - Limit: 111, - Offset: 0, + options := ListOptions{ + Pagination: &Pagination{ + Limit: 111, + }, } sql := "SELECT * FROM my_models LIMIT $1" mock.ExpectQuery(fmt.Sprintf("^%s$", regexp.QuoteMeta(sql))). - WithArgs(pagination.Limit) + WithArgs(options.Pagination.Limit) - _, err := repo.List(filter, nil, &pagination) + _, err := repo.List(filter, &options) assert.Nil(t, err) if err := mock.ExpectationsWereMet(); err != nil { @@ -111,16 +114,17 @@ func TestListMethod(t *testing.T) { repo.Init(db) filter := MyModelFilter{} - pagination := Pagination{ - Limit: 0, - Offset: 222, + options := ListOptions{ + Pagination: &Pagination{ + Offset: 222, + }, } sql := "SELECT * FROM my_models OFFSET $1" mock.ExpectQuery(fmt.Sprintf("^%s$", regexp.QuoteMeta(sql))). - WithArgs(pagination.Offset) + WithArgs(options.Pagination.Offset) - _, err := repo.List(filter, nil, &pagination) + _, err := repo.List(filter, &options) assert.Nil(t, err) if err := mock.ExpectationsWereMet(); err != nil { @@ -136,16 +140,18 @@ func TestListMethod(t *testing.T) { repo.Init(db) filter := MyModelFilter{} - pagination := Pagination{ - Limit: 111, - Offset: 222, + options := ListOptions{ + Pagination: &Pagination{ + Limit: 111, + Offset: 222, + }, } sql := "SELECT * FROM my_models LIMIT $1 OFFSET $2" mock.ExpectQuery(fmt.Sprintf("^%s$", regexp.QuoteMeta(sql))). - WithArgs(pagination.Limit, pagination.Offset) + WithArgs(options.Pagination.Limit, options.Pagination.Offset) - _, err := repo.List(filter, nil, &pagination) + _, err := repo.List(filter, &options) assert.Nil(t, err) if err := mock.ExpectationsWereMet(); err != nil { @@ -169,7 +175,7 @@ func TestListMethod(t *testing.T) { mock.ExpectQuery(fmt.Sprintf("^%s$", regexp.QuoteMeta(sql))). WithArgs(id) - _, err := repo.List(filter, nil, nil) + _, err := repo.List(filter, nil) assert.Nil(t, err) if err := mock.ExpectationsWereMet(); err != nil { @@ -197,7 +203,7 @@ func TestListMethod(t *testing.T) { mock.ExpectQuery(fmt.Sprintf("^%s$", regexp.QuoteMeta(sql))). WithArgs(id, value, count) - _, err := repo.List(filter, nil, nil) + _, err := repo.List(filter, nil) assert.Nil(t, err) if err := mock.ExpectationsWereMet(); err != nil { @@ -220,16 +226,44 @@ func TestListMethod(t *testing.T) { Value: &value, Count: &count, } - pagination := Pagination{ - Offset: 111, - Limit: 222, + options := ListOptions{ + Pagination: &Pagination{ + Offset: 111, + Limit: 222, + }, } sql := "SELECT * FROM my_models WHERE my_models.id = $1 AND my_models.value = $2 AND my_models.count > $3 LIMIT $4 OFFSET $5" mock.ExpectQuery(fmt.Sprintf("^%s$", regexp.QuoteMeta(sql))). - WithArgs(id, value, count, pagination.Limit, pagination.Offset) + WithArgs(id, value, count, options.Pagination.Limit, options.Pagination.Offset) - _, err := repo.List(filter, nil, &pagination) + _, err := repo.List(filter, &options) + assert.Nil(t, err) + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + }) + + t.Run("Only id and count", func(t *testing.T) { + sqldb, db, mock := NewMockDB() + defer sqldb.Close() + + repo := RepoBase[MyModel]{} + repo.Init(db) + + filter := MyModelFilter{} + options := ListOptions{ + Only: &[]string{ + "id", + "count", + }, + } + + sql := "SELECT id,count FROM my_models" + mock.ExpectQuery(fmt.Sprintf("^%s$", regexp.QuoteMeta(sql))) + + _, err := repo.List(filter, &options) assert.Nil(t, err) if err := mock.ExpectationsWereMet(); err != nil {