Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.Extensions.AI;
using Microsoft.Extensions.VectorData;
using Microsoft.Extensions.VectorData.ProviderServices;
using MongoDB.Bson;
using MongoDB.Bson.Serialization.Attributes;

namespace Microsoft.SemanticKernel.Connectors.MongoDB;
Expand Down Expand Up @@ -43,9 +44,9 @@ protected override void ProcessTypeProperties(Type type, VectorStoreCollectionDe

protected override bool IsKeyPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes)
{
supportedTypes = "string";
supportedTypes = "string, Guid, ObjectId";

return type == typeof(string);
return type == typeof(string) || type == typeof(Guid) || type == typeof(ObjectId);
}

protected override bool IsDataPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes)
Expand Down
34 changes: 13 additions & 21 deletions dotnet/src/VectorData/MongoDB/MongoCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ public class MongoCollection<TKey, TRecord> : VectorStoreCollection<TKey, TRecor
/// <summary>Number of nearest neighbors to use during the vector search.</summary>
private readonly int? _numCandidates;

/// <summary>Types of keys permitted.</summary>
private readonly Type[] _validKeyTypes = [typeof(string), typeof(Guid), typeof(ObjectId)];

/// <summary>
/// Initializes a new instance of the <see cref="MongoCollection{TKey, TRecord}"/> class.
/// </summary>
Expand Down Expand Up @@ -103,9 +106,9 @@ internal MongoCollection(IMongoDatabase mongoDatabase, string name, Func<MongoCo
Verify.NotNull(mongoDatabase);
Verify.NotNullOrWhiteSpace(name);

if (typeof(TKey) != typeof(string) && typeof(TKey) != typeof(object))
if (!this._validKeyTypes.Contains(typeof(TKey)) && typeof(TKey) != typeof(object))
{
throw new NotSupportedException("Only string keys are supported.");
throw new NotSupportedException("Only string, Guid and ObjectID keys are supported.");
}

options ??= MongoCollectionOptions.Default;
Expand Down Expand Up @@ -157,9 +160,9 @@ await this.RunOperationWithRetryAsync(
/// <inheritdoc />
public override async Task DeleteAsync(TKey key, CancellationToken cancellationToken = default)
{
var stringKey = this.GetStringKey(key);
Verify.NotNull(key);

await this.RunOperationAsync("DeleteOne", () => this._mongoCollection.DeleteOneAsync(this.GetFilterById(stringKey), cancellationToken))
await this.RunOperationAsync("DeleteOne", () => this._mongoCollection.DeleteOneAsync(this.GetFilterById(key), cancellationToken))
.ConfigureAwait(false);
}

Expand All @@ -181,7 +184,7 @@ public override Task EnsureCollectionDeletedAsync(CancellationToken cancellation
/// <inheritdoc />
public override async Task<TRecord?> GetAsync(TKey key, RecordRetrievalOptions? options = null, CancellationToken cancellationToken = default)
{
var stringKey = this.GetStringKey(key);
Verify.NotNull(key);

var includeVectors = options?.IncludeVectors ?? false;
if (includeVectors && this._model.EmbeddingGenerationRequired)
Expand All @@ -190,7 +193,7 @@ public override Task EnsureCollectionDeletedAsync(CancellationToken cancellation
}

using var cursor = await this
.FindAsync(this.GetFilterById(stringKey), top: 1, skip: null, includeVectors, sortDefinition: null, cancellationToken)
.FindAsync(this.GetFilterById(key), top: 1, skip: null, includeVectors, sortDefinition: null, cancellationToken)
.ConfigureAwait(false);

var record = await cursor.SingleOrDefaultAsync(cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -267,7 +270,7 @@ private async Task UpsertCoreAsync(TRecord record, int recordIndex, IReadOnlyLis
var replaceOptions = new ReplaceOptions { IsUpsert = true };
var storageModel = this._mapper.MapFromDataToStorageModel(record, recordIndex, generatedEmbeddings);

var key = storageModel[MongoConstants.MongoReservedKeyPropertyName].AsString;
var key = storageModel[MongoConstants.MongoReservedKeyPropertyName];

await this.RunOperationAsync(OperationName, async () =>
await this._mongoCollection
Expand Down Expand Up @@ -673,11 +676,11 @@ private async IAsyncEnumerable<VectorSearchResult<TRecord>> EnumerateAndMapSearc
}
}

private FilterDefinition<BsonDocument> GetFilterById(string id)
private FilterDefinition<BsonDocument> GetFilterById(object id)
=> Builders<BsonDocument>.Filter.Eq(document => document[MongoConstants.MongoReservedKeyPropertyName], id);

private FilterDefinition<BsonDocument> GetFilterByIds(IEnumerable<string> ids)
=> Builders<BsonDocument>.Filter.In(document => document[MongoConstants.MongoReservedKeyPropertyName].AsString, ids);
private FilterDefinition<BsonDocument> GetFilterByIds(IEnumerable<object> ids)
=> Builders<BsonDocument>.Filter.In(document => document[MongoConstants.MongoReservedKeyPropertyName], ids);

private async Task<bool> InternalCollectionExistsAsync(CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -723,16 +726,5 @@ private async Task<T> RunOperationWithRetryAsync<T>(
operation,
cancellationToken).ConfigureAwait(false);

private string GetStringKey(TKey key)
{
Verify.NotNull(key);

var stringKey = key as string ?? throw new UnreachableException("string key should have been validated during model building");

Verify.NotNullOrWhiteSpace(stringKey, nameof(key));

return stringKey;
}

#endregion
}
2 changes: 1 addition & 1 deletion dotnet/src/VectorData/MongoDB/MongoVectorStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public MongoVectorStore(IMongoDatabase mongoDatabase, MongoVectorStoreOptions? o
#pragma warning disable IDE0090 // Use 'new(...)'
/// <inheritdoc />
[RequiresDynamicCode("This overload of GetCollection() is incompatible with NativeAOT. For dynamic mapping via Dictionary<string, object?>, call GetDynamicCollection() instead.")]
[RequiresUnreferencedCode("This overload of GetCollecttion() is incompatible with trimming. For dynamic mapping via Dictionary<string, object?>, call GetDynamicCollection() instead.")]
[RequiresUnreferencedCode("This overload of GetCollection() is incompatible with trimming. For dynamic mapping via Dictionary<string, object?>, call GetDynamicCollection() instead.")]
#if NET8_0_OR_GREATER
public override MongoCollection<TKey, TRecord> GetCollection<TKey, TRecord>(string name, VectorStoreCollectionDefinition? definition = null)
#else
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
// Copyright (c) Microsoft. All rights reserved.

using MongoDB.Bson;
using MongoDB.ConformanceTests.Support;
using VectorData.ConformanceTests.CRUD;
using Xunit;

namespace MongoDB.ConformanceTests.CRUD;

public class MongoBatchConformanceTests(MongoSimpleModelFixture fixture)
: BatchConformanceTests<string>(fixture), IClassFixture<MongoSimpleModelFixture>
public class MongoBatchConformanceTests_String(MongoSimpleModelFixture<string> fixture)
: BatchConformanceTests<string>(fixture), IClassFixture<MongoSimpleModelFixture<string>>
{
}

public class MongoBatchConformanceTests_Guid(MongoSimpleModelFixture<Guid> fixture)
: BatchConformanceTests<Guid>(fixture), IClassFixture<MongoSimpleModelFixture<Guid>>
{
}

public class MongoBatchConformanceTests_ObjectId(MongoSimpleModelFixture<ObjectId> fixture)
: BatchConformanceTests<ObjectId>(fixture), IClassFixture<MongoSimpleModelFixture<ObjectId>>
{
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
// Copyright (c) Microsoft. All rights reserved.

using MongoDB.Bson;
using MongoDB.ConformanceTests.Support;
using VectorData.ConformanceTests.CRUD;
using Xunit;

namespace MongoDB.ConformanceTests.CRUD;

public class MongoRecordConformanceTests(MongoSimpleModelFixture fixture)
: RecordConformanceTests<string>(fixture), IClassFixture<MongoSimpleModelFixture>
public class MongoRecordConformanceTests_String(MongoSimpleModelFixture<string> fixture)
: RecordConformanceTests<string>(fixture), IClassFixture<MongoSimpleModelFixture<string>>
{
}

public class MongoRecordConformanceTests_Guid(MongoSimpleModelFixture<Guid> fixture)
: RecordConformanceTests<Guid>(fixture), IClassFixture<MongoSimpleModelFixture<Guid>>
{
}

public class MongoRecordConformanceTests_ObjectId(MongoSimpleModelFixture<ObjectId> fixture)
: RecordConformanceTests<ObjectId>(fixture), IClassFixture<MongoSimpleModelFixture<ObjectId>>
{
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

namespace MongoDB.ConformanceTests.Support;

public class MongoSimpleModelFixture : SimpleModelFixture<string>
public class MongoSimpleModelFixture<TKey> : SimpleModelFixture<TKey>
where TKey : notnull
{
public override TestStore TestStore => MongoTestStore.Instance;
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.SemanticKernel.Connectors.MongoDB;
using MongoDB.Bson;
using MongoDB.Driver;
using Testcontainers.MongoDb;
using VectorData.ConformanceTests.Support;
Expand Down Expand Up @@ -58,6 +59,18 @@ private async Task<MongoClientSettings> StartMongoDbContainerAsync()
};
}

private static readonly string? s_baseObjectId = ObjectId.GenerateNewId().ToString().Substring(0, 14);

public override TKey GenerateKey<TKey>(int value)
{
if (typeof(TKey) == typeof(ObjectId))
{
return (TKey)(object)ObjectId.Parse(s_baseObjectId + value.ToString("0000000000"));
}

return base.GenerateKey<TKey>(value);
}

protected override async Task StopAsync()
{
if (this._container != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public void GetCollectionWithNotSupportedKeyThrowsException()
using var sut = new MongoVectorStore(this._mockMongoDatabase.Object);

// Act & Assert
Assert.Throws<NotSupportedException>(() => sut.GetCollection<Guid, MongoHotelModel>("collection"));
Assert.Throws<NotSupportedException>(() => sut.GetCollection<long, MongoHotelModel>("collection"));
}

[Fact]
Expand Down
Loading