1+ import logging
12from typing import Any , Callable , Literal , get_origin
23
4+ from litellm import ContextWindowExceededError
35from pydantic import BaseModel
46
57import dspy
68from dspy .primitives .program import Module
79from dspy .primitives .tool import Tool
810from dspy .signatures .signature import ensure_signature
911
12+ logger = logging .getLogger (__name__ )
13+
1014
1115class ReAct (Module ):
1216 def __init__ (self , signature , tools : list [Callable ], max_iters = 5 ):
@@ -32,15 +36,11 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
3236 ]
3337 )
3438
35- finish_desc = (
36- f"Signals that the final outputs, i.e. { outputs } , are now available and marks the task as complete."
37- )
38- finish_args = {} # k: v.annotation for k, v in signature.output_fields.items()}
3939 tools ["finish" ] = Tool (
4040 func = lambda ** kwargs : "Completed." ,
4141 name = "finish" ,
42- desc = finish_desc ,
43- args = finish_args ,
42+ desc = f"Signals that the final outputs, i.e. { outputs } , are now available and marks the task as complete." ,
43+ args = {} ,
4444 )
4545
4646 for idx , tool in enumerate (tools .values ()):
@@ -66,18 +66,15 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
6666 self .react = dspy .Predict (react_signature )
6767 self .extract = dspy .ChainOfThought (fallback_signature )
6868
69- def forward (self , ** input_args ):
70- def format (trajectory : dict [str , Any ], last_iteration : bool ):
71- adapter = dspy .settings .adapter or dspy .ChatAdapter ()
72- trajectory_signature = dspy .Signature (f"{ ', ' .join (trajectory .keys ())} -> x" )
73- return adapter .format_fields (trajectory_signature , trajectory , role = "user" )
69+ def _format_trajectory (self , trajectory : dict [str , Any ]):
70+ adapter = dspy .settings .adapter or dspy .ChatAdapter ()
71+ trajectory_signature = dspy .Signature (f"{ ', ' .join (trajectory .keys ())} -> x" )
72+ return adapter .format_fields (trajectory_signature , trajectory , role = "user" )
7473
74+ def forward (self , ** input_args ):
7575 trajectory = {}
7676 for idx in range (self .max_iters ):
77- pred = self .react (
78- ** input_args ,
79- trajectory = format (trajectory , last_iteration = (idx == self .max_iters - 1 )),
80- )
77+ pred = self ._call_with_potential_trajectory_truncation (self .react , trajectory , ** input_args )
8178
8279 trajectory [f"thought_{ idx } " ] = pred .next_thought
8380 trajectory [f"tool_name_{ idx } " ] = pred .next_tool_name
@@ -102,9 +99,38 @@ def format(trajectory: dict[str, Any], last_iteration: bool):
10299 if pred .next_tool_name == "finish" :
103100 break
104101
105- extract = self .extract ( ** input_args , trajectory = format ( trajectory , last_iteration = False ) )
102+ extract = self ._call_with_potential_trajectory_truncation ( self . extract , trajectory , ** input_args )
106103 return dspy .Prediction (trajectory = trajectory , ** extract )
107104
105+ def _call_with_potential_trajectory_truncation (self , module , trajectory , ** input_args ):
106+ while True :
107+ try :
108+ return module (
109+ ** input_args ,
110+ trajectory = self ._format_trajectory (trajectory ),
111+ )
112+ except ContextWindowExceededError :
113+ logger .warning ("Trajectory exceeded the context window, truncating the oldest tool call information." )
114+ trajectory = self .truncate_trajectory (trajectory )
115+
116+ def truncate_trajectory (self , trajectory ):
117+ """Truncates the trajectory so that it fits in the context window.
118+
119+ Users can override this method to implement their own truncation logic.
120+ """
121+ keys = list (trajectory .keys ())
122+ if len (keys ) < 4 :
123+ # Every tool call has 4 keys: thought, tool_name, tool_args, and observation.
124+ raise ValueError (
125+ "The trajectory is too long so your prompt exceeded the context window, but the trajectory cannot be "
126+ "truncated because it only has one tool call."
127+ )
128+
129+ for key in keys [:4 ]:
130+ trajectory .pop (key )
131+
132+ return trajectory
133+
108134
109135"""
110136Thoughts and Planned Improvements for dspy.ReAct.
0 commit comments