diff --git a/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoModelBuilder.cs b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoModelBuilder.cs index 245d33986e25..43ca3cfdb1d8 100644 --- a/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoModelBuilder.cs +++ b/dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoModelBuilder.cs @@ -7,6 +7,7 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.VectorData; using Microsoft.Extensions.VectorData.ProviderServices; +using MongoDB.Bson; using MongoDB.Bson.Serialization.Attributes; namespace Microsoft.SemanticKernel.Connectors.MongoDB; @@ -43,9 +44,9 @@ protected override void ProcessTypeProperties(Type type, VectorStoreCollectionDe protected override bool IsKeyPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes) { - supportedTypes = "string"; + supportedTypes = "string, Guid, ObjectId"; - return type == typeof(string); + return type == typeof(string) || type == typeof(Guid) || type == typeof(ObjectId); } protected override bool IsDataPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes) diff --git a/dotnet/src/VectorData/MongoDB/MongoCollection.cs b/dotnet/src/VectorData/MongoDB/MongoCollection.cs index 59f6428a0815..6937533a0899 100644 --- a/dotnet/src/VectorData/MongoDB/MongoCollection.cs +++ b/dotnet/src/VectorData/MongoDB/MongoCollection.cs @@ -75,6 +75,9 @@ public class MongoCollection : VectorStoreCollectionNumber of nearest neighbors to use during the vector search. private readonly int? _numCandidates; + /// Types of keys permitted. + private readonly Type[] _validKeyTypes = [typeof(string), typeof(Guid), typeof(ObjectId)]; + /// /// Initializes a new instance of the class. /// @@ -103,9 +106,9 @@ internal MongoCollection(IMongoDatabase mongoDatabase, string name, Func public override async Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) { - var stringKey = this.GetStringKey(key); + Verify.NotNull(key); - await this.RunOperationAsync("DeleteOne", () => this._mongoCollection.DeleteOneAsync(this.GetFilterById(stringKey), cancellationToken)) + await this.RunOperationAsync("DeleteOne", () => this._mongoCollection.DeleteOneAsync(this.GetFilterById(key), cancellationToken)) .ConfigureAwait(false); } @@ -181,7 +184,7 @@ public override Task EnsureCollectionDeletedAsync(CancellationToken cancellation /// public override async Task GetAsync(TKey key, RecordRetrievalOptions? options = null, CancellationToken cancellationToken = default) { - var stringKey = this.GetStringKey(key); + Verify.NotNull(key); var includeVectors = options?.IncludeVectors ?? false; if (includeVectors && this._model.EmbeddingGenerationRequired) @@ -190,7 +193,7 @@ public override Task EnsureCollectionDeletedAsync(CancellationToken cancellation } using var cursor = await this - .FindAsync(this.GetFilterById(stringKey), top: 1, skip: null, includeVectors, sortDefinition: null, cancellationToken) + .FindAsync(this.GetFilterById(key), top: 1, skip: null, includeVectors, sortDefinition: null, cancellationToken) .ConfigureAwait(false); var record = await cursor.SingleOrDefaultAsync(cancellationToken).ConfigureAwait(false); @@ -267,7 +270,7 @@ private async Task UpsertCoreAsync(TRecord record, int recordIndex, IReadOnlyLis var replaceOptions = new ReplaceOptions { IsUpsert = true }; var storageModel = this._mapper.MapFromDataToStorageModel(record, recordIndex, generatedEmbeddings); - var key = storageModel[MongoConstants.MongoReservedKeyPropertyName].AsString; + var key = storageModel[MongoConstants.MongoReservedKeyPropertyName]; await this.RunOperationAsync(OperationName, async () => await this._mongoCollection @@ -673,11 +676,11 @@ private async IAsyncEnumerable> EnumerateAndMapSearc } } - private FilterDefinition GetFilterById(string id) + private FilterDefinition GetFilterById(object id) => Builders.Filter.Eq(document => document[MongoConstants.MongoReservedKeyPropertyName], id); - private FilterDefinition GetFilterByIds(IEnumerable ids) - => Builders.Filter.In(document => document[MongoConstants.MongoReservedKeyPropertyName].AsString, ids); + private FilterDefinition GetFilterByIds(IEnumerable ids) + => Builders.Filter.In(document => document[MongoConstants.MongoReservedKeyPropertyName], ids); private async Task InternalCollectionExistsAsync(CancellationToken cancellationToken) { @@ -723,16 +726,5 @@ private async Task RunOperationWithRetryAsync( operation, cancellationToken).ConfigureAwait(false); - private string GetStringKey(TKey key) - { - Verify.NotNull(key); - - var stringKey = key as string ?? throw new UnreachableException("string key should have been validated during model building"); - - Verify.NotNullOrWhiteSpace(stringKey, nameof(key)); - - return stringKey; - } - #endregion } diff --git a/dotnet/src/VectorData/MongoDB/MongoVectorStore.cs b/dotnet/src/VectorData/MongoDB/MongoVectorStore.cs index a276b25a5108..008e0d77819d 100644 --- a/dotnet/src/VectorData/MongoDB/MongoVectorStore.cs +++ b/dotnet/src/VectorData/MongoDB/MongoVectorStore.cs @@ -54,7 +54,7 @@ public MongoVectorStore(IMongoDatabase mongoDatabase, MongoVectorStoreOptions? o #pragma warning disable IDE0090 // Use 'new(...)' /// [RequiresDynamicCode("This overload of GetCollection() is incompatible with NativeAOT. For dynamic mapping via Dictionary, call GetDynamicCollection() instead.")] - [RequiresUnreferencedCode("This overload of GetCollecttion() is incompatible with trimming. For dynamic mapping via Dictionary, call GetDynamicCollection() instead.")] + [RequiresUnreferencedCode("This overload of GetCollection() is incompatible with trimming. For dynamic mapping via Dictionary, call GetDynamicCollection() instead.")] #if NET8_0_OR_GREATER public override MongoCollection GetCollection(string name, VectorStoreCollectionDefinition? definition = null) #else diff --git a/dotnet/test/VectorData/MongoDB.ConformanceTests/CRUD/MongoBatchConformanceTests.cs b/dotnet/test/VectorData/MongoDB.ConformanceTests/CRUD/MongoBatchConformanceTests.cs index 95cad010e6e4..307a79479f6a 100644 --- a/dotnet/test/VectorData/MongoDB.ConformanceTests/CRUD/MongoBatchConformanceTests.cs +++ b/dotnet/test/VectorData/MongoDB.ConformanceTests/CRUD/MongoBatchConformanceTests.cs @@ -1,12 +1,23 @@ // Copyright (c) Microsoft. All rights reserved. +using MongoDB.Bson; using MongoDB.ConformanceTests.Support; using VectorData.ConformanceTests.CRUD; using Xunit; namespace MongoDB.ConformanceTests.CRUD; -public class MongoBatchConformanceTests(MongoSimpleModelFixture fixture) - : BatchConformanceTests(fixture), IClassFixture +public class MongoBatchConformanceTests_String(MongoSimpleModelFixture fixture) + : BatchConformanceTests(fixture), IClassFixture> +{ +} + +public class MongoBatchConformanceTests_Guid(MongoSimpleModelFixture fixture) + : BatchConformanceTests(fixture), IClassFixture> +{ +} + +public class MongoBatchConformanceTests_ObjectId(MongoSimpleModelFixture fixture) + : BatchConformanceTests(fixture), IClassFixture> { } diff --git a/dotnet/test/VectorData/MongoDB.ConformanceTests/CRUD/MongoRecordConformanceTests.cs b/dotnet/test/VectorData/MongoDB.ConformanceTests/CRUD/MongoRecordConformanceTests.cs index 393e1ac69903..171676cceb46 100644 --- a/dotnet/test/VectorData/MongoDB.ConformanceTests/CRUD/MongoRecordConformanceTests.cs +++ b/dotnet/test/VectorData/MongoDB.ConformanceTests/CRUD/MongoRecordConformanceTests.cs @@ -1,12 +1,23 @@ // Copyright (c) Microsoft. All rights reserved. +using MongoDB.Bson; using MongoDB.ConformanceTests.Support; using VectorData.ConformanceTests.CRUD; using Xunit; namespace MongoDB.ConformanceTests.CRUD; -public class MongoRecordConformanceTests(MongoSimpleModelFixture fixture) - : RecordConformanceTests(fixture), IClassFixture +public class MongoRecordConformanceTests_String(MongoSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture> +{ +} + +public class MongoRecordConformanceTests_Guid(MongoSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture> +{ +} + +public class MongoRecordConformanceTests_ObjectId(MongoSimpleModelFixture fixture) + : RecordConformanceTests(fixture), IClassFixture> { } diff --git a/dotnet/test/VectorData/MongoDB.ConformanceTests/Support/MongoSimpleModelFixture.cs b/dotnet/test/VectorData/MongoDB.ConformanceTests/Support/MongoSimpleModelFixture.cs index 82da1a539532..ec264726f40c 100644 --- a/dotnet/test/VectorData/MongoDB.ConformanceTests/Support/MongoSimpleModelFixture.cs +++ b/dotnet/test/VectorData/MongoDB.ConformanceTests/Support/MongoSimpleModelFixture.cs @@ -4,7 +4,8 @@ namespace MongoDB.ConformanceTests.Support; -public class MongoSimpleModelFixture : SimpleModelFixture +public class MongoSimpleModelFixture : SimpleModelFixture + where TKey : notnull { public override TestStore TestStore => MongoTestStore.Instance; } diff --git a/dotnet/test/VectorData/MongoDB.ConformanceTests/Support/MongoTestStore.cs b/dotnet/test/VectorData/MongoDB.ConformanceTests/Support/MongoTestStore.cs index 15bea231a726..0696f4b2728f 100644 --- a/dotnet/test/VectorData/MongoDB.ConformanceTests/Support/MongoTestStore.cs +++ b/dotnet/test/VectorData/MongoDB.ConformanceTests/Support/MongoTestStore.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using Microsoft.SemanticKernel.Connectors.MongoDB; +using MongoDB.Bson; using MongoDB.Driver; using Testcontainers.MongoDb; using VectorData.ConformanceTests.Support; @@ -58,6 +59,18 @@ private async Task StartMongoDbContainerAsync() }; } + private static readonly string? s_baseObjectId = ObjectId.GenerateNewId().ToString().Substring(0, 14); + + public override TKey GenerateKey(int value) + { + if (typeof(TKey) == typeof(ObjectId)) + { + return (TKey)(object)ObjectId.Parse(s_baseObjectId + value.ToString("0000000000")); + } + + return base.GenerateKey(value); + } + protected override async Task StopAsync() { if (this._container != null) diff --git a/dotnet/test/VectorData/MongoDB.UnitTests/MongoVectorStoreTests.cs b/dotnet/test/VectorData/MongoDB.UnitTests/MongoVectorStoreTests.cs index 69ab57ef6bd2..ee2754342390 100644 --- a/dotnet/test/VectorData/MongoDB.UnitTests/MongoVectorStoreTests.cs +++ b/dotnet/test/VectorData/MongoDB.UnitTests/MongoVectorStoreTests.cs @@ -26,7 +26,7 @@ public void GetCollectionWithNotSupportedKeyThrowsException() using var sut = new MongoVectorStore(this._mockMongoDatabase.Object); // Act & Assert - Assert.Throws(() => sut.GetCollection("collection")); + Assert.Throws(() => sut.GetCollection("collection")); } [Fact]