From 5dad1bcc8a9ae4823aa38cdfe387e6bb7b1b71f4 Mon Sep 17 00:00:00 2001 From: haawha Date: Thu, 2 Jan 2025 18:46:22 +0800 Subject: [PATCH] fix(security): prevent SQL injection in chart data query (CVE-2024-10901) - Add SQL validation for chart queries - Block dangerous DuckDB operations - Implement timeout protection - Enhance input validation and type safety --- .../openapi/api_v1/editor/api_editor_v1.py | 57 ++++++++++++++++--- 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py index de623bdcc..be028109d 100644 --- a/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py +++ b/dbgpt/app/openapi/api_v1/editor/api_editor_v1.py @@ -2,6 +2,7 @@ import logging import time from typing import Dict, List +import re from fastapi import APIRouter, Body, Depends @@ -171,18 +172,56 @@ async def editor_chart_run(run_param: dict = Body()): db_name = run_param["db_name"] sql = run_param["sql"] chart_type = run_param["chart_type"] - if not db_name and not sql: - return Result.failed("SQL run param error!") + + # Validate input parameters + if not db_name or not sql or not chart_type: + return Result.failed("Required parameters missing") + try: - dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() + # Validate database type and prevent dangerous operations db_conn = CFG.local_db_manager.get_connector(db_name) - colunms, sql_result = db_conn.query_ex(sql) + db_type = getattr(db_conn, "db_type", "").lower() + + # Block dangerous operations for DuckDB + if db_type == "duckdb": + # Block file operations and system commands + dangerous_keywords = [ + # File operations + "copy", "export", "import", "load", "install", + "read_", "write_", "save", "from_", "to_", + # System commands + "create_", "drop_", ".execute(", "system", "shell", + # Additional DuckDB specific operations + "attach", "detach", "pragma", "checkpoint", + "load_extension", "unload_extension", + # File paths + "/'", "'/'", "\\", "://" + ] + sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass + if any(keyword in sql_lower for keyword in dangerous_keywords): + logger.warning(f"Blocked dangerous SQL operation attempt in chart: {sql}") + return Result.failed(msg="Operation not allowed for security reasons") + + # Additional check for file path patterns + if re.search(r"['\"].*[/\\].*['\"]", sql): + logger.warning(f"Blocked file path in chart SQL: {sql}") + return Result.failed(msg="File operations not allowed") + + dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() + + start_time = time.time() * 1000 + + # Execute query with timeout + colunms, sql_result = db_conn.query_ex(sql, timeout=30) + + # Safely convert and process results field_names, chart_values = dashboard_data_loader.get_chart_values_by_data( - colunms, sql_result, sql + colunms, + [tuple(str(x) if x is not None else None for x in row) for row in sql_result], + sql ) - start_time = time.time() * 1000 - # 计算执行耗时 + # Calculate execution time end_time = time.time() * 1000 sql_run_data: SqlRunData = SqlRunData( result_info="", @@ -192,7 +231,9 @@ async def editor_chart_run(run_param: dict = Body()): ) return Result.succ( ChartRunData( - sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type + sql_data=sql_run_data, + chart_values=chart_values, + chart_type=chart_type ) ) except Exception as e: