diff --git a/dotnet/src/VectorData/PgVector/PostgresCollection.cs b/dotnet/src/VectorData/PgVector/PostgresCollection.cs index a0c2273fabc8..f9061387378b 100644 --- a/dotnet/src/VectorData/PgVector/PostgresCollection.cs +++ b/dotnet/src/VectorData/PgVector/PostgresCollection.cs @@ -24,7 +24,7 @@ namespace Microsoft.SemanticKernel.Connectors.PgVector; /// The type of the key. /// The type of the record. #pragma warning disable CA1711 // Identifiers should not have incorrect suffix -public class PostgresCollection : VectorStoreCollection +public class PostgresCollection : VectorStoreCollection, IKeywordHybridSearchable #pragma warning restore CA1711 // Identifiers should not have incorrect suffix where TKey : notnull where TRecord : class @@ -52,6 +52,9 @@ public class PostgresCollection : VectorStoreCollectionThe default options for vector search. private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + /// The default options for hybrid search. + private static readonly HybridSearchOptions s_defaultHybridSearchOptions = new(); + /// /// Initializes a new instance of the class. /// @@ -396,43 +399,7 @@ public override async IAsyncEnumerable> SearchAsync< } var vectorProperty = this._model.GetVectorPropertyOrSingle(options); - - object vector = searchValue switch - { - // Dense float32 - ReadOnlyMemory r => r, - float[] f => new ReadOnlyMemory(f), - Embedding e => e.Vector, - _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator> generator - => await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false), - -#if NET - // Dense float16 - ReadOnlyMemory r => r, - Half[] f => new ReadOnlyMemory(f), - Embedding e => e.Vector, - _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator> generator - => await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false), -#endif - - // Dense Binary - BitArray b => b, - BinaryEmbedding e => e.Vector, - _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator generator - => await generator.GenerateAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false), - - // Sparse - SparseVector sv => sv, - // TODO: Add a PG-specific SparseVectorEmbedding type - - _ => vectorProperty.EmbeddingGenerator is null - ? throw new NotSupportedException(VectorDataStrings.InvalidSearchInputAndNoEmbeddingGeneratorWasConfigured(searchValue.GetType(), PostgresModelBuilder.SupportedVectorTypes)) - : throw new InvalidOperationException(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType(typeof(TInput), vectorProperty.EmbeddingGenerator.GetType())) - }; - - var pgVector = PostgresPropertyMapping.MapVectorForStorageModel(vector); - - Verify.NotNull(pgVector); + var pgVector = await this.ConvertSearchInputToVectorAsync(searchValue, vectorProperty, cancellationToken).ConfigureAwait(false); // Simulating skip/offset logic locally, since OFFSET can work only with LIMIT in combination // and LIMIT is not supported in vector search extension, instead of LIMIT - "k" parameter is used. @@ -460,6 +427,51 @@ _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator + public async IAsyncEnumerable> HybridSearchAsync( + TInput searchValue, + ICollection keywords, + int top, + HybridSearchOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TInput : notnull + { + Verify.NotNull(searchValue); + Verify.NotNull(keywords); + Verify.NotLessThan(top, 1); + + options ??= s_defaultHybridSearchOptions; + if (options.IncludeVectors && this._model.EmbeddingGenerationRequired) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + var vectorProperty = this._model.GetVectorPropertyOrSingle(new() { VectorProperty = options.VectorProperty }); + var textProperty = this._model.GetFullTextDataPropertyOrSingle(options.AdditionalProperty); + var pgVector = await this.ConvertSearchInputToVectorAsync(searchValue, vectorProperty, cancellationToken).ConfigureAwait(false); + + using var connection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + using var command = connection.CreateCommand(); + PostgresSqlBuilder.BuildHybridSearchCommand(command, this._schema, this.Name, this._model, vectorProperty, textProperty, pgVector, keywords, +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + options.OldFilter, +#pragma warning restore CS0618 // VectorSearchFilter is obsolete + options.Filter, options.Skip, options.IncludeVectors, top, options.ScoreThreshold); + + using var reader = await connection.ExecuteWithErrorHandlingAsync( + this._collectionMetadata, + "HybridSearch", + () => command.ExecuteReaderAsync(cancellationToken), + cancellationToken).ConfigureAwait(false); + + while (await reader.ReadWithErrorHandlingAsync(this._collectionMetadata, "HybridSearch", cancellationToken).ConfigureAwait(false)) + { + yield return new VectorSearchResult( + this._mapper.MapFromStorageToDataModel(reader, options.IncludeVectors), + reader.GetDouble(reader.GetOrdinal(PostgresConstants.DistanceColumnName))); + } + } + #endregion Search /// @@ -513,11 +525,11 @@ private async Task InternalCreateCollectionAsync(bool ifNotExists, CancellationT batch.BatchCommands.Add( new NpgsqlBatchCommand(PostgresSqlBuilder.BuildCreateTableSql(this._schema, this.Name, this._model, pgVersion, ifNotExists))); - foreach (var (column, kind, function, isVector) in PostgresPropertyMapping.GetIndexInfo(this._model.Properties)) + foreach (var (column, kind, function, isVector, isFullText, fullTextLanguage) in PostgresPropertyMapping.GetIndexInfo(this._model.Properties)) { batch.BatchCommands.Add( new NpgsqlBatchCommand( - PostgresSqlBuilder.BuildCreateIndexSql(this._schema, this.Name, column, kind, function, isVector, ifNotExists))); + PostgresSqlBuilder.BuildCreateIndexSql(this._schema, this.Name, column, kind, function, isVector, isFullText, fullTextLanguage, ifNotExists))); } await batch.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); @@ -535,4 +547,48 @@ private Task RunOperationAsync(string operationName, Func> operati this._collectionMetadata, operationName, operation); + + /// + /// Converts a search input value to a PostgreSQL vector representation, generating embeddings if necessary. + /// + private async Task ConvertSearchInputToVectorAsync(TInput searchValue, VectorPropertyModel vectorProperty, CancellationToken cancellationToken) + where TInput : notnull + { + object vector = searchValue switch + { + // Dense float32 + ReadOnlyMemory r => r, + float[] f => new ReadOnlyMemory(f), + Embedding e => e.Vector, + _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator> generator + => await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false), + +#if NET + // Dense float16 + ReadOnlyMemory r => r, + Half[] f => new ReadOnlyMemory(f), + Embedding e => e.Vector, + _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator> generator + => await generator.GenerateVectorAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false), +#endif + + // Dense Binary + BitArray b => b, + BinaryEmbedding e => e.Vector, + _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator generator + => await generator.GenerateAsync(searchValue, cancellationToken: cancellationToken).ConfigureAwait(false), + + // Sparse + SparseVector sv => sv, + // TODO: Add a PG-specific SparseVectorEmbedding type + + _ => vectorProperty.EmbeddingGenerator is null + ? throw new NotSupportedException(VectorDataStrings.InvalidSearchInputAndNoEmbeddingGeneratorWasConfigured(searchValue.GetType(), PostgresModelBuilder.SupportedVectorTypes)) + : throw new InvalidOperationException(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType(typeof(TInput), vectorProperty.EmbeddingGenerator.GetType())) + }; + + var pgVector = PostgresPropertyMapping.MapVectorForStorageModel(vector); + Verify.NotNull(pgVector); + return pgVector; + } } diff --git a/dotnet/src/VectorData/PgVector/PostgresConstants.cs b/dotnet/src/VectorData/PgVector/PostgresConstants.cs index 75054a6a7560..631f92bf45e0 100644 --- a/dotnet/src/VectorData/PgVector/PostgresConstants.cs +++ b/dotnet/src/VectorData/PgVector/PostgresConstants.cs @@ -24,6 +24,9 @@ internal static class PostgresConstants /// The default distance function. public const string DefaultDistanceFunction = DistanceFunction.CosineDistance; + /// The default full-text search language for PostgreSQL. + public const string DefaultFullTextSearchLanguage = "english"; + public static readonly Dictionary IndexMaxDimensions = new() { { IndexKind.Hnsw, 2000 }, diff --git a/dotnet/src/VectorData/PgVector/PostgresPropertyExtensions.cs b/dotnet/src/VectorData/PgVector/PostgresPropertyExtensions.cs new file mode 100644 index 000000000000..e259aa181fa3 --- /dev/null +++ b/dotnet/src/VectorData/PgVector/PostgresPropertyExtensions.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ProviderServices; + +namespace Microsoft.SemanticKernel.Connectors.PgVector; + +/// +/// Extension methods for configuring PostgreSQL-specific properties on vector store property definitions. +/// +public static class PostgresPropertyExtensions +{ + private const string FullTextSearchLanguageKey = "Postgres:FullTextSearchLanguage"; + + /// + /// Sets the PostgreSQL full-text search language for a data property. + /// + /// The data property to configure. + /// The PostgreSQL text search language name (e.g., "english", "spanish", "german"). + /// The same property instance for method chaining. + /// + /// This language is used with PostgreSQL's to_tsvector and plainto_tsquery functions + /// when creating GIN indexes and performing full-text search operations. + /// Common language options include: "simple", "english", "spanish", "german", "french", etc. + /// See PostgreSQL documentation for the full list of available text search configurations. + /// + public static VectorStoreDataProperty WithFullTextSearchLanguage(this VectorStoreDataProperty property, string? language) + { + property.ProviderAnnotations ??= []; + property.ProviderAnnotations[FullTextSearchLanguageKey] = language; + return property; + } + + /// + /// Gets the PostgreSQL full-text search language configured for a data property. + /// + /// The data property to read from. + /// The configured language, or if not set. + public static string? GetFullTextSearchLanguage(this VectorStoreDataProperty property) + => property.ProviderAnnotations?.TryGetValue(FullTextSearchLanguageKey, out var value) == true + ? value as string + : null; + + /// + /// Gets the PostgreSQL full-text search language configured for a data property model. + /// + /// The data property model to read from. + /// The configured language, or the default language ("english") if not set. + internal static string GetFullTextSearchLanguageOrDefault(this DataPropertyModel property) + => property.ProviderAnnotations?.TryGetValue(FullTextSearchLanguageKey, out var value) == true && value is string language + ? language + : PostgresConstants.DefaultFullTextSearchLanguage; +} diff --git a/dotnet/src/VectorData/PgVector/PostgresPropertyMapping.cs b/dotnet/src/VectorData/PgVector/PostgresPropertyMapping.cs index ee6981fcfb6a..1515f332dd68 100644 --- a/dotnet/src/VectorData/PgVector/PostgresPropertyMapping.cs +++ b/dotnet/src/VectorData/PgVector/PostgresPropertyMapping.cs @@ -152,13 +152,13 @@ public static NpgsqlParameter GetNpgsqlParameter(object? value) /// Returns information about indexes to create, validating that the dimensions of the vector are supported. /// /// The properties of the vector store record. - /// A list of tuples containing the column name, index kind, and distance function for each property. + /// A list of tuples containing the column name, index kind, distance function, and full-text language for each property. /// /// The default index kind is "Flat", which prevents the creation of an index. /// - public static List<(string column, string kind, string function, bool isVector)> GetIndexInfo(IReadOnlyList properties) + public static List<(string column, string kind, string function, bool isVector, bool isFullText, string? fullTextLanguage)> GetIndexInfo(IReadOnlyList properties) { - var vectorIndexesToCreate = new List<(string column, string kind, string function, bool isVector)>(); + var vectorIndexesToCreate = new List<(string column, string kind, string function, bool isVector, bool isFullText, string? fullTextLanguage)>(); foreach (var property in properties) { switch (property) @@ -185,7 +185,7 @@ public static NpgsqlParameter GetNpgsqlParameter(object? value) ); } - vectorIndexesToCreate.Add((vectorProperty.StorageName, indexKind, distanceFunction, isVector: true)); + vectorIndexesToCreate.Add((vectorProperty.StorageName, indexKind, distanceFunction, isVector: true, isFullText: false, fullTextLanguage: null)); } break; @@ -193,7 +193,13 @@ public static NpgsqlParameter GetNpgsqlParameter(object? value) case DataPropertyModel dataProperty: if (dataProperty.IsIndexed) { - vectorIndexesToCreate.Add((dataProperty.StorageName, "", "", isVector: false)); + vectorIndexesToCreate.Add((dataProperty.StorageName, kind: "", function: "", isVector: false, isFullText: false, fullTextLanguage: null)); + } + + if (dataProperty.IsFullTextIndexed) + { + var language = dataProperty.GetFullTextSearchLanguageOrDefault(); + vectorIndexesToCreate.Add((dataProperty.StorageName, kind: "", function: "", isVector: false, isFullText: true, fullTextLanguage: language)); } break; diff --git a/dotnet/src/VectorData/PgVector/PostgresServiceCollectionExtensions.cs b/dotnet/src/VectorData/PgVector/PostgresServiceCollectionExtensions.cs index 390e4beb7122..a248d4709c0b 100644 --- a/dotnet/src/VectorData/PgVector/PostgresServiceCollectionExtensions.cs +++ b/dotnet/src/VectorData/PgVector/PostgresServiceCollectionExtensions.cs @@ -286,8 +286,8 @@ private static void AddAbstractions(IServiceCollection services, services.Add(new ServiceDescriptor(typeof(IVectorSearchable), serviceKey, static (sp, key) => sp.GetRequiredKeyedService>(key), lifetime)); - // Once HybridSearch supports get implemented by PostgresCollection, - // we need to add IKeywordHybridSearchable abstraction here as well. + services.Add(new ServiceDescriptor(typeof(IKeywordHybridSearchable), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService>(key), lifetime)); } private static PostgresVectorStoreOptions? GetStoreOptions(IServiceProvider sp, Func? optionsProvider) diff --git a/dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs b/dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs index 5340560f2d3d..2997a305d733 100644 --- a/dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs +++ b/dotnet/src/VectorData/PgVector/PostgresSqlBuilder.cs @@ -119,7 +119,7 @@ internal static string BuildCreateTableSql(string schema, string tableName, Coll } /// - internal static string BuildCreateIndexSql(string schema, string tableName, string columnName, string indexKind, string distanceFunction, bool isVector, bool ifNotExists) + internal static string BuildCreateIndexSql(string schema, string tableName, string columnName, string indexKind, string distanceFunction, bool isVector, bool isFullText, string? fullTextLanguage, bool ifNotExists) { var indexName = $"{tableName}_{columnName}_index"; @@ -130,6 +130,15 @@ internal static string BuildCreateIndexSql(string schema, string tableName, stri sql.Append("IF NOT EXISTS "); } + if (isFullText) + { + // Create a GIN index for full-text search + var language = fullTextLanguage ?? PostgresConstants.DefaultFullTextSearchLanguage; + sql.AppendIdentifier(indexName).Append(" ON ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName) + .Append(" USING GIN (to_tsvector(").AppendLiteral(language).Append(", ").AppendIdentifier(columnName).Append("))"); + return sql.ToString(); + } + if (!isVector) { sql.AppendIdentifier(indexName).Append(" ON ").AppendIdentifier(schema).Append('.').AppendIdentifier(tableName) @@ -422,6 +431,61 @@ internal static void BuildDeleteBatchCommand(NpgsqlCommand command, string internal static StringBuilder AppendIdentifier(this StringBuilder sb, string identifier) => sb.Append('"').Append(identifier.Replace("\"", "\"\"")).Append('"'); + /// + /// Appends a properly quoted and escaped PostgreSQL string literal to the StringBuilder. + /// In PostgreSQL, string literals are quoted with single quotes, and embedded single quotes are escaped by doubling them. + /// + internal static StringBuilder AppendLiteral(this StringBuilder sb, string value) + => sb.Append('\'').Append(value.Replace("'", "''")).Append('\''); + + /// + /// Gets the PostgreSQL distance operator for the specified distance function. + /// + private static string GetDistanceOperator(string? distanceFunction) + => distanceFunction switch + { + DistanceFunction.EuclideanDistance or null => "<->", + DistanceFunction.CosineDistance or DistanceFunction.CosineSimilarity => "<=>", + DistanceFunction.ManhattanDistance => "<+>", + DistanceFunction.DotProductSimilarity => "<#>", + DistanceFunction.HammingDistance => "<~>", + _ => throw new NotSupportedException($"Distance function {distanceFunction} is not supported.") + }; + + /// + /// Generates filter clause from either legacy or new filter, returning condition and parameters. + /// +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + private static (string Clause, List Parameters) GenerateFilterClause( + CollectionModel model, + VectorSearchFilter? legacyFilter, + Expression>? newFilter, + int startParamIndex) + => (oldFilter: legacyFilter, newFilter) switch + { + (not null, not null) => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), + (not null, null) => GenerateLegacyFilterWhereClause(model, legacyFilter, startParamIndex), + (null, not null) => GenerateNewFilterWhereClause(model, newFilter, startParamIndex), + _ => (Clause: string.Empty, Parameters: new List()) + }; + + /// + /// Generates filter condition (without WHERE keyword) from either legacy or new filter. + /// + private static (string Condition, List Parameters) GenerateFilterCondition( + CollectionModel model, + VectorSearchFilter? legacyFilter, + Expression>? newFilter, + int startParamIndex) + => (oldFilter: legacyFilter, newFilter) switch + { + (not null, not null) => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), + (not null, null) => GenerateLegacyFilterCondition(model, legacyFilter, startParamIndex), + (null, not null) => GenerateNewFilterCondition(model, newFilter, startParamIndex), + _ => (Condition: string.Empty, Parameters: new List()) + }; +#pragma warning restore CS0618 // VectorSearchFilter is obsolete + #pragma warning disable CS0618 // VectorSearchFilter is obsolete /// internal static void BuildGetNearestMatchCommand( @@ -441,27 +505,10 @@ internal static void BuildGetNearestMatchCommand( } var distanceFunction = vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction; - var distanceOp = distanceFunction switch - { - DistanceFunction.EuclideanDistance or null => "<->", - DistanceFunction.CosineDistance or DistanceFunction.CosineSimilarity => "<=>", - DistanceFunction.ManhattanDistance => "<+>", - DistanceFunction.DotProductSimilarity => "<#>", - DistanceFunction.HammingDistance => "<~>", - - _ => throw new NotSupportedException($"Distance function {vectorProperty.DistanceFunction} is not supported.") - }; + var distanceOp = GetDistanceOperator(distanceFunction); // Start where clause params at 2, vector takes param 1. -#pragma warning disable CS0618 // VectorSearchFilter is obsolete - var (where, parameters) = (oldFilter: legacyFilter, newFilter) switch - { - (not null, not null) => throw new ArgumentException("Either Filter or OldFilter can be specified, but not both"), - (not null, null) => GenerateLegacyFilterWhereClause(model, legacyFilter, startParamIndex: 2), - (null, not null) => GenerateNewFilterWhereClause(model, newFilter, startParamIndex: 2), - _ => (Clause: string.Empty, Parameters: []) - }; -#pragma warning restore CS0618 // VectorSearchFilter is obsolete + var (where, parameters) = GenerateFilterClause(model, legacyFilter, newFilter, startParamIndex: 2); StringBuilder sql = new(); sql.Append("SELECT ").Append(columns).Append(", ").AppendIdentifier(vectorProperty.StorageName) @@ -604,16 +651,28 @@ internal static void BuildSelectWhereCommand( } internal static (string Clause, List Parameters) GenerateNewFilterWhereClause(CollectionModel model, LambdaExpression newFilter, int startParamIndex) + { + var (condition, parameters) = GenerateNewFilterCondition(model, newFilter, startParamIndex); + return (string.IsNullOrEmpty(condition) ? string.Empty : $"WHERE {condition}", parameters); + } + + internal static (string Condition, List Parameters) GenerateNewFilterCondition(CollectionModel model, LambdaExpression newFilter, int startParamIndex) { PostgresFilterTranslator translator = new(model, newFilter, startParamIndex); - translator.Translate(appendWhere: true); + translator.Translate(appendWhere: false); return (translator.Clause.ToString(), translator.ParameterValues); } #pragma warning disable CS0618 // VectorSearchFilter is obsolete internal static (string Clause, List Parameters) GenerateLegacyFilterWhereClause(CollectionModel model, VectorSearchFilter legacyFilter, int startParamIndex) { - StringBuilder whereClause = new("WHERE "); + var (condition, parameters) = GenerateLegacyFilterCondition(model, legacyFilter, startParamIndex); + return ($"WHERE {condition}", parameters); + } + + internal static (string Condition, List Parameters) GenerateLegacyFilterCondition(CollectionModel model, VectorSearchFilter legacyFilter, int startParamIndex) + { + StringBuilder condition = new(); var parameters = new List(); var paramIndex = startParamIndex; @@ -623,7 +682,7 @@ internal static (string Clause, List Parameters) GenerateLegacyFilterWhe { if (!first) { - whereClause.Append(" AND "); + condition.Append(" AND "); } first = false; @@ -632,7 +691,7 @@ internal static (string Clause, List Parameters) GenerateLegacyFilterWhe var property = model.Properties.FirstOrDefault(p => p.ModelName == equalTo.FieldName); if (property == null) { throw new ArgumentException($"Property {equalTo.FieldName} not found in record definition."); } - whereClause.AppendIdentifier(property.StorageName).Append(" = $").Append(paramIndex); + condition.AppendIdentifier(property.StorageName).Append(" = $").Append(paramIndex); parameters.Add(equalTo.Value); paramIndex++; } @@ -646,7 +705,7 @@ internal static (string Clause, List Parameters) GenerateLegacyFilterWhe throw new ArgumentException($"Property {anyTagEqualTo.FieldName} must be of type List to use AnyTagEqualTo filter."); } - whereClause.AppendIdentifier(property.StorageName).Append(" @> ARRAY[$").Append(paramIndex).Append("::TEXT]"); + condition.AppendIdentifier(property.StorageName).Append(" @> ARRAY[$").Append(paramIndex).Append("::TEXT]"); parameters.Add(anyTagEqualTo.Value); paramIndex++; } @@ -656,7 +715,144 @@ internal static (string Clause, List Parameters) GenerateLegacyFilterWhe } } - return (whereClause.ToString(), parameters); + return (condition.ToString(), parameters); + } +#pragma warning restore CS0618 // VectorSearchFilter is obsolete + +#pragma warning disable CS0618 // VectorSearchFilter is obsolete + /// + /// Builds a hybrid search command that combines vector similarity search with full-text keyword search using RRF (Reciprocal Rank Fusion). + /// + internal static void BuildHybridSearchCommand( + NpgsqlCommand command, string schema, string tableName, CollectionModel model, + VectorPropertyModel vectorProperty, DataPropertyModel textProperty, + object vectorValue, ICollection keywords, + VectorSearchFilter? legacyFilter, Expression>? newFilter, + int? skip, bool includeVectors, int top, double? scoreThreshold = null) + { + // RRF constant - higher values give more weight to lower-ranked results + const int RrfConstant = 60; + + // Build column list with proper escaping (for the final SELECT) + StringBuilder columns = new(); + for (var i = 0; i < model.Properties.Count; i++) + { + if (!includeVectors && model.Properties[i] is VectorPropertyModel) + { + continue; + } + + if (columns.Length > 0) + { + columns.Append(", "); + } + + columns.Append("t.").AppendIdentifier(model.Properties[i].StorageName); + } + + var distanceOp = GetDistanceOperator(vectorProperty.DistanceFunction ?? PostgresConstants.DefaultDistanceFunction); + + // Parameters: $1 = keywords, $2 = vector, $3 = RRF constant + // Additional parameters start at $4 for filters + var (filterCondition, filterParameters) = GenerateFilterCondition(model, legacyFilter, newFilter, startParamIndex: 4); + + // Build the full table name + var fullTableName = new StringBuilder() + .AppendIdentifier(schema).Append('.').AppendIdentifier(tableName) + .ToString(); + + // Use a larger internal limit for the CTEs to get better ranking, then limit final results + var internalLimit = (top + (skip ?? 0)) * 2; + if (internalLimit < 20) + { + internalLimit = 20; + } + + // Get the full-text search language from the text property + var language = textProperty.GetFullTextSearchLanguageOrDefault(); + + StringBuilder sql = new(); + sql.AppendLine() + .Append("WITH semantic_search AS (").AppendLine() + .Append(" SELECT ").AppendIdentifier(model.KeyProperty.StorageName).Append(", RANK () OVER (ORDER BY ").AppendIdentifier(vectorProperty.StorageName) + .Append(' ').Append(distanceOp).Append(" $2) AS rank").AppendLine() + .Append(" FROM ").Append(fullTableName); + + // Add filter for semantic search + if (!string.IsNullOrEmpty(filterCondition)) + { + sql.AppendLine().Append(" WHERE ").Append(filterCondition); + } + + sql.AppendLine() + .Append(" ORDER BY ").AppendIdentifier(vectorProperty.StorageName).Append(' ').Append(distanceOp).Append(" $2").AppendLine() + .Append(" LIMIT ").Append(internalLimit).AppendLine() + .Append("),").AppendLine() + .Append("keyword_search AS (").AppendLine() + .Append(" SELECT ").AppendIdentifier(model.KeyProperty.StorageName).Append(", RANK () OVER (ORDER BY ts_rank_cd(to_tsvector(").AppendLiteral(language).Append(", ").AppendIdentifier(textProperty.StorageName).Append("), query) DESC) AS rank").AppendLine() + .Append(" FROM ").Append(fullTableName).Append(", plainto_tsquery(").AppendLiteral(language).Append(", $1) query").AppendLine() + .Append(" WHERE to_tsvector(").AppendLiteral(language).Append(", ").AppendIdentifier(textProperty.StorageName).Append(") @@ query"); + + // Add filter for keyword search (using AND since we already have a WHERE clause) + if (!string.IsNullOrEmpty(filterCondition)) + { + sql.Append(" AND ").Append(filterCondition); + } + + sql.AppendLine() + .Append(" ORDER BY ts_rank_cd(to_tsvector(").AppendLiteral(language).Append(", ").AppendIdentifier(textProperty.StorageName).Append("), query) DESC").AppendLine() + .Append(" LIMIT ").Append(internalLimit).AppendLine() + .Append(')').AppendLine() + .Append("SELECT ").Append(columns).Append(',').AppendLine() + .Append(" COALESCE(1.0 / ($3 + semantic_search.rank), 0.0) +").AppendLine() + .Append(" COALESCE(1.0 / ($3 + keyword_search.rank), 0.0) AS ").AppendIdentifier(PostgresConstants.DistanceColumnName).AppendLine() + .Append("FROM semantic_search").AppendLine() + .Append("FULL OUTER JOIN keyword_search ON semantic_search.").AppendIdentifier(model.KeyProperty.StorageName).Append(" = keyword_search.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine() + .Append("JOIN ").Append(fullTableName).Append(" t ON t.").AppendIdentifier(model.KeyProperty.StorageName).Append(" = COALESCE(semantic_search.").AppendIdentifier(model.KeyProperty.StorageName).Append(", keyword_search.").AppendIdentifier(model.KeyProperty.StorageName).Append(')').AppendLine() + .Append("ORDER BY ").AppendIdentifier(PostgresConstants.DistanceColumnName).Append(" DESC"); + + var commandText = sql.ToString(); + + // Apply score threshold filter if specified (higher RRF scores = better match) + if (scoreThreshold.HasValue) + { + var scoreThresholdParamIndex = filterParameters.Count + 4; + StringBuilder outerSql = new(); + outerSql.Append("SELECT * FROM (").Append(commandText).Append(") AS scored WHERE ") + .AppendIdentifier(PostgresConstants.DistanceColumnName).Append(" >= $").Append(scoreThresholdParamIndex); + commandText = outerSql.ToString(); + } + + // Apply LIMIT and OFFSET + StringBuilder finalSql = new(); + finalSql.Append("SELECT * FROM (").Append(commandText).Append(") AS results LIMIT ").Append(top); + if (skip > 0) + { + finalSql.Append(" OFFSET ").Append(skip.Value); + } + + command.CommandText = finalSql.ToString(); + + Debug.Assert(command.Parameters.Count == 0); + + // $1 = keywords (joined as single string for plainto_tsquery) + command.Parameters.Add(new NpgsqlParameter { Value = string.Join(" ", keywords) }); + // $2 = vector + command.Parameters.Add(new NpgsqlParameter { Value = vectorValue }); + // $3 = RRF constant + command.Parameters.Add(new NpgsqlParameter { Value = RrfConstant }); + + // Filter parameters starting at $4 + foreach (var parameter in filterParameters) + { + command.Parameters.Add(new NpgsqlParameter { Value = parameter }); + } + + // Score threshold parameter + if (scoreThreshold.HasValue) + { + command.Parameters.Add(new NpgsqlParameter { Value = scoreThreshold.Value }); + } } #pragma warning restore CS0618 // VectorSearchFilter is obsolete } diff --git a/dotnet/src/VectorData/VectorData.Abstractions/ProviderServices/CollectionModelBuilder.cs b/dotnet/src/VectorData/VectorData.Abstractions/ProviderServices/CollectionModelBuilder.cs index 79af64cae8c5..de8456e70fb4 100644 --- a/dotnet/src/VectorData/VectorData.Abstractions/ProviderServices/CollectionModelBuilder.cs +++ b/dotnet/src/VectorData/VectorData.Abstractions/ProviderServices/CollectionModelBuilder.cs @@ -256,12 +256,8 @@ protected virtual void ProcessRecordDefinition(VectorStoreCollectionDefinition d { // Property wasn't found attribute-annotated on the CLR type, so we need to add it. - var propertyType = definitionProperty.Type; - if (propertyType is null) - { - throw new InvalidOperationException(VectorDataStrings.MissingTypeOnPropertyDefinition(definitionProperty)); - } - + var propertyType = definitionProperty.Type + ?? throw new InvalidOperationException(VectorDataStrings.MissingTypeOnPropertyDefinition(definitionProperty)); switch (definitionProperty) { case VectorStoreKeyProperty definitionKeyProperty: @@ -289,6 +285,12 @@ protected virtual void ProcessRecordDefinition(VectorStoreCollectionDefinition d this.SetPropertyStorageName(property, definitionProperty.StorageName, type); + // Copy provider-specific properties if present + if (definitionProperty.ProviderAnnotations is not null) + { + property.ProviderAnnotations = new Dictionary(definitionProperty.ProviderAnnotations); + } + switch (definitionProperty) { case VectorStoreKeyProperty definitionKeyProperty: diff --git a/dotnet/src/VectorData/VectorData.Abstractions/ProviderServices/PropertyModel.cs b/dotnet/src/VectorData/VectorData.Abstractions/ProviderServices/PropertyModel.cs index 4e728d02ad14..668ba50d7f47 100644 --- a/dotnet/src/VectorData/VectorData.Abstractions/ProviderServices/PropertyModel.cs +++ b/dotnet/src/VectorData/VectorData.Abstractions/ProviderServices/PropertyModel.cs @@ -54,6 +54,14 @@ public string StorageName /// public PropertyInfo? PropertyInfo { get; set; } + /// + /// Gets or sets a dictionary of provider-specific annotations for this property. + /// + /// + /// This allows setting database-specific configuration options that aren't universal across all vector stores. + /// + public Dictionary? ProviderAnnotations { get; set; } + /// /// Reads the property from the given , returning the value as an . /// diff --git a/dotnet/src/VectorData/VectorData.Abstractions/RecordDefinition/VectorStoreProperty.cs b/dotnet/src/VectorData/VectorData.Abstractions/RecordDefinition/VectorStoreProperty.cs index aa82d9f58605..d3f2b8faced9 100644 --- a/dotnet/src/VectorData/VectorData.Abstractions/RecordDefinition/VectorStoreProperty.cs +++ b/dotnet/src/VectorData/VectorData.Abstractions/RecordDefinition/VectorStoreProperty.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; namespace Microsoft.Extensions.VectorData; @@ -33,6 +34,9 @@ private protected VectorStoreProperty(VectorStoreProperty source) this.Name = source.Name; this.StorageName = source.StorageName; this.Type = source.Type; + this.ProviderAnnotations = source.ProviderAnnotations is not null + ? new Dictionary(source.ProviderAnnotations) + : null; } /// @@ -55,4 +59,13 @@ private protected VectorStoreProperty(VectorStoreProperty source) /// Gets or sets the type of the property. /// public Type? Type { get; set; } + + /// + /// Gets or sets a dictionary of provider-specific annotations for this property. + /// + /// + /// This allows setting database-specific configuration options that aren't universal across all vector stores. + /// Use provider-specific extension methods to set and get values in a strongly-typed manner. + /// + public Dictionary? ProviderAnnotations { get; set; } } diff --git a/dotnet/test/VectorData/PgVector.ConformanceTests/PostgresHybridSearchTests.cs b/dotnet/test/VectorData/PgVector.ConformanceTests/PostgresHybridSearchTests.cs new file mode 100644 index 000000000000..34d5fa6887a9 --- /dev/null +++ b/dotnet/test/VectorData/PgVector.ConformanceTests/PostgresHybridSearchTests.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft. All rights reserved. + +using PgVector.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace PgVector.ConformanceTests; + +public class PostgresHybridSearchTests( + PostgresHybridSearchTests.VectorAndStringFixture vectorAndStringFixture, + PostgresHybridSearchTests.MultiTextFixture multiTextFixture) + : HybridSearchTests(vectorAndStringFixture, multiTextFixture), + IClassFixture, + IClassFixture +{ + public new class VectorAndStringFixture : HybridSearchTests.VectorAndStringFixture + { + public override TestStore TestStore => PostgresTestStore.Instance; + } + + public new class MultiTextFixture : HybridSearchTests.MultiTextFixture + { + public override TestStore TestStore => PostgresTestStore.Instance; + } +} diff --git a/dotnet/test/VectorData/PgVector.UnitTests/PostgresPropertyMappingTests.cs b/dotnet/test/VectorData/PgVector.UnitTests/PostgresPropertyMappingTests.cs index 09b6541842b3..61547fbb4e7d 100644 --- a/dotnet/test/VectorData/PgVector.UnitTests/PostgresPropertyMappingTests.cs +++ b/dotnet/test/VectorData/PgVector.UnitTests/PostgresPropertyMappingTests.cs @@ -104,23 +104,26 @@ public void GetIndexInfoReturnsCorrectValues() // Assert Assert.Equal(3, indexInfo.Count); - foreach (var (columnName, indexKind, distanceFunction, isVector) in indexInfo) + foreach (var (columnName, indexKind, distanceFunction, isVector, isFullText, fullTextLanguage) in indexInfo) { if (columnName == "vector1") { Assert.True(isVector); + Assert.False(isFullText); Assert.Equal(IndexKind.Hnsw, indexKind); Assert.Equal(DistanceFunction.CosineDistance, distanceFunction); } else if (columnName == "vector3") { Assert.True(isVector); + Assert.False(isFullText); Assert.Equal(IndexKind.Hnsw, indexKind); Assert.Equal(DistanceFunction.ManhattanDistance, distanceFunction); } else if (columnName == "data1") { Assert.False(isVector); + Assert.False(isFullText); } else { diff --git a/dotnet/test/VectorData/PgVector.UnitTests/PostgresSqlBuilderTests.cs b/dotnet/test/VectorData/PgVector.UnitTests/PostgresSqlBuilderTests.cs index 43ae2533173c..95d401b3c360 100644 --- a/dotnet/test/VectorData/PgVector.UnitTests/PostgresSqlBuilderTests.cs +++ b/dotnet/test/VectorData/PgVector.UnitTests/PostgresSqlBuilderTests.cs @@ -88,12 +88,12 @@ public void TestBuildCreateIndexCommand(string indexKind, string distanceFunctio if (indexKind != IndexKind.Hnsw) { - Assert.Throws(() => PostgresSqlBuilder.BuildCreateIndexSql("public", "testcollection", vectorColumn, indexKind, distanceFunction, true, ifNotExists)); - Assert.Throws(() => PostgresSqlBuilder.BuildCreateIndexSql("public", "testcollection", vectorColumn, indexKind, distanceFunction, true, ifNotExists)); + Assert.Throws(() => PostgresSqlBuilder.BuildCreateIndexSql("public", "testcollection", vectorColumn, indexKind, distanceFunction, isVector: true, isFullText: false, fullTextLanguage: null, ifNotExists)); + Assert.Throws(() => PostgresSqlBuilder.BuildCreateIndexSql("public", "testcollection", vectorColumn, indexKind, distanceFunction, isVector: true, isFullText: false, fullTextLanguage: null, ifNotExists)); return; } - var sql = PostgresSqlBuilder.BuildCreateIndexSql("public", "1testcollection", vectorColumn, indexKind, distanceFunction, true, ifNotExists); + var sql = PostgresSqlBuilder.BuildCreateIndexSql("public", "1testcollection", vectorColumn, indexKind, distanceFunction, isVector: true, isFullText: false, fullTextLanguage: null, ifNotExists); // Check for expected properties; integration tests will validate the actual SQL. Assert.Contains("CREATE INDEX ", sql); @@ -136,7 +136,7 @@ public void TestBuildCreateIndexCommand(string indexKind, string distanceFunctio [InlineData(false)] public void TestBuildCreateNonVectorIndexCommand(bool ifNotExists) { - var sql = PostgresSqlBuilder.BuildCreateIndexSql("schema", "tableName", "columnName", indexKind: "", distanceFunction: "", isVector: false, ifNotExists); + var sql = PostgresSqlBuilder.BuildCreateIndexSql("schema", "tableName", "columnName", indexKind: "", distanceFunction: "", isVector: false, isFullText: false, fullTextLanguage: null, ifNotExists); var expectedCommandText = ifNotExists ? "CREATE INDEX IF NOT EXISTS \"tableName_columnName_index\" ON \"schema\".\"tableName\" (\"columnName\")" @@ -145,6 +145,30 @@ public void TestBuildCreateNonVectorIndexCommand(bool ifNotExists) Assert.Equal(expectedCommandText, sql); } + [Theory] + [InlineData(null, "english")] // Default language + [InlineData("spanish", "spanish")] + [InlineData("german", "german")] + public void TestBuildCreateFullTextIndexCommand(string? configuredLanguage, string expectedLanguage) + { + var sql = PostgresSqlBuilder.BuildCreateIndexSql("schema", "tableName", "content", indexKind: "", distanceFunction: "", isVector: false, isFullText: true, fullTextLanguage: configuredLanguage, ifNotExists: true); + + var expectedCommandText = $"CREATE INDEX IF NOT EXISTS \"tableName_content_index\" ON \"schema\".\"tableName\" USING GIN (to_tsvector('{expectedLanguage}', \"content\"))"; + + Assert.Equal(expectedCommandText, sql); + } + + [Fact] + public void TestBuildCreateFullTextIndexCommand_EscapesSingleQuotes() + { + // Verify that single quotes in the language name are properly escaped to prevent SQL injection + var sql = PostgresSqlBuilder.BuildCreateIndexSql("schema", "tableName", "content", indexKind: "", distanceFunction: "", isVector: false, isFullText: true, fullTextLanguage: "test'injection", ifNotExists: true); + + var expectedCommandText = "CREATE INDEX IF NOT EXISTS \"tableName_content_index\" ON \"schema\".\"tableName\" USING GIN (to_tsvector('test''injection', \"content\"))"; + + Assert.Equal(expectedCommandText, sql); + } + [Fact] public void TestBuildDropTableCommand() {