diff --git a/pyproject.toml b/pyproject.toml index 8dcb37a5..412bef43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,8 @@ test = [ "mypy==1.11.2", "pytest-asyncio==0.24.0", "pytest==8.3.3", - "pytest-cov==6.0.0" + "pytest-cov==6.0.0", + "langgraph==0.2.74" ] langgraph = ["langgraph-checkpoint"] diff --git a/src/langchain_google_cloud_sql_pg/async_checkpoint.py b/src/langchain_google_cloud_sql_pg/async_checkpoint.py index 9a4b8adf..e2f36b68 100644 --- a/src/langchain_google_cloud_sql_pg/async_checkpoint.py +++ b/src/langchain_google_cloud_sql_pg/async_checkpoint.py @@ -430,3 +430,88 @@ async def alist( ), pending_writes=self._load_writes(value["pending_writes"]), ) + + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Asynchronously fetch a checkpoint tuple using the given configuration. + + Args: + config (RunnableConfig): Configuration specifying which checkpoint to retrieve. + + Returns: + Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found. + """ + + SELECT = f""" + SELECT + thread_id, + checkpoint, + checkpoint_ns, + checkpoint_id, + parent_checkpoint_id, + metadata, + type, + ( + SELECT array_agg(array[cw.task_id::text::bytea, cw.channel::bytea, cw.type::bytea, cw.blob] order by cw.task_id, cw.idx) + FROM "{self.schema_name}"."{self.table_name_writes}" cw + where cw.thread_id = c.thread_id + AND cw.checkpoint_ns = c.checkpoint_ns + AND cw.checkpoint_id = c.checkpoint_id + ) AS pending_writes, + ( + SELECT array_agg(array[cw.type::bytea, cw.blob] order by cw.task_path, cw.task_id, cw.idx) + FROM "{self.schema_name}"."{self.table_name_writes}" cw + WHERE cw.thread_id = c.thread_id + AND cw.checkpoint_ns = c.checkpoint_ns + AND cw.checkpoint_id = c.parent_checkpoint_id + AND cw.channel = '{TASKS}' + ) AS pending_sends + FROM "{self.schema_name}"."{self.table_name}" c + """ + + thread_id = config["configurable"]["thread_id"] + checkpoint_id = get_checkpoint_id(config) + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + if checkpoint_id: + args = { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + } + where = "WHERE thread_id = :thread_id AND checkpoint_ns = :checkpoint_ns AND checkpoint_id = :checkpoint_id" + else: + args = {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns} + where = "WHERE thread_id = :thread_id AND checkpoint_ns = :checkpoint_ns ORDER BY checkpoint_id DESC LIMIT 1" + + async with self.pool.connect() as conn: + result = await conn.execute(text(SELECT + where), args) + row = result.fetchone() + if not row: + return None + value = row._mapping + return CheckpointTuple( + config={ + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": value["checkpoint_id"], + } + }, + checkpoint=self.serde.loads_typed((value["type"], value["checkpoint"])), + metadata=( + self.jsonplus_serde.loads(value["metadata"]) + if value["metadata"] is not None + else {} + ), + parent_config=( + { + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": value["parent_checkpoint_id"], + } + } + if value["parent_checkpoint_id"] + else None + ), + pending_writes=self._load_writes(value["pending_writes"]), + ) diff --git a/src/langchain_google_cloud_sql_pg/engine.py b/src/langchain_google_cloud_sql_pg/engine.py index f19fe9ba..0e779c99 100644 --- a/src/langchain_google_cloud_sql_pg/engine.py +++ b/src/langchain_google_cloud_sql_pg/engine.py @@ -40,7 +40,6 @@ USER_AGENT = "langchain-google-cloud-sql-pg-python/" + __version__ CHECKPOINTS_TABLE = "checkpoints" -CHECKPOINT_WRITES_TABLE = "checkpoint_writes" async def _get_iam_principal_email( @@ -759,6 +758,8 @@ async def _ainit_checkpoint_table( Args: schema_name (str): The schema name to store the checkpoint tables. Default: "public". + table_name (str): The PgSQL database table name. + Default: "checkpoints". Returns: None @@ -800,6 +801,8 @@ async def ainit_checkpoint_table( Args: schema_name (str): The schema name to store checkpoint tables. Default: "public". + table_name (str): The PgSQL database table name. + Default: "checkpoints". Returns: None @@ -819,6 +822,8 @@ def init_checkpoint_table( Args: schema_name (str): The schema name to store checkpoint tables. Default: "public". + table_name (str): The PgSQL database table name. + Default: "checkpoints". Returns: None diff --git a/tests/test_async_checkpoint.py b/tests/test_async_checkpoint.py index 310c4098..42b727fb 100644 --- a/tests/test_async_checkpoint.py +++ b/tests/test_async_checkpoint.py @@ -39,6 +39,12 @@ empty_checkpoint, ) from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer +from langgraph.prebuilt import ( + ToolNode, + ValidationNode, + create_react_agent, + tools_condition, +) from sqlalchemy import text from sqlalchemy.engine.row import RowMapping @@ -47,6 +53,7 @@ write_config: RunnableConfig = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} read_config: RunnableConfig = {"configurable": {"thread_id": "1"}} +thread_agent_config: RunnableConfig = {"configurable": {"thread_id": "123"}} project_id = os.environ["PROJECT_ID"] region = os.environ["REGION"] @@ -313,6 +320,80 @@ async def test_checkpoint_alist( } == {"", "inner"} +class FakeToolCallingModel(BaseChatModel): + tool_calls: Optional[list[list[ToolCall]]] = None + index: int = 0 + tool_style: Literal["openai", "anthropic"] = "openai" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + messages_string = "-".join( + [str(m.content) for m in messages if isinstance(m.content, str)] + ) + tool_calls = ( + self.tool_calls[self.index % len(self.tool_calls)] + if self.tool_calls + else [] + ) + message = AIMessage( + content=messages_string, + id=str(self.index), + tool_calls=tool_calls.copy(), + ) + self.index += 1 + return ChatResult(generations=[ChatGeneration(message=message)]) + + @property + def _llm_type(self) -> str: + return "fake-tool-call-model" + + +@pytest.mark.asyncio +async def test_checkpoint_aget_tuple( + checkpointer: AsyncPostgresSaver, +) -> None: + # from the tests in https://github.com/langchain-ai/langgraph/blob/909190cede6a80bb94a2d4cfe7dedc49ef0d4127/libs/langgraph/tests/test_prebuilt.py + model = FakeToolCallingModel() + + agent = create_react_agent(model, [], checkpointer=checkpointer) + inputs = [HumanMessage("hi?")] + response = await agent.ainvoke( + {"messages": inputs}, config=thread_agent_config, debug=True + ) + expected_response = {"messages": inputs + [AIMessage(content="hi?", id="0")]} + assert response == expected_response + + def _AnyIdHumanMessage(**kwargs: Any) -> HumanMessage: + """Create a human message with an any id field.""" + message = HumanMessage(**kwargs) + message.id = AnyStr() + return message + + saved = await checkpointer.aget_tuple(thread_agent_config) + assert saved is not None + assert saved.checkpoint["channel_values"] == { + "messages": [ + _AnyIdHumanMessage(content="hi?"), + AIMessage(content="hi?", id="0"), + ], + "agent": "agent", + } + assert saved.metadata == { + "parents": {}, + "source": "loop", + "writes": {"agent": {"messages": [AIMessage(content="hi?", id="0")]}}, + "step": 1, + "thread_id": "123", + } + assert saved.pending_writes == [] + + @pytest.mark.asyncio async def test_metadata( checkpointer: AsyncPostgresSaver, diff --git a/tests/test_engine.py b/tests/test_engine.py index 670fc6ce..04634358 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -28,10 +28,6 @@ from sqlalchemy.pool import NullPool from langchain_google_cloud_sql_pg import Column, PostgresEngine -from langchain_google_cloud_sql_pg.engine import ( - CHECKPOINT_WRITES_TABLE, - CHECKPOINTS_TABLE, -) DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_") CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_")