diff --git a/libs/checkpoint-postgres/README.md b/libs/checkpoint-postgres/README.md index 75ecb95fa..0fc9ca1e9 100644 --- a/libs/checkpoint-postgres/README.md +++ b/libs/checkpoint-postgres/README.md @@ -19,7 +19,10 @@ const readConfig = { } }; -const checkpointer = PostgresSaver.fromConnString("postgresql://..."); +// you can optionally pass a configuration object as the second parameter +const checkpointer = PostgresSaver.fromConnString("postgresql://...", { + schema: "schema_name" // defaults to "public" +}); // You must call .setup() the first time you use the checkpointer: await checkpointer.setup(); diff --git a/libs/checkpoint-postgres/src/index.ts b/libs/checkpoint-postgres/src/index.ts index 31e82166a..0fe921307 100644 --- a/libs/checkpoint-postgres/src/index.ts +++ b/libs/checkpoint-postgres/src/index.ts @@ -12,15 +12,21 @@ import { } from "@langchain/langgraph-checkpoint"; import pg from "pg"; -import { MIGRATIONS } from "./migrations.js"; +import { getMigrations } from "./migrations.js"; import { - INSERT_CHECKPOINT_WRITES_SQL, - SELECT_SQL, - UPSERT_CHECKPOINT_BLOBS_SQL, - UPSERT_CHECKPOINT_WRITES_SQL, - UPSERT_CHECKPOINTS_SQL, + type SQL_STATEMENTS, + getSQLStatements, + getTablesWithSchema, } from "./sql.js"; +export interface PostgresSaverOptions { + schema?: string; +} + +interface DefaultPostgresSaverOptions extends PostgresSaverOptions { + schema: "public" | string; +} + const { Pool } = pg; /** @@ -35,7 +41,11 @@ const { Pool } = pg; * import { createReactAgent } from "@langchain/langgraph/prebuilt"; * * const checkpointer = PostgresSaver.fromConnString( - * "postgresql://user:password@localhost:5432/db" + * "postgresql://user:password@localhost:5432/db", + * // optional configuration object + * { + * schema: "custom_schema" // defaults to "public" + * }, * ); * * // NOTE: you need to call .setup() the first time you're using your checkpointer @@ -60,18 +70,48 @@ const { Pool } = pg; */ export class PostgresSaver extends BaseCheckpointSaver { private pool: pg.Pool; + private readonly options: DefaultPostgresSaverOptions = { + schema: "public", + }; + private readonly SQL_STATEMENTS: SQL_STATEMENTS; protected isSetup: boolean; - constructor(pool: pg.Pool, serde?: SerializerProtocol) { + constructor( + pool: pg.Pool, + serde?: SerializerProtocol, + options?: PostgresSaverOptions + ) { super(serde); this.pool = pool; this.isSetup = false; + this.options = { + ...this.options, + ...options, + }; + this.SQL_STATEMENTS = getSQLStatements(this.options.schema); } - static fromConnString(connString: string): PostgresSaver { + /** + * Creates a new instance of PostgresSaver from a connection string. + * + * @param {string} connString - The connection string to connect to the Postgres database. + * @param {PostgresSaverOptions} [options] - Optional configuration object. + * @returns {PostgresSaver} A new instance of PostgresSaver. + * + * @example + * const connString = "postgresql://user:password@localhost:5432/db"; + * const checkpointer = PostgresSaver.fromConnString(connString, { + * schema: "custom_schema" // defaults to "public" + * }); + * await checkpointer.setup(); + */ + static fromConnString( + connString: string, + options: PostgresSaverOptions = {} + ): PostgresSaver { const pool = new Pool({ connectionString: connString }); - return new PostgresSaver(pool); + return new PostgresSaver(pool, undefined, options); } /** @@ -83,11 +123,13 @@ export class PostgresSaver extends BaseCheckpointSaver { */ async setup(): Promise { const client = await this.pool.connect(); + const SCHEMA_TABLES = getTablesWithSchema(this.options.schema); try { + await client.query(`CREATE SCHEMA IF NOT EXISTS ${this.options.schema}`); let version = -1; try { const result = await client.query( - "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1" + `SELECT v FROM ${SCHEMA_TABLES.checkpoint_migrations} ORDER BY v DESC LIMIT 1` ); if (result.rows.length > 0) { version = result.rows[0].v; @@ -97,7 +139,7 @@ export class PostgresSaver extends BaseCheckpointSaver { // Assume table doesn't exist if there's an error if ( error?.message.includes( - 'relation "checkpoint_migrations" does not exist' + `relation "${SCHEMA_TABLES.checkpoint_migrations}" does not exist` ) ) { version = -1; @@ -106,10 +148,11 @@ export class PostgresSaver extends BaseCheckpointSaver { } } + const MIGRATIONS = getMigrations(this.options.schema); for (let v = version + 1; v < MIGRATIONS.length; v += 1) { await client.query(MIGRATIONS[v]); await client.query( - "INSERT INTO checkpoint_migrations (v) VALUES ($1)", + `INSERT INTO ${SCHEMA_TABLES.checkpoint_migrations} (v) VALUES ($1)`, [v] ); } @@ -317,7 +360,10 @@ export class PostgresSaver extends BaseCheckpointSaver { args = [thread_id, checkpoint_ns]; } - const result = await this.pool.query(SELECT_SQL + where, args); + const result = await this.pool.query( + this.SQL_STATEMENTS.SELECT_SQL + where, + args + ); const [row] = result.rows; @@ -370,7 +416,7 @@ export class PostgresSaver extends BaseCheckpointSaver { ): AsyncGenerator { const { filter, before, limit } = options ?? {}; const [where, args] = this._searchWhere(config, filter, before); - let query = `${SELECT_SQL}${where} ORDER BY checkpoint_id DESC`; + let query = `${this.SQL_STATEMENTS.SELECT_SQL}${where} ORDER BY checkpoint_id DESC`; if (limit !== undefined) { // eslint-disable-next-line @typescript-eslint/no-explicit-any query += ` LIMIT ${parseInt(limit as any, 10)}`; // sanitize via parseInt, as limit could be an externally provided value @@ -449,9 +495,12 @@ export class PostgresSaver extends BaseCheckpointSaver { newVersions ); for (const serializedBlob of serializedBlobs) { - await client.query(UPSERT_CHECKPOINT_BLOBS_SQL, serializedBlob); + await client.query( + this.SQL_STATEMENTS.UPSERT_CHECKPOINT_BLOBS_SQL, + serializedBlob + ); } - await client.query(UPSERT_CHECKPOINTS_SQL, [ + await client.query(this.SQL_STATEMENTS.UPSERT_CHECKPOINTS_SQL, [ thread_id, checkpoint_ns, checkpoint.id, @@ -483,8 +532,8 @@ export class PostgresSaver extends BaseCheckpointSaver { taskId: string ): Promise { const query = writes.every((w) => w[0] in WRITES_IDX_MAP) - ? UPSERT_CHECKPOINT_WRITES_SQL - : INSERT_CHECKPOINT_WRITES_SQL; + ? this.SQL_STATEMENTS.UPSERT_CHECKPOINT_WRITES_SQL + : this.SQL_STATEMENTS.INSERT_CHECKPOINT_WRITES_SQL; const dumpedWrites = this._dumpWrites( config.configurable?.thread_id, diff --git a/libs/checkpoint-postgres/src/migrations.ts b/libs/checkpoint-postgres/src/migrations.ts index b5daf41aa..71df75a7e 100644 --- a/libs/checkpoint-postgres/src/migrations.ts +++ b/libs/checkpoint-postgres/src/migrations.ts @@ -1,40 +1,45 @@ +import { getTablesWithSchema } from "./sql.js"; + /** - * To add a new migration, add a new string to the MIGRATIONS list. + * To add a new migration, add a new string to the list returned by the getMigrations function. * The position of the migration in the list is the version number. */ -export const MIGRATIONS = [ - `CREATE TABLE IF NOT EXISTS checkpoint_migrations ( - v INTEGER PRIMARY KEY -);`, - `CREATE TABLE IF NOT EXISTS checkpoints ( - thread_id TEXT NOT NULL, - checkpoint_ns TEXT NOT NULL DEFAULT '', - checkpoint_id TEXT NOT NULL, - parent_checkpoint_id TEXT, - type TEXT, - checkpoint JSONB NOT NULL, - metadata JSONB NOT NULL DEFAULT '{}', - PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) -);`, - `CREATE TABLE IF NOT EXISTS checkpoint_blobs ( - thread_id TEXT NOT NULL, - checkpoint_ns TEXT NOT NULL DEFAULT '', - channel TEXT NOT NULL, - version TEXT NOT NULL, - type TEXT NOT NULL, - blob BYTEA, - PRIMARY KEY (thread_id, checkpoint_ns, channel, version) -);`, - `CREATE TABLE IF NOT EXISTS checkpoint_writes ( - thread_id TEXT NOT NULL, - checkpoint_ns TEXT NOT NULL DEFAULT '', - checkpoint_id TEXT NOT NULL, - task_id TEXT NOT NULL, - idx INTEGER NOT NULL, - channel TEXT NOT NULL, - type TEXT, - blob BYTEA NOT NULL, - PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) -);`, - "ALTER TABLE checkpoint_blobs ALTER COLUMN blob DROP not null;", -]; +export const getMigrations = (schema: string) => { + const SCHEMA_TABLES = getTablesWithSchema(schema); + return [ + `CREATE TABLE IF NOT EXISTS ${SCHEMA_TABLES.checkpoint_migrations} ( + v INTEGER PRIMARY KEY + );`, + `CREATE TABLE IF NOT EXISTS ${SCHEMA_TABLES.checkpoints} ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + parent_checkpoint_id TEXT, + type TEXT, + checkpoint JSONB NOT NULL, + metadata JSONB NOT NULL DEFAULT '{}', + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) + );`, + `CREATE TABLE IF NOT EXISTS ${SCHEMA_TABLES.checkpoint_blobs} ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + channel TEXT NOT NULL, + version TEXT NOT NULL, + type TEXT NOT NULL, + blob BYTEA, + PRIMARY KEY (thread_id, checkpoint_ns, channel, version) + );`, + `CREATE TABLE IF NOT EXISTS ${SCHEMA_TABLES.checkpoint_writes} ( + thread_id TEXT NOT NULL, + checkpoint_ns TEXT NOT NULL DEFAULT '', + checkpoint_id TEXT NOT NULL, + task_id TEXT NOT NULL, + idx INTEGER NOT NULL, + channel TEXT NOT NULL, + type TEXT, + blob BYTEA NOT NULL, + PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) + );`, + `ALTER TABLE ${SCHEMA_TABLES.checkpoint_blobs} ALTER COLUMN blob DROP not null;`, + ]; +}; diff --git a/libs/checkpoint-postgres/src/sql.ts b/libs/checkpoint-postgres/src/sql.ts index 24ac8824d..9a1ca7925 100644 --- a/libs/checkpoint-postgres/src/sql.ts +++ b/libs/checkpoint-postgres/src/sql.ts @@ -1,7 +1,37 @@ import { TASKS } from "@langchain/langgraph-checkpoint"; -export const SELECT_SQL = ` -select +export interface SQL_STATEMENTS { + SELECT_SQL: string; + UPSERT_CHECKPOINT_BLOBS_SQL: string; + UPSERT_CHECKPOINTS_SQL: string; + UPSERT_CHECKPOINT_WRITES_SQL: string; + INSERT_CHECKPOINT_WRITES_SQL: string; +} + +interface TABLES { + checkpoints: string; + checkpoint_blobs: string; + checkpoint_writes: string; + checkpoint_migrations: string; +} + +export const getTablesWithSchema = (schema: string): TABLES => { + const tables = [ + "checkpoints", + "checkpoint_blobs", + "checkpoint_migrations", + "checkpoint_writes", + ]; + return tables.reduce((acc, table) => { + acc[table as keyof TABLES] = `${schema}.${table}`; + return acc; + }, {} as TABLES); +}; + +export const getSQLStatements = (schema: string): SQL_STATEMENTS => { + const SCHEMA_TABLES = getTablesWithSchema(schema); + return { + SELECT_SQL: `select thread_id, checkpoint, checkpoint_ns, @@ -9,58 +39,56 @@ select parent_checkpoint_id, metadata, ( - select array_agg(array[bl.channel::bytea, bl.type::bytea, bl.blob]) - from jsonb_each_text(checkpoint -> 'channel_versions') - inner join checkpoint_blobs bl - on bl.thread_id = checkpoints.thread_id - and bl.checkpoint_ns = checkpoints.checkpoint_ns - and bl.channel = jsonb_each_text.key - and bl.version = jsonb_each_text.value + select array_agg(array[bl.channel::bytea, bl.type::bytea, bl.blob]) + from jsonb_each_text(checkpoint -> 'channel_versions') + inner join ${SCHEMA_TABLES.checkpoint_blobs} bl + on bl.thread_id = cp.thread_id + and bl.checkpoint_ns = cp.checkpoint_ns + and bl.channel = jsonb_each_text.key + and bl.version = jsonb_each_text.value ) as channel_values, ( - select - array_agg(array[cw.task_id::text::bytea, cw.channel::bytea, cw.type::bytea, cw.blob] order by cw.task_id, cw.idx) - from checkpoint_writes cw - where cw.thread_id = checkpoints.thread_id - and cw.checkpoint_ns = checkpoints.checkpoint_ns - and cw.checkpoint_id = checkpoints.checkpoint_id + select + array_agg(array[cw.task_id::text::bytea, cw.channel::bytea, cw.type::bytea, cw.blob] order by cw.task_id, cw.idx) + from ${SCHEMA_TABLES.checkpoint_writes} cw + where cw.thread_id = cp.thread_id + and cw.checkpoint_ns = cp.checkpoint_ns + and cw.checkpoint_id = cp.checkpoint_id ) as pending_writes, ( - select array_agg(array[cw.type::bytea, cw.blob] order by cw.idx) - from checkpoint_writes cw - where cw.thread_id = checkpoints.thread_id - and cw.checkpoint_ns = checkpoints.checkpoint_ns - and cw.checkpoint_id = checkpoints.parent_checkpoint_id - and cw.channel = '${TASKS}' + select array_agg(array[cw.type::bytea, cw.blob] order by cw.idx) + from ${SCHEMA_TABLES.checkpoint_writes} cw + where cw.thread_id = cp.thread_id + and cw.checkpoint_ns = cp.checkpoint_ns + and cw.checkpoint_id = cp.parent_checkpoint_id + and cw.channel = '${TASKS}' ) as pending_sends -from checkpoints `; + from ${SCHEMA_TABLES.checkpoints} cp `, // <-- the trailing space is necessary for combining with WHERE clauses -export const UPSERT_CHECKPOINT_BLOBS_SQL = ` - INSERT INTO checkpoint_blobs (thread_id, checkpoint_ns, channel, version, type, blob) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (thread_id, checkpoint_ns, channel, version) DO NOTHING -`; + UPSERT_CHECKPOINT_BLOBS_SQL: `INSERT INTO ${SCHEMA_TABLES.checkpoint_blobs} (thread_id, checkpoint_ns, channel, version, type, blob) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (thread_id, checkpoint_ns, channel, version) DO NOTHING + `, -export const UPSERT_CHECKPOINTS_SQL = ` - INSERT INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, checkpoint, metadata) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id) - DO UPDATE SET - checkpoint = EXCLUDED.checkpoint, - metadata = EXCLUDED.metadata; -`; + UPSERT_CHECKPOINTS_SQL: `INSERT INTO ${SCHEMA_TABLES.checkpoints} (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, checkpoint, metadata) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id) + DO UPDATE SET + checkpoint = EXCLUDED.checkpoint, + metadata = EXCLUDED.metadata; + `, -export const UPSERT_CHECKPOINT_WRITES_SQL = ` - INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO UPDATE SET - channel = EXCLUDED.channel, - type = EXCLUDED.type, - blob = EXCLUDED.blob; -`; + UPSERT_CHECKPOINT_WRITES_SQL: `INSERT INTO ${SCHEMA_TABLES.checkpoint_writes} (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO UPDATE SET + channel = EXCLUDED.channel, + type = EXCLUDED.type, + blob = EXCLUDED.blob; + `, -export const INSERT_CHECKPOINT_WRITES_SQL = ` - INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO NOTHING -`; + INSERT_CHECKPOINT_WRITES_SQL: `INSERT INTO ${SCHEMA_TABLES.checkpoint_writes} (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO NOTHING + `, + }; +}; diff --git a/libs/checkpoint-postgres/src/tests/checkpoints.int.test.ts b/libs/checkpoint-postgres/src/tests/checkpoints.int.test.ts index fa08d6036..68389ebbf 100644 --- a/libs/checkpoint-postgres/src/tests/checkpoints.int.test.ts +++ b/libs/checkpoint-postgres/src/tests/checkpoints.int.test.ts @@ -48,14 +48,22 @@ const checkpoint2: Checkpoint = { pending_sends: [], }; -const postgresSavers: PostgresSaver[] = []; +const TEST_POSTGRES_URL = process.env.TEST_POSTGRES_URL; +if (!TEST_POSTGRES_URL) { + throw new Error("TEST_POSTGRES_URL environment variable is required"); +} -describe("PostgresSaver", () => { +let postgresSavers: PostgresSaver[] = []; + +describe.each([ + { schema: undefined, description: "the default schema" }, + { schema: "custom_schema", description: "a custom schema" }, +])("PostgresSaver with $description", ({ schema }) => { let postgresSaver: PostgresSaver; beforeEach(async () => { const pool = new Pool({ - connectionString: process.env.TEST_POSTGRES_URL, + connectionString: TEST_POSTGRES_URL, }); // Generate a unique database name const dbName = `lg_test_db_${Date.now()}_${Math.floor( @@ -65,12 +73,15 @@ describe("PostgresSaver", () => { try { // Create a new database await pool.query(`CREATE DATABASE ${dbName}`); + console.log(`Created database: ${dbName}`); // Connect to the new database - const dbConnectionString = `${process.env.TEST_POSTGRES_URL?.split("/") + const dbConnectionString = `${TEST_POSTGRES_URL?.split("/") .slice(0, -1) .join("/")}/${dbName}`; - postgresSaver = PostgresSaver.fromConnString(dbConnectionString); + postgresSaver = PostgresSaver.fromConnString(dbConnectionString, { + schema, + }); postgresSavers.push(postgresSaver); await postgresSaver.setup(); } finally { @@ -80,9 +91,11 @@ describe("PostgresSaver", () => { afterAll(async () => { await Promise.all(postgresSavers.map((saver) => saver.end())); + // clear the ended savers to clean up for the next test + postgresSavers = []; // Drop all test databases const pool = new Pool({ - connectionString: process.env.TEST_POSTGRES_URL, + connectionString: TEST_POSTGRES_URL, }); try {