-
Notifications
You must be signed in to change notification settings - Fork 2.5k
feat(gepa): add tool description optimization for multi-agent systems #8928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 35 commits
6412a5d
cf0be4f
aa53fe2
045c6cf
c4f2041
260ca80
04f7e3d
f92e184
7178869
e34703b
3f05311
4df9ce5
4296ccf
ea1204a
48d5cd6
548d9b6
e61d0a1
5c95412
19d7717
9ce5fe4
91331d0
3418b59
b26d39a
2791b5c
8e63c62
7a9d2f3
cd0de57
67bb739
b3026a7
4e107aa
78547e7
7fa829b
da0e7bc
e51158d
776ab9b
ec6bb7b
b6cc67b
1206f38
333cbbf
a50552a
965b157
5ddc6d3
2269de5
17456f0
c884c18
82dee25
ca84b9d
2eb8986
9f37ac1
bd4cdac
0ad4077
ef5563e
1b10b65
675a0cd
4a4d209
d84842f
bb28f5f
a590e46
6aceaf5
7a5bf05
12b01ed
265896c
fe19dac
38dd7cb
7f05a73
0a6016d
a635768
e35603a
ecb3726
d3693c9
a086646
0cecb75
9592c50
76d7af5
ac66e05
3ec4ada
b679ba2
02aa151
d37e433
d8b7c66
f62a68e
e031409
a133545
b1e4f3d
7f81e88
f267ccc
28ceb70
d8275ef
deeb010
4bcc714
4b872d7
5129586
ebe4221
2133b0b
9c05b6a
ec9241b
46d8f5e
5d33fc6
b564029
13209f5
09990a6
33fc771
fa72fc0
2a15e56
59f23e5
68d7021
d99ba1d
3fd9a0a
7d64e7a
4b3ee18
3a5fb7f
0e75d8c
734fbdf
1fb15ba
da2f6d0
a942246
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import inspect | ||
| import json | ||
| import logging | ||
| import random | ||
| from dataclasses import dataclass | ||
|
|
@@ -9,6 +10,7 @@ | |
| from gepa.proposer.reflective_mutation.base import ReflectionComponentSelector | ||
|
|
||
| from dspy.clients.lm import LM | ||
| from dspy.predict.react import ReAct | ||
| from dspy.primitives import Example, Module, Prediction | ||
| from dspy.teleprompt.gepa.gepa_utils import DspyAdapter, DSPyTrace, PredictorFeedbackFn, ScoreWithFeedback | ||
| from dspy.teleprompt.teleprompt import Teleprompter | ||
|
|
@@ -273,6 +275,11 @@ def metric( | |
| warn_on_score_mismatch: GEPA (currently) expects the metric to return the same module-level score when | ||
| called with and without the pred_name. This flag (defaults to True) determines whether a warning is | ||
| raised if a mismatch in module-level and predictor-level score is detected. | ||
| optimize_react_components: Whether to optimize ReAct module components including react | ||
| instructions, extract instructions, tool descriptions, and tool argument descriptions. | ||
| When enabled, GEPA jointly optimizes all four components of ReAct modules. See the | ||
| [ReAct Component Optimization guide](https://dspy.ai/api/optimizers/GEPA/GEPA_Advanced/#react-component-optimization) | ||
| for details on when to use this feature and how it works. Default is False. | ||
| seed: The random seed to use for reproducibility. Default is 0. | ||
| gepa_kwargs: (Optional) provide additional kwargs to be passed to [gepa.optimize](https://github.com/gepa-ai/gepa/blob/main/src/gepa/api.py) method | ||
|
|
||
|
|
@@ -328,6 +335,7 @@ def __init__( | |
| wandb_init_kwargs: dict[str, Any] | None = None, | ||
| track_best_outputs: bool = False, | ||
| warn_on_score_mismatch: bool = True, | ||
| optimize_react_components: bool = False, | ||
| use_mlflow: bool = False, | ||
| # Reproducibility | ||
| seed: int | None = 0, | ||
|
|
@@ -390,6 +398,7 @@ def __init__( | |
| self.wandb_api_key = wandb_api_key | ||
| self.wandb_init_kwargs = wandb_init_kwargs | ||
| self.warn_on_score_mismatch = warn_on_score_mismatch | ||
| self.optimize_react_components = optimize_react_components | ||
| self.use_mlflow = use_mlflow | ||
|
|
||
| if track_best_outputs: | ||
|
|
@@ -518,11 +527,55 @@ def feedback_fn( | |
| rng=rng, | ||
| reflection_lm=self.reflection_lm, | ||
| custom_instruction_proposer=self.custom_instruction_proposer, | ||
| warn_on_score_mismatch=self.warn_on_score_mismatch | ||
| warn_on_score_mismatch=self.warn_on_score_mismatch, | ||
| optimize_react_components=self.optimize_react_components, | ||
| ) | ||
|
|
||
| # Instantiate GEPA with the simpler adapter-based API | ||
| base_program = {name: pred.signature.instructions for name, pred in student.named_predictors()} | ||
|
|
||
| if self.optimize_react_components: | ||
| for module_path, module in student.named_sub_modules(): | ||
| # Only process ReAct modules | ||
| if not isinstance(module, ReAct): | ||
| continue | ||
| prefix = module_path.removeprefix("self.") if module_path != "self" else "" | ||
|
|
||
| # Get first predictor name as module identifier | ||
| for pred_name, _ in module.named_predictors(): | ||
| comp_name = pred_name if not prefix else f"{prefix}.{pred_name}" | ||
| module_key = f"react_module:{comp_name.split('.')[0]}" if prefix else "react_module" | ||
|
|
||
| # Build JSON config with tool args for reflection | ||
| config = { | ||
| "react": module.react.signature.instructions, | ||
| "extract": module.extract.predict.signature.instructions, | ||
| "tools": { | ||
| tool_name: { | ||
| "desc": tool.desc, | ||
| "args": tool.args, | ||
| "arg_desc": tool.arg_desc or {} | ||
| } | ||
| for tool_name, tool in module.tools.items() | ||
| if tool_name != "finish" | ||
| } | ||
| } | ||
|
|
||
| # Replace predictor keys with module key and extract key to prevent duplicates | ||
| base_program.pop(comp_name, None) | ||
| extract_key = f"{prefix}.extract.predict" if prefix else "extract.predict" | ||
| base_program.pop(extract_key, None) | ||
| base_program[module_key] = json.dumps(config, indent=2) | ||
| break | ||
|
|
||
| # Log base_program keys for debugging | ||
| logger.info(f"Initialized base_program with {len(base_program)} components:") | ||
| for key in sorted(base_program.keys()): | ||
| if key.startswith("react_module"): | ||
| logger.info(f" {key}: <ReAct module JSON config>") | ||
|
||
| else: | ||
| logger.info(f" {key}: <instruction>") | ||
|
|
||
| gepa_result: GEPAResult = optimize( | ||
| seed_candidate=base_program, | ||
| trainset=trainset, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of engineering this towards
dspy.ReAct, I recommend covering tool calling in general:optimize_react_components=>optimize_tools. ReAct is just one way for tool calling agent, and it's quite common for users to make customizations, and we may create other tool calling agent architectures in the near future.