Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 10 additions & 35 deletions storage/cosmosdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func (s *CosmosDBAdapter) GetLatestMigration() (int, error) {

func (s *CosmosDBAdapter) Create(item any, params ...map[string]any) error {
// Extract provider-specific parameters
paramMap := s.extractParams(params...)
paramMap := extractParams(params...)

containerName := s.getContainerName(item)
containerClient, err := s.databaseClient.NewContainer(containerName)
Expand Down Expand Up @@ -230,7 +230,7 @@ func (s *CosmosDBAdapter) Get(dest any, filter map[string]any, params ...map[str
}

// Extract provider-specific parameters
paramMap := s.extractParams(params...)
paramMap := extractParams(params...)

containerName := s.getContainerName(dest)
containerClient, err := s.databaseClient.NewContainer(containerName)
Expand Down Expand Up @@ -302,7 +302,7 @@ func (s *CosmosDBAdapter) Update(item any, filter map[string]any, params ...map[
}

// Extract provider-specific parameters
paramMap := s.extractParams(params...)
paramMap := extractParams(params...)

containerName := s.getContainerName(item)
containerClient, err := s.databaseClient.NewContainer(containerName)
Expand Down Expand Up @@ -388,7 +388,7 @@ func (s *CosmosDBAdapter) Delete(item any, filter map[string]any, params ...map[
}

// Extract provider-specific parameters
paramMap := s.extractParams(params...)
paramMap := extractParams(params...)

containerName := s.getContainerName(item)
containerClient, err := s.databaseClient.NewContainer(containerName)
Expand Down Expand Up @@ -428,8 +428,8 @@ func (s *CosmosDBAdapter) Delete(item any, filter map[string]any, params ...map[

func (s *CosmosDBAdapter) List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) {
// Extract sort direction from params
paramMap := s.extractParams(params...)
sortDirection := s.extractSortDirection(paramMap)
paramMap := extractParams(params...)
sortDirection := extractSortDirection(paramMap)

return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, filter, params...)
}
Expand All @@ -441,8 +441,8 @@ func (s *CosmosDBAdapter) Search(dest any, sortKey string, query string, limit i
// For custom queries, use the Query method instead

// Extract sort direction from params
paramMap := s.extractParams(params...)
sortDirection := s.extractSortDirection(paramMap)
paramMap := extractParams(params...)
sortDirection := extractSortDirection(paramMap)

// Use executePaginatedQuery with empty filter (the query parameter is ignored for CosmosDB)
return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, map[string]any{}, params...)
Expand Down Expand Up @@ -517,14 +517,14 @@ func (s *CosmosDBAdapter) Query(dest any, statement string, limit int, cursor st
func (s *CosmosDBAdapter) executePaginatedQuery(
dest any,
sortKey string,
sortDirection string,
sortDirection SortingDirection,
limit int,
cursor string,
filter map[string]any,
params ...map[string]any,
) (string, error) {
// Extract provider-specific parameters
paramMap := s.extractParams(params...)
paramMap := extractParams(params...)

containerName := s.getContainerName(dest)
containerClient, err := s.databaseClient.NewContainer(containerName)
Expand Down Expand Up @@ -706,31 +706,6 @@ func (s *CosmosDBAdapter) executeQuery(
return page, err
}

// extractParams merges all provided parameter maps into a single map
func (s *CosmosDBAdapter) extractParams(params ...map[string]any) map[string]any {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did these two methods moved from cosmos to sql? If they are shared shouldn't they be in a shared location like storage.go maybe?

paramMap := make(map[string]any)
for _, param := range params {
for k, v := range param {
paramMap[k] = v
}
}
return paramMap
}

// extractSortDirection extracts and validates sort direction from params
func (s *CosmosDBAdapter) extractSortDirection(paramMap map[string]any) string {
sortDirection := "ASC" // Default to ASC
if dir, exists := paramMap["sort_direction"]; exists {
if dirStr, ok := dir.(string); ok {
sortDirection = strings.ToUpper(dirStr)
if sortDirection != "ASC" && sortDirection != "DESC" {
sortDirection = "ASC" // Fallback to ASC for invalid values
}
}
}
return sortDirection
}

// buildFilter constructs WHERE clause conditions from filter map
func (s *CosmosDBAdapter) buildFilter(filter map[string]any, paramIndex *int) (string, []azcosmos.QueryParameter) {
conditions := []string{}
Expand Down
38 changes: 33 additions & 5 deletions storage/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"log/slog"
"maps"
"reflect"
"strings"
"sync"
Expand Down Expand Up @@ -217,6 +218,7 @@ func (s *SQLAdapter) Delete(item any, filter map[string]any, params ...map[strin
func (s *SQLAdapter) executePaginatedQuery(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing godoc, specifically b/c the function has default behavior.
Also, I think that switch-case would be more appropriate than the current defaultValue-if clause

dest any,
sortKey string,
sortDirection SortingDirection,
limit int,
cursor string,
builder queryBuilder,
Expand All @@ -231,10 +233,14 @@ func (s *SQLAdapter) executePaginatedQuery(
}
q := s.DB.Model(dest).Scopes(builder)

q = q.Limit(limit + 1).Order(fmt.Sprintf("%s ASC", sortKey))
q = q.Limit(limit + 1).Order(fmt.Sprintf("%s %s", sortKey, sortDirection))

if cursorValue != "" {
q = q.Where(fmt.Sprintf("%s > ?", sortKey), cursorValue)
cursorOp := ">"
if sortDirection == Descending {
cursorOp = "<"
}
q = q.Where(fmt.Sprintf("%s %s ?", sortKey, cursorOp), cursorValue)
}

if result := q.Find(dest); result.Error != nil {
Expand Down Expand Up @@ -263,7 +269,8 @@ func (s *SQLAdapter) executePaginatedQuery(
}

func (s *SQLAdapter) List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) {
return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB {
sortDirection := extractSortDirection(extractParams(params...))
return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB {
if len(filter) > 0 {
query, bindings := s.buildQuery(filter)
return q.Where(query, bindings)
Expand All @@ -273,8 +280,9 @@ func (s *SQLAdapter) List(dest any, sortKey string, filter map[string]any, limit
}

func (s *SQLAdapter) Search(dest any, sortKey string, query string, limit int, cursor string, params ...map[string]any) (string, error) {
sortDirection := extractSortDirection(extractParams(params...))
if query == "" {
return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB {
return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB {
return q
})
}
Expand All @@ -296,7 +304,7 @@ func (s *SQLAdapter) Search(dest any, sortKey string, query string, limit int, c

slog.Debug(fmt.Sprintf(`Where clause: %s, with params %s`, whereClause, queryParams))

return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB {
return s.executePaginatedQuery(dest, sortKey, sortDirection, limit, cursor, func(q *gorm.DB) *gorm.DB {
if whereClause != "" {
return q.Where(whereClause, queryParams...)
}
Expand Down Expand Up @@ -324,6 +332,26 @@ func (s *SQLAdapter) Query(dest any, statement string, limit int, cursor string,
return "", fmt.Errorf("not implemented yet")
}

func extractParams(params ...map[string]any) map[string]any {
flatParams := make(map[string]any)
for _, param := range params {
maps.Copy(flatParams, param)
}
return flatParams
}

func extractSortDirection(paramMap map[string]any) SortingDirection {
if dir, exists := paramMap[SortDirectionKey]; exists {
if dirStr, ok := dir.(string); ok {
switch SortingDirection(strings.ToUpper(dirStr)) {
case Descending:
return Descending
}
}
}
return Ascending
}
Comment on lines +343 to +353
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, we should be explicit on what we match and return. and return an error otherwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, let me add this


func (s *SQLAdapter) buildQuery(filter map[string]any) (string, map[string]any) {
clauses := []string{}
bindings := make(map[string]any)
Expand Down
49 changes: 49 additions & 0 deletions storage/sql_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package storage

import (
"maps"
"testing"
)

func TestExtractSortDirection(t *testing.T) {
tests := []struct {
name string
input map[string]any
expected SortingDirection
}{
{"default when missing", map[string]any{}, Ascending},
{"ASC explicit", map[string]any{SortDirectionKey: "ASC"}, Ascending},
{"DESC", map[string]any{SortDirectionKey: "DESC"}, Descending},
{"lowercase desc", map[string]any{SortDirectionKey: "desc"}, Descending},
{"invalid falls back to ASC", map[string]any{SortDirectionKey: "SIDEWAYS"}, Ascending},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractSortDirection(tt.input)
if got != tt.expected {
t.Errorf("extractSortDirection(%v) = %q; want %q", tt.input, got, tt.expected)
}
})
}
}

func TestExtractParams(t *testing.T) {
tests := []struct {
name string
input []map[string]any
expected map[string]any
}{
{"empty input", []map[string]any{}, map[string]any{}},
{"single map", []map[string]any{{"a": 1}}, map[string]any{"a": 1}},
{"two maps merged", []map[string]any{{"a": 1}, {"b": 2}}, map[string]any{"a": 1, "b": 2}},
{"later map wins on collision", []map[string]any{{"a": 1}, {"a": 2}}, map[string]any{"a": 2}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractParams(tt.input...)
if !maps.Equal(got, tt.expected) {
t.Errorf("extractParams(%v) = %v; want %v", tt.input, got, tt.expected)
}
})
}
}
9 changes: 9 additions & 0 deletions storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ const (
COSMOSDB_PROVIDER StorageProviders = "cosmosdb"
)

type SortingDirection string

const (
Ascending SortingDirection = "ASC"
Descending SortingDirection = "DESC"
)

const SortDirectionKey = "sort_direction"

func (s StorageAdapterFactory) GetInstance(adapterType StorageAdapterType, config any) (StorageAdapter, error) {
if config == nil {
config = make(map[string]string)
Expand Down