diff --git a/src/dippy/cli/__init__.py b/src/dippy/cli/__init__.py index 55a43ee..37b462e 100644 --- a/src/dippy/cli/__init__.py +++ b/src/dippy/cli/__init__.py @@ -12,7 +12,10 @@ from dataclasses import dataclass from functools import lru_cache from pathlib import Path -from typing import Literal, Optional, Protocol +from typing import TYPE_CHECKING, Literal, Optional, Protocol + +if TYPE_CHECKING: + from dippy.core.config import Config @dataclass(frozen=True) @@ -20,6 +23,7 @@ class HandlerContext: """Context passed to handlers.""" tokens: list[str] + config: Config | None = None @dataclass(frozen=True) diff --git a/src/dippy/cli/python.py b/src/dippy/cli/python.py index d0caf76..f9ad328 100644 --- a/src/dippy/cli/python.py +++ b/src/dippy/cli/python.py @@ -460,9 +460,20 @@ class SafetyAnalyzer(ast.NodeVisitor): are allowed. Anything unknown is flagged. """ - def __init__(self, allow_print: bool = True): + def __init__( + self, + allow_print: bool = True, + extra_safe_modules: frozenset[str] = frozenset(), + extra_deny_modules: frozenset[str] = frozenset(), + ): self.violations: list[Violation] = [] self.allow_print = allow_print + self.safe_modules = SAFE_MODULES | extra_safe_modules + # User-configured allow explicitly overrides hardcoded dangerous list. + # Only exact matches are removed — submodules must be allowed separately. + self.deny_modules = ( + DANGEROUS_MODULES | extra_deny_modules + ) - extra_safe_modules def _add(self, node: ast.AST, kind: str, detail: str) -> None: self.violations.append( @@ -476,9 +487,9 @@ def visit_Import(self, node: ast.Import) -> None: module = alias.name root = module.split(".")[0] - if module in DANGEROUS_MODULES or root in DANGEROUS_MODULES: + if module in self.deny_modules or root in self.deny_modules: self._add(node, "import", f"dangerous module: {module}") - elif module not in SAFE_MODULES and root not in SAFE_MODULES: + elif module not in self.safe_modules and root not in self.safe_modules: self._add(node, "import", f"unknown module: {module}") self.generic_visit(node) @@ -491,9 +502,9 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: module = node.module root = module.split(".")[0] - if module in DANGEROUS_MODULES or root in DANGEROUS_MODULES: + if module in self.deny_modules or root in self.deny_modules: self._add(node, "import", f"dangerous module: {module}") - elif module not in SAFE_MODULES and root not in SAFE_MODULES: + elif module not in self.safe_modules and root not in self.safe_modules: self._add(node, "import", f"unknown module: {module}") self.generic_visit(node) @@ -616,7 +627,12 @@ def visit_Try(self, node: ast.Try) -> None: self.generic_visit(node) -def analyze_python_source(source: str, allow_print: bool = True) -> list[Violation]: +def analyze_python_source( + source: str, + allow_print: bool = True, + extra_safe_modules: frozenset[str] = frozenset(), + extra_deny_modules: frozenset[str] = frozenset(), +) -> list[Violation]: """ Analyze Python source code for safety violations. @@ -627,12 +643,20 @@ def analyze_python_source(source: str, allow_print: bool = True) -> list[Violati except SyntaxError as e: return [Violation(e.lineno or 0, e.offset or 0, "syntax", str(e))] - analyzer = SafetyAnalyzer(allow_print=allow_print) + analyzer = SafetyAnalyzer( + allow_print=allow_print, + extra_safe_modules=extra_safe_modules, + extra_deny_modules=extra_deny_modules, + ) analyzer.visit(tree) return analyzer.violations -def analyze_python_file(path: Path) -> tuple[bool, str]: +def analyze_python_file( + path: Path, + extra_safe_modules: frozenset[str] = frozenset(), + extra_deny_modules: frozenset[str] = frozenset(), +) -> tuple[bool, str]: """ Analyze a Python file for safety. @@ -662,7 +686,11 @@ def analyze_python_file(path: Path) -> tuple[bool, str]: except (OSError, UnicodeDecodeError) as e: return False, f"cannot read file: {e}" - violations = analyze_python_source(source) + violations = analyze_python_source( + source, + extra_safe_modules=extra_safe_modules, + extra_deny_modules=extra_deny_modules, + ) if violations: # Return first violation as reason @@ -777,6 +805,11 @@ def classify(ctx: HandlerContext) -> Classification: """ tokens = ctx.tokens cwd = Path.cwd() + config = ctx.config + + # Build extra module sets from config + extra_safe = frozenset(config.python_allow_modules) if config else frozenset() + extra_deny = frozenset(config.python_deny_modules) if config else frozenset() desc = get_description(tokens) @@ -818,7 +851,9 @@ def classify(ctx: HandlerContext) -> Classification: return Classification("ask", description=desc) # Try to analyze the script - is_safe, reason = analyze_python_file(script_path) + is_safe, reason = analyze_python_file( + script_path, extra_safe_modules=extra_safe, extra_deny_modules=extra_deny + ) if is_safe: return Classification("allow", description=f"{desc} (analyzed)") diff --git a/src/dippy/core/analyzer.py b/src/dippy/core/analyzer.py index 68841c4..5fc4128 100644 --- a/src/dippy/core/analyzer.py +++ b/src/dippy/core/analyzer.py @@ -287,7 +287,9 @@ def _analyze_command( and position > base_idx ): handler = get_handler(base) - outer_result = handler.classify(HandlerContext(words[base_idx:])) + outer_result = handler.classify( + HandlerContext(words[base_idx:], config=config) + ) if outer_result.action != "allow": inner_cmd = _get_word_value(word).strip("$()") return Decision("ask", f"cmdsub injection risk: {inner_cmd}") @@ -448,7 +450,7 @@ def _analyze_simple_command( # 5. CLI-specific handlers handler = get_handler(base) if handler: - result = handler.classify(HandlerContext(tokens)) + result = handler.classify(HandlerContext(tokens, config=config)) desc = result.description or get_description(tokens, base) # Check handler-provided redirect targets against config (skip in remote mode) if result.redirect_targets and not remote: diff --git a/src/dippy/core/config.py b/src/dippy/core/config.py index 653c5be..0a04006 100644 --- a/src/dippy/core/config.py +++ b/src/dippy/core/config.py @@ -8,6 +8,30 @@ from dataclasses import dataclass, field, replace from pathlib import Path +# Valid Python module path: dotted identifiers (e.g. "numpy", "http.server") +_MODULE_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$") + + +def _parse_module_name(rest: str) -> str: + """Parse and validate a Python module name from a directive argument. + + Strips inline comments (# ...) and validates the module name. + Raises ValueError if the name is missing, has extra words, or is invalid. + """ + # Strip inline comments + if "#" in rest: + rest = rest[: rest.index("#")].rstrip() + if not rest: + raise ValueError("requires a module name") + parts = rest.split() + if len(parts) != 1: + raise ValueError(f"requires exactly one module name, got: {rest!r}") + mod = parts[0] + if not _MODULE_RE.match(mod): + raise ValueError(f"invalid Python module name: {mod!r}") + return mod + + # Cache home directory at module load - fails fast if HOME is unset _HOME = Path.home() @@ -62,6 +86,12 @@ class Config: aliases: dict[str, str] = field(default_factory=dict) """Command aliases mapping source to target (e.g., ~/bin/gh -> gh).""" + python_allow_modules: list[str] = field(default_factory=list) + """Extra modules to treat as safe for Python static analysis.""" + + python_deny_modules: list[str] = field(default_factory=list) + """Extra modules to treat as dangerous for Python static analysis.""" + default: str = "ask" # 'allow' | 'ask' log: Path | None = None # None = no logging log_full: bool = False # log full command (requires log path) @@ -122,6 +152,9 @@ def _merge_configs(base: Config, overlay: Config) -> Config: after_mcp_rules=base.after_mcp_rules + overlay.after_mcp_rules, # Aliases: overlay wins for conflicting keys aliases={**base.aliases, **overlay.aliases}, + # Python module lists accumulate + python_allow_modules=base.python_allow_modules + overlay.python_allow_modules, + python_deny_modules=base.python_deny_modules + overlay.python_deny_modules, # Settings: overlay wins if set default=overlay.default if overlay.default != "ask" else base.default, log=overlay.log if overlay.log is not None else base.log, @@ -209,6 +242,8 @@ def parse_config(text: str, source: str | None = None) -> Config: mcp_rules: list[Rule] = [] after_mcp_rules: list[Rule] = [] aliases: dict[str, str] = {} + python_allow_modules: list[str] = [] + python_deny_modules: list[str] = [] settings: dict[str, bool | int | str | Path] = {} prefix = f"{source}: " if source else "" @@ -321,6 +356,14 @@ def parse_config(text: str, source: str | None = None) -> Config: ) aliases[expanded_source] = alias_target + elif directive == "python-allow-module": + mod = _parse_module_name(rest) + python_allow_modules.append(mod) + + elif directive == "python-deny-module": + mod = _parse_module_name(rest) + python_deny_modules.append(mod) + elif directive == "set": _apply_setting(settings, rest) @@ -337,6 +380,8 @@ def parse_config(text: str, source: str | None = None) -> Config: mcp_rules=mcp_rules, after_mcp_rules=after_mcp_rules, aliases=aliases, + python_allow_modules=python_allow_modules, + python_deny_modules=python_deny_modules, default=settings.get("default", "ask"), log=settings.get("log"), log_full=settings.get("log_full", False), diff --git a/tests/cli/test_python.py b/tests/cli/test_python.py index b8fec89..8d539fb 100644 --- a/tests/cli/test_python.py +++ b/tests/cli/test_python.py @@ -8,6 +8,7 @@ import pytest +from dippy.core.config import Config from conftest import is_approved, needs_confirmation @@ -1285,3 +1286,156 @@ def test_analyze_safe_still_safe(self): """ violations = analyze_python_source(source) assert len(violations) == 0, f"Expected no violations, got {violations}" + + +class TestPythonConfigModules: + """Tests for configurable safe/unsafe module lists.""" + + def test_allow_module_via_config(self, check, tmp_path): + """User-allowed module should pass analysis.""" + script = tmp_path / "use_numpy.py" + script.write_text("import numpy\nx = numpy.array([1, 2, 3])") + config = Config(python_allow_modules=["numpy"]) + result = check(f"python {script}", config=config) + assert is_approved(result), "numpy should be approved via config" + + def test_deny_module_via_config(self, check, tmp_path): + """User-denied module should be flagged even if normally safe.""" + script = tmp_path / "use_json.py" + script.write_text("import json\njson.dumps({})") + config = Config(python_deny_modules=["json"]) + result = check(f"python {script}", config=config) + assert needs_confirmation(result), "json should be denied via config" + + def test_deny_overrides_safe(self, check, tmp_path): + """Deny should override the hardcoded safe list.""" + script = tmp_path / "use_math.py" + script.write_text("import math\nprint(math.pi)") + config = Config(python_deny_modules=["math"]) + result = check(f"python {script}", config=config) + assert needs_confirmation(result), "math should be denied via config override" + + def test_multiple_config_modules(self, check, tmp_path): + """Multiple allowed modules should all work.""" + script = tmp_path / "multi.py" + script.write_text("import numpy\nimport pandas\nx = 1") + config = Config(python_allow_modules=["numpy", "pandas"]) + result = check(f"python {script}", config=config) + assert is_approved(result), "multiple config modules should work" + + def test_no_config_unknown_module_blocked(self, check, tmp_path): + """Without config, unknown module should be blocked.""" + script = tmp_path / "use_numpy.py" + script.write_text("import numpy") + result = check(f"python {script}") + assert needs_confirmation(result), "unknown module without config should ask" + + +class TestPythonAllowOverridesDangerous: + """Tests for python-allow-module overriding hardcoded dangerous modules.""" + + def test_allow_pathlib_module(self, check, tmp_path): + """python-allow-module pathlib should override dangerous for pathlib.""" + script = tmp_path / "use_pathlib.py" + script.write_text("from pathlib import Path\np = Path('.')") + config = Config(python_allow_modules=["pathlib"]) + result = check(f"python {script}", config=config) + assert is_approved(result), "pathlib should be approved via config override" + + def test_allow_root_only_blocks_submodule(self, check, tmp_path): + """Allowing only root should NOT approve separately-listed submodules.""" + script = tmp_path / "use_http.py" + script.write_text("import http.server") + config = Config(python_allow_modules=["http"]) + result = check(f"python {script}", config=config) + assert needs_confirmation(result), "http.server needs separate allow" + + def test_allow_root_and_submodules(self, check, tmp_path): + """Allowing both http and http.server should approve the script.""" + script = tmp_path / "use_http.py" + script.write_text("import http.server\nprint('ok')") + config = Config(python_allow_modules=["http", "http.server"]) + result = check(f"python {script}", config=config) + assert is_approved(result), "http + http.server should be approved via config" + + def test_without_allow_still_blocked(self, check, tmp_path): + """Without config, dangerous modules should still be blocked.""" + script = tmp_path / "use_http.py" + script.write_text("import http.server") + result = check(f"python {script}") + assert needs_confirmation(result), "http should still be blocked without config" + + +class TestUnitAnalysisConfigModules: + """Unit tests for analyze_python_source with extra modules.""" + + def test_extra_safe_module(self): + from dippy.cli.python import analyze_python_source + + violations = analyze_python_source( + "import numpy", extra_safe_modules=frozenset({"numpy"}) + ) + assert len(violations) == 0 + + def test_extra_deny_module(self): + from dippy.cli.python import analyze_python_source + + violations = analyze_python_source( + "import json", extra_deny_modules=frozenset({"json"}) + ) + assert len(violations) > 0 + assert any("json" in v.detail for v in violations) + + def test_deny_overrides_builtin_safe(self): + from dippy.cli.python import analyze_python_source + + violations = analyze_python_source( + "import math", extra_deny_modules=frozenset({"math"}) + ) + assert len(violations) > 0 + assert any("math" in v.detail for v in violations) + + def test_from_import_respects_config(self): + from dippy.cli.python import analyze_python_source + + violations = analyze_python_source( + "from numpy import array", extra_safe_modules=frozenset({"numpy"}) + ) + assert len(violations) == 0 + + def test_allow_overrides_dangerous_exact(self): + """python-allow-module should override exact match in DANGEROUS_MODULES.""" + from dippy.cli.python import analyze_python_source + + violations = analyze_python_source( + "import http", extra_safe_modules=frozenset({"http"}) + ) + assert len(violations) == 0 + + def test_allow_does_not_override_submodules(self): + """Allowing root does NOT automatically allow separately-listed submodules.""" + from dippy.cli.python import analyze_python_source + + violations = analyze_python_source( + "import http.server", extra_safe_modules=frozenset({"http"}) + ) + assert len(violations) > 0 + + def test_allow_submodule_explicitly(self): + """Explicitly allowing a submodule should work.""" + from dippy.cli.python import analyze_python_source + + violations = analyze_python_source( + "import http.server", + extra_safe_modules=frozenset({"http", "http.server"}), + ) + assert len(violations) == 0 + + def test_allow_override_pathlib(self): + """Allowing pathlib should override dangerous for pathlib.""" + from dippy.cli.python import analyze_python_source + + violations = analyze_python_source( + "from pathlib import Path", extra_safe_modules=frozenset({"pathlib"}) + ) + assert len(violations) == 0 diff --git a/tests/test_config.py b/tests/test_config.py index 5393af0..47a8e9d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1973,3 +1973,91 @@ def test_alias_absolute_path(self, tmp_path): m = match_command(c, cfg, tmp_path) assert m is not None assert m.decision == "allow" + + +class TestPythonModuleDirectives: + """Tests for python-allow-module and python-deny-module directives.""" + + def test_python_allow_module(self): + cfg = parse_config("python-allow-module numpy") + assert cfg.python_allow_modules == ["numpy"] + + def test_python_deny_module(self): + cfg = parse_config("python-deny-module requests") + assert cfg.python_deny_modules == ["requests"] + + def test_multiple_allow_modules(self): + cfg = parse_config("python-allow-module numpy\npython-allow-module pandas") + assert cfg.python_allow_modules == ["numpy", "pandas"] + + def test_multiple_deny_modules(self): + cfg = parse_config("python-deny-module requests\npython-deny-module boto3") + assert cfg.python_deny_modules == ["requests", "boto3"] + + def test_empty_module_skipped(self): + """Empty module name should be skipped with warning.""" + cfg = parse_config("python-allow-module") + assert cfg.python_allow_modules == [] + + def test_inline_comment_after_module(self): + """Inline comment after module name should be stripped.""" + cfg = parse_config("python-allow-module numpy # math library") + assert cfg.python_allow_modules == ["numpy"] + + def test_merge_accumulates_modules(self): + """Merging configs should accumulate module lists.""" + base = Config(python_allow_modules=["numpy"]) + overlay = Config(python_allow_modules=["pandas"]) + merged = _merge_configs(base, overlay) + assert merged.python_allow_modules == ["numpy", "pandas"] + + def test_merge_accumulates_deny_modules(self): + base = Config(python_deny_modules=["requests"]) + overlay = Config(python_deny_modules=["boto3"]) + merged = _merge_configs(base, overlay) + assert merged.python_deny_modules == ["requests", "boto3"] + + def test_mixed_directives(self): + """Allow and deny modules can coexist.""" + cfg = parse_config("python-allow-module numpy\npython-deny-module requests") + assert cfg.python_allow_modules == ["numpy"] + assert cfg.python_deny_modules == ["requests"] + + def test_with_other_directives(self): + """Module directives coexist with other config directives.""" + cfg = parse_config("allow git\npython-allow-module numpy\ndeny rm") + assert cfg.python_allow_modules == ["numpy"] + assert len(cfg.rules) == 2 + + def test_dotted_module_name(self): + """Dotted module names like http.server should be valid.""" + cfg = parse_config("python-allow-module http.server") + assert cfg.python_allow_modules == ["http.server"] + + def test_deeply_nested_module_name(self): + """Deeply nested module names should be valid.""" + cfg = parse_config("python-allow-module xml.etree.ElementTree") + assert cfg.python_allow_modules == ["xml.etree.ElementTree"] + + @pytest.mark.parametrize( + "bad_name", + [ + "lol what", + "123bad", + "foo/bar", + "foo-bar", + ".leading.dot", + "trailing.", + "double..dot", + "foo bar baz", + ], + ) + def test_invalid_module_name_skipped(self, bad_name): + """Invalid Python module names should be rejected.""" + cfg = parse_config(f"python-allow-module {bad_name}") + assert cfg.python_allow_modules == [] + + def test_invalid_deny_module_name_skipped(self): + """Invalid module name in deny should also be rejected.""" + cfg = parse_config("python-deny-module not-valid!") + assert cfg.python_deny_modules == []