diff --git a/src/common/connectionManager.ts b/src/common/connectionManager.ts index 1094f845..9145441e 100644 --- a/src/common/connectionManager.ts +++ b/src/common/connectionManager.ts @@ -32,9 +32,33 @@ export interface ConnectionState { connectedAtlasCluster?: AtlasClusterConnectionInfo; } -export interface ConnectionStateConnected extends ConnectionState { - tag: "connected"; - serviceProvider: NodeDriverServiceProvider; +export class ConnectionStateConnected implements ConnectionState { + public tag = "connected" as const; + + constructor( + public serviceProvider: NodeDriverServiceProvider, + public connectionStringAuthType?: ConnectionStringAuthType, + public connectedAtlasCluster?: AtlasClusterConnectionInfo + ) {} + + private _isSearchSupported?: boolean; + + public async isSearchSupported(): Promise { + if (this._isSearchSupported === undefined) { + try { + const dummyDatabase = `search-index-test-db-${Date.now()}`; + const dummyCollection = `search-index-test-coll-${Date.now()}`; + // If a cluster supports search indexes, the call below will succeed + // with a cursor otherwise will throw an Error + await this.serviceProvider.getSearchIndexes(dummyDatabase, dummyCollection); + this._isSearchSupported = true; + } catch { + this._isSearchSupported = false; + } + } + + return this._isSearchSupported; + } } export interface ConnectionStateConnecting extends ConnectionState { @@ -199,12 +223,10 @@ export class MCPConnectionManager extends ConnectionManager { }); } - return this.changeState("connection-success", { - tag: "connected", - connectedAtlasCluster: settings.atlas, - serviceProvider: await serviceProvider, - connectionStringAuthType, - }); + return this.changeState( + "connection-success", + new ConnectionStateConnected(await serviceProvider, connectionStringAuthType, settings.atlas) + ); } catch (error: unknown) { const errorReason = error instanceof Error ? error.message : `${error as string}`; this.changeState("connection-error", { @@ -270,11 +292,14 @@ export class MCPConnectionManager extends ConnectionManager { this.currentConnectionState.tag === "connecting" && this.currentConnectionState.connectionStringAuthType?.startsWith("oidc") ) { - this.changeState("connection-success", { - ...this.currentConnectionState, - tag: "connected", - serviceProvider: await this.currentConnectionState.serviceProvider, - }); + this.changeState( + "connection-success", + new ConnectionStateConnected( + await this.currentConnectionState.serviceProvider, + this.currentConnectionState.connectionStringAuthType, + this.currentConnectionState.connectedAtlasCluster + ) + ); } this.logger.info({ diff --git a/src/common/session.ts b/src/common/session.ts index 4ec536f4..6a30cbf0 100644 --- a/src/common/session.ts +++ b/src/common/session.ts @@ -141,6 +141,15 @@ export class Session extends EventEmitter { return this.connectionManager.currentConnectionState.tag === "connected"; } + get isConnectedToMongot(): Promise { + const state = this.connectionManager.currentConnectionState; + if (state.tag === "connected") { + return state.isSearchSupported(); + } + + return Promise.resolve(false); + } + get serviceProvider(): NodeDriverServiceProvider { if (this.isConnectedToMongoDB) { const state = this.connectionManager.currentConnectionState as ConnectionStateConnected; @@ -153,17 +162,4 @@ export class Session extends EventEmitter { get connectedAtlasCluster(): AtlasClusterConnectionInfo | undefined { return this.connectionManager.currentConnectionState.connectedAtlasCluster; } - - async isSearchIndexSupported(): Promise { - try { - const dummyDatabase = `search-index-test-db-${Date.now()}`; - const dummyCollection = `search-index-test-coll-${Date.now()}`; - // If a cluster supports search indexes, the call below will succeed - // with a cursor otherwise will throw an Error - await this.serviceProvider.getSearchIndexes(dummyDatabase, dummyCollection); - return true; - } catch { - return false; - } - } } diff --git a/src/resources/common/debug.ts b/src/resources/common/debug.ts index f76030b5..6f381b93 100644 --- a/src/resources/common/debug.ts +++ b/src/resources/common/debug.ts @@ -61,7 +61,7 @@ export class DebugResource extends ReactiveResource< switch (this.current.tag) { case "connected": { - const searchIndexesSupported = await this.session.isSearchIndexSupported(); + const searchIndexesSupported = await this.session.isConnectedToMongot; result += `The user is connected to the MongoDB cluster${searchIndexesSupported ? " with support for search indexes" : " without any support for search indexes"}.`; break; } diff --git a/src/tools/mongodb/create/createIndex.ts b/src/tools/mongodb/create/createIndex.ts index d87b9df0..305ab576 100644 --- a/src/tools/mongodb/create/createIndex.ts +++ b/src/tools/mongodb/create/createIndex.ts @@ -1,16 +1,87 @@ import { z } from "zod"; import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; -import type { ToolArgs, OperationType } from "../../tool.js"; +import type { ToolCategory } from "../../tool.js"; +import { type ToolArgs, type OperationType, FeatureFlags } from "../../tool.js"; import type { IndexDirection } from "mongodb"; +const vectorSearchIndexDefinition = z.object({ + type: z.literal("vectorSearch"), + fields: z + .array( + z.discriminatedUnion("type", [ + z + .object({ + type: z.literal("filter"), + path: z + .string() + .describe( + "Name of the field to index. For nested fields, use dot notation to specify path to embedded fields" + ), + }) + .strict() + .describe("Definition for a field that will be used for pre-filtering results."), + z + .object({ + type: z.literal("vector"), + path: z + .string() + .describe( + "Name of the field to index. For nested fields, use dot notation to specify path to embedded fields" + ), + numDimensions: z + .number() + .min(1) + .max(8192) + .describe( + "Number of vector dimensions that MongoDB Vector Search enforces at index-time and query-time" + ), + similarity: z + .enum(["cosine", "euclidean", "dotProduct"]) + .default("cosine") + .describe( + "Vector similarity function to use to search for top K-nearest neighbors. You can set this field only for vector-type fields." + ), + quantization: z + .enum(["none", "scalar", "binary"]) + .optional() + .default("none") + .describe( + "Type of automatic vector quantization for your vectors. Use this setting only if your embeddings are float or double vectors." + ), + }) + .strict() + .describe("Definition for a field that contains vector embeddings."), + ]) + ) + .nonempty() + .refine((fields) => fields.some((f) => f.type === "vector"), { + message: "At least one vector field must be defined", + }) + .describe( + "Definitions for the vector and filter fields to index, one definition per document. You must specify `vector` for fields that contain vector embeddings and `filter` for additional fields to filter on. At least one vector-type field definition is required." + ), +}); + export class CreateIndexTool extends MongoDBToolBase { public name = "create-index"; protected description = "Create an index for a collection"; protected argsShape = { ...DbOperationArgs, - keys: z.object({}).catchall(z.custom()).describe("The index definition"), name: z.string().optional().describe("The name of the index"), + definition: z + .array( + z.discriminatedUnion("type", [ + z.object({ + type: z.literal("classic"), + keys: z.object({}).catchall(z.custom()).describe("The index definition"), + }), + ...(this.isFeatureFlagEnabled(FeatureFlags.VectorSearch) ? [vectorSearchIndexDefinition] : []), + ]) + ) + .describe( + "The index definition. Use 'classic' for standard indexes and 'vectorSearch' for vector search indexes" + ), }; public operationType: OperationType = "create"; @@ -18,16 +89,59 @@ export class CreateIndexTool extends MongoDBToolBase { protected async execute({ database, collection, - keys, name, + definition: definitions, }: ToolArgs): Promise { const provider = await this.ensureConnected(); - const indexes = await provider.createIndexes(database, collection, [ - { - key: keys, - name, - }, - ]); + let indexes: string[] = []; + const definition = definitions[0]; + if (!definition) { + throw new Error("Index definition not provided. Expected one of the following: `classic`, `vectorSearch`"); + } + + switch (definition.type) { + case "classic": + indexes = await provider.createIndexes(database, collection, [ + { + key: definition.keys, + name, + }, + ]); + break; + case "vectorSearch": + { + const isVectorSearchSupported = await this.session.isConnectedToMongot; + if (!isVectorSearchSupported) { + // TODO: remove hacky casts once we merge the local dev tools + const isLocalAtlasAvailable = + (this.server?.tools.filter((t) => t.category === ("atlas-local" as unknown as ToolCategory)) + .length ?? 0) > 0; + + const CTA = isLocalAtlasAvailable ? "`atlas-local` tools" : "Atlas CLI"; + return { + content: [ + { + text: `The connected MongoDB deployment does not support vector search indexes. Either connect to a MongoDB Atlas cluster or use the ${CTA} to create and manage a local Atlas deployment.`, + type: "text", + }, + ], + isError: true, + }; + } + + indexes = await provider.createSearchIndexes(database, collection, [ + { + name, + definition: { + fields: definition.fields, + }, + type: "vectorSearch", + }, + ]); + } + + break; + } return { content: [ diff --git a/src/tools/mongodb/mongodbTool.ts b/src/tools/mongodb/mongodbTool.ts index ded994ab..2b901036 100644 --- a/src/tools/mongodb/mongodbTool.ts +++ b/src/tools/mongodb/mongodbTool.ts @@ -13,7 +13,7 @@ export const DbOperationArgs = { }; export abstract class MongoDBToolBase extends ToolBase { - private server?: Server; + protected server?: Server; public category: ToolCategory = "mongodb"; protected async ensureConnected(): Promise { diff --git a/src/tools/tool.ts b/src/tools/tool.ts index d609e78a..bb7e872c 100644 --- a/src/tools/tool.ts +++ b/src/tools/tool.ts @@ -15,6 +15,10 @@ export type ToolCallbackArgs = Parameters = Parameters>[1]; +export const enum FeatureFlags { + VectorSearch = "vectorSearch", +} + /** * The type of operation the tool performs. This is used when evaluating if a tool is allowed to run based on * the config's `disabledTools` and `readOnly` settings. @@ -314,6 +318,16 @@ export abstract class ToolBase { this.telemetry.emitEvents([event]); } + + // TODO: Move this to a separate file + protected isFeatureFlagEnabled(flag: FeatureFlags): boolean { + switch (flag) { + case FeatureFlags.VectorSearch: + return this.config.voyageApiKey !== ""; + default: + return false; + } + } } /** diff --git a/tests/accuracy/createIndex.test.ts b/tests/accuracy/createIndex.test.ts index 08326ce3..becd5b46 100644 --- a/tests/accuracy/createIndex.test.ts +++ b/tests/accuracy/createIndex.test.ts @@ -1,6 +1,20 @@ +import { afterAll, beforeAll } from "vitest"; import { describeAccuracyTests } from "./sdk/describeAccuracyTests.js"; import { Matcher } from "./sdk/matcher.js"; +let originalApiKey: string | undefined; +beforeAll(() => { + originalApiKey = process.env.MDB_VOYAGE_API_KEY; + + // We just need a valid key when registering the tool, the actual value is not important + if (!originalApiKey) { + process.env.MDB_VOYAGE_API_KEY = "valid-key"; + } +}); +afterAll(() => { + process.env.MDB_VOYAGE_API_KEY = originalApiKey; +}); + describeAccuracyTests([ { prompt: "Create an index that covers the following query on 'mflix.movies' namespace - { \"release_year\": 1992 }", @@ -11,9 +25,14 @@ describeAccuracyTests([ database: "mflix", collection: "movies", name: Matcher.anyOf(Matcher.undefined, Matcher.string()), - keys: { - release_year: 1, - }, + definition: [ + { + type: "classic", + keys: { + release_year: 1, + }, + }, + ], }, }, ], @@ -27,9 +46,104 @@ describeAccuracyTests([ database: "mflix", collection: "movies", name: Matcher.anyOf(Matcher.undefined, Matcher.string()), - keys: { - title: "text", - }, + definition: [ + { + type: "classic", + keys: { + title: "text", + }, + }, + ], + }, + }, + ], + }, + { + prompt: "Create a vector search index on 'mydb.movies' namespace on the 'plotSummary' field. The index should use 1024 dimensions.", + expectedToolCalls: [ + { + toolName: "create-index", + parameters: { + database: "mydb", + collection: "movies", + name: Matcher.anyOf(Matcher.undefined, Matcher.string()), + definition: [ + { + type: "vectorSearch", + fields: [ + { + type: "vector", + path: "plotSummary", + numDimensions: 1024, + }, + ], + }, + ], + }, + }, + ], + }, + { + prompt: "Create a vector search index on 'mydb.movies' namespace with on the 'plotSummary' field and 'genre' field, both of which contain vector embeddings. Pick a sensible number of dimensions for a voyage 3.5 model.", + expectedToolCalls: [ + { + toolName: "create-index", + parameters: { + database: "mydb", + collection: "movies", + name: Matcher.anyOf(Matcher.undefined, Matcher.string()), + definition: [ + { + type: "vectorSearch", + fields: [ + { + type: "vector", + path: "plotSummary", + numDimensions: Matcher.number( + (value) => value % 2 === 0 && value >= 256 && value <= 8192 + ), + similarity: Matcher.anyOf(Matcher.undefined, Matcher.string()), + }, + { + type: "vector", + path: "genre", + numDimensions: Matcher.number( + (value) => value % 2 === 0 && value >= 256 && value <= 8192 + ), + similarity: Matcher.anyOf(Matcher.undefined, Matcher.string()), + }, + ], + }, + ], + }, + }, + ], + }, + { + prompt: "Create a vector search index on 'mydb.movies' namespace where the 'plotSummary' field is indexed as a 1024-dimensional vector and the 'releaseDate' field is indexed as a regular field.", + expectedToolCalls: [ + { + toolName: "create-index", + parameters: { + database: "mydb", + collection: "movies", + name: Matcher.anyOf(Matcher.undefined, Matcher.string()), + definition: [ + { + type: "vectorSearch", + fields: [ + { + type: "vector", + path: "plotSummary", + numDimensions: 1024, + }, + { + type: "filter", + path: "releaseDate", + }, + ], + }, + ], }, }, ], diff --git a/tests/accuracy/dropIndex.test.ts b/tests/accuracy/dropIndex.test.ts index 48023af5..82e76075 100644 --- a/tests/accuracy/dropIndex.test.ts +++ b/tests/accuracy/dropIndex.test.ts @@ -40,9 +40,14 @@ describeAccuracyTests([ database: "mflix", collection: "movies", name: Matcher.anyOf(Matcher.undefined, Matcher.string()), - keys: { - title: "text", - }, + definition: [ + { + keys: { + title: "text", + }, + type: "classic", + }, + ], }, }, { diff --git a/tests/accuracy/sdk/accuracyTestingClient.ts b/tests/accuracy/sdk/accuracyTestingClient.ts index 3e5b89b7..48cba3b2 100644 --- a/tests/accuracy/sdk/accuracyTestingClient.ts +++ b/tests/accuracy/sdk/accuracyTestingClient.ts @@ -82,7 +82,8 @@ export class AccuracyTestingClient { static async initializeClient( mdbConnectionString: string, atlasApiClientId?: string, - atlasApiClientSecret?: string + atlasApiClientSecret?: string, + voyageApiKey?: string ): Promise { const args = [ MCP_SERVER_CLI_SCRIPT, @@ -90,6 +91,7 @@ export class AccuracyTestingClient { mdbConnectionString, ...(atlasApiClientId ? ["--apiClientId", atlasApiClientId] : []), ...(atlasApiClientSecret ? ["--apiClientSecret", atlasApiClientSecret] : []), + ...(voyageApiKey ? ["--voyageApiKey", voyageApiKey] : []), ]; const clientTransport = new StdioClientTransport({ diff --git a/tests/accuracy/sdk/describeAccuracyTests.ts b/tests/accuracy/sdk/describeAccuracyTests.ts index df35e3a0..4c39e962 100644 --- a/tests/accuracy/sdk/describeAccuracyTests.ts +++ b/tests/accuracy/sdk/describeAccuracyTests.ts @@ -68,6 +68,7 @@ export function describeAccuracyTests(accuracyTestConfigs: AccuracyTestConfig[]) const atlasApiClientId = process.env.MDB_MCP_API_CLIENT_ID; const atlasApiClientSecret = process.env.MDB_MCP_API_CLIENT_SECRET; + const voyageApiKey = process.env.MDB_VOYAGE_API_KEY; let commitSHA: string; let accuracyResultStorage: AccuracyResultStorage; @@ -85,7 +86,8 @@ export function describeAccuracyTests(accuracyTestConfigs: AccuracyTestConfig[]) testMCPClient = await AccuracyTestingClient.initializeClient( mdbIntegration.connectionString(), atlasApiClientId, - atlasApiClientSecret + atlasApiClientSecret, + voyageApiKey ); agent = getVercelToolCallingAgent(); }); diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index 0f510bec..bde3c622 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -29,13 +29,22 @@ export const driverOptions = setupDriverConfig({ export const defaultDriverOptions: DriverOptions = { ...driverOptions }; -interface ParameterInfo { +interface Parameter { name: string; - type: string; description: string; required: boolean; } +interface SingleValueParameter extends Parameter { + type: string; +} + +interface AnyOfParameter extends Parameter { + anyOf: { type: string }[]; +} + +type ParameterInfo = SingleValueParameter | AnyOfParameter; + type ToolInfo = Awaited>["tools"][number]; export interface IntegrationTest { @@ -219,18 +228,38 @@ export function getParameters(tool: ToolInfo): ParameterInfo[] { return Object.entries(tool.inputSchema.properties) .sort((a, b) => a[0].localeCompare(b[0])) - .map(([key, value]) => { - expect(value).toHaveProperty("type"); + .map(([name, value]) => { expect(value).toHaveProperty("description"); - const typedValue = value as { type: string; description: string }; - expect(typeof typedValue.type).toBe("string"); - expect(typeof typedValue.description).toBe("string"); + const description = (value as { description: string }).description; + const required = (tool.inputSchema.required as string[])?.includes(name) ?? false; + expect(typeof description).toBe("string"); + + if (value && typeof value === "object" && "anyOf" in value) { + const typedOptions = new Array<{ type: string }>(); + for (const option of value.anyOf as { type: string }[]) { + expect(option).toHaveProperty("type"); + + typedOptions.push({ type: option.type }); + } + + return { + name, + anyOf: typedOptions, + description: description, + required, + }; + } + + expect(value).toHaveProperty("type"); + + const type = (value as { type: string }).type; + expect(typeof type).toBe("string"); return { - name: key, - type: typedValue.type, - description: typedValue.description, - required: (tool.inputSchema.required as string[])?.includes(key) ?? false, + name, + type, + description, + required, }; }); } diff --git a/tests/integration/tools/mongodb/create/createIndex.test.ts b/tests/integration/tools/mongodb/create/createIndex.test.ts index 3c789be8..82ea826d 100644 --- a/tests/integration/tools/mongodb/create/createIndex.test.ts +++ b/tests/integration/tools/mongodb/create/createIndex.test.ts @@ -1,4 +1,4 @@ -import { describeWithMongoDB, validateAutoConnectBehavior } from "../mongodbHelpers.js"; +import { describeWithMongoDB, validateAutoConnectBehavior, waitUntilSearchIsReady } from "../mongodbHelpers.js"; import { getResponseContent, @@ -6,199 +6,507 @@ import { validateToolMetadata, validateThrowsForInvalidArguments, expectDefined, + defaultTestConfig, } from "../../../helpers.js"; -import type { IndexDirection } from "mongodb"; -import { expect, it } from "vitest"; - -describeWithMongoDB("createIndex tool", (integration) => { - validateToolMetadata(integration, "create-index", "Create an index for a collection", [ - ...databaseCollectionParameters, - { - name: "keys", - type: "object", - description: "The index definition", - required: true, - }, - { - name: "name", - type: "string", - description: "The name of the index", - required: false, - }, - ]); - - validateThrowsForInvalidArguments(integration, "create-index", [ - {}, - { collection: "bar", database: 123, keys: { foo: 1 } }, - { collection: [], database: "test", keys: { foo: 1 } }, - { collection: "bar", database: "test", keys: { foo: 1 }, name: 123 }, - { collection: "bar", database: "test", keys: "foo", name: "my-index" }, - ]); - - const validateIndex = async (collection: string, expected: { name: string; key: object }[]): Promise => { - const mongoClient = integration.mongoClient(); - const collections = await mongoClient.db(integration.randomDbName()).listCollections().toArray(); - expect(collections).toHaveLength(1); - expect(collections[0]?.name).toEqual("coll1"); - const indexes = await mongoClient.db(integration.randomDbName()).collection(collection).indexes(); - expect(indexes).toHaveLength(expected.length + 1); - expect(indexes[0]?.name).toEqual("_id_"); - for (const index of expected) { - const foundIndex = indexes.find((i) => i.name === index.name); - expectDefined(foundIndex); - expect(foundIndex.key).toEqual(index.key); - } - }; - - it("creates the namespace if necessary", async () => { - await integration.connectMcpClient(); - const response = await integration.mcpClient().callTool({ - name: "create-index", - arguments: { - database: integration.randomDbName(), - collection: "coll1", - keys: { prop1: 1 }, - name: "my-index", - }, - }); +import { ObjectId, type IndexDirection } from "mongodb"; +import { beforeEach, describe, expect, it } from "vitest"; +import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; - const content = getResponseContent(response.content); - expect(content).toEqual( - `Created the index "my-index" on collection "coll1" in database "${integration.randomDbName()}"` - ); +describeWithMongoDB("createIndex tool when search is not enabled", (integration) => { + it("doesn't allow creating vector search indexes", async () => { + expect(integration.mcpServer().userConfig.voyageApiKey).toEqual(""); - await validateIndex("coll1", [{ name: "my-index", key: { prop1: 1 } }]); - }); + const { tools } = await integration.mcpClient().listTools(); + const createIndexTool = tools.find((tool) => tool.name === "create-index"); + const definitionProperty = createIndexTool?.inputSchema.properties?.definition as { + type: string; + items: { anyOf: Array<{ properties: Record> }> }; + }; + expectDefined(definitionProperty); - it("generates a name if not provided", async () => { - await integration.connectMcpClient(); - const response = await integration.mcpClient().callTool({ - name: "create-index", - arguments: { database: integration.randomDbName(), collection: "coll1", keys: { prop1: 1 } }, - }); + expect(definitionProperty.type).toEqual("array"); - const content = getResponseContent(response.content); - expect(content).toEqual( - `Created the index "prop1_1" on collection "coll1" in database "${integration.randomDbName()}"` - ); - await validateIndex("coll1", [{ name: "prop1_1", key: { prop1: 1 } }]); + // Because search is not enabled, the only available index definition is 'classic' + // We expect 1 option in the anyOf array where type is "classic" + expect(definitionProperty.items.anyOf).toHaveLength(1); + expect(definitionProperty.items.anyOf?.[0]?.properties?.type).toEqual({ type: "string", const: "classic" }); + expect(definitionProperty.items.anyOf?.[0]?.properties?.keys).toBeDefined(); }); +}); - it("can create multiple indexes in the same collection", async () => { - await integration.connectMcpClient(); - let response = await integration.mcpClient().callTool({ - name: "create-index", - arguments: { database: integration.randomDbName(), collection: "coll1", keys: { prop1: 1 } }, - }); +describeWithMongoDB( + "createIndex tool when search is enabled", + (integration) => { + it("allows creating vector search indexes", async () => { + expect(integration.mcpServer().userConfig.voyageApiKey).not.toEqual(""); + + const { tools } = await integration.mcpClient().listTools(); + const createIndexTool = tools.find((tool) => tool.name === "create-index"); + const definitionProperty = createIndexTool?.inputSchema.properties?.definition as { + type: string; + items: { anyOf: Array<{ properties: Record> }> }; + }; + expectDefined(definitionProperty); + + expect(definitionProperty.type).toEqual("array"); + + // Because search is now enabled, we should see both "classic" and "vectorSearch" options in + // the anyOf array. + expect(definitionProperty.items.anyOf).toHaveLength(2); + expect(definitionProperty.items.anyOf?.[0]?.properties?.type).toEqual({ type: "string", const: "classic" }); + expect(definitionProperty.items.anyOf?.[0]?.properties?.keys).toBeDefined(); + expect(definitionProperty.items.anyOf?.[1]?.properties?.type).toEqual({ + type: "string", + const: "vectorSearch", + }); + expect(definitionProperty.items.anyOf?.[1]?.properties?.fields).toBeDefined(); - expect(getResponseContent(response.content)).toEqual( - `Created the index "prop1_1" on collection "coll1" in database "${integration.randomDbName()}"` - ); + const fields = definitionProperty.items.anyOf?.[1]?.properties?.fields as { + type: string; + items: { anyOf: Array<{ type: string; properties: Record> }> }; + }; - response = await integration.mcpClient().callTool({ - name: "create-index", - arguments: { database: integration.randomDbName(), collection: "coll1", keys: { prop2: -1 } }, + expect(fields.type).toEqual("array"); + expect(fields.items.anyOf).toHaveLength(2); + expect(fields.items.anyOf?.[0]?.type).toEqual("object"); + expect(fields.items.anyOf?.[0]?.properties?.type).toEqual({ type: "string", const: "filter" }); + expectDefined(fields.items.anyOf?.[0]?.properties?.path); + + expect(fields.items.anyOf?.[1]?.type).toEqual("object"); + expect(fields.items.anyOf?.[1]?.properties?.type).toEqual({ type: "string", const: "vector" }); + expectDefined(fields.items.anyOf?.[1]?.properties?.path); + expectDefined(fields.items.anyOf?.[1]?.properties?.quantization); + expectDefined(fields.items.anyOf?.[1]?.properties?.numDimensions); + expectDefined(fields.items.anyOf?.[1]?.properties?.similarity); }); + }, + { + getUserConfig: () => { + return { + ...defaultTestConfig, + voyageApiKey: "valid_key", + }; + }, + } +); - expect(getResponseContent(response.content)).toEqual( - `Created the index "prop2_-1" on collection "coll1" in database "${integration.randomDbName()}"` - ); +describeWithMongoDB( + "createIndex tool with classic indexes", + (integration) => { + validateToolMetadata(integration, "create-index", "Create an index for a collection", [ + ...databaseCollectionParameters, + { + name: "definition", + type: "array", + description: + "The index definition. Use 'classic' for standard indexes and 'vectorSearch' for vector search indexes", + required: true, + }, + { + name: "name", + type: "string", + description: "The name of the index", + required: false, + }, + ]); - await validateIndex("coll1", [ - { name: "prop1_1", key: { prop1: 1 } }, - { name: "prop2_-1", key: { prop2: -1 } }, + validateThrowsForInvalidArguments(integration, "create-index", [ + {}, + { collection: "bar", database: 123, definition: [{ type: "classic", keys: { foo: 1 } }] }, + { collection: [], database: "test", definition: [{ type: "classic", keys: { foo: 1 } }] }, + { collection: "bar", database: "test", definition: [{ type: "classic", keys: { foo: 1 } }], name: 123 }, + { + collection: "bar", + database: "test", + definition: [{ type: "unknown", keys: { foo: 1 } }], + name: "my-index", + }, + { + collection: "bar", + database: "test", + definition: [{ type: "vectorSearch", fields: { foo: 1 } }], + }, + { + collection: "bar", + database: "test", + definition: [{ type: "vectorSearch", fields: [] }], + }, + { + collection: "bar", + database: "test", + definition: [{ type: "vectorSearch", fields: [{ type: "vector", path: "foo" }] }], + }, + { + collection: "bar", + database: "test", + definition: [{ type: "vectorSearch", fields: [{ type: "filter", path: "foo" }] }], + }, + { + collection: "bar", + database: "test", + definition: [ + { + type: "vectorSearch", + fields: [ + { type: "vector", path: "foo", numDimensions: 128 }, + { type: "filter", path: "bar", numDimensions: 128 }, + ], + }, + ], + }, ]); - }); - it("can create multiple indexes on the same property", async () => { - await integration.connectMcpClient(); - let response = await integration.mcpClient().callTool({ - name: "create-index", - arguments: { database: integration.randomDbName(), collection: "coll1", keys: { prop1: 1 } }, + const validateIndex = async (collection: string, expected: { name: string; key: object }[]): Promise => { + const mongoClient = integration.mongoClient(); + const collections = await mongoClient.db(integration.randomDbName()).listCollections().toArray(); + expect(collections).toHaveLength(1); + expect(collections[0]?.name).toEqual("coll1"); + const indexes = await mongoClient.db(integration.randomDbName()).collection(collection).indexes(); + expect(indexes).toHaveLength(expected.length + 1); + expect(indexes[0]?.name).toEqual("_id_"); + for (const index of expected) { + const foundIndex = indexes.find((i) => i.name === index.name); + expectDefined(foundIndex); + expect(foundIndex.key).toEqual(index.key); + } + }; + + it("creates the namespace if necessary", async () => { + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "create-index", + arguments: { + database: integration.randomDbName(), + collection: "coll1", + definition: [ + { + type: "classic", + keys: { prop1: 1 }, + }, + ], + name: "my-index", + }, + }); + + const content = getResponseContent(response.content); + expect(content).toEqual( + `Created the index "my-index" on collection "coll1" in database "${integration.randomDbName()}"` + ); + + await validateIndex("coll1", [{ name: "my-index", key: { prop1: 1 } }]); }); - expect(getResponseContent(response.content)).toEqual( - `Created the index "prop1_1" on collection "coll1" in database "${integration.randomDbName()}"` - ); + it("generates a name if not provided", async () => { + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "create-index", + arguments: { + database: integration.randomDbName(), + collection: "coll1", + definition: [{ type: "classic", keys: { prop1: 1 } }], + }, + }); - response = await integration.mcpClient().callTool({ - name: "create-index", - arguments: { database: integration.randomDbName(), collection: "coll1", keys: { prop1: -1 } }, + const content = getResponseContent(response.content); + expect(content).toEqual( + `Created the index "prop1_1" on collection "coll1" in database "${integration.randomDbName()}"` + ); + await validateIndex("coll1", [{ name: "prop1_1", key: { prop1: 1 } }]); }); - expect(getResponseContent(response.content)).toEqual( - `Created the index "prop1_-1" on collection "coll1" in database "${integration.randomDbName()}"` - ); + it("can create multiple indexes in the same collection", async () => { + await integration.connectMcpClient(); + let response = await integration.mcpClient().callTool({ + name: "create-index", + arguments: { + database: integration.randomDbName(), + collection: "coll1", + definition: [{ type: "classic", keys: { prop1: 1 } }], + }, + }); - await validateIndex("coll1", [ - { name: "prop1_1", key: { prop1: 1 } }, - { name: "prop1_-1", key: { prop1: -1 } }, - ]); - }); + expect(getResponseContent(response.content)).toEqual( + `Created the index "prop1_1" on collection "coll1" in database "${integration.randomDbName()}"` + ); - it("doesn't duplicate indexes", async () => { - await integration.connectMcpClient(); - let response = await integration.mcpClient().callTool({ - name: "create-index", - arguments: { database: integration.randomDbName(), collection: "coll1", keys: { prop1: 1 } }, - }); + response = await integration.mcpClient().callTool({ + name: "create-index", + arguments: { + database: integration.randomDbName(), + collection: "coll1", + definition: [{ type: "classic", keys: { prop2: -1 } }], + }, + }); - expect(getResponseContent(response.content)).toEqual( - `Created the index "prop1_1" on collection "coll1" in database "${integration.randomDbName()}"` - ); + expect(getResponseContent(response.content)).toEqual( + `Created the index "prop2_-1" on collection "coll1" in database "${integration.randomDbName()}"` + ); - response = await integration.mcpClient().callTool({ - name: "create-index", - arguments: { database: integration.randomDbName(), collection: "coll1", keys: { prop1: 1 } }, + await validateIndex("coll1", [ + { name: "prop1_1", key: { prop1: 1 } }, + { name: "prop2_-1", key: { prop2: -1 } }, + ]); }); - expect(getResponseContent(response.content)).toEqual( - `Created the index "prop1_1" on collection "coll1" in database "${integration.randomDbName()}"` - ); + it("can create multiple indexes on the same property", async () => { + await integration.connectMcpClient(); + let response = await integration.mcpClient().callTool({ + name: "create-index", + arguments: { + database: integration.randomDbName(), + collection: "coll1", + definition: [{ type: "classic", keys: { prop1: 1 } }], + }, + }); + + expect(getResponseContent(response.content)).toEqual( + `Created the index "prop1_1" on collection "coll1" in database "${integration.randomDbName()}"` + ); - await validateIndex("coll1", [{ name: "prop1_1", key: { prop1: 1 } }]); - }); + response = await integration.mcpClient().callTool({ + name: "create-index", + arguments: { + database: integration.randomDbName(), + collection: "coll1", + definition: [{ type: "classic", keys: { prop1: -1 } }], + }, + }); - const testCases: { name: string; direction: IndexDirection }[] = [ - { name: "descending", direction: -1 }, - { name: "ascending", direction: 1 }, - { name: "hashed", direction: "hashed" }, - { name: "text", direction: "text" }, - { name: "geoHaystack", direction: "2dsphere" }, - { name: "geo2d", direction: "2d" }, - ]; - - for (const { name, direction } of testCases) { - it(`creates ${name} index`, async () => { + expect(getResponseContent(response.content)).toEqual( + `Created the index "prop1_-1" on collection "coll1" in database "${integration.randomDbName()}"` + ); + + await validateIndex("coll1", [ + { name: "prop1_1", key: { prop1: 1 } }, + { name: "prop1_-1", key: { prop1: -1 } }, + ]); + }); + + it("doesn't duplicate indexes", async () => { await integration.connectMcpClient(); - const response = await integration.mcpClient().callTool({ + let response = await integration.mcpClient().callTool({ name: "create-index", - arguments: { database: integration.randomDbName(), collection: "coll1", keys: { prop1: direction } }, + arguments: { + database: integration.randomDbName(), + collection: "coll1", + definition: [{ type: "classic", keys: { prop1: 1 } }], + }, }); expect(getResponseContent(response.content)).toEqual( - `Created the index "prop1_${direction}" on collection "coll1" in database "${integration.randomDbName()}"` + `Created the index "prop1_1" on collection "coll1" in database "${integration.randomDbName()}"` ); - let expectedKey: object = { prop1: direction }; - if (direction === "text") { - expectedKey = { - _fts: "text", - _ftsx: 1, - }; - } - await validateIndex("coll1", [{ name: `prop1_${direction}`, key: expectedKey }]); + response = await integration.mcpClient().callTool({ + name: "create-index", + arguments: { + database: integration.randomDbName(), + collection: "coll1", + definition: [{ type: "classic", keys: { prop1: 1 } }], + }, + }); + + expect(getResponseContent(response.content)).toEqual( + `Created the index "prop1_1" on collection "coll1" in database "${integration.randomDbName()}"` + ); + + await validateIndex("coll1", [{ name: "prop1_1", key: { prop1: 1 } }]); }); + + it("fails to create a vector search index", async () => { + await integration.connectMcpClient(); + const collection = new ObjectId().toString(); + await integration + .mcpServer() + .session.serviceProvider.createCollection(integration.randomDbName(), collection); + + const response = await integration.mcpClient().callTool({ + name: "create-index", + arguments: { + database: integration.randomDbName(), + collection, + name: "vector_1_vector", + definition: [ + { + type: "vectorSearch", + fields: [ + { type: "vector", path: "vector_1", numDimensions: 4 }, + { type: "filter", path: "category" }, + ], + }, + ], + }, + }); + + const content = getResponseContent(response.content); + expect(content).toContain("The connected MongoDB deployment does not support vector search indexes."); + expect(response.isError).toBe(true); + }); + + const testCases: { name: string; direction: IndexDirection }[] = [ + { name: "descending", direction: -1 }, + { name: "ascending", direction: 1 }, + { name: "hashed", direction: "hashed" }, + { name: "text", direction: "text" }, + { name: "geoHaystack", direction: "2dsphere" }, + { name: "geo2d", direction: "2d" }, + ]; + + for (const { name, direction } of testCases) { + it(`creates ${name} index`, async () => { + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "create-index", + arguments: { + database: integration.randomDbName(), + collection: "coll1", + definition: [{ type: "classic", keys: { prop1: direction } }], + }, + }); + + expect(getResponseContent(response.content)).toEqual( + `Created the index "prop1_${direction}" on collection "coll1" in database "${integration.randomDbName()}"` + ); + + let expectedKey: object = { prop1: direction }; + if (direction === "text") { + expectedKey = { + _fts: "text", + _ftsx: 1, + }; + } + await validateIndex("coll1", [{ name: `prop1_${direction}`, key: expectedKey }]); + }); + } + + validateAutoConnectBehavior(integration, "create-index", () => { + return { + args: { + database: integration.randomDbName(), + collection: "coll1", + definition: [{ type: "classic", keys: { prop1: 1 } }], + }, + expectedResponse: `Created the index "prop1_1" on collection "coll1" in database "${integration.randomDbName()}"`, + }; + }); + }, + { + getUserConfig: () => { + return { + ...defaultTestConfig, + voyageApiKey: "valid_key", + }; + }, } +); - validateAutoConnectBehavior(integration, "create-index", () => { - return { - args: { - database: integration.randomDbName(), - collection: "coll1", - keys: { prop1: 1 }, - }, - expectedResponse: `Created the index "prop1_1" on collection "coll1" in database "${integration.randomDbName()}"`, - }; - }); -}); +describeWithMongoDB( + "createIndex tool with vector search indexes", + (integration) => { + let provider: NodeDriverServiceProvider; + + beforeEach(async ({ signal }) => { + await integration.connectMcpClient(); + provider = integration.mcpServer().session.serviceProvider; + await waitUntilSearchIsReady(provider, signal); + }); + + describe("when the collection does not exist", () => { + it("throws an error", async () => { + const response = await integration.mcpClient().callTool({ + name: "create-index", + arguments: { + database: integration.randomDbName(), + collection: "foo", + definition: [ + { + type: "vectorSearch", + fields: [ + { type: "vector", path: "vector_1", numDimensions: 4 }, + { type: "filter", path: "category" }, + ], + }, + ], + }, + }); + + const content = getResponseContent(response.content); + expect(content).toContain(`Collection '${integration.randomDbName()}.foo' does not exist`); + }); + }); + + describe("when the database does not exist", () => { + it("throws an error", async () => { + const response = await integration.mcpClient().callTool({ + name: "create-index", + arguments: { + database: "nonexistent_db", + collection: "foo", + definition: [ + { + type: "vectorSearch", + fields: [{ type: "vector", path: "vector_1", numDimensions: 4 }], + }, + ], + }, + }); + + const content = getResponseContent(response.content); + expect(content).toContain(`Collection 'nonexistent_db.foo' does not exist`); + }); + }); + + describe("when the collection exists", () => { + it("creates the index", async () => { + const collection = new ObjectId().toString(); + await provider.createCollection(integration.randomDbName(), collection); + const response = await integration.mcpClient().callTool({ + name: "create-index", + arguments: { + database: integration.randomDbName(), + collection, + name: "vector_1_vector", + definition: [ + { + type: "vectorSearch", + fields: [ + { type: "vector", path: "vector_1", numDimensions: 4 }, + { type: "filter", path: "category" }, + ], + }, + ], + }, + }); + + const content = getResponseContent(response.content); + expect(content).toEqual( + `Created the index "vector_1_vector" on collection "${collection}" in database "${integration.randomDbName()}"` + ); + + const indexes = await provider.getSearchIndexes(integration.randomDbName(), collection); + expect(indexes).toHaveLength(1); + expect(indexes[0]?.name).toEqual("vector_1_vector"); + expect(indexes[0]?.type).toEqual("vectorSearch"); + expect(indexes[0]?.status).toEqual("PENDING"); + expect(indexes[0]?.queryable).toEqual(false); + expect(indexes[0]?.latestDefinition).toEqual({ + fields: [ + { type: "vector", path: "vector_1", numDimensions: 4, similarity: "cosine" }, + { type: "filter", path: "category" }, + ], + }); + }); + }); + }, + { + getUserConfig: () => { + return { + ...defaultTestConfig, + voyageApiKey: "valid_key", + }; + }, + downloadOptions: { + search: true, + }, + } +); diff --git a/tests/integration/tools/mongodb/mongodbHelpers.ts b/tests/integration/tools/mongodb/mongodbHelpers.ts index 7c6da487..57959864 100644 --- a/tests/integration/tools/mongodb/mongodbHelpers.ts +++ b/tests/integration/tools/mongodb/mongodbHelpers.ts @@ -10,12 +10,14 @@ import { defaultTestConfig, defaultDriverOptions, getDataFromUntrustedContent, + sleep, } from "../../helpers.js"; import type { UserConfig, DriverOptions } from "../../../../src/common/config.js"; import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it } from "vitest"; import { EJSON } from "bson"; import { MongoDBClusterProcess } from "./mongodbClusterProcess.js"; import type { MongoClusterConfiguration } from "./mongodbClusterProcess.js"; +import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import type { createMockElicitInput, MockClientCapabilities } from "../../../utils/elicitationMocks.js"; const __dirname = path.dirname(fileURLToPath(import.meta.url)); @@ -278,3 +280,57 @@ export async function getServerVersion(integration: MongoDBIntegrationTestCase): const serverStatus = await client.db("admin").admin().serverStatus(); return serverStatus.version as string; } + +const SEARCH_RETRIES = 200; + +export async function waitUntilSearchIsReady( + provider: NodeDriverServiceProvider, + abortSignal: AbortSignal +): Promise { + let lastError: unknown = null; + + for (let i = 0; i < SEARCH_RETRIES && !abortSignal.aborted; i++) { + try { + await provider.insertOne("tmp", "test", { field1: "yay" }); + await provider.createSearchIndexes("tmp", "test", [{ definition: { mappings: { dynamic: true } } }]); + await provider.dropCollection("tmp", "test"); + return; + } catch (err) { + lastError = err; + await sleep(100); + } + } + + throw new Error(`Search Management Index is not ready.\nlastError: ${JSON.stringify(lastError)}`); +} + +export async function waitUntilSearchIndexIsQueryable( + provider: NodeDriverServiceProvider, + database: string, + collection: string, + indexName: string, + abortSignal: AbortSignal +): Promise { + let lastIndexStatus: unknown = null; + let lastError: unknown = null; + + for (let i = 0; i < SEARCH_RETRIES && !abortSignal.aborted; i++) { + try { + const [indexStatus] = await provider.getSearchIndexes(database, collection, indexName); + lastIndexStatus = indexStatus; + + if (indexStatus?.queryable === true) { + return; + } + } catch (err) { + lastError = err; + await sleep(100); + } + } + + throw new Error( + `Index ${indexName} in ${database}.${collection} is not ready: +lastIndexStatus: ${JSON.stringify(lastIndexStatus)} +lastError: ${JSON.stringify(lastError)}` + ); +} diff --git a/tests/integration/tools/mongodb/search/listSearchIndexes.test.ts b/tests/integration/tools/mongodb/search/listSearchIndexes.test.ts index fa69fa72..477f9fae 100644 --- a/tests/integration/tools/mongodb/search/listSearchIndexes.test.ts +++ b/tests/integration/tools/mongodb/search/listSearchIndexes.test.ts @@ -1,4 +1,9 @@ -import { describeWithMongoDB, getSingleDocFromUntrustedContent } from "../mongodbHelpers.js"; +import { + describeWithMongoDB, + getSingleDocFromUntrustedContent, + waitUntilSearchIndexIsQueryable, + waitUntilSearchIsReady, +} from "../mongodbHelpers.js"; import { describe, it, expect, beforeEach } from "vitest"; import { getResponseContent, @@ -6,13 +11,11 @@ import { validateToolMetadata, validateThrowsForInvalidArguments, databaseCollectionInvalidArgs, - sleep, getDataFromUntrustedContent, } from "../../../helpers.js"; import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import type { SearchIndexStatus } from "../../../../../src/tools/mongodb/search/listSearchIndexes.js"; -const SEARCH_RETRIES = 200; const SEARCH_TIMEOUT = 20_000; describeWithMongoDB("list search indexes tool in local MongoDB", (integration) => { @@ -98,7 +101,7 @@ describeWithMongoDB( "returns the list of existing indexes and detects if they are queryable", { timeout: SEARCH_TIMEOUT }, async ({ signal }) => { - await waitUntilIndexIsQueryable(provider, "any", "foo", "default", signal); + await waitUntilSearchIndexIsQueryable(provider, "any", "foo", "default", signal); const response = await integration.mcpClient().callTool({ name: "list-search-indexes", @@ -121,51 +124,3 @@ describeWithMongoDB( downloadOptions: { search: true }, } ); - -async function waitUntilSearchIsReady(provider: NodeDriverServiceProvider, abortSignal: AbortSignal): Promise { - let lastError: unknown = null; - - for (let i = 0; i < SEARCH_RETRIES && !abortSignal.aborted; i++) { - try { - await provider.insertOne("tmp", "test", { field1: "yay" }); - await provider.createSearchIndexes("tmp", "test", [{ definition: { mappings: { dynamic: true } } }]); - return; - } catch (err) { - lastError = err; - await sleep(100); - } - } - - throw new Error(`Search Management Index is not ready.\nlastError: ${JSON.stringify(lastError)}`); -} - -async function waitUntilIndexIsQueryable( - provider: NodeDriverServiceProvider, - database: string, - collection: string, - indexName: string, - abortSignal: AbortSignal -): Promise { - let lastIndexStatus: unknown = null; - let lastError: unknown = null; - - for (let i = 0; i < SEARCH_RETRIES && !abortSignal.aborted; i++) { - try { - const [indexStatus] = await provider.getSearchIndexes(database, collection, indexName); - lastIndexStatus = indexStatus; - - if (indexStatus?.queryable === true) { - return; - } - } catch (err) { - lastError = err; - await sleep(100); - } - } - - throw new Error( - `Index ${indexName} in ${database}.${collection} is not ready: -lastIndexStatus: ${JSON.stringify(lastIndexStatus)} -lastError: ${JSON.stringify(lastError)}` - ); -} diff --git a/tests/unit/common/session.test.ts b/tests/unit/common/session.test.ts index ea6ac348..40ae560d 100644 --- a/tests/unit/common/session.test.ts +++ b/tests/unit/common/session.test.ts @@ -134,7 +134,7 @@ describe("Session", () => { await session.connectToMongoDB({ connectionString: "mongodb://localhost:27017", }); - expect(await session.isSearchIndexSupported()).toEqual(true); + expect(await session.isConnectedToMongot).toEqual(true); }); it("should return false if listing search indexes fail with search error", async () => { @@ -142,7 +142,7 @@ describe("Session", () => { await session.connectToMongoDB({ connectionString: "mongodb://localhost:27017", }); - expect(await session.isSearchIndexSupported()).toEqual(false); + expect(await session.isConnectedToMongot).toEqual(false); }); }); }); diff --git a/tests/unit/resources/common/debug.test.ts b/tests/unit/resources/common/debug.test.ts index 3a4c68e2..c7cf8061 100644 --- a/tests/unit/resources/common/debug.test.ts +++ b/tests/unit/resources/common/debug.test.ts @@ -103,7 +103,7 @@ describe("debug resource", () => { }); it("should notify if a cluster supports search indexes", async () => { - session.isSearchIndexSupported = vi.fn().mockResolvedValue(true); + vi.spyOn(session, "isConnectedToMongot", "get").mockImplementation(() => Promise.resolve(true)); debugResource.reduceApply("connect", undefined); const output = await debugResource.toOutput();