27
27
from smolagents .default_tools import BASE_PYTHON_TOOLS
28
28
from smolagents .local_python_executor import (
29
29
DANGEROUS_FUNCTIONS ,
30
- DANGEROUS_MODULES ,
31
30
InterpreterError ,
32
31
LocalPythonExecutor ,
33
32
PrintContainer ,
41
40
)
42
41
43
42
43
+ # Non-exhaustive list of dangerous modules that should not be imported
44
+ DANGEROUS_MODULES = [
45
+ "builtins" ,
46
+ "io" ,
47
+ "multiprocessing" ,
48
+ "os" ,
49
+ "pathlib" ,
50
+ "pty" ,
51
+ "shutil" ,
52
+ "socket" ,
53
+ "subprocess" ,
54
+ "sys" ,
55
+ ]
56
+
57
+
44
58
# Fake function we will use as tool
45
59
def add_two (x ):
46
60
return x + 2
@@ -504,10 +518,10 @@ def test_imports(self):
504
518
505
519
# Test submodules are handled properly, thus not raising error
506
520
code = "import numpy.random as rd\n rng = rd.default_rng(12345)\n rng.random()"
507
- result , _ = evaluate_python_code (code , BASE_PYTHON_TOOLS , state = {}, authorized_imports = ["numpy" ])
521
+ result , _ = evaluate_python_code (code , BASE_PYTHON_TOOLS , state = {}, authorized_imports = ["numpy.random " ])
508
522
509
523
code = "from numpy.random import default_rng as d_rng\n rng = d_rng(12345)\n rng.random()"
510
- result , _ = evaluate_python_code (code , BASE_PYTHON_TOOLS , state = {}, authorized_imports = ["numpy" ])
524
+ result , _ = evaluate_python_code (code , BASE_PYTHON_TOOLS , state = {}, authorized_imports = ["numpy.random " ])
511
525
512
526
def test_additional_imports (self ):
513
527
code = "import numpy as np"
@@ -1773,7 +1787,16 @@ def test_vulnerability_via_importlib(self, additional_authorized_imports, expect
1773
1787
"code, additional_authorized_imports, expected_error" ,
1774
1788
[
1775
1789
# os submodule
1776
- ("import queue; queue.threading._os.system(':')" , [], InterpreterError ("Forbidden access to module: os" )),
1790
+ (
1791
+ "import queue; queue.threading._os.system(':')" ,
1792
+ [],
1793
+ InterpreterError ("Forbidden access to module: threading" ),
1794
+ ),
1795
+ (
1796
+ "import queue; queue.threading._os.system(':')" ,
1797
+ ["threading" ],
1798
+ InterpreterError ("Forbidden access to module: os" ),
1799
+ ),
1777
1800
("import random; random._os.system(':')" , [], InterpreterError ("Forbidden access to module: os" )),
1778
1801
(
1779
1802
"import random; random.__dict__['_os'].system(':')" ,
@@ -1783,22 +1806,42 @@ def test_vulnerability_via_importlib(self, additional_authorized_imports, expect
1783
1806
(
1784
1807
"import doctest; doctest.inspect.os.system(':')" ,
1785
1808
["doctest" ],
1809
+ InterpreterError ("Forbidden access to module: inspect" ),
1810
+ ),
1811
+ (
1812
+ "import doctest; doctest.inspect.os.system(':')" ,
1813
+ ["doctest" , "inspect" ],
1786
1814
InterpreterError ("Forbidden access to module: os" ),
1787
1815
),
1788
1816
# subprocess submodule
1789
1817
(
1790
1818
"import asyncio; asyncio.base_events.events.subprocess" ,
1791
1819
["asyncio" ],
1820
+ InterpreterError ("Forbidden access to module: asyncio.base_events" ),
1821
+ ),
1822
+ (
1823
+ "import asyncio; asyncio.base_events.events.subprocess" ,
1824
+ ["asyncio" , "asyncio.base_events" ],
1825
+ InterpreterError ("Forbidden access to module: asyncio.events" ),
1826
+ ),
1827
+ (
1828
+ "import asyncio; asyncio.base_events.events.subprocess" ,
1829
+ ["asyncio" , "asyncio.base_events" , "asyncio.events" ],
1792
1830
InterpreterError ("Forbidden access to module: subprocess" ),
1793
1831
),
1794
1832
# sys submodule
1795
1833
(
1796
1834
"import queue; queue.threading._sys.modules['os'].system(':')" ,
1797
1835
[],
1836
+ InterpreterError ("Forbidden access to module: threading" ),
1837
+ ),
1838
+ (
1839
+ "import queue; queue.threading._sys.modules['os'].system(':')" ,
1840
+ ["threading" ],
1798
1841
InterpreterError ("Forbidden access to module: sys" ),
1799
1842
),
1800
1843
# Allowed
1801
- ("import pandas; pandas.io" , ["pandas" ], None ),
1844
+ ("import pandas; pandas.io" , ["pandas" , "pandas.io" ], None ),
1802
1845
],
1803
1846
)
1804
1847
def test_vulnerability_via_submodules (self , code , additional_authorized_imports , expected_error ):
@@ -1851,8 +1894,12 @@ def test_vulnerability_builtins_via_sys(self, additional_authorized_imports, add
1851
1894
@pytest .mark .parametrize (
1852
1895
"additional_authorized_imports, additional_tools, expected_error" ,
1853
1896
[
1854
- ([], [], InterpreterError ("Forbidden access to module: builtins" )),
1855
- (["builtins" , "os" ], ["__import__" ], None ),
1897
+ ([], [], InterpreterError ("Forbidden access to module: smolagents.local_python_executor" )),
1898
+ (
1899
+ ["builtins" , "os" ],
1900
+ ["__import__" ],
1901
+ InterpreterError ("Forbidden access to module: smolagents.local_python_executor" ),
1902
+ ),
1856
1903
],
1857
1904
)
1858
1905
def test_vulnerability_builtins_via_traceback (
@@ -1888,8 +1935,18 @@ def test_vulnerability_builtins_via_traceback(
1888
1935
@pytest .mark .parametrize (
1889
1936
"additional_authorized_imports, additional_tools, expected_error" ,
1890
1937
[
1891
- ([], [], InterpreterError ("Forbidden access to module: builtins" )),
1892
- (["builtins" , "os" ], ["__import__" ], None ),
1938
+ ([], [], InterpreterError ("Forbidden access to module: warnings" )),
1939
+ (["warnings" ], [], InterpreterError ("Forbidden access to module: builtins" )),
1940
+ (
1941
+ ["warnings" , "builtins" ],
1942
+ [],
1943
+ (
1944
+ InterpreterError ("Forbidden access to function: __import__" ),
1945
+ InterpreterError ("Forbidden access to module: os" ),
1946
+ ),
1947
+ ),
1948
+ (["warnings" , "builtins" , "os" ], [], (InterpreterError ("Forbidden access to function: __import__" ), None )),
1949
+ (["warnings" , "builtins" , "os" ], ["__import__" ], None ),
1893
1950
],
1894
1951
)
1895
1952
def test_vulnerability_builtins_via_class_catch_warnings (
@@ -1902,20 +1959,22 @@ def test_vulnerability_builtins_via_class_catch_warnings(
1902
1959
from builtins import __import__
1903
1960
1904
1961
executor .send_tools ({"__import__" : __import__ })
1905
- with (
1906
- pytest .raises (type (expected_error ), match = f".*{ expected_error } " )
1907
- if isinstance (expected_error , Exception )
1908
- else does_not_raise ()
1909
- ):
1962
+ if isinstance (expected_error , tuple ): # different error depending on patch status
1963
+ expected_error = expected_error [patch_builtin_import_module ]
1964
+ if isinstance (expected_error , Exception ):
1965
+ expectation = pytest .raises (type (expected_error ), match = f".*{ expected_error } " )
1966
+ elif expected_error is None :
1967
+ expectation = does_not_raise ()
1968
+ with expectation :
1910
1969
executor (
1911
1970
dedent (
1912
1971
"""
1913
1972
classes = {}.__class__.__base__.__subclasses__()
1914
1973
for cls in classes:
1915
1974
if cls.__name__ == "catch_warnings":
1916
- builtins = cls()._module.__builtins__
1917
- builtins_import = builtins["__import__"]
1918
1975
break
1976
+ builtins = cls()._module.__builtins__
1977
+ builtins_import = builtins["__import__"]
1919
1978
os_module = builtins_import('os')
1920
1979
os_module.system(":")
1921
1980
"""
0 commit comments