Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(checkpoint-postgres): Add support for providing a custom schema during initialization #838

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3645b0f
updated migrations to accept a custom schema value
nick-w-nick Feb 1, 2025
ad1c263
updated sql queries to accept a custom schema value
nick-w-nick Feb 1, 2025
41f0eb6
updated setup method to use the custom schema provided
nick-w-nick Feb 1, 2025
2975b4f
updated checkpoint saver methods to use the custom schema queries
nick-w-nick Feb 1, 2025
614a1eb
added log for database creation
nick-w-nick Feb 1, 2025
99f176c
added custom schema param
nick-w-nick Feb 1, 2025
fb0c0df
added clarifying comment regarding trailing space to prevent it from …
nick-w-nick Feb 1, 2025
292a765
updated test to run both with and without a custom schema to test def…
nick-w-nick Feb 1, 2025
e13d6d7
updated readme to include how to provide a custom schema
nick-w-nick Feb 1, 2025
98556b6
added jsdoc for the fromConnString method
nick-w-nick Feb 1, 2025
1f9bf14
moved duplicated env var to a central variable + added error to catch…
nick-w-nick Feb 1, 2025
e6d8498
updated test descriptions
nick-w-nick Feb 1, 2025
70992c0
lint fixes
nick-w-nick Feb 1, 2025
1b266ff
updated references to table names to use a generated constant to ensu…
nick-w-nick Feb 1, 2025
299822a
updated function name
nick-w-nick Feb 2, 2025
f1b7eac
lint fixes
nick-w-nick Feb 2, 2025
d59efbb
added schema to jsdoc example
nick-w-nick Feb 2, 2025
b4c4590
updated second parameter to be an options object rather than set to a…
nick-w-nick Feb 2, 2025
40fc1cd
updated test to use config object
nick-w-nick Feb 2, 2025
38892e6
updated readme to include options object
nick-w-nick Feb 2, 2025
9721ecf
updated js doc
nick-w-nick Feb 2, 2025
a5ee62e
lint fixes
nick-w-nick Feb 2, 2025
edf3387
replaced utility function with default options interface
nick-w-nick Feb 2, 2025
a69640a
updated jsdoc
nick-w-nick Feb 2, 2025
e69f58e
Merge branch 'main' into feature/add-schema-support-to-postgres
nick-w-nick Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion libs/checkpoint-postgres/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ const readConfig = {
}
};

const checkpointer = PostgresSaver.fromConnString("postgresql://...");
// you can optionally pass a configuration object as the second parameter
const checkpointer = PostgresSaver.fromConnString("postgresql://...", {
schema: "schema_name" // defaults to "public"
});

// You must call .setup() the first time you use the checkpointer:
await checkpointer.setup();
Expand Down
87 changes: 68 additions & 19 deletions libs/checkpoint-postgres/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,21 @@ import {
} from "@langchain/langgraph-checkpoint";
import pg from "pg";

import { MIGRATIONS } from "./migrations.js";
import { getMigrations } from "./migrations.js";
import {
INSERT_CHECKPOINT_WRITES_SQL,
SELECT_SQL,
UPSERT_CHECKPOINT_BLOBS_SQL,
UPSERT_CHECKPOINT_WRITES_SQL,
UPSERT_CHECKPOINTS_SQL,
type SQL_STATEMENTS,
getSQLStatements,
getTablesWithSchema,
} from "./sql.js";

export interface PostgresSaverOptions {
schema?: string;
}

interface DefaultPostgresSaverOptions extends PostgresSaverOptions {
schema: "public" | string;
}

const { Pool } = pg;

/**
Expand All @@ -35,7 +41,11 @@ const { Pool } = pg;
* import { createReactAgent } from "@langchain/langgraph/prebuilt";
*
* const checkpointer = PostgresSaver.fromConnString(
* "postgresql://user:password@localhost:5432/db"
* "postgresql://user:password@localhost:5432/db",
* // optional configuration object
* {
* schema: "custom_schema" // defaults to "public"
* },
* );
*
* // NOTE: you need to call .setup() the first time you're using your checkpointer
Expand All @@ -60,18 +70,48 @@ const { Pool } = pg;
*/
export class PostgresSaver extends BaseCheckpointSaver {
private pool: pg.Pool;
private readonly options: DefaultPostgresSaverOptions = {
schema: "public",
};
private readonly SQL_STATEMENTS: SQL_STATEMENTS;

protected isSetup: boolean;

constructor(pool: pg.Pool, serde?: SerializerProtocol) {
constructor(
pool: pg.Pool,
serde?: SerializerProtocol,
options?: PostgresSaverOptions
) {
super(serde);
this.pool = pool;
this.isSetup = false;
this.options = {
...this.options,
...options,
};
this.SQL_STATEMENTS = getSQLStatements(this.options.schema);
}

static fromConnString(connString: string): PostgresSaver {
/**
* Creates a new instance of PostgresSaver from a connection string.
*
* @param {string} connString - The connection string to connect to the Postgres database.
* @param {PostgresSaverOptions} [options] - Optional configuration object.
* @returns {PostgresSaver} A new instance of PostgresSaver.
*
* @example
* const connString = "postgresql://user:password@localhost:5432/db";
* const checkpointer = PostgresSaver.fromConnString(connString, {
* schema: "custom_schema" // defaults to "public"
* });
* await checkpointer.setup();
*/
static fromConnString(
connString: string,
options: PostgresSaverOptions = {}
): PostgresSaver {
const pool = new Pool({ connectionString: connString });
return new PostgresSaver(pool);
return new PostgresSaver(pool, undefined, options);
}

/**
Expand All @@ -83,11 +123,13 @@ export class PostgresSaver extends BaseCheckpointSaver {
*/
async setup(): Promise<void> {
const client = await this.pool.connect();
const SCHEMA_TABLES = getTablesWithSchema(this.options.schema);
try {
await client.query(`CREATE SCHEMA IF NOT EXISTS ${this.options.schema}`);
let version = -1;
try {
const result = await client.query(
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1"
`SELECT v FROM ${SCHEMA_TABLES.checkpoint_migrations} ORDER BY v DESC LIMIT 1`
);
if (result.rows.length > 0) {
version = result.rows[0].v;
Expand All @@ -97,7 +139,7 @@ export class PostgresSaver extends BaseCheckpointSaver {
// Assume table doesn't exist if there's an error
if (
error?.message.includes(
'relation "checkpoint_migrations" does not exist'
`relation "${SCHEMA_TABLES.checkpoint_migrations}" does not exist`
)
) {
version = -1;
Expand All @@ -106,10 +148,11 @@ export class PostgresSaver extends BaseCheckpointSaver {
}
}

const MIGRATIONS = getMigrations(this.options.schema);
for (let v = version + 1; v < MIGRATIONS.length; v += 1) {
await client.query(MIGRATIONS[v]);
await client.query(
"INSERT INTO checkpoint_migrations (v) VALUES ($1)",
`INSERT INTO ${SCHEMA_TABLES.checkpoint_migrations} (v) VALUES ($1)`,
[v]
);
}
Expand Down Expand Up @@ -317,7 +360,10 @@ export class PostgresSaver extends BaseCheckpointSaver {
args = [thread_id, checkpoint_ns];
}

const result = await this.pool.query(SELECT_SQL + where, args);
const result = await this.pool.query(
this.SQL_STATEMENTS.SELECT_SQL + where,
args
);

const [row] = result.rows;

Expand Down Expand Up @@ -370,7 +416,7 @@ export class PostgresSaver extends BaseCheckpointSaver {
): AsyncGenerator<CheckpointTuple> {
const { filter, before, limit } = options ?? {};
const [where, args] = this._searchWhere(config, filter, before);
let query = `${SELECT_SQL}${where} ORDER BY checkpoint_id DESC`;
let query = `${this.SQL_STATEMENTS.SELECT_SQL}${where} ORDER BY checkpoint_id DESC`;
if (limit !== undefined) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
query += ` LIMIT ${parseInt(limit as any, 10)}`; // sanitize via parseInt, as limit could be an externally provided value
Expand Down Expand Up @@ -449,9 +495,12 @@ export class PostgresSaver extends BaseCheckpointSaver {
newVersions
);
for (const serializedBlob of serializedBlobs) {
await client.query(UPSERT_CHECKPOINT_BLOBS_SQL, serializedBlob);
await client.query(
this.SQL_STATEMENTS.UPSERT_CHECKPOINT_BLOBS_SQL,
serializedBlob
);
}
await client.query(UPSERT_CHECKPOINTS_SQL, [
await client.query(this.SQL_STATEMENTS.UPSERT_CHECKPOINTS_SQL, [
thread_id,
checkpoint_ns,
checkpoint.id,
Expand Down Expand Up @@ -483,8 +532,8 @@ export class PostgresSaver extends BaseCheckpointSaver {
taskId: string
): Promise<void> {
const query = writes.every((w) => w[0] in WRITES_IDX_MAP)
? UPSERT_CHECKPOINT_WRITES_SQL
: INSERT_CHECKPOINT_WRITES_SQL;
? this.SQL_STATEMENTS.UPSERT_CHECKPOINT_WRITES_SQL
: this.SQL_STATEMENTS.INSERT_CHECKPOINT_WRITES_SQL;

const dumpedWrites = this._dumpWrites(
config.configurable?.thread_id,
Expand Down
79 changes: 42 additions & 37 deletions libs/checkpoint-postgres/src/migrations.ts
Original file line number Diff line number Diff line change
@@ -1,40 +1,45 @@
import { getTablesWithSchema } from "./sql.js";

/**
* To add a new migration, add a new string to the MIGRATIONS list.
* To add a new migration, add a new string to the list returned by the getMigrations function.
* The position of the migration in the list is the version number.
*/
export const MIGRATIONS = [
`CREATE TABLE IF NOT EXISTS checkpoint_migrations (
v INTEGER PRIMARY KEY
);`,
`CREATE TABLE IF NOT EXISTS checkpoints (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
parent_checkpoint_id TEXT,
type TEXT,
checkpoint JSONB NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}',
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
);`,
`CREATE TABLE IF NOT EXISTS checkpoint_blobs (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
channel TEXT NOT NULL,
version TEXT NOT NULL,
type TEXT NOT NULL,
blob BYTEA,
PRIMARY KEY (thread_id, checkpoint_ns, channel, version)
);`,
`CREATE TABLE IF NOT EXISTS checkpoint_writes (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
task_id TEXT NOT NULL,
idx INTEGER NOT NULL,
channel TEXT NOT NULL,
type TEXT,
blob BYTEA NOT NULL,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);`,
"ALTER TABLE checkpoint_blobs ALTER COLUMN blob DROP not null;",
];
export const getMigrations = (schema: string) => {
const SCHEMA_TABLES = getTablesWithSchema(schema);
return [
`CREATE TABLE IF NOT EXISTS ${SCHEMA_TABLES.checkpoint_migrations} (
v INTEGER PRIMARY KEY
);`,
`CREATE TABLE IF NOT EXISTS ${SCHEMA_TABLES.checkpoints} (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
parent_checkpoint_id TEXT,
type TEXT,
checkpoint JSONB NOT NULL,
metadata JSONB NOT NULL DEFAULT '{}',
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id)
);`,
`CREATE TABLE IF NOT EXISTS ${SCHEMA_TABLES.checkpoint_blobs} (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
channel TEXT NOT NULL,
version TEXT NOT NULL,
type TEXT NOT NULL,
blob BYTEA,
PRIMARY KEY (thread_id, checkpoint_ns, channel, version)
);`,
`CREATE TABLE IF NOT EXISTS ${SCHEMA_TABLES.checkpoint_writes} (
thread_id TEXT NOT NULL,
checkpoint_ns TEXT NOT NULL DEFAULT '',
checkpoint_id TEXT NOT NULL,
task_id TEXT NOT NULL,
idx INTEGER NOT NULL,
channel TEXT NOT NULL,
type TEXT,
blob BYTEA NOT NULL,
PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx)
);`,
`ALTER TABLE ${SCHEMA_TABLES.checkpoint_blobs} ALTER COLUMN blob DROP not null;`,
];
};
Loading