Skip to content

Commit 8e343fb

Browse files
damiengSergeyMenshykhwestey-m
authored
.Net: Add support for Guid and ObjectID keys in the MongoDB Connector (#12827)
### Motivation and Context The MongoDB connector currently only supports string key types. ### Description Adds support for Guid and ObjectId key types to the MongoDB Connector. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄 --------- Co-authored-by: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Co-authored-by: westey <164392973+westey-m@users.noreply.github.com>
1 parent 3147b45 commit 8e343fb

8 files changed

Lines changed: 59 additions & 30 deletions

File tree

dotnet/src/InternalUtilities/connectors/Memory/MongoDB/MongoModelBuilder.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using Microsoft.Extensions.AI;
88
using Microsoft.Extensions.VectorData;
99
using Microsoft.Extensions.VectorData.ProviderServices;
10+
using MongoDB.Bson;
1011
using MongoDB.Bson.Serialization.Attributes;
1112

1213
namespace Microsoft.SemanticKernel.Connectors.MongoDB;
@@ -43,9 +44,9 @@ protected override void ProcessTypeProperties(Type type, VectorStoreCollectionDe
4344

4445
protected override bool IsKeyPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes)
4546
{
46-
supportedTypes = "string";
47+
supportedTypes = "string, Guid, ObjectId";
4748

48-
return type == typeof(string);
49+
return type == typeof(string) || type == typeof(Guid) || type == typeof(ObjectId);
4950
}
5051

5152
protected override bool IsDataPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes)

dotnet/src/VectorData/MongoDB/MongoCollection.cs

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ public class MongoCollection<TKey, TRecord> : VectorStoreCollection<TKey, TRecor
7575
/// <summary>Number of nearest neighbors to use during the vector search.</summary>
7676
private readonly int? _numCandidates;
7777

78+
/// <summary>Types of keys permitted.</summary>
79+
private readonly Type[] _validKeyTypes = [typeof(string), typeof(Guid), typeof(ObjectId)];
80+
7881
/// <summary>
7982
/// Initializes a new instance of the <see cref="MongoCollection{TKey, TRecord}"/> class.
8083
/// </summary>
@@ -103,9 +106,9 @@ internal MongoCollection(IMongoDatabase mongoDatabase, string name, Func<MongoCo
103106
Verify.NotNull(mongoDatabase);
104107
Verify.NotNullOrWhiteSpace(name);
105108

106-
if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(object))
109+
if (!this._validKeyTypes.Contains(typeof(TKey)) && typeof(TKey) != typeof(object))
107110
{
108-
throw new NotSupportedException("Only string keys are supported.");
111+
throw new NotSupportedException("Only string, Guid and ObjectID keys are supported.");
109112
}
110113

111114
options ??= MongoCollectionOptions.Default;
@@ -157,9 +160,9 @@ await this.RunOperationWithRetryAsync(
157160
/// <inheritdoc />
158161
public override async Task DeleteAsync(TKey key, CancellationToken cancellationToken = default)
159162
{
160-
var stringKey = this.GetStringKey(key);
163+
Verify.NotNull(key);
161164

162-
await this.RunOperationAsync("DeleteOne", () => this._mongoCollection.DeleteOneAsync(this.GetFilterById(stringKey), cancellationToken))
165+
await this.RunOperationAsync("DeleteOne", () => this._mongoCollection.DeleteOneAsync(this.GetFilterById(key), cancellationToken))
163166
.ConfigureAwait(false);
164167
}
165168

@@ -181,7 +184,7 @@ public override Task EnsureCollectionDeletedAsync(CancellationToken cancellation
181184
/// <inheritdoc />
182185
public override async Task<TRecord?> GetAsync(TKey key, RecordRetrievalOptions? options = null, CancellationToken cancellationToken = default)
183186
{
184-
var stringKey = this.GetStringKey(key);
187+
Verify.NotNull(key);
185188

186189
var includeVectors = options?.IncludeVectors ?? false;
187190
if (includeVectors && this._model.EmbeddingGenerationRequired)
@@ -190,7 +193,7 @@ public override Task EnsureCollectionDeletedAsync(CancellationToken cancellation
190193
}
191194

192195
using var cursor = await this
193-
.FindAsync(this.GetFilterById(stringKey), top: 1, skip: null, includeVectors, sortDefinition: null, cancellationToken)
196+
.FindAsync(this.GetFilterById(key), top: 1, skip: null, includeVectors, sortDefinition: null, cancellationToken)
194197
.ConfigureAwait(false);
195198

196199
var record = await cursor.SingleOrDefaultAsync(cancellationToken).ConfigureAwait(false);
@@ -267,7 +270,7 @@ private async Task UpsertCoreAsync(TRecord record, int recordIndex, IReadOnlyLis
267270
var replaceOptions = new ReplaceOptions { IsUpsert = true };
268271
var storageModel = this._mapper.MapFromDataToStorageModel(record, recordIndex, generatedEmbeddings);
269272

270-
var key = storageModel[MongoConstants.MongoReservedKeyPropertyName].AsString;
273+
var key = storageModel[MongoConstants.MongoReservedKeyPropertyName];
271274

272275
await this.RunOperationAsync(OperationName, async () =>
273276
await this._mongoCollection
@@ -673,11 +676,11 @@ private async IAsyncEnumerable<VectorSearchResult<TRecord>> EnumerateAndMapSearc
673676
}
674677
}
675678

676-
private FilterDefinition<BsonDocument> GetFilterById(string id)
679+
private FilterDefinition<BsonDocument> GetFilterById(object id)
677680
=> Builders<BsonDocument>.Filter.Eq(document => document[MongoConstants.MongoReservedKeyPropertyName], id);
678681

679-
private FilterDefinition<BsonDocument> GetFilterByIds(IEnumerable<string> ids)
680-
=> Builders<BsonDocument>.Filter.In(document => document[MongoConstants.MongoReservedKeyPropertyName].AsString, ids);
682+
private FilterDefinition<BsonDocument> GetFilterByIds(IEnumerable<object> ids)
683+
=> Builders<BsonDocument>.Filter.In(document => document[MongoConstants.MongoReservedKeyPropertyName], ids);
681684

682685
private async Task<bool> InternalCollectionExistsAsync(CancellationToken cancellationToken)
683686
{
@@ -723,16 +726,5 @@ private async Task<T> RunOperationWithRetryAsync<T>(
723726
operation,
724727
cancellationToken).ConfigureAwait(false);
725728

726-
private string GetStringKey(TKey key)
727-
{
728-
Verify.NotNull(key);
729-
730-
var stringKey = key as string ?? throw new UnreachableException("string key should have been validated during model building");
731-
732-
Verify.NotNullOrWhiteSpace(stringKey, nameof(key));
733-
734-
return stringKey;
735-
}
736-
737729
#endregion
738730
}

dotnet/src/VectorData/MongoDB/MongoVectorStore.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public MongoVectorStore(IMongoDatabase mongoDatabase, MongoVectorStoreOptions? o
5454
#pragma warning disable IDE0090 // Use 'new(...)'
5555
/// <inheritdoc />
5656
[RequiresDynamicCode("This overload of GetCollection() is incompatible with NativeAOT. For dynamic mapping via Dictionary<string, object?>, call GetDynamicCollection() instead.")]
57-
[RequiresUnreferencedCode("This overload of GetCollecttion() is incompatible with trimming. For dynamic mapping via Dictionary<string, object?>, call GetDynamicCollection() instead.")]
57+
[RequiresUnreferencedCode("This overload of GetCollection() is incompatible with trimming. For dynamic mapping via Dictionary<string, object?>, call GetDynamicCollection() instead.")]
5858
#if NET8_0_OR_GREATER
5959
public override MongoCollection<TKey, TRecord> GetCollection<TKey, TRecord>(string name, VectorStoreCollectionDefinition? definition = null)
6060
#else
Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using MongoDB.Bson;
34
using MongoDB.ConformanceTests.Support;
45
using VectorData.ConformanceTests.CRUD;
56
using Xunit;
67

78
namespace MongoDB.ConformanceTests.CRUD;
89

9-
public class MongoBatchConformanceTests(MongoSimpleModelFixture fixture)
10-
: BatchConformanceTests<string>(fixture), IClassFixture<MongoSimpleModelFixture>
10+
public class MongoBatchConformanceTests_String(MongoSimpleModelFixture<string> fixture)
11+
: BatchConformanceTests<string>(fixture), IClassFixture<MongoSimpleModelFixture<string>>
12+
{
13+
}
14+
15+
public class MongoBatchConformanceTests_Guid(MongoSimpleModelFixture<Guid> fixture)
16+
: BatchConformanceTests<Guid>(fixture), IClassFixture<MongoSimpleModelFixture<Guid>>
17+
{
18+
}
19+
20+
public class MongoBatchConformanceTests_ObjectId(MongoSimpleModelFixture<ObjectId> fixture)
21+
: BatchConformanceTests<ObjectId>(fixture), IClassFixture<MongoSimpleModelFixture<ObjectId>>
1122
{
1223
}
Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using MongoDB.Bson;
34
using MongoDB.ConformanceTests.Support;
45
using VectorData.ConformanceTests.CRUD;
56
using Xunit;
67

78
namespace MongoDB.ConformanceTests.CRUD;
89

9-
public class MongoRecordConformanceTests(MongoSimpleModelFixture fixture)
10-
: RecordConformanceTests<string>(fixture), IClassFixture<MongoSimpleModelFixture>
10+
public class MongoRecordConformanceTests_String(MongoSimpleModelFixture<string> fixture)
11+
: RecordConformanceTests<string>(fixture), IClassFixture<MongoSimpleModelFixture<string>>
12+
{
13+
}
14+
15+
public class MongoRecordConformanceTests_Guid(MongoSimpleModelFixture<Guid> fixture)
16+
: RecordConformanceTests<Guid>(fixture), IClassFixture<MongoSimpleModelFixture<Guid>>
17+
{
18+
}
19+
20+
public class MongoRecordConformanceTests_ObjectId(MongoSimpleModelFixture<ObjectId> fixture)
21+
: RecordConformanceTests<ObjectId>(fixture), IClassFixture<MongoSimpleModelFixture<ObjectId>>
1122
{
1223
}

dotnet/test/VectorData/MongoDB.ConformanceTests/Support/MongoSimpleModelFixture.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
namespace MongoDB.ConformanceTests.Support;
66

7-
public class MongoSimpleModelFixture : SimpleModelFixture<string>
7+
public class MongoSimpleModelFixture<TKey> : SimpleModelFixture<TKey>
8+
where TKey : notnull
89
{
910
public override TestStore TestStore => MongoTestStore.Instance;
1011
}

dotnet/test/VectorData/MongoDB.ConformanceTests/Support/MongoTestStore.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

33
using Microsoft.SemanticKernel.Connectors.MongoDB;
4+
using MongoDB.Bson;
45
using MongoDB.Driver;
56
using Testcontainers.MongoDb;
67
using VectorData.ConformanceTests.Support;
@@ -58,6 +59,18 @@ private async Task<MongoClientSettings> StartMongoDbContainerAsync()
5859
};
5960
}
6061

62+
private static readonly string? s_baseObjectId = ObjectId.GenerateNewId().ToString().Substring(0, 14);
63+
64+
public override TKey GenerateKey<TKey>(int value)
65+
{
66+
if (typeof(TKey) == typeof(ObjectId))
67+
{
68+
return (TKey)(object)ObjectId.Parse(s_baseObjectId + value.ToString("0000000000"));
69+
}
70+
71+
return base.GenerateKey<TKey>(value);
72+
}
73+
6174
protected override async Task StopAsync()
6275
{
6376
if (this._container != null)

dotnet/test/VectorData/MongoDB.UnitTests/MongoVectorStoreTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public void GetCollectionWithNotSupportedKeyThrowsException()
2626
using var sut = new MongoVectorStore(this._mockMongoDatabase.Object);
2727

2828
// Act & Assert
29-
Assert.Throws<NotSupportedException>(() => sut.GetCollection<Guid, MongoHotelModel>("collection"));
29+
Assert.Throws<NotSupportedException>(() => sut.GetCollection<long, MongoHotelModel>("collection"));
3030
}
3131

3232
[Fact]

0 commit comments

Comments
 (0)