Skip to content

Commit

Permalink
run with "make fmt"
Browse files Browse the repository at this point in the history
  • Loading branch information
haawha committed Jan 6, 2025
1 parent 0322d6b commit 94bd927
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions dbgpt/app/openapi/api_v1/editor/api_editor_v1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import logging
import re
import time
from typing import Dict, List
import re

from fastapi import APIRouter, Body, Depends

Expand Down Expand Up @@ -103,25 +103,44 @@ async def editor_sql_run(run_param: dict = Body()):
sql = run_param["sql"]
if not db_name and not sql:
return Result.failed(msg="SQL run param error!")

# Validate database type and prevent dangerous operations
conn = CFG.local_db_manager.get_connector(db_name)
db_type = getattr(conn, "db_type", "").lower()

# Block dangerous operations for DuckDB
if db_type == "duckdb":
# Block file operations and system commands
dangerous_keywords = [
# File operations
"copy", "export", "import", "load", "install",
"read_", "write_", "save", "from_", "to_",
"copy",
"export",
"import",
"load",
"install",
"read_",
"write_",
"save",
"from_",
"to_",
# System commands
"create_", "drop_", ".execute(", "system", "shell",
"create_",
"drop_",
".execute(",
"system",
"shell",
# Additional DuckDB specific operations
"attach", "detach", "pragma", "checkpoint",
"load_extension", "unload_extension",
"attach",
"detach",
"pragma",
"checkpoint",
"load_extension",
"unload_extension",
# File paths
"/'", "'/'", "\\", "://"
"/'",
"'/'",
"\\",
"://",
]
sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass
if any(keyword in sql_lower for keyword in dangerous_keywords):
Expand All @@ -136,9 +155,11 @@ async def editor_sql_run(run_param: dict = Body()):
try:
start_time = time.time() * 1000
# Add timeout protection
colunms, sql_result = conn.query_ex(sql, timeout=30)
colunms, sql_result = conn.query_ex(sql, timeout=30)
# Convert result type safely
sql_result = [tuple(str(x) if x is not None else None for x in row) for row in sql_result]
sql_result = [
tuple(str(x) if x is not None else None for x in row) for row in sql_result
]
# Calculate execution time
end_time = time.time() * 1000
sql_run_data: SqlRunData = SqlRunData(
Expand Down

0 comments on commit 94bd927

Please sign in to comment.