Skip to content

Commit 045c6cf

Browse files
committed
feat(gepa): implement tool-specific proposer for tool descriptions
- Add ToolProposer with GenerateImprovedToolDescription signature - Implement routing logic to separate tools from signatures - Tools use ToolProposer, signatures use custom or parent default - Backward compatible: preserves existing custom_instruction_proposer behavior - Add test verifying routing splits components correctly
1 parent aa53fe2 commit 045c6cf

File tree

3 files changed

+279
-18
lines changed

3 files changed

+279
-18
lines changed

dspy/teleprompt/gepa/gepa_utils.py

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515

1616
logger = logging.getLogger(__name__)
1717

18+
1819
class LoggerAdapter:
1920
def __init__(self, logger: logging.Logger):
2021
self.logger = logger
2122

2223
def log(self, x: str):
2324
self.logger.info(x)
2425

26+
2527
DSPyTrace = list[tuple[Any, dict[str, Any], Prediction]]
2628

2729

@@ -31,15 +33,17 @@ class ReflectiveExample(TypedDict):
3133
3234
Each example contains the predictor inputs, generated outputs, and feedback from evaluation.
3335
"""
34-
Inputs: dict[str, Any] # Predictor inputs (may include str, dspy.Image, etc.)
35-
Generated_Outputs: dict[str, Any] | str # Success: dict with output fields, Failure: error message string
36-
Feedback: str # Always a string - from metric function or parsing error message
36+
37+
Inputs: dict[str, Any] # Predictor inputs (may include str, dspy.Image, etc.)
38+
Generated_Outputs: dict[str, Any] | str # Success: dict with output fields, Failure: error message string
39+
Feedback: str # Always a string - from metric function or parsing error message
3740

3841

3942
class ScoreWithFeedback(Prediction):
4043
score: float
4144
feedback: str
4245

46+
4347
class PredictorFeedbackFn(Protocol):
4448
def __call__(
4549
predictor_output: dict[str, Any],
@@ -64,6 +68,7 @@ def __call__(
6468
"""
6569
...
6670

71+
6772
class DspyAdapter(GEPAAdapter[Example, TraceData, Prediction]):
6873
def __init__(
6974
self,
@@ -91,36 +96,80 @@ def __init__(
9196
self.warn_on_score_mismatch = warn_on_score_mismatch
9297
self.optimize_tool_descriptions = optimize_tool_descriptions
9398

94-
if self.custom_instruction_proposer is not None:
95-
# We are only overriding the propose_new_texts method when a custom
96-
# instruction proposer is provided. Otherwise, we use the GEPA
97-
# default propose_new_texts.
99+
if self.optimize_tool_descriptions or self.custom_instruction_proposer is not None:
100+
# Set up combined proposer for tool optimization and/or custom instruction proposer.
101+
# This routes components to appropriate proposers based on type:
102+
# - Signatures -> custom_instruction_proposer (if provided) OR parent default
103+
# - Tools -> ToolProposer (if optimize_tool_descriptions=True)
98104

99-
def custom_propose_new_texts(
105+
# Determine which proposer handles signatures
106+
if self.custom_instruction_proposer is not None:
107+
signature_proposer = self.custom_instruction_proposer
108+
else:
109+
signature_proposer = super().propose_new_texts
110+
111+
def propose_new_texts(
100112
candidate: dict[str, str],
101113
reflective_dataset: dict[str, list[dict[str, Any]]],
102-
components_to_update: list[str]
114+
components_to_update: list[str],
103115
) -> dict[str, str]:
116+
"""Propose new texts for both signatures and tools.
117+
118+
Splits components by type (tool: prefix vs signatures), calls appropriate
119+
proposers, and merges results. Handles reflection_lm context if provided.
120+
"""
121+
# Split by component type if tool optimization enabled
122+
if self.optimize_tool_descriptions:
123+
tool_components = [c for c in components_to_update if c.startswith("tool:")]
124+
sig_components = [c for c in components_to_update if not c.startswith("tool:")]
125+
else:
126+
tool_components = []
127+
sig_components = components_to_update
128+
129+
# Apply reflection_lm context to all proposer calls if provided
104130
if self.reflection_lm is not None:
105131
with dspy.context(lm=self.reflection_lm):
106-
return self.custom_instruction_proposer(
132+
sig_texts = signature_proposer(
107133
candidate=candidate,
108134
reflective_dataset=reflective_dataset,
109-
components_to_update=components_to_update
135+
components_to_update=sig_components,
110136
)
137+
138+
if tool_components:
139+
from .instruction_proposal import ToolProposer
140+
141+
tool_texts = ToolProposer()(
142+
candidate=candidate,
143+
reflective_dataset=reflective_dataset,
144+
components_to_update=tool_components,
145+
)
146+
return {**sig_texts, **tool_texts}
147+
else:
148+
return sig_texts
111149
else:
112-
return self.custom_instruction_proposer(
150+
sig_texts = signature_proposer(
113151
candidate=candidate,
114152
reflective_dataset=reflective_dataset,
115-
components_to_update=components_to_update
153+
components_to_update=sig_components,
116154
)
117155

118-
self.propose_new_texts = custom_propose_new_texts
156+
if tool_components:
157+
from .instruction_proposal import ToolProposer
158+
159+
tool_texts = ToolProposer()(
160+
candidate=candidate,
161+
reflective_dataset=reflective_dataset,
162+
components_to_update=tool_components,
163+
)
164+
return {**sig_texts, **tool_texts}
165+
else:
166+
return sig_texts
167+
168+
self.propose_new_texts = propose_new_texts
119169

120170
# Cache predictor names/signatures
121171
self.named_predictors = list(self.student.named_predictors())
122172

123-
124173
def build_program(self, candidate: dict[str, str]):
125174
new_prog = self.student.deepcopy()
126175
for name, pred in new_prog.named_predictors():
@@ -176,16 +225,19 @@ def evaluate(self, batch, candidate, capture_traces=False):
176225
return_all_scores=True,
177226
failure_score=self.failure_score,
178227
provide_traceback=True,
179-
max_errors=len(batch) * 100
228+
max_errors=len(batch) * 100,
180229
)
181230
res = evaluator(program)
182231
outputs = [r[1] for r in res.results]
183232
scores = [r[2] for r in res.results]
184233
scores = [s["score"] if hasattr(s, "score") else s for s in scores]
185234
return EvaluationBatch(outputs=outputs, scores=scores, trajectories=None)
186235

187-
def make_reflective_dataset(self, candidate, eval_batch, components_to_update) -> dict[str, list[ReflectiveExample]]:
236+
def make_reflective_dataset(
237+
self, candidate, eval_batch, components_to_update
238+
) -> dict[str, list[ReflectiveExample]]:
188239
from dspy.teleprompt.bootstrap_trace import FailedPrediction
240+
189241
program = self.build_program(candidate)
190242

191243
ret_d: dict[str, list[ReflectiveExample]] = {}
@@ -284,7 +336,9 @@ def make_reflective_dataset(self, candidate, eval_batch, components_to_update) -
284336
d["Feedback"] = fb["feedback"]
285337
if fb["score"] != module_score:
286338
if self.warn_on_score_mismatch:
287-
logger.warning("The score returned by the metric with pred_name is different from the overall metric score. This can indicate 2 things: Either the metric is non-deterministic (e.g., LLM-as-judge, Semantic score, etc.) or the metric returned a score specific to pred_name that differs from the module level score. Currently, GEPA does not support predictor level scoring (support coming soon), and only requires a feedback text to be provided, which can be specific to the predictor or program level. GEPA will ignore the differing score returned, and instead use module level score. You can safely ignore this warning if using a semantic metric, however, if this mismatch is caused due to predictor scoring, please return module-level scores. To disable this warning, set warn_on_score_mismatch=False.")
339+
logger.warning(
340+
"The score returned by the metric with pred_name is different from the overall metric score. This can indicate 2 things: Either the metric is non-deterministic (e.g., LLM-as-judge, Semantic score, etc.) or the metric returned a score specific to pred_name that differs from the module level score. Currently, GEPA does not support predictor level scoring (support coming soon), and only requires a feedback text to be provided, which can be specific to the predictor or program level. GEPA will ignore the differing score returned, and instead use module level score. You can safely ignore this warning if using a semantic metric, however, if this mismatch is caused due to predictor scoring, please return module-level scores. To disable this warning, set warn_on_score_mismatch=False."
341+
)
288342
self.warn_on_score_mismatch = False
289343
fb["score"] = module_score
290344

dspy/teleprompt/gepa/instruction_proposal.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,141 @@ def __call__(
310310
updated_components[component_name] = new_instruction
311311

312312
return updated_components
313+
314+
315+
class GenerateImprovedToolDescriptionFromFeedback(dspy.Signature):
316+
"""I provided an assistant with the following description for a tool:
317+
```
318+
<current_tool_description>
319+
```
320+
321+
This tool is available to the assistant. The following are examples of task inputs provided to the assistant, the assistant's decisions about which tools to use, and feedback on whether those decisions were correct:
322+
```
323+
<examples_with_feedback>
324+
```
325+
326+
Your task is to write a better description for this tool.
327+
328+
Read the examples carefully and identify patterns in when the tool was used successfully versus when it was misused or overlooked. Identify any domain-specific information about the tool's capabilities or appropriate usage that may not be available to the assistant in the future. The assistant may have developed effective patterns for tool selection - if so, ensure the tool description supports those patterns.
329+
330+
Provide the new tool description within ``` blocks."""
331+
332+
current_tool_description = dspy.InputField(desc="The current description of the tool")
333+
examples_with_feedback = dspy.InputField(desc="Examples showing tool usage decisions and feedback on correctness")
334+
335+
improved_tool_description = dspy.OutputField(
336+
desc="An improved description that helps with tool selection decisions"
337+
)
338+
339+
340+
class SingleComponentToolProposer(dspy.Module):
341+
"""dspy.Module for proposing improved tool descriptions based on feedback."""
342+
343+
def __init__(self):
344+
super().__init__()
345+
self.propose_description = dspy.Predict(GenerateImprovedToolDescriptionFromFeedback)
346+
347+
def forward(self, current_tool_description: str, reflective_dataset: list[ReflectiveExample]) -> str:
348+
"""Generate an improved tool description based on current description and feedback examples.
349+
350+
Args:
351+
current_tool_description: The current description of the tool
352+
reflective_dataset: List of examples with inputs, outputs, and feedback
353+
354+
Returns:
355+
str: Improved tool description text
356+
"""
357+
# Reuse formatting from SingleComponentMultiModalProposer
358+
formatted_examples, _ = self._format_examples_for_instruction_generation(reflective_dataset)
359+
360+
result = self.propose_description(
361+
current_tool_description=current_tool_description, examples_with_feedback=formatted_examples
362+
)
363+
364+
return result.improved_tool_description
365+
366+
def _format_examples_for_instruction_generation(
367+
self, reflective_dataset: list[ReflectiveExample]
368+
) -> tuple[str, dict[int, list[Type]]]:
369+
"""Format examples using GEPA's markdown structure.
370+
371+
Returns:
372+
tuple: (formatted_text, image_map) where image_map is always empty for tools
373+
"""
374+
375+
def render_value(value, level=3):
376+
if isinstance(value, dict):
377+
s = ""
378+
for k, v in value.items():
379+
s += f"{'#' * level} {k}\n"
380+
s += render_value(v, min(level + 1, 6))
381+
if not value:
382+
s += "\n"
383+
return s
384+
elif isinstance(value, (list, tuple)):
385+
s = ""
386+
for i, item in enumerate(value):
387+
s += f"{'#' * level} Item {i + 1}\n"
388+
s += render_value(item, min(level + 1, 6))
389+
if not value:
390+
s += "\n"
391+
return s
392+
else:
393+
return f"{str(value).strip()}\n\n"
394+
395+
def convert_sample_to_markdown(sample, example_num):
396+
s = f"# Example {example_num}\n"
397+
for key, val in sample.items():
398+
s += f"## {key}\n"
399+
s += render_value(val, level=3)
400+
return s
401+
402+
formatted_parts = []
403+
for i, example_data in enumerate(reflective_dataset):
404+
formatted_example = convert_sample_to_markdown(example_data, i + 1)
405+
formatted_parts.append(formatted_example)
406+
407+
formatted_text = "\n\n".join(formatted_parts)
408+
return formatted_text, {}
409+
410+
411+
class ToolProposer(ProposalFn):
412+
"""GEPA-compatible tool description proposer.
413+
414+
This class handles tool description optimization during GEPA optimization by using
415+
a single-component proposer for each tool that needs to be updated.
416+
"""
417+
418+
def __init__(self):
419+
self.single_proposer = SingleComponentToolProposer()
420+
421+
def __call__(
422+
self,
423+
candidate: dict[str, str],
424+
reflective_dataset: dict[str, list[ReflectiveExample]],
425+
components_to_update: list[str],
426+
) -> dict[str, str]:
427+
"""GEPA-compatible proposal function.
428+
429+
Args:
430+
candidate: Current component name -> description mapping
431+
reflective_dataset: Component name -> list of reflective examples
432+
components_to_update: List of component names to update
433+
434+
Returns:
435+
dict: Component name -> new description mapping
436+
"""
437+
updated_components = {}
438+
439+
for component_name in components_to_update:
440+
if component_name in candidate and component_name in reflective_dataset:
441+
current_description = candidate[component_name]
442+
component_reflective_data = reflective_dataset[component_name]
443+
444+
new_description = self.single_proposer(
445+
current_tool_description=current_description, reflective_dataset=component_reflective_data
446+
)
447+
448+
updated_components[component_name] = new_description
449+
450+
return updated_components

tests/teleprompt/test_gepa_tool_optimization.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,72 @@ def forward(self, question):
154154
assert "search" in optimized.subagent.tools
155155
assert "calculator" in optimized.main_agent.tools
156156
assert "spawn_subagent" in optimized.main_agent.tools
157+
158+
159+
def test_tool_and_signature_optimization_with_proposer_routing():
160+
"""Test that routing logic correctly splits tools and signatures."""
161+
from unittest.mock import Mock, patch
162+
163+
from dspy.teleprompt.gepa.gepa_utils import DspyAdapter
164+
165+
# Create module with BOTH signature and tools
166+
calc_tool = dspy.Tool(calculator, name="calculator", desc="Original calculator description")
167+
react = dspy.ReAct("question -> answer", tools=[calc_tool])
168+
169+
# Create adapter with tool optimization enabled
170+
adapter = DspyAdapter(
171+
student_module=react,
172+
metric_fn=simple_metric,
173+
feedback_map={},
174+
failure_score=0.0,
175+
optimize_tool_descriptions=True,
176+
reflection_lm=None,
177+
)
178+
179+
# Verify propose_new_texts was created
180+
assert hasattr(adapter, "propose_new_texts"), "Routing logic should have set propose_new_texts"
181+
182+
# Mock the ToolProposer to verify it gets called with tools only
183+
mock_tool_proposer_instance = Mock()
184+
mock_tool_proposer_instance.return_value = {"tool:calculator": "Improved calculator description"}
185+
186+
mock_tool_proposer_class = Mock(return_value=mock_tool_proposer_instance)
187+
188+
# Mock parent propose_new_texts to verify it gets called with signatures only
189+
mock_parent_propose = Mock(return_value={"react": "Improved signature instruction"})
190+
191+
with patch("dspy.teleprompt.gepa.instruction_proposal.ToolProposer", mock_tool_proposer_class):
192+
with patch.object(adapter.__class__.__bases__[0], "propose_new_texts", mock_parent_propose, create=True):
193+
# Rebuild adapter to pick up mocked parent
194+
adapter_with_mock = DspyAdapter(
195+
student_module=react,
196+
metric_fn=simple_metric,
197+
feedback_map={},
198+
failure_score=0.0,
199+
optimize_tool_descriptions=True,
200+
reflection_lm=None,
201+
)
202+
203+
candidate = {
204+
"react": "Original signature",
205+
"tool:calculator": "Original tool desc",
206+
}
207+
208+
reflective_dataset = {
209+
"react": [{"input": "test"}],
210+
"tool:calculator": [{"input": "calc"}],
211+
}
212+
213+
components = ["react", "tool:calculator"]
214+
215+
result = adapter_with_mock.propose_new_texts(candidate, reflective_dataset, components)
216+
217+
# Verify routing: ToolProposer was called with tools only
218+
assert mock_tool_proposer_instance.called, "ToolProposer should have been called"
219+
tool_call_args = mock_tool_proposer_instance.call_args[1]
220+
assert "tool:calculator" in tool_call_args["components_to_update"]
221+
assert "react" not in tool_call_args["components_to_update"]
222+
223+
# Verify both components in result
224+
assert "react" in result
225+
assert "tool:calculator" in result

0 commit comments

Comments
 (0)