diff --git a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py index de623bdcc..51c33bde0 100644 --- a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py +++ b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py @@ -1,5 +1,6 @@ import json import logging +import re import time from typing import Dict, List @@ -102,14 +103,64 @@ async def editor_sql_run(run_param: dict = Body()): sql = run_param["sql"] if not db_name and not sql: return Result.failed(msg="SQL run param error!") + + # Validate database type and prevent dangerous operations conn = CFG.local_db_manager.get_connector(db_name) + db_type = getattr(conn, "db_type", "").lower() + + # Block dangerous operations for DuckDB + if db_type == "duckdb": + # Block file operations and system commands + dangerous_keywords = [ + # File operations + "copy", + "export", + "import", + "load", + "install", + "read_", + "write_", + "save", + "from_", + "to_", + # System commands + "create_", + "drop_", + ".execute(", + "system", + "shell", + # Additional DuckDB specific operations + "attach", + "detach", + "pragma", + "checkpoint", + "load_extension", + "unload_extension", + # File paths + "/'", + "'/'", + "\\", + "://", + ] + sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass + if any(keyword in sql_lower for keyword in dangerous_keywords): + logger.warning(f"Blocked dangerous SQL operation attempt: {sql}") + return Result.failed(msg="Operation not allowed for security reasons") + + # Additional check for file path patterns + if re.search(r"['\"].*[/\\].*['\"]", sql): + logger.warning(f"Blocked file path in SQL: {sql}") + return Result.failed(msg="File operations not allowed") try: start_time = time.time() * 1000 - colunms, sql_result = conn.query_ex(sql) - # 转换结果类型 - sql_result = [tuple(x) for x in sql_result] - # 计算执行耗时 + # Add timeout protection + colunms, sql_result = conn.query_ex(sql, timeout=30) + # Convert result type safely + sql_result = [ + tuple(str(x) if x is not None else None for x in row) for row in sql_result + ] + # Calculate execution time end_time = time.time() * 1000 sql_run_data: SqlRunData = SqlRunData( result_info="", @@ -119,7 +170,7 @@ async def editor_sql_run(run_param: dict = Body()): ) return Result.succ(sql_run_data) except Exception as e: - logging.error("editor_sql_run exception!" + str(e)) + logger.error(f"editor_sql_run exception: {str(e)}", exc_info=True) return Result.succ( SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[]) )