1+ import json
12import logging
23import random
3- import json
4- from collections import defaultdict
5- from copy import deepcopy
64from typing import Any , Callable , Protocol , TypedDict
7- from dspy . predict . react import ReAct
5+
86from gepa import EvaluationBatch , GEPAAdapter
97from gepa .core .adapter import ProposalFn
108
1311from dspy .adapters .types import History
1412from dspy .adapters .types .base_type import Type
1513from dspy .evaluate import Evaluate
14+ from dspy .predict .react import ReAct
1615from dspy .primitives import Example , Prediction
1716from dspy .teleprompt .bootstrap_trace import TraceData
1817
@@ -137,7 +136,7 @@ def default_instruction_proposer(
137136 react_module_proposer = None
138137 if self .optimize_tool_descriptions :
139138 from .instruction_proposal import ReActModuleProposer
140-
139+
141140 react_module_proposer = ReActModuleProposer ()
142141
143142 def propose_component_texts (
@@ -160,7 +159,7 @@ def propose_component_texts(
160159 reflective_dataset = reflective_dataset ,
161160 components_to_update = components_to_update ,
162161 )
163-
162+
164163 # Otherwise, route to appropriate proposers
165164 # Separate react_module components from regular instruction components
166165 react_module_components = [c for c in components_to_update if c .startswith ("react_module" )]
@@ -188,7 +187,7 @@ def propose_component_texts(
188187 )
189188 )
190189
191- # Handle ReAct module components
190+ # Handle ReAct module components
192191 if react_module_components :
193192 logger .debug (f"Routing { len (react_module_components )} react_module components to react_module_proposer" )
194193 if self .reflection_lm is not None :
@@ -220,60 +219,60 @@ def propose_component_texts(
220219
221220 def build_program (self , candidate : dict [str , str ]):
222221 new_prog = self .student .deepcopy ()
223-
222+
224223 # Apply regular predictor instructions
225224 for name , pred in new_prog .named_predictors ():
226225 if name in candidate :
227226 pred .signature = pred .signature .with_instructions (candidate [name ])
228227
229228 # Apply ReAct module updates (JSON configs for ReAct modules: react, extract, tools)
230229 if self .optimize_tool_descriptions :
231-
230+
232231 for module_path , module in new_prog .named_sub_modules ():
233232 # Only process ReAct modules
234233 if not isinstance (module , ReAct ):
235234 continue
236-
235+
237236 # Build module key
238237 prefix = module_path .removeprefix ("self." ) if module_path != "self" else ""
239238 module_key = "react_module" if prefix == "" else f"react_module:{ prefix } "
240-
239+
241240 # Check if this module was optimized
242241 if module_key not in candidate :
243242 continue
244-
243+
245244 # Deserialize JSON containing optimized module configuration
246245 try :
247246 module_config = json .loads (candidate [module_key ])
248247 logger .debug (f"Applying optimized module config to { module_key } " )
249-
248+
250249 # Apply react instruction
251250 if "react" in module_config :
252251 module .react .signature = module .react .signature .with_instructions (module_config ["react" ])
253- logger .debug (f " Updated react instruction" )
254-
252+ logger .debug (" Updated react instruction" )
253+
255254 # Apply extract instruction
256255 if "extract" in module_config :
257256 module .extract .predict .signature = module .extract .predict .signature .with_instructions (module_config ["extract" ])
258- logger .debug (f " Updated extract instruction" )
259-
257+ logger .debug (" Updated extract instruction" )
258+
260259 # Apply tool descriptions
261260 if "tools" in module_config :
262261 for tool_name , tool_config in module_config ["tools" ].items ():
263262 tool = module .tools [tool_name ]
264-
263+
265264 # Update tool description
266265 if tool_config .get ("desc" ):
267266 tool .desc = tool_config ["desc" ]
268267 logger .debug (f" Updated tool '{ tool_name } ' description" )
269-
268+
270269 # Update tool arg descriptions
271270 arg_desc = tool_config .get ("arg_desc" )
272271 if arg_desc :
273272 tool .arg_desc = tool .arg_desc or {}
274273 tool .arg_desc .update (arg_desc )
275274 logger .debug (f" Updated tool '{ tool_name } ' arg descriptions: { list (arg_desc .keys ())} " )
276-
275+
277276 except json .JSONDecodeError as e :
278277 logger .error (f"Failed to parse JSON config for { module_key } : { e } " )
279278 raise
@@ -341,14 +340,14 @@ def make_reflective_dataset(
341340
342341 for pred_name in components_to_update :
343342 logger .info (f"Processing component: { pred_name } " )
344-
343+
345344 # Handle ReAct module components - use extract predictor for final outputs
346345 if pred_name .startswith ("react_module" ):
347346 module_name = pred_name .replace ("react_module:" , "" ) if ":" in pred_name else None
348347 react_module = getattr (program , module_name ) if module_name else program
349348 module = react_module .extract .predict
350349 logger .debug (f" ReAct module detected: using { module_name or 'top-level' } .extract for final outputs" )
351-
350+
352351 # Regular predictor - find by name
353352 else :
354353 module = None
@@ -449,7 +448,7 @@ def make_reflective_dataset(
449448 actual_pred_name = pred_name .split (":" , 1 )[1 ] + ".react" if ":" in pred_name else "react"
450449 else :
451450 actual_pred_name = pred_name
452-
451+
453452 feedback_fn = self .feedback_map [actual_pred_name ]
454453 fb = feedback_fn (
455454 predictor_output = outputs ,
@@ -461,11 +460,12 @@ def make_reflective_dataset(
461460 d ["Feedback" ] = fb ["feedback" ]
462461 if fb ["score" ] != module_score :
463462 if self .warn_on_score_mismatch :
463+ logger .warning ("The score returned by the metric with pred_name is different from the overall metric score. This can indicate 2 things: Either the metric is non-deterministic (e.g., LLM-as-judge, Semantic score, etc.) or the metric returned a score specific to pred_name that differs from the module level score. Currently, GEPA does not support predictor level scoring (support coming soon), and only requires a feedback text to be provided, which can be specific to the predictor or program level. GEPA will ignore the differing score returned, and instead use module level score. You can safely ignore this warning if using a semantic metric, however, if this mismatch is caused due to predictor scoring, please return module-level scores. To disable this warning, set warn_on_score_mismatch=False." )
464464 self .warn_on_score_mismatch = False
465465 fb ["score" ] = module_score
466466
467467 items .append (d )
468-
468+
469469 # Log exact reflective example that reflection LM will see
470470 if pred_name .startswith ("react_module" ) and len (items ) == 1 :
471471 logger .info (f" First reflective example for { pred_name } :" )
@@ -480,7 +480,7 @@ def make_reflective_dataset(
480480 if len (items ) == 0 :
481481 logger .warning (f" No valid reflective examples found for { pred_name } " )
482482 continue
483-
483+
484484 ret_d [pred_name ] = items
485485 logger .info (f" Created { len (items )} reflective examples for { pred_name } " )
486486
0 commit comments