Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(security): prevent SQL injection in chart data query (CVE-2024-10901) #2269

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 70 additions & 7 deletions dbgpt/app/openapi/api_v1/editor/api_editor_v1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import re
import time
from typing import Dict, List

Expand Down Expand Up @@ -171,18 +172,80 @@ async def editor_chart_run(run_param: dict = Body()):
db_name = run_param["db_name"]
sql = run_param["sql"]
chart_type = run_param["chart_type"]
if not db_name and not sql:
return Result.failed("SQL run param error!")

# Validate input parameters
if not db_name or not sql or not chart_type:
return Result.failed("Required parameters missing")

try:
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
# Validate database type and prevent dangerous operations
db_conn = CFG.local_db_manager.get_connector(db_name)
colunms, sql_result = db_conn.query_ex(sql)
db_type = getattr(db_conn, "db_type", "").lower()

# Block dangerous operations for DuckDB
if db_type == "duckdb":
# Block file operations and system commands
dangerous_keywords = [
# File operations
"copy",
"export",
"import",
"load",
"install",
"read_",
"write_",
"save",
"from_",
"to_",
# System commands
"create_",
"drop_",
".execute(",
"system",
"shell",
# Additional DuckDB specific operations
"attach",
"detach",
"pragma",
"checkpoint",
"load_extension",
"unload_extension",
# File paths
"/'",
"'/'",
"\\",
"://",
]
sql_lower = sql.lower().replace(" ", "") # Remove spaces to prevent bypass
if any(keyword in sql_lower for keyword in dangerous_keywords):
logger.warning(
f"Blocked dangerous SQL operation attempt in chart: {sql}"
)
return Result.failed(msg="Operation not allowed for security reasons")

# Additional check for file path patterns
if re.search(r"['\"].*[/\\].*['\"]", sql):
logger.warning(f"Blocked file path in chart SQL: {sql}")
return Result.failed(msg="File operations not allowed")

dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()

start_time = time.time() * 1000

# Execute query with timeout
colunms, sql_result = db_conn.query_ex(sql, timeout=30)

# Safely convert and process results
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(
colunms, sql_result, sql
colunms,
[
tuple(str(x) if x is not None else None for x in row)
for row in sql_result
],
sql,
)

start_time = time.time() * 1000
# 计算执行耗时
# Calculate execution time
end_time = time.time() * 1000
sql_run_data: SqlRunData = SqlRunData(
result_info="",
Expand Down
Loading