diff --git a/__init__.py b/__init__.py index edad52f..9810a2e 100644 --- a/__init__.py +++ b/__init__.py @@ -49,10 +49,8 @@ def _get_provider_class(provider_type: ProviderType): return SearchBackend if provider_type == ProviderType.PLAN: - return PlanBackend if provider_type == ProviderType.CODE: - return CodeBackend return VanillaBackend diff --git a/backends/code_backend.py b/backends/code_backend.py index 9804cc0..455d7dc 100644 --- a/backends/code_backend.py +++ b/backends/code_backend.py @@ -39,7 +39,9 @@ def _schema_to_python_type(schema): if "anyOf" in schema or "oneOf" in schema: variants = schema.get("anyOf") or schema.get("oneOf", []) - mapped = [_schema_to_python_type(v) for v in variants if v.get("type") != "null"] + mapped = [ + _schema_to_python_type(v) for v in variants if v.get("type") != "null" + ] has_null = any(v.get("type") == "null" for v in variants) if not mapped: return "None" if has_null else "Any" @@ -81,13 +83,19 @@ def _build_stub(tool): params = [] for name, prop in properties.items(): type_str = _schema_to_python_type(prop) - params.append(f"{name}: {type_str}" if name in required else f"{name}: {type_str} = None") + params.append( + f"{name}: {type_str}" if name in required else f"{name}: {type_str} = None" + ) req = [p for p in params if "= None" not in p] opt = [p for p in params if "= None" in p] all_params = req + opt - sig = f"async def {tool.name}(*, {', '.join(all_params)})" if all_params else f"async def {tool.name}()" + sig = ( + f"async def {tool.name}(*, {', '.join(all_params)})" + if all_params + else f"async def {tool.name}()" + ) sig += " -> dict:" doc_lines = [] @@ -172,7 +180,6 @@ def _build_stub(tool): class CodeBackend(BaseProvider): - def initialize(self, config): self._config = config self._tools = [] @@ -205,12 +212,15 @@ async def call_tool(tool_name, arguments): module = pytypes.ModuleType("tools") for tool in self._tools: + def _make_stub(t): async def stub(**kwargs): return await call_tool(t.name, kwargs) + stub.__name__ = t.name stub.__doc__ = t.description return stub + setattr(module, tool.name, _make_stub(tool)) return module @@ -240,7 +250,9 @@ def search_tools(query): for info in tool_index.values(): text = f"{info['name']} {info['description'] or ''}".lower() if query_lower in text: - results.append({"name": info["name"], "description": info["description"]}) + results.append( + {"name": info["name"], "description": info["description"]} + ) return results module.list_tools = list_tools @@ -317,24 +329,98 @@ def serve_resources(self): return [(resource, read_result)] async def _execute_code(self, code: str, timeout: int = 30) -> dict: - _safe_builtins = {k: v for k, v in __builtins__.items() if k not in ( - "__import__", "exec", "eval", "compile", "open", - "breakpoint", "exit", "quit", "globals", "locals", - "getattr", "setattr", "delattr", "vars", "dir", - "memoryview", "type", "__build_class__", - )} if isinstance(__builtins__, dict) else {k: getattr(__builtins__, k) for k in ( - "print", "len", "range", "enumerate", "zip", "map", "filter", - "sorted", "reversed", "list", "dict", "set", "tuple", "str", - "int", "float", "bool", "bytes", "bytearray", - "min", "max", "sum", "abs", "round", "pow", "divmod", - "any", "all", "isinstance", "issubclass", "hasattr", - "repr", "format", "hash", "id", "callable", - "iter", "next", "slice", "frozenset", "complex", - "chr", "ord", "hex", "oct", "bin", - "ValueError", "TypeError", "KeyError", "IndexError", - "AttributeError", "RuntimeError", "StopIteration", - "Exception", "BaseException", "True", "False", "None", - ) if hasattr(__builtins__, k)} + _safe_builtins = ( + { + k: v + for k, v in __builtins__.items() + if k + not in ( + "__import__", + "exec", + "eval", + "compile", + "open", + "breakpoint", + "exit", + "quit", + "globals", + "locals", + "getattr", + "setattr", + "delattr", + "vars", + "dir", + "memoryview", + "type", + "__build_class__", + ) + } + if isinstance(__builtins__, dict) + else { + k: getattr(__builtins__, k) + for k in ( + "print", + "len", + "range", + "enumerate", + "zip", + "map", + "filter", + "sorted", + "reversed", + "list", + "dict", + "set", + "tuple", + "str", + "int", + "float", + "bool", + "bytes", + "bytearray", + "min", + "max", + "sum", + "abs", + "round", + "pow", + "divmod", + "any", + "all", + "isinstance", + "issubclass", + "hasattr", + "repr", + "format", + "hash", + "id", + "callable", + "iter", + "next", + "slice", + "frozenset", + "complex", + "chr", + "ord", + "hex", + "oct", + "bin", + "ValueError", + "TypeError", + "KeyError", + "IndexError", + "AttributeError", + "RuntimeError", + "StopIteration", + "Exception", + "BaseException", + "True", + "False", + "None", + ) + if hasattr(__builtins__, k) + } + ) namespace = {"__builtins__": _safe_builtins} # Inject modules — user code calls tools.X(), runtime.Y() @@ -343,7 +429,6 @@ async def _execute_code(self, code: str, timeout: int = 30) -> dict: namespace["tools"] = self._tools_module namespace["runtime"] = self._runtime_module - indented = textwrap.indent(code, " ") wrapped = f"async def __user_main__():\n{indented}\n" @@ -354,8 +439,10 @@ async def _execute_code(self, code: str, timeout: int = 30) -> dict: compiled = compile(wrapped, "", "exec") exec(compiled, namespace) - with contextlib.redirect_stdout(stdout_capture), \ - contextlib.redirect_stderr(stderr_capture): + with ( + contextlib.redirect_stdout(stdout_capture), + contextlib.redirect_stderr(stderr_capture), + ): await asyncio.wait_for( namespace["__user_main__"](), timeout=timeout,