Skip to content

Commit 925b814

Browse files
authored
Fix ReAct tool behavior when lacking type hints (#7655)
1 parent 80cdcbe commit 925b814

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

dspy/predict/react.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from dspy.signatures.signature import ensure_signature
99
from dspy.utils.callback import with_callbacks
1010

11-
1211
class Tool:
1312
def __init__(
1413
self,
@@ -18,25 +17,47 @@ def __init__(
1817
args: dict[str, Any] = None,
1918
):
2019
annotations_func = func if inspect.isfunction(func) or inspect.ismethod(func) else func.__call__
20+
2121
self.func = func
2222
self.name = name or getattr(func, "__name__", type(func).__name__)
23-
self.desc = desc or getattr(func, "__doc__", None) or getattr(annotations_func, "__doc__", "")
23+
self.desc = (
24+
desc
25+
or getattr(func, "__doc__", None)
26+
or getattr(annotations_func, "__doc__", "")
27+
)
2428
self.args = {}
2529
self.arg_types = {}
26-
for k, v in (args or get_type_hints(annotations_func)).items():
30+
31+
# If an explicit args dict is passed, use that; otherwise, extract from the function.
32+
if args is not None:
33+
hints = args
34+
else:
35+
# Use inspect.signature to get all parameter names
36+
sig = inspect.signature(annotations_func)
37+
# Get available type hints
38+
available_hints = get_type_hints(annotations_func)
39+
# Build a dictionary of parameter name -> type (defaulting to Any when missing)
40+
hints = {
41+
param_name: available_hints.get(param_name, Any)
42+
for param_name in sig.parameters.keys()
43+
}
44+
45+
# Process each argument's type to generate its JSON schema.
46+
for k, v in hints.items():
2747
self.arg_types[k] = v
2848
if k == "return":
2949
continue
30-
if isinstance((origin := get_origin(v) or v), type) and issubclass(origin, BaseModel):
50+
# Check if the type (or its origin) is a subclass of Pydantic's BaseModel
51+
origin = get_origin(v) or v
52+
if isinstance(origin, type) and issubclass(origin, BaseModel):
3153
self.args[k] = v.model_json_schema()
3254
else:
33-
self.args[k] = TypeAdapter(v).json_schema()
55+
self.args[k] = TypeAdapter(v).json_schema() or "Any"
3456

3557
@with_callbacks
3658
def __call__(self, *args, **kwargs):
3759
return self.func(*args, **kwargs)
3860

39-
4061
class ReAct(Module):
4162
def __init__(self, signature, tools: list[Callable], max_iters=5):
4263
"""

0 commit comments

Comments
 (0)