-
Notifications
You must be signed in to change notification settings - Fork 3
feat: add sorting to sql adapter #145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0b48b63
bdf9c0a
c38344b
c86c808
edb279e
8f1e293
1a05bc9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ import ( | |
| "errors" | ||
| "fmt" | ||
| "log/slog" | ||
| "maps" | ||
| "reflect" | ||
| "strings" | ||
| "sync" | ||
|
|
@@ -217,6 +218,7 @@ func (s *SQLAdapter) Delete(item any, filter map[string]any, params ...map[strin | |
| func (s *SQLAdapter) executePaginatedQuery( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing godoc, specifically b/c the function has default behavior. |
||
| dest any, | ||
| sortKey string, | ||
| sortDirection SortingDirection, | ||
| limit int, | ||
| cursor string, | ||
| builder queryBuilder, | ||
|
|
@@ -231,10 +233,14 @@ func (s *SQLAdapter) executePaginatedQuery( | |
| } | ||
| q := s.DB.Model(dest).Scopes(builder) | ||
|
|
||
| q = q.Limit(limit + 1).Order(fmt.Sprintf("%s ASC", sortKey)) | ||
| q = q.Limit(limit + 1).Order(fmt.Sprintf("%s %s", sortKey, sortDirection)) | ||
|
|
||
| if cursorValue != "" { | ||
| q = q.Where(fmt.Sprintf("%s > ?", sortKey), cursorValue) | ||
| cursorOp := ">" | ||
| if sortDirection == Descending { | ||
| cursorOp = "<" | ||
Lutherwaves marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| q = q.Where(fmt.Sprintf("%s %s ?", sortKey, cursorOp), cursorValue) | ||
| } | ||
|
|
||
| if result := q.Find(dest); result.Error != nil { | ||
|
|
@@ -263,7 +269,8 @@ func (s *SQLAdapter) executePaginatedQuery( | |
| } | ||
|
|
||
| func (s *SQLAdapter) List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) { | ||
| return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB { | ||
| sortDirection := extractSortDirection(extractParams(params...)) | ||
| return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB { | ||
| if len(filter) > 0 { | ||
| query, bindings := s.buildQuery(filter) | ||
| return q.Where(query, bindings) | ||
|
|
@@ -273,8 +280,9 @@ func (s *SQLAdapter) List(dest any, sortKey string, filter map[string]any, limit | |
| } | ||
|
|
||
| func (s *SQLAdapter) Search(dest any, sortKey string, query string, limit int, cursor string, params ...map[string]any) (string, error) { | ||
| sortDirection := extractSortDirection(extractParams(params...)) | ||
| if query == "" { | ||
| return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB { | ||
| return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB { | ||
| return q | ||
| }) | ||
| } | ||
|
|
@@ -296,7 +304,7 @@ func (s *SQLAdapter) Search(dest any, sortKey string, query string, limit int, c | |
|
|
||
| slog.Debug(fmt.Sprintf(`Where clause: %s, with params %s`, whereClause, queryParams)) | ||
|
|
||
| return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB { | ||
| return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB { | ||
| if whereClause != "" { | ||
| return q.Where(whereClause, queryParams...) | ||
| } | ||
|
|
@@ -324,6 +332,26 @@ func (s *SQLAdapter) Query(dest any, statement string, limit int, cursor string, | |
| return "", fmt.Errorf("not implemented yet") | ||
| } | ||
|
|
||
| func extractParams(params ...map[string]any) map[string]any { | ||
| flatParams := make(map[string]any) | ||
| for _, param := range params { | ||
| maps.Copy(flatParams, param) | ||
| } | ||
| return flatParams | ||
| } | ||
|
|
||
| func extractSortDirection(paramMap map[string]any) SortingDirection { | ||
| if dir, exists := paramMap[SortDirectionKey]; exists { | ||
| if dirStr, ok := dir.(string); ok { | ||
| switch SortingDirection(strings.ToUpper(dirStr)) { | ||
| case Descending: | ||
| return Descending | ||
| } | ||
| } | ||
| } | ||
| return Ascending | ||
| } | ||
|
Comment on lines
+343
to
+353
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMHO, we should be explicit on what we match and return. and return an error otherwise.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, let me add this |
||
|
|
||
| func (s *SQLAdapter) buildQuery(filter map[string]any) (string, map[string]any) { | ||
| clauses := []string{} | ||
| bindings := make(map[string]any) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| package storage | ||
|
|
||
| import ( | ||
| "maps" | ||
| "testing" | ||
| ) | ||
|
|
||
| func TestExtractSortDirection(t *testing.T) { | ||
| tests := []struct { | ||
| name string | ||
| input map[string]any | ||
| expected SortingDirection | ||
| }{ | ||
| {"default when missing", map[string]any{}, Ascending}, | ||
| {"ASC explicit", map[string]any{SortDirectionKey: "ASC"}, Ascending}, | ||
| {"DESC", map[string]any{SortDirectionKey: "DESC"}, Descending}, | ||
| {"lowercase desc", map[string]any{SortDirectionKey: "desc"}, Descending}, | ||
| {"invalid falls back to ASC", map[string]any{SortDirectionKey: "SIDEWAYS"}, Ascending}, | ||
| } | ||
| for _, tt := range tests { | ||
| t.Run(tt.name, func(t *testing.T) { | ||
| got := extractSortDirection(tt.input) | ||
| if got != tt.expected { | ||
| t.Errorf("extractSortDirection(%v) = %q; want %q", tt.input, got, tt.expected) | ||
| } | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| func TestExtractParams(t *testing.T) { | ||
| tests := []struct { | ||
| name string | ||
| input []map[string]any | ||
| expected map[string]any | ||
| }{ | ||
| {"empty input", []map[string]any{}, map[string]any{}}, | ||
| {"single map", []map[string]any{{"a": 1}}, map[string]any{"a": 1}}, | ||
| {"two maps merged", []map[string]any{{"a": 1}, {"b": 2}}, map[string]any{"a": 1, "b": 2}}, | ||
| {"later map wins on collision", []map[string]any{{"a": 1}, {"a": 2}}, map[string]any{"a": 2}}, | ||
| } | ||
| for _, tt := range tests { | ||
| t.Run(tt.name, func(t *testing.T) { | ||
| got := extractParams(tt.input...) | ||
| if !maps.Equal(got, tt.expected) { | ||
| t.Errorf("extractParams(%v) = %v; want %v", tt.input, got, tt.expected) | ||
| } | ||
| }) | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did these two methods moved from cosmos to sql? If they are shared shouldn't they be in a shared location like storage.go maybe?