Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions code_review_graph/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
code-review-graph register <path> [--alias name]
code-review-graph unregister <path_or_alias>
code-review-graph repos
code-review-graph query --sql <SELECT> [options]
code-review-graph query --file <path> [options]
"""

from __future__ import annotations
Expand Down Expand Up @@ -103,6 +105,7 @@ def _print_banner() -> None:
{g}postprocess{r} Run post-processing {d}(flows, communities, FTS){r}
{g}eval{r} Run evaluation benchmarks
{g}serve{r} Start MCP server {d}(stdio, or {g}--http{r} on localhost:5555){r}
{g}query{r} Execute a raw SQL SELECT against the graph database

{d}Run{r} {b}code-review-graph <command> --help{r} {d}for details{r}
""")
Expand Down Expand Up @@ -285,6 +288,70 @@ def _cli_post_process(store: GraphStore) -> None:
print(f"Communities: {pp['communities_detected']}")


def _handle_query(args: argparse.Namespace) -> None:
"""Execute a raw SQL SELECT against the graph database."""
from .incremental import find_project_root, get_db_path
from .query import _MAX_LIMIT, QueryError, format_table, run_query

# Resolve SQL source
if args.sql:
sql = args.sql
elif args.file:
try:
sql = Path(args.file).read_text(encoding="utf-8")
except OSError as exc:
print(json.dumps({"error": str(exc), "type": "file_error"}), file=sys.stderr)
sys.exit(1)
elif not sys.stdin.isatty():
sql = sys.stdin.read()
else:
print(
json.dumps(
{"error": "provide --sql, --file, or pipe SQL via stdin",
"type": "missing_input"}
),
file=sys.stderr,
)
sys.exit(1)

# Parse --param key=value flags
params: dict[str, str] = {}
for p in args.param or []:
if "=" not in p:
print(
json.dumps(
{"error": f"--param must be key=value format, got: {p!r}",
"type": "invalid_param"}
),
file=sys.stderr,
)
sys.exit(1)
k, _, v = p.partition("=")
params[k] = v

limit = min(max(1, args.limit), _MAX_LIMIT)

repo_root = Path(args.repo) if args.repo else find_project_root()
db_path = get_db_path(repo_root)

try:
rows = run_query(
db_path,
sql,
params=params or None,
limit=limit,
timeout=args.timeout,
)
except QueryError as exc:
print(json.dumps(exc.to_dict()), file=sys.stderr)
sys.exit(1)

if args.output_format == "table":
print(format_table(rows))
else:
print(json.dumps(rows, indent=2, default=str))


def main() -> None:
"""Main CLI entry point."""
ap = argparse.ArgumentParser(
Expand Down Expand Up @@ -486,6 +553,43 @@ def main() -> None:
detect_cmd.add_argument("--brief", action="store_true", help="Show brief summary only")
detect_cmd.add_argument("--repo", default=None, help="Repository root (auto-detected)")

# query
query_cmd = sub.add_parser(
"query", help="Execute a raw SQL SELECT against the graph database"
)
query_sql_group = query_cmd.add_mutually_exclusive_group()
query_sql_group.add_argument("--sql", default=None, help="SQL SELECT statement")
query_sql_group.add_argument("--file", default=None, help="Path to a .sql file")
query_cmd.add_argument(
"--param",
action="append",
default=[],
metavar="key=value",
help="Bind a named parameter (:key) in the SQL (repeatable)",
)
query_cmd.add_argument(
"--limit",
type=int,
default=100,
help="Maximum rows returned (default: 100, max: 1000)",
)
query_cmd.add_argument(
"--format",
choices=["json", "table"],
default="json",
dest="output_format",
help="Output format: json (default) | table",
)
query_cmd.add_argument(
"--timeout",
type=int,
default=10,
help="Query timeout in seconds (default: 10)",
)
query_cmd.add_argument(
"--repo", default=None, help="Repository root (default: current directory)"
)

# serve
serve_cmd = sub.add_parser(
"serve",
Expand Down Expand Up @@ -624,6 +728,11 @@ def main() -> None:
print(f" {entry['path']}{alias_str}")
return

if args.command == "query":
logging.basicConfig(level=logging.WARNING, format="%(levelname)s: %(message)s")
_handle_query(args)
return

logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

from .graph import GraphStore
Expand Down
134 changes: 134 additions & 0 deletions code_review_graph/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Read-only raw SQL query against the graph database."""

from __future__ import annotations

import sqlite3
import threading
from pathlib import Path
from typing import Any
from urllib.parse import quote

import sqlglot
from sqlglot import expressions as exp

_DEFAULT_LIMIT = 100
_MAX_LIMIT = 1000
_DEFAULT_TIMEOUT = 10


class QueryError(Exception):
"""Structured error raised when a query violates safety rules or fails at runtime."""

def __init__(self, message: str, error_type: str) -> None:
super().__init__(message)
self.error_type = error_type

def to_dict(self) -> dict[str, str]:
return {"error": str(self), "type": self.error_type}


def _validate_sql(sql: str) -> None:
"""Raise QueryError if the SQL statement violates safety constraints.

Uses sqlglot to parse into an AST so string literals, double-quoted
identifiers, and comments can never fool the keyword check.
"""
try:
statements = sqlglot.parse(sql, dialect="sqlite",
error_level=sqlglot.ErrorLevel.RAISE)
except sqlglot.errors.ParseError as exc:
raise QueryError(str(exc), "sql_syntax_error") from exc

if not statements:
raise QueryError("empty query", "disallowed_statement")

if len(statements) > 1:
raise QueryError("only a single SQL statement is permitted",
"disallowed_statement")

stmt = statements[0]

if isinstance(stmt, exp.Pragma):
raise QueryError("PRAGMA statements are not permitted", "disallowed_pragma")

# exp.Select covers both plain SELECT and CTEs (WITH ... SELECT).
if not isinstance(stmt, exp.Select):
raise QueryError("only SELECT statements are permitted", "disallowed_statement")


def run_query(
db_path: Path,
sql: str,
params: dict[str, Any] | None = None,
limit: int = _DEFAULT_LIMIT,
timeout: int = _DEFAULT_TIMEOUT,
) -> list[dict[str, Any]]:
"""Execute a read-only SELECT against the graph database.

Opens the database in read-only URI mode so writes are rejected at the
connection level, not just by statement inspection.

Raises:
QueryError: on safety violations, syntax errors, timeout, or missing DB.
"""
limit = min(max(1, limit), _MAX_LIMIT)

_validate_sql(sql)

db_uri = f"file:{quote(str(db_path))}?mode=ro"
try:
conn = sqlite3.connect(db_uri, uri=True, check_same_thread=False)
except sqlite3.OperationalError as exc:
raise QueryError(str(exc), "db_error") from exc

conn.row_factory = sqlite3.Row
timed_out = threading.Event()

def _on_timeout() -> None:
timed_out.set()
conn.interrupt()

timer = threading.Timer(timeout, _on_timeout)
timer.start()
try:
cur = conn.execute(sql, params or {})
# Pull one extra row to detect truncation. We cap at the driver level
# rather than appending "LIMIT N" to the SQL so the user's query runs
# unmodified — appending LIMIT would break queries that already contain one.
rows = cur.fetchmany(limit + 1)
return [dict(row) for row in rows[:limit]]
except sqlite3.OperationalError as exc:
if timed_out.is_set():
raise QueryError(f"query exceeded timeout of {timeout}s", "timeout") from exc
raise QueryError(str(exc), "sql_error") from exc
except sqlite3.DatabaseError as exc:
raise QueryError(str(exc), "sql_syntax_error") from exc
finally:
timer.cancel()
conn.close()


def format_table(rows: list[dict[str, Any]]) -> str:
"""Format a list of row dicts as a pipe-delimited ASCII table."""
if not rows:
return "(no results)"

columns = list(rows[0].keys())

def _str(v: Any) -> str:
return "NULL" if v is None else str(v)

str_rows = [{col: _str(row.get(col)) for col in columns} for row in rows]
col_widths = {
col: max(len(col), max(len(r[col]) for r in str_rows))
for col in columns
}

header = " | ".join(col.ljust(col_widths[col]) for col in columns)
separator = "-+-".join("-" * col_widths[col] for col in columns)
data_lines = [
" | ".join(r[col].ljust(col_widths[col]) for col in columns)
for r in str_rows
]

return "\n".join([header, separator, *data_lines])
9 changes: 9 additions & 0 deletions docs/COMMANDS.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,15 @@ code-review-graph register <path> [--alias name] # Register a repository
code-review-graph unregister <path_or_alias> # Remove from registry
code-review-graph repos # List registered repositories

# Raw SQL query
code-review-graph query --sql "SELECT name, kind FROM nodes LIMIT 10"
code-review-graph query --file query.sql # Read SQL from file
code-review-graph query --sql "SELECT * FROM nodes WHERE name = :n" --param n=main
code-review-graph query --sql "..." --format table # ASCII table output
code-review-graph query --sql "..." --limit 500 # Up to 1000 rows
code-review-graph query --sql "..." --timeout 30 # Custom timeout (seconds)
code-review-graph query --sql "..." --repo /path/to/repo # Explicit repo root

# Evaluation
code-review-graph eval # Run evaluation benchmarks

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"tree-sitter-language-pack>=0.3.0,<1",
"networkx>=3.2,<4",
"watchdog>=4.0.0,<6",
"sqlglot>=30.4.3",
]

[project.urls]
Expand Down
Loading
Loading