From f0d504dba6dbc4fbe4f743b4b7eed678d86e1665 Mon Sep 17 00:00:00 2001 From: Pascal Kienast Date: Sat, 14 Mar 2026 15:24:44 +0100 Subject: [PATCH] fix: improve exact-match search and forget reliability --- client.ts | 101 +++++++++++++++++++++++++++++++--- hooks/capture.ts | 114 +++++++++++++++++++++++++++++++++----- tests/capture.test.ts | 124 ++++++++++++++++++++++++++++++++++++++++++ tests/client.test.ts | 95 ++++++++++++++++++++++++++++++++ tsconfig.json | 1 + 5 files changed, 413 insertions(+), 22 deletions(-) create mode 100644 tests/capture.test.ts create mode 100644 tests/client.test.ts diff --git a/client.ts b/client.ts index fbb71a4..e67f4ee 100644 --- a/client.ts +++ b/client.ts @@ -32,6 +32,81 @@ function limitText(text: string, max: number): string { return text.length > max ? `${text.slice(0, max)}…` : text } +export function normalizeMemoryText(text: string | undefined): string { + return sanitizeContent(text ?? "") + .trim() + .replace(/\s+/g, " ") + .toLowerCase() +} + +function getQueryTokens(normalizedQuery: string): string[] { + if (!normalizedQuery) return [] + return normalizedQuery.split(/\s+/).filter(Boolean) +} + +function getTokenCoverage( + normalizedText: string, + queryTokens: string[], +): number { + if (queryTokens.length === 0) return 0 + const matched = queryTokens.filter((token) => + normalizedText.includes(token), + ).length + return matched / queryTokens.length +} + +export function findExactContentMatch( + results: SearchResult[], + query: string, +): SearchResult | undefined { + const normalizedQuery = normalizeMemoryText(query) + if (!normalizedQuery) return undefined + + return results.find( + (result) => + normalizeMemoryText(result.content || result.memory) === normalizedQuery, + ) +} + +export function rerankSearchResults( + query: string, + results: SearchResult[], +): SearchResult[] { + const normalizedQuery = normalizeMemoryText(query) + if (!normalizedQuery || results.length <= 1) return results + + const queryTokens = getQueryTokens(normalizedQuery) + + return [...results] + .map((result, index) => { + const normalizedContent = normalizeMemoryText( + result.content || result.memory, + ) + return { + index, + result, + exact: normalizedContent === normalizedQuery ? 1 : 0, + contains: normalizedContent.includes(normalizedQuery) ? 1 : 0, + tokenCoverage: getTokenCoverage(normalizedContent, queryTokens), + similarity: result.similarity ?? 0, + contentLength: normalizedContent.length || Number.MAX_SAFE_INTEGER, + } + }) + .sort((a, b) => { + if (b.exact !== a.exact) return b.exact - a.exact + if (b.contains !== a.contains) return b.contains - a.contains + if (b.tokenCoverage !== a.tokenCoverage) { + return b.tokenCoverage - a.tokenCoverage + } + if (b.similarity !== a.similarity) return b.similarity - a.similarity + if (a.contentLength !== b.contentLength) { + return a.contentLength - b.contentLength + } + return a.index - b.index + }) + .map(({ result }) => result) +} + export class SupermemoryClient { private client: Supermemory private containerTag: string @@ -90,27 +165,31 @@ export class SupermemoryClient { limit = 5, containerTag?: string, ): Promise { + const cleanedQuery = sanitizeContent(query) const tag = containerTag ?? this.containerTag log.debugRequest("search.memories", { - query, + query: cleanedQuery, limit, containerTag: tag, }) const response = await this.client.search.memories({ - q: query, + q: cleanedQuery, containerTag: tag, limit, + rerank: true, + rewriteQuery: false, }) - const results: SearchResult[] = (response.results ?? []).map((r) => ({ + const rawResults: SearchResult[] = (response.results ?? []).map((r) => ({ id: r.id, content: r.memory ?? "", memory: r.memory, similarity: r.similarity, metadata: r.metadata ?? undefined, })) + const results = rerankSearchResults(cleanedQuery, rawResults) log.debugResponse("search.memories", { count: results.length }) return results @@ -168,17 +247,25 @@ export class SupermemoryClient { query: string, containerTag?: string, ): Promise<{ success: boolean; message: string }> { - log.debugRequest("forgetByQuery", { query, containerTag }) + const cleanedQuery = sanitizeContent(query) + log.debugRequest("forgetByQuery", { query: cleanedQuery, containerTag }) - const results = await this.search(query, 5, containerTag) + const results = await this.search(cleanedQuery, 10, containerTag) if (results.length === 0) { return { success: false, message: "No matching memory found to forget." } } - const target = results[0] - await this.deleteMemory(target.id, containerTag) + const target = findExactContentMatch(results, cleanedQuery) ?? results[0] + const deleted = await this.deleteMemory(target.id, containerTag) const preview = limitText(target.content || target.memory || "", 100) + if (!deleted.forgotten) { + return { + success: false, + message: `Unable to confirm forgetting: "${preview}"`, + } + } + return { success: true, message: `Forgot: "${preview}"` } } diff --git a/hooks/capture.ts b/hooks/capture.ts index 9bfd930..bc589ed 100644 --- a/hooks/capture.ts +++ b/hooks/capture.ts @@ -4,6 +4,17 @@ import { log } from "../logger.ts" import { buildDocumentId } from "../memory.ts" const SKIPPED_PROVIDERS = ["exec-event", "cron-event", "heartbeat"] +const MEMORY_TOOL_PREFIX = "supermemory_" +const MEMORY_COMMAND_PREFIXES = ["/remember", "/recall"] +const MEMORY_TOOL_RESPONSE_PATTERNS = [ + /^Stored:\s*"/i, + /^Forgot:\s*"/i, + /^Found \d+ memories:/i, + /^No relevant memories found\.?$/i, + /^No matching memory found to forget\.?$/i, + /^Memory forgotten\.?$/i, + /^Provide a query or memoryId to forget\.?$/i, +] function getLastTurn(messages: unknown[]): unknown[] { let lastUserIdx = -1 @@ -21,6 +32,89 @@ function getLastTurn(messages: unknown[]): unknown[] { return lastUserIdx >= 0 ? messages.slice(lastUserIdx) : messages } +function collectTextParts(content: unknown): string[] { + const parts: string[] = [] + + if (typeof content === "string") { + parts.push(content) + return parts + } + + if (!Array.isArray(content)) return parts + + for (const block of content) { + if (!block || typeof block !== "object") continue + const b = block as Record + if (b.type === "text" && typeof b.text === "string") { + parts.push(b.text) + } + } + + return parts +} + +function messageReferencesSupermemoryTool( + msgObj: Record, +): boolean { + for (const key of ["name", "toolName"]) { + if ( + typeof msgObj[key] === "string" && + (msgObj[key] as string).startsWith(MEMORY_TOOL_PREFIX) + ) { + return true + } + } + + const content = msgObj.content + if (!Array.isArray(content)) return false + + for (const block of content) { + if (!block || typeof block !== "object") continue + const b = block as Record + if (typeof b.name === "string" && b.name.startsWith(MEMORY_TOOL_PREFIX)) { + return true + } + if ( + typeof b.toolName === "string" && + b.toolName.startsWith(MEMORY_TOOL_PREFIX) + ) { + return true + } + } + + return false +} + +export function isSupermemoryManagementTurn(messages: unknown[]): boolean { + for (const msg of messages) { + if (!msg || typeof msg !== "object") continue + const msgObj = msg as Record + + if (messageReferencesSupermemoryTool(msgObj)) { + return true + } + + for (const text of collectTextParts(msgObj.content)) { + const trimmed = text.trim() + const lower = trimmed.toLowerCase() + if ( + MEMORY_COMMAND_PREFIXES.some( + (prefix) => lower === prefix || lower.startsWith(`${prefix} `), + ) + ) { + return true + } + if ( + MEMORY_TOOL_RESPONSE_PATTERNS.some((pattern) => pattern.test(trimmed)) + ) { + return true + } + } + } + + return false +} + export function buildCaptureHandler( client: SupermemoryClient, cfg: SupermemoryConfig, @@ -46,6 +140,10 @@ export function buildCaptureHandler( return const lastTurn = getLastTurn(event.messages) + if (isSupermemoryManagementTurn(lastTurn)) { + log.debug("capture: skipping supermemory management turn") + return + } const texts: string[] = [] for (const msg of lastTurn) { @@ -54,21 +152,7 @@ export function buildCaptureHandler( const role = msgObj.role if (role !== "user" && role !== "assistant") continue - const content = msgObj.content - - const parts: string[] = [] - - if (typeof content === "string") { - parts.push(content) - } else if (Array.isArray(content)) { - for (const block of content) { - if (!block || typeof block !== "object") continue - const b = block as Record - if (b.type === "text" && typeof b.text === "string") { - parts.push(b.text) - } - } - } + const parts = collectTextParts(msgObj.content) if (parts.length > 0) { texts.push(`[role: ${role}]\n${parts.join("\n")}\n[${role}:end]`) diff --git a/tests/capture.test.ts b/tests/capture.test.ts new file mode 100644 index 0000000..f448e05 --- /dev/null +++ b/tests/capture.test.ts @@ -0,0 +1,124 @@ +import { describe, expect, it } from "bun:test" +import type { SupermemoryClient } from "../client.ts" +import type { SupermemoryConfig } from "../config.ts" +import { + buildCaptureHandler, + isSupermemoryManagementTurn, +} from "../hooks/capture.ts" + +const cfg: SupermemoryConfig = { + apiKey: undefined, + containerTag: "test_container", + autoRecall: true, + autoCapture: true, + maxRecallResults: 10, + profileFrequency: 50, + captureMode: "all", + entityContext: "test context", + debug: false, + enableCustomContainerTags: false, + customContainers: [], + customContainerInstructions: "", +} + +type CaptureArgs = [ + content: string, + metadata?: Record, + customId?: string, + containerTag?: string, + entityContext?: string, +] + +type CaptureClient = Pick + +describe("capture hook", () => { + it("detects turns that manage supermemory directly", () => { + const turn = [ + { role: "user", content: "Please forget this memory." }, + { + role: "assistant", + content: [ + { + type: "tool_call", + name: "supermemory_forget", + arguments: { query: "foo" }, + }, + ], + }, + { role: "assistant", content: 'Forgot: "foo"' }, + ] + + expect(isSupermemoryManagementTurn(turn)).toBe(true) + }) + + it("skips capturing turns that used supermemory tools", async () => { + const calls: CaptureArgs[] = [] + const client: CaptureClient = { + addMemory: async (...args) => { + calls.push(args) + return { id: "memory_1" } + }, + } + const handler = buildCaptureHandler( + client as unknown as SupermemoryClient, + cfg, + () => "session-123", + ) + + await handler( + { + success: true, + messages: [ + { role: "user", content: "Please forget this memory." }, + { + role: "assistant", + content: [ + { + type: "tool_call", + name: "supermemory_forget", + arguments: { query: "foo" }, + }, + ], + }, + { role: "assistant", content: 'Forgot: "foo"' }, + ], + }, + { messageProvider: "discord" }, + ) + + expect(calls.length).toBe(0) + }) + + it("still captures normal conversational turns", async () => { + const calls: CaptureArgs[] = [] + const client: CaptureClient = { + addMemory: async (...args) => { + calls.push(args) + return { id: "memory_1" } + }, + } + const handler = buildCaptureHandler( + client as unknown as SupermemoryClient, + cfg, + () => "session-123", + ) + + await handler( + { + success: true, + messages: [ + { role: "user", content: "My favorite editor is Helix." }, + { role: "assistant", content: "Got it." }, + ], + }, + { messageProvider: "discord" }, + ) + + expect(calls.length).toBe(1) + const [content, metadata, customId] = calls[0] + expect(content).toContain("My favorite editor is Helix.") + expect(content).toContain("Got it.") + expect(metadata).toEqual(expect.objectContaining({ source: "openclaw" })) + expect(customId).toBe("session_session_123") + }) +}) diff --git a/tests/client.test.ts b/tests/client.test.ts new file mode 100644 index 0000000..de8612d --- /dev/null +++ b/tests/client.test.ts @@ -0,0 +1,95 @@ +import { describe, expect, it } from "bun:test" +import type { SearchResult } from "../client.ts" +import { SupermemoryClient } from "../client.ts" + +const API_KEY = "sm_12345678901234567890" + +describe("SupermemoryClient", () => { + it("ranks exact literal matches ahead of higher-similarity fuzzy matches", async () => { + const query = "OPENCLAW_SUPERMEMORY_HEALTHCHECK_2026-03-14_1459" + const client = new SupermemoryClient(API_KEY, "test_container") + + Reflect.set(client as object, "client", { + search: { + memories: async () => ({ + results: [ + { + id: "allowlist", + memory: "openclaw-supermemory allowlist note", + similarity: 0.97, + metadata: null, + }, + { + id: "exact", + memory: query, + similarity: 0.25, + metadata: null, + }, + ], + }), + }, + }) + + const results = await client.search(query) + + expect(results[0]?.id).toBe("exact") + }) + + it("prefers an exact textual match when forgetting by query", async () => { + const query = "OPENCLAW_SUPERMEMORY_HEALTHCHECK_2026-03-14_1459" + const client = new SupermemoryClient(API_KEY, "test_container") + const deletedIds: string[] = [] + + Reflect.set( + client as object, + "search", + async (): Promise => [ + { + id: "allowlist", + content: "openclaw-supermemory allowlist note", + similarity: 0.99, + }, + { + id: "exact", + content: query, + similarity: 0.21, + }, + ], + ) + Reflect.set(client as object, "deleteMemory", async (id: string) => { + deletedIds.push(id) + return { id, forgotten: true } + }) + + const result = await client.forgetByQuery(query) + + expect(deletedIds).toEqual(["exact"]) + expect(result).toEqual({ success: true, message: `Forgot: "${query}"` }) + }) + + it("does not claim success when the delete result cannot be confirmed", async () => { + const query = "OPENCLAW_SUPERMEMORY_HEALTHCHECK_2026-03-14_1459" + const client = new SupermemoryClient(API_KEY, "test_container") + + Reflect.set( + client as object, + "search", + async (): Promise => [ + { + id: "exact", + content: query, + similarity: 0.42, + }, + ], + ) + Reflect.set(client as object, "deleteMemory", async (id: string) => ({ + id, + forgotten: false, + })) + + const result = await client.forgetByQuery(query) + + expect(result.success).toBe(false) + expect(result.message).toContain("Unable to confirm forgetting") + }) +}) diff --git a/tsconfig.json b/tsconfig.json index f55f8b4..2e47325 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -17,6 +17,7 @@ "tools/*.ts", "hooks/*.ts", "commands/*.ts", + "tests/*.ts", "types/*.d.ts", "lib/*.d.ts" ],