Skip to content

Commit ba17042

Browse files
Forbid all modules by default except whitelist authorized_imports (huggingface#935)
1 parent 13afa93 commit ba17042

File tree

4 files changed

+85
-50
lines changed

4 files changed

+85
-50
lines changed

docs/source/en/guided_tour.mdx

+3
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ agent = CodeAgent(tools=[], model=model, additional_authorized_imports=['request
181181
agent.run("Could you get me the title of the page at url 'https://huggingface.co/blog'?")
182182
```
183183

184+
Additionally, as an extra security layer, access to submodule is forbidden by default, unless explicitly authorized within the import list.
185+
For instance, to access the `numpy.random` submodule, you need to add `'numpy.random'` to the `additional_authorized_imports` list.
186+
184187
> [!WARNING]
185188
> The LLM can generate arbitrary code that will then be executed: do not add any unsafe imports!
186189

docs/source/en/tutorials/secure_code_execution.mdx

+2-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ We have re-built a more secure `LocalPythonExecutor` from the ground up.
5252

5353
To be precise, this interpreter works by loading the Abstract Syntax Tree (AST) from your Code and executes it operation by operation, making sure to always follow certain rules:
5454
- By default, imports are disallowed unless they have been explicitly added to an authorization list by the user.
55-
- Even so, because some innocuous packages like `re` can give access to potentially harmful packages as in `re.subprocess`, subpackages that match a list of dangerous patterns are not imported.
55+
- Furthermore, access to submodules is disabled by default, and each must be explicitly authorized in the import list as well.
56+
- Note that some seemingly innocuous packages like `random` can give access to potentially harmful submodules, as in `random._os`.
5657
- The total count of elementary operations processed is capped to prevent infinite loops and resource bloating.
5758
- Any operation that has not been explicitly defined in our custom interpreter will raise an error.
5859

src/smolagents/local_python_executor.py

+5-33
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from collections.abc import Mapping
2525
from functools import wraps
2626
from importlib import import_module
27-
from importlib.util import find_spec
2827
from types import BuiltinFunctionType, FunctionType, ModuleType
2928
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
3029

@@ -114,19 +113,6 @@ def custom_print(*args):
114113
"complex": complex,
115114
}
116115

117-
DANGEROUS_MODULES = [
118-
"builtins",
119-
"io",
120-
"multiprocessing",
121-
"os",
122-
"pathlib",
123-
"pty",
124-
"shutil",
125-
"socket",
126-
"subprocess",
127-
"sys",
128-
]
129-
130116
DANGEROUS_FUNCTIONS = [
131117
"builtins.compile",
132118
"builtins.eval",
@@ -238,25 +224,11 @@ def _check_return(
238224
result = func(expression, state, static_tools, custom_tools, authorized_imports=authorized_imports)
239225
if "*" not in authorized_imports:
240226
if isinstance(result, ModuleType):
241-
for module_name in DANGEROUS_MODULES:
242-
if (
243-
module_name not in authorized_imports
244-
and result.__name__ == module_name
245-
# builtins has no __file__ attribute
246-
and getattr(result, "__file__", "")
247-
== (getattr(import_module(module_name), "__file__", "") if find_spec(module_name) else "")
248-
):
249-
raise InterpreterError(f"Forbidden access to module: {module_name}")
250-
elif isinstance(result, dict) and result.get("__name__"):
251-
for module_name in DANGEROUS_MODULES:
252-
if (
253-
module_name not in authorized_imports
254-
and result["__name__"] == module_name
255-
# builtins has no __file__ attribute
256-
and result.get("__file__", "")
257-
== (getattr(import_module(module_name), "__file__", "") if find_spec(module_name) else "")
258-
):
259-
raise InterpreterError(f"Forbidden access to module: {module_name}")
227+
if result.__name__ not in authorized_imports:
228+
raise InterpreterError(f"Forbidden access to module: {result.__name__}")
229+
elif isinstance(result, dict) and result.get("__spec__"):
230+
if result["__name__"] not in authorized_imports:
231+
raise InterpreterError(f"Forbidden access to module: {result['__name__']}")
260232
elif isinstance(result, (FunctionType, BuiltinFunctionType)):
261233
for qualified_function_name in DANGEROUS_FUNCTIONS:
262234
module_name, function_name = qualified_function_name.rsplit(".", 1)

tests/test_local_python_executor.py

+75-16
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from smolagents.default_tools import BASE_PYTHON_TOOLS
2828
from smolagents.local_python_executor import (
2929
DANGEROUS_FUNCTIONS,
30-
DANGEROUS_MODULES,
3130
InterpreterError,
3231
LocalPythonExecutor,
3332
PrintContainer,
@@ -41,6 +40,21 @@
4140
)
4241

4342

43+
# Non-exhaustive list of dangerous modules that should not be imported
44+
DANGEROUS_MODULES = [
45+
"builtins",
46+
"io",
47+
"multiprocessing",
48+
"os",
49+
"pathlib",
50+
"pty",
51+
"shutil",
52+
"socket",
53+
"subprocess",
54+
"sys",
55+
]
56+
57+
4458
# Fake function we will use as tool
4559
def add_two(x):
4660
return x + 2
@@ -504,10 +518,10 @@ def test_imports(self):
504518

505519
# Test submodules are handled properly, thus not raising error
506520
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
507-
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
521+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy.random"])
508522

509523
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
510-
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
524+
result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy.random"])
511525

512526
def test_additional_imports(self):
513527
code = "import numpy as np"
@@ -1773,7 +1787,16 @@ def test_vulnerability_via_importlib(self, additional_authorized_imports, expect
17731787
"code, additional_authorized_imports, expected_error",
17741788
[
17751789
# os submodule
1776-
("import queue; queue.threading._os.system(':')", [], InterpreterError("Forbidden access to module: os")),
1790+
(
1791+
"import queue; queue.threading._os.system(':')",
1792+
[],
1793+
InterpreterError("Forbidden access to module: threading"),
1794+
),
1795+
(
1796+
"import queue; queue.threading._os.system(':')",
1797+
["threading"],
1798+
InterpreterError("Forbidden access to module: os"),
1799+
),
17771800
("import random; random._os.system(':')", [], InterpreterError("Forbidden access to module: os")),
17781801
(
17791802
"import random; random.__dict__['_os'].system(':')",
@@ -1783,22 +1806,42 @@ def test_vulnerability_via_importlib(self, additional_authorized_imports, expect
17831806
(
17841807
"import doctest; doctest.inspect.os.system(':')",
17851808
["doctest"],
1809+
InterpreterError("Forbidden access to module: inspect"),
1810+
),
1811+
(
1812+
"import doctest; doctest.inspect.os.system(':')",
1813+
["doctest", "inspect"],
17861814
InterpreterError("Forbidden access to module: os"),
17871815
),
17881816
# subprocess submodule
17891817
(
17901818
"import asyncio; asyncio.base_events.events.subprocess",
17911819
["asyncio"],
1820+
InterpreterError("Forbidden access to module: asyncio.base_events"),
1821+
),
1822+
(
1823+
"import asyncio; asyncio.base_events.events.subprocess",
1824+
["asyncio", "asyncio.base_events"],
1825+
InterpreterError("Forbidden access to module: asyncio.events"),
1826+
),
1827+
(
1828+
"import asyncio; asyncio.base_events.events.subprocess",
1829+
["asyncio", "asyncio.base_events", "asyncio.events"],
17921830
InterpreterError("Forbidden access to module: subprocess"),
17931831
),
17941832
# sys submodule
17951833
(
17961834
"import queue; queue.threading._sys.modules['os'].system(':')",
17971835
[],
1836+
InterpreterError("Forbidden access to module: threading"),
1837+
),
1838+
(
1839+
"import queue; queue.threading._sys.modules['os'].system(':')",
1840+
["threading"],
17981841
InterpreterError("Forbidden access to module: sys"),
17991842
),
18001843
# Allowed
1801-
("import pandas; pandas.io", ["pandas"], None),
1844+
("import pandas; pandas.io", ["pandas", "pandas.io"], None),
18021845
],
18031846
)
18041847
def test_vulnerability_via_submodules(self, code, additional_authorized_imports, expected_error):
@@ -1851,8 +1894,12 @@ def test_vulnerability_builtins_via_sys(self, additional_authorized_imports, add
18511894
@pytest.mark.parametrize(
18521895
"additional_authorized_imports, additional_tools, expected_error",
18531896
[
1854-
([], [], InterpreterError("Forbidden access to module: builtins")),
1855-
(["builtins", "os"], ["__import__"], None),
1897+
([], [], InterpreterError("Forbidden access to module: smolagents.local_python_executor")),
1898+
(
1899+
["builtins", "os"],
1900+
["__import__"],
1901+
InterpreterError("Forbidden access to module: smolagents.local_python_executor"),
1902+
),
18561903
],
18571904
)
18581905
def test_vulnerability_builtins_via_traceback(
@@ -1888,8 +1935,18 @@ def test_vulnerability_builtins_via_traceback(
18881935
@pytest.mark.parametrize(
18891936
"additional_authorized_imports, additional_tools, expected_error",
18901937
[
1891-
([], [], InterpreterError("Forbidden access to module: builtins")),
1892-
(["builtins", "os"], ["__import__"], None),
1938+
([], [], InterpreterError("Forbidden access to module: warnings")),
1939+
(["warnings"], [], InterpreterError("Forbidden access to module: builtins")),
1940+
(
1941+
["warnings", "builtins"],
1942+
[],
1943+
(
1944+
InterpreterError("Forbidden access to function: __import__"),
1945+
InterpreterError("Forbidden access to module: os"),
1946+
),
1947+
),
1948+
(["warnings", "builtins", "os"], [], (InterpreterError("Forbidden access to function: __import__"), None)),
1949+
(["warnings", "builtins", "os"], ["__import__"], None),
18931950
],
18941951
)
18951952
def test_vulnerability_builtins_via_class_catch_warnings(
@@ -1902,20 +1959,22 @@ def test_vulnerability_builtins_via_class_catch_warnings(
19021959
from builtins import __import__
19031960

19041961
executor.send_tools({"__import__": __import__})
1905-
with (
1906-
pytest.raises(type(expected_error), match=f".*{expected_error}")
1907-
if isinstance(expected_error, Exception)
1908-
else does_not_raise()
1909-
):
1962+
if isinstance(expected_error, tuple): # different error depending on patch status
1963+
expected_error = expected_error[patch_builtin_import_module]
1964+
if isinstance(expected_error, Exception):
1965+
expectation = pytest.raises(type(expected_error), match=f".*{expected_error}")
1966+
elif expected_error is None:
1967+
expectation = does_not_raise()
1968+
with expectation:
19101969
executor(
19111970
dedent(
19121971
"""
19131972
classes = {}.__class__.__base__.__subclasses__()
19141973
for cls in classes:
19151974
if cls.__name__ == "catch_warnings":
1916-
builtins = cls()._module.__builtins__
1917-
builtins_import = builtins["__import__"]
19181975
break
1976+
builtins = cls()._module.__builtins__
1977+
builtins_import = builtins["__import__"]
19191978
os_module = builtins_import('os')
19201979
os_module.system(":")
19211980
"""

0 commit comments

Comments
 (0)