Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@ def _get_provider_class(provider_type: ProviderType):

return SearchBackend
if provider_type == ProviderType.PLAN:

return PlanBackend
if provider_type == ProviderType.CODE:

return CodeBackend
return VanillaBackend

Expand Down
139 changes: 113 additions & 26 deletions backends/code_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def _schema_to_python_type(schema):

if "anyOf" in schema or "oneOf" in schema:
variants = schema.get("anyOf") or schema.get("oneOf", [])
mapped = [_schema_to_python_type(v) for v in variants if v.get("type") != "null"]
mapped = [
_schema_to_python_type(v) for v in variants if v.get("type") != "null"
]
has_null = any(v.get("type") == "null" for v in variants)
if not mapped:
return "None" if has_null else "Any"
Expand Down Expand Up @@ -81,13 +83,19 @@ def _build_stub(tool):
params = []
for name, prop in properties.items():
type_str = _schema_to_python_type(prop)
params.append(f"{name}: {type_str}" if name in required else f"{name}: {type_str} = None")
params.append(
f"{name}: {type_str}" if name in required else f"{name}: {type_str} = None"
)

req = [p for p in params if "= None" not in p]
opt = [p for p in params if "= None" in p]
all_params = req + opt

sig = f"async def {tool.name}(*, {', '.join(all_params)})" if all_params else f"async def {tool.name}()"
sig = (
f"async def {tool.name}(*, {', '.join(all_params)})"
if all_params
else f"async def {tool.name}()"
)
sig += " -> dict:"

doc_lines = []
Expand Down Expand Up @@ -172,7 +180,6 @@ def _build_stub(tool):


class CodeBackend(BaseProvider):

def initialize(self, config):
self._config = config
self._tools = []
Expand Down Expand Up @@ -205,12 +212,15 @@ async def call_tool(tool_name, arguments):
module = pytypes.ModuleType("tools")

for tool in self._tools:

def _make_stub(t):
async def stub(**kwargs):
return await call_tool(t.name, kwargs)

stub.__name__ = t.name
stub.__doc__ = t.description
return stub

setattr(module, tool.name, _make_stub(tool))

return module
Expand Down Expand Up @@ -240,7 +250,9 @@ def search_tools(query):
for info in tool_index.values():
text = f"{info['name']} {info['description'] or ''}".lower()
if query_lower in text:
results.append({"name": info["name"], "description": info["description"]})
results.append(
{"name": info["name"], "description": info["description"]}
)
return results

module.list_tools = list_tools
Expand Down Expand Up @@ -317,24 +329,98 @@ def serve_resources(self):
return [(resource, read_result)]

async def _execute_code(self, code: str, timeout: int = 30) -> dict:
_safe_builtins = {k: v for k, v in __builtins__.items() if k not in (
"__import__", "exec", "eval", "compile", "open",
"breakpoint", "exit", "quit", "globals", "locals",
"getattr", "setattr", "delattr", "vars", "dir",
"memoryview", "type", "__build_class__",
)} if isinstance(__builtins__, dict) else {k: getattr(__builtins__, k) for k in (
"print", "len", "range", "enumerate", "zip", "map", "filter",
"sorted", "reversed", "list", "dict", "set", "tuple", "str",
"int", "float", "bool", "bytes", "bytearray",
"min", "max", "sum", "abs", "round", "pow", "divmod",
"any", "all", "isinstance", "issubclass", "hasattr",
"repr", "format", "hash", "id", "callable",
"iter", "next", "slice", "frozenset", "complex",
"chr", "ord", "hex", "oct", "bin",
"ValueError", "TypeError", "KeyError", "IndexError",
"AttributeError", "RuntimeError", "StopIteration",
"Exception", "BaseException", "True", "False", "None",
) if hasattr(__builtins__, k)}
_safe_builtins = (
{
k: v
for k, v in __builtins__.items()
if k
not in (
"__import__",
"exec",
"eval",
"compile",
"open",
"breakpoint",
"exit",
"quit",
"globals",
"locals",
"getattr",
"setattr",
"delattr",
"vars",
"dir",
"memoryview",
"type",
"__build_class__",
)
}
if isinstance(__builtins__, dict)
else {
k: getattr(__builtins__, k)
for k in (
"print",
"len",
"range",
"enumerate",
"zip",
"map",
"filter",
"sorted",
"reversed",
"list",
"dict",
"set",
"tuple",
"str",
"int",
"float",
"bool",
"bytes",
"bytearray",
"min",
"max",
"sum",
"abs",
"round",
"pow",
"divmod",
"any",
"all",
"isinstance",
"issubclass",
"hasattr",
"repr",
"format",
"hash",
"id",
"callable",
"iter",
"next",
"slice",
"frozenset",
"complex",
"chr",
"ord",
"hex",
"oct",
"bin",
"ValueError",
"TypeError",
"KeyError",
"IndexError",
"AttributeError",
"RuntimeError",
"StopIteration",
"Exception",
"BaseException",
"True",
"False",
"None",
)
if hasattr(__builtins__, k)
}
)
namespace = {"__builtins__": _safe_builtins}

# Inject modules — user code calls tools.X(), runtime.Y()
Expand All @@ -343,7 +429,6 @@ async def _execute_code(self, code: str, timeout: int = 30) -> dict:
namespace["tools"] = self._tools_module
namespace["runtime"] = self._runtime_module


indented = textwrap.indent(code, " ")
wrapped = f"async def __user_main__():\n{indented}\n"

Expand All @@ -354,8 +439,10 @@ async def _execute_code(self, code: str, timeout: int = 30) -> dict:
compiled = compile(wrapped, "<user_code>", "exec")
exec(compiled, namespace)

with contextlib.redirect_stdout(stdout_capture), \
contextlib.redirect_stderr(stderr_capture):
with (
contextlib.redirect_stdout(stdout_capture),
contextlib.redirect_stderr(stderr_capture),
):
await asyncio.wait_for(
namespace["__user_main__"](),
timeout=timeout,
Expand Down