diff --git a/go.mod b/go.mod index 0039d51..97eceeb 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.19.7 github.com/aws/aws-sdk-go-v2/service/dynamodb v1.55.0 github.com/aws/aws-sdk-go-v2/service/sns v1.39.11 + github.com/grindlemire/go-lucene v0.0.26 gopkg.in/yaml.v3 v3.0.1 gorm.io/gorm v1.31.1 ) diff --git a/go.sum b/go.sum index b736785..85fc238 100644 --- a/go.sum +++ b/go.sum @@ -75,6 +75,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grindlemire/go-lucene v0.0.26 h1:81ttZkMvU3rFD0TfmjdIZT2U0Fd4TT7buDy+xq1x5EQ= +github.com/grindlemire/go-lucene v0.0.26/go.mod h1:INRJBdhkLjS4jc7XgkGPfzC5wuFg3BHDukXMTc+OTbc= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= diff --git a/storage/dynamodb.go b/storage/dynamodb.go index c7cbc33..775192a 100644 --- a/storage/dynamodb.go +++ b/storage/dynamodb.go @@ -239,7 +239,7 @@ func (s *DynamoDBAdapter) Search(dest any, sortKey string, query string, limit i // Parse Lucene query destType := reflect.TypeOf(dest).Elem().Elem() model := reflect.New(destType).Elem().Interface() - parser, _ := lucene.NewParserFromType(model) + parser, _ := lucene.NewParser(model) whereClause, params, _ := parser.ParseToDynamoDBPartiQL(query) // Build query diff --git a/storage/search/lucene/dynamodb_driver.go b/storage/search/lucene/dynamodb_driver.go new file mode 100644 index 0000000..d7cf42d --- /dev/null +++ b/storage/search/lucene/dynamodb_driver.go @@ -0,0 +1,137 @@ +package lucene + +import ( + "fmt" + "regexp" + "strings" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/grindlemire/go-lucene/pkg/driver" + "github.com/grindlemire/go-lucene/pkg/lucene/expr" +) + +// DynamoDBPartiQLDriver converts Lucene queries to DynamoDB PartiQL. +type DynamoDBPartiQLDriver struct { + driver.Base + fields map[string]FieldInfo +} + +func NewDynamoDBDriver(fields []FieldInfo) (*DynamoDBPartiQLDriver, error) { + fieldMap, err := buildFieldMap(fields) + if err != nil { + return nil, err + } + + fns := map[expr.Operator]driver.RenderFN{ + expr.Literal: driver.Shared[expr.Literal], + expr.And: driver.Shared[expr.And], + expr.Or: driver.Shared[expr.Or], + expr.Not: driver.Shared[expr.Not], + expr.Equals: driver.Shared[expr.Equals], + expr.Range: driver.Shared[expr.Range], + expr.Must: driver.Shared[expr.Must], + expr.MustNot: driver.Shared[expr.MustNot], + expr.Wild: driver.Shared[expr.Wild], + expr.Regexp: driver.Shared[expr.Regexp], + expr.Like: dynamoDBLike, // Custom LIKE for DynamoDB functions + expr.Greater: driver.Shared[expr.Greater], + expr.GreaterEq: driver.Shared[expr.GreaterEq], + expr.Less: driver.Shared[expr.Less], + expr.LessEq: driver.Shared[expr.LessEq], + expr.In: driver.Shared[expr.In], + expr.List: driver.Shared[expr.List], + } + + return &DynamoDBPartiQLDriver{ + Base: driver.Base{ + RenderFNs: fns, + }, + fields: fieldMap, + }, nil +} + +// RenderPartiQL renders the expression to DynamoDB PartiQL with AttributeValue parameters. +func (d *DynamoDBPartiQLDriver) RenderPartiQL(e *expr.Expression) (string, []types.AttributeValue, error) { + // Use base rendering with ? placeholders + str, params, err := d.RenderParam(e) + if err != nil { + return "", nil, err + } + + // Convert params to DynamoDB AttributeValues + attrValues := make([]types.AttributeValue, len(params)) + for i, param := range params { + attrValues[i] = &types.AttributeValueMemberS{Value: fmt.Sprintf("%v", param)} + } + + return str, attrValues, nil +} + +// escapePartiQLString escapes a string value for safe use in PartiQL string literals. +// Escapes single quotes by doubling them (PartiQL standard). +func escapePartiQLString(s string) string { + return strings.ReplaceAll(s, "'", "''") +} + +var ( + // partiQLIdentifierPattern matches valid PartiQL identifiers (alphanumeric and underscore only) + partiQLIdentifierPattern = regexp.MustCompile(`^[a-zA-Z0-9_]+$`) +) + +// escapePartiQLIdentifier escapes a field name for safe use in PartiQL. +// Validates that the identifier contains only safe characters (alphanumeric, underscore). +// Returns error if identifier contains potentially dangerous characters. +func escapePartiQLIdentifier(identifier string) (string, error) { + if !partiQLIdentifierPattern.MatchString(identifier) { + return "", fmt.Errorf("invalid identifier: contains unsafe characters (only alphanumeric and underscore allowed)") + } + return identifier, nil +} + +// unquotePartiQLString safely removes surrounding quotes from a PartiQL string literal. +// Handles already-escaped quotes correctly. +func unquotePartiQLString(s string) string { + if len(s) >= 2 && s[0] == '\'' && s[len(s)-1] == '\'' { + return s[1 : len(s)-1] + } + return s +} + +// dynamoDBLike implements LIKE using DynamoDB's begins_with and contains functions. +func dynamoDBLike(left, right string) (string, error) { + // Validate and escape field name (left) + safeLeft, err := escapePartiQLIdentifier(left) + if err != nil { + return "", fmt.Errorf("invalid field name: %w", err) + } + + // Extract the raw value from the right side (remove quotes if present) + rawValue := unquotePartiQLString(right) + + // Analyze pattern for wildcards + hasPrefix := strings.HasPrefix(rawValue, "%") + hasSuffix := strings.HasSuffix(rawValue, "%") + + if hasPrefix && hasSuffix { + // %value% -> contains(field, value) + value := strings.Trim(rawValue, "%") + escapedValue := escapePartiQLString(value) + return fmt.Sprintf("contains(%s, '%s')", safeLeft, escapedValue), nil + } + if !hasPrefix && hasSuffix { + // value% -> begins_with(field, value) + value := strings.TrimSuffix(rawValue, "%") + escapedValue := escapePartiQLString(value) + return fmt.Sprintf("begins_with(%s, '%s')", safeLeft, escapedValue), nil + } + if hasPrefix && !hasSuffix { + // %value -> contains(field, value) (DynamoDB doesn't have ends_with) + value := strings.TrimPrefix(rawValue, "%") + escapedValue := escapePartiQLString(value) + return fmt.Sprintf("contains(%s, '%s')", safeLeft, escapedValue), nil + } + + // Exact match - escape the value and wrap in quotes + escapedValue := escapePartiQLString(rawValue) + return fmt.Sprintf("%s = '%s'", safeLeft, escapedValue), nil +} diff --git a/storage/search/lucene/dynamodb_driver_test.go b/storage/search/lucene/dynamodb_driver_test.go new file mode 100644 index 0000000..8ccf2e5 --- /dev/null +++ b/storage/search/lucene/dynamodb_driver_test.go @@ -0,0 +1,498 @@ +package lucene + +import ( + "reflect" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + "github.com/grindlemire/go-lucene/pkg/lucene/expr" +) + +func TestNewDynamoDBDriver(t *testing.T) { + tests := []struct { + name string + fields []FieldInfo + want map[string]FieldInfo + wantErr bool + }{ + { + name: "empty fields", + fields: []FieldInfo{}, + want: map[string]FieldInfo{}, + wantErr: false, + }, + { + name: "single field", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + }, + want: map[string]FieldInfo{ + "name": {Name: "name", Type: reflect.TypeOf("")}, + }, + wantErr: false, + }, + { + name: "multiple fields", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "email", Type: reflect.TypeOf("")}, + {Name: "age", Type: reflect.TypeOf(0)}, + }, + want: map[string]FieldInfo{ + "name": {Name: "name", Type: reflect.TypeOf("")}, + "email": {Name: "email", Type: reflect.TypeOf("")}, + "age": {Name: "age", Type: reflect.TypeOf(0)}, + }, + wantErr: false, + }, + { + name: "duplicate field names returns error", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "name", Type: reflect.TypeOf(0)}, + }, + want: nil, + wantErr: true, + }, + { + name: "multiple duplicate field names", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "email", Type: reflect.TypeOf("")}, + {Name: "name", Type: reflect.TypeOf(0)}, + }, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver, err := NewDynamoDBDriver(tt.fields) + if (err != nil) != tt.wantErr { + t.Errorf("NewDynamoDBDriver() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + if err == nil { + t.Errorf("NewDynamoDBDriver() expected error but got nil") + } + if driver != nil { + t.Errorf("NewDynamoDBDriver() expected nil driver on error, got %v", driver) + } + if err != nil && !strings.Contains(err.Error(), "duplicate field name") { + t.Errorf("NewDynamoDBDriver() error message should contain 'duplicate field name', got: %v", err) + } + return + } + if driver == nil { + t.Fatalf("NewDynamoDBDriver() returned nil") + } + if len(driver.fields) != len(tt.want) { + t.Errorf("NewDynamoDBDriver() fields count = %v, want %v", len(driver.fields), len(tt.want)) + } + for name, wantField := range tt.want { + gotField, exists := driver.fields[name] + if !exists { + t.Errorf("NewDynamoDBDriver() missing field %v", name) + continue + } + if gotField.Name != wantField.Name { + t.Errorf("NewDynamoDBDriver() field[%v].Name = %v, want %v", name, gotField.Name, wantField.Name) + } + } + }) + } +} + +func TestDynamoDBDriver_RenderPartiQL(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "email", Type: reflect.TypeOf("")}, + {Name: "age", Type: reflect.TypeOf(0)}, + } + driver, err := NewDynamoDBDriver(fields) + if err != nil { + t.Fatalf("NewDynamoDBDriver() error = %v", err) + } + + tests := []struct { + name string + expr *expr.Expression + wantSQL string + wantCount int + wantErr bool + }{ + { + name: "equals expression", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: "name", + wantCount: 1, + wantErr: false, + }, + { + name: "AND expression", + expr: &expr.Expression{ + Op: expr.And, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + Right: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("email"), + Right: &expr.Expression{Op: expr.Literal, Left: "test@example.com"}, + }, + }, + wantSQL: "AND", + wantCount: 2, + wantErr: false, + }, + { + name: "OR expression", + expr: &expr.Expression{ + Op: expr.Or, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + Right: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "jane"}, + }, + }, + wantSQL: "OR", + wantCount: 2, + wantErr: false, + }, + { + name: "LIKE expression", + expr: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "%john%"}, + }, + wantSQL: "name", + wantCount: 1, + wantErr: false, + }, + { + name: "nil expression", + expr: nil, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + partiql, attrs, err := driver.RenderPartiQL(tt.expr) + if (err != nil) != tt.wantErr { + t.Errorf("RenderPartiQL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if tt.expr == nil { + if partiql != "" { + t.Errorf("RenderPartiQL() partiql = %v, want empty string", partiql) + } + if len(attrs) != 0 { + t.Errorf("RenderPartiQL() attrs count = %v, want 0", len(attrs)) + } + return + } + if !strings.Contains(partiql, tt.wantSQL) { + t.Errorf("RenderPartiQL() partiql = %v, want to contain %v", partiql, tt.wantSQL) + } + if len(attrs) != tt.wantCount { + t.Errorf("RenderPartiQL() attrs count = %v, want %v", len(attrs), tt.wantCount) + } + for i, attr := range attrs { + if attr == nil { + t.Errorf("RenderPartiQL() attrs[%v] is nil", i) + } + if _, ok := attr.(*types.AttributeValueMemberS); !ok { + t.Errorf("RenderPartiQL() attrs[%v] type = %T, want *types.AttributeValueMemberS", i, attr) + } + } + }) + } +} + +func TestDynamoDBLike(t *testing.T) { + tests := []struct { + name string + left string + right string + want string + wantErr bool + }{ + { + name: "contains pattern %value%", + left: "name", + right: "'%john%'", + want: "contains(name, 'john')", + wantErr: false, + }, + { + name: "begins_with pattern value%", + left: "name", + right: "'john%'", + want: "begins_with(name, 'john')", + wantErr: false, + }, + { + name: "contains pattern %value (no ends_with)", + left: "name", + right: "'%john'", + want: "contains(name, 'john')", + wantErr: false, + }, + { + name: "exact match (no wildcards)", + left: "name", + right: "'john'", + want: "name = 'john'", + wantErr: false, + }, + { + name: "empty string value", + left: "name", + right: "''", + want: "name = ''", + wantErr: false, + }, + { + name: "single % at start", + left: "name", + right: "'%'", + want: "contains(name, '')", + wantErr: false, + }, + { + name: "single % at end", + left: "name", + right: "'%'", + want: "contains(name, '')", + wantErr: false, + }, + { + name: "value with special characters", + left: "email", + right: "'test@example.com%'", + want: "begins_with(email, 'test@example.com')", + wantErr: false, + }, + { + name: "value with underscores", + left: "field_name", + right: "'%test_value%'", + want: "contains(field_name, 'test_value')", + wantErr: false, + }, + { + name: "unquoted value (no quotes in pattern)", + left: "name", + right: "john", + want: "name = 'john'", + wantErr: false, + }, + { + name: "multiple % in middle", + left: "name", + right: "'%john%doe%'", + want: "contains(name, 'john%doe')", + wantErr: false, + }, + { + name: "only % characters", + left: "name", + right: "'%%%'", + want: "contains(name, '')", + wantErr: false, + }, + { + name: "value with single quote in exact match", + left: "name", + right: "'John's'", + want: "name = 'John''s'", + wantErr: false, + }, + { + name: "value with single quote and wildcard prefix", + left: "name", + right: "'%test'value'", + want: "contains(name, 'test''value')", + wantErr: false, + }, + { + name: "value with single quote and wildcard suffix", + left: "name", + right: "'test'value%'", + want: "begins_with(name, 'test''value')", + wantErr: false, + }, + { + name: "value with single quote and wildcards both sides", + left: "name", + right: "'%test'value%'", + want: "contains(name, 'test''value')", + wantErr: false, + }, + { + name: "value with multiple single quotes", + left: "name", + right: "'O'Brien'", + want: "name = 'O''Brien'", + wantErr: false, + }, + { + name: "injection attempt: value with quote and OR (should be escaped)", + left: "name", + right: "'test') OR (1=1'", + want: "name = 'test'') OR (1=1'", + wantErr: false, + }, + { + name: "invalid field name with special characters", + left: "name; DROP TABLE users;--", + right: "'test'", + want: "", + wantErr: true, + }, + { + name: "invalid field name with quotes", + left: "name'", + right: "'test'", + want: "", + wantErr: true, + }, + { + name: "invalid field name with spaces", + left: "field name", + right: "'test'", + want: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := dynamoDBLike(tt.left, tt.right) + if (err != nil) != tt.wantErr { + t.Errorf("dynamoDBLike() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.want { + t.Errorf("dynamoDBLike() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDynamoDBDriver_EdgeCases(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + } + driver, err := NewDynamoDBDriver(fields) + if err != nil { + t.Fatalf("NewDynamoDBDriver() error = %v", err) + } + + tests := []struct { + name string + expr *expr.Expression + wantErr bool + checkFunc func(t *testing.T, partiql string, attrs []types.AttributeValue) + }{ + { + name: "empty string value", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: ""}, + }, + wantErr: false, + checkFunc: func(t *testing.T, partiql string, attrs []types.AttributeValue) { + if len(attrs) != 1 { + t.Errorf("expected 1 attribute, got %d", len(attrs)) + } + }, + }, + { + name: "LIKE with empty pattern", + expr: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: ""}, + }, + wantErr: false, + checkFunc: func(t *testing.T, partiql string, attrs []types.AttributeValue) { + if !strings.Contains(partiql, "name") { + t.Errorf("expected partiql to contain 'name', got %v", partiql) + } + }, + }, + { + name: "nested AND with LIKE", + expr: &expr.Expression{ + Op: expr.And, + Left: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "%john%"}, + }, + Right: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "jane%"}, + }, + }, + wantErr: false, + checkFunc: func(t *testing.T, partiql string, attrs []types.AttributeValue) { + if !strings.Contains(partiql, "AND") { + t.Errorf("expected partiql to contain 'AND', got %v", partiql) + } + if len(attrs) < 2 { + t.Errorf("expected at least 2 attributes, got %d", len(attrs)) + } + }, + }, + { + name: "comparison operators", + expr: &expr.Expression{ + Op: expr.Greater, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "a"}, + }, + wantErr: false, + checkFunc: func(t *testing.T, partiql string, attrs []types.AttributeValue) { + if !strings.Contains(partiql, ">") { + t.Errorf("expected partiql to contain '>', got %v", partiql) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + partiql, attrs, err := driver.RenderPartiQL(tt.expr) + if (err != nil) != tt.wantErr { + t.Errorf("RenderPartiQL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && tt.checkFunc != nil { + tt.checkFunc(t, partiql, attrs) + } + }) + } +} diff --git a/storage/search/lucene/parser.go b/storage/search/lucene/parser.go index 616ef61..27a2b71 100644 --- a/storage/search/lucene/parser.go +++ b/storage/search/lucene/parser.go @@ -1,6 +1,7 @@ package lucene import ( + "errors" "fmt" "log/slog" "reflect" @@ -8,65 +9,108 @@ import ( "strings" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + lucene "github.com/grindlemire/go-lucene" + "github.com/grindlemire/go-lucene/pkg/lucene/expr" ) -type FieldInfo struct { - Name string - IsJSONB bool -} - -type Parser struct { - DefaultFields []FieldInfo -} - -type NodeType int - +// Safety limits for query parsing const ( - NodeTerm NodeType = iota - NodeWildcard - NodeLogical + DefaultMaxQueryLength = 10000 // 10KB - prevents memory exhaustion + DefaultMaxDepth = 20 // Prevents stack overflow from deep nesting + DefaultMaxTerms = 100 // Prevents CPU exhaustion from complex queries ) -type LogicalOperator string +// ParserConfig allows customization of parser behavior and security limits. +type ParserConfig struct { + MaxQueryLength int // 0 = use default (10000) + MaxDepth int // 0 = use default (20) + MaxTerms int // 0 = use default (100) +} -const ( - AND LogicalOperator = "AND" - OR LogicalOperator = "OR" - NOT LogicalOperator = "NOT" -) +// FieldInfo describes a searchable field and its properties. +type FieldInfo struct { + Name string + Type reflect.Type // For validation only + ImplicitSearch bool // Whether this field is included in unfielded/implicit queries +} -type MatchType int +// Parser provides Lucene query parsing with security limits. +// Drivers are created on-demand when calling ParseToSQL or ParseToDynamoDBPartiQL. +type Parser struct { + Fields []FieldInfo // All searchable fields -const ( - matchExact MatchType = iota - matchStartsWith - matchEndsWith - matchContains -) + // Security limits (configurable with safe defaults) + MaxQueryLength int // Maximum query string length (default: 10KB) + MaxDepth int // Maximum nesting depth (default: 20) + MaxTerms int // Maximum number of terms (default: 100) -type Node struct { - Type NodeType - Field string - Value string - Operator LogicalOperator - Children []*Node - Negate bool - MatchType MatchType + // Field lookup maps for O(1) validation + fieldMap map[string]FieldInfo // All fields by name } -func NewParserFromType(model any) (*Parser, error) { - fields, err := getStructFields(model) +// NewParser creates a parser by introspecting a struct's fields. +// +// Basic usage: +// +// parser, err := lucene.NewParser(Task{}) +// +// With custom configuration: +// +// config := &lucene.ParserConfig{ +// MaxQueryLength: 5000, +// MaxDepth: 10, +// } +// parser, err := lucene.NewParser(Task{}, config) +// +// Auto-detection rules: +// - String fields: ImplicitSearch=true (included in unfielded queries) +// - Non-string fields (int, time.Time, uuid, etc.): ImplicitSearch=false +// - JSONB fields: ImplicitSearch=false (require field.subfield syntax) +// +// Field name extraction: +// - Uses `json` struct tag for field names +// - Skips fields without `json` tag or with `json:"-"` +func NewParser(model any, config ...*ParserConfig) (*Parser, error) { + fields, err := extractFields(model) if err != nil { return nil, err } - return NewParser(fields), nil -} -func NewParser(defaultFields []FieldInfo) *Parser { - return &Parser{DefaultFields: defaultFields} + // Build field map and validate for duplicates + fieldMap, err := buildFieldMap(fields) + if err != nil { + return nil, err + } + + // Apply config or use defaults + maxQueryLength := DefaultMaxQueryLength + maxDepth := DefaultMaxDepth + maxTerms := DefaultMaxTerms + + if len(config) > 0 && config[0] != nil { + cfg := config[0] + if cfg.MaxQueryLength > 0 { + maxQueryLength = cfg.MaxQueryLength + } + if cfg.MaxDepth > 0 { + maxDepth = cfg.MaxDepth + } + if cfg.MaxTerms > 0 { + maxTerms = cfg.MaxTerms + } + } + + return &Parser{ + Fields: fields, + MaxQueryLength: maxQueryLength, + MaxDepth: maxDepth, + MaxTerms: maxTerms, + fieldMap: fieldMap, + }, nil } -func getStructFields(model any) ([]FieldInfo, error) { +// extractFields uses reflection to extract field metadata from a struct. +func extractFields(model any) ([]FieldInfo, error) { t := reflect.TypeOf(model) if t.Kind() == reflect.Ptr { t = t.Elem() @@ -79,437 +123,690 @@ func getStructFields(model any) ([]FieldInfo, error) { var fields []FieldInfo for i := 0; i < t.NumField(); i++ { field := t.Field(i) + + // Get field name from json tag jsonTag := field.Tag.Get("json") if jsonTag == "" || jsonTag == "-" { continue } + // Strip options from json tag (e.g., "name,omitempty" -> "name") if commaIdx := strings.Index(jsonTag, ","); commaIdx != -1 { jsonTag = jsonTag[:commaIdx] } - gormTag := field.Tag.Get("gorm") - isJSONB := strings.Contains(gormTag, "type:jsonb") + // Implicit search: only string fields + implicitSearch := field.Type.Kind() == reflect.String fields = append(fields, FieldInfo{ - Name: jsonTag, - IsJSONB: isJSONB, + Name: jsonTag, + Type: field.Type, + ImplicitSearch: implicitSearch, }) } return fields, nil } -func (p *Parser) ParseToMap(query string) (map[string]any, error) { - node, err := p.parse(query) - if err != nil { - return nil, err +// buildFieldMap builds a field map from a slice of fields and validates for duplicates. +// Returns an error if duplicate field names are found. +func buildFieldMap(fields []FieldInfo) (map[string]FieldInfo, error) { + fieldMap := make(map[string]FieldInfo, len(fields)) + for _, f := range fields { + if existing, exists := fieldMap[f.Name]; exists { + return nil, fmt.Errorf("duplicate field name '%s': already defined with type %v, cannot redefine with type %v", f.Name, existing.Type, f.Type) + } + fieldMap[f.Name] = f } - return p.nodeToMap(node), nil + return fieldMap, nil } -func (p *Parser) ParseToSQL(query string) (string, []any, error) { - slog.Debug(fmt.Sprintf(`Parsing query to sql: %s`, query)) - re := regexp.MustCompile(`(\w+):"([^"]+)"`) - query = re.ReplaceAllString(query, `$1:$2`) - node, err := p.parse(query) - if err != nil { - return "", nil, err +// canUseNestedAccess checks if a field type supports nested access (field.subfield syntax). +func canUseNestedAccess(t reflect.Type) bool { + // Return false for nil types + if t == nil { + return false } - return p.nodeToSQL(node) -} -func (p *Parser) parse(query string) (*Node, error) { - query = strings.TrimSpace(query) - if query == "" { - return nil, nil + // Unwrap pointers + for t.Kind() == reflect.Ptr { + t = t.Elem() } - if strings.HasPrefix(query, "(") && strings.HasSuffix(query, ")") { - return p.parse(query[1 : len(query)-1]) + // Check type name for JSONB-like types + name := t.Name() + if strings.Contains(name, "JSONB") || strings.Contains(name, "JSON") { + return true } - if andParts := splitByOperator(query, "AND"); len(andParts) > 1 { - return p.createLogicalNode(AND, andParts) - } - if orParts := splitByOperator(query, "OR"); len(orParts) > 1 { - return p.createLogicalNode(OR, orParts) - } - if notParts := splitByOperator(query, "NOT"); len(notParts) > 1 { - return p.createLogicalNode(NOT, notParts) + // Maps and structs support nested access + if t.Kind() == reflect.Map || t.Kind() == reflect.Struct { + return true } - if parts := strings.SplitN(query, ":", 2); len(parts) == 2 { - field := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - // Skip empty fields or values - if field == "" || value == "" { - return nil, nil - } - return p.createTermNode(field, value) - } + return false +} - // Skip empty implicit terms - if query = strings.TrimSpace(query); query == "" { - return nil, nil - } +// Precompiled regex for performance - matches Lucene operators and special syntax +var ( + // Matches field:value pattern (including JSONB like labels.category:value) + fieldValuePattern = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)?:`) + // Extracts field name from field:value pattern + fieldExtractPattern = regexp.MustCompile(`([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)?):`) + // Matches boolean operators (case-insensitive) + booleanOperators = regexp.MustCompile(`(?i)^(AND|OR|NOT|\+|-)$`) + // Matches range syntax + rangePattern = regexp.MustCompile(`^\[.*\s+TO\s+.*\]$|^\{.*\s+TO\s+.*\}$`) +) - return p.createImplicitNode(query) +// InvalidFieldError represents an error when a query references a non-existent field +type InvalidFieldError struct { + Field string + ValidFields []string } -func splitByOperator(input string, op string) []string { - // Handle case where the operator is at the beginning of the string - trimmedInput := strings.TrimSpace(input) - lowerInput := strings.ToLower(trimmedInput) - lowerOp := strings.ToLower(op) - - if strings.HasPrefix(lowerInput, lowerOp) { - // Check if it's a standalone word (followed by space or end of string) - opLength := len(op) - if len(trimmedInput) == opLength || (len(trimmedInput) > opLength && trimmedInput[opLength] == ' ') { - afterOp := strings.TrimSpace(trimmedInput[opLength:]) - if afterOp != "" { - return []string{"", afterOp} - } - } +func (e *InvalidFieldError) Error() string { + return fmt.Sprintf("invalid field '%s' in query; valid fields are: %s", e.Field, strings.Join(e.ValidFields, ", ")) +} + +// ParseToMap parses a Lucene query into a map representation. +// Note: This is a legacy method kept for backward compatibility. +func (p *Parser) ParseToMap(query string) (map[string]any, error) { + if err := p.validateQuery(query); err != nil { + return nil, err } - // Original logic for operators in the middle - re := regexp.MustCompile(fmt.Sprintf(`(?i)\s+%s\s+`, op)) - parts := re.Split(input, -1) - if len(parts) > 1 { - return parts + e, err := p.parseWithImplicitSearch(query) + if err != nil { + return nil, err } - return nil + // Convert expression to map + return p.exprToMap(e), nil } -func (p *Parser) createImplicitNode(term string) (*Node, error) { - slog.Debug(fmt.Sprintf(`Handling implicit: %s`, term)) - term = strings.Trim(term, `"`) - - containsWildcard := strings.Contains(term, "*") || strings.Contains(term, "?") +// parseQueryCommon performs common parsing steps shared by ParseToSQL and ParseToDynamoDBPartiQL. +// Returns the parsed expression or an error. +func (p *Parser) parseQueryCommon(query string, queryType string) (*expr.Expression, error) { + slog.Debug(fmt.Sprintf(`Parsing query to %s: %s`, queryType, query)) - node := &Node{ - Type: NodeLogical, - Operator: OR, + if err := p.validateQuery(query); err != nil { + return nil, err } - for _, field := range p.DefaultFields { - var child *Node - var err error + // Expand implicit terms first (for validation of the full query) + expandedQuery := p.expandImplicitTerms(query) - if containsWildcard { - child, err = p.createWildcardNode(field.Name, term) - } else { - child, err = p.createTermNode(field.Name, term) + // Validate all field references exist in the model + if err := p.ValidateFields(expandedQuery); err != nil { + return nil, err + } - if child.Type == NodeTerm { - child.Type = NodeWildcard - child.MatchType = matchContains - } - } - if err != nil { - return nil, err - } - node.Children = append(node.Children, child) + // Parse using the library + e, err := p.parseWithImplicitSearch(query) + if err != nil { + return nil, err } - return node, nil + return e, nil } -func (p *Parser) createWildcardNode(field, value string) (*Node, error) { - // Skip empty fields or values - field = strings.TrimSpace(field) - value = strings.TrimSpace(value) - - if field == "" || value == "" { - return nil, nil +// ParseToSQL parses a Lucene query and converts it to SQL with parameters for the specified provider. +// Creates a SQL driver on-demand for rendering with provider-specific syntax. +// Provider should be one of: "postgresql", "mysql", "sqlite" +func (p *Parser) ParseToSQL(query string, provider string) (string, []any, error) { + e, err := p.parseQueryCommon(query, "SQL") + if err != nil { + return "", nil, err } - formattedField := p.formatFieldName(field) - - node := &Node{ - Type: NodeWildcard, - Field: formattedField, - Value: value, - } - - // Process the wildcard pattern - if strings.HasPrefix(value, "*") && strings.HasSuffix(value, "*") { - // For *term* pattern - node.MatchType = matchContains - node.Value = strings.Trim(value, "*") - } else if strings.HasPrefix(value, "*") { - // For *term pattern - node.MatchType = matchEndsWith - node.Value = strings.TrimPrefix(value, "*") - } else if strings.HasSuffix(value, "*") { - // For term* pattern - node.MatchType = matchStartsWith - node.Value = strings.TrimSuffix(value, "*") - } else if strings.Contains(value, "*") { - // For patterns like te*rm - node.MatchType = matchContains - // Replace wildcards with % for SQL LIKE - node.Value = strings.ReplaceAll(value, "*", "%") - } else { - // Default to contains match for other patterns - node.MatchType = matchContains - } - - // Skip if the value becomes empty after processing - if node.Value == "" { - return nil, nil + // Create SQL driver on-demand for the specified provider and render + driver, err := NewSQLDriver(p.Fields, provider) + if err != nil { + return "", nil, err + } + sql, params, err := driver.RenderParam(e) + if err != nil { + return "", nil, err } - return node, nil + return sql, params, nil } -func (p *Parser) formatFieldName(fieldName string) string { - if parts := strings.SplitN(fieldName, ".", 2); len(parts) == 2 { - baseField := parts[0] - subField := parts[1] +// ParseToDynamoDBPartiQL parses a Lucene query and converts it to DynamoDB PartiQL. +// Creates a DynamoDB driver on-demand for rendering. +func (p *Parser) ParseToDynamoDBPartiQL(query string) (string, []types.AttributeValue, error) { + e, err := p.parseQueryCommon(query, "DynamoDB PartiQL") + if err != nil { + return "", nil, err + } - for _, field := range p.DefaultFields { - if field.IsJSONB && field.Name == baseField { - return fmt.Sprintf("%s->>'%s'", baseField, subField) - } - } + // Create DynamoDB driver on-demand and render + driver, err := NewDynamoDBDriver(p.Fields) + if err != nil { + return "", nil, err + } + partiql, attrs, err := driver.RenderPartiQL(e) + if err != nil { + return "", nil, err } - return fieldName + + return partiql, attrs, nil } -func (p *Parser) createTermNode(field, value string) (*Node, error) { - field = strings.TrimSpace(field) - value = strings.TrimSpace(value) +func (p *Parser) validateQuery(query string) error { + var errs []error - if field == "" || value == "" { - return nil, nil + if len(query) > p.MaxQueryLength { + errs = append(errs, fmt.Errorf("query too long: %d bytes exceeds maximum of %d bytes", len(query), p.MaxQueryLength)) } - formattedField := p.formatFieldName(field) - trimmedValue := strings.TrimSpace(strings.Trim(value, `"`)) + depth := calculateNestingDepth(query) + if depth > p.MaxDepth { + errs = append(errs, fmt.Errorf("query too complex: nesting depth %d exceeds maximum of %d", depth, p.MaxDepth)) + } - // Skip if the value becomes empty after trimming - if trimmedValue == "" { - return nil, nil + terms := countTerms(query) + if terms > p.MaxTerms { + errs = append(errs, fmt.Errorf("query too large: %d terms exceeds maximum of %d", terms, p.MaxTerms)) } - node := &Node{ - Type: NodeTerm, - Field: formattedField, - Value: strings.Trim(value, `"`), + if len(errs) > 0 { + return errors.Join(errs...) } - if strings.Contains(value, "*") || strings.Contains(value, "?") { - node.Type = NodeWildcard + return nil +} - // Determine the match type based on wildcard position - if strings.HasPrefix(value, "*") && strings.HasSuffix(value, "*") { - node.MatchType = matchContains - node.Value = strings.Trim(value, "*") - } else if strings.HasPrefix(value, "*") { - node.MatchType = matchEndsWith - node.Value = strings.TrimPrefix(value, "*") - } else if strings.HasSuffix(value, "*") { - node.MatchType = matchStartsWith - node.Value = strings.TrimSuffix(value, "*") - } else { - // For patterns like te*rm or te?rm - node.MatchType = matchContains - // For SQL LIKE, convert * to % and ? to _ - node.Value = strings.ReplaceAll(strings.ReplaceAll(value, "*", "%"), "?", "_") +func calculateNestingDepth(query string) int { + maxDepth := 0 + currentDepth := 0 + inQuotes := false + + for i := 0; i < len(query); i++ { + c := query[i] + + if c == '\\' && i+1 < len(query) { + i++ + continue } - // Skip if the value becomes empty after processing wildcards - if node.Value == "" { - return nil, nil + if c == '"' { + inQuotes = !inQuotes + continue + } + + if !inQuotes { + switch c { + case '(', '[', '{': + currentDepth++ + if currentDepth > maxDepth { + maxDepth = currentDepth + } + case ')', ']', '}': + currentDepth-- + } } } - return node, nil + return maxDepth } -func (p *Parser) createLogicalNode(op LogicalOperator, parts []string) (*Node, error) { - node := &Node{ - Type: NodeLogical, - Operator: op, +// countTerms counts search terms in a query. +// Terms include field:value pairs, implicit terms, and quoted phrases. +// Operators (AND, OR, NOT) and parentheses are excluded. +func countTerms(query string) int { + if query == "" { + return 0 } - for _, part := range parts { - if strings.TrimSpace(part) == "" { + terms := 0 + inQuotes := false + inRange := false + currentTerm := false + + for i := 0; i < len(query); i++ { + c := query[i] + + if c == '\\' && i+1 < len(query) { + i++ + currentTerm = true + continue + } + + if c == '"' { + if !inQuotes { + if currentTerm { + terms++ + } + currentTerm = true + } else { + if currentTerm { + terms++ + currentTerm = false + } + } + inQuotes = !inQuotes + continue + } + + if !inQuotes { + if c == '[' || c == '{' { + inRange = true + if currentTerm { + terms++ + currentTerm = false + } + continue + } + if c == ']' || c == '}' { + inRange = false + if currentTerm { + terms++ + currentTerm = false + } + continue + } + } + + if c == ' ' && !inQuotes && !inRange { + if currentTerm { + terms++ + currentTerm = false + } continue } - child, err := p.parse(part) - if err != nil { - return nil, err + + if !inQuotes && !inRange && (c == '(' || c == ')') { + if currentTerm { + terms++ + currentTerm = false + } + continue } - if child != nil { - node.Children = append(node.Children, child) + + if !inQuotes && !inRange && currentTerm { + remaining := query[i:] + if strings.HasPrefix(remaining, "AND ") || strings.HasPrefix(remaining, "OR ") || + strings.HasPrefix(remaining, "NOT ") || strings.HasPrefix(remaining, "and ") || + strings.HasPrefix(remaining, "or ") || strings.HasPrefix(remaining, "not ") { + terms++ + currentTerm = false + if len(remaining) >= 3 && (remaining[0] == 'A' || remaining[0] == 'a') { + i += 3 + continue + } + if len(remaining) >= 3 && (remaining[0] == 'N' || remaining[0] == 'n') { + i += 3 + continue + } + i += 2 + continue + } } + + currentTerm = true } - // If no valid children were found, return nil - if len(node.Children) == 0 { - return nil, nil + if currentTerm { + terms++ } - return node, nil + return terms } -func (p *Parser) nodeToMap(node *Node) map[string]any { - if node == nil { +// ValidateFields returns InvalidFieldError if the query references non-existent fields. +func (p *Parser) ValidateFields(query string) error { + matches := fieldExtractPattern.FindAllStringSubmatchIndex(query, -1) + if len(matches) == 0 { return nil } - switch node.Type { - case NodeTerm: - return map[string]any{node.Field: node.Value} - case NodeWildcard: - return map[string]any{node.Field: map[string]string{ - "$like": wildcardToPattern(node.Value, node.MatchType), - }} - case NodeLogical: - result := make(map[string]any) - children := make([]map[string]any, 0, len(node.Children)) - for _, child := range node.Children { - children = append(children, p.nodeToMap(child)) + validFields := p.getValidFieldNames() + + for _, match := range matches { + if len(match) < 4 { + continue + } + fieldStart := match[2] + fieldEnd := match[3] + + if isInsideQuotes(query, fieldStart) { + continue + } + + fieldName := query[fieldStart:fieldEnd] + + if err := p.validateFieldName(fieldName); err != nil { + return &InvalidFieldError{ + Field: fieldName, + ValidFields: validFields, + } } - result[string(node.Operator)] = children - return result } + return nil } -func (p *Parser) nodeToSQL(node *Node) (string, []any, error) { - if node == nil { - return "", nil, nil +func isInsideQuotes(query string, pos int) bool { + inQuotes := false + for i := 0; i < pos && i < len(query); i++ { + c := query[i] + if c == '\\' && i+1 < len(query) { + i++ + continue + } + if c == '"' { + inQuotes = !inQuotes + } } + return inQuotes +} - switch node.Type { - case NodeTerm: - if strings.Contains(node.Field, "->>") { - return fmt.Sprintf("%s = ?", node.Field), []any{node.Value}, nil +// validateFieldName validates both simple fields (name) and nested fields (labels.category). +func (p *Parser) validateFieldName(fieldName string) error { + if strings.Contains(fieldName, ".") { + parts := strings.SplitN(fieldName, ".", 2) + if len(parts) != 2 { + return fmt.Errorf("invalid field format: %s", fieldName) } - return fmt.Sprintf("%s = ?", node.Field), []any{node.Value}, nil - case NodeWildcard: - pattern := wildcardToPattern(node.Value, node.MatchType) - if strings.Contains(node.Field, "->>") { - return fmt.Sprintf("%s ILIKE ?", node.Field), []any{pattern}, nil - } else { - return fmt.Sprintf("%s::text ILIKE ?", node.Field), []any{pattern}, nil + + baseField := parts[0] + subField := parts[1] + + // Check for whitespace in field names (security: prevents obfuscation, OWASP A03) + if strings.TrimSpace(baseField) != baseField { + return fmt.Errorf("invalid field format '%s': whitespace not allowed in field names", fieldName) + } + if strings.TrimSpace(subField) != subField { + return fmt.Errorf("invalid field format '%s': whitespace not allowed in field names (use 'field.subfield' not 'field. subfield')", fieldName) } - case NodeLogical: - var parts []string - var params []any - for _, child := range node.Children { - sqlPart, childParams, err := p.nodeToSQL(child) - if err != nil { - return "", nil, err - } - if sqlPart != "" { - parts = append(parts, sqlPart) - params = append(params, childParams...) - } + // Check if base field exists + field, exists := p.fieldMap[baseField] + if !exists { + return fmt.Errorf("field '%s' does not exist", baseField) } - if len(parts) == 0 { - return "", nil, nil + // Check if base field supports nested access + if !canUseNestedAccess(field.Type) { + return fmt.Errorf("field '%s' does not support nested access (field.subfield syntax); use explicit field names only", baseField) + } + + return nil + } + + if _, exists := p.fieldMap[fieldName]; !exists { + return fmt.Errorf("field '%s' does not exist", fieldName) + } + + return nil +} + +func (p *Parser) getValidFieldNames() []string { + var names []string + for _, f := range p.Fields { + // Add a hint for fields that support nested access + if canUseNestedAccess(f.Type) { + names = append(names, f.Name+".*") + } else { + names = append(names, f.Name) } + } + return names +} - if len(parts) == 1 { - return parts[0], params, nil +func (p *Parser) getImplicitSearchFields() []FieldInfo { + var fields []FieldInfo + for _, field := range p.Fields { + if field.ImplicitSearch { + fields = append(fields, field) } + } + return fields +} + +// isImplicitTerm returns true if token is a search term without an explicit field prefix. +func isImplicitTerm(token string) bool { + token = strings.TrimSpace(token) + if token == "" { + return false + } - operator := string(node.Operator) - if node.Negate { - operator = "NOT " + operator + // Check if it's a boolean operator + if booleanOperators.MatchString(token) { + return false + } + + // Check if it starts with + or - (required/prohibited operators) + if strings.HasPrefix(token, "+") || strings.HasPrefix(token, "-") { + // Remove the prefix and check the rest + rest := token[1:] + if fieldValuePattern.MatchString(rest) { + return false // It's a +field:value or -field:value } + // Otherwise it's an implicit term with +/- modifier + return true + } - return fmt.Sprintf("(%s)", strings.Join(parts, fmt.Sprintf(" %s ", operator))), params, nil + // Check if it's a field:value pattern + if fieldValuePattern.MatchString(token) { + return false } - return "", nil, fmt.Errorf("unsupported node type") + // Check if it's a range query + if rangePattern.MatchString(token) { + return false + } + + // Check if it's a parenthesis + if token == "(" || token == ")" { + return false + } + + // Quoted strings are also implicit terms (they search across implicit search fields) + if strings.HasPrefix(token, `"`) && strings.HasSuffix(token, `"`) { + return true + } + + return true } -func (p *Parser) ParseToDynamoDBPartiQL(query string) (string, []types.AttributeValue, error) { - slog.Debug(fmt.Sprintf(`Parsing query to DynamoDB PartiQL: %s`, query)) - node, err := p.parse(query) - if err != nil { - return "", nil, err +// expandImplicitTerms expands implicit search terms to explicit field:value patterns +// across all implicit search fields. For example: +// "paint" → "(name:*paint* OR description:*paint*)" +// "paint*" → "(name:paint* OR description:paint*)" +// '"Living Room"' → '(name:"Living Room" OR description:"Living Room")' +func (p *Parser) expandImplicitTerms(query string) string { + implicitFields := p.getImplicitSearchFields() + if len(implicitFields) == 0 { + return query + } + + // Tokenize the query while preserving structure + tokens := tokenizeQuery(query) + var result []string + + for _, token := range tokens { + if isImplicitTerm(token) { + // Check if it has a +/- prefix + prefix := "" + term := token + if strings.HasPrefix(token, "+") || strings.HasPrefix(token, "-") { + prefix = string(token[0]) + term = token[1:] + } + + // Check if it's a quoted phrase (exact match) or already has wildcards + searchTerm := term + isQuotedPhrase := strings.HasPrefix(term, `"`) && strings.HasSuffix(term, `"`) + hasWildcards := strings.Contains(term, "*") || strings.Contains(term, "?") + + // For implicit search without wildcards or quotes, use contains matching + // This provides a better user experience for simple searches + if !isQuotedPhrase && !hasWildcards { + searchTerm = "*" + term + "*" + } + + // Expand to all implicit search fields with OR + var fieldTerms []string + for _, field := range implicitFields { + fieldTerms = append(fieldTerms, fmt.Sprintf("%s:%s", field.Name, searchTerm)) + } + + if len(fieldTerms) == 1 { + result = append(result, prefix+fieldTerms[0]) + } else { + expanded := "(" + strings.Join(fieldTerms, " OR ") + ")" + if prefix != "" { + expanded = prefix + expanded + } + result = append(result, expanded) + } + } else { + result = append(result, token) + } } - return p.nodeToDynamoDBPartiQL(node) + + return strings.Join(result, " ") } -func (p *Parser) nodeToDynamoDBPartiQL(node *Node) (string, []types.AttributeValue, error) { - if node == nil { - return "", nil, nil - } - - switch node.Type { - case NodeTerm: - // For term node, create an exact match condition - return fmt.Sprintf("%s = ?", node.Field), []types.AttributeValue{ - &types.AttributeValueMemberS{Value: node.Value}, - }, nil - case NodeWildcard: - // For wildcard node, use begins_with or contains based on the match type - switch node.MatchType { - case matchStartsWith: - return fmt.Sprintf("begins_with(%s, ?)", node.Field), []types.AttributeValue{ - &types.AttributeValueMemberS{Value: node.Value}, - }, nil - case matchEndsWith, matchContains: - return fmt.Sprintf("contains(%s, ?)", node.Field), []types.AttributeValue{ - &types.AttributeValueMemberS{Value: node.Value}, - }, nil - default: - return fmt.Sprintf("%s = ?", node.Field), []types.AttributeValue{ - &types.AttributeValueMemberS{Value: node.Value}, - }, nil - } - case NodeLogical: - // For logical node, combine conditions with appropriate operator - var parts []string - var params []types.AttributeValue - - for _, child := range node.Children { - part, childParams, err := p.nodeToDynamoDBPartiQL(child) - if err != nil { - return "", nil, err +// tokenizeQuery splits query into tokens, preserving quoted strings and range brackets. +func tokenizeQuery(query string) []string { + var tokens []string + var current strings.Builder + inQuotes := false + inRange := false + rangeDepth := 0 + + for i := 0; i < len(query); i++ { + c := query[i] + + // Handle quotes + if c == '"' && (i == 0 || query[i-1] != '\\') { + inQuotes = !inQuotes + current.WriteByte(c) + continue + } + + // Handle range brackets + if !inQuotes { + if c == '[' || c == '{' { + inRange = true + rangeDepth++ + current.WriteByte(c) + continue } - if part != "" { - parts = append(parts, part) - params = append(params, childParams...) + if c == ']' || c == '}' { + current.WriteByte(c) + rangeDepth-- + if rangeDepth == 0 { + inRange = false + } + continue } } - if len(parts) == 0 { - return "", nil, nil + // Handle spaces (token separators) when not in quotes or range + if c == ' ' && !inQuotes && !inRange { + if current.Len() > 0 { + tokens = append(tokens, current.String()) + current.Reset() + } + continue } - operator := string(node.Operator) - if node.Negate { - operator = "NOT " + operator + // Handle parentheses as separate tokens + if !inQuotes && !inRange && (c == '(' || c == ')') { + if current.Len() > 0 { + tokens = append(tokens, current.String()) + current.Reset() + } + tokens = append(tokens, string(c)) + continue } - return fmt.Sprintf("(%s)", strings.Join(parts, fmt.Sprintf(" %s ", operator))), params, nil + current.WriteByte(c) + } + + if current.Len() > 0 { + tokens = append(tokens, current.String()) + } + + return tokens +} + +// parseWithImplicitSearch expands unfielded terms across all implicit search fields with OR. +func (p *Parser) parseWithImplicitSearch(query string) (*expr.Expression, error) { + query = strings.TrimSpace(query) + if query == "" { + return nil, nil + } + + // Expand implicit terms to explicit field:value patterns + expandedQuery := p.expandImplicitTerms(query) + + slog.Debug("Query expansion", "original", query, "expanded", expandedQuery) + + // Get first implicit field as fallback for the parser + fallbackField := "" + implicitFields := p.getImplicitSearchFields() + if len(implicitFields) > 0 { + fallbackField = implicitFields[0].Name + } + if fallbackField == "" && len(p.Fields) > 0 { + fallbackField = p.Fields[0].Name + } + + return lucene.Parse(expandedQuery, lucene.WithDefaultField(fallbackField)) +} + +// exprToMap converts expression to map format (legacy, kept for backward compatibility). +func (p *Parser) exprToMap(e *expr.Expression) map[string]any { + if e == nil { + return nil + } + + result := make(map[string]any) + + switch e.Op { + case expr.Equals: + if col, ok := e.Left.(expr.Column); ok { + result[string(col)] = p.valueToAny(e.Right) + } + case expr.Like: + if col, ok := e.Left.(expr.Column); ok { + pattern := p.valueToAny(e.Right) + result[string(col)] = map[string]any{"$like": pattern} + } + case expr.And, expr.Or, expr.Not: + var children []map[string]any + if leftExpr, ok := e.Left.(*expr.Expression); ok { + children = append(children, p.exprToMap(leftExpr)) + } + if rightExpr, ok := e.Right.(*expr.Expression); ok { + children = append(children, p.exprToMap(rightExpr)) + } + result[e.Op.String()] = children + default: + // For other operators, do a simple conversion + if col, ok := e.Left.(expr.Column); ok { + result[string(col)] = p.valueToAny(e.Right) + } } - return "", nil, fmt.Errorf("unsupported node type") + return result } -func wildcardToPattern(value string, matchType MatchType) string { - switch matchType { - case matchStartsWith: - return value + "%" - case matchEndsWith: - return "%" + value - case matchContains: - return "%" + value + "%" +func (p *Parser) valueToAny(v any) any { + switch val := v.(type) { + case *expr.Expression: + return p.exprToMap(val) + case string: + return val + case int, float64: + return val default: - return value + return fmt.Sprintf("%v", v) } } diff --git a/storage/search/lucene/parser_test.go b/storage/search/lucene/parser_test.go new file mode 100644 index 0000000..c94fd4b --- /dev/null +++ b/storage/search/lucene/parser_test.go @@ -0,0 +1,1334 @@ +package lucene + +import ( + "strings" + "testing" +) + +// Helper functions following FIRST principles + +// assertSQLContains checks that SQL contains all required substrings (more precise validation) +func assertSQLContains(t *testing.T, sql string, required []string, msg string) { + t.Helper() + for _, req := range required { + if !strings.Contains(sql, req) { + t.Errorf("%s: SQL = %q, missing required substring %q", msg, sql, req) + } + } +} + +// assertSQLNotContains checks that SQL does not contain forbidden substrings +func assertSQLNotContains(t *testing.T, sql string, forbidden []string, msg string) { + t.Helper() + for _, forb := range forbidden { + if strings.Contains(sql, forb) { + t.Errorf("%s: SQL = %q, contains forbidden substring %q", msg, sql, forb) + } + } +} + +// assertParamsEqual validates exact parameter values (self-validating) +func assertParamsEqual(t *testing.T, got []any, want []any, msg string) { + t.Helper() + if len(got) != len(want) { + t.Errorf("%s: param count = %d, want %d", msg, len(got), len(want)) + return + } + for i := range got { + if got[i] != want[i] { + t.Errorf("%s: param[%d] = %v, want %v", msg, i, got[i], want[i]) + } + } +} + +// assertErrorContains validates error messages precisely +func assertErrorContains(t *testing.T, err error, wantSubstrings []string, msg string) { + t.Helper() + if err == nil { + t.Errorf("%s: expected error, got nil", msg) + return + } + errMsg := err.Error() + for _, want := range wantSubstrings { + if !strings.Contains(errMsg, want) { + t.Errorf("%s: error = %q, missing required substring %q", msg, errMsg, want) + } + } +} + +// createParser is a helper to reduce duplication (Fast principle - parser created once per test) +func createParser(t *testing.T, model any, config ...*ParserConfig) *Parser { + t.Helper() + parser, err := NewParser(model, config...) + if err != nil { + t.Fatalf("NewParser() error = %v", err) + } + return parser +} + +// Test model definitions +type BasicModel struct { + Name string `json:"name"` + Email string `json:"email"` +} + +type BooleanModel struct { + Name string `json:"name"` + Status string `json:"status"` + Role string `json:"role"` +} + +type RangeModel struct { + Age int `json:"age"` + Date string `json:"date"` +} + +type TextModel struct { + Description string `json:"description"` + Title string `json:"title"` + Name string `json:"name"` +} + +type ComplexModel struct { + Name string `json:"name"` + Age int `json:"age"` + Status string `json:"status"` + Email string `json:"email"` +} + +// JSONB types for testing +type JSONBType map[string]interface{} + +type JSONBModel struct { + Metadata JSONBType `json:"metadata"` +} + +type MixedModel struct { + Name string `json:"name"` + Description string `json:"description"` + Status string `json:"status"` + Labels JSONBType `json:"labels"` + Metadata JSONBType `json:"metadata"` +} + +type NullModel struct { + Name string `json:"name"` + ParentID string `json:"parent_id"` + DeletedAt string `json:"deleted_at"` + AttachmentIDs string `json:"attachment_ids"` +} + +// TestBasicFieldSearch tests basic field:value queries +// Improved with precise assertions following FIRST principles +func TestBasicFieldSearch(t *testing.T) { + parser := createParser(t, BasicModel{}) + + tests := []struct { + name string + query string + wantSQL []string + wantNot []string + wantParams []any + wantErr bool + }{ + { + name: "simple field query", + query: "name:john", + wantSQL: []string{`"name"`, "=", "?"}, + wantNot: []string{"ILIKE", "LIKE"}, + wantParams: []any{"john"}, + wantErr: false, + }, + { + name: "wildcard prefix", + query: "name:john*", + wantSQL: []string{`"name"`, "ILIKE", "?"}, + wantNot: []string{"="}, + wantParams: []any{"john%"}, + wantErr: false, + }, + { + name: "wildcard suffix", + query: "name:*john", + wantSQL: []string{`"name"`, "ILIKE", "?"}, + wantParams: []any{"%john"}, + wantErr: false, + }, + { + name: "wildcard contains", + query: "name:*john*", + wantSQL: []string{`"name"`, "ILIKE", "?"}, + wantParams: []any{"%john%"}, + wantErr: false, + }, + { + name: "email field", + query: `email:"test@example.com"`, + wantSQL: []string{`"email"`, "=", "?"}, + wantParams: []any{"test@example.com"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := parser.ParseToSQL(tt.query, "postgresql") + if (err != nil) != tt.wantErr { + t.Fatalf("ParseToSQL() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + assertSQLContains(t, sql, tt.wantSQL, tt.name) + if len(tt.wantNot) > 0 { + assertSQLNotContains(t, sql, tt.wantNot, tt.name) + } + if len(tt.wantParams) > 0 { + // Only validate params if we expect specific values + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } + } + } + }) + } +} + +// TestBooleanOperators tests AND, OR, NOT operators +// Improved with parameter validation +func TestBooleanOperators(t *testing.T) { + parser := createParser(t, BooleanModel{}) + + tests := []struct { + name string + query string + wantSQL []string + wantParams []any + wantErr bool + }{ + { + name: "AND operator", + query: "name:john AND status:active", + wantSQL: []string{`"name"`, `"status"`, "AND"}, + wantParams: []any{"john", "active"}, + wantErr: false, + }, + { + name: "OR operator", + query: "name:john OR name:jane", + wantSQL: []string{`"name"`, "OR"}, + wantParams: []any{"john", "jane"}, + wantErr: false, + }, + { + name: "NOT operator", + query: "name:john NOT status:inactive", + wantSQL: []string{`"name"`, `"status"`, "NOT"}, + wantParams: []any{"john", "inactive"}, + wantErr: false, + }, + { + name: "complex nested", + query: "(name:john OR name:jane) AND status:active", + wantSQL: []string{`"name"`, `"status"`, "OR", "AND"}, + wantParams: []any{"john", "jane", "active"}, + wantErr: false, + }, + { + name: "case insensitive AND", + query: "name:john and status:active", + wantSQL: []string{"AND"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := parser.ParseToSQL(tt.query, "postgresql") + if (err != nil) != tt.wantErr { + t.Fatalf("ParseToSQL() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + assertSQLContains(t, sql, tt.wantSQL, tt.name) + if len(tt.wantParams) > 0 { + // Only validate params if we expect specific values + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } + } + } + }) + } +} + +// TestRequiredProhibited tests + and - operators +func TestRequiredProhibited(t *testing.T) { + parser, err := NewParser(BooleanModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) + } + + tests := []struct { + name string + query string + wantSQL []string + }{ + { + name: "required term", + query: "+name:john", + wantSQL: []string{`"name"`}, + }, + { + name: "prohibited term", + query: "-status:inactive", + wantSQL: []string{"NOT", `"status"`}, + }, + { + name: "mixed required and prohibited", + query: "+name:john -status:inactive", + wantSQL: []string{`"name"`, "NOT", `"status"`}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + } + } + }) + } +} + +// TestRangeQueries tests range query syntax +// Improved with parameter validation +func TestRangeQueries(t *testing.T) { + parser := createParser(t, RangeModel{}) + + tests := []struct { + name string + query string + wantSQL []string + wantParams []any + wantErr bool + }{ + { + name: "inclusive range", + query: "age:[25 TO 65]", + wantSQL: []string{`"age"`, "BETWEEN"}, + wantParams: []any{"25", "65"}, + wantErr: false, + }, + { + name: "exclusive range", + query: "age:{25 TO 65}", + wantSQL: []string{`"age"`, ">", "<"}, + wantParams: []any{"25", "65"}, + wantErr: false, + }, + { + name: "open-ended range min", + query: "age:[25 TO *]", + wantSQL: []string{`"age"`, ">="}, + wantParams: []any{"25"}, + wantErr: false, + }, + { + name: "open-ended range max", + query: "age:[* TO 65]", + wantSQL: []string{`"age"`, "<="}, + wantParams: []any{"65"}, + wantErr: false, + }, + { + name: "date range", + query: "date:[2024-01-01 TO 2024-12-31]", + wantSQL: []string{`"date"`, "BETWEEN"}, + wantParams: []any{"2024-01-01", "2024-12-31"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := parser.ParseToSQL(tt.query, "postgresql") + if (err != nil) != tt.wantErr { + t.Fatalf("ParseToSQL() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + assertSQLContains(t, sql, tt.wantSQL, tt.name) + // Only validate params if we expect specific values + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } + } + }) + } +} + +// TestQuotedPhrases tests quoted phrase handling +func TestQuotedPhrases(t *testing.T) { + parser, err := NewParser(TextModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) + } + + tests := []struct { + name string + query string + wantSQL []string + }{ + { + name: "simple quoted phrase", + query: `description:"hello world"`, + wantSQL: []string{`"description"`}, + }, + { + name: "phrase with special chars", + query: `title:"test-app (v1.0)"`, + wantSQL: []string{`"title"`}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + } + } + }) + } +} + +// TestEscapedCharacters tests escaped character handling +func TestEscapedCharacters(t *testing.T) { + parser, err := NewParser(TextModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) + } + + tests := []struct { + name string + query string + wantSQL []string + }{ + { + name: "escaped colon", + query: `name:test\:value`, + wantSQL: []string{`"name"`}, + }, + { + name: "escaped plus", + query: `name:C\+\+`, + wantSQL: []string{`"name"`}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + } + } + }) + } +} + +// TestComplexQueries tests complex query combinations +func TestComplexQueries(t *testing.T) { + parser, err := NewParser(ComplexModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) + } + + tests := []struct { + name string + query string + wantSQL []string + shouldErr bool + }{ + { + name: "complex with ranges and wildcards", + query: "(name:john* OR email:test*) AND age:[25 TO 65]", + wantSQL: []string{`"name"`, `"email"`, `"age"`, "OR", "AND", "BETWEEN"}, + }, + { + name: "complex with required and prohibited", + query: "+name:john -status:inactive age:[25 TO 65]", + wantSQL: []string{`"name"`, `"status"`, `"age"`, "NOT"}, + }, + { + name: "complex with quoted phrases", + query: `name:"John Doe" AND status:active`, + wantSQL: []string{`"name"`, `"status"`, "AND"}, + }, + { + name: "complex nested query", + query: "((name:john OR name:jane) AND status:active) OR (status:pending AND age:[18 TO *])", + wantSQL: []string{`"name"`, `"status"`, `"age"`, "OR", "AND"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") + if (err != nil) != tt.shouldErr { + t.Fatalf("ParseToSQL() error = %v, shouldErr = %v", err, tt.shouldErr) + } + if !tt.shouldErr { + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + } + } + } + }) + } +} + +// TestImplicitSearch tests implicit search across string fields +// Improved with precise validation +func TestImplicitSearch(t *testing.T) { + parser := createParser(t, TextModel{}) + + tests := []struct { + name string + query string + wantSQL []string + wantParams []any + wantErr bool + }{ + { + name: "implicit search", + query: "john", + wantSQL: []string{"OR"}, + wantParams: []any{"%john%", "%john%", "%john%"}, + wantErr: false, + }, + { + name: "implicit search with wildcard", + query: "john*", + wantSQL: []string{"OR"}, + wantParams: []any{"john%", "john%", "john%"}, + wantErr: false, + }, + { + name: "implicit quoted phrase", + query: `"john doe"`, + wantSQL: []string{"OR"}, + wantParams: []any{"john doe", "john doe", "john doe"}, // quotes are stripped + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := parser.ParseToSQL(tt.query, "postgresql") + if (err != nil) != tt.wantErr { + t.Fatalf("ParseToSQL() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + assertSQLContains(t, sql, tt.wantSQL, tt.name) + // Only validate params if we expect specific values + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } + } + }) + } +} + +// TestJSONBFields tests JSONB field notation +func TestJSONBFields(t *testing.T) { + parser, err := NewParser(JSONBModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) + } + + tests := []struct { + name string + query string + wantSQL []string + }{ + { + name: "JSONB field access", + query: "metadata.key:value", + wantSQL: []string{`metadata->>'key'`}, + }, + { + name: "JSONB with wildcard", + query: "metadata.tags:prod*", + wantSQL: []string{`metadata->>'tags'`, "ILIKE"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, want) + } + } + }) + } +} + +// TestMapOutput tests the legacy map output format +func TestMapOutput(t *testing.T) { + parser, err := NewParser(BooleanModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) + } + + result, err := parser.ParseToMap("name:john AND status:active") + if err != nil { + t.Fatalf("ParseToMap() error = %v", err) + } + + if result == nil { + t.Errorf("ParseToMap() returned nil") + } +} + +// TestFieldValidation tests field validation for invalid field references +// Improved with precise error message validation +func TestFieldValidation(t *testing.T) { + parser := createParser(t, MixedModel{}) + + tests := []struct { + name string + query string + wantErr bool + wantErrMsgs []string + }{ + { + name: "valid field query", + query: "name:john", + wantErr: false, + }, + { + name: "valid JSONB sub-field", + query: "labels.category:prod", + wantErr: false, + }, + { + name: "invalid field", + query: "invalidfield:value", + wantErr: true, + wantErrMsgs: []string{"invalidfield", "invalid field"}, + }, + { + name: "invalid JSONB base", + query: "notjsonb.subfield:value", + wantErr: true, + wantErrMsgs: []string{"notjsonb"}, // Error message may vary + }, + { + name: "sub-field on non-JSONB field", + query: "name.subfield:value", + wantErr: true, + wantErrMsgs: []string{"name.subfield", "invalid field"}, + }, + { + name: "implicit search (no explicit fields) - valid", + query: "searchterm", + wantErr: false, + }, + { + name: "mixed valid and implicit", + query: "name:john OR searchterm", + wantErr: false, + }, + { + name: "mixed valid and invalid", + query: "name:john OR invalidfield:value", + wantErr: true, + wantErrMsgs: []string{"invalidfield"}, + }, + { + name: "complex valid query", + query: "(name:john OR description:test) AND labels.env:prod", + wantErr: false, + }, + { + name: "invalid field in complex query", + query: "(name:john OR badfield:test) AND status:active", + wantErr: true, + wantErrMsgs: []string{"badfield"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, err := parser.ParseToSQL(tt.query, "postgresql") + if (err != nil) != tt.wantErr { + t.Errorf("ParseToSQL() error = %v, wantErr = %v", err, tt.wantErr) + return + } + if tt.wantErr && len(tt.wantErrMsgs) > 0 { + assertErrorContains(t, err, tt.wantErrMsgs, tt.name) + } + }) + } +} + +// TestNullValueQueries tests null value handling for IS NULL queries +// Improved with precise SQL and parameter validation +func TestNullValueQueries(t *testing.T) { + parser := createParser(t, NullModel{}) + + tests := []struct { + name string + query string + wantSQL []string + wantNot []string + wantParams []any + wantErr bool + }{ + { + name: "field is null (lowercase)", + query: "parent_id:null", + wantSQL: []string{`"parent_id"`, "IS NULL"}, + wantNot: []string{"=", "?"}, + wantParams: []any{}, + wantErr: false, + }, + { + name: "field is NULL (uppercase)", + query: "parent_id:NULL", + wantSQL: []string{`"parent_id"`, "IS NULL"}, + wantParams: []any{}, + wantErr: false, + }, + { + name: "field is Null (mixed case)", + query: "parent_id:Null", + wantSQL: []string{`"parent_id"`, "IS NULL"}, + wantParams: []any{}, + wantErr: false, + }, + { + name: "combined null with other conditions", + query: "name:john AND deleted_at:null", + wantSQL: []string{`"name"`, `"deleted_at"`, "IS NULL", "AND"}, + wantParams: []any{"john"}, + wantErr: false, + }, + { + name: "NOT null (is not null)", + query: "NOT deleted_at:null", + wantSQL: []string{"NOT", `"deleted_at"`}, + wantParams: []any{"null"}, // NOT null is parsed as NOT field=null, not NOT field IS NULL + wantErr: false, + }, + { + name: "nil should be treated as literal value (not NULL)", + query: "name:nil", + wantSQL: []string{`"name"`, "=", "?"}, + wantParams: []any{"nil"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := parser.ParseToSQL(tt.query, "postgresql") + if (err != nil) != tt.wantErr { + t.Errorf("ParseToSQL() error = %v, wantErr = %v", err, tt.wantErr) + return + } + if !tt.wantErr { + assertSQLContains(t, sql, tt.wantSQL, tt.name) + if len(tt.wantNot) > 0 { + assertSQLNotContains(t, sql, tt.wantNot, tt.name) + } + if len(tt.wantParams) > 0 { + // Only validate params if we expect specific values + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } + } + } + }) + } +} + +// TestEmptyAsLiteralValue tests that 'empty' is treated as a literal value +func TestEmptyAsLiteralValue(t *testing.T) { + parser, err := NewParser(BooleanModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) + } + + sql, params, err := parser.ParseToSQL("status:empty", "postgresql") + if err != nil { + t.Fatalf("ParseToSQL() error = %v", err) + } + + if !strings.Contains(sql, `"status" =`) { + t.Errorf("Expected regular equals query, got: %v", sql) + } + if len(params) != 1 || params[0] != "empty" { + t.Errorf("Expected params to contain 'empty', got: %v", params) + } +} + +// TestFuzzySearch tests fuzzy search operator (~) +func TestFuzzySearch(t *testing.T) { + parser, err := NewParser(MixedModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) + } + + tests := []struct { + name string + query string + wantSQL string + wantErr bool + }{ + { + name: "basic fuzzy search", + query: "name:roam~", + wantSQL: "similarity", + }, + { + name: "fuzzy with distance", + query: "name:roam~2", + wantSQL: "similarity", + }, + { + name: "fuzzy on JSONB field", + query: "labels.tag:prod~", + wantSQL: "similarity", + }, + { + name: "fuzzy combined with other conditions", + query: "name:test~ AND status:active", + wantSQL: "similarity", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") + if (err != nil) != tt.wantErr { + t.Errorf("ParseToSQL() error = %v, wantErr = %v", err, tt.wantErr) + } + if !tt.wantErr && !strings.Contains(sql, tt.wantSQL) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, tt.wantSQL) + } + }) + } +} + +// TestEscaping tests that special characters can be escaped +func TestEscaping(t *testing.T) { + type EscapeModel struct { + Name string `json:"name"` + Version string `json:"version"` + Path string `json:"path"` + } + + parser, err := NewParser(EscapeModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) + } + + tests := []struct { + name string + query string + wantSQL string + wantErr bool + }{ + { + name: "escaped plus sign", + query: `name:C\+\+`, + wantSQL: `"name"`, + }, + { + name: "escaped colon", + query: `name:test\:value`, + wantSQL: `"name"`, + }, + { + name: "escaped path separator", + query: `path:\/usr\/bin`, + wantSQL: `"path"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, _, err := parser.ParseToSQL(tt.query, "postgresql") + if (err != nil) != tt.wantErr { + t.Errorf("ParseToSQL() error = %v, wantErr = %v", err, tt.wantErr) + } + if !tt.wantErr && !strings.Contains(sql, tt.wantSQL) { + t.Errorf("ParseToSQL() sql = %v, want to contain %v", sql, tt.wantSQL) + } + }) + } +} + +// TestBoostOperatorError tests that boost operator returns a clear error +func TestBoostOperatorError(t *testing.T) { + parser, err := NewParser(BooleanModel{}) + if err != nil { + t.Fatalf("NewParser() error = %v", err) + } + + tests := []struct { + name string + query string + wantErr string + }{ + { + name: "boost operator", + query: "name:john^2", + wantErr: "boost", + }, + { + name: "boost in compound query", + query: "name:john^2 AND status:active", + wantErr: "boost", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _, err := parser.ParseToSQL(tt.query, "postgresql") + if err == nil { + t.Errorf("ParseToSQL(%q) expected error but got none", tt.query) + return + } + if !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.wantErr)) { + t.Errorf("ParseToSQL(%q) error = %v, want to contain %v", tt.query, err, tt.wantErr) + } + }) + } +} + +// TestNewParser tests parser creation and configuration +func TestNewParser(t *testing.T) { + tests := []struct { + name string + model any + config *ParserConfig + wantErr bool + wantCount int + }{ + { + name: "basic model", + model: BasicModel{}, + wantErr: false, + wantCount: 2, + }, + { + name: "pointer to model", + model: &BasicModel{}, + wantErr: false, + wantCount: 2, + }, + { + name: "with custom config", + model: BasicModel{}, + config: &ParserConfig{MaxQueryLength: 5000, MaxDepth: 10, MaxTerms: 50}, + wantErr: false, + wantCount: 2, + }, + { + name: "invalid model (not struct)", + model: "not a struct", + wantErr: true, + }, + { + name: "empty struct", + model: struct{}{}, + wantErr: false, + wantCount: 0, + }, + { + name: "model with no json tags", + model: struct{ Name string }{}, + wantErr: false, + wantCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewParser(tt.model, tt.config) + if (err != nil) != tt.wantErr { + t.Errorf("NewParser() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if parser == nil { + t.Fatal("NewParser() returned nil parser") + } + if len(parser.Fields) != tt.wantCount { + t.Errorf("NewParser() field count = %d, want %d", len(parser.Fields), tt.wantCount) + } + if tt.config != nil { + if tt.config.MaxQueryLength > 0 && parser.MaxQueryLength != tt.config.MaxQueryLength { + t.Errorf("NewParser() MaxQueryLength = %d, want %d", parser.MaxQueryLength, tt.config.MaxQueryLength) + } + if tt.config.MaxDepth > 0 && parser.MaxDepth != tt.config.MaxDepth { + t.Errorf("NewParser() MaxDepth = %d, want %d", parser.MaxDepth, tt.config.MaxDepth) + } + if tt.config.MaxTerms > 0 && parser.MaxTerms != tt.config.MaxTerms { + t.Errorf("NewParser() MaxTerms = %d, want %d", parser.MaxTerms, tt.config.MaxTerms) + } + } + } + }) + } +} + +// TestParser_ValidateQuery tests query validation (security limits) +func TestParser_ValidateQuery(t *testing.T) { + parser := createParser(t, BasicModel{}) + + tests := []struct { + name string + query string + config *ParserConfig + wantErr bool + wantError []string + }{ + { + name: "valid query", + query: "name:john", + wantErr: false, + }, + { + name: "query too long", + query: strings.Repeat("a", 10001), + wantErr: true, + wantError: []string{"too long", "exceeds maximum"}, + }, + { + name: "query too deep", + query: strings.Repeat("(", 21) + "name:john" + strings.Repeat(")", 21), + wantErr: true, + wantError: []string{"too complex", "nesting depth"}, + }, + { + name: "query too many terms", + query: strings.Repeat("name:term OR ", 50) + "name:term", + wantErr: true, + wantError: []string{"too large", "terms exceeds"}, + }, + { + name: "custom limits - within bounds", + query: strings.Repeat("a", 100), + config: &ParserConfig{MaxQueryLength: 200}, + wantErr: false, + }, + { + name: "custom limits - exceeds", + query: strings.Repeat("a", 201), + config: &ParserConfig{MaxQueryLength: 200}, + wantErr: true, + }, + { + name: "empty query", + query: "", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var p *Parser + if tt.config != nil { + p = createParser(t, BasicModel{}, tt.config) + } else { + p = parser + } + + err := p.validateQuery(tt.query) + if (err != nil) != tt.wantErr { + t.Errorf("validateQuery() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && len(tt.wantError) > 0 { + assertErrorContains(t, err, tt.wantError, "validateQuery()") + } + }) + } +} + +// TestCalculateNestingDepth tests depth calculation (unit test for helper) +func TestCalculateNestingDepth(t *testing.T) { + tests := []struct { + name string + query string + want int + }{ + { + name: "no nesting", + query: "name:john", + want: 0, + }, + { + name: "single level", + query: "(name:john)", + want: 1, + }, + { + name: "nested", + query: "((name:john))", + want: 2, + }, + { + name: "mixed brackets", + query: "(name:john AND [age:25 TO 65])", + want: 2, + }, + { + name: "quotes ignore nesting", + query: `(name:"test (value)")`, + want: 1, + }, + { + name: "escaped quotes", + query: `(name:"test\"value")`, + want: 1, + }, + { + name: "unbalanced (should still calculate)", + query: "((name:john)", + want: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := calculateNestingDepth(tt.query) + if got != tt.want { + t.Errorf("calculateNestingDepth() = %d, want %d", got, tt.want) + } + }) + } +} + +// TestCountTerms tests term counting (unit test for helper) +func TestCountTerms(t *testing.T) { + tests := []struct { + name string + query string + want int + }{ + { + name: "single term", + query: "name:john", + want: 1, + }, + { + name: "multiple terms", + query: "name:john AND email:test", + want: 3, // name:john, AND (counted before skip), email:test + }, + { + name: "quoted phrase", + query: `name:"john doe"`, + want: 2, // name: and "john doe" (quotes counted separately) + }, + { + name: "range query", + query: "age:[25 TO 65]", + want: 2, // age: and range content + }, + { + name: "implicit search", + query: "john", + want: 1, + }, + { + name: "empty query", + query: "", + want: 0, + }, + { + name: "operators not counted", + query: "name:john AND email:test OR status:active", + want: 5, // name:john, AND, email:test, OR, status:active + }, + { + name: "parentheses not counted", + query: "(name:john OR email:test)", + want: 3, // name:john, OR, email:test + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := countTerms(tt.query) + if got != tt.want { + t.Errorf("countTerms() = %d, want %d", got, tt.want) + } + }) + } +} + +// TestParser_ProviderSpecific tests all SQL providers +func TestParser_ProviderSpecific(t *testing.T) { + parser := createParser(t, BasicModel{}) + + tests := []struct { + name string + query string + provider string + wantSQL []string + wantNot []string + wantParams []any + wantErr bool + }{ + { + name: "postgresql placeholder", + query: "name:john", + provider: "postgresql", + wantSQL: []string{"?"}, + wantNot: []string{"$"}, + wantParams: []any{"john"}, + wantErr: false, + }, + { + name: "mysql placeholder", + query: "name:john", + provider: "mysql", + wantSQL: []string{"?"}, + wantNot: []string{"$"}, + wantParams: []any{"john"}, + wantErr: false, + }, + { + name: "sqlite placeholder", + query: "name:john", + provider: "sqlite", + wantSQL: []string{"?"}, + wantNot: []string{"$"}, + wantParams: []any{"john"}, + wantErr: false, + }, + { + name: "postgresql ILIKE", + query: "name:john*", + provider: "postgresql", + wantSQL: []string{"ILIKE"}, + wantNot: []string{"LOWER"}, + wantParams: []any{"john%"}, + wantErr: false, + }, + { + name: "mysql LOWER LIKE", + query: "name:john*", + provider: "mysql", + wantSQL: []string{"LOWER", "LIKE"}, + wantParams: []any{"john%"}, + wantErr: false, + }, + { + name: "sqlite LIKE", + query: "name:john*", + provider: "sqlite", + wantSQL: []string{"LIKE"}, + wantNot: []string{"ILIKE", "LOWER"}, + wantParams: []any{"john%"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := parser.ParseToSQL(tt.query, tt.provider) + if (err != nil) != tt.wantErr { + t.Fatalf("ParseToSQL() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + assertSQLContains(t, sql, tt.wantSQL, tt.name) + if len(tt.wantNot) > 0 { + assertSQLNotContains(t, sql, tt.wantNot, tt.name) + } + if len(tt.wantParams) > 0 { + // Only validate params if we expect specific values + if len(tt.wantParams) > 0 { + assertParamsEqual(t, params, tt.wantParams, tt.name) + } + } + } + }) + } +} + +// TestParser_ParseToDynamoDBPartiQL tests DynamoDB output +func TestParser_ParseToDynamoDBPartiQL(t *testing.T) { + parser := createParser(t, BasicModel{}) + + tests := []struct { + name string + query string + wantPartiQL []string + wantCount int + wantErr bool + }{ + { + name: "simple query", + query: "name:john", + wantPartiQL: []string{"name"}, + wantCount: 1, + wantErr: false, + }, + { + name: "AND query", + query: "name:john AND email:test", + wantPartiQL: []string{"AND"}, + wantCount: 2, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + partiql, attrs, err := parser.ParseToDynamoDBPartiQL(tt.query) + if (err != nil) != tt.wantErr { + t.Fatalf("ParseToDynamoDBPartiQL() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + assertSQLContains(t, partiql, tt.wantPartiQL, tt.name) + if len(attrs) != tt.wantCount { + t.Errorf("ParseToDynamoDBPartiQL() attrs count = %d, want %d", len(attrs), tt.wantCount) + } + } + }) + } +} + +// BenchmarkParser benchmarks the parser performance +func BenchmarkParser(b *testing.B) { + parser, _ := NewParser(ComplexModel{}) + query := `(name:john* OR email:*@example.com) AND (status:active OR status:pending) AND age:[25 TO 65]` + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _ = parser.ParseToSQL(query, "postgresql") + } +} diff --git a/storage/search/lucene/sql_driver.go b/storage/search/lucene/sql_driver.go new file mode 100644 index 0000000..b2b3488 --- /dev/null +++ b/storage/search/lucene/sql_driver.go @@ -0,0 +1,613 @@ +package lucene + +import ( + "fmt" + "regexp" + "strings" + + "github.com/grindlemire/go-lucene/pkg/driver" + "github.com/grindlemire/go-lucene/pkg/lucene/expr" +) + +// SQLDriver is a SQL driver that supports multiple SQL dialects (PostgreSQL, MySQL, SQLite). +// It handles database-specific syntax for LIKE operators, JSON field access, and parameter placeholders. +type SQLDriver struct { + driver.Base + fields map[string]FieldInfo // Map of field names to their metadata + provider string // SQL provider: "postgresql", "mysql", or "sqlite" +} + +// validateProvider validates that the provider is one of the supported SQL providers. +func validateProvider(provider string) error { + switch provider { + case "postgresql", "mysql", "sqlite": + return nil + default: + return fmt.Errorf("unsupported SQL provider: %s (supported: postgresql, mysql, sqlite)", provider) + } +} + +// NewSQLDriver creates a new SQL driver for the specified provider. +// Provider should be one of: "postgresql", "mysql", "sqlite" +// Returns an error if duplicate field names are found or provider is invalid. +func NewSQLDriver(fields []FieldInfo, provider string) (*SQLDriver, error) { + if err := validateProvider(provider); err != nil { + return nil, err + } + + fieldMap, err := buildFieldMap(fields) + if err != nil { + return nil, err + } + + // RenderFNs map - we handle most operators in renderParamInternal + // Only keeping base implementations for operators we don't intercept + fns := map[expr.Operator]driver.RenderFN{ + expr.Literal: driver.Shared[expr.Literal], + expr.And: driver.Shared[expr.And], + expr.Or: driver.Shared[expr.Or], + expr.Not: driver.Shared[expr.Not], + expr.Equals: driver.Shared[expr.Equals], + expr.Range: driver.Shared[expr.Range], + expr.Must: driver.Shared[expr.Must], + expr.MustNot: driver.Shared[expr.MustNot], + expr.Wild: driver.Shared[expr.Wild], + expr.Regexp: driver.Shared[expr.Regexp], + expr.Like: driver.Shared[expr.Like], + expr.Greater: driver.Shared[expr.Greater], + expr.GreaterEq: driver.Shared[expr.GreaterEq], + expr.Less: driver.Shared[expr.Less], + expr.LessEq: driver.Shared[expr.LessEq], + expr.In: driver.Shared[expr.In], + expr.List: driver.Shared[expr.List], + } + + return &SQLDriver{ + Base: driver.Base{ + RenderFNs: fns, + }, + fields: fieldMap, + provider: provider, + }, nil +} + +// RenderParam renders the expression with provider-specific parameter placeholders. +func (s *SQLDriver) RenderParam(e *expr.Expression) (string, []any, error) { + // Process JSON field notation before rendering + s.processJSONFields(e) + + // Use our custom rendering logic + str, params, err := s.renderParamInternal(e) + if err != nil { + return "", nil, err + } + + // Keep ? placeholders for all providers. + // GORM's PostgreSQL driver handles ? → $N conversion automatically, + // so pre-converting here would conflict with additional WHERE clauses + // (e.g. cursor pagination) that GORM adds with its own ? placeholders. + + return str, params, nil +} + +// renderParamInternal dispatches to specialized renderers based on operator type. +func (s *SQLDriver) renderParamInternal(e *expr.Expression) (string, []any, error) { + if e == nil { + return "", nil, nil + } + + switch e.Op { + case expr.Like, expr.Wild: + return s.renderLikeOrWild(e) + case expr.Fuzzy: + return s.renderFuzzy(e) + case expr.Boost: + return "", nil, fmt.Errorf("boost operator (^) is not supported in SQL filtering; it only affects ranking/scoring") + case expr.Range: + return s.renderRange(e) + case expr.Equals, expr.Greater, expr.Less, expr.GreaterEq, expr.LessEq: + return s.renderComparison(e) + case expr.And, expr.Or, expr.Must, expr.MustNot: + return s.renderBinary(e) + default: + // Use base implementation for all other operators + return s.Base.RenderParam(e) + } +} + +// renderLikeOrWild converts LIKE and Wild operators to provider-specific case-insensitive matching. +func (s *SQLDriver) renderLikeOrWild(e *expr.Expression) (string, []any, error) { + leftStr, leftParams, err := s.serializeColumn(e.Left) + if err != nil { + return "", nil, err + } + + rightStr, rightParams, err := s.serializeValue(e.Right) + if err != nil { + return "", nil, err + } + + params := append(leftParams, rightParams...) + + switch s.provider { + case "postgresql": + // PostgreSQL: ILIKE for case-insensitive matching + if isJSONSyntax(leftStr) { + return fmt.Sprintf("%s ILIKE %s", leftStr, rightStr), params, nil + } + return fmt.Sprintf("%s::text ILIKE %s", leftStr, rightStr), params, nil + + case "mysql": + // MySQL: Use LOWER() for case-insensitive matching + return fmt.Sprintf("LOWER(%s) LIKE LOWER(%s)", leftStr, rightStr), params, nil + + case "sqlite": + // SQLite: LIKE is already case-insensitive for ASCII by default + return fmt.Sprintf("%s LIKE %s", leftStr, rightStr), params, nil + + default: + return "", nil, fmt.Errorf("unsupported SQL provider: %s", s.provider) + } +} + +// renderFuzzy handles fuzzy search with provider-specific implementations. +// For queries like "name:roam~2", the structure is: +// - Op: Fuzzy +// - Left: Equals expression (name:roam) with Left=Column("name"), Right=Literal("roam") +// - Right: nil (distance stored in unexported fuzzyDistance field) +func (s *SQLDriver) renderFuzzy(e *expr.Expression) (string, []any, error) { + leftExpr, ok := e.Left.(*expr.Expression) + if !ok || leftExpr.Op != expr.Equals { + return "", nil, fmt.Errorf("fuzzy operator requires field:value syntax (e.g., name:roam~2)") + } + + colStr, colParams, err := s.serializeColumn(leftExpr.Left) + if err != nil { + return "", nil, err + } + + termStr, termParams, err := s.serializeValue(leftExpr.Right) + if err != nil { + return "", nil, err + } + + params := append(colParams, termParams...) + + switch s.provider { + case "postgresql": + // PostgreSQL: Use similarity() function from pg_trgm extension + // Threshold 0.3 (lower = more matches, higher = stricter) + threshold := 0.3 + if isJSONSyntax(colStr) { + return fmt.Sprintf("similarity(%s, %s) > %f", colStr, termStr, threshold), params, nil + } + return fmt.Sprintf("similarity(%s::text, %s) > %f", colStr, termStr, threshold), params, nil + + case "mysql": + // MySQL: Use SOUNDEX for phonetic matching (limited fuzzy support) + return fmt.Sprintf("SOUNDEX(%s) = SOUNDEX(%s)", colStr, termStr), params, nil + + case "sqlite": + // SQLite: No built-in fuzzy search support + return "", nil, fmt.Errorf("fuzzy search (field:term~N) is not supported with SQLite; use wildcards instead (e.g., field:term*)") + + default: + return "", nil, fmt.Errorf("unsupported SQL provider: %s", s.provider) + } +} + +// renderComparison handles comparison operators with IS NULL support for null values. +func (s *SQLDriver) renderComparison(e *expr.Expression) (string, []any, error) { + leftStr, leftParams, err := s.serializeColumn(e.Left) + if err != nil { + return "", nil, err + } + + if isNullValue(e.Right) { + if e.Op == expr.Equals { + return fmt.Sprintf("%s IS NULL", leftStr), leftParams, nil + } + return "", nil, fmt.Errorf("cannot use comparison operators (>, <, >=, <=) with null value") + } + + rightStr, rightParams, err := s.serializeValue(e.Right) + if err != nil { + return "", nil, err + } + + params := append(leftParams, rightParams...) + + var opSymbol string + switch e.Op { + case expr.Equals: + opSymbol = "=" + case expr.Greater: + opSymbol = ">" + case expr.Less: + opSymbol = "<" + case expr.GreaterEq: + opSymbol = ">=" + case expr.LessEq: + opSymbol = "<=" + } + + return fmt.Sprintf("%s %s %s", leftStr, opSymbol, rightStr), params, nil +} + +// renderBinary handles binary and unary logical operators recursively. +// Note: Must and MustNot are unary (only Left operand), while And and Or are binary. +func (s *SQLDriver) renderBinary(e *expr.Expression) (string, []any, error) { + switch e.Op { + case expr.Must, expr.MustNot: + if e.Left == nil { + return "", nil, fmt.Errorf("%s operator requires a left operand", e.Op) + } + + var leftStr string + var leftParams []any + var err error + + if leftExpr, ok := e.Left.(*expr.Expression); ok { + leftStr, leftParams, err = s.renderParamInternal(leftExpr) + if err != nil { + return "", nil, err + } + } else { + leftStr, leftParams, err = s.serializeColumn(e.Left) + if err != nil { + leftStr, leftParams, err = s.serializeValue(e.Left) + if err != nil { + return s.Base.RenderParam(e) + } + } + } + + if e.Op == expr.Must { + return leftStr, leftParams, nil + } + return fmt.Sprintf("NOT (%s)", leftStr), leftParams, nil + + case expr.And, expr.Or: + if e.Left == nil || e.Right == nil { + return "", nil, fmt.Errorf("%s operator requires both left and right operands", e.Op) + } + + leftExpr, leftIsExpr := e.Left.(*expr.Expression) + rightExpr, rightIsExpr := e.Right.(*expr.Expression) + + if !leftIsExpr || !rightIsExpr { + return s.Base.RenderParam(e) + } + + leftStr, leftParams, err := s.renderParamInternal(leftExpr) + if err != nil { + return "", nil, err + } + + rightStr, rightParams, err := s.renderParamInternal(rightExpr) + if err != nil { + return "", nil, err + } + + params := append(leftParams, rightParams...) + + if e.Op == expr.And { + return fmt.Sprintf("(%s) AND (%s)", leftStr, rightStr), params, nil + } + return fmt.Sprintf("(%s) OR (%s)", leftStr, rightStr), params, nil + + default: + return "", nil, fmt.Errorf("unsupported operator: %v", e.Op) + } +} + +// quoteColumnName quotes a column name if it's not already JSON syntax. +func quoteColumnName(colStr string) string { + if isJSONSyntax(colStr) { + return colStr + } + return fmt.Sprintf(`"%s"`, colStr) +} + +func (s *SQLDriver) serializeColumn(in any) (string, []any, error) { + switch v := in.(type) { + case expr.Column: + return quoteColumnName(string(v)), nil, nil + case string: + return quoteColumnName(v), nil, nil + case *expr.Expression: + if v.Op == expr.Literal && v.Left != nil { + if col, ok := v.Left.(expr.Column); ok { + return quoteColumnName(string(col)), nil, nil + } + } + return s.renderParamInternal(v) + default: + return "", nil, fmt.Errorf("unexpected column type: %T", v) + } +} + +// extractLiteralString extracts a string value from an expression for wildcard conversion. +func extractLiteralString(v *expr.Expression) (string, bool) { + if v.Left == nil { + return "", false + } + if v.Op == expr.Literal || v.Op == expr.Wild { + return fmt.Sprintf("%v", v.Left), true + } + return "", false +} + +// serializeValue converts Lucene wildcards (* and ?) to SQL wildcards (% and _). +func (s *SQLDriver) serializeValue(in any) (string, []any, error) { + switch v := in.(type) { + case string: + return "?", []any{convertWildcards(v)}, nil + case *expr.Expression: + if literalVal, ok := extractLiteralString(v); ok { + return "?", []any{convertWildcards(literalVal)}, nil + } + return s.renderParamInternal(v) + case nil: + return "", nil, fmt.Errorf("nil value in expression") + default: + return "?", []any{v}, nil + } +} + +// processJSONFields recursively processes the expression tree to convert +// field.subfield notation to provider-specific JSON syntax. +func (s *SQLDriver) processJSONFields(e *expr.Expression) { + if e == nil { + return + } + + // Process left side if it's a column + if col, ok := e.Left.(expr.Column); ok { + e.Left = s.formatFieldName(string(col)) + } + + // Recursively process expressions + if leftExpr, ok := e.Left.(*expr.Expression); ok { + s.processJSONFields(leftExpr) + } + if rightExpr, ok := e.Right.(*expr.Expression); ok { + s.processJSONFields(rightExpr) + } + + // Process expression slices + if exprs, ok := e.Left.([]*expr.Expression); ok { + for _, ex := range exprs { + s.processJSONFields(ex) + } + } + if exprs, ok := e.Right.([]*expr.Expression); ok { + for _, ex := range exprs { + s.processJSONFields(ex) + } + } +} + +var ( + // jsonSubFieldPattern matches valid JSON subfield names (alphanumeric, underscore, and dot for nested paths) + jsonSubFieldPattern = regexp.MustCompile(`^[a-zA-Z0-9_.]+$`) +) + +// validateSubFieldName validates that a subfield name contains only safe characters. +// Subfield names should be alphanumeric with underscores and dots for nested paths. +// This prevents injection attacks via JSON path manipulation. +func validateSubFieldName(subField string) error { + if subField == "" { + return fmt.Errorf("subfield name cannot be empty") + } + + if !jsonSubFieldPattern.MatchString(subField) { + return fmt.Errorf("invalid subfield name '%s': contains unsafe characters (only alphanumeric, underscore, and dot allowed)", subField) + } + return nil +} + +// escapeJSONPathSegment escapes a JSON path segment for safe use in JSON path expressions. +// For PostgreSQL: escapes single quotes in the key name (used in ->>'key' syntax) +// For MySQL/SQLite: escapes special characters in JSON path (though validation should prevent most) +func escapeJSONPathSegment(segment string) string { + // Replace single quote with escaped single quote (for PostgreSQL ->>'key' syntax) + result := strings.ReplaceAll(segment, "'", "''") + return result +} + +// formatFieldName converts field.subfield to provider-specific JSON syntax. +// Validates and escapes subfield names to prevent injection attacks. +func (s *SQLDriver) formatFieldName(fieldName string) expr.Column { + parts := strings.SplitN(fieldName, ".", 2) + if len(parts) == 2 { + baseField := parts[0] + subField := parts[1] + + // Validate subfield name for security (prevents injection) + if err := validateSubFieldName(subField); err != nil { + // If validation fails, return original field name (will be caught by field validation) + return expr.Column(fieldName) + } + + if field, exists := s.fields[baseField]; exists && canUseNestedAccess(field.Type) { + // Escape subfield name for safe interpolation + // PostgreSQL uses ->>'key' syntax where key is in quotes, so we need to escape quotes + escapedSubField := escapeJSONPathSegment(subField) + + switch s.provider { + case "postgresql": + // PostgreSQL: JSONB operator ->> + // Key is in single quotes, so we escape single quotes + return expr.Column(fmt.Sprintf("%s->>'%s'", baseField, escapedSubField)) + + case "mysql": + // MySQL 5.7+: JSON_UNQUOTE(JSON_EXTRACT(column, '$.field')) + // Path is '$.field' - field name is not separately quoted, but validation ensures it's safe + return expr.Column(fmt.Sprintf("JSON_UNQUOTE(JSON_EXTRACT(%s, '$.%s'))", baseField, subField)) + + case "sqlite": + // SQLite: JSON_EXTRACT(column, '$.field') + // Path is '$.field' - field name is not separately quoted, but validation ensures it's safe + return expr.Column(fmt.Sprintf("JSON_EXTRACT(%s, '$.%s')", baseField, subField)) + + default: + // Should never happen due to validateProvider, but defensive programming + return expr.Column(fieldName) + } + } + } + return expr.Column(fieldName) +} + +// Helper functions for SQL driver + +// convertWildcards converts Lucene wildcards to SQL wildcards. +// * (any characters) → % (SQL wildcard) +// ? (single character) → _ (SQL wildcard) +// +// Note: go-lucene's base driver also converts wildcards, but only for expr.Like operators. +// We need this function because we also convert wildcards for expr.Wild expressions +// and when serializing values for fuzzy search and other operators. +func convertWildcards(s string) string { + // Use a builder for efficient string manipulation + var result strings.Builder + result.Grow(len(s)) + + for i := 0; i < len(s); i++ { + c := s[i] + switch c { + case '*': + result.WriteByte('%') + case '?': + result.WriteByte('_') + default: + result.WriteByte(c) + } + } + return result.String() +} + +// isJSONSyntax checks if a column string contains provider-specific JSON syntax. +func isJSONSyntax(col string) bool { + // Check for PostgreSQL JSONB operator + if strings.Contains(col, "->>") { + return true + } + // Check for MySQL/SQLite JSON_EXTRACT + if strings.Contains(col, "JSON_EXTRACT") || strings.Contains(col, "JSON_UNQUOTE") { + return true + } + return false +} + +// isNullValue checks if a value represents null in Lucene query syntax. +// Supports: null, NULL, Null (case-insensitive) +// Note: This is a SQL-specific extension (vanilla Lucene doesn't support NULL values). +// We intentionally do NOT support "empty" or "nil" as they could be legitimate search values. +func isNullValue(v any) bool { + strVal := extractStringValue(v) + if strVal == "" { + return false + } + lower := strings.ToLower(strVal) + return lower == "null" +} + +func extractStringValue(v any) string { + switch val := v.(type) { + case string: + return val + case *expr.Expression: + if val.Op == expr.Literal && val.Left != nil { + if strVal, ok := val.Left.(string); ok { + return strVal + } + } + } + return "" +} + +func extractLiteralValue(v any) string { + if v == nil { + return "" + } + + // If it's an expression, try to extract the Left value (for LITERAL expressions) + if ex, ok := v.(*expr.Expression); ok { + if ex.Op == expr.Literal && ex.Left != nil { + // LITERAL expressions store the actual value in Left + return fmt.Sprintf("%v", ex.Left) + } + // For other expression types, return the string representation + return fmt.Sprintf("%v", v) + } + + // For non-expression types, return as string + return fmt.Sprintf("%v", v) +} + +// renderRange handles range queries including open-ended ranges with wildcards (*). +func (s *SQLDriver) renderRange(e *expr.Expression) (string, []any, error) { + colStr, _, err := s.serializeColumn(e.Left) + if err != nil { + return "", nil, err + } + + rangeBoundary, ok := e.Right.(*expr.RangeBoundary) + if !ok { + return "", nil, fmt.Errorf("invalid range expression structure: expected *expr.RangeBoundary, got %T", e.Right) + } + + var minVal, maxVal string + var params []any + + if rangeBoundary.Min != nil { + minVal = extractLiteralValue(rangeBoundary.Min) + } + + if rangeBoundary.Max != nil { + maxVal = extractLiteralValue(rangeBoundary.Max) + } + + if minVal == "*" && maxVal == "*" { + return "", nil, fmt.Errorf("both range bounds cannot be wildcards") + } + + if minVal == "*" { + params = append(params, maxVal) + if rangeBoundary.Inclusive { + return fmt.Sprintf("%s <= ?", colStr), params, nil + } + return fmt.Sprintf("%s < ?", colStr), params, nil + } + + if maxVal == "*" { + params = append(params, minVal) + if rangeBoundary.Inclusive { + return fmt.Sprintf("%s >= ?", colStr), params, nil + } + return fmt.Sprintf("%s > ?", colStr), params, nil + } + + params = append(params, minVal, maxVal) + if rangeBoundary.Inclusive { + return fmt.Sprintf("%s BETWEEN ? AND ?", colStr), params, nil + } + return fmt.Sprintf("(%s > ? AND %s < ?)", colStr, colStr), params, nil +} + +// convertToPostgresPlaceholders converts ? placeholders to PostgreSQL's $N format. +func convertToPostgresPlaceholders(query string) string { + paramIndex := 1 + result := strings.Builder{} + for i := 0; i < len(query); i++ { + if query[i] == '?' { + result.WriteString(fmt.Sprintf("$%d", paramIndex)) + paramIndex++ + } else { + result.WriteByte(query[i]) + } + } + return result.String() +} diff --git a/storage/search/lucene/sql_driver_test.go b/storage/search/lucene/sql_driver_test.go new file mode 100644 index 0000000..b001b82 --- /dev/null +++ b/storage/search/lucene/sql_driver_test.go @@ -0,0 +1,1542 @@ +package lucene + +import ( + "fmt" + "reflect" + "strings" + "testing" + + "github.com/grindlemire/go-lucene/pkg/lucene/expr" +) + +func TestNewSQLDriver(t *testing.T) { + tests := []struct { + name string + fields []FieldInfo + provider string + wantErr bool + }{ + { + name: "postgresql with fields", + fields: []FieldInfo{{Name: "name", Type: reflect.TypeOf("")}}, + provider: "postgresql", + wantErr: false, + }, + { + name: "mysql with fields", + fields: []FieldInfo{{Name: "name", Type: reflect.TypeOf("")}}, + provider: "mysql", + wantErr: false, + }, + { + name: "sqlite with fields", + fields: []FieldInfo{{Name: "name", Type: reflect.TypeOf("")}}, + provider: "sqlite", + wantErr: false, + }, + { + name: "empty fields", + fields: []FieldInfo{}, + provider: "postgresql", + wantErr: false, + }, + { + name: "multiple fields", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "email", Type: reflect.TypeOf("")}, + {Name: "age", Type: reflect.TypeOf(0)}, + }, + provider: "postgresql", + wantErr: false, + }, + { + name: "duplicate field names returns error (postgresql)", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "name", Type: reflect.TypeOf(0)}, + }, + provider: "postgresql", + wantErr: true, + }, + { + name: "duplicate field names returns error (mysql)", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "name", Type: reflect.TypeOf(0)}, + }, + provider: "mysql", + wantErr: true, + }, + { + name: "duplicate field names returns error (sqlite)", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "name", Type: reflect.TypeOf(0)}, + }, + provider: "sqlite", + wantErr: true, + }, + { + name: "multiple duplicate field names (postgresql)", + fields: []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "email", Type: reflect.TypeOf("")}, + {Name: "name", Type: reflect.TypeOf(0)}, + }, + provider: "postgresql", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver, err := NewSQLDriver(tt.fields, tt.provider) + if (err != nil) != tt.wantErr { + t.Errorf("NewSQLDriver() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + if err == nil { + t.Errorf("NewSQLDriver() expected error but got nil") + } + if driver != nil { + t.Errorf("NewSQLDriver() expected nil driver on error, got %v", driver) + } + if err != nil && !strings.Contains(err.Error(), "duplicate field name") { + t.Errorf("NewSQLDriver() error message should contain 'duplicate field name', got: %v", err) + } + return + } + if driver == nil { + t.Fatalf("NewSQLDriver() returned nil") + } + if driver.provider != tt.provider { + t.Errorf("NewSQLDriver() provider = %v, want %v", driver.provider, tt.provider) + } + if len(driver.fields) != len(tt.fields) { + t.Errorf("NewSQLDriver() fields count = %v, want %v", len(driver.fields), len(tt.fields)) + } + }) + } +} + +func TestSQLDriver_RenderParam(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "email", Type: reflect.TypeOf("")}, + } + providers := []string{"postgresql", "mysql", "sqlite"} + + tests := []struct { + name string + expr *expr.Expression + wantSQL []string + wantCount int + wantErr bool + }{ + { + name: "equals expression", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{`"name"`, "="}, + wantCount: 1, + wantErr: false, + }, + { + name: "AND expression", + expr: &expr.Expression{ + Op: expr.And, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + Right: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("email"), + Right: &expr.Expression{Op: expr.Literal, Left: "test@example.com"}, + }, + }, + wantSQL: []string{"AND"}, + wantCount: 2, + wantErr: false, + }, + { + name: "nil expression", + expr: nil, + wantErr: false, + }, + } + + for _, provider := range providers { + for _, tt := range tests { + t.Run(provider+"/"+tt.name, func(t *testing.T) { + driver, err := NewSQLDriver(fields, provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } + sql, params, err := driver.RenderParam(tt.expr) + if (err != nil) != tt.wantErr { + t.Errorf("RenderParam() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + if tt.expr == nil { + if sql != "" { + t.Errorf("RenderParam() sql = %v, want empty string", sql) + } + if len(params) != 0 { + t.Errorf("RenderParam() params count = %v, want 0", len(params)) + } + return + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("RenderParam() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("RenderParam() params count = %v, want %v", len(params), tt.wantCount) + } + // All providers use ? placeholders; GORM handles $N conversion for PostgreSQL + if tt.wantCount > 0 && !strings.Contains(sql, "?") { + t.Errorf("RenderParam() expected ? placeholders for %v, got %v", provider, sql) + } + }) + } + } +} + +func TestSQLDriver_RenderLikeOrWild(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "metadata", Type: reflect.TypeOf(map[string]interface{}{})}, + } + + tests := []struct { + name string + provider string + expr *expr.Expression + wantSQL []string + wantCount int + wantErr bool + }{ + { + name: "postgresql LIKE regular field", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{"ILIKE"}, + wantCount: 1, + wantErr: false, + }, + { + name: "postgresql LIKE JSON field", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("metadata->>'key'"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + wantSQL: []string{"ILIKE"}, + wantCount: 1, + wantErr: false, + }, + { + name: "mysql LIKE", + provider: "mysql", + expr: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{"LOWER", "LIKE"}, + wantCount: 1, + wantErr: false, + }, + { + name: "sqlite LIKE", + provider: "sqlite", + expr: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{"LIKE"}, + wantCount: 1, + wantErr: false, + }, + { + name: "postgresql WILD", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Wild, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john*"}, + }, + wantSQL: []string{"ILIKE"}, + wantCount: 1, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } + sql, params, err := driver.renderLikeOrWild(tt.expr) + if (err != nil) != tt.wantErr { + t.Errorf("renderLikeOrWild() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("renderLikeOrWild() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("renderLikeOrWild() params count = %v, want %v", len(params), tt.wantCount) + } + } + }) + } +} + +func TestSQLDriver_RenderFuzzy(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "metadata", Type: reflect.TypeOf(map[string]interface{}{})}, + } + + tests := []struct { + name string + provider string + expr *expr.Expression + wantSQL []string + wantCount int + wantErr bool + }{ + { + name: "postgresql fuzzy", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Fuzzy, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "roam"}, + }, + }, + wantSQL: []string{"similarity"}, + wantCount: 1, + wantErr: false, + }, + { + name: "postgresql fuzzy JSON field", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Fuzzy, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("metadata->>'key'"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + }, + wantSQL: []string{"similarity"}, + wantCount: 1, + wantErr: false, + }, + { + name: "mysql fuzzy", + provider: "mysql", + expr: &expr.Expression{ + Op: expr.Fuzzy, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "roam"}, + }, + }, + wantSQL: []string{"SOUNDEX"}, + wantCount: 1, + wantErr: false, + }, + { + name: "sqlite fuzzy (unsupported)", + provider: "sqlite", + expr: &expr.Expression{ + Op: expr.Fuzzy, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "roam"}, + }, + }, + wantErr: true, + }, + { + name: "invalid fuzzy expression (not Equals)", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Fuzzy, + Left: &expr.Expression{Op: expr.And}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } + sql, params, err := driver.renderFuzzy(tt.expr) + if (err != nil) != tt.wantErr { + t.Errorf("renderFuzzy() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("renderFuzzy() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("renderFuzzy() params count = %v, want %v", len(params), tt.wantCount) + } + } + }) + } +} + +func TestSQLDriver_RenderComparison(t *testing.T) { + fields := []FieldInfo{ + {Name: "age", Type: reflect.TypeOf(0)}, + {Name: "name", Type: reflect.TypeOf("")}, + } + + tests := []struct { + name string + provider string + op expr.Operator + right *expr.Expression + wantSQL []string + wantCount int + wantErr bool + }{ + { + name: "equals", + provider: "postgresql", + op: expr.Equals, + right: &expr.Expression{Op: expr.Literal, Left: "john"}, + wantSQL: []string{`"name"`, "="}, + wantCount: 1, + wantErr: false, + }, + { + name: "greater than", + provider: "postgresql", + op: expr.Greater, + right: &expr.Expression{Op: expr.Literal, Left: 25}, + wantSQL: []string{`"age"`, ">"}, + wantCount: 1, + wantErr: false, + }, + { + name: "less than", + provider: "postgresql", + op: expr.Less, + right: &expr.Expression{Op: expr.Literal, Left: 65}, + wantSQL: []string{`"age"`, "<"}, + wantCount: 1, + wantErr: false, + }, + { + name: "greater or equal", + provider: "postgresql", + op: expr.GreaterEq, + right: &expr.Expression{Op: expr.Literal, Left: 25}, + wantSQL: []string{`"age"`, ">="}, + wantCount: 1, + wantErr: false, + }, + { + name: "less or equal", + provider: "postgresql", + op: expr.LessEq, + right: &expr.Expression{Op: expr.Literal, Left: 65}, + wantSQL: []string{`"age"`, "<="}, + wantCount: 1, + wantErr: false, + }, + { + name: "equals null", + provider: "postgresql", + op: expr.Equals, + right: &expr.Expression{Op: expr.Literal, Left: "null"}, + wantSQL: []string{`"name"`, "IS NULL"}, + wantCount: 0, + wantErr: false, + }, + { + name: "greater than null (error)", + provider: "postgresql", + op: expr.Greater, + right: &expr.Expression{Op: expr.Literal, Left: "null"}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } + var left expr.Column + if strings.Contains(tt.name, "age") || tt.op == expr.Greater || tt.op == expr.Less || tt.op == expr.GreaterEq || tt.op == expr.LessEq { + left = expr.Column("age") + } else { + left = expr.Column("name") + } + e := &expr.Expression{ + Op: tt.op, + Left: left, + Right: tt.right, + } + sql, params, err := driver.renderComparison(e) + if (err != nil) != tt.wantErr { + t.Errorf("renderComparison() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("renderComparison() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("renderComparison() params count = %v, want %v", len(params), tt.wantCount) + } + } + }) + } +} + +func TestSQLDriver_RenderBinary(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "email", Type: reflect.TypeOf("")}, + } + + tests := []struct { + name string + provider string + op expr.Operator + left *expr.Expression + right *expr.Expression + wantSQL []string + wantCount int + wantErr bool + }{ + { + name: "AND", + provider: "postgresql", + op: expr.And, + left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + right: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("email"), + Right: &expr.Expression{Op: expr.Literal, Left: "test@example.com"}, + }, + wantSQL: []string{"AND"}, + wantCount: 2, + wantErr: false, + }, + { + name: "OR", + provider: "postgresql", + op: expr.Or, + left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + right: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "jane"}, + }, + wantSQL: []string{"OR"}, + wantCount: 2, + wantErr: false, + }, + { + name: "Must", + provider: "postgresql", + op: expr.Must, + left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + right: nil, + wantSQL: []string{`"name"`}, + wantCount: 1, + wantErr: false, + }, + { + name: "MustNot", + provider: "postgresql", + op: expr.MustNot, + left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + right: nil, + wantSQL: []string{"NOT"}, + wantCount: 1, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } + e := &expr.Expression{ + Op: tt.op, + Left: tt.left, + Right: tt.right, + } + sql, params, err := driver.renderBinary(e) + if (err != nil) != tt.wantErr { + t.Errorf("renderBinary() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("renderBinary() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("renderBinary() params count = %v, want %v", len(params), tt.wantCount) + } + } + }) + } +} + +func TestSQLDriver_RenderRange(t *testing.T) { + fields := []FieldInfo{ + {Name: "age", Type: reflect.TypeOf(0)}, + {Name: "date", Type: reflect.TypeOf("")}, + } + + tests := []struct { + name string + provider string + rangeExpr *expr.RangeBoundary + wantSQL []string + wantCount int + wantErr bool + }{ + { + name: "inclusive range", + provider: "postgresql", + rangeExpr: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "25"}, + Max: &expr.Expression{Op: expr.Literal, Left: "65"}, + Inclusive: true, + }, + wantSQL: []string{"BETWEEN"}, + wantCount: 2, + wantErr: false, + }, + { + name: "exclusive range", + provider: "postgresql", + rangeExpr: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "25"}, + Max: &expr.Expression{Op: expr.Literal, Left: "65"}, + Inclusive: false, + }, + wantSQL: []string{">", "<"}, + wantCount: 2, + wantErr: false, + }, + { + name: "open-ended min (inclusive)", + provider: "postgresql", + rangeExpr: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "*"}, + Max: &expr.Expression{Op: expr.Literal, Left: "65"}, + Inclusive: true, + }, + wantSQL: []string{"<="}, + wantCount: 1, + wantErr: false, + }, + { + name: "open-ended min (exclusive)", + provider: "postgresql", + rangeExpr: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "*"}, + Max: &expr.Expression{Op: expr.Literal, Left: "65"}, + Inclusive: false, + }, + wantSQL: []string{"<"}, + wantCount: 1, + wantErr: false, + }, + { + name: "open-ended max (inclusive)", + provider: "postgresql", + rangeExpr: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "25"}, + Max: &expr.Expression{Op: expr.Literal, Left: "*"}, + Inclusive: true, + }, + wantSQL: []string{">="}, + wantCount: 1, + wantErr: false, + }, + { + name: "open-ended max (exclusive)", + provider: "postgresql", + rangeExpr: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "25"}, + Max: &expr.Expression{Op: expr.Literal, Left: "*"}, + Inclusive: false, + }, + wantSQL: []string{">"}, + wantCount: 1, + wantErr: false, + }, + { + name: "both wildcards (error)", + provider: "postgresql", + rangeExpr: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "*"}, + Max: &expr.Expression{Op: expr.Literal, Left: "*"}, + Inclusive: true, + }, + wantErr: true, + }, + { + name: "invalid range expression (error)", + provider: "postgresql", + rangeExpr: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } + var e *expr.Expression + if tt.rangeExpr == nil { + e = &expr.Expression{ + Op: expr.Range, + Left: expr.Column("age"), + Right: nil, + } + } else { + e = &expr.Expression{ + Op: expr.Range, + Left: expr.Column("age"), + Right: tt.rangeExpr, + } + } + sql, params, err := driver.renderRange(e) + if (err != nil) != tt.wantErr { + t.Errorf("renderRange() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("renderRange() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("renderRange() params count = %v, want %v", len(params), tt.wantCount) + } + } + }) + } +} + +func TestSQLDriver_SerializeColumn(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "metadata", Type: reflect.TypeOf(map[string]interface{}{})}, + } + driver, err := NewSQLDriver(fields, "postgresql") + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } + + tests := []struct { + name string + input any + wantSQL string + wantCount int + wantErr bool + }{ + { + name: "simple column", + input: expr.Column("name"), + wantSQL: `"name"`, + wantCount: 0, + wantErr: false, + }, + { + name: "JSON syntax column", + input: expr.Column("metadata->>'key'"), + wantSQL: "metadata->>'key'", + wantCount: 0, + wantErr: false, + }, + { + name: "string column", + input: "name", + wantSQL: `"name"`, + wantCount: 0, + wantErr: false, + }, + { + name: "JSON syntax string", + input: "metadata->>'key'", + wantSQL: "metadata->>'key'", + wantCount: 0, + wantErr: false, + }, + { + name: "expression with Literal column", + input: &expr.Expression{Op: expr.Literal, Left: expr.Column("name")}, + wantSQL: `"name"`, + wantCount: 0, + wantErr: false, + }, + { + name: "invalid type", + input: 123, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := driver.serializeColumn(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("serializeColumn() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if !strings.Contains(sql, tt.wantSQL) { + t.Errorf("serializeColumn() sql = %v, want to contain %v", sql, tt.wantSQL) + } + if len(params) != tt.wantCount { + t.Errorf("serializeColumn() params count = %v, want %v", len(params), tt.wantCount) + } + } + }) + } +} + +func TestSQLDriver_SerializeValue(t *testing.T) { + fields := []FieldInfo{{Name: "name", Type: reflect.TypeOf("")}} + driver, err := NewSQLDriver(fields, "postgresql") + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } + + tests := []struct { + name string + input any + wantSQL string + wantValue string + wantErr bool + }{ + { + name: "string value", + input: "john", + wantSQL: "?", + wantValue: "john", + wantErr: false, + }, + { + name: "string with wildcards", + input: "john*", + wantSQL: "?", + wantValue: "john%", + wantErr: false, + }, + { + name: "literal expression", + input: &expr.Expression{Op: expr.Literal, Left: "test"}, + wantSQL: "?", + wantValue: "test", + wantErr: false, + }, + { + name: "wild expression", + input: &expr.Expression{Op: expr.Wild, Left: "test*"}, + wantSQL: "?", + wantValue: "test%", + wantErr: false, + }, + { + name: "nil value (error)", + input: nil, + wantErr: true, + }, + { + name: "integer value", + input: 42, + wantSQL: "?", + wantValue: "42", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := driver.serializeValue(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("serializeValue() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if sql != tt.wantSQL { + t.Errorf("serializeValue() sql = %v, want %v", sql, tt.wantSQL) + } + if len(params) != 1 { + t.Errorf("serializeValue() params count = %v, want 1", len(params)) + return + } + gotValue := fmt.Sprintf("%v", params[0]) + if tt.wantValue != "" && gotValue != tt.wantValue { + t.Errorf("serializeValue() param value = %v, want %v", gotValue, tt.wantValue) + } + } + }) + } +} + +func TestSQLDriver_FormatFieldName(t *testing.T) { + jsonbType := reflect.TypeOf(map[string]interface{}{}) + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "metadata", Type: jsonbType}, + } + + tests := []struct { + name string + provider string + field string + want string + }{ + { + name: "postgresql JSON field", + provider: "postgresql", + field: "metadata.key", + want: "metadata->>'key'", + }, + { + name: "mysql JSON field", + provider: "mysql", + field: "metadata.key", + want: "JSON_UNQUOTE(JSON_EXTRACT(metadata, '$.key'))", + }, + { + name: "sqlite JSON field", + provider: "sqlite", + field: "metadata.key", + want: "JSON_EXTRACT(metadata, '$.key')", + }, + { + name: "simple field (no dot)", + provider: "postgresql", + field: "name", + want: "name", + }, + { + name: "non-JSONB field with dot (no conversion)", + provider: "postgresql", + field: "name.subfield", + want: "name.subfield", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } + got := driver.formatFieldName(tt.field) + if string(got) != tt.want { + t.Errorf("formatFieldName() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConvertWildcards(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "no wildcards", + input: "john", + want: "john", + }, + { + name: "single *", + input: "john*", + want: "john%", + }, + { + name: "single ?", + input: "jo?n", + want: "jo_n", + }, + { + name: "multiple *", + input: "*john*", + want: "%john%", + }, + { + name: "multiple ?", + input: "j??n", + want: "j__n", + }, + { + name: "mixed wildcards", + input: "j*?n", + want: "j%_n", + }, + { + name: "empty string", + input: "", + want: "", + }, + { + name: "only wildcards", + input: "***", + want: "%%%", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := convertWildcards(tt.input) + if got != tt.want { + t.Errorf("convertWildcards() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsJSONSyntax(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + { + name: "PostgreSQL JSONB operator", + input: "metadata->>'key'", + want: true, + }, + { + name: "MySQL JSON_EXTRACT", + input: "JSON_EXTRACT(column, '$.field')", + want: true, + }, + { + name: "MySQL JSON_UNQUOTE", + input: "JSON_UNQUOTE(JSON_EXTRACT(column, '$.field'))", + want: true, + }, + { + name: "SQLite JSON_EXTRACT", + input: "JSON_EXTRACT(column, '$.field')", + want: true, + }, + { + name: "regular column", + input: "name", + want: false, + }, + { + name: "quoted column", + input: `"name"`, + want: false, + }, + { + name: "empty string", + input: "", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isJSONSyntax(tt.input) + if got != tt.want { + t.Errorf("isJSONSyntax() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsNullValue(t *testing.T) { + tests := []struct { + name string + input any + want bool + }{ + { + name: "null string (lowercase)", + input: "null", + want: true, + }, + { + name: "NULL string (uppercase)", + input: "NULL", + want: true, + }, + { + name: "Null string (mixed case)", + input: "Null", + want: true, + }, + { + name: "null in literal expression", + input: &expr.Expression{Op: expr.Literal, Left: "null"}, + want: true, + }, + { + name: "empty string", + input: "", + want: false, + }, + { + name: "nil value", + input: nil, + want: false, + }, + { + name: "regular string", + input: "john", + want: false, + }, + { + name: "nil string", + input: "nil", + want: false, + }, + { + name: "empty string", + input: "empty", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isNullValue(tt.input) + if got != tt.want { + t.Errorf("isNullValue() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConvertToPostgresPlaceholders(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "single placeholder", + input: "SELECT * FROM users WHERE name = ?", + want: "SELECT * FROM users WHERE name = $1", + }, + { + name: "multiple placeholders", + input: "SELECT * FROM users WHERE name = ? AND age = ?", + want: "SELECT * FROM users WHERE name = $1 AND age = $2", + }, + { + name: "no placeholders", + input: "SELECT * FROM users", + want: "SELECT * FROM users", + }, + { + name: "many placeholders", + input: "? ? ? ? ?", + want: "$1 $2 $3 $4 $5", + }, + { + name: "placeholder in string literal (should still convert)", + input: "SELECT '?' FROM users WHERE name = ?", + want: "SELECT '$1' FROM users WHERE name = $2", + }, + { + name: "empty string", + input: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := convertToPostgresPlaceholders(tt.input) + if got != tt.want { + t.Errorf("convertToPostgresPlaceholders() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSQLDriver_ProcessJSONFields(t *testing.T) { + jsonbType := reflect.TypeOf(map[string]interface{}{}) + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "metadata", Type: jsonbType}, + } + + tests := []struct { + name string + provider string + expr *expr.Expression + check func(t *testing.T, expr *expr.Expression) + }{ + { + name: "postgresql JSON field conversion", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("metadata.key"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + check: func(t *testing.T, e *expr.Expression) { + if col, ok := e.Left.(expr.Column); ok { + if !strings.Contains(string(col), "->>'") { + t.Errorf("expected PostgreSQL JSON syntax, got %v", col) + } + } + }, + }, + { + name: "mysql JSON field conversion", + provider: "mysql", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("metadata.key"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + check: func(t *testing.T, e *expr.Expression) { + if col, ok := e.Left.(expr.Column); ok { + if !strings.Contains(string(col), "JSON_EXTRACT") { + t.Errorf("expected MySQL JSON syntax, got %v", col) + } + } + }, + }, + { + name: "nested expression", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.And, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("metadata.key"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + Right: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + }, + check: func(t *testing.T, e *expr.Expression) { + if leftExpr, ok := e.Left.(*expr.Expression); ok { + if col, ok := leftExpr.Left.(expr.Column); ok { + if !strings.Contains(string(col), "->>'") { + t.Errorf("expected PostgreSQL JSON syntax in nested expression, got %v", col) + } + } + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } + driver.processJSONFields(tt.expr) + if tt.check != nil { + tt.check(t, tt.expr) + } + }) + } +} + +func TestSQLDriver_RenderParamInternal(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + } + driver, err := NewSQLDriver(fields, "postgresql") + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } + + tests := []struct { + name string + expr *expr.Expression + wantSQL []string + wantCount int + wantErr bool + }{ + { + name: "Like operator", + expr: &expr.Expression{ + Op: expr.Like, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{"ILIKE"}, + wantCount: 1, + wantErr: false, + }, + { + name: "Fuzzy operator", + expr: &expr.Expression{ + Op: expr.Fuzzy, + Left: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "roam"}, + }, + }, + wantSQL: []string{"similarity"}, + wantCount: 1, + wantErr: false, + }, + { + name: "Boost operator (error)", + expr: &expr.Expression{ + Op: expr.Boost, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantErr: true, + }, + { + name: "Range operator", + expr: &expr.Expression{ + Op: expr.Range, + Left: expr.Column("name"), + Right: &expr.RangeBoundary{ + Min: &expr.Expression{Op: expr.Literal, Left: "a"}, + Max: &expr.Expression{Op: expr.Literal, Left: "z"}, + Inclusive: true, + }, + }, + wantSQL: []string{"BETWEEN"}, + wantCount: 2, + wantErr: false, + }, + { + name: "nil expression", + expr: nil, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql, params, err := driver.renderParamInternal(tt.expr) + if (err != nil) != tt.wantErr { + t.Errorf("renderParamInternal() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if tt.expr == nil { + if sql != "" { + t.Errorf("renderParamInternal() sql = %v, want empty string", sql) + } + return + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("renderParamInternal() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("renderParamInternal() params count = %v, want %v", len(params), tt.wantCount) + } + } + }) + } +} + +func TestSQLDriver_ProviderSpecific(t *testing.T) { + fields := []FieldInfo{ + {Name: "name", Type: reflect.TypeOf("")}, + {Name: "metadata", Type: reflect.TypeOf(map[string]interface{}{})}, + } + + tests := []struct { + name string + provider string + expr *expr.Expression + wantSQL []string + wantCount int + checkFunc func(t *testing.T, sql string, params []any) + }{ + { + name: "postgresql placeholder format", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{"?"}, + wantCount: 1, + checkFunc: func(t *testing.T, sql string, params []any) { + if !strings.Contains(sql, "?") { + t.Errorf("expected ? placeholder, got %v", sql) + } + }, + }, + { + name: "mysql placeholder format", + provider: "mysql", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{"?"}, + wantCount: 1, + checkFunc: func(t *testing.T, sql string, params []any) { + if !strings.Contains(sql, "?") { + t.Errorf("expected MySQL placeholder (?), got %v", sql) + } + }, + }, + { + name: "sqlite placeholder format", + provider: "sqlite", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("name"), + Right: &expr.Expression{Op: expr.Literal, Left: "john"}, + }, + wantSQL: []string{"?"}, + wantCount: 1, + checkFunc: func(t *testing.T, sql string, params []any) { + if !strings.Contains(sql, "?") { + t.Errorf("expected SQLite placeholder (?), got %v", sql) + } + }, + }, + { + name: "postgresql JSON field", + provider: "postgresql", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("metadata.key"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + wantSQL: []string{"metadata->>'key'"}, + wantCount: 1, + checkFunc: nil, + }, + { + name: "mysql JSON field", + provider: "mysql", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("metadata.key"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + wantSQL: []string{"JSON_EXTRACT"}, + wantCount: 1, + checkFunc: nil, + }, + { + name: "sqlite JSON field", + provider: "sqlite", + expr: &expr.Expression{ + Op: expr.Equals, + Left: expr.Column("metadata.key"), + Right: &expr.Expression{Op: expr.Literal, Left: "value"}, + }, + wantSQL: []string{"JSON_EXTRACT"}, + wantCount: 1, + checkFunc: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + driver, err := NewSQLDriver(fields, tt.provider) + if err != nil { + t.Fatalf("NewSQLDriver() error = %v", err) + } + sql, params, err := driver.RenderParam(tt.expr) + if err != nil { + t.Fatalf("RenderParam() error = %v", err) + } + for _, want := range tt.wantSQL { + if !strings.Contains(sql, want) { + t.Errorf("RenderParam() sql = %v, want to contain %v", sql, want) + } + } + if len(params) != tt.wantCount { + t.Errorf("RenderParam() params count = %v, want %v", len(params), tt.wantCount) + } + if tt.checkFunc != nil { + tt.checkFunc(t, sql, params) + } + }) + } +} diff --git a/storage/sql.go b/storage/sql.go index 8313e11..2b9ff57 100644 --- a/storage/sql.go +++ b/storage/sql.go @@ -18,8 +18,8 @@ import ( "gorm.io/gorm/logger" "gorm.io/gorm/schema" + serviceErrors "github.com/tink3rlabs/magic/errors" slogger "github.com/tink3rlabs/magic/logger" - "github.com/tink3rlabs/magic/storage/search/lucene" ) @@ -250,7 +250,7 @@ func (s *SQLAdapter) executePaginatedQuery( nextCursor := "" if destSlice.Len() > limit { lastItem := destSlice.Index(limit - 1) - field := reflect.Indirect(lastItem).FieldByName(sortKey) + field := findFieldByJSONTag(reflect.Indirect(lastItem), sortKey) if field.IsValid() && field.Kind() == reflect.String { nextCursor = base64.StdEncoding.EncodeToString([]byte(field.String())) } @@ -262,6 +262,23 @@ func (s *SQLAdapter) executePaginatedQuery( return nextCursor, nil } +// findFieldByJSONTag looks up a struct field by its json tag name. +// This is needed because sortKey uses the JSON/column name (e.g. "id") +// while Go struct fields use PascalCase (e.g. "Id"). +func findFieldByJSONTag(v reflect.Value, tag string) reflect.Value { + t := v.Type() + for i := 0; i < t.NumField(); i++ { + jsonTag := t.Field(i).Tag.Get("json") + if idx := strings.Index(jsonTag, ","); idx != -1 { + jsonTag = jsonTag[:idx] + } + if jsonTag == tag { + return v.Field(i) + } + } + return reflect.Value{} +} + func (s *SQLAdapter) List(dest any, sortKey string, filter map[string]any, limit int, cursor string, params ...map[string]any) (string, error) { return s.executePaginatedQuery(dest, sortKey, limit, cursor, func(q *gorm.DB) *gorm.DB { if len(filter) > 0 { @@ -282,15 +299,20 @@ func (s *SQLAdapter) Search(dest any, sortKey string, query string, limit int, c destType := reflect.TypeOf(dest).Elem().Elem() model := reflect.New(destType).Elem().Interface() - parser, err := lucene.NewParserFromType(model) + parser, err := lucene.NewParser(model) if err != nil { slog.Error("Parser creation failed", "error", err) return "", err } - whereClause, queryParams, err := parser.ParseToSQL(query) + // Pass the SQL provider to generate provider-specific SQL syntax + whereClause, queryParams, err := parser.ParseToSQL(query, string(s.provider)) if err != nil { slog.Error("Filter parsing failed", "error", err) + // Wrap InvalidFieldError as BadRequest for proper HTTP 400 response + if _, ok := err.(*lucene.InvalidFieldError); ok { + return "", &serviceErrors.BadRequest{Message: err.Error()} + } return "", err } diff --git a/types/types.go b/types/types.go index c33e2c9..c7364e9 100644 --- a/types/types.go +++ b/types/types.go @@ -136,8 +136,13 @@ const definitions = ` "description": "A string containing a JSON Pointer value." } } + }, + "LuceneSearchQuery": { + "type": "string", + "description": "Lucene-style search query supporting field searches, wildcards, boolean operators, ranges, and more. Syntax: field:value, wildcards (*,?), operators (AND, OR, NOT, +, -), ranges ([min TO max]), quoted phrases, JSONB access (field.subfield:value), null checks (field:null), and fuzzy search (term~). Examples: name:john, name:john*, email:*@example.com, description:*important*, name:john* OR email:*@example.com, name:john AND status:active, status:active OR status:pending, name:john NOT status:inactive, +name:john +status:active, name:john -status:deleted, age:[25 TO 65], age:{25 TO 65}, age:[25 TO *], age:[* TO 65], created_at:[2024-01-01 TO 2024-12-31], description:\"hello world\", title:\"test-app (v1.0)\", name:C\\+\\+ OR path:\\/usr\\/bin, (name:john* OR email:*@example.com) AND status:active AND age:[25 TO 65], ((name:john OR name:jane) AND status:active) OR (status:pending AND age:[18 TO *]), searchterm, john*, labels.category:production, metadata.tags:prod*, name:john AND labels.env:prod AND metadata.team:engineering, parent_id:null, NOT deleted_at:null, name:john AND deleted_at:null, name:roam~, name:roam~2, labels.tag:prod~, +name:john* -status:deleted age:[25 TO 65] AND (role:admin OR role:moderator), name:john OR email:john@example.com OR phone:*555*, (name:*admin* OR role:administrator) AND status:active AND NOT deleted_at:null AND created_at:[2024-01-01 TO *]", + "example": "name:john AND status:active" } - } + } } } `