diff --git a/storage/cosmosdb.go b/storage/cosmosdb.go index 3474b6f..0a0b618 100644 --- a/storage/cosmosdb.go +++ b/storage/cosmosdb.go @@ -164,7 +164,7 @@ func (s *CosmosDBAdapter) GetLatestMigration() (int, error) { func (s *CosmosDBAdapter) Create(item any, params ...map[string]any) error { // Extract provider-specific parameters - paramMap := s.extractParams(params...) + paramMap := extractParams(params...) containerName := s.getContainerName(item) containerClient, err := s.databaseClient.NewContainer(containerName) @@ -230,7 +230,7 @@ func (s *CosmosDBAdapter) Get(dest any, filter map[string]any, params ...map[str } // Extract provider-specific parameters - paramMap := s.extractParams(params...) + paramMap := extractParams(params...) containerName := s.getContainerName(dest) containerClient, err := s.databaseClient.NewContainer(containerName) @@ -302,7 +302,7 @@ func (s *CosmosDBAdapter) Update(item any, filter map[string]any, params ...map[ } // Extract provider-specific parameters - paramMap := s.extractParams(params...) + paramMap := extractParams(params...) containerName := s.getContainerName(item) containerClient, err := s.databaseClient.NewContainer(containerName) @@ -388,7 +388,7 @@ func (s *CosmosDBAdapter) Delete(item any, filter map[string]any, params ...map[ } // Extract provider-specific parameters - paramMap := s.extractParams(params...) + paramMap := extractParams(params...) containerName := s.getContainerName(item) containerClient, err := s.databaseClient.NewContainer(containerName) @@ -428,8 +428,8 @@ func (s *CosmosDBAdapter) Delete(item any, filter map[string]any, params ...map[ func (s *CosmosDBAdapter) List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) { // Extract sort direction from params - paramMap := s.extractParams(params...) - sortDirection := s.extractSortDirection(paramMap) + paramMap := extractParams(params...) + sortDirection := extractSortDirection(paramMap) return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, filter, params...) } @@ -441,8 +441,8 @@ func (s *CosmosDBAdapter) Search(dest any, sortKey string, query string, limit i // For custom queries, use the Query method instead // Extract sort direction from params - paramMap := s.extractParams(params...) - sortDirection := s.extractSortDirection(paramMap) + paramMap := extractParams(params...) + sortDirection := extractSortDirection(paramMap) // Use executePaginatedQuery with empty filter (the query parameter is ignored for CosmosDB) return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, map[string]any{}, params...) @@ -517,14 +517,14 @@ func (s *CosmosDBAdapter) Query(dest any, statement string, limit int, cursor st func (s *CosmosDBAdapter) executePaginatedQuery( dest any, sortKey string, - sortDirection string, + sortDirection SortingDirection, limit int, cursor string, filter map[string]any, params ...map[string]any, ) (string, error) { // Extract provider-specific parameters - paramMap := s.extractParams(params...) + paramMap := extractParams(params...) containerName := s.getContainerName(dest) containerClient, err := s.databaseClient.NewContainer(containerName) @@ -706,31 +706,6 @@ func (s *CosmosDBAdapter) executeQuery( return page, err } -// extractParams merges all provided parameter maps into a single map -func (s *CosmosDBAdapter) extractParams(params ...map[string]any) map[string]any { - paramMap := make(map[string]any) - for _, param := range params { - for k, v := range param { - paramMap[k] = v - } - } - return paramMap -} - -// extractSortDirection extracts and validates sort direction from params -func (s *CosmosDBAdapter) extractSortDirection(paramMap map[string]any) string { - sortDirection := "ASC" // Default to ASC - if dir, exists := paramMap["sort_direction"]; exists { - if dirStr, ok := dir.(string); ok { - sortDirection = strings.ToUpper(dirStr) - if sortDirection != "ASC" && sortDirection != "DESC" { - sortDirection = "ASC" // Fallback to ASC for invalid values - } - } - } - return sortDirection -} - // buildFilter constructs WHERE clause conditions from filter map func (s *CosmosDBAdapter) buildFilter(filter map[string]any, paramIndex *int) (string, []azcosmos.QueryParameter) { conditions := []string{} diff --git a/storage/sql.go b/storage/sql.go index 8313e11..ea98546 100644 --- a/storage/sql.go +++ b/storage/sql.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log/slog" + "maps" "reflect" "strings" "sync" @@ -217,6 +218,7 @@ func (s *SQLAdapter) Delete(item any, filter map[string]any, params ...map[strin func (s *SQLAdapter) executePaginatedQuery( dest any, sortKey string, + sortDirection SortingDirection, limit int, cursor string, builder queryBuilder, @@ -231,10 +233,14 @@ func (s *SQLAdapter) executePaginatedQuery( } q := s.DB.Model(dest).Scopes(builder) - q = q.Limit(limit + 1).Order(fmt.Sprintf("%s ASC", sortKey)) + q = q.Limit(limit + 1).Order(fmt.Sprintf("%s %s", sortKey, sortDirection)) if cursorValue != "" { - q = q.Where(fmt.Sprintf("%s > ?", sortKey), cursorValue) + cursorOp := ">" + if sortDirection == Descending { + cursorOp = "<" + } + q = q.Where(fmt.Sprintf("%s %s ?", sortKey, cursorOp), cursorValue) } if result := q.Find(dest); result.Error != nil { @@ -263,7 +269,8 @@ func (s *SQLAdapter) executePaginatedQuery( } func (s *SQLAdapter) List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) { - return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB { + sortDirection := extractSortDirection(extractParams(params...)) + return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB { if len(filter) > 0 { query, bindings := s.buildQuery(filter) return q.Where(query, bindings) @@ -273,8 +280,9 @@ func (s *SQLAdapter) List(dest any, sortKey string, filter map[string]any, limit } func (s *SQLAdapter) Search(dest any, sortKey string, query string, limit int, cursor string, params ...map[string]any) (string, error) { + sortDirection := extractSortDirection(extractParams(params...)) if query == "" { - return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB { + return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB { return q }) } @@ -296,7 +304,7 @@ func (s *SQLAdapter) Search(dest any, sortKey string, query string, limit int, c slog.Debug(fmt.Sprintf(`Where clause: %s, with params %s`, whereClause, queryParams)) - return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB { + return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB { if whereClause != "" { return q.Where(whereClause, queryParams...) } @@ -324,6 +332,26 @@ func (s *SQLAdapter) Query(dest any, statement string, limit int, cursor string, return "", fmt.Errorf("not implemented yet") } +func extractParams(params ...map[string]any) map[string]any { + flatParams := make(map[string]any) + for _, param := range params { + maps.Copy(flatParams, param) + } + return flatParams +} + +func extractSortDirection(paramMap map[string]any) SortingDirection { + if dir, exists := paramMap[SortDirectionKey]; exists { + if dirStr, ok := dir.(string); ok { + switch SortingDirection(strings.ToUpper(dirStr)) { + case Descending: + return Descending + } + } + } + return Ascending +} + func (s *SQLAdapter) buildQuery(filter map[string]any) (string, map[string]any) { clauses := []string{} bindings := make(map[string]any) diff --git a/storage/sql_test.go b/storage/sql_test.go new file mode 100644 index 0000000..98fea16 --- /dev/null +++ b/storage/sql_test.go @@ -0,0 +1,49 @@ +package storage + +import ( + "maps" + "testing" +) + +func TestExtractSortDirection(t *testing.T) { + tests := []struct { + name string + input map[string]any + expected SortingDirection + }{ + {"default when missing", map[string]any{}, Ascending}, + {"ASC explicit", map[string]any{SortDirectionKey: "ASC"}, Ascending}, + {"DESC", map[string]any{SortDirectionKey: "DESC"}, Descending}, + {"lowercase desc", map[string]any{SortDirectionKey: "desc"}, Descending}, + {"invalid falls back to ASC", map[string]any{SortDirectionKey: "SIDEWAYS"}, Ascending}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractSortDirection(tt.input) + if got != tt.expected { + t.Errorf("extractSortDirection(%v) = %q; want %q", tt.input, got, tt.expected) + } + }) + } +} + +func TestExtractParams(t *testing.T) { + tests := []struct { + name string + input []map[string]any + expected map[string]any + }{ + {"empty input", []map[string]any{}, map[string]any{}}, + {"single map", []map[string]any{{"a": 1}}, map[string]any{"a": 1}}, + {"two maps merged", []map[string]any{{"a": 1}, {"b": 2}}, map[string]any{"a": 1, "b": 2}}, + {"later map wins on collision", []map[string]any{{"a": 1}, {"a": 2}}, map[string]any{"a": 2}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractParams(tt.input...) + if !maps.Equal(got, tt.expected) { + t.Errorf("extractParams(%v) = %v; want %v", tt.input, got, tt.expected) + } + }) + } +} diff --git a/storage/storage.go b/storage/storage.go index 2414a3f..8bdaaf8 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -47,6 +47,15 @@ const ( COSMOSDB_PROVIDER StorageProviders = "cosmosdb" ) +type SortingDirection string + +const ( + Ascending SortingDirection = "ASC" + Descending SortingDirection = "DESC" +) + +const SortDirectionKey = "sort_direction" + func (s StorageAdapterFactory) GetInstance(adapterType StorageAdapterType, config any) (StorageAdapter, error) { if config == nil { config = make(map[string]string)