Skip to content

Commit 6a6a25e

Browse files
Add context truncation logic for react (#7780)
1 parent 83ba5c4 commit 6a6a25e

File tree

3 files changed

+90
-16
lines changed

3 files changed

+90
-16
lines changed

dspy/adapters/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from abc import ABC, abstractmethod
22

3+
from litellm import ContextWindowExceededError
4+
35
from dspy.utils.callback import with_callbacks
46

7+
58
class Adapter(ABC):
69
def __init__(self, callbacks=None):
710
self.callbacks = callbacks or []
@@ -40,6 +43,9 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs):
4043
return values
4144

4245
except Exception as e:
46+
if isinstance(e, ContextWindowExceededError):
47+
# On context window exceeded error, we don't want to retry with a different adapter.
48+
raise e
4349
from .json_adapter import JSONAdapter
4450
if not isinstance(self, JSONAdapter):
4551
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs)

dspy/predict/react.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
import logging
12
from typing import Any, Callable, Literal, get_origin
23

4+
from litellm import ContextWindowExceededError
35
from pydantic import BaseModel
46

57
import dspy
68
from dspy.primitives.program import Module
79
from dspy.primitives.tool import Tool
810
from dspy.signatures.signature import ensure_signature
911

12+
logger = logging.getLogger(__name__)
13+
1014

1115
class ReAct(Module):
1216
def __init__(self, signature, tools: list[Callable], max_iters=5):
@@ -32,15 +36,11 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
3236
]
3337
)
3438

35-
finish_desc = (
36-
f"Signals that the final outputs, i.e. {outputs}, are now available and marks the task as complete."
37-
)
38-
finish_args = {} # k: v.annotation for k, v in signature.output_fields.items()}
3939
tools["finish"] = Tool(
4040
func=lambda **kwargs: "Completed.",
4141
name="finish",
42-
desc=finish_desc,
43-
args=finish_args,
42+
desc=f"Signals that the final outputs, i.e. {outputs}, are now available and marks the task as complete.",
43+
args={},
4444
)
4545

4646
for idx, tool in enumerate(tools.values()):
@@ -66,18 +66,15 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
6666
self.react = dspy.Predict(react_signature)
6767
self.extract = dspy.ChainOfThought(fallback_signature)
6868

69-
def forward(self, **input_args):
70-
def format(trajectory: dict[str, Any], last_iteration: bool):
71-
adapter = dspy.settings.adapter or dspy.ChatAdapter()
72-
trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x")
73-
return adapter.format_fields(trajectory_signature, trajectory, role="user")
69+
def _format_trajectory(self, trajectory: dict[str, Any]):
70+
adapter = dspy.settings.adapter or dspy.ChatAdapter()
71+
trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x")
72+
return adapter.format_fields(trajectory_signature, trajectory, role="user")
7473

74+
def forward(self, **input_args):
7575
trajectory = {}
7676
for idx in range(self.max_iters):
77-
pred = self.react(
78-
**input_args,
79-
trajectory=format(trajectory, last_iteration=(idx == self.max_iters - 1)),
80-
)
77+
pred = self._call_with_potential_trajectory_truncation(self.react, trajectory, **input_args)
8178

8279
trajectory[f"thought_{idx}"] = pred.next_thought
8380
trajectory[f"tool_name_{idx}"] = pred.next_tool_name
@@ -102,9 +99,38 @@ def format(trajectory: dict[str, Any], last_iteration: bool):
10299
if pred.next_tool_name == "finish":
103100
break
104101

105-
extract = self.extract(**input_args, trajectory=format(trajectory, last_iteration=False))
102+
extract = self._call_with_potential_trajectory_truncation(self.extract, trajectory, **input_args)
106103
return dspy.Prediction(trajectory=trajectory, **extract)
107104

105+
def _call_with_potential_trajectory_truncation(self, module, trajectory, **input_args):
106+
while True:
107+
try:
108+
return module(
109+
**input_args,
110+
trajectory=self._format_trajectory(trajectory),
111+
)
112+
except ContextWindowExceededError:
113+
logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.")
114+
trajectory = self.truncate_trajectory(trajectory)
115+
116+
def truncate_trajectory(self, trajectory):
117+
"""Truncates the trajectory so that it fits in the context window.
118+
119+
Users can override this method to implement their own truncation logic.
120+
"""
121+
keys = list(trajectory.keys())
122+
if len(keys) < 4:
123+
# Every tool call has 4 keys: thought, tool_name, tool_args, and observation.
124+
raise ValueError(
125+
"The trajectory is too long so your prompt exceeded the context window, but the trajectory cannot be "
126+
"truncated because it only has one tool call."
127+
)
128+
129+
for key in keys[:4]:
130+
trajectory.pop(key)
131+
132+
return trajectory
133+
108134

109135
"""
110136
Thoughts and Planned Improvements for dspy.ReAct.

tests/predict/test_react.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import dspy
66
from dspy.predict import react
77
from dspy.utils.dummies import DummyLM, dummy_rm
8+
import litellm
89

910
# def test_example_no_tools():
1011
# # Create a simple dataset which the model will use with the Retrieve tool.
@@ -260,3 +261,44 @@ def foo(a, b):
260261
"observation_1": "Completed.",
261262
}
262263
assert outputs.trajectory == expected_trajectory
264+
265+
266+
def test_trajectory_truncation():
267+
# Create a simple tool for testing
268+
def echo(text: str) -> str:
269+
return f"Echoed: {text}"
270+
271+
# Create ReAct instance with our echo tool
272+
react = dspy.ReAct("input_text -> output_text", tools=[echo])
273+
274+
# Mock react.react to simulate multiple tool calls
275+
call_count = 0
276+
277+
def mock_react(**kwargs):
278+
nonlocal call_count
279+
call_count += 1
280+
281+
if call_count < 3:
282+
# First 2 calls use the echo tool
283+
return dspy.Prediction(
284+
next_thought=f"Thought {call_count}",
285+
next_tool_name="echo",
286+
next_tool_args={"text": f"Text {call_count}"},
287+
)
288+
elif call_count == 3:
289+
# The 3rd call raises context window exceeded error
290+
raise litellm.ContextWindowExceededError("Context window exceeded", "dummy_model", "dummy_provider")
291+
else:
292+
# The 4th call finishes
293+
return dspy.Prediction(next_thought="Final thought", next_tool_name="finish", next_tool_args={})
294+
295+
react.react = mock_react
296+
react.extract = lambda **kwargs: dspy.Prediction(output_text="Final output")
297+
298+
# Call forward and get the result
299+
result = react(input_text="test input")
300+
301+
# Verify that older entries in the trajectory were truncated
302+
assert "thought_0" not in result.trajectory
303+
assert "thought_2" in result.trajectory
304+
assert result.output_text == "Final output"

0 commit comments

Comments
 (0)