diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 4dcaedfc..0eba321d 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -1,3 +1,7 @@ +""" +模块说明:Alembic 迁移环境与配置入口。 +""" + import asyncio from logging.config import fileConfig diff --git a/backend/app/__init__.py b/backend/app/__init__.py index e69de29b..520d265d 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -0,0 +1,3 @@ +""" +模块说明:包初始化与导出。 +""" diff --git a/backend/app/api/__init__.py b/backend/app/api/__init__.py index e69de29b..61bc8108 100644 --- a/backend/app/api/__init__.py +++ b/backend/app/api/__init__.py @@ -0,0 +1,3 @@ +""" +模块说明:API 路由与依赖定义:__init__。 +""" diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 7a795338..ebd11c65 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -1,3 +1,7 @@ +""" +模块说明:API 路由与依赖定义:deps。 +""" + from typing import Generator, Optional from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer @@ -19,32 +23,56 @@ async def get_current_user( db: AsyncSession = Depends(get_db), token: str = Depends(reusable_oauth2) ) -> User: + """ + 从请求令牌解析并返回当前用户。 + + 处理流程: + - 解析并校验 JWT + - 构建 TokenPayload + - 查询用户并校验状态 + """ + # 解析并验证 JWT try: + # 解码令牌载荷 payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] ) + # 将载荷解析为结构化数据 token_data = token_schema.TokenPayload(**payload) except (JWTError, ValidationError): + # 令牌无效或格式不正确时返回 401 raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无法验证凭据", headers={"WWW-Authenticate": "Bearer"}, ) - + # 查询用户记录 result = await db.execute(select(User).where(User.id == token_data.sub)) + # 提取用户对象 user = result.scalars().first() - + # 若用户不存在则返回 404 if not user: raise HTTPException(status_code=404, detail="用户不存在") + # 若用户已被禁用则返回 400 if not user.is_active: raise HTTPException(status_code=400, detail="用户已被禁用") + # 返回当前用户 return user async def get_current_active_superuser( current_user: User = Depends(get_current_user), ) -> User: + """ + 校验当前用户是否为超级管理员。 + + 处理流程: + - 依赖注入获取当前用户 + - 检查超级管理员标识 + """ + # 若非超级管理员则拒绝访问 if not current_user.is_superuser: raise HTTPException( status_code=400, detail="权限不足" ) + # 返回通过校验的用户 return current_user diff --git a/backend/app/api/v1/__init__.py b/backend/app/api/v1/__init__.py index e69de29b..61bc8108 100644 --- a/backend/app/api/v1/__init__.py +++ b/backend/app/api/v1/__init__.py @@ -0,0 +1,3 @@ +""" +模块说明:API 路由与依赖定义:__init__。 +""" diff --git a/backend/app/api/v1/api.py b/backend/app/api/v1/api.py index 285a3aa0..50213819 100644 --- a/backend/app/api/v1/api.py +++ b/backend/app/api/v1/api.py @@ -1,3 +1,7 @@ +""" +模块说明:API 路由与依赖定义:api。 +""" + from fastapi import APIRouter from app.api.v1.endpoints import auth, users, projects, tasks, scan, members, config, database, prompts, rules, agent_tasks, embedding_config, ssh_keys diff --git a/backend/app/api/v1/endpoints/__init__.py b/backend/app/api/v1/endpoints/__init__.py index e69de29b..61bc8108 100644 --- a/backend/app/api/v1/endpoints/__init__.py +++ b/backend/app/api/v1/endpoints/__init__.py @@ -0,0 +1,3 @@ +""" +模块说明:API 路由与依赖定义:__init__。 +""" diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py index d976d9fd..6eb5663c 100644 --- a/backend/app/api/v1/endpoints/agent_tasks.py +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -231,7 +231,14 @@ class TaskSummaryResponse(BaseModel): def is_task_cancelled(task_id: str) -> bool: - """检查任务是否已被取消""" + """ + 检查任务是否已被取消。 + + 处理流程: + - 查询已取消任务集合 + - 返回是否存在 + """ + # 判断是否在取消集合内 return task_id in _cancelled_tasks @@ -241,6 +248,7 @@ async def _execute_agent_task(task_id: str): 架构:OrchestratorAgent 作为大脑,动态调度子 Agent """ + # 延迟导入 Agent 与依赖组件 from app.services.agent.agents import OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent from app.services.agent.event_manager import EventManager, AgentEventEmitter from app.services.llm.service import LLMService @@ -249,20 +257,21 @@ async def _execute_agent_task(task_id: str): from app.core.config import settings import time - # 🔥 在任务最开始就初始化 Docker 沙箱管理器 + # 在任务最开始就初始化 Docker 沙箱管理器 # 这样可以确保整个任务生命周期内使用同一个管理器,并且尽早发现 Docker 问题 logger.info(f"🚀 Starting execution for task {task_id}") sandbox_manager = SandboxManager() await sandbox_manager.initialize() logger.info(f"🐳 Global Sandbox Manager initialized (Available: {sandbox_manager.is_available})") - # 🔥 提前创建事件管理器,以便在克隆仓库和索引时发送实时日志 + # 提前创建事件管理器,以便在克隆仓库和索引时发送实时日志 from app.services.agent.event_manager import EventManager, AgentEventEmitter event_manager = EventManager(db_session_factory=async_session_factory) event_manager.create_queue(task_id) event_emitter = AgentEventEmitter(task_id, event_manager) _running_event_managers[task_id] = event_manager + # 打开异步数据库会话 async with async_session_factory() as db: orchestrator = None start_time = time.time() @@ -280,7 +289,7 @@ async def _execute_agent_task(task_id: str): logger.error(f"Project not found for task {task_id}") return - # 🔥 发送任务开始事件 - 使用 phase_start 让前端知道进入准备阶段 + # 发送任务开始事件 - 使用 phase_start 让前端知道进入准备阶段 await event_emitter.emit_phase_start("preparation", f"🚀 任务开始执行: {project.name}") # 更新任务阶段为准备中 @@ -292,13 +301,13 @@ async def _execute_agent_task(task_id: str): # 获取用户配置(需要在获取项目根目录之前,以便传递 token) user_config = await _get_user_config(db, task.created_by) - # 从用户配置中提取 token和SSH密钥(用于私有仓库克隆) + # 从用户配置中提取 token 和 SSH 密钥(用于私有仓库克隆) other_config = (user_config or {}).get('otherConfig', {}) github_token = other_config.get('githubToken') or settings.GITHUB_TOKEN gitlab_token = other_config.get('gitlabToken') or settings.GITLAB_TOKEN gitea_token = other_config.get('giteaToken') or settings.GITEA_TOKEN - # 解密SSH私钥 + # 解密 SSH 私钥 ssh_private_key = None if 'sshPrivateKey' in other_config: try: @@ -308,23 +317,23 @@ async def _execute_agent_task(task_id: str): except Exception as e: logger.warning(f"解密SSH私钥失败: {e}") - # 获取项目根目录(传递任务指定的分支和认证 token/SSH密钥) - # 🔥 传递 event_emitter 以发送克隆进度 + # 获取项目根目录(传递任务指定的分支和认证 token/SSH 密钥) + # 传递 event_emitter 以发送克隆进度 project_root = await _get_project_root( project, task_id, task.branch_name, github_token=github_token, gitlab_token=gitlab_token, - gitea_token=gitea_token, # 🔥 新增 - ssh_private_key=ssh_private_key, # 🔥 新增SSH密钥 - event_emitter=event_emitter, # 🔥 新增 + gitea_token=gitea_token, + ssh_private_key=ssh_private_key, + event_emitter=event_emitter, ) - # 🔥 自动修正 target_files 路径 + # 自动修正 target_files 路径 # 如果发生了目录调整(例如 ZIP 解压后只有一层目录,root 被下移), # 原有的 target_files (如 "Prefix/file.php") 可能无法匹配。 - # 我们需要检测并移除这些无效的前缀。 + # 需要检测并移除这些无效的前缀。 if task.target_files and len(task.target_files) > 0: # 1. 检查是否存在不匹配的文件 all_exist = True @@ -370,7 +379,7 @@ async def _execute_agent_task(task_id: str): await event_emitter.emit_info(f"🔧 自动修正了 {fixed_count} 个目标文件的路径") task.target_files = new_target_files - # 🔥 重新验证修正后的文件 + # 重新验证修正后的文件 valid_target_files = [] if task.target_files: for tf in task.target_files: @@ -389,7 +398,7 @@ async def _execute_agent_task(task_id: str): logger.info(f"🚀 Task {task_id} started with Dynamic Agent Tree architecture") - # 🔥 获取项目根目录后检查取消 + # 获取项目根目录后检查取消 if is_task_cancelled(task_id): logger.info(f"[Cancel] Task {task_id} cancelled after project preparation") raise asyncio.CancelledError("任务已取消") @@ -670,35 +679,50 @@ def check_global_cancel(): async def _get_user_config(db: AsyncSession, user_id: Optional[str]) -> Optional[Dict[str, Any]]: - """获取用户配置""" + """ + 获取用户配置。 + + 处理流程: + - 校验用户 ID + - 查询用户配置 + - 解密敏感字段并返回 + """ + # 无用户 ID 直接返回 if not user_id: return None try: + # 延迟导入解密工具与字段列表 from app.api.v1.endpoints.config import ( decrypt_config, SENSITIVE_LLM_FIELDS, SENSITIVE_OTHER_FIELDS ) + # 查询用户配置 result = await db.execute( select(UserConfig).where(UserConfig.user_id == user_id) ) config = result.scalar_one_or_none() if config and config.llm_config: + # 解析配置 user_llm_config = json.loads(config.llm_config) if config.llm_config else {} user_other_config = json.loads(config.other_config) if config.other_config else {} + # 解密敏感字段 user_llm_config = decrypt_config(user_llm_config, SENSITIVE_LLM_FIELDS) user_other_config = decrypt_config(user_other_config, SENSITIVE_OTHER_FIELDS) + # 返回统一配置结构 return { "llmConfig": user_llm_config, "otherConfig": user_other_config, } except Exception as e: + # 捕获异常并记录警告 logger.warning(f"Failed to get user config: {e}") + # 默认无配置 return None @@ -713,7 +737,8 @@ async def _initialize_tools( event_emitter: Optional[Any] = None, # 🔥 新增:用于发送实时日志 task_id: Optional[str] = None, # 🔥 新增:用于取消检查 ) -> Dict[str, Dict[str, Any]]: - """初始化工具集 + """ + 初始化工具集。 Args: project_root: 项目根目录 @@ -726,6 +751,7 @@ async def _initialize_tools( event_emitter: 事件发送器(用于发送实时日志) task_id: 任务 ID(用于取消检查) """ + # 导入基础工具 from app.services.agent.tools import ( FileReadTool, FileSearchTool, ListFilesTool, PatternMatchTool, CodeAnalysisTool, DataFlowAnalysisTool, @@ -737,11 +763,12 @@ async def _initialize_tools( # 🔥 RAG 工具 RAGQueryTool, SecurityCodeSearchTool, FunctionContextTool, ) + # 导入安全知识工具 from app.services.agent.knowledge import ( SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool, ) - # 🔥 RAG 相关导入 + # RAG 相关导入 from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService, IndexUpdateMode from app.core.config import settings @@ -1042,7 +1069,8 @@ async def _collect_project_info( exclude_patterns: Optional[List[str]] = None, target_files: Optional[List[str]] = None, ) -> Dict[str, Any]: - """收集项目信息 + """ + 收集项目信息。 Args: project_root: 项目根目录 @@ -1055,6 +1083,7 @@ async def _collect_project_info( """ import fnmatch + # 初始化项目信息结构 info = { "name": project_name, "root": project_root, @@ -1081,13 +1110,14 @@ async def _collect_project_info( # 目标文件集合 target_files_set = set(target_files) if target_files else None + # 文件后缀到语言映射 lang_map = { ".py": "Python", ".js": "JavaScript", ".ts": "TypeScript", ".java": "Java", ".go": "Go", ".php": "PHP", ".rb": "Ruby", ".rs": "Rust", ".c": "C", ".cpp": "C++", } - # 🔥 收集过滤后的文件列表 + # 收集过滤后的文件列表 filtered_files = [] filtered_dirs = set() @@ -1126,13 +1156,13 @@ async def _collect_project_info( if ext in lang_map and lang_map[ext] not in info["languages"]: info["languages"].append(lang_map[ext]) - # 🔥 根据是否有目标文件限制,生成不同的结构信息 + # 根据是否有目标文件限制,生成不同的结构信息 if target_files_set: # 当指定了目标文件时,只显示目标文件和相关目录 info["structure"] = { "directories": sorted(list(filtered_dirs))[:20], "files": filtered_files[:30], - "scope_limited": True, # 🔥 标记这是限定范围的视图 + "scope_limited": True, # 标记这是限定范围的视图 "scope_message": f"审计范围限定为 {len(filtered_files)} 个指定文件", } else: @@ -1148,8 +1178,10 @@ async def _collect_project_info( pass except Exception as e: + # 记录异常并返回已有信息 logger.warning(f"Failed to collect project info: {e}") + # 返回项目信息 return info @@ -1160,7 +1192,7 @@ async def _save_findings( project_root: Optional[str] = None, ) -> int: """ - 保存发现到数据库 + 保存发现到数据库。 🔥 增强版:支持多种 Agent 输出格式,健壮的字段映射 🔥 v2.1: 添加文件路径验证,过滤幻觉发现 @@ -1176,13 +1208,14 @@ async def _save_findings( """ from app.models.agent_task import VulnerabilityType + # 记录保存开始 logger.info(f"[SaveFindings] Starting to save {len(findings)} findings for task {task_id}") if not findings: logger.warning(f"[SaveFindings] No findings to save for task {task_id}") return 0 - # 🔥 Case-insensitive mapping preparation + # Case-insensitive mapping preparation severity_map = { "critical": VulnerabilitySeverity.CRITICAL, "high": VulnerabilitySeverity.HIGH, @@ -1191,6 +1224,7 @@ async def _save_findings( "info": VulnerabilitySeverity.INFO, } + # 漏洞类型映射 type_map = { "sql_injection": VulnerabilityType.SQL_INJECTION, "nosql_injection": VulnerabilityType.NOSQL_INJECTION, @@ -1212,16 +1246,18 @@ async def _save_findings( "memory_corruption": VulnerabilityType.MEMORY_CORRUPTION, } + # 初始化保存计数 saved_count = 0 logger.info(f"Saving {len(findings)} findings for task {task_id}") + # 逐条处理发现 for finding in findings: if not isinstance(finding, dict): logger.debug(f"[SaveFindings] Skipping non-dict finding: {type(finding)}") continue try: - # 🔥 Handle severity (case-insensitive, support multiple field names) + # 解析严重程度(兼容多字段) raw_severity = str( finding.get("severity") or finding.get("risk") or @@ -1229,7 +1265,7 @@ async def _save_findings( ).lower().strip() severity_enum = severity_map.get(raw_severity, VulnerabilitySeverity.MEDIUM) - # 🔥 Handle vulnerability type (case-insensitive & snake_case normalization) + # 解析漏洞类型(兼容多字段) # Support multiple field names: vulnerability_type, type, vuln_type raw_type = str( finding.get("vulnerability_type") or @@ -1416,7 +1452,15 @@ async def _save_findings( def _calculate_security_score(findings: List[Dict]) -> float: - """计算安全评分""" + """ + 计算安全评分。 + + 处理流程: + - 为空时返回满分 + - 根据严重程度累计扣分 + - 返回最终评分 + """ + # 无发现返回满分 if not findings: return 100.0 @@ -1429,26 +1473,33 @@ def _calculate_security_score(findings: List[Dict]) -> float: "info": 1, } + # 累计扣分 total_deduction = 0 for f in findings: if isinstance(f, dict): sev = f.get("severity", "low") total_deduction += deductions.get(sev, 3) + # 计算最终得分 score = max(0, 100 - total_deduction) return float(score) async def _save_agent_tree(db: AsyncSession, task_id: str) -> None: """ - 保存 Agent 树到数据库 + 保存 Agent 树到数据库。 - 🔥 在任务完成前调用,将内存中的 Agent 树持久化到数据库 + 处理流程: + - 从注册表获取树结构 + - 计算节点深度 + - 批量写入数据库 """ + # 延迟导入模型与注册表 from app.models.agent_task import AgentTreeNode from app.services.agent.core import agent_registry try: + # 获取 Agent 树结构 tree = agent_registry.get_agent_tree() nodes = tree.get("nodes", {}) @@ -1535,13 +1586,19 @@ async def create_agent_task( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 创建并启动 Agent 审计任务 + 创建并启动 Agent 审计任务。 + + 处理流程: + - 校验项目归属 + - 创建任务记录 + - 后台启动执行 """ # 验证项目 project = await db.get(Project, request.project_id) if not project: raise HTTPException(status_code=404, detail="项目不存在") + # 校验项目权限 if project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权访问此项目") @@ -1563,6 +1620,7 @@ async def create_agent_task( created_by=current_user.id, ) + # 保存任务 db.add(task) await db.commit() await db.refresh(task) @@ -1570,6 +1628,7 @@ async def create_agent_task( # 在后台启动任务(项目根目录在任务内部获取) background_tasks.add_task(_execute_agent_task, task.id) + # 记录日志 logger.info(f"Created agent task {task.id} for project {project.name}") return task @@ -1585,7 +1644,12 @@ async def list_agent_tasks( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取 Agent 任务列表 + 获取 Agent 任务列表。 + + 处理流程: + - 查询用户项目 + - 过滤任务条件 + - 返回分页结果 """ # 获取用户的项目 projects_result = await db.execute( @@ -1599,9 +1663,11 @@ async def list_agent_tasks( # 构建查询 query = select(AgentTask).where(AgentTask.project_id.in_(user_project_ids)) + # 过滤项目 if project_id: query = query.where(AgentTask.project_id == project_id) + # 过滤状态 if status: try: status_enum = AgentTaskStatus(status) @@ -1609,9 +1675,11 @@ async def list_agent_tasks( except ValueError: pass + # 排序与分页 query = query.order_by(AgentTask.created_at.desc()) query = query.offset(skip).limit(limit) + # 查询任务 result = await db.execute(query) tasks = result.scalars().all() @@ -1625,8 +1693,14 @@ async def get_agent_task( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取 Agent 任务详情 + 获取 Agent 任务详情。 + + 处理流程: + - 校验任务存在 + - 校验权限 + - 组装进度字段 """ + # 查询任务 task = await db.get(AgentTask, task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") @@ -1722,30 +1796,38 @@ async def cancel_agent_task( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 取消 Agent 任务 + 取消 Agent 任务。 + + 处理流程: + - 校验任务与权限 + - 标记取消标志并中断执行 + - 更新任务状态 """ + # 查询任务 task = await db.get(AgentTask, task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") + # 校验权限 project = await db.get(Project, task.project_id) if not project or project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权操作此任务") + # 已结束的任务不允许取消 if task.status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]: raise HTTPException(status_code=400, detail="任务已结束,无法取消") - # 🔥 0. 立即标记任务为已取消(用于前置操作的取消检查) + # 0. 立即标记任务为已取消(用于前置操作的取消检查) _cancelled_tasks.add(task_id) logger.info(f"[Cancel] Added task {task_id} to cancelled set") - # 🔥 1. 设置 Agent 的取消标志 + # 1. 设置 Agent 的取消标志 runner = _running_tasks.get(task_id) if runner: runner.cancel() logger.info(f"[Cancel] Set cancel flag for task {task_id}") - # 🔥 2. 通过 agent_registry 取消所有子 Agent + # 2. 通过 agent_registry 取消所有子 Agent from app.services.agent.core import agent_registry from app.services.agent.core.graph_controller import stop_all_agents try: @@ -1755,7 +1837,7 @@ async def cancel_agent_task( except Exception as e: logger.warning(f"[Cancel] Failed to stop agents via registry: {e}") - # 🔥 3. 强制取消 asyncio Task(立即中断 LLM 调用) + # 3. 强制取消 asyncio Task(立即中断 LLM 调用) asyncio_task = _running_asyncio_tasks.get(task_id) if asyncio_task and not asyncio_task.done(): asyncio_task.cancel() @@ -1778,18 +1860,33 @@ async def stream_agent_events( current_user: User = Depends(deps.get_current_user), ): """ - 获取 Agent 事件流 (SSE) + 获取 Agent 事件流 (SSE)。 + + 处理流程: + - 校验任务与权限 + - 轮询数据库事件 + - 按序推送 SSE 数据 """ + # 校验任务存在 task = await db.get(AgentTask, task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") + # 校验权限 project = await db.get(Project, task.project_id) if not project or project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权访问此任务") async def event_generator(): - """生成 SSE 事件流""" + """ + 生成 SSE 事件流。 + + 处理流程: + - 按序查询事件 + - 空闲超时自动退出 + - 推送心跳与状态更新 + """ + # 初始化序号与轮询状态 last_sequence = after_sequence poll_interval = 0.5 max_idle = 300 # 5 分钟无事件后关闭 @@ -1869,7 +1966,7 @@ async def stream_agent_with_thinking( current_user: User = Depends(deps.get_current_user), ): """ - 增强版事件流 (SSE) + 增强版事件流 (SSE)。 支持: - LLM 思考过程的 Token 级流式输出 (仅运行时) @@ -1880,17 +1977,27 @@ async def stream_agent_with_thinking( 优先使用内存中的事件队列 (支持 thinking_token), 如果任务未在运行,则回退到数据库轮询 (不支持 thinking_token 复盘)。 """ + # 校验任务存在 task = await db.get(AgentTask, task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") + # 校验权限 project = await db.get(Project, task.project_id) if not project or project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权访问此任务") # 定义 SSE 格式化函数 def format_sse_event(event_data: Dict[str, Any]) -> str: - """格式化为 SSE 事件""" + """ + 格式化为 SSE 事件。 + + 处理流程: + - 推导事件类型 + - 统一字段 + - 输出 SSE 格式字符串 + """ + # 获取事件类型 event_type = event_data.get("event_type") or event_data.get("type") # 统一字段 @@ -1900,7 +2007,14 @@ def format_sse_event(event_data: Dict[str, Any]) -> str: return f"event: {event_type}\ndata: {json.dumps(event_data, ensure_ascii=False)}\n\n" async def enhanced_event_generator(): - """生成增强版 SSE 事件流""" + """ + 生成增强版 SSE 事件流。 + + 处理流程: + - 优先使用内存事件队列 + - 若不存在则回退到数据库轮询 + - 按需过滤 thinking/tool 事件 + """ # 1. 检查任务是否在运行中 (内存) event_manager = _running_event_managers.get(task_id) @@ -1921,7 +2035,7 @@ async def enhanced_event_generator(): if event_type in skip_types: continue - # 🔥 Debug: 记录 thinking_token 事件 + # Debug: 记录 thinking_token 事件 if event_type == "thinking_token": token = event.get("metadata", {}).get("token", "")[:20] logger.debug(f"Stream {task_id}: Sending thinking_token: '{token}...'") @@ -1929,7 +2043,7 @@ async def enhanced_event_generator(): # 格式化并 yield yield format_sse_event(event) - # 🔥 CRITICAL: 为 thinking_token 添加微小延迟 + # 为 thinking_token 添加微小延迟 # 确保事件在不同的 TCP 包中发送,让前端能够逐个处理 # 没有这个延迟,所有 token 会在一次 read() 中被接收,导致 React 批量更新 if event_type == "thinking_token": @@ -2060,16 +2174,24 @@ async def list_agent_events( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取 Agent 事件列表 + 获取 Agent 事件列表。 + + 处理流程: + - 校验任务与权限 + - 分页查询事件 + - 返回事件序列 """ + # 校验任务存在 task = await db.get(AgentTask, task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") + # 校验权限 project = await db.get(Project, task.project_id) if not project or project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权访问此任务") + # 查询事件 result = await db.execute( select(AgentEvent) .where(AgentEvent.task_id == task_id) @@ -2079,7 +2201,7 @@ async def list_agent_events( ) events = result.scalars().all() - # 🔥 Debug logging + # Debug logging logger.debug(f"[EventsList] Task {task_id}: returning {len(events)} events (after_sequence={after_sequence})") if events: logger.debug(f"[EventsList] First event: type={events[0].event_type}, seq={events[0].sequence}") @@ -2100,18 +2222,27 @@ async def list_agent_findings( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取 Agent 发现列表 + 获取 Agent 发现列表。 + + 处理流程: + - 校验任务与权限 + - 应用过滤条件 + - 分页返回结果 """ + # 校验任务存在 task = await db.get(AgentTask, task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") + # 校验权限 project = await db.get(Project, task.project_id) if not project or project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权访问此任务") + # 构建查询 query = select(AgentFinding).where(AgentFinding.task_id == task_id) + # 按严重程度过滤 if severity: try: sev_enum = VulnerabilitySeverity(severity) @@ -2119,6 +2250,7 @@ async def list_agent_findings( except ValueError: pass + # 仅验证通过 if verified_only: query = query.where(AgentFinding.is_verified == True) @@ -2131,9 +2263,11 @@ async def list_agent_findings( VulnerabilitySeverity.INFO: 4, } + # 应用排序与分页 query = query.order_by(AgentFinding.severity, AgentFinding.created_at.desc()) query = query.offset(skip).limit(limit) + # 查询结果 result = await db.execute(query) findings = result.scalars().all() @@ -2147,12 +2281,19 @@ async def get_task_summary( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取任务摘要 + 获取任务摘要。 + + 处理流程: + - 校验任务与权限 + - 汇总发现统计 + - 生成阶段与持续时间 """ + # 校验任务存在 task = await db.get(AgentTask, task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") + # 校验权限 project = await db.get(Project, task.project_id) if not project or project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权访问此任务") @@ -2193,6 +2334,7 @@ async def get_task_summary( ) phases = [str(p[0]) for p in phases_result.fetchall() if p[0]] + # 组装响应 return TaskSummaryResponse( task_id=task_id, status=str(task.status), # status 已经是字符串 @@ -2215,24 +2357,34 @@ async def update_finding_status( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 更新发现状态 + 更新发现状态。 + + 处理流程: + - 校验输入状态 + - 校验任务与权限 + - 更新发现状态 """ + # 读取状态 status = body.get("status") if not status: raise HTTPException(status_code=400, detail="缺少 status 字段") + # 校验任务存在 task = await db.get(AgentTask, task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") + # 校验权限 project = await db.get(Project, task.project_id) if not project or project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权操作") + # 校验发现归属 finding = await db.get(AgentFinding, finding_id) if not finding or finding.task_id != task_id: raise HTTPException(status_code=404, detail="发现不存在") + # 校验状态合法性 VALID_FINDING_STATUSES = { FindingStatus.NEW, FindingStatus.ANALYZING, FindingStatus.VERIFIED, FindingStatus.FALSE_POSITIVE, FindingStatus.NEEDS_REVIEW, @@ -2241,8 +2393,10 @@ async def update_finding_status( if status not in VALID_FINDING_STATUSES: raise HTTPException(status_code=400, detail=f"无效的状态: {status}") + # 更新状态 finding.status = status + # 提交修改 await db.commit() return {"message": "状态已更新", "finding_id": finding_id, "status": status} @@ -2252,7 +2406,7 @@ async def update_finding_status( def validate_git_url(url: str) -> bool: """ - 验证 Git URL 是否安全 + 验证 Git URL 是否安全。 Args: url: Git URL @@ -2260,9 +2414,11 @@ def validate_git_url(url: str) -> bool: Returns: bool: URL 是否安全 """ + # 空值直接拒绝 if not url: return False + # 解析 URL from urllib.parse import urlparse parsed = urlparse(url) @@ -2277,11 +2433,12 @@ def validate_git_url(url: str) -> bool: if pattern in url: return False + # 通过校验 return True def validate_branch_name(branch: str) -> bool: """ - 验证 Git 分支名称是否安全 + 验证 Git 分支名称是否安全。 Args: branch: 分支名称 @@ -2289,6 +2446,7 @@ def validate_branch_name(branch: str) -> bool: Returns: bool: 分支名称是否安全 """ + # 空值直接拒绝 if not branch: return False @@ -2306,11 +2464,12 @@ def validate_branch_name(branch: str) -> bool: if len(branch) > 256: return False + # 通过校验 return True def is_path_safe(base_path: str, target_path: str) -> bool: """ - 检查目标路径是否在基础目录内(防止路径遍历) + 检查目标路径是否在基础目录内(防止路径遍历)。 Args: base_path: 基础目录 @@ -2328,17 +2487,19 @@ def is_path_safe(base_path: str, target_path: str) -> bool: def safe_extract_zip(zip_ref: zipfile.ZipFile, extract_dir: str, task_id: str) -> None: """ - 安全解压 ZIP 文件,防止 Zip Slip 攻击 + 安全解压 ZIP 文件,防止 Zip Slip 攻击。 Args: zip_ref: ZipFile 对象 extract_dir: 解压目标目录 task_id: 任务 ID(用于取消检查) """ + # 取消检查回调 def check_cancelled(): if is_task_cancelled(task_id): raise asyncio.CancelledError("任务已取消") + # 获取压缩包内文件列表 file_list = zip_ref.namelist() # 找到公共前缀 @@ -2346,6 +2507,7 @@ def check_cancelled(): common_prefix = file_list[0].split('/')[0] + '/' for i, file_name in enumerate(file_list): + # 分批检查取消 if i % 50 == 0: check_cancelled() @@ -2355,11 +2517,12 @@ def check_cancelled(): if target_path: full_target = os.path.join(extract_dir, target_path) - # 🔥 安全检查:防止路径遍历 + # 安全检查:防止路径遍历 if not is_path_safe(extract_dir, target_path): logger.warning(f"⚠️ 检测到路径遍历攻击: {file_name}") continue + # 目录/文件分流处理 if file_name.endswith('/'): os.makedirs(full_target, exist_ok=True) else: @@ -2415,11 +2578,12 @@ async def emit(message: str, level: str = "info"): elif level == "error": await event_emitter.emit_error(message) - # 🔥 辅助函数:检查取消状态 + # 辅助函数:检查取消状态 def check_cancelled(): if is_task_cancelled(task_id): raise asyncio.CancelledError("任务已取消") + # 设置任务临时目录 base_path = f"/tmp/deepaudit/{task_id}" # 确保目录存在且为空 @@ -2427,24 +2591,25 @@ def check_cancelled(): shutil.rmtree(base_path) os.makedirs(base_path, exist_ok=True) - # 🔥 在开始任何操作前检查取消 + # 在开始任何操作前检查取消 check_cancelled() # 根据项目类型处理 if project.source_type == "zip": - # 🔥 ZIP 项目:解压 ZIP 文件 - check_cancelled() # 🔥 解压前检查 + # ZIP 项目:解压 ZIP 文件 + check_cancelled() # 解压前检查 await emit(f"📦 正在解压项目文件...") from app.services.zip_storage import load_project_zip + # 加载 ZIP 路径 zip_path = await load_project_zip(project.id) if zip_path and os.path.exists(zip_path): try: - check_cancelled() # 🔥 解压前再次检查 + check_cancelled() # 解压前再次检查 with zipfile.ZipFile(zip_path, 'r') as zip_ref: - # 🔥 逐个文件解压,支持取消检查 - # 🔥 Security Fix: 使用 safe_extract_zip 替代 extract,防止 Zip Slip 和软链接攻击 + # 逐个文件解压,支持取消检查 + # 使用 safe_extract_zip 替代 extract,防止 Zip Slip 和软链接攻击 safe_extract_zip(zip_ref, base_path, task_id) logger.info(f"✅ Extracted ZIP project {project.id} to {base_path}") await emit(f"✅ ZIP 文件解压完成") @@ -2895,7 +3060,7 @@ async def get_agent_tree( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取任务的 Agent 树结构 + 获取任务的 Agent 树结构。 返回动态 Agent 树的完整结构,包括: - 所有 Agent 节点 @@ -2903,10 +3068,12 @@ async def get_agent_tree( - 执行状态 - 发现统计 """ + # 校验任务存在 task = await db.get(AgentTask, task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") + # 校验权限 project = await db.get(Project, task.project_id) if not project or project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权访问此任务") @@ -2916,6 +3083,7 @@ async def get_agent_tree( logger.debug(f"[AgentTree API] task_id={task_id}, runner exists={runner is not None}") if runner: + # 运行中任务从注册表实时读取 from app.services.agent.core import agent_registry tree = agent_registry.get_agent_tree() @@ -2923,13 +3091,13 @@ async def get_agent_tree( logger.debug(f"[AgentTree API] tree nodes={len(tree.get('nodes', {}))}, root={tree.get('root_agent_id')}") logger.debug(f"[AgentTree API] 节点详情: {list(tree.get('nodes', {}).keys())}") - # 🔥 获取 root agent ID,用于判断是否是 Orchestrator + # 获取 root agent ID,用于判断是否是 Orchestrator root_agent_id = tree.get("root_agent_id") # 构建节点列表 nodes = [] for agent_id, node_data in tree.get("nodes", {}).items(): - # 🔥 从 Agent 实例获取实时统计数据 + # 从 Agent 实例获取实时统计数据 iterations = 0 tool_calls = 0 tokens_used = 0 @@ -2942,8 +3110,8 @@ async def get_agent_tree( tool_calls = agent_stats.get("tool_calls", 0) tokens_used = agent_stats.get("tokens_used", 0) - # 🔥 FIX: 对于 Orchestrator (root agent),使用 task 的 findings_count - # 这确保了正确显示聚合的 findings 总数 + # 对于 Orchestrator (root agent),使用 task 的 findings_count + # 确保显示聚合的 findings 总数 if agent_id == root_agent_id: findings_count = task.findings_count or 0 else: @@ -2968,7 +3136,7 @@ async def get_agent_tree( children=[], )) - # 🔥 使用 task.findings_count 作为 total_findings,确保一致性 + # 使用 task.findings_count 作为 total_findings,确保一致性 return AgentTreeResponse( task_id=task_id, root_agent_id=root_agent_id, @@ -2983,6 +3151,7 @@ async def get_agent_tree( # 从数据库获取(已完成的任务) from app.models.agent_task import AgentTreeNode + # 查询持久化树结构 result = await db.execute( select(AgentTreeNode) .where(AgentTreeNode.task_id == task_id) @@ -2990,6 +3159,7 @@ async def get_agent_tree( ) db_nodes = result.scalars().all() + # 无节点直接返回空树 if not db_nodes: return AgentTreeResponse( task_id=task_id, @@ -3014,8 +3184,8 @@ async def get_agent_tree( elif node.status == "failed": failed += 1 - # 🔥 FIX: 对于 Orchestrator (root agent),使用 task 的 findings_count - # 这确保了正确显示聚合的 findings 总数 + # 对于 Orchestrator (root agent),使用 task 的 findings_count + # 确保显示聚合的 findings 总数 if node.parent_agent_id is None: # Root agent uses task's total findings node_findings_count = task.findings_count or 0 @@ -3041,7 +3211,7 @@ async def get_agent_tree( children=[], )) - # 🔥 使用 task.findings_count 作为 total_findings,确保一致性 + # 使用 task.findings_count 作为 total_findings,确保一致性 return AgentTreeResponse( task_id=task_id, root_agent_id=root_id, @@ -3084,33 +3254,41 @@ async def list_checkpoints( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取任务的检查点列表 + 获取任务的检查点列表。 用于: - 查看执行历史 - 状态恢复 - 调试分析 """ + # 校验任务存在 task = await db.get(AgentTask, task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") + # 校验权限 project = await db.get(Project, task.project_id) if not project or project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权访问此任务") + # 延迟导入检查点模型 from app.models.agent_task import AgentCheckpoint + # 构建查询 query = select(AgentCheckpoint).where(AgentCheckpoint.task_id == task_id) + # 按 agent 过滤 if agent_id: query = query.where(AgentCheckpoint.agent_id == agent_id) + # 排序并限制数量 query = query.order_by(AgentCheckpoint.created_at.desc()).limit(limit) + # 执行查询 result = await db.execute(query) checkpoints = result.scalars().all() + # 序列化为响应结构 return [ CheckpointResponse( id=cp.id, @@ -3138,20 +3316,24 @@ async def get_checkpoint_detail( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取检查点详情 + 获取检查点详情。 返回完整的 Agent 状态数据 """ + # 校验任务存在 task = await db.get(AgentTask, task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") + # 校验权限 project = await db.get(Project, task.project_id) if not project or project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权访问此任务") + # 延迟导入检查点模型 from app.models.agent_task import AgentCheckpoint + # 查询检查点 checkpoint = await db.get(AgentCheckpoint, checkpoint_id) if not checkpoint or checkpoint.task_id != task_id: raise HTTPException(status_code=404, detail="检查点不存在") @@ -3164,6 +3346,7 @@ async def get_checkpoint_detail( except json.JSONDecodeError: pass + # 构建响应 return { "id": checkpoint.id, "task_id": checkpoint.task_id, @@ -3194,14 +3377,16 @@ async def generate_audit_report( current_user: User = Depends(deps.get_current_user), ): """ - 生成审计报告 + 生成审计报告。 支持 Markdown 和 JSON 格式 """ + # 校验任务存在 task = await db.get(AgentTask, task_id) if not task: raise HTTPException(status_code=404, detail="任务不存在") + # 校验权限 project = await db.get(Project, task.project_id) if not project or project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权访问此任务") @@ -3223,7 +3408,7 @@ async def generate_audit_report( ) findings = findings.scalars().all() - # 🔥 Helper function to normalize severity for comparison (case-insensitive) + # Helper function to normalize severity for comparison (case-insensitive) def normalize_severity(sev: str) -> str: return str(sev).lower().strip() if sev else "" @@ -3308,6 +3493,7 @@ def normalize_severity(sev: str) -> str: else: duration_str = f"{int(duration)} 秒" + # 汇总输出内容 md_lines = [] # Header @@ -3396,10 +3582,12 @@ def normalize_severity(sev: str) -> str: if not severity_findings: continue + # 分级输出标题 md_lines.append(f"## {severity_name} 漏洞") md_lines.append("") for i, f in enumerate(severity_findings, 1): + # 标记验证与 PoC verified_badge = "[已验证]" if f.is_verified else "[未验证]" poc_badge = " [含 PoC]" if f.has_poc else "" @@ -3429,7 +3617,7 @@ def normalize_severity(sev: str) -> str: md_lines.append("") if f.code_snippet: - # 🔥 v2.1: 增强语言检测,避免默认 python 标记错误 + # v2.1: 增强语言检测,避免默认 python 标记错误 lang = "text" # 默认使用 text 而非 python if f.file_path: ext = f.file_path.split('.')[-1].lower() @@ -3478,6 +3666,7 @@ def normalize_severity(sev: str) -> str: 'groovy': 'groovy', 'gradle': 'groovy', } lang = lang_map.get(ext, 'text') + # 添加代码片段 md_lines.append("**漏洞代码:**") md_lines.append("") md_lines.append(f"```{lang}") @@ -3492,6 +3681,7 @@ def normalize_severity(sev: str) -> str: md_lines.append("") if f.fix_code: + # 输出修复示例 md_lines.append("**参考修复代码:**") md_lines.append("") md_lines.append(f"```{lang if f.file_path else 'text'}") @@ -3499,7 +3689,7 @@ def normalize_severity(sev: str) -> str: md_lines.append("```") md_lines.append("") - # 🔥 添加 PoC 详情 + # 添加 PoC 详情 if f.has_poc: md_lines.append("**概念验证 (PoC):**") md_lines.append("") @@ -3509,6 +3699,7 @@ def normalize_severity(sev: str) -> str: md_lines.append("") if f.poc_steps: + # 输出复现步骤 md_lines.append("**复现步骤:**") md_lines.append("") for step_idx, step in enumerate(f.poc_steps, 1): @@ -3516,6 +3707,7 @@ def normalize_severity(sev: str) -> str: md_lines.append("") if f.poc_code: + # 输出 PoC 代码 md_lines.append("**PoC 代码:**") md_lines.append("") md_lines.append("```") @@ -3554,8 +3746,10 @@ def normalize_severity(sev: str) -> str: md_lines.append("") content = "\n".join(md_lines) + # 构造下载文件名 filename = f"audit_report_{task.id[:8]}_{datetime.now().strftime('%Y%m%d')}.md" + # 返回 Markdown 下载 from fastapi.responses import Response return Response( content=content, diff --git a/backend/app/api/v1/endpoints/auth.py b/backend/app/api/v1/endpoints/auth.py index 506135b1..262799ab 100644 --- a/backend/app/api/v1/endpoints/auth.py +++ b/backend/app/api/v1/endpoints/auth.py @@ -1,3 +1,7 @@ +""" +模块说明:API 路由与依赖定义:auth。 +""" + from datetime import timedelta from typing import Any from fastapi import APIRouter, Depends, HTTPException, status @@ -17,6 +21,7 @@ router = APIRouter() class RegisterRequest(BaseModel): + """注册请求体模型。""" email: EmailStr password: str full_name: str @@ -27,22 +32,32 @@ async def login( form_data: OAuth2PasswordRequestForm = Depends() ) -> Any: """ - OAuth2 compatible token login, get an access token for future requests. - Username field should contain the email address. + OAuth2 兼容登录,返回访问令牌。 + + 处理流程: + - 通过邮箱查找用户 + - 校验密码与用户状态 + - 生成并返回访问令牌 """ + # 按邮箱查询用户 result = await db.execute(select(User).where(User.email == form_data.username)) + # 取出用户对象 user = result.scalars().first() - + # 校验用户存在性与密码正确性 if not user or not security.verify_password(form_data.password, user.hashed_password): raise HTTPException(status_code=400, detail="邮箱或密码错误") + # 校验用户是否被禁用 elif not user.is_active: raise HTTPException(status_code=400, detail="用户已被禁用") - + # 计算访问令牌过期时间 access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + # 返回访问令牌与类型 return { + # 创建访问令牌 "access_token": security.create_access_token( user.id, expires_delta=access_token_expires ), + # 固定令牌类型 "token_type": "bearer", } @@ -53,32 +68,49 @@ async def register( user_in: RegisterRequest, ) -> Any: """ - Register a new user. + 注册新用户。 + + 处理流程: + - 检查邮箱是否已被注册 + - 判断是否为首个用户 + - 创建用户并设置角色 """ - # Check if user already exists + # 检查邮箱是否已存在 result = await db.execute(select(User).where(User.email == user_in.email)) + # 提取已存在用户 existing_user = result.scalars().first() + # 若已存在则返回 400 if existing_user: raise HTTPException( status_code=400, detail="该邮箱已被注册", ) - - # Check if this is the first user (make them admin) + # 查询用户总数,用于判断是否首个用户 count_result = await db.execute(select(User)) + # 取出全部用户 all_users = count_result.scalars().all() + # 判断首个用户 is_first_user = len(all_users) == 0 - - # Create new user + # 构建新用户对象 db_user = User( + # 邮箱 email=user_in.email, + # 密码哈希 hashed_password=security.get_password_hash(user_in.password), + # 显示名 full_name=user_in.full_name, + # 启用用户 is_active=True, + # 首个用户设为超级管理员 is_superuser=is_first_user, + # 首个用户设为 admin 角色 role="admin" if is_first_user else "member", ) + # 写入数据库 db.add(db_user) + # 提交事务 await db.commit() + # 刷新对象以获取数据库生成字段 await db.refresh(db_user) + # 返回新用户 return db_user diff --git a/backend/app/api/v1/endpoints/config.py b/backend/app/api/v1/endpoints/config.py index 38e1b435..b5e99a81 100644 --- a/backend/app/api/v1/endpoints/config.py +++ b/backend/app/api/v1/endpoints/config.py @@ -28,20 +28,42 @@ def encrypt_config(config: dict, sensitive_fields: list) -> dict: - """加密配置中的敏感字段""" + """ + 加密配置中的敏感字段。 + + 处理流程: + - 复制配置,避免原地修改 + - 遍历敏感字段并加密 + - 返回加密后的配置 + """ + # 复制配置以避免修改原对象 encrypted = config.copy() + # 遍历敏感字段 for field in sensitive_fields: + # 仅在字段存在且非空时加密 if field in encrypted and encrypted[field]: encrypted[field] = encrypt_sensitive_data(encrypted[field]) + # 返回加密后的配置 return encrypted def decrypt_config(config: dict, sensitive_fields: list) -> dict: - """解密配置中的敏感字段""" + """ + 解密配置中的敏感字段。 + + 处理流程: + - 复制配置,避免原地修改 + - 遍历敏感字段并解密 + - 返回解密后的配置 + """ + # 复制配置以避免修改原对象 decrypted = config.copy() + # 遍历敏感字段 for field in sensitive_fields: + # 仅在字段存在且非空时解密 if field in decrypted and decrypted[field]: decrypted[field] = decrypt_sensitive_data(decrypted[field]) + # 返回解密后的配置 return decrypted @@ -107,7 +129,15 @@ class Config: def get_default_config() -> dict: - """获取系统默认配置""" + """ + 获取系统默认配置。 + + 处理流程: + - 从系统设置读取默认值 + - 组装 LLM 与其他配置 + - 返回默认配置字典 + """ + # 返回系统默认配置 return { "llmConfig": { "llmProvider": settings.LLM_PROVIDER, @@ -150,7 +180,14 @@ def get_default_config() -> dict: @router.get("/defaults") async def get_default_config_endpoint() -> Any: - """获取系统默认配置(无需认证)""" + """ + 获取系统默认配置(无需认证)。 + + 处理流程: + - 调用默认配置构建函数 + - 返回配置结果 + """ + # 返回系统默认配置 return get_default_config() @@ -159,15 +196,24 @@ async def get_my_config( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """获取当前用户的配置(合并用户配置和系统默认配置)""" + """ + 获取当前用户的配置(合并用户配置和系统默认配置)。 + + 处理流程: + - 查询用户配置记录 + - 获取系统默认配置 + - 解密并合并用户配置 + - 返回合并后的配置 + """ + # 查询用户配置记录 result = await db.execute( select(UserConfig).where(UserConfig.user_id == current_user.id) ) + # 获取配置对象 config = result.scalar_one_or_none() - # 获取系统默认配置 default_config = get_default_config() - + # 若用户没有保存配置则返回默认配置 if not config: print(f"[Config] 用户 {current_user.id} 没有保存的配置,返回默认配置") # 返回系统默认配置 @@ -178,23 +224,21 @@ async def get_my_config( otherConfig=default_config["otherConfig"], created_at="", ) - - # 合并用户配置和默认配置(用户配置优先) + # 读取用户配置 JSON user_llm_config = json.loads(config.llm_config) if config.llm_config else {} user_other_config = json.loads(config.other_config) if config.other_config else {} - # 解密敏感字段 user_llm_config = decrypt_config(user_llm_config, SENSITIVE_LLM_FIELDS) user_other_config = decrypt_config(user_other_config, SENSITIVE_OTHER_FIELDS) - + # 输出调试信息 print(f"[Config] 用户 {current_user.id} 的保存配置:") print(f" - llmProvider: {user_llm_config.get('llmProvider')}") print(f" - llmApiKey: {'***' + user_llm_config.get('llmApiKey', '')[-4:] if user_llm_config.get('llmApiKey') else '(空)'}") print(f" - llmModel: {user_llm_config.get('llmModel')}") - + # 合并默认配置与用户配置(用户配置优先) merged_llm_config = {**default_config["llmConfig"], **user_llm_config} merged_other_config = {**default_config["otherConfig"], **user_other_config} - + # 返回合并后的配置 return UserConfigResponse( id=config.id, user_id=config.user_id, @@ -211,20 +255,29 @@ async def update_my_config( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """更新当前用户的配置""" + """ + 更新当前用户的配置。 + + 处理流程: + - 查询现有配置记录 + - 加密敏感字段 + - 创建或更新配置记录 + - 返回合并后的配置 + """ + # 查询现有配置记录 result = await db.execute( select(UserConfig).where(UserConfig.user_id == current_user.id) ) + # 获取配置对象 config = result.scalar_one_or_none() - - # 准备要保存的配置数据(加密敏感字段) + # 提取待保存的 LLM 配置 llm_data = config_in.llmConfig.dict(exclude_none=True) if config_in.llmConfig else {} + # 提取待保存的其他配置 other_data = config_in.otherConfig.dict(exclude_none=True) if config_in.otherConfig else {} - # 加密敏感字段 llm_data_encrypted = encrypt_config(llm_data, SENSITIVE_LLM_FIELDS) other_data_encrypted = encrypt_config(other_data, SENSITIVE_OTHER_FIELDS) - + # 若不存在配置记录则创建 if not config: # 创建新配置 config = UserConfig( @@ -232,6 +285,7 @@ async def update_my_config( llm_config=json.dumps(llm_data_encrypted), other_config=json.dumps(other_data_encrypted), ) + # 写入数据库 db.add(config) else: # 更新现有配置 @@ -239,31 +293,34 @@ async def update_my_config( existing_llm = json.loads(config.llm_config) if config.llm_config else {} # 先解密现有数据,再合并新数据,最后加密 existing_llm = decrypt_config(existing_llm, SENSITIVE_LLM_FIELDS) + # 合并新数据 existing_llm.update(llm_data) # 使用未加密的新数据合并 + # 重新加密并保存 config.llm_config = json.dumps(encrypt_config(existing_llm, SENSITIVE_LLM_FIELDS)) - + # 更新其他配置 if config_in.otherConfig: existing_other = json.loads(config.other_config) if config.other_config else {} # 先解密现有数据,再合并新数据,最后加密 existing_other = decrypt_config(existing_other, SENSITIVE_OTHER_FIELDS) + # 合并新数据 existing_other.update(other_data) # 使用未加密的新数据合并 + # 重新加密并保存 config.other_config = json.dumps(encrypt_config(existing_other, SENSITIVE_OTHER_FIELDS)) - + # 提交事务 await db.commit() + # 刷新对象 await db.refresh(config) - # 获取系统默认配置并合并(与 get_my_config 保持一致) default_config = get_default_config() user_llm_config = json.loads(config.llm_config) if config.llm_config else {} user_other_config = json.loads(config.other_config) if config.other_config else {} - # 解密后返回给前端 user_llm_config = decrypt_config(user_llm_config, SENSITIVE_LLM_FIELDS) user_other_config = decrypt_config(user_other_config, SENSITIVE_OTHER_FIELDS) - + # 合并默认配置与用户配置 merged_llm_config = {**default_config["llmConfig"], **user_llm_config} merged_other_config = {**default_config["otherConfig"], **user_other_config} - + # 返回更新后的配置 return UserConfigResponse( id=config.id, user_id=config.user_id, @@ -279,16 +336,25 @@ async def delete_my_config( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """删除当前用户的配置(恢复为默认)""" + """ + 删除当前用户的配置(恢复为默认)。 + + 处理流程: + - 查询用户配置记录 + - 删除并提交 + - 返回删除结果 + """ + # 查询用户配置记录 result = await db.execute( select(UserConfig).where(UserConfig.user_id == current_user.id) ) + # 获取配置对象 config = result.scalar_one_or_none() - + # 若存在配置则删除 if config: await db.delete(config) await db.commit() - + # 返回删除结果 return {"message": "配置已删除"} @@ -316,24 +382,33 @@ async def test_llm_connection( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """测试LLM连接是否正常""" + """ + 测试 LLM 连接是否正常。 + + 处理流程: + - 读取并解密用户配置 + - 计算调试参数与默认值 + - 创建对应适配器并发送测试请求 + - 返回测试结果与调试信息 + """ + # 延迟导入 LLM 相关组件 from app.services.llm.factory import LLMFactory, NATIVE_ONLY_PROVIDERS from app.services.llm.adapters import LiteLLMAdapter, BaiduAdapter, MinimaxAdapter, DoubaoAdapter from app.services.llm.types import LLMConfig, LLMProvider, LLMRequest, LLMMessage, DEFAULT_MODELS, DEFAULT_BASE_URLS import traceback import time - + # 记录开始时间 start_time = time.time() - # 获取用户保存的配置 result = await db.execute( select(UserConfig).where(UserConfig.user_id == current_user.id) ) + # 获取配置记录 user_config_record = result.scalar_one_or_none() - # 解析用户配置 saved_llm_config = {} saved_other_config = {} + # 若存在配置则解密字段 if user_config_record: if user_config_record.llm_config: saved_llm_config = decrypt_config( @@ -345,7 +420,6 @@ async def test_llm_connection( json.loads(user_config_record.other_config), SENSITIVE_OTHER_FIELDS ) - # 从保存的配置中获取参数(用于调试显示) saved_timeout_ms = saved_llm_config.get('llmTimeout', settings.LLM_TIMEOUT * 1000) saved_temperature = saved_llm_config.get('llmTemperature', settings.LLM_TEMPERATURE) @@ -354,7 +428,7 @@ async def test_llm_connection( saved_gap_ms = saved_other_config.get('llmGapMs', settings.LLM_GAP_MS) saved_max_files = saved_other_config.get('maxAnalyzeFiles', settings.MAX_ANALYZE_FILES) saved_output_lang = saved_other_config.get('outputLanguage', settings.OUTPUT_LANGUAGE) - + # 组装调试信息 debug_info = { "provider": request.provider, "model_requested": request.model, @@ -374,7 +448,7 @@ async def test_llm_connection( } try: - # 解析provider + # 解析 provider 映射 provider_map = { 'gemini': LLMProvider.GEMINI, 'openai': LLMProvider.OPENAI, @@ -389,7 +463,9 @@ async def test_llm_connection( 'ollama': LLMProvider.OLLAMA, } + # 获取匹配的 provider provider = provider_map.get(request.provider.lower()) + # 若 provider 不支持则返回失败 if not provider: debug_info["error_type"] = "unsupported_provider" return LLMTestResponse( @@ -397,16 +473,14 @@ async def test_llm_connection( message=f"不支持的LLM提供商: {request.provider}", debug=debug_info ) - - # 获取默认模型 + # 获取默认模型与默认 Base URL model = request.model or DEFAULT_MODELS.get(provider) base_url = request.baseUrl or DEFAULT_BASE_URLS.get(provider, "") - - # 测试时使用用户保存的所有配置参数 + # 测试时使用用户保存的配置参数 test_timeout = int(saved_timeout_ms / 1000) if saved_timeout_ms else settings.LLM_TIMEOUT test_temperature = saved_temperature if saved_temperature is not None else settings.LLM_TEMPERATURE test_max_tokens = saved_max_tokens if saved_max_tokens else settings.LLM_MAX_TOKENS - + # 记录调试参数 debug_info["model_used"] = model debug_info["base_url_used"] = base_url debug_info["is_native_adapter"] = provider in NATIVE_ONLY_PROVIDERS @@ -415,10 +489,9 @@ async def test_llm_connection( "temperature": test_temperature, "max_tokens": test_max_tokens, } - + # 打印测试请求日志 print(f"[LLM Test] 开始测试: provider={provider.value}, model={model}, base_url={base_url}, temperature={test_temperature}, timeout={test_timeout}s, max_tokens={test_max_tokens}") - - # 创建配置 + # 创建 LLM 配置 config = LLMConfig( provider=provider, api_key=request.apiKey, @@ -428,8 +501,7 @@ async def test_llm_connection( temperature=test_temperature, max_tokens=test_max_tokens, ) - - # 直接创建新的适配器实例(不使用缓存),确保使用最新的配置 + # 直接创建新的适配器实例(不使用缓存) if provider in NATIVE_ONLY_PROVIDERS: native_adapter_map = { LLMProvider.BAIDU: BaiduAdapter, @@ -443,7 +515,7 @@ async def test_llm_connection( debug_info["adapter_type"] = "LiteLLMAdapter" # 获取 LiteLLM 实际使用的模型名 debug_info["litellm_model"] = getattr(adapter, '_get_litellm_model', lambda: model)() if hasattr(adapter, '_get_litellm_model') else model - + # 构建测试请求 test_request = LLMRequest( messages=[ LLMMessage(role="user", content="Say 'Hello' in one word.") @@ -451,33 +523,41 @@ async def test_llm_connection( temperature=test_temperature, max_tokens=test_max_tokens, ) - + # 发送测试请求 print(f"[LLM Test] 发送测试请求...") response = await adapter.complete(test_request) - + # 计算耗时 elapsed_time = time.time() - start_time debug_info["elapsed_time_ms"] = round(elapsed_time * 1000, 2) # 验证响应内容 if not response or not response.content: + # 标记为空响应错误类型 debug_info["error_type"] = "empty_response" + # 保存原始响应便于排查 debug_info["raw_response"] = str(response) if response else None + # 记录日志 print(f"[LLM Test] 空响应: {response}") + # 返回失败结果 return LLMTestResponse( success=False, message="LLM 返回空响应,请检查 API Key 和配置", debug=debug_info ) + # 记录响应长度 debug_info["response_length"] = len(response.content) + # 记录 token 使用情况 debug_info["usage"] = { "prompt_tokens": getattr(response, 'prompt_tokens', None), "completion_tokens": getattr(response, 'completion_tokens', None), "total_tokens": getattr(response, 'total_tokens', None), } + # 输出成功日志 print(f"[LLM Test] 成功! 响应: {response.content[:50]}... 耗时: {elapsed_time:.2f}s") + # 返回成功响应 return LLMTestResponse( success=True, message=f"连接成功 ({elapsed_time:.2f}s)", @@ -487,27 +567,35 @@ async def test_llm_connection( ) except Exception as e: + # 计算耗时 elapsed_time = time.time() - start_time + # 解析错误信息 error_msg = str(e) + # 获取错误类型 error_type = type(e).__name__ + # 更新调试信息 debug_info["elapsed_time_ms"] = round(elapsed_time * 1000, 2) debug_info["error_type"] = error_type debug_info["error_message"] = error_msg debug_info["traceback"] = traceback.format_exc() + # 提取 LLMError 中的 api_response # 提取 LLMError 中的 api_response if hasattr(e, 'api_response') and e.api_response: debug_info["api_response"] = e.api_response if hasattr(e, 'status_code') and e.status_code: debug_info["status_code"] = e.status_code + # 打印失败日志 print(f"[LLM Test] 失败: {error_type}: {error_msg}") print(f"[LLM Test] Traceback:\n{traceback.format_exc()}") + # 构造更友好的错误信息 # 提供更友好的错误信息 friendly_message = error_msg + # 根据错误关键字分类 # 优先检查余额不足(因为某些 API 用 429 表示余额不足) if any(keyword in error_msg for keyword in ["余额不足", "资源包", "充值", "quota", "insufficient", "balance", "402"]): friendly_message = "账户余额不足或配额已用尽,请充值后重试" @@ -533,6 +621,7 @@ async def test_llm_connection( else: debug_info["error_category"] = "unknown" + # 返回失败响应 return LLMTestResponse( success=False, message=friendly_message, @@ -542,12 +631,22 @@ async def test_llm_connection( @router.get("/llm-providers") async def get_llm_providers() -> Any: - """获取支持的LLM提供商列表""" + """ + 获取支持的 LLM 提供商列表。 + + 处理流程: + - 读取系统支持的 provider + - 组装默认模型与可用模型 + - 返回 provider 列表 + """ + # 延迟导入工厂与常量 from app.services.llm.factory import LLMFactory from app.services.llm.types import LLMProvider, DEFAULT_BASE_URLS - + # 构建 provider 列表 providers = [] + # 遍历所有支持的 provider for provider in LLMFactory.get_supported_providers(): + # 追加 provider 配置信息 providers.append({ "id": provider.value, "name": provider.value.upper(), @@ -555,6 +654,5 @@ async def get_llm_providers() -> Any: "models": LLMFactory.get_available_models(provider), "defaultBaseUrl": DEFAULT_BASE_URLS.get(provider, ""), }) - + # 返回 provider 列表 return {"providers": providers} - diff --git a/backend/app/api/v1/endpoints/database.py b/backend/app/api/v1/endpoints/database.py index 079273ce..73ee66d7 100644 --- a/backend/app/api/v1/endpoints/database.py +++ b/backend/app/api/v1/endpoints/database.py @@ -24,7 +24,7 @@ class DatabaseExportResponse(BaseModel): - """数据库导出响应""" + """数据库导出响应模型。""" export_date: str user_id: str data: Dict[str, Any] @@ -39,8 +39,12 @@ async def export_database( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 导出当前用户的所有数据 - 包括:项目、任务、问题、即时分析、用户配置 + 导出当前用户的所有数据。 + + 处理流程: + - 查询项目、任务、问题、即时分析与配置 + - 构建标准化导出结构 + - 返回导出响应 """ try: # 1. 获取用户的所有项目 @@ -182,6 +186,7 @@ async def export_database( ], } + # 返回导出结果 return DatabaseExportResponse( export_date=export_data["export_date"], user_id=current_user.id, @@ -189,12 +194,13 @@ async def export_database( ) except Exception as e: + # 记录错误并返回异常 print(f"导出数据失败: {e}") raise HTTPException(status_code=500, detail=f"导出数据失败: {str(e)}") class DatabaseImportRequest(BaseModel): - """数据库导入请求""" + """数据库导入请求模型。""" data: Dict[str, Any] @@ -205,8 +211,12 @@ async def import_database( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 从 JSON 文件导入数据 - 注意:导入会合并数据,不会删除现有数据 + 从 JSON 文件导入数据。 + + 处理流程: + - 读取并解析 JSON + - 校验用户身份 + - 逐类导入并合并数据 """ try: # 读取文件内容 @@ -216,12 +226,14 @@ async def import_database( if not isinstance(import_data, dict) or "data" not in import_data: raise HTTPException(status_code=400, detail="无效的导入文件格式") + # 提取数据载荷 data = import_data["data"] - # 验证用户ID(只能导入自己的数据) + # 验证用户 ID(只能导入自己的数据) if data.get("user", {}).get("id") != current_user.id: raise HTTPException(status_code=403, detail="只能导入自己的数据") + # 初始化导入计数 imported_count = { "projects": 0, "tasks": 0, @@ -235,6 +247,7 @@ async def import_database( for p_data in data["projects"]: existing = await db.get(Project, p_data.get("id")) if not existing: + # 构建项目对象 project = Project( id=p_data.get("id"), name=p_data.get("name"), @@ -250,6 +263,7 @@ async def import_database( db.add(project) imported_count["projects"] += 1 + # 提交项目导入 await db.commit() # 2. 导入任务(需要先有项目) @@ -260,6 +274,7 @@ async def import_database( # 检查项目是否存在 project = await db.get(Project, t_data.get("project_id")) if project: + # 构建任务对象 task = AuditTask( id=t_data.get("id"), project_id=t_data.get("project_id"), @@ -278,6 +293,7 @@ async def import_database( db.add(task) imported_count["tasks"] += 1 + # 提交任务导入 await db.commit() # 3. 导入问题(需要先有任务) @@ -288,6 +304,7 @@ async def import_database( # 检查任务是否存在 task = await db.get(AuditTask, i_data.get("task_id")) if task: + # 构建问题对象 issue = AuditIssue( id=i_data.get("id"), task_id=i_data.get("task_id"), @@ -307,6 +324,7 @@ async def import_database( db.add(issue) imported_count["issues"] += 1 + # 提交问题导入 await db.commit() # 4. 导入即时分析 @@ -314,6 +332,7 @@ async def import_database( for a_data in data["instant_analyses"]: existing = await db.get(InstantAnalysis, a_data.get("id")) if not existing: + # 构建分析记录 analysis = InstantAnalysis( id=a_data.get("id"), user_id=current_user.id, @@ -327,6 +346,7 @@ async def import_database( db.add(analysis) imported_count["analyses"] += 1 + # 提交分析导入 await db.commit() # 5. 导入用户配置(合并) @@ -338,6 +358,7 @@ async def import_database( config = config_result.scalar_one_or_none() if not config: + # 创建配置记录 config = UserConfig( user_id=current_user.id, llm_config=json.dumps(data["user_config"].get("llm_config", {})), @@ -355,8 +376,10 @@ async def import_database( imported_count["config"] = 1 + # 提交配置导入 await db.commit() + # 返回导入统计 return { "message": "数据导入成功", "imported": imported_count @@ -365,6 +388,7 @@ async def import_database( except json.JSONDecodeError: raise HTTPException(status_code=400, detail="无效的 JSON 文件格式") except Exception as e: + # 记录错误并回滚 print(f"导入数据失败: {e}") await db.rollback() raise HTTPException(status_code=500, detail=f"导入数据失败: {str(e)}") @@ -376,10 +400,15 @@ async def clear_database( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 清空当前用户的所有数据 - 注意:此操作不可恢复,请谨慎使用 + 清空当前用户的所有数据。 + + 处理流程: + - 删除问题、任务、项目、分析与配置 + - 统计删除数量 + - 提交事务并返回结果 """ try: + # 初始化删除计数 deleted_count = { "projects": 0, "tasks": 0, @@ -450,21 +479,24 @@ async def clear_database( for member in members: await db.delete(member) + # 提交删除 await db.commit() + # 返回删除结果 return { "message": "数据已清空", "deleted": deleted_count } except Exception as e: + # 记录错误并回滚 print(f"清空数据失败: {e}") await db.rollback() raise HTTPException(status_code=500, detail=f"清空数据失败: {str(e)}") class DatabaseStatsResponse(BaseModel): - """数据库统计信息响应""" + """数据库统计信息响应模型。""" total_projects: int active_projects: int total_tasks: int @@ -490,7 +522,12 @@ async def get_database_stats( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取当前用户的数据库统计信息 + 获取当前用户的数据库统计信息。 + + 处理流程: + - 统计项目、任务与问题 + - 统计即时分析与成员 + - 汇总并返回 """ try: # 1. 项目统计 @@ -561,6 +598,7 @@ async def get_database_stats( ) has_config = config_result.scalar_one_or_none() is not None + # 返回统计结果 return DatabaseStatsResponse( total_projects=total_projects, active_projects=active_projects, @@ -582,12 +620,13 @@ async def get_database_stats( ) except Exception as e: + # 记录错误并返回异常 print(f"获取统计信息失败: {e}") raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}") class DatabaseHealthResponse(BaseModel): - """数据库健康检查响应""" + """数据库健康检查响应模型。""" status: str # healthy, warning, error database_connected: bool total_records: int @@ -602,9 +641,16 @@ async def check_database_health( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 检查数据库健康状态 + 检查数据库健康状态。 + + 处理流程: + - 校验数据库连接 + - 统计记录数量 + - 检查数据完整性 + - 返回健康状态 """ try: + # 初始化状态与指标 issues = [] warnings = [] database_connected = True @@ -679,6 +725,7 @@ async def check_database_health( else: status = "healthy" + # 返回健康检查结果 return DatabaseHealthResponse( status=status, database_connected=database_connected, @@ -689,6 +736,6 @@ async def check_database_health( ) except Exception as e: + # 记录错误并返回异常 print(f"健康检查失败: {e}") raise HTTPException(status_code=500, detail=f"健康检查失败: {str(e)}") - diff --git a/backend/app/api/v1/endpoints/embedding_config.py b/backend/app/api/v1/endpoints/embedding_config.py index 4a1e68d6..573a15c1 100644 --- a/backend/app/api/v1/endpoints/embedding_config.py +++ b/backend/app/api/v1/endpoints/embedding_config.py @@ -176,18 +176,32 @@ class TestEmbeddingResponse(BaseModel): async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> EmbeddingConfig: - """从数据库获取嵌入配置(异步)""" + """ + 从数据库获取嵌入配置(异步)。 + + 处理流程: + - 查询用户配置记录 + - 解析 other_config 并读取嵌入配置 + - 若无配置则返回默认值 + """ + # 查询用户配置记录 result = await db.execute( select(UserConfig).where(UserConfig.user_id == user_id) ) + # 获取用户配置对象 user_config = result.scalar_one_or_none() + # 若存在 other_config 则尝试解析 if user_config and user_config.other_config: try: + # 解析 other_config other_config = json.loads(user_config.other_config) if isinstance(user_config.other_config, str) else user_config.other_config + # 读取嵌入配置数据 embedding_data = other_config.get(EMBEDDING_CONFIG_KEY) + # 若存在嵌入配置则构建对象 if embedding_data: + # 组装嵌入配置对象 config = EmbeddingConfig( provider=embedding_data.get("provider", settings.EMBEDDING_PROVIDER), model=embedding_data.get("model", settings.EMBEDDING_MODEL), @@ -196,9 +210,12 @@ async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> Embedd dimensions=embedding_data.get("dimensions"), batch_size=embedding_data.get("batch_size", 100), ) + # 记录读取日志 print(f"[EmbeddingConfig] 读取用户 {user_id} 的嵌入配置: provider={config.provider}, model={config.model}") + # 返回用户配置 return config except (json.JSONDecodeError, AttributeError) as e: + # 记录解析失败日志 print(f"[EmbeddingConfig] 解析用户 {user_id} 配置失败: {e}") # 返回默认配置 @@ -213,10 +230,20 @@ async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> Embedd async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: EmbeddingConfig) -> None: - """保存嵌入配置到数据库(异步)""" + """ + 保存嵌入配置到数据库(异步)。 + + 处理流程: + - 查询用户配置记录 + - 构建嵌入配置数据 + - 更新或创建配置记录 + - 提交事务 + """ + # 查询用户配置记录 result = await db.execute( select(UserConfig).where(UserConfig.user_id == user_id) ) + # 获取用户配置对象 user_config = result.scalar_one_or_none() # 准备嵌入配置数据 @@ -236,7 +263,9 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em except (json.JSONDecodeError, TypeError): other_config = {} + # 写入嵌入配置 other_config[EMBEDDING_CONFIG_KEY] = embedding_data + # 保存到 other_config 字段 user_config.other_config = json.dumps(other_config) # 🔥 显式标记 other_config 字段已修改,确保 SQLAlchemy 检测到变化 flag_modified(user_config, "other_config") @@ -248,9 +277,12 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em llm_config="{}", other_config=json.dumps({EMBEDDING_CONFIG_KEY: embedding_data}), ) + # 新配置入库 db.add(user_config) + # 提交事务 await db.commit() + # 记录保存日志 print(f"[EmbeddingConfig] 已保存用户 {user_id} 的嵌入配置: provider={config.provider}, model={config.model}") @@ -261,8 +293,13 @@ async def list_embedding_providers( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取可用的嵌入模型提供商列表 + 获取可用的嵌入模型提供商列表。 + + 处理流程: + - 依赖鉴权 + - 返回静态提供商列表 """ + # 返回提供商列表 return EMBEDDING_PROVIDERS @@ -272,13 +309,20 @@ async def get_current_config( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取当前嵌入模型配置(从数据库读取) + 获取当前嵌入模型配置(从数据库读取)。 + + 处理流程: + - 查询数据库配置 + - 计算默认维度 + - 返回配置响应 """ + # 读取用户配置 config = await get_embedding_config_from_db(db, current_user.id) # 获取维度:优先使用用户配置的维度,否则使用默认值 dimensions = config.dimensions if config.dimensions else _get_model_dimensions(config.provider, config.model) + # 返回配置响应 return EmbeddingConfigResponse( provider=config.provider, model=config.model, @@ -296,7 +340,12 @@ async def update_config( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 更新嵌入模型配置(持久化到数据库) + 更新嵌入模型配置(持久化到数据库)。 + + 处理流程: + - 校验提供商合法性 + - 检查 API Key 要求 + - 保存配置 """ # 验证提供商 provider_ids = [p.id for p in EMBEDDING_PROVIDERS] @@ -314,6 +363,7 @@ async def update_config( # 保存到数据库 await save_embedding_config_to_db(db, current_user.id, config) + # 返回保存结果 return {"message": "配置已保存", "provider": config.provider, "model": config.model} @@ -323,14 +373,24 @@ async def test_embedding( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 测试嵌入模型配置 + 测试嵌入模型配置。 + + 处理流程: + - 初始化嵌入服务 + - 执行测试嵌入 + - 固定响应时间防止时间侧信道 + - 返回测试结果 """ + # 固定响应时间,防止 SSRF 时间侧信道攻击 FIXED_DURATION = 3.0 # 固定响应时间,防止SSRF时间侧信道攻击 + # 记录开始时间 start_time = time.time() try: + # 延迟导入嵌入服务 from app.services.rag.embeddings import EmbeddingService + # 创建嵌入服务实例 service = EmbeddingService( provider=request.provider, model=request.model, @@ -340,13 +400,17 @@ async def test_embedding( cache_enabled=False, ) + # 生成嵌入向量 embedding = await service.embed(request.test_text) + # 计算耗时 elapsed = time.time() - start_time + # 记录实际延迟 latency_ms = int(elapsed * 1000) # 在sleep前计算实际延迟 + # 若耗时不足则补足固定时长 if elapsed < FIXED_DURATION: await asyncio.sleep(FIXED_DURATION - elapsed) - + # 返回成功结果 return TestEmbeddingResponse( success=True, message=f"嵌入成功! 维度: {len(embedding)}", @@ -354,13 +418,15 @@ async def test_embedding( sample_embedding=embedding[:5], # 返回前 5 维 latency_ms=latency_ms, ) - + except Exception as e: + # 发生异常时也同样等待,确保时间特征一致 # 发生异常时也同样等待,确保时间特征一致 elapsed = time.time() - start_time if elapsed < FIXED_DURATION: await asyncio.sleep(FIXED_DURATION - elapsed) + # 返回失败结果 return TestEmbeddingResponse( success=False, message=f"嵌入失败: {str(e)}", @@ -373,13 +439,19 @@ async def get_provider_models( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取指定提供商的模型列表 + 获取指定提供商的模型列表。 + + 处理流程: + - 查询提供商信息 + - 校验存在性 + - 返回模型信息 """ + # 查找提供商信息 provider_info = next((p for p in EMBEDDING_PROVIDERS if p.id == provider), None) - + # 若提供商不存在则返回 404 if not provider_info: raise HTTPException(status_code=404, detail=f"提供商不存在: {provider}") - + # 返回提供商模型信息 return { "provider": provider, "models": provider_info.models, @@ -389,7 +461,14 @@ async def get_provider_models( def _get_model_dimensions(provider: str, model: str) -> int: - """获取模型维度""" + """ + 获取模型维度。 + + 处理流程: + - 通过模型名称查表 + - 若无匹配则返回默认值 + """ + # 模型维度映射表 dimensions_map = { # OpenAI "text-embedding-3-small": 1536, @@ -431,5 +510,5 @@ def _get_model_dimensions(provider: str, model: str) -> int: "text-embedding-v2": 1536, # 支持维度: 1536 } + # 返回模型维度,默认 768 return dimensions_map.get(model, 768) - diff --git a/backend/app/api/v1/endpoints/members.py b/backend/app/api/v1/endpoints/members.py index 0bb8ae68..bb6ae7af 100644 --- a/backend/app/api/v1/endpoints/members.py +++ b/backend/app/api/v1/endpoints/members.py @@ -1,3 +1,7 @@ +""" +模块说明:API 路由与依赖定义:members。 +""" + from typing import Any, List, Optional from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession @@ -14,8 +18,8 @@ router = APIRouter() -# Schemas class UserSchema(BaseModel): + """用户基础信息的响应模型。""" id: str email: Optional[str] = None full_name: Optional[str] = None @@ -27,6 +31,7 @@ class Config: class ProjectMemberSchema(BaseModel): + """项目成员的响应模型,包含用户信息。""" id: str project_id: str user_id: str @@ -41,11 +46,13 @@ class Config: class AddMemberRequest(BaseModel): + """新增项目成员的请求体模型。""" user_id: str role: str = "member" class UpdateMemberRequest(BaseModel): + """更新项目成员角色与权限的请求体模型。""" role: Optional[str] = None permissions: Optional[str] = None @@ -57,19 +64,30 @@ async def get_project_members( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Get all members of a project. + 获取指定项目的成员列表。 + + 处理流程: + - 校验项目是否存在 + - 查询成员并加载关联用户信息 + - 按加入时间倒序返回 """ - # Verify project exists + # 读取项目,确认项目是否存在 project = await db.get(Project, project_id) + # 若项目不存在则返回 404 if not project: raise HTTPException(status_code=404, detail="项目不存在") - + # 查询项目成员并预加载用户信息 result = await db.execute( + # 构造成员查询 select(ProjectMember) + # 预加载 user 关联,避免 N+1 .options(selectinload(ProjectMember.user)) + # 过滤当前项目 .where(ProjectMember.project_id == project_id) + # 按加入时间倒序 .order_by(ProjectMember.joined_at.desc()) ) + # 返回成员列表 return result.scalars().all() @@ -81,50 +99,68 @@ async def add_project_member( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Add a member to a project. + 向指定项目添加成员。 + + 处理流程: + - 校验项目是否存在 + - 校验当前用户权限 + - 校验目标用户是否存在 + - 校验是否已是成员 + - 创建成员记录并返回 """ - # Verify project exists + # 读取项目,确认项目是否存在 project = await db.get(Project, project_id) + # 若项目不存在则返回 404 if not project: raise HTTPException(status_code=404, detail="项目不存在") - - # Check if user is project owner or admin + # 校验当前用户是否为项目所有者或超级管理员 if project.owner_id != current_user.id and not current_user.is_superuser: raise HTTPException(status_code=403, detail="权限不足") - - # Check if user exists + # 读取目标用户,确认用户是否存在 user = await db.get(User, member_in.user_id) + # 若用户不存在则返回 404 if not user: raise HTTPException(status_code=404, detail="用户不存在") - - # Check if already a member + # 查询是否已存在成员记录 existing = await db.execute( + # 构造成员查询条件 select(ProjectMember) + # 同时匹配项目与用户 .where( ProjectMember.project_id == project_id, ProjectMember.user_id == member_in.user_id ) ) + # 若已有成员记录则返回 400 if existing.scalars().first(): raise HTTPException(status_code=400, detail="用户已是项目成员") - - # Create member + # 构建成员记录 member = ProjectMember( + # 绑定项目 project_id=project_id, + # 绑定用户 user_id=member_in.user_id, + # 角色信息 role=member_in.role, + # 权限 JSON 字符串占位 permissions="{}" ) + # 追加到会话 db.add(member) + # 提交事务 await db.commit() + # 刷新以获取数据库生成字段 await db.refresh(member) - - # Reload with user relationship + # 重新加载成员并预加载用户关系 result = await db.execute( + # 构造成员查询 select(ProjectMember) + # 预加载 user 关联 .options(selectinload(ProjectMember.user)) + # 根据成员 ID 定位 .where(ProjectMember.id == member.id) ) + # 返回新成员 return result.scalars().first() @@ -137,41 +173,55 @@ async def update_project_member( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Update a project member's role or permissions. + 更新项目成员的角色或权限。 + + 处理流程: + - 校验项目是否存在 + - 校验当前用户权限 + - 获取成员记录 + - 更新字段并保存 + - 返回更新后的成员信息 """ - # Verify project exists + # 读取项目,确认项目是否存在 project = await db.get(Project, project_id) + # 若项目不存在则返回 404 if not project: raise HTTPException(status_code=404, detail="项目不存在") - - # Check permissions + # 校验当前用户是否为项目所有者或超级管理员 if project.owner_id != current_user.id and not current_user.is_superuser: raise HTTPException(status_code=403, detail="权限不足") - - # Get member + # 查询目标成员记录 result = await db.execute( + # 构造成员查询 select(ProjectMember) + # 同时匹配成员与项目 .where(ProjectMember.id == member_id, ProjectMember.project_id == project_id) ) + # 获取成员对象 member = result.scalars().first() + # 若成员不存在则返回 404 if not member: raise HTTPException(status_code=404, detail="成员不存在") - - # Update fields + # 如果传入角色则更新角色 if member_update.role: member.role = member_update.role + # 如果传入权限则更新权限 if member_update.permissions: member.permissions = member_update.permissions - + # 提交事务 await db.commit() + # 刷新对象 await db.refresh(member) - - # Reload with user relationship + # 重新加载成员并预加载用户关系 result = await db.execute( + # 构造成员查询 select(ProjectMember) + # 预加载 user 关联 .options(selectinload(ProjectMember.user)) + # 根据成员 ID 定位 .where(ProjectMember.id == member.id) ) + # 返回更新后的成员 return result.scalars().first() @@ -183,33 +233,42 @@ async def remove_project_member( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Remove a member from a project. + 从项目中移除成员。 + + 处理流程: + - 校验项目是否存在 + - 校验当前用户权限 + - 获取成员记录 + - 删除成员并提交 """ - # Verify project exists + # 读取项目,确认项目是否存在 project = await db.get(Project, project_id) + # 若项目不存在则返回 404 if not project: raise HTTPException(status_code=404, detail="项目不存在") - - # Check permissions + # 校验当前用户是否为项目所有者或超级管理员 if project.owner_id != current_user.id and not current_user.is_superuser: raise HTTPException(status_code=403, detail="权限不足") - - # Get member + # 查询目标成员记录 result = await db.execute( + # 构造成员查询 select(ProjectMember) + # 同时匹配成员与项目 .where(ProjectMember.id == member_id, ProjectMember.project_id == project_id) ) + # 获取成员对象 member = result.scalars().first() + # 若成员不存在则返回 404 if not member: raise HTTPException(status_code=404, detail="成员不存在") - + # 删除成员记录 await db.delete(member) + # 提交事务 await db.commit() - + # 返回删除结果 return {"message": "成员已移除"} - diff --git a/backend/app/api/v1/endpoints/projects.py b/backend/app/api/v1/endpoints/projects.py index 4cad10a4..51659b21 100644 --- a/backend/app/api/v1/endpoints/projects.py +++ b/backend/app/api/v1/endpoints/projects.py @@ -1,3 +1,7 @@ +""" +模块说明:API 路由与依赖定义:projects。 +""" + from typing import Any, List, Optional from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, UploadFile, File from fastapi.responses import FileResponse @@ -29,6 +33,7 @@ # Schemas class ProjectCreate(BaseModel): + """创建项目的请求模型。""" name: str source_type: Optional[str] = "repository" # 'repository' 或 'zip' repository_url: Optional[str] = None @@ -38,6 +43,7 @@ class ProjectCreate(BaseModel): programming_languages: Optional[List[str]] = None class ProjectUpdate(BaseModel): + """更新项目的请求模型。""" name: Optional[str] = None source_type: Optional[str] = None repository_url: Optional[str] = None @@ -47,6 +53,7 @@ class ProjectUpdate(BaseModel): programming_languages: Optional[List[str]] = None class OwnerSchema(BaseModel): + """项目所有者信息模型。""" id: str email: Optional[str] = None full_name: Optional[str] = None @@ -57,6 +64,7 @@ class Config: from_attributes = True class ProjectResponse(BaseModel): + """项目响应模型。""" id: str name: str description: Optional[str] = None @@ -75,6 +83,7 @@ class Config: from_attributes = True class StatsResponse(BaseModel): + """项目统计响应模型。""" total_projects: int active_projects: int total_tasks: int @@ -91,12 +100,18 @@ async def create_project( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Create new project. + 创建新项目。 + + 处理流程: + - 根据 source_type 计算默认值 + - 构建项目对象并保存 + - 返回创建结果 """ import json # 根据 source_type 设置默认值 source_type = project_in.source_type or "repository" + # 构建项目对象 project = Project( name=project_in.name, source_type=source_type, @@ -107,9 +122,11 @@ async def create_project( programming_languages=json.dumps(project_in.programming_languages or []), owner_id=current_user.id ) + # 写入数据库 db.add(project) await db.commit() await db.refresh(project) + # 返回项目 return project @router.get("/", response_model=List[ProjectResponse]) @@ -121,15 +138,23 @@ async def read_projects( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Retrieve projects for current user. + 获取当前用户的项目列表。 + + 处理流程: + - 构建查询并限制为当前用户 + - 可选过滤已删除项目 + - 分页排序后返回 """ + # 构建查询并预加载 owner query = select(Project).options(selectinload(Project.owner)) # 只返回当前用户的项目 query = query.where(Project.owner_id == current_user.id) if not include_deleted: query = query.where(Project.is_active == True) + # 排序与分页 query = query.order_by(Project.created_at.desc()).offset(skip).limit(limit) result = await db.execute(query) + # 返回项目列表 return result.scalars().all() @router.get("/deleted", response_model=List[ProjectResponse]) @@ -138,8 +163,13 @@ async def read_deleted_projects( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Retrieve deleted (soft-deleted) projects for current user. + 获取当前用户已删除(软删除)的项目列表。 + + 处理流程: + - 查询当前用户的非活跃项目 + - 按更新时间倒序返回 """ + # 查询已删除项目 result = await db.execute( select(Project) .options(selectinload(Project.owner)) @@ -147,6 +177,7 @@ async def read_deleted_projects( .where(Project.is_active == False) .order_by(Project.updated_at.desc()) ) + # 返回项目列表 return result.scalars().all() @router.get("/stats", response_model=StatsResponse) @@ -155,7 +186,12 @@ async def get_stats( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Get statistics for current user. + 获取当前用户的项目统计信息。 + + 处理流程: + - 查询项目与任务 + - 合并旧任务与 Agent 任务统计 + - 计算平均质量分 """ # 只统计当前用户的项目 projects_result = await db.execute( @@ -177,14 +213,14 @@ async def get_stats( ) issues = issues_result.scalars().all() - # 🔥 同时统计新的 AgentTask + # 同时统计新的 AgentTask agent_tasks_result = await db.execute( select(AgentTask).where(AgentTask.project_id.in_(project_ids)) if project_ids else select(AgentTask).where(False) ) agent_tasks = agent_tasks_result.scalars().all() agent_task_ids = [t.id for t in agent_tasks] - # 🔥 统计 AgentFinding + # 统计 AgentFinding agent_findings_result = await db.execute( select(AgentFinding).where(AgentFinding.task_id.in_(agent_task_ids)) if agent_task_ids else select(AgentFinding).where(False) ) @@ -209,6 +245,7 @@ async def get_stats( ) avg_quality_score = sum(quality_scores) / len(quality_scores) if quality_scores else 0.0 + # 返回统计结果 return { "total_projects": len(projects), "active_projects": len([p for p in projects if p.is_active]), @@ -226,8 +263,14 @@ async def read_project( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Get project by ID. + 获取项目详情。 + + 处理流程: + - 查询项目并预加载 owner + - 校验存在性与权限 + - 返回项目对象 """ + # 查询项目 result = await db.execute( select(Project) .options(selectinload(Project.owner)) @@ -241,6 +284,7 @@ async def read_project( if project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权查看此项目") + # 返回项目 return project @router.put("/{id}", response_model=ProjectResponse) @@ -252,9 +296,15 @@ async def update_project( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Update project. + 更新项目信息。 + + 处理流程: + - 查询项目并校验权限 + - 处理可选字段与序列化 + - 保存更新并返回 """ import json + # 查询项目 result = await db.execute(select(Project).where(Project.id == id)) project = result.scalars().first() if not project: @@ -264,16 +314,20 @@ async def update_project( if project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权更新此项目") + # 提取更新字段 update_data = project_in.model_dump(exclude_unset=True) if "programming_languages" in update_data and update_data["programming_languages"] is not None: update_data["programming_languages"] = json.dumps(update_data["programming_languages"]) + # 应用更新 for field, value in update_data.items(): setattr(project, field, value) + # 更新时间并提交 project.updated_at = datetime.now(timezone.utc) await db.commit() await db.refresh(project) + # 返回更新后的项目 return project @router.delete("/{id}") @@ -283,8 +337,14 @@ async def delete_project( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Soft delete project. + 软删除项目。 + + 处理流程: + - 查询项目并校验权限 + - 标记为非活跃 + - 保存并返回结果 """ + # 查询项目 result = await db.execute(select(Project).where(Project.id == id)) project = result.scalars().first() if not project: @@ -294,9 +354,11 @@ async def delete_project( if project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权删除此项目") + # 标记为非活跃 project.is_active = False project.updated_at = datetime.now(timezone.utc) await db.commit() + # 返回删除结果 return {"message": "项目已删除"} @router.post("/{id}/restore") @@ -306,8 +368,14 @@ async def restore_project( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Restore soft-deleted project. + 恢复已软删除项目。 + + 处理流程: + - 查询项目并校验权限 + - 标记为活跃 + - 保存并返回结果 """ + # 查询项目 result = await db.execute(select(Project).where(Project.id == id)) project = result.scalars().first() if not project: @@ -317,9 +385,11 @@ async def restore_project( if project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权恢复此项目") + # 标记为活跃 project.is_active = True project.updated_at = datetime.now(timezone.utc) await db.commit() + # 返回恢复结果 return {"message": "项目已恢复"} @router.delete("/{id}/permanent") @@ -329,8 +399,14 @@ async def permanently_delete_project( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Permanently delete project. + 永久删除项目。 + + 处理流程: + - 查询项目并校验权限 + - 删除 ZIP 资源(如有) + - 删除项目并提交 """ + # 查询项目 result = await db.execute(select(Project).where(Project.id == id)) project = result.scalars().first() if not project: @@ -340,7 +416,7 @@ async def permanently_delete_project( if project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权永久删除此项目") - # 如果是ZIP类型项目,删除关联的ZIP文件和元数据 + # 如果是 ZIP 类型项目,删除关联的 ZIP 文件和元数据 if project.source_type == "zip": try: await delete_project_zip(id) @@ -348,8 +424,10 @@ async def permanently_delete_project( except Exception as e: print(f"[Warning] 删除ZIP文件失败: {e}") + # 删除项目记录 await db.delete(project) await db.commit() + # 返回删除结果 return {"message": "项目已永久删除"} @@ -362,16 +440,23 @@ async def get_project_files( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Get list of files in the project. + 获取项目文件列表。 + 可选参数: - branch: 指定仓库分支(仅对仓库类型项目有效) - exclude_patterns: JSON 格式的排除模式数组,如 ["node_modules/**", "*.log"] + + 处理流程: + - 校验项目与权限 + - 解析排除模式 + - 根据来源类型返回文件列表 """ + # 获取项目信息 project = await db.get(Project, id) if not project: raise HTTPException(status_code=404, detail="项目不存在") - # Check permissions + # 检查权限 if project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权查看此项目") @@ -383,10 +468,11 @@ async def get_project_files( except json.JSONDecodeError: pass + # 准备返回的文件列表 files = [] if project.source_type == "zip": - # Handle ZIP project + # 处理 ZIP 类型项目 zip_path = await load_project_zip(id) print(f"📦 ZIP项目 {id} 文件路径: {zip_path}") if not zip_path or not os.path.exists(zip_path): @@ -394,6 +480,7 @@ async def get_project_files( return [] try: + # 遍历 ZIP 内文件 with zipfile.ZipFile(zip_path, 'r') as zip_ref: for file_info in zip_ref.infolist(): if not file_info.is_dir(): @@ -410,11 +497,11 @@ async def get_project_files( raise HTTPException(status_code=500, detail="无法读取项目文件") elif project.source_type == "repository": - # Handle Repository project + # 处理仓库类型项目 if not project.repository_url: return [] - # Get tokens from user config + # 从用户配置中获取 Token from sqlalchemy.future import select from app.core.encryption import decrypt_sensitive_data from app.core.config import settings @@ -422,16 +509,19 @@ async def get_project_files( SENSITIVE_OTHER_FIELDS = ['githubToken', 'gitlabToken', 'sshPrivateKey'] + # 查询用户配置 result = await db.execute( select(UserConfig).where(UserConfig.user_id == current_user.id) ) config = result.scalar_one_or_none() + # 初始化 Token 与 SSH 私钥 github_token = settings.GITHUB_TOKEN gitlab_token = settings.GITLAB_TOKEN ssh_private_key = None if config and config.other_config: + # 解密用户配置 other_config = json.loads(config.other_config) for field in SENSITIVE_OTHER_FIELDS: if field in other_config and other_config[field]: @@ -443,13 +533,13 @@ async def get_project_files( elif field == 'sshPrivateKey': ssh_private_key = decrypted_val - # 检查是否为SSH URL + # 检查是否为 SSH URL is_ssh_url = GitSSHOperations.is_ssh_url(project.repository_url) target_branch = branch or project.default_branch or "main" try: if is_ssh_url: - # 使用SSH方式获取文件列表 + # 使用 SSH 方式获取文件列表 if not ssh_private_key: raise HTTPException( status_code=400, @@ -463,9 +553,10 @@ async def get_project_files( target_branch, parsed_exclude_patterns ) + # 将文件内容长度作为 size 返回 files = [{"path": f["path"], "size": len(f.get("content", ""))} for f in files_with_content] else: - # 使用API方式获取文件列表 + # 使用 API 方式获取文件列表 repo_type = project.repository_type or "other" if repo_type == "github": @@ -484,9 +575,11 @@ async def get_project_files( print(f"Error fetching repo files: {e}") raise HTTPException(status_code=500, detail=f"无法获取仓库文件: {str(e)}") + # 返回文件列表 return files class ScanRequest(BaseModel): + """扫描任务请求模型。""" file_paths: Optional[List[str]] = None full_scan: bool = True exclude_patterns: Optional[List[str]] = None @@ -502,8 +595,15 @@ async def scan_project( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Start a scan task. + 启动项目扫描任务。 + + 处理流程: + - 校验项目与权限 + - 创建任务记录 + - 解析用户配置并解密敏感字段 + - 触发后台扫描任务 """ + # 查询项目 project = await db.get(Project, id) if not project: raise HTTPException(status_code=404, detail="项目不存在") @@ -512,7 +612,7 @@ async def scan_project( branch_name = scan_request.branch_name if scan_request else None exclude_patterns = scan_request.exclude_patterns if scan_request else None - # Create Task Record + # 创建任务记录 task = AuditTask( project_id=project.id, created_by=current_user.id, @@ -538,19 +638,24 @@ async def scan_project( SENSITIVE_OTHER_FIELDS = ['githubToken', 'gitlabToken'] def decrypt_config(config_dict: dict, sensitive_fields: list) -> dict: - """解密配置中的敏感字段""" + """解密配置中的敏感字段。""" + # 拷贝配置避免原地修改 decrypted = config_dict.copy() + # 遍历敏感字段并解密 for field in sensitive_fields: if field in decrypted and decrypted[field]: decrypted[field] = decrypt_sensitive_data(decrypted[field]) + # 返回解密结果 return decrypted + # 查询用户配置 result = await db.execute( select(UserConfig).where(UserConfig.user_id == current_user.id) ) config = result.scalar_one_or_none() user_config = {} if config: + # 解析配置 llm_config = json.loads(config.llm_config) if config.llm_config else {} other_config = json.loads(config.other_config) if config.other_config else {} # 解密敏感字段 @@ -565,15 +670,17 @@ def decrypt_config(config_dict: dict, sensitive_fields: list) -> dict: if scan_request and scan_request.file_paths: user_config['scan_config'] = {'file_paths': scan_request.file_paths} - # Trigger Background Task + # 触发后台任务 background_tasks.add_task(scan_repo_task, task.id, AsyncSessionLocal, user_config) + # 返回任务状态 return {"task_id": task.id, "status": "started"} # ============ ZIP文件管理端点 ============ class ZipFileMetaResponse(BaseModel): + """ZIP 文件元数据响应模型。""" has_file: bool original_filename: Optional[str] = None file_size: Optional[int] = None @@ -587,13 +694,19 @@ async def get_project_zip_info( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取项目ZIP文件信息 + 获取项目 ZIP 文件信息。 + + 处理流程: + - 校验项目存在性 + - 查询 ZIP 元数据 + - 返回元数据信息 """ + # 查询项目 project = await db.get(Project, id) if not project: raise HTTPException(status_code=404, detail="项目不存在") - # 检查是否有ZIP文件 + # 检查是否有 ZIP 文件 has_file = await has_project_zip(id) if not has_file: return {"has_file": False} @@ -601,6 +714,7 @@ async def get_project_zip_info( # 获取元数据 meta = await get_project_zip_meta(id) if meta: + # 返回元数据 return { "has_file": True, "original_filename": meta.get("original_filename"), @@ -608,6 +722,7 @@ async def get_project_zip_info( "uploaded_at": meta.get("uploaded_at") } + # 仅标记存在文件 return {"has_file": True} @@ -619,8 +734,15 @@ async def upload_project_zip( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 上传或更新项目ZIP文件 + 上传或更新项目 ZIP 文件。 + + 处理流程: + - 校验项目、权限与类型 + - 校验文件与大小 + - 保存至持久化存储 + - 清理临时文件 """ + # 查询项目 project = await db.get(Project, id) if not project: raise HTTPException(status_code=404, detail="项目不存在") @@ -642,6 +764,7 @@ async def upload_project_zip( temp_file_path = f"/tmp/{temp_file_id}.zip" try: + # 写入临时文件 with open(temp_file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) @@ -653,6 +776,7 @@ async def upload_project_zip( # 保存到持久化存储 meta = await save_project_zip(id, temp_file_path, file.filename) + # 返回上传结果与元数据 return { "message": "ZIP文件上传成功", "original_filename": meta["original_filename"], @@ -672,8 +796,14 @@ async def delete_project_zip_file( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 删除项目ZIP文件 + 删除项目 ZIP 文件。 + + 处理流程: + - 校验项目与权限 + - 删除 ZIP 文件 + - 返回结果 """ + # 查询项目 project = await db.get(Project, id) if not project: raise HTTPException(status_code=404, detail="项目不存在") @@ -682,11 +812,14 @@ async def delete_project_zip_file( if project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权操作此项目") + # 执行删除 deleted = await delete_project_zip(id) if deleted: + # 返回删除成功 return {"message": "ZIP文件已删除"} else: + # 返回未找到文件 return {"message": "没有找到ZIP文件"} @@ -699,8 +832,14 @@ async def get_project_branches( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取项目仓库的分支列表 + 获取项目仓库的分支列表。 + + 处理流程: + - 校验项目类型与仓库地址 + - 解密用户 Token + - 按仓库类型获取分支列表 """ + # 查询项目 project = await db.get(Project, id) if not project: raise HTTPException(status_code=404, detail="项目不存在") @@ -721,6 +860,7 @@ async def get_project_branches( ) config = config.scalar_one_or_none() + # 初始化 Token github_token = settings.GITHUB_TOKEN gitea_token = settings.GITEA_TOKEN gitlab_token = settings.GITLAB_TOKEN @@ -729,6 +869,7 @@ async def get_project_branches( if config and config.other_config: import json + # 解密用户 Token other_config = json.loads(config.other_config) for field in SENSITIVE_OTHER_FIELDS: if field in other_config and other_config[field]: @@ -771,6 +912,7 @@ async def get_project_branches( branches.remove(default_branch) branches.insert(0, default_branch) + # 返回分支信息 return {"branches": branches, "default_branch": default_branch} except Exception as e: diff --git a/backend/app/api/v1/endpoints/prompts.py b/backend/app/api/v1/endpoints/prompts.py index 9045c5b4..341e3d48 100644 --- a/backend/app/api/v1/endpoints/prompts.py +++ b/backend/app/api/v1/endpoints/prompts.py @@ -35,20 +35,27 @@ async def list_prompt_templates( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """获取提示词模板列表""" + """ + 获取提示词模板列表。 + + 处理流程: + - 构建模板查询(系统模板 + 用户模板) + - 应用筛选与排序 + - 统计总数并分页 + - 解析变量并返回列表 + """ + # 构建基础查询 query = select(PromptTemplate) - # 过滤条件:系统模板 + 当前用户创建的模板 query = query.where( (PromptTemplate.is_system == True) | (PromptTemplate.created_by == current_user.id) ) - + # 可选过滤条件 if template_type: query = query.where(PromptTemplate.template_type == template_type) if is_active is not None: query = query.where(PromptTemplate.is_active == is_active) - # 排序:系统模板优先,然后按排序权重和创建时间 query = query.order_by( PromptTemplate.is_system.desc(), @@ -56,25 +63,24 @@ async def list_prompt_templates( PromptTemplate.sort_order.asc(), PromptTemplate.created_at.desc() ) - - # 计数 + # 统计总数 count_query = select(sql_func.count()).select_from(query.subquery()) total = (await db.execute(count_query)).scalar() - - # 分页 + # 分页查询 query = query.offset(skip).limit(limit) result = await db.execute(query) templates = result.scalars().all() - + # 组装响应项 items = [] for t in templates: + # 解析变量 JSON variables = {} if t.variables: try: variables = json.loads(t.variables) except: pass - + # 构建响应对象 items.append(PromptTemplateResponse( id=t.id, name=t.name, @@ -91,7 +97,7 @@ async def list_prompt_templates( created_at=t.created_at, updated_at=t.updated_at, )) - + # 返回列表响应 return PromptTemplateListResponse(items=items, total=total) @@ -101,26 +107,34 @@ async def get_prompt_template( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """获取单个提示词模板""" + """ + 获取单个提示词模板。 + + 处理流程: + - 查询模板 + - 校验存在性与权限 + - 解析变量并返回 + """ + # 查询模板 result = await db.execute( select(PromptTemplate).where(PromptTemplate.id == template_id) ) + # 获取模板对象 template = result.scalar_one_or_none() - + # 若模板不存在则返回 404 if not template: raise HTTPException(status_code=404, detail="模板不存在") - - # 检查权限 + # 校验权限:系统模板对所有人可见 if not template.is_system and template.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权访问此模板") - + # 解析变量 JSON variables = {} if template.variables: try: variables = json.loads(template.variables) except: pass - + # 返回模板响应 return PromptTemplateResponse( id=template.id, name=template.name, @@ -145,7 +159,15 @@ async def create_prompt_template( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """创建提示词模板""" + """ + 创建提示词模板。 + + 处理流程: + - 构建模板对象 + - 保存并刷新 + - 返回创建结果 + """ + # 构建模板对象 template = PromptTemplate( name=template_in.name, description=template_in.description, @@ -159,11 +181,12 @@ async def create_prompt_template( is_default=False, created_by=current_user.id, ) - + # 保存模板 db.add(template) await db.commit() + # 刷新对象以获取最新数据 await db.refresh(template) - + # 返回创建结果 return PromptTemplateResponse( id=template.id, name=template.name, @@ -189,15 +212,24 @@ async def update_prompt_template( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """更新提示词模板""" + """ + 更新提示词模板。 + + 处理流程: + - 查询模板 + - 校验权限与系统模板限制 + - 应用更新并保存 + - 返回更新结果 + """ + # 查询模板 result = await db.execute( select(PromptTemplate).where(PromptTemplate.id == template_id) ) + # 获取模板对象 template = result.scalar_one_or_none() - + # 若模板不存在则返回 404 if not template: raise HTTPException(status_code=404, detail="模板不存在") - # 系统模板不允许修改核心内容,只能修改启用状态 if template.is_system: if template_in.is_active is not None: @@ -205,10 +237,9 @@ async def update_prompt_template( else: raise HTTPException(status_code=403, detail="系统模板不允许修改") else: - # 检查权限 + # 校验权限 if template.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权修改此模板") - # 更新字段 update_data = template_in.dict(exclude_unset=True) for field, value in update_data.items(): @@ -216,17 +247,18 @@ async def update_prompt_template( setattr(template, field, json.dumps(value)) elif field != "is_default": # 不允许用户设置默认 setattr(template, field, value) - + # 提交更新 await db.commit() + # 刷新对象 await db.refresh(template) - + # 解析变量 JSON variables = {} if template.variables: try: variables = json.loads(template.variables) except: pass - + # 返回更新结果 return PromptTemplateResponse( id=template.id, name=template.name, @@ -251,24 +283,34 @@ async def delete_prompt_template( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """删除提示词模板""" + """ + 删除提示词模板。 + + 处理流程: + - 查询模板 + - 校验存在性与权限 + - 删除并提交 + - 返回删除结果 + """ + # 查询模板 result = await db.execute( select(PromptTemplate).where(PromptTemplate.id == template_id) ) + # 获取模板对象 template = result.scalar_one_or_none() - + # 若模板不存在则返回 404 if not template: raise HTTPException(status_code=404, detail="模板不存在") - + # 系统模板禁止删除 if template.is_system: raise HTTPException(status_code=403, detail="系统模板不允许删除") - + # 校验权限 if template.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权删除此模板") - + # 删除模板 await db.delete(template) await db.commit() - + # 返回删除结果 return {"message": "模板已删除"} @@ -278,19 +320,28 @@ async def test_prompt_template( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """测试提示词效果""" + """ + 测试提示词效果。 + + 处理流程: + - 读取并解密用户配置 + - 创建 LLM 服务 + - 使用自定义提示词进行分析 + - 返回测试结果 + """ + # 延迟导入依赖 from app.services.llm.service import LLMService from app.models.user_config import UserConfig from app.core.encryption import decrypt_sensitive_data - + # 记录开始时间 start_time = time.time() - try: # 获取用户配置 user_config = {} result_config = await db.execute( select(UserConfig).where(UserConfig.user_id == current_user.id) ) + # 获取配置对象 config = result_config.scalar_one_or_none() if config: # 需要解密的敏感字段 @@ -299,17 +350,15 @@ async def test_prompt_template( 'qwenApiKey', 'deepseekApiKey', 'zhipuApiKey', 'moonshotApiKey', 'baiduApiKey', 'minimaxApiKey', 'doubaoApiKey' ] - + # 解析并解密 LLM 配置 llm_config = json.loads(config.llm_config) if config.llm_config else {} for field in SENSITIVE_LLM_FIELDS: if field in llm_config and llm_config[field]: llm_config[field] = decrypt_sensitive_data(llm_config[field]) - + # 组装用户配置 user_config = {'llmConfig': llm_config} - - # 创建使用用户配置的LLM服务实例 + # 创建使用用户配置的 LLM 服务实例 llm_service = LLMService(user_config=user_config) - # 使用自定义提示词进行分析 result = await llm_service.analyze_code_with_custom_prompt( code=request.code, @@ -317,19 +366,22 @@ async def test_prompt_template( custom_prompt=request.content, output_language=request.output_language, ) - + # 计算耗时 execution_time = time.time() - start_time - + # 返回成功结果 return PromptTestResponse( success=True, result=result, execution_time=round(execution_time, 2), ) except Exception as e: + # 计算耗时 execution_time = time.time() - start_time + # 记录异常信息 import traceback print(f"❌ 提示词测试失败: {e}") print(traceback.format_exc()) + # 返回失败结果 return PromptTestResponse( success=False, error=str(e), @@ -343,15 +395,25 @@ async def set_default_template( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """设置默认模板(仅管理员)""" + """ + 设置默认模板(仅管理员)。 + + 处理流程: + - 校验管理员权限 + - 查询模板 + - 取消同类型默认模板 + - 设置当前模板为默认 + """ + # 仅管理员可设置默认模板 if not current_user.is_superuser: raise HTTPException(status_code=403, detail="仅管理员可设置默认模板") - + # 查询目标模板 result = await db.execute( select(PromptTemplate).where(PromptTemplate.id == template_id) ) + # 获取模板对象 template = result.scalar_one_or_none() - + # 若模板不存在则返回 404 if not template: raise HTTPException(status_code=404, detail="模板不存在") diff --git a/backend/app/api/v1/endpoints/rules.py b/backend/app/api/v1/endpoints/rules.py index d59db6d5..bfd23af0 100644 --- a/backend/app/api/v1/endpoints/rules.py +++ b/backend/app/api/v1/endpoints/rules.py @@ -42,7 +42,15 @@ async def list_rule_sets( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """获取审计规则集列表""" + """ + 获取审计规则集列表。 + + 处理流程: + - 构建可访问规则集查询 + - 应用过滤条件与排序 + - 分页并构建响应 + """ + # 构建规则集查询并预加载 rules query = select(AuditRuleSet).options(selectinload(AuditRuleSet.rules)) # 过滤条件:系统规则集 + 当前用户创建的规则集 @@ -51,6 +59,7 @@ async def list_rule_sets( (AuditRuleSet.created_by == current_user.id) ) + # 应用过滤条件 if language: query = query.where(AuditRuleSet.language == language) if rule_type: @@ -80,8 +89,10 @@ async def list_rule_sets( result = await db.execute(query) rule_sets = result.scalars().unique().all() + # 构建响应项 items = [] for rs in rule_sets: + # 解析严重程度权重 severity_weights = {"critical": 10, "high": 5, "medium": 2, "low": 1} if rs.severity_weights: try: @@ -89,6 +100,7 @@ async def list_rule_sets( except: pass + # 组装规则列表 rules = [ AuditRuleResponse( id=r.id, @@ -109,6 +121,7 @@ async def list_rule_sets( for r in rs.rules ] + # 组装规则集响应 items.append(AuditRuleSetResponse( id=rs.id, name=rs.name, @@ -128,6 +141,7 @@ async def list_rule_sets( enabled_rules_count=len([r for r in rules if r.enabled]), )) + # 返回分页结果 return AuditRuleSetListResponse(items=items, total=total) @@ -137,7 +151,15 @@ async def get_rule_set( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """获取单个规则集""" + """ + 获取单个规则集详情。 + + 处理流程: + - 查询规则集并预加载规则 + - 校验存在性与权限 + - 组装并返回响应 + """ + # 查询规则集 result = await db.execute( select(AuditRuleSet) .options(selectinload(AuditRuleSet.rules)) @@ -145,12 +167,15 @@ async def get_rule_set( ) rule_set = result.scalar_one_or_none() + # 校验存在性 if not rule_set: raise HTTPException(status_code=404, detail="规则集不存在") + # 校验访问权限 if not rule_set.is_system and rule_set.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权访问此规则集") + # 解析严重程度权重 severity_weights = {"critical": 10, "high": 5, "medium": 2, "low": 1} if rule_set.severity_weights: try: @@ -158,6 +183,7 @@ async def get_rule_set( except: pass + # 组装规则列表 rules = [ AuditRuleResponse( id=r.id, @@ -178,6 +204,7 @@ async def get_rule_set( for r in rule_set.rules ] + # 返回规则集详情 return AuditRuleSetResponse( id=rule_set.id, name=rule_set.name, @@ -204,7 +231,15 @@ async def create_rule_set( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """创建审计规则集""" + """ + 创建审计规则集。 + + 处理流程: + - 构建规则集对象 + - 逐条创建规则 + - 提交并返回结果 + """ + # 构建规则集对象 rule_set = AuditRuleSet( name=rule_set_in.name, description=rule_set_in.description, @@ -218,12 +253,14 @@ async def create_rule_set( created_by=current_user.id, ) + # 写入数据库并获取 rule_set.id db.add(rule_set) await db.flush() # 创建规则 rules = [] for rule_in in (rule_set_in.rules or []): + # 构建规则对象 rule = AuditRule( rule_set_id=rule_set.id, rule_code=rule_in.rule_code, @@ -240,9 +277,11 @@ async def create_rule_set( db.add(rule) rules.append(rule) + # 提交并刷新 await db.commit() await db.refresh(rule_set) + # 返回规则集响应 return AuditRuleSetResponse( id=rule_set.id, name=rule_set.name, @@ -288,7 +327,15 @@ async def update_rule_set( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """更新审计规则集""" + """ + 更新审计规则集。 + + 处理流程: + - 查询规则集并校验权限 + - 更新字段并提交 + - 返回最新规则集 + """ + # 查询规则集 result = await db.execute( select(AuditRuleSet) .options(selectinload(AuditRuleSet.rules)) @@ -296,6 +343,7 @@ async def update_rule_set( ) rule_set = result.scalar_one_or_none() + # 校验存在性 if not rule_set: raise HTTPException(status_code=404, detail="规则集不存在") @@ -306,9 +354,11 @@ async def update_rule_set( else: raise HTTPException(status_code=403, detail="系统规则集不允许修改") else: + # 校验拥有者权限 if rule_set.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权修改此规则集") + # 应用更新字段 update_data = rule_set_in.dict(exclude_unset=True) for field, value in update_data.items(): if field == "severity_weights" and value is not None: @@ -316,9 +366,11 @@ async def update_rule_set( elif field != "is_default": setattr(rule_set, field, value) + # 提交并刷新 await db.commit() await db.refresh(rule_set) + # 解析严重程度权重 severity_weights = {"critical": 10, "high": 5, "medium": 2, "low": 1} if rule_set.severity_weights: try: @@ -326,6 +378,7 @@ async def update_rule_set( except: pass + # 组装规则列表 rules = [ AuditRuleResponse( id=r.id, @@ -346,6 +399,7 @@ async def update_rule_set( for r in rule_set.rules ] + # 返回更新结果 return AuditRuleSetResponse( id=rule_set.id, name=rule_set.name, @@ -372,24 +426,37 @@ async def delete_rule_set( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """删除审计规则集""" + """ + 删除审计规则集。 + + 处理流程: + - 查询规则集并校验权限 + - 删除规则集并提交 + - 返回结果 + """ + # 查询规则集 result = await db.execute( select(AuditRuleSet).where(AuditRuleSet.id == rule_set_id) ) rule_set = result.scalar_one_or_none() + # 校验存在性 if not rule_set: raise HTTPException(status_code=404, detail="规则集不存在") + # 系统规则集不可删除 if rule_set.is_system: raise HTTPException(status_code=403, detail="系统规则集不允许删除") + # 校验拥有者权限 if rule_set.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权删除此规则集") + # 删除并提交 await db.delete(rule_set) await db.commit() + # 返回删除结果 return {"message": "规则集已删除"} @@ -399,7 +466,15 @@ async def export_rule_set( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """导出规则集为JSON""" + """ + 导出规则集为 JSON。 + + 处理流程: + - 查询规则集与规则 + - 校验权限并构建导出结构 + - 返回 JSON 响应并设置文件名 + """ + # 查询规则集 result = await db.execute( select(AuditRuleSet) .options(selectinload(AuditRuleSet.rules)) @@ -407,12 +482,15 @@ async def export_rule_set( ) rule_set = result.scalar_one_or_none() + # 校验存在性 if not rule_set: raise HTTPException(status_code=404, detail="规则集不存在") + # 校验权限 if not rule_set.is_system and rule_set.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权导出此规则集") + # 解析严重程度权重 severity_weights = {"critical": 10, "high": 5, "medium": 2, "low": 1} if rule_set.severity_weights: try: @@ -420,6 +498,7 @@ async def export_rule_set( except: pass + # 构建导出结构 export_data = { "name": rule_set.name, "description": rule_set.description, @@ -448,6 +527,7 @@ async def export_rule_set( from urllib.parse import quote encoded_filename = quote(f"{rule_set.name}.json") + # 返回 JSON 文件响应 return JSONResponse( content=export_data, headers={ @@ -462,7 +542,15 @@ async def import_rule_set( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """导入规则集""" + """ + 导入规则集。 + + 处理流程: + - 创建规则集对象 + - 批量创建规则 + - 提交并返回结果 + """ + # 构建规则集对象 rule_set = AuditRuleSet( name=import_data.name, description=import_data.description, @@ -475,11 +563,14 @@ async def import_rule_set( created_by=current_user.id, ) + # 写入数据库并获取 rule_set.id db.add(rule_set) await db.flush() + # 创建规则 rules = [] for rule_in in import_data.rules: + # 构建规则对象 rule = AuditRule( rule_set_id=rule_set.id, rule_code=rule_in.rule_code, @@ -496,9 +587,11 @@ async def import_rule_set( db.add(rule) rules.append(rule) + # 提交并刷新 await db.commit() await db.refresh(rule_set) + # 返回规则集响应 return AuditRuleSetResponse( id=rule_set.id, name=rule_set.name, @@ -546,21 +639,33 @@ async def add_rule_to_set( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """向规则集添加规则""" + """ + 向规则集添加规则。 + + 处理流程: + - 查询规则集并校验权限 + - 创建规则并提交 + - 返回规则响应 + """ + # 查询规则集 result = await db.execute( select(AuditRuleSet).where(AuditRuleSet.id == rule_set_id) ) rule_set = result.scalar_one_or_none() + # 校验存在性 if not rule_set: raise HTTPException(status_code=404, detail="规则集不存在") + # 系统规则集不可修改 if rule_set.is_system: raise HTTPException(status_code=403, detail="系统规则集不允许添加规则") + # 校验拥有者权限 if rule_set.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权修改此规则集") + # 构建规则对象 rule = AuditRule( rule_set_id=rule_set_id, rule_code=rule_in.rule_code, @@ -575,10 +680,12 @@ async def add_rule_to_set( sort_order=rule_in.sort_order, ) + # 保存规则 db.add(rule) await db.commit() await db.refresh(rule) + # 返回规则响应 return AuditRuleResponse( id=rule.id, rule_set_id=rule.rule_set_id, @@ -605,21 +712,33 @@ async def update_rule( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """更新规则""" + """ + 更新规则。 + + 处理流程: + - 校验规则集权限 + - 查询规则并更新字段 + - 提交并返回结果 + """ + # 查询规则集 result = await db.execute( select(AuditRuleSet).where(AuditRuleSet.id == rule_set_id) ) rule_set = result.scalar_one_or_none() + # 校验存在性 if not rule_set: raise HTTPException(status_code=404, detail="规则集不存在") + # 系统规则集不可修改 if rule_set.is_system: raise HTTPException(status_code=403, detail="系统规则集不允许修改规则") + # 校验拥有者权限 if rule_set.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权修改此规则集") + # 查询规则 result = await db.execute( select(AuditRule).where( AuditRule.id == rule_id, @@ -628,16 +747,20 @@ async def update_rule( ) rule = result.scalar_one_or_none() + # 校验规则存在性 if not rule: raise HTTPException(status_code=404, detail="规则不存在") + # 应用更新字段 update_data = rule_in.dict(exclude_unset=True) for field, value in update_data.items(): setattr(rule, field, value) + # 提交并刷新 await db.commit() await db.refresh(rule) + # 返回更新后的规则 return AuditRuleResponse( id=rule.id, rule_set_id=rule.rule_set_id, @@ -663,21 +786,33 @@ async def delete_rule( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """删除规则""" + """ + 删除规则。 + + 处理流程: + - 校验规则集权限 + - 查询规则并删除 + - 返回结果 + """ + # 查询规则集 result = await db.execute( select(AuditRuleSet).where(AuditRuleSet.id == rule_set_id) ) rule_set = result.scalar_one_or_none() + # 校验存在性 if not rule_set: raise HTTPException(status_code=404, detail="规则集不存在") + # 系统规则集不可删除规则 if rule_set.is_system: raise HTTPException(status_code=403, detail="系统规则集不允许删除规则") + # 校验拥有者权限 if rule_set.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权修改此规则集") + # 查询规则 result = await db.execute( select(AuditRule).where( AuditRule.id == rule_id, @@ -686,12 +821,15 @@ async def delete_rule( ) rule = result.scalar_one_or_none() + # 校验规则存在性 if not rule: raise HTTPException(status_code=404, detail="规则不存在") + # 删除规则并提交 await db.delete(rule) await db.commit() + # 返回删除结果 return {"message": "规则已删除"} @@ -702,12 +840,21 @@ async def toggle_rule( db: AsyncSession = Depends(get_db), current_user: User = Depends(deps.get_current_user), ) -> Any: - """切换规则启用状态""" + """ + 切换规则启用状态。 + + 处理流程: + - 校验规则集权限 + - 查询规则并切换启用状态 + - 返回最新状态 + """ + # 查询规则集 result = await db.execute( select(AuditRuleSet).where(AuditRuleSet.id == rule_set_id) ) rule_set = result.scalar_one_or_none() + # 校验存在性 if not rule_set: raise HTTPException(status_code=404, detail="规则集不存在") @@ -715,6 +862,7 @@ async def toggle_rule( if not rule_set.is_system and rule_set.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权修改此规则集") + # 查询规则 result = await db.execute( select(AuditRule).where( AuditRule.id == rule_id, @@ -723,10 +871,13 @@ async def toggle_rule( ) rule = result.scalar_one_or_none() + # 校验规则存在性 if not rule: raise HTTPException(status_code=404, detail="规则不存在") + # 切换启用状态 rule.enabled = not rule.enabled await db.commit() + # 返回切换结果 return {"enabled": rule.enabled, "message": f"规则已{'启用' if rule.enabled else '禁用'}"} diff --git a/backend/app/api/v1/endpoints/scan.py b/backend/app/api/v1/endpoints/scan.py index 0b1bc20a..d539b64d 100644 --- a/backend/app/api/v1/endpoints/scan.py +++ b/backend/app/api/v1/endpoints/scan.py @@ -1,3 +1,7 @@ +""" +模块说明:API 路由与依赖定义:scan。 +""" + from fastapi import APIRouter, UploadFile, File, Form, Depends, BackgroundTasks, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -29,10 +33,13 @@ def normalize_path(path: str) -> str: """ - 统一路径分隔符为正斜杠,确保跨平台兼容性 - Windows 使用反斜杠 (\),Unix/Mac 使用正斜杠 (/) - 统一转换为正斜杠以保证一致性 + 统一路径分隔符为正斜杠,确保跨平台兼容性。 + + 处理流程: + - 将反斜杠替换为正斜杠 + - 返回标准化路径 """ + # 统一替换为正斜杠 return path.replace("\\", "/") @@ -45,24 +52,36 @@ def normalize_path(path: str) -> str: async def process_zip_task(task_id: str, file_path: str, db_session_factory, user_config: dict = None): - """后台ZIP文件处理任务""" + """ + 后台 ZIP 文件处理任务。 + + 处理流程: + - 查询任务并标记运行 + - 解压 ZIP 并筛选可扫描文件 + - 按配置逐文件分析并写入问题 + - 汇总结果并更新任务状态 + """ + # 打开数据库会话 async with db_session_factory() as db: + # 查询任务 task = await db.get(AuditTask, task_id) if not task: return try: + # 更新任务为运行中 task.status = "running" task.started_at = datetime.now(timezone.utc) await db.commit() - # 创建使用用户配置的LLM服务实例 + # 创建使用用户配置的 LLM 服务实例 llm_service = LLMService(user_config=user_config or {}) - # Extract ZIP + # 准备解压目录 extract_dir = Path(f"/tmp/{task_id}") extract_dir.mkdir(parents=True, exist_ok=True) + # 解压 ZIP with zipfile.ZipFile(file_path, 'r') as zip_ref: zip_ref.extractall(extract_dir) @@ -70,13 +89,14 @@ async def process_zip_task(task_id: str, file_path: str, db_session_factory, use scan_config = (user_config or {}).get('scan_config', {}) custom_exclude_patterns = scan_config.get('exclude_patterns', []) - # Find files + # 扫描可分析文件 files_to_scan = [] for root, dirs, files in os.walk(extract_dir): # 排除常见非代码目录 dirs[:] = [d for d in dirs if d not in ['node_modules', '__pycache__', '.git', 'dist', 'build', 'vendor']] for file in files: + # 计算相对路径 full_path = Path(root) / file # 统一使用正斜杠,确保跨平台兼容性 rel_path = normalize_path(str(full_path.relative_to(extract_dir))) @@ -84,6 +104,7 @@ async def process_zip_task(task_id: str, file_path: str, db_session_factory, use # 检查文件类型和排除规则(包含用户自定义排除模式) if is_text_file(rel_path) and not should_exclude(rel_path, custom_exclude_patterns): try: + # 读取文本内容并检查大小 content = full_path.read_text(errors='ignore') if len(content) <= settings.MAX_FILE_SIZE_BYTES: files_to_scan.append({ @@ -105,15 +126,19 @@ async def process_zip_task(task_id: str, file_path: str, db_session_factory, use # 统一目标文件路径的分隔符,确保匹配一致性 normalized_targets = {normalize_path(p) for p in target_files} print(f"🎯 ZIP任务: 指定分析 {len(normalized_targets)} 个文件") + # 仅保留目标文件 files_to_scan = [f for f in files_to_scan if f['path'] in normalized_targets] elif max_analyze_files > 0: + # 按最大文件数裁剪 files_to_scan = files_to_scan[:max_analyze_files] + # 更新任务统计信息 task.total_files = len(files_to_scan) await db.commit() print(f"📊 ZIP任务 {task_id}: 找到 {len(files_to_scan)} 个文件 (最大文件数: {max_analyze_files}, 请求间隔: {llm_gap_ms}ms)") + # 初始化统计变量 total_issues = 0 total_lines = 0 quality_scores = [] @@ -131,11 +156,12 @@ async def process_zip_task(task_id: str, file_path: str, db_session_factory, use return try: + # 读取文件内容与语言 content = file_info['content'] total_lines += content.count('\n') + 1 language = get_language_from_path(file_info['path']) - # 获取规则集和提示词模板ID + # 获取规则集和提示词模板 ID scan_config = (user_config or {}).get('scan_config', {}) rule_set_id = scan_config.get('rule_set_id') prompt_template_id = scan_config.get('prompt_template_id') @@ -149,8 +175,10 @@ async def process_zip_task(task_id: str, file_path: str, db_session_factory, use db_session=db ) else: + # 使用默认规则分析 result = await llm_service.analyze_code(content, language) + # 写入问题记录 issues = result.get("issues", []) for i in issues: issue = AuditIssue( @@ -171,9 +199,11 @@ async def process_zip_task(task_id: str, file_path: str, db_session_factory, use db.add(issue) total_issues += 1 + # 记录质量评分 if "quality_score" in result: quality_scores.append(result["quality_score"]) + # 更新任务进度 scanned_files += 1 task.scanned_files = scanned_files task.total_lines = total_lines @@ -186,6 +216,7 @@ async def process_zip_task(task_id: str, file_path: str, db_session_factory, use await asyncio.sleep(llm_gap_ms / 1000) except Exception as file_error: + # 记录单文件失败 failed_files += 1 print(f"❌ ZIP任务分析文件失败 ({file_info['path']}): {file_error}") await asyncio.sleep(llm_gap_ms / 1000) @@ -204,6 +235,7 @@ async def process_zip_task(task_id: str, file_path: str, db_session_factory, use await db.commit() print(f"❌ ZIP任务 {task_id} 失败: 所有 {len(files_to_scan)} 个文件分析均失败,请检查 LLM API 配置") else: + # 标记完成并写入统计 task.status = "completed" task.completed_at = datetime.now(timezone.utc) task.scanned_files = scanned_files @@ -212,16 +244,19 @@ async def process_zip_task(task_id: str, file_path: str, db_session_factory, use task.quality_score = avg_quality_score await db.commit() print(f"✅ ZIP任务 {task_id} 完成: 扫描 {scanned_files} 个文件, 发现 {total_issues} 个问题") + # 清理任务控制状态 task_control.cleanup_task(task_id) except Exception as e: + # 记录失败 print(f"❌ ZIP扫描失败: {e}") task.status = "failed" task.completed_at = datetime.now(timezone.utc) await db.commit() + # 清理任务控制状态 task_control.cleanup_task(task_id) finally: - # Cleanup - 只清理解压目录,不删除源ZIP文件(已持久化存储) + # Cleanup - 只清理解压目录,不删除源 ZIP 文件(已持久化存储) if extract_dir.exists(): shutil.rmtree(extract_dir) @@ -236,10 +271,15 @@ async def scan_zip( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Upload and scan a ZIP file. - 上传ZIP文件并启动扫描,同时将ZIP文件保存到持久化存储 + 上传 ZIP 文件并启动扫描,同时将 ZIP 文件保存到持久化存储。 + + 处理流程: + - 校验项目与权限 + - 校验 ZIP 文件与大小 + - 保存 ZIP 到持久化存储 + - 创建任务并触发后台处理 """ - # Verify project exists + # 校验项目存在 project = await db.get(Project, project_id) if not project: raise HTTPException(status_code=404, detail="项目不存在") @@ -248,26 +288,26 @@ async def scan_zip( if project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权操作此项目") - # Validate file + # 校验文件类型 if not file.filename.lower().endswith('.zip'): raise HTTPException(status_code=400, detail="请上传ZIP格式文件") - # Save Uploaded File to temp + # 保存上传文件到临时路径 file_id = str(uuid.uuid4()) file_path = f"/tmp/{file_id}.zip" with open(file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) - # Check file size + # 校验文件大小 file_size = os.path.getsize(file_path) if file_size > 500 * 1024 * 1024: # 500MB limit os.remove(file_path) raise HTTPException(status_code=400, detail="文件大小不能超过500MB") - # 保存ZIP文件到持久化存储 + # 保存 ZIP 文件到持久化存储 await save_project_zip(project_id, file_path, file.filename) - # Parse scan_config if provided + # 解析扫描配置 parsed_scan_config = {} if scan_config: try: @@ -275,7 +315,7 @@ async def scan_zip( except json.JSONDecodeError: pass - # Create Task + # 创建扫描任务 task = AuditTask( project_id=project_id, created_by=current_user.id, @@ -299,14 +339,16 @@ async def scan_zip( 'prompt_template_id': parsed_scan_config.get('prompt_template_id'), } - # Trigger Background Task - 使用持久化存储的文件路径 + # 触发后台任务 - 使用持久化存储的文件路径 stored_zip_path = await load_project_zip(project_id) background_tasks.add_task(process_zip_task, task.id, stored_zip_path or file_path, AsyncSessionLocal, user_config) + # 返回任务状态 return {"task_id": task.id, "status": "queued"} class ScanRequest(BaseModel): + """扫描配置请求模型。""" file_paths: Optional[List[str]] = None full_scan: bool = True exclude_patterns: Optional[List[str]] = None @@ -323,9 +365,14 @@ async def scan_stored_zip( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 使用已存储的ZIP文件启动扫描(无需重新上传) + 使用已存储的 ZIP 文件启动扫描(无需重新上传)。 + + 处理流程: + - 校验项目与权限 + - 校验已存储 ZIP + - 创建任务并触发后台处理 """ - # Verify project exists + # 校验项目存在 project = await db.get(Project, project_id) if not project: raise HTTPException(status_code=404, detail="项目不存在") @@ -334,12 +381,12 @@ async def scan_stored_zip( if project.owner_id != current_user.id: raise HTTPException(status_code=403, detail="无权操作此项目") - # 检查是否有存储的ZIP文件 + # 检查是否有存储的 ZIP 文件 stored_zip_path = await load_project_zip(project_id) if not stored_zip_path: raise HTTPException(status_code=400, detail="项目没有已存储的ZIP文件,请先上传") - # Create Task + # 创建扫描任务 task = AuditTask( project_id=project_id, created_by=current_user.id, @@ -363,19 +410,22 @@ async def scan_stored_zip( 'prompt_template_id': scan_request.prompt_template_id, } - # Trigger Background Task + # 触发后台任务 background_tasks.add_task(process_zip_task, task.id, stored_zip_path, AsyncSessionLocal, user_config) + # 返回任务状态 return {"task_id": task.id, "status": "queued"} class InstantAnalysisRequest(BaseModel): + """即时分析请求模型。""" code: str language: str prompt_template_id: Optional[str] = None class InstantAnalysisResponse(BaseModel): + """即时分析响应模型。""" id: str user_id: str language: str @@ -390,7 +440,15 @@ class Config: async def get_user_config_dict(db: AsyncSession, user_id: str) -> dict: - """获取用户配置字典(包含解密敏感字段)""" + """ + 获取用户配置字典(包含解密敏感字段)。 + + 处理流程: + - 查询用户配置 + - 解析并解密敏感字段 + - 返回统一配置字典 + """ + # 延迟导入解密函数 from app.core.encryption import decrypt_sensitive_data # 需要解密的敏感字段列表(与 config.py 保持一致) @@ -402,16 +460,29 @@ async def get_user_config_dict(db: AsyncSession, user_id: str) -> dict: SENSITIVE_OTHER_FIELDS = ['githubToken', 'gitlabToken'] def decrypt_config(config: dict, sensitive_fields: list) -> dict: - """解密配置中的敏感字段""" + """ + 解密配置中的敏感字段。 + + 处理流程: + - 拷贝配置 + - 遍历敏感字段并解密 + - 返回解密结果 + """ + # 拷贝配置 decrypted = config.copy() + # 遍历敏感字段 for field in sensitive_fields: + # 有值时解密 if field in decrypted and decrypted[field]: decrypted[field] = decrypt_sensitive_data(decrypted[field]) + # 返回解密后的配置 return decrypted + # 查询用户配置 result = await db.execute( select(UserConfig).where(UserConfig.user_id == user_id) ) + # 获取配置对象 config = result.scalar_one_or_none() if not config: return {} @@ -424,6 +495,7 @@ def decrypt_config(config: dict, sensitive_fields: list) -> dict: llm_config = decrypt_config(llm_config, SENSITIVE_LLM_FIELDS) other_config = decrypt_config(other_config, SENSITIVE_OTHER_FIELDS) + # 返回组合后的配置 return { 'llmConfig': llm_config, 'otherConfig': other_config, @@ -437,14 +509,21 @@ async def instant_analysis( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Perform instant code analysis. + 执行即时代码分析。 + + 处理流程: + - 获取用户配置 + - 调用 LLM 分析 + - 保存分析记录 + - 返回分析结果 """ # 获取用户配置 user_config = await get_user_config_dict(db, current_user.id) - # 创建使用用户配置的LLM服务实例 + # 创建使用用户配置的 LLM 服务实例 llm_service = LLMService(user_config=user_config) + # 记录开始时间 start_time = datetime.now(timezone.utc) try: @@ -465,10 +544,11 @@ async def instant_analysis( detail=f"代码分析失败: {error_msg}" ) + # 计算分析耗时 end_time = datetime.now(timezone.utc) duration = (end_time - start_time).total_seconds() - # Save record + # 保存分析记录 analysis = InstantAnalysis( user_id=current_user.id, language=req.language, @@ -482,7 +562,7 @@ async def instant_analysis( await db.commit() await db.refresh(analysis) - # Return result with analysis ID for export functionality + # 返回结果并附带分析记录 ID return { **result, "analysis_id": analysis.id, @@ -497,14 +577,20 @@ async def get_instant_analysis_history( limit: int = 20, ) -> Any: """ - Get user's instant analysis history. + 获取用户即时分析历史记录。 + + 处理流程: + - 查询用户历史记录 + - 按时间倒序返回 """ + # 查询历史记录 result = await db.execute( select(InstantAnalysis) .where(InstantAnalysis.user_id == current_user.id) .order_by(InstantAnalysis.created_at.desc()) .limit(limit) ) + # 返回记录列表 return result.scalars().all() @@ -515,21 +601,30 @@ async def delete_instant_analysis( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Delete a specific instant analysis record. + 删除指定的即时分析记录。 + + 处理流程: + - 查询记录 + - 校验存在性 + - 删除并提交 """ + # 查询分析记录 result = await db.execute( select(InstantAnalysis) .where(InstantAnalysis.id == analysis_id) .where(InstantAnalysis.user_id == current_user.id) ) + # 获取记录对象 analysis = result.scalar_one_or_none() if not analysis: raise HTTPException(status_code=404, detail="分析记录不存在") + # 删除记录 await db.delete(analysis) await db.commit() + # 返回删除结果 return {"message": "删除成功"} @@ -539,15 +634,22 @@ async def delete_all_instant_analyses( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Delete all instant analysis records for current user. + 删除当前用户的全部即时分析记录。 + + 处理流程: + - 批量删除用户记录 + - 提交事务 """ + # 延迟导入 delete from sqlalchemy import delete + # 执行批量删除 await db.execute( delete(InstantAnalysis).where(InstantAnalysis.user_id == current_user.id) ) await db.commit() + # 返回删除结果 return {"message": "已清空所有历史记录"} @@ -558,8 +660,14 @@ async def export_instant_report_pdf( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Export instant analysis report as PDF by analysis ID. + 根据分析记录 ID 导出即时分析 PDF 报告。 + + 处理流程: + - 查询分析记录 + - 解析分析结果 + - 生成 PDF 并返回 """ + # 延迟导入响应与报告生成器 from fastapi.responses import Response from app.services.report_generator import ReportGenerator @@ -587,7 +695,7 @@ async def export_instant_report_pdf( analysis.analysis_time ) - # 返回 PDF 文件 + # 构建文件名并返回 PDF 文件 filename = f"instant-analysis-{analysis.language}-{analysis.id[:8]}.pdf" return Response( content=pdf_bytes, diff --git a/backend/app/api/v1/endpoints/ssh_keys.py b/backend/app/api/v1/endpoints/ssh_keys.py index 81f067aa..8d7d36b6 100644 --- a/backend/app/api/v1/endpoints/ssh_keys.py +++ b/backend/app/api/v1/endpoints/ssh_keys.py @@ -23,21 +23,25 @@ # Schemas class SSHKeyGenerateResponse(BaseModel): + """SSH 密钥生成响应模型。""" public_key: str message: str class SSHKeyResponse(BaseModel): + """SSH 密钥查询响应模型。""" has_key: bool public_key: Optional[str] = None fingerprint: Optional[str] = None class SSHKeyTestRequest(BaseModel): + """SSH 密钥测试请求模型。""" repo_url: str class SSHKeyTestResponse(BaseModel): + """SSH 密钥测试响应模型。""" success: bool message: str output: Optional[str] = None @@ -50,12 +54,16 @@ async def generate_ssh_key( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 生成新的SSH密钥对 + 生成新的 SSH 密钥对。 - 生成RSA 4096格式的SSH密钥对,私钥加密存储在用户配置中,公钥返回给用户 + 处理流程: + - 生成 RSA 4096 密钥对 + - 获取或创建用户配置 + - 加密并保存私钥,保存公钥 + - 返回公钥与指纹 """ try: - # 生成SSH密钥对 (RSA 4096) + # 生成 SSH 密钥对 (RSA 4096) private_key, public_key = SSHKeyService.generate_rsa_key(key_size=4096) # 获取或创建用户配置 @@ -65,6 +73,7 @@ async def generate_ssh_key( user_config = result.scalar_one_or_none() if not user_config: + # 创建新的用户配置 user_config = UserConfig( user_id=current_user.id, llm_config="{}", @@ -72,22 +81,25 @@ async def generate_ssh_key( ) db.add(user_config) - # 解析现有的other_config + # 解析现有的 other_config other_config = json.loads(user_config.other_config) if user_config.other_config else {} # 加密并存储私钥 encrypted_private_key = encrypt_sensitive_data(private_key) other_config['sshPrivateKey'] = encrypted_private_key - other_config['sshPublicKey'] = public_key # 公钥不需要加密 + # 公钥不需要加密 + other_config['sshPublicKey'] = public_key # 更新配置 user_config.other_config = json.dumps(other_config) + # 提交事务 await db.commit() # 计算公钥指纹 fingerprint = SSHKeyService.get_public_key_fingerprint(public_key) + # 返回公钥与指纹 return { "public_key": public_key, "fingerprint": fingerprint, @@ -105,7 +117,12 @@ async def get_ssh_key( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取当前用户的SSH公钥 + 获取当前用户的 SSH 公钥。 + + 处理流程: + - 读取用户配置 + - 解析公钥并计算指纹 + - 返回公钥信息或空结果 """ try: # 获取用户配置 @@ -115,21 +132,25 @@ async def get_ssh_key( user_config = result.scalar_one_or_none() if not user_config or not user_config.other_config: + # 未找到配置则返回无密钥 return {"has_key": False} # 解析配置 other_config = json.loads(user_config.other_config) if 'sshPublicKey' in other_config: + # 读取公钥并计算指纹 public_key = other_config['sshPublicKey'] fingerprint = SSHKeyService.get_public_key_fingerprint(public_key) + # 返回公钥信息 return { "has_key": True, "public_key": public_key, "fingerprint": fingerprint } else: + # 未找到公钥 return {"has_key": False} except Exception as e: @@ -143,7 +164,12 @@ async def delete_ssh_key( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 删除当前用户的SSH密钥 + 删除当前用户的 SSH 密钥。 + + 处理流程: + - 读取用户配置 + - 删除私钥与公钥字段 + - 保存配置并返回结果 """ try: # 获取用户配置 @@ -153,12 +179,13 @@ async def delete_ssh_key( user_config = result.scalar_one_or_none() if not user_config or not user_config.other_config: + # 未找到密钥记录 raise HTTPException(status_code=404, detail="未找到SSH密钥") # 解析配置 other_config = json.loads(user_config.other_config) - # 删除SSH密钥 + # 删除 SSH 密钥 if 'sshPrivateKey' in other_config: del other_config['sshPrivateKey'] if 'sshPublicKey' in other_config: @@ -166,8 +193,10 @@ async def delete_ssh_key( # 更新配置 user_config.other_config = json.dumps(other_config) + # 提交事务 await db.commit() + # 返回删除结果 return {"message": "SSH密钥已删除"} except HTTPException: @@ -185,7 +214,7 @@ async def test_ssh_key( test_request: SSHKeyTestRequest, ) -> Any: """ - 测试SSH密钥是否有效 + 测试 SSH 密钥是否有效。 Args: test_request: 包含repo_url的测试请求 @@ -198,16 +227,19 @@ async def test_ssh_key( user_config = result.scalar_one_or_none() if not user_config or not user_config.other_config: + # 未找到密钥记录 raise HTTPException(status_code=404, detail="未找到SSH密钥,请先生成SSH密钥") # 解析配置 other_config = json.loads(user_config.other_config) if 'sshPrivateKey' not in other_config: + # 缺少私钥 raise HTTPException(status_code=404, detail="未找到SSH密钥,请先生成SSH密钥") # 解密私钥 private_key = decrypt_sensitive_data(other_config['sshPrivateKey']) + # 读取公钥 public_key = other_config.get('sshPublicKey', '') # 验证密钥对是否匹配 @@ -215,15 +247,17 @@ async def test_ssh_key( logger.debug(f"Key pair validation result: {is_valid}") if not is_valid: + # 返回密钥对不匹配结果 return { "success": False, "message": "密钥对验证失败:私钥和公钥不匹配", "output": "请重新生成SSH密钥" } - # 测试SSH连接 + # 测试 SSH 连接 result = GitSSHOperations.test_ssh_key(test_request.repo_url, private_key) + # 返回测试结果 return result except HTTPException: @@ -238,20 +272,24 @@ async def clear_known_hosts_file( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 清理known_hosts文件 + 清理 known_hosts 文件。 - 清空SSH known_hosts文件中保存的所有主机密钥。 - 下次连接时会重新接受并保存新的host key。 + 处理流程: + - 调用系统清理函数 + - 返回清理结果 """ try: + # 清理 known_hosts 文件 success = clear_known_hosts() if success: + # 返回成功结果 return { "success": True, "message": "known_hosts文件已清理,下次连接时会重新保存主机密钥" } else: + # 清理失败 raise HTTPException(status_code=500, detail="清理known_hosts文件失败") except HTTPException: diff --git a/backend/app/api/v1/endpoints/tasks.py b/backend/app/api/v1/endpoints/tasks.py index 16a52c6d..a09b046b 100644 --- a/backend/app/api/v1/endpoints/tasks.py +++ b/backend/app/api/v1/endpoints/tasks.py @@ -1,3 +1,7 @@ +""" +模块说明:API 路由与依赖定义:tasks。 +""" + from typing import Any, List, Optional from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession @@ -16,8 +20,8 @@ router = APIRouter() -# Schemas class AuditIssueSchema(BaseModel): + """审计问题的响应模型。""" id: str task_id: str file_path: str @@ -41,10 +45,12 @@ class Config: class IssueUpdateSchema(BaseModel): + """问题状态更新的请求体模型。""" status: Optional[str] = None class ProjectSchema(BaseModel): + """项目基础信息的响应模型。""" id: str name: str description: Optional[str] = None @@ -63,6 +69,7 @@ class Config: class AuditTaskSchema(BaseModel): + """审计任务的响应模型。""" id: str project_id: str task_type: str @@ -92,21 +99,32 @@ async def list_tasks( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - List tasks for current user's projects. + 获取当前用户项目的任务列表。 + + 处理流程: + - 查询当前用户拥有的项目列表 + - 构建任务查询并过滤项目范围 + - 可选按项目过滤 + - 返回按创建时间倒序的任务列表 """ - # 先获取当前用户的项目ID列表 + # 查询当前用户拥有的项目 ID 列表 projects_result = await db.execute( select(Project.id).where(Project.owner_id == current_user.id) ) + # 提取项目 ID 列表 user_project_ids = [p[0] for p in projects_result.fetchall()] - + # 构建任务查询并预加载项目 query = select(AuditTask).options(selectinload(AuditTask.project)) - # 只返回当前用户项目的任务 + # 仅返回当前用户项目的任务 query = query.where(AuditTask.project_id.in_(user_project_ids)) if user_project_ids else query.where(False) + # 可选按项目过滤 if project_id: query = query.where(AuditTask.project_id == project_id) + # 按创建时间倒序 query = query.order_by(AuditTask.created_at.desc()) + # 执行查询 result = await db.execute(query) + # 返回任务列表 return result.scalars().all() @@ -117,21 +135,28 @@ async def read_task( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Get task status by ID. + 获取指定任务信息。 + + 处理流程: + - 查询任务并预加载项目 + - 校验任务存在性 + - 校验是否为创建者 """ + # 查询任务并加载项目 result = await db.execute( select(AuditTask) .options(selectinload(AuditTask.project)) .where(AuditTask.id == id) ) + # 提取任务对象 task = result.scalars().first() + # 若任务不存在则返回 404 if not task: raise HTTPException(status_code=404, detail="任务不存在") - - # 检查权限:只有任务创建者可以查看 + # 校验权限:仅任务创建者可查看 if task.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权查看此任务") - + # 返回任务信息 return task @@ -142,28 +167,34 @@ async def cancel_task( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Cancel a running task. + 取消运行中的任务。 + + 处理流程: + - 查询任务 + - 校验创建者权限 + - 校验任务状态 + - 标记任务取消并更新数据库 """ + # 查询任务 result = await db.execute(select(AuditTask).where(AuditTask.id == id)) + # 提取任务对象 task = result.scalars().first() + # 若任务不存在则返回 404 if not task: raise HTTPException(status_code=404, detail="任务不存在") - - # 检查权限:只有任务创建者可以取消 + # 校验权限:仅任务创建者可取消 if task.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权取消此任务") - + # 校验可取消状态 if task.status not in ["pending", "running"]: raise HTTPException(status_code=400, detail="只能取消待处理或运行中的任务") - # 标记任务为取消 task_control.cancel_task(id) - # 更新数据库状态 task.status = "cancelled" task.completed_at = datetime.now(timezone.utc) await db.commit() - + # 返回取消结果 return {"message": "任务已取消", "task_id": id} @@ -174,20 +205,26 @@ async def read_task_issues( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Get issues for a specific task. + 获取指定任务的问题列表。 + + 处理流程: + - 查询任务并校验存在性 + - 校验创建者权限 + - 查询问题并按严重程度排序 """ - # 先检查任务是否存在且属于当前用户 + # 查询任务以校验权限 task_result = await db.execute( select(AuditTask).where(AuditTask.id == id) ) + # 提取任务对象 task = task_result.scalars().first() + # 若任务不存在则返回 404 if not task: raise HTTPException(status_code=404, detail="任务不存在") - - # 检查权限:只有任务创建者可以查看问题 + # 校验权限:仅任务创建者可查看问题 if task.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权查看此任务的问题") - + # 查询任务问题并排序 result = await db.execute( select(AuditIssue) .where(AuditIssue.task_id == id) @@ -197,6 +234,7 @@ async def read_task_issues( AuditIssue.created_at.desc() ) ) + # 返回问题列表 return result.scalars().all() @@ -209,24 +247,35 @@ async def update_issue( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Update issue status (e.g., resolve, mark as false positive). + 更新问题状态(如已解决、误报)。 + + 处理流程: + - 查询问题并校验存在性 + - 更新状态与解决信息 + - 保存并返回更新结果 """ + # 查询指定问题 result = await db.execute( select(AuditIssue) .where(AuditIssue.id == issue_id, AuditIssue.task_id == task_id) ) + # 提取问题对象 issue = result.scalars().first() + # 若问题不存在则返回 404 if not issue: raise HTTPException(status_code=404, detail="问题不存在") - + # 如果传入状态则更新 if issue_update.status: issue.status = issue_update.status + # 若标记为已解决则记录处理人和时间 if issue_update.status == "resolved": issue.resolved_by = current_user.id issue.resolved_at = datetime.now(timezone.utc) - + # 提交事务 await db.commit() + # 刷新对象 await db.refresh(issue) + # 返回更新结果 return issue @@ -237,34 +286,40 @@ async def export_task_report_pdf( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - Export task audit report as PDF. + 导出任务审计报告为 PDF。 + + 处理流程: + - 查询任务并校验权限 + - 查询问题并整理数据 + - 生成 PDF 并返回文件响应 """ + # 延迟导入响应类 from fastapi.responses import Response + # 延迟导入报告生成器 from app.services.report_generator import ReportGenerator - # 获取任务 result = await db.execute( select(AuditTask) .options(selectinload(AuditTask.project)) .where(AuditTask.id == id) ) + # 提取任务对象 task = result.scalars().first() + # 若任务不存在则返回 404 if not task: raise HTTPException(status_code=404, detail="任务不存在") - - # 检查权限 + # 校验权限:仅创建者可导出 if task.created_by != current_user.id: raise HTTPException(status_code=403, detail="无权导出此任务报告") - - # 获取问题列表 + # 查询问题列表 issues_result = await db.execute( select(AuditIssue) .where(AuditIssue.task_id == id) .order_by(AuditIssue.severity.desc(), AuditIssue.created_at.desc()) ) + # 提取问题列表 issues = issues_result.scalars().all() - - # 转换为字典 + # 将任务信息转换为字典 task_dict = { 'id': task.id, 'status': task.status, @@ -277,7 +332,7 @@ async def export_task_report_pdf( 'created_at': task.created_at.isoformat() if task.created_at else None, 'completed_at': task.completed_at.isoformat() if task.completed_at else None, } - + # 将问题列表转换为字典列表 issues_list = [ { 'title': issue.title, @@ -292,14 +347,13 @@ async def export_task_report_pdf( } for issue in issues ] - + # 获取项目名称 project_name = task.project.name if task.project else "Unknown Project" - - # 生成 PDF + # 生成 PDF 内容 pdf_bytes = ReportGenerator.generate_task_report(task_dict, issues_list, project_name) - - # 返回 PDF 文件 + # 构建文件名 filename = f"audit-report-{task.id[:8]}-{datetime.now(timezone.utc).strftime('%Y%m%d')}.pdf" + # 返回 PDF 响应 return Response( content=pdf_bytes, media_type="application/pdf", diff --git a/backend/app/api/v1/endpoints/users.py b/backend/app/api/v1/endpoints/users.py index 67a3958a..66176910 100644 --- a/backend/app/api/v1/endpoints/users.py +++ b/backend/app/api/v1/endpoints/users.py @@ -1,3 +1,7 @@ +""" +模块说明:API 路由与依赖定义:users。 +""" + from typing import Any, List, Optional from fastapi import APIRouter, Body, Depends, HTTPException, Query from fastapi.encoders import jsonable_encoder @@ -24,40 +28,53 @@ async def read_users( current_user: User = Depends(deps.get_current_active_superuser), ) -> Any: """ - 获取用户列表(支持分页、搜索、筛选) + 获取用户列表(支持分页、搜索、筛选)。 + + 处理流程: + - 构建基础查询与计数查询 + - 根据搜索、角色、状态追加过滤条件 + - 计算总数 + - 执行分页查询并返回结果 """ + # 基础查询:获取用户 query = select(User) + # 统计查询:获取总数 count_query = select(func.count(User.id)) - - # 搜索条件 + # 追加搜索条件 if search: + # 构建模糊搜索过滤条件 search_filter = or_( User.email.ilike(f"%{search}%"), User.full_name.ilike(f"%{search}%"), User.phone.ilike(f"%{search}%") ) + # 将搜索条件应用到数据查询 query = query.where(search_filter) + # 将搜索条件应用到统计查询 count_query = count_query.where(search_filter) - - # 角色筛选 + # 追加角色筛选 if role: + # 限定角色 query = query.where(User.role == role) + # 限定统计范围 count_query = count_query.where(User.role == role) - - # 状态筛选 + # 追加状态筛选 if is_active is not None: + # 限定启用状态 query = query.where(User.is_active == is_active) + # 限定统计范围 count_query = count_query.where(User.is_active == is_active) - - # 获取总数 + # 执行统计查询 total_result = await db.execute(count_query) + # 取出总数 total = total_result.scalar() - - # 分页查询 + # 应用排序与分页 query = query.order_by(User.created_at.desc()).offset(skip).limit(limit) + # 执行数据查询 result = await db.execute(query) + # 取出用户列表 users = result.scalars().all() - + # 返回分页结果 return { "users": users, "total": total, @@ -73,28 +90,46 @@ async def create_user( current_user: User = Depends(deps.get_current_active_superuser), ) -> Any: """ - 创建新用户(仅管理员) + 创建新用户(仅管理员)。 + + 处理流程: + - 校验邮箱是否已被注册 + - 构建用户对象并写入数据库 """ + # 检查邮箱是否已存在 result = await db.execute(select(User).where(User.email == user_in.email)) + # 提取已有用户 user = result.scalars().first() + # 若已存在则返回 400 if user: raise HTTPException( status_code=400, detail="该邮箱已被注册", ) - + # 构建用户对象 db_user = User( + # 邮箱 email=user_in.email, + # 密码哈希 hashed_password=security.get_password_hash(user_in.password), + # 显示名 full_name=user_in.full_name, + # 手机号 phone=user_in.phone, + # 角色 role=user_in.role, + # 是否启用 is_active=user_in.is_active if user_in.is_active is not None else True, + # 是否超级管理员 is_superuser=user_in.is_superuser if user_in.is_superuser is not None else False, ) + # 写入数据库 db.add(db_user) + # 提交事务 await db.commit() + # 刷新对象 await db.refresh(db_user) + # 返回创建结果 return db_user @router.get("/me", response_model=UserSchema) @@ -102,8 +137,13 @@ async def read_user_me( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取当前用户信息 + 获取当前用户信息。 + + 处理流程: + - 依赖注入获取当前用户 + - 直接返回用户对象 """ + # 返回当前用户 return current_user @router.put("/me", response_model=UserSchema) @@ -114,25 +154,33 @@ async def update_user_me( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 更新当前用户信息 + 更新当前用户信息。 + + 处理流程: + - 提取可更新字段 + - 移除普通用户不可修改字段 + - 处理密码更新 + - 写入数据库并返回 """ + # 获取请求中的可更新字段 update_data = user_in.model_dump(exclude_unset=True) - - # 普通用户不能修改自己的角色和超级管理员状态 + # 移除普通用户不可修改的字段 update_data.pop('role', None) update_data.pop('is_superuser', None) update_data.pop('is_active', None) - - # 如果更新密码 + # 如果包含密码则进行哈希处理 if 'password' in update_data and update_data['password']: update_data['hashed_password'] = security.get_password_hash(update_data['password']) + # 移除明文密码字段 update_data.pop('password', None) - + # 将更新字段写入当前用户对象 for field, value in update_data.items(): setattr(current_user, field, value) - + # 提交事务 await db.commit() + # 刷新对象 await db.refresh(current_user) + # 返回更新结果 return current_user @router.get("/{user_id}", response_model=UserSchema) @@ -142,12 +190,21 @@ async def read_user( current_user: User = Depends(deps.get_current_user), ) -> Any: """ - 获取指定用户信息 + 获取指定用户信息。 + + 处理流程: + - 根据用户 ID 查询用户 + - 不存在则返回 404 + - 返回用户信息 """ + # 按用户 ID 查询 result = await db.execute(select(User).where(User.id == user_id)) + # 提取用户对象 user = result.scalars().first() + # 若用户不存在则返回 404 if not user: raise HTTPException(status_code=404, detail="用户不存在") + # 返回用户 return user @router.put("/{user_id}", response_model=UserSchema) @@ -159,25 +216,36 @@ async def update_user( current_user: User = Depends(deps.get_current_active_superuser), ) -> Any: """ - 更新用户信息(仅管理员) + 更新用户信息(仅管理员)。 + + 处理流程: + - 查询目标用户 + - 提取更新字段 + - 处理密码更新 + - 写入数据库并返回 """ + # 查询目标用户 result = await db.execute(select(User).where(User.id == user_id)) + # 提取用户对象 user = result.scalars().first() + # 若用户不存在则返回 404 if not user: raise HTTPException(status_code=404, detail="用户不存在") - + # 获取请求中的可更新字段 update_data = user_in.model_dump(exclude_unset=True) - - # 如果更新密码 + # 如果包含密码则进行哈希处理 if 'password' in update_data and update_data['password']: update_data['hashed_password'] = security.get_password_hash(update_data['password']) + # 移除明文密码字段 update_data.pop('password', None) - + # 写入更新字段 for field, value in update_data.items(): setattr(user, field, value) - + # 提交事务 await db.commit() + # 刷新对象 await db.refresh(user) + # 返回更新结果 return user @router.delete("/{user_id}") @@ -187,18 +255,28 @@ async def delete_user( current_user: User = Depends(deps.get_current_active_superuser), ) -> Any: """ - 删除用户(仅管理员) + 删除用户(仅管理员)。 + + 处理流程: + - 禁止删除自身 + - 查询目标用户 + - 删除用户并提交 """ + # 禁止删除自己的账户 if user_id == current_user.id: raise HTTPException(status_code=400, detail="不能删除自己的账户") - + # 查询目标用户 result = await db.execute(select(User).where(User.id == user_id)) + # 提取用户对象 user = result.scalars().first() + # 若用户不存在则返回 404 if not user: raise HTTPException(status_code=404, detail="用户不存在") - + # 删除用户 await db.delete(user) + # 提交事务 await db.commit() + # 返回删除结果 return {"message": "用户已删除"} @router.post("/{user_id}/toggle-status", response_model=UserSchema) @@ -208,23 +286,33 @@ async def toggle_user_status( current_user: User = Depends(deps.get_current_active_superuser), ) -> Any: """ - 切换用户状态(启用/禁用) + 切换用户状态(启用/禁用)。 + + 处理流程: + - 禁止禁用自身账户 + - 查询目标用户 + - 切换启用状态并保存 """ + # 禁止禁用自己的账户 if user_id == current_user.id: raise HTTPException(status_code=400, detail="不能禁用自己的账户") - + # 查询目标用户 result = await db.execute(select(User).where(User.id == user_id)) + # 提取用户对象 user = result.scalars().first() + # 若用户不存在则返回 404 if not user: raise HTTPException(status_code=404, detail="用户不存在") - + # 切换启用状态 user.is_active = not user.is_active + # 提交事务 await db.commit() + # 刷新对象 await db.refresh(user) + # 返回更新结果 return user - diff --git a/backend/app/core/__init__.py b/backend/app/core/__init__.py index e69de29b..5e6a94da 100644 --- a/backend/app/core/__init__.py +++ b/backend/app/core/__init__.py @@ -0,0 +1,3 @@ +""" +模块说明:核心配置与安全组件:__init__。 +""" diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 50733690..8042f9bc 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -1,3 +1,7 @@ +""" +模块说明:核心配置与安全组件:config。 +""" + from typing import List, Union, Optional from pydantic import AnyHttpUrl, validator from pydantic_settings import BaseSettings diff --git a/backend/app/core/security.py b/backend/app/core/security.py index 3e652fc7..36df3437 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -1,3 +1,7 @@ +""" +模块说明:核心配置与安全组件:security。 +""" + from datetime import datetime, timedelta, timezone from typing import Any, Union from jose import jwt diff --git a/backend/app/db/__init__.py b/backend/app/db/__init__.py index e69de29b..97e352a3 100644 --- a/backend/app/db/__init__.py +++ b/backend/app/db/__init__.py @@ -0,0 +1,3 @@ +""" +模块说明:数据库初始化与会话管理:__init__。 +""" diff --git a/backend/app/db/base.py b/backend/app/db/base.py index 8d235153..4d403c93 100644 --- a/backend/app/db/base.py +++ b/backend/app/db/base.py @@ -1,3 +1,7 @@ +""" +模块说明:数据库初始化与会话管理:base。 +""" + from sqlalchemy.orm import as_declarative, declared_attr @as_declarative() diff --git a/backend/app/db/session.py b/backend/app/db/session.py index 823a85b6..be50ca4a 100644 --- a/backend/app/db/session.py +++ b/backend/app/db/session.py @@ -1,3 +1,7 @@ +""" +模块说明:数据库初始化与会话管理:session。 +""" + from contextlib import asynccontextmanager from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker diff --git a/backend/app/main.py b/backend/app/main.py index 45c5ca8d..c19f5b11 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,3 +1,7 @@ +""" +模块说明:应用模块:main。 +""" + import logging from contextlib import asynccontextmanager from fastapi import FastAPI @@ -19,71 +23,115 @@ async def check_agent_services(): - """检查 Agent 必须服务的可用性""" + """ + 检查 Agent 依赖服务的可用性。 + + 处理流程: + - 检查 Docker 客户端与守护进程 + - 检查 Redis 连接可用性 + - 返回所有不可用项 + """ + # 存放不可用服务的提示信息 issues = [] # 检查 Docker/沙箱服务 try: + # 延迟导入,避免未安装时直接报错 import docker + # 创建 Docker 客户端 client = docker.from_env() + # 发送 ping 以验证可用性 client.ping() + # 记录服务可用 logger.info(" - Docker 服务可用") except ImportError: + # Docker 客户端库未安装 issues.append("Docker Python 库未安装 (pip install docker)") except Exception as e: + # Docker 服务不可用或连接异常 issues.append(f"Docker 服务不可用: {e}") # 检查 Redis 连接(可选警告) try: + # 延迟导入,避免未安装时直接报错 import redis + # 读取环境变量 import os + # 获取 Redis 连接地址 redis_url = os.environ.get("REDIS_URL", "redis://localhost:6379/0") + # 创建 Redis 客户端 r = redis.from_url(redis_url) + # 发送 ping 验证连接 r.ping() + # 记录服务可用 logger.info(" - Redis 服务可用") except ImportError: + # Redis 客户端库未安装 logger.warning(" - Redis Python 库未安装,部分功能可能受限") except Exception as e: + # Redis 服务连接失败 logger.warning(f" - Redis 服务连接失败: {e}") + # 返回不可用服务列表 return issues @asynccontextmanager async def lifespan(app: FastAPI): """ - 应用生命周期管理 - 启动时初始化数据库(创建默认账户等) + 应用生命周期管理。 + + 处理流程: + - 启动时初始化数据库 + - 检查 Agent 依赖服务 + - 输出启动日志 + - 关闭时输出关闭日志 """ + # 记录启动日志 logger.info("DeepAudit 后端服务启动中...") # 初始化数据库(创建默认账户) # 注意:需要先运行 alembic upgrade head 创建表结构 try: + # 打开异步数据库会话 async with AsyncSessionLocal() as db: + # 执行初始化逻辑 await init_db(db) + # 记录初始化完成 logger.info(" - 数据库初始化完成") except Exception as e: + # 解析异常信息 # 表不存在时静默跳过,等待用户运行数据库迁移 error_msg = str(e) + # 若表不存在则提示迁移 if "does not exist" in error_msg or "UndefinedTableError" in error_msg: logger.info("数据库表未创建,请先运行: alembic upgrade head") else: + # 其他异常仅记录警告 logger.warning(f"数据库初始化跳过: {e}") # 检查 Agent 服务 logger.info("检查 Agent 核心服务...") + # 执行服务检查 issues = await check_agent_services() + # 如果存在问题则输出详细信息 if issues: + # 打印分隔线 logger.warning("=" * 50) + # 打印问题标题 logger.warning("Agent 服务检查发现问题:") + # 逐条输出问题 for issue in issues: logger.warning(f" - {issue}") + # 输出提醒信息 logger.warning("部分功能可能不可用,请检查配置") + # 打印分隔线 logger.warning("=" * 50) else: + # 记录检查通过 logger.info(" - Agent 核心服务检查通过") + # 输出启动信息 logger.info("=" * 50) logger.info("DeepAudit 后端服务已启动") logger.info(f"API 文档: http://localhost:8000/docs") @@ -91,8 +139,10 @@ async def lifespan(app: FastAPI): logger.info("演示账户: demo@example.com / demo123") logger.info("=" * 50) + # 交出控制权给应用运行 yield + # 记录关闭日志 logger.info("DeepAudit 后端服务已关闭") @@ -116,11 +166,25 @@ async def lifespan(app: FastAPI): @app.get("/health") async def health_check(): + """ + 健康检查接口。 + + 处理流程: + - 直接返回固定状态 + """ + # 返回健康状态 return {"status": "ok"} @app.get("/") async def root(): + """ + 根路径接口。 + + 处理流程: + - 返回欢迎信息与文档入口 + """ + # 返回 API 欢迎信息 return { "message": "Welcome to DeepAudit API", "docs": "/docs", diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 06527e3c..eb30bf3c 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,3 +1,7 @@ +""" +模块说明:数据模型定义:__init__。 +""" + from .user import User from .user_config import UserConfig from .project import Project, ProjectMember diff --git a/backend/app/models/analysis.py b/backend/app/models/analysis.py index 9eedb5c8..215e2a80 100644 --- a/backend/app/models/analysis.py +++ b/backend/app/models/analysis.py @@ -1,3 +1,7 @@ +""" +模块说明:数据模型定义:analysis。 +""" + import uuid from sqlalchemy import Column, String, Integer, DateTime, Float, Text, ForeignKey from sqlalchemy.sql import func diff --git a/backend/app/models/audit.py b/backend/app/models/audit.py index f207a233..43d5ad7d 100644 --- a/backend/app/models/audit.py +++ b/backend/app/models/audit.py @@ -1,3 +1,7 @@ +""" +模块说明:数据模型定义:audit。 +""" + import uuid from sqlalchemy import Column, String, Integer, DateTime, ForeignKey, Text, Float from sqlalchemy.sql import func diff --git a/backend/app/models/project.py b/backend/app/models/project.py index 5debcbfa..0cdc1a18 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -1,3 +1,7 @@ +""" +模块说明:数据模型定义:project。 +""" + import uuid from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, Text from sqlalchemy.sql import func diff --git a/backend/app/models/user.py b/backend/app/models/user.py index 89d61576..cd6dec37 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -1,3 +1,7 @@ +""" +模块说明:数据模型定义:user。 +""" + import uuid from sqlalchemy import Column, String, Boolean, DateTime from sqlalchemy.sql import func diff --git a/backend/app/schemas/__init__.py b/backend/app/schemas/__init__.py index e69de29b..3529cfd0 100644 --- a/backend/app/schemas/__init__.py +++ b/backend/app/schemas/__init__.py @@ -0,0 +1,3 @@ +""" +模块说明:数据校验与序列化模型:__init__。 +""" diff --git a/backend/app/schemas/token.py b/backend/app/schemas/token.py index ab1a709f..8b289ca0 100644 --- a/backend/app/schemas/token.py +++ b/backend/app/schemas/token.py @@ -1,3 +1,7 @@ +""" +模块说明:数据校验与序列化模型:token。 +""" + from typing import Optional from pydantic import BaseModel diff --git a/backend/app/schemas/user.py b/backend/app/schemas/user.py index 88825e9f..19e606fb 100644 --- a/backend/app/schemas/user.py +++ b/backend/app/schemas/user.py @@ -1,3 +1,7 @@ +""" +模块说明:数据校验与序列化模型:user。 +""" + from typing import Optional, List from pydantic import BaseModel, EmailStr diff --git a/backend/app/utils/__init__.py b/backend/app/utils/__init__.py index e69de29b..520d265d 100644 --- a/backend/app/utils/__init__.py +++ b/backend/app/utils/__init__.py @@ -0,0 +1,3 @@ +""" +模块说明:包初始化与导出。 +""" diff --git a/backend/app/utils/repo_utils.py b/backend/app/utils/repo_utils.py index 58246dfa..184599c2 100644 --- a/backend/app/utils/repo_utils.py +++ b/backend/app/utils/repo_utils.py @@ -1,3 +1,7 @@ +""" +模块说明:应用模块:repo_utils。 +""" + from urllib.parse import urlparse, urlunparse from typing import Dict, Optional diff --git a/backend/check_docker_direct.py b/backend/check_docker_direct.py index 95394d9f..3ff0bb3c 100644 --- a/backend/check_docker_direct.py +++ b/backend/check_docker_direct.py @@ -1,3 +1,7 @@ +""" +模块说明:后端模块:check_docker_direct。 +""" + import sys try: diff --git a/backend/check_sandbox.py b/backend/check_sandbox.py index c5f092b3..5e221e82 100644 --- a/backend/check_sandbox.py +++ b/backend/check_sandbox.py @@ -1,3 +1,7 @@ +""" +模块说明:后端模块:check_sandbox。 +""" + import asyncio import logging diff --git a/backend/main.py b/backend/main.py index a5cd78ad..f5d37858 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,3 +1,7 @@ +""" +模块说明:后端模块:main。 +""" + def main(): print("Hello from deepaudit-backend!") diff --git a/backend/scripts/test_verification_sandbox.py b/backend/scripts/test_verification_sandbox.py index e788cc2e..3fd95c52 100644 --- a/backend/scripts/test_verification_sandbox.py +++ b/backend/scripts/test_verification_sandbox.py @@ -1,3 +1,7 @@ +""" +模块说明:用于脚本化执行的工具脚本:test_verification_sandbox。 +""" + import asyncio import os diff --git a/backend/verify_llm.py b/backend/verify_llm.py index b14659bf..02827a8a 100644 --- a/backend/verify_llm.py +++ b/backend/verify_llm.py @@ -1,3 +1,7 @@ +""" +模块说明:后端模块:verify_llm。 +""" + import asyncio import logging import sys diff --git a/docs/AGENT_AUDIT_ARCHITECTURE.md b/docs/AGENT_AUDIT_ARCHITECTURE.md index 405288f1..59865b7a 100644 --- a/docs/AGENT_AUDIT_ARCHITECTURE.md +++ b/docs/AGENT_AUDIT_ARCHITECTURE.md @@ -1,4 +1,4 @@ -# DeepAudit Agent 审计架构文档 +# DeepAudit 项目架构文档 ## 目录 @@ -17,6 +17,9 @@ 9. [关键设计模式](#9-关键设计模式) 10. [安全与健壮性](#10-安全与健壮性) 11. [扩展指南](#11-扩展指南) +12. [项目整体架构](#12-项目整体架构) +13. [核心业务流程](#13-核心业务流程) +14. [数据模型与存储](#14-数据模型与存储) --- @@ -1144,6 +1147,99 @@ switch (event.type) { --- +## 12. 项目整体架构 + +### 12.1 代码组织与层次 + +``` +DeepAudit_improve/ +├── backend/ # FastAPI 后端 +│ ├── app/ +│ │ ├── api/ # REST API 路由 +│ │ ├── core/ # 配置与安全 +│ │ ├── db/ # 数据库会话与初始化 +│ │ ├── models/ # ORM 数据模型 +│ │ ├── schemas/ # Pydantic Schema +│ │ ├── services/ # 业务服务层(Agent/LLM/RAG/扫描/存储) +│ │ └── main.py # 应用入口 +│ └── tests/ +├── frontend/ # React + TypeScript 前端 +│ └── src/ +│ ├── pages/ # 页面级模块 +│ ├── shared/api/ # API 调用与 SSE 客户端 +│ └── components/ # UI 组件 +├── docker/ # 沙箱与数据库镜像 +└── docs/ # 文档 +``` + +### 12.2 后端分层职责 + +- 入口层:`app/main.py` 创建 FastAPI 实例、CORS、中间件与生命周期管理 +- API 层:`app/api/v1/api.py` 聚合各业务路由并挂载到 `/api/v1` +- 依赖与鉴权:`app/api/deps.py` 提供 OAuth2/JWT 认证依赖 +- 业务服务层:`app/services/*` 实现扫描、Agent、LLM、RAG、ZIP 存储等核心逻辑 +- 数据层:`app/models/*` 定义数据模型,`app/db/session.py` 管理异步会话 + +### 12.3 前端分层职责 + +- 页面层:`pages/AgentAudit` 负责 Agent 审计的任务创建、日志、树状视图与报告导出 +- 状态层:`pages/AgentAudit/hooks` 与 `shared/api` 实现状态归集和 SSE 流处理 +- API 层:`shared/api/agentTasks.ts` 与 `shared/api/agentStream.ts` 负责任务与事件调用 + +--- + +## 13. 核心业务流程 + +### 13.1 项目创建与代码来源 + +1. 创建项目:通过 `projects` API 创建 Project 记录 +2. 代码来源: + - 仓库模式:根据仓库类型拉取代码(GitHub/GitLab/Gitea) + - ZIP 模式:上传 ZIP 并存储在本地 ZIP 存储目录 +3. 项目关联:记录项目元信息与用户所有权,供后续扫描与审计复用 + +### 13.2 传统扫描流程(非 Agent) + +1. 创建扫描任务:`scan` API 创建 `AuditTask` +2. LLM 分析:`LLMService` 解析文件与风险点,生成 `AuditIssue` +3. 任务生命周期:任务可取消,完成后持久化结果并返回摘要 + +### 13.3 Agent 审计流程(动态多 Agent) + +1. 任务创建:`agent_tasks` API 创建 `AgentTask` 并启动后台执行器 +2. 初始化上下文:收集项目结构、初始化工具、启动事件流 +3. 动态调度:Orchestrator 生成并派发 Recon/Analysis/Verification 子 Agent +4. 事件与结果:事件流写入数据库并同步 SSE,发现记录写入 `AgentFinding` +5. 报告生成:按严重程度汇总发现,生成 Markdown/JSON 报告 + +### 13.4 前端交互流程 + +1. 任务创建后进入 AgentAudit 页面 +2. SSE 流式接入:展示 thinking、tool、phase、finding 等事件 +3. 任务结束后展示报告导出入口与统计摘要 + +--- + +## 14. 数据模型与存储 + +### 14.1 核心实体 + +- 用户与权限:`User`、`ProjectMember` +- 项目与任务:`Project`、`AuditTask`、`AgentTask` +- 扫描结果:`AuditIssue`、`AgentFinding` +- 事件流:`AgentEvent` +- 规则与提示词:`AuditRuleSet`、`AuditRule`、`PromptTemplate` +- 用户配置:`UserConfig` + +### 14.2 存储策略 + +- 关系数据:PostgreSQL(通过 SQLAlchemy 异步会话访问) +- ZIP 代码包:本地文件系统存储,保存 ZIP 与元数据文件 +- RAG 索引:向量存储(Chroma/InMemory)持久化项目代码块与元数据 +- 事件流:运行态内存队列 + 数据库持久化,支持断线重连 + +--- + ## 附录: 关键文件索引 ``` @@ -1190,5 +1286,5 @@ API: --- -*文档版本: 1.0* -*最后更新: 2025-12-13* +*文档版本: 1.1* +*最后更新: 2026-03-06* diff --git a/docs/Interceptor/Java.md b/docs/Interceptor/Java.md new file mode 100644 index 00000000..f4ffacdb --- /dev/null +++ b/docs/Interceptor/Java.md @@ -0,0 +1,118 @@ +# Java Interceptor 清单 + +## 1. Spring MVC 拦截器 +- HandlerInterceptor +- HandlerInterceptorAdapter (旧版) +- AsyncHandlerInterceptor +- WebMvcConfigurer.addInterceptors(...) +- HandlerMethodArgumentResolver + +## 2. Spring AOP / Method Interceptor +- org.aopalliance.intercept.MethodInterceptor +- org.springframework.aop.MethodBeforeAdvice +- org.springframework.aop.AfterReturningAdvice +- org.springframework.aop.ThrowsAdvice +- org.springframework.aop.framework.ProxyFactory + +## 3. Spring WebFlux 拦截 +- HandlerFilterFunction +- RouterFunction.filter(...) +- WebFilter (接近拦截器语义) +- GlobalFilter (Spring Cloud Gateway) +- GatewayFilter + +## 4. Spring Security +- AbstractSecurityInterceptor +- FilterSecurityInterceptor +- MethodSecurityInterceptor +- PreAuthorize / PostAuthorize +- SecurityExpressionHandler + +## 5. JAX-RS / Jersey / RESTEasy +- ContainerRequestFilter +- ContainerResponseFilter +- ReaderInterceptor +- WriterInterceptor +- ClientRequestFilter +- ClientResponseFilter + +## 6. Struts2 / WebWork +- com.opensymphony.xwork2.interceptor.Interceptor +- com.opensymphony.xwork2.interceptor.AbstractInterceptor +- com.opensymphony.xwork2.interceptor.MethodFilterInterceptor +- com.opensymphony.xwork2.interceptor.PreResultListener + +## 7. MyBatis / Hibernate +- MyBatis: org.apache.ibatis.plugin.Interceptor +- MyBatis: @Intercepts / @Signature +- Hibernate: org.hibernate.Interceptor +- Hibernate: EmptyInterceptor + +## 8. 旧框架与安全相关拦截 +- Apache Shiro: org.apache.shiro.aop.MethodInterceptor +- Apache Shiro: org.apache.shiro.aop.MethodInvocation +- Apache Shiro: org.apache.shiro.aop.AdviceFilter +- Apache Shiro: org.apache.shiro.aop.AnnotationMethodInterceptor +- Apache Shiro: org.apache.shiro.aop.MethodAnnotationResolver +- Apache Shiro: org.apache.shiro.aop.SpringAnnotationResolver +- Apache Shiro: org.apache.shiro.web.filter.authc.AuthenticatingFilter +- Apache Shiro: org.apache.shiro.web.filter.AccessControlFilter +- Apache Shiro: org.apache.shiro.web.filter.PathMatchingFilter +- SiteMesh: com.opensymphony.sitemesh.webapp.SiteMeshFilter (拦截视图) +- SiteMesh 3: org.sitemesh.webapp.SiteMeshFilter +- Wicket: IRequestCycleListener +- Wicket: IRequestCycleListener.onBeginRequest / onEndRequest +- Wicket: IRequestCycleListener.onRequestHandlerResolved +- Wicket: IRequestCycleListener.onRequestHandlerExecuted +- Wicket: RequestCycleListener.onBeginRequest() +- Wicket: RequestCycleListener.onEndRequest() +- Wicket: RequestCycleListener.onRequestHandlerScheduled(...) +- Wicket: RequestCycleListener.onRequestHandlerExecuted(...) +- Wicket: RequestCycleListener.onException(...) +- Wicket: RequestCycleListener.onDetach(...) +- Wicket: IComponentInstantiationListener +- JSF: javax.faces.event.PhaseListener +- JSF: javax.faces.event.PhaseEvent.getPhaseId() +- JSF: PhaseId.ANY_PHASE / RESTORE_VIEW / APPLY_REQUEST_VALUES +- JSF: PhaseId.PROCESS_VALIDATIONS / UPDATE_MODEL_VALUES +- JSF: PhaseId.INVOKE_APPLICATION / RENDER_RESPONSE +- JSF: javax.faces.event.SystemEventListener +- Struts1: org.apache.struts.action.RequestProcessor +- Struts1: RequestProcessor.processActionForm(...) +- Struts1: RequestProcessor.processValidate(...) +- Struts1: RequestProcessor.processPreprocess(...) +- Struts1: org.apache.struts.action.ActionServlet +- Apache CXF: org.apache.cxf.interceptor.Interceptor +- Apache CXF: org.apache.cxf.phase.PhaseInterceptor +- Apache CXF: org.apache.cxf.interceptor.InInterceptor +- Apache CXF: org.apache.cxf.interceptor.OutInterceptor +- Apache CXF: org.apache.cxf.interceptor.FaultInterceptor + +## 9. 安全漏洞相关拦截器与 SQL 检查 +- org.springframework.security.access.intercept.aopalliance.MethodSecurityInterceptor +- org.springframework.security.access.prepost.PreInvocationAuthorizationAdvice +- org.springframework.security.access.vote.AffirmativeBased +- org.hibernate.resource.jdbc.spi.StatementInspector +- org.hibernate.Interceptor.onPrepareStatement(...) +- MyBatis: org.apache.ibatis.executor.statement.StatementHandler +- MyBatis-Plus: com.baomidou.mybatisplus.extension.plugins.inner.IllegalSQLInnerInterceptor +- MyBatis-Plus: com.baomidou.mybatisplus.extension.plugins.inner.BlockAttackInnerInterceptor + +## 10. 漏洞类型对应拦截器补充 +- SQL 注入: StatementInspector / StatementHandler / IllegalSQLInnerInterceptor +- SSRF: org.springframework.http.client.ClientHttpRequestInterceptor +- SSRF: org.springframework.web.reactive.function.client.ExchangeFilterFunction +- XSS 输出: org.springframework.web.servlet.mvc.method.annotation.ResponseBodyAdvice +- 不安全重定向: HandlerInterceptor / HandlerMethodArgumentResolver +- 路径遍历: HandlerInterceptor / WebFilter + +## 11. 漏洞类型对应常见框架/库 +- SQL 注入: MyBatis Interceptor / MyBatis-Plus InnerInterceptor +- SQL 注入: Hibernate StatementInspector +- SQL 注入: JOOQ VisitListener +- SSRF: Spring Cloud Gateway GlobalFilter +- SSRF: OkHttp Interceptor +- XXE: JAXB Unmarshaller Listener +- XSS: OWASP Java HTML Sanitizer +- 反序列化: Jackson PolymorphicTypeValidator +- 模板注入: Thymeleaf Dialect 前置拦截 diff --git a/docs/Interceptor/JavaScript.md b/docs/Interceptor/JavaScript.md new file mode 100644 index 00000000..8b2f89e7 --- /dev/null +++ b/docs/Interceptor/JavaScript.md @@ -0,0 +1,68 @@ +# JavaScript Interceptor 清单 + +## 1. Express / Koa +- Express: app.use / router.use +- Express: middleware (req, res, next) +- Koa: app.use(async (ctx, next) => ...) +- koa-compose + +## 2. NestJS +- NestInterceptor (intercept) +- Guard (CanActivate) +- Pipe (transform) +- ExceptionFilter (catch) +- Middleware (NestMiddleware) + +## 3. Fastify / Hapi +- Fastify: addHook("onRequest"|"preHandler"|"onSend"|"onResponse") +- Hapi: server.ext("onRequest"|"onPreHandler"|"onPreResponse") + +## 4. Next.js / Remix / Nuxt +- Next.js Middleware (middleware.ts) +- Next.js API handler wrapper +- Remix: loader/action 包装 +- Nuxt: server middleware / nitro plugins + +## 5. GraphQL / RPC +- Apollo Server plugins +- graphql-middleware +- tRPC: middleware / procedures +- gRPC Node: interceptors + +## 6. 旧框架/自定义 +- Restify: server.pre / server.use +- Restify: server.on("after") / server.on("restifyError") +- Restify: plugins.acceptParser / plugins.queryParser / plugins.bodyParser +- Sails: policies +- LoopBack: interceptors +- 自定义拦截: function(req, res, next) { ... } +- Perl Mojolicious: app->hook(before_dispatch/after_dispatch) + +## 7. 安全漏洞相关拦截与校验 +- NestJS: AuthGuard / RolesGuard +- NestJS: ValidationPipe +- NestJS: ClassSerializerInterceptor +- Express: express-validator middleware +- Express: csurf middleware +- Express: express-rate-limit +- Koa: koa-validate +- Koa: koa-csrf +- Fastify: preValidation hook +- Fastify: @fastify/csrf-protection +- Hapi: @hapi/crumb + +## 8. 漏洞类型对应拦截补充 +- SQL 注入: knex raw 校验拦截 / prisma middleware +- SSRF: fetch/axios wrapper middleware +- XSS 输出: DOMPurify 前置拦截 / 自定义 sanitizer +- 原型污染: merge/assign 拦截中间件 +- 不安全重定向: res.redirect 包装拦截 + +## 9. 漏洞类型对应常见框架/库 +- SQL 注入: Sequelize beforeFind 钩子 +- SQL 注入: Prisma middleware +- SQL 注入: TypeORM subscribers +- SSRF: undici Dispatcher/Agent 包装 +- XXE: xml2js 安全配置 +- XSS: sanitize-html +- 原型污染: lodash merge guard diff --git a/docs/Interceptor/Python.md b/docs/Interceptor/Python.md new file mode 100644 index 00000000..ea37a563 --- /dev/null +++ b/docs/Interceptor/Python.md @@ -0,0 +1,76 @@ +# Python Interceptor 清单 + +## 1. Django +- Middleware (process_request/process_view/process_response) +- View decorator (自定义装饰器) +- Signal: request_started / request_finished +- DRF: BaseAuthentication.authenticate(...) +- DRF: BasePermission.has_permission(...) +- DRF: BaseThrottle.allow_request(...) + +## 2. Flask +- @app.before_request / after_request +- @app.teardown_request +- Blueprint.before_request / after_request +- Flask-Login: @login_required +- Flask-JWT-Extended: @jwt_required + +## 3. FastAPI / Starlette +- @app.middleware("http") +- BaseHTTPMiddleware +- Dependency Injection (Depends) +- APIRoute.get_route_handler 自定义包装 + +## 4. Sanic / Aiohttp / Tornado +- Sanic: @app.middleware("request"|"response") +- aiohttp: @web.middleware +- Tornado: RequestHandler.prepare() +- Tornado: RequestHandler.on_finish() + +## 5. Falcon / Pyramid / Bottle +- Falcon: middleware.process_request / process_response +- Pyramid: tween_factory +- Pyramid: @subscriber(NewRequest) +- Bottle: app.add_hook("before_request"/"after_request") + +## 6. Celery / 任务拦截 +- Celery: task_prerun / task_postrun +- Celery: Task.before_start / after_return +- RQ: job hooks + +## 7. 旧框架/自定义 +- web.py: web.application.add_processor(...) +- Pylons: pylons.middleware +- TurboGears: @before_call / @after_call +- TurboGears: @before_validate / @before_render +- web2py: response.middleware +- web2py: @auth.requires / @auth.requires_login +- Perl Mojolicious: app->hook(before_dispatch/after_dispatch) + +## 8. 安全漏洞相关拦截与装饰器 +- Django: django.views.decorators.csrf.csrf_protect +- Django: django.views.decorators.csrf.ensure_csrf_cookie +- Django: django.views.decorators.clickjacking.xframe_options_deny +- Django: django.views.decorators.clickjacking.xframe_options_sameorigin +- DRF: permission_classes(...) +- DRF: throttle_classes(...) +- Flask: flask_wtf.csrf.CSRFProtect +- Flask: flask_talisman.Talisman +- FastAPI: fastapi.security.HTTPBearer +- FastAPI: fastapi.security.OAuth2PasswordBearer + +## 9. 漏洞类型对应拦截补充 +- SQL 注入: Django connection.execute_wrapper +- SQL 注入: SQLAlchemy event.listen(engine, "before_cursor_execute", ...) +- SSRF: requests.Session.request 自定义包装 +- XSS 输出: Jinja2 Environment.autoescape +- 不安全重定向: Django redirect 包装器 + +## 10. 漏洞类型对应常见框架/库 +- SQL 注入: Django ORM QuerySet 参数化 +- SQL 注入: SQLAlchemy event.listen / before_cursor_execute +- SQL 注入: Peewee pre-save hook 校验 +- SSRF: httpx.Client 自定义 transport 限制 +- XXE: defusedxml 元素解析 +- XSS: Bleach 清洗装饰器 +- 反序列化: itsdangerous BadSignature 拦截 diff --git a/docs/filter/Java.md b/docs/filter/Java.md new file mode 100644 index 00000000..b211cdb0 --- /dev/null +++ b/docs/filter/Java.md @@ -0,0 +1,134 @@ +# Java Filter / 中间件 清单 + +## 1. Servlet 规范过滤器 +- javax.servlet.Filter +- javax.servlet.FilterChain +- javax.servlet.FilterConfig +- javax.servlet.ServletRequestWrapper +- javax.servlet.ServletResponseWrapper +- javax.servlet.http.HttpServletRequestWrapper +- javax.servlet.http.HttpServletResponseWrapper + +## 2. Spring Web MVC 过滤与拦截 +- OncePerRequestFilter +- GenericFilterBean +- DelegatingFilterProxy +- FilterRegistrationBean +- HandlerInterceptor +- HandlerInterceptorAdapter (旧版) +- WebMvcConfigurer.addInterceptors(...) +- HandlerMethodArgumentResolver + +## 3. Spring Security 过滤链 +- FilterChainProxy +- SecurityFilterChain +- SecurityContextPersistenceFilter +- UsernamePasswordAuthenticationFilter +- BasicAuthenticationFilter +- BearerTokenAuthenticationFilter +- OAuth2LoginAuthenticationFilter +- ExceptionTranslationFilter +- CsrfFilter +- HeaderWriterFilter + +## 4. Spring WebFlux 过滤与拦截 +- org.springframework.web.server.WebFilter +- org.springframework.web.server.WebFilterChain +- WebFilterChainProxy +- HandlerFilterFunction (Functional endpoints) +- RouterFunction.filter(...) +- GlobalFilter (Spring Cloud Gateway) +- GatewayFilter / GatewayFilterSpec + +## 5. JAX-RS 过滤与拦截 +- ContainerRequestFilter +- ContainerResponseFilter +- ClientRequestFilter +- ClientResponseFilter +- WriterInterceptor +- ReaderInterceptor + +## 6. Struts2 过滤与拦截 +- org.apache.struts2.dispatcher.filter.StrutsPrepareAndExecuteFilter +- org.apache.struts2.dispatcher.filter.StrutsPrepareFilter +- org.apache.struts2.dispatcher.filter.StrutsExecuteFilter +- com.opensymphony.xwork2.interceptor.Interceptor +- com.opensymphony.xwork2.interceptor.AbstractInterceptor + +## 7. Micronaut / Quarkus 过滤 +- io.micronaut.http.filter.HttpServerFilter +- io.micronaut.http.filter.ServerFilterPhase +- io.micronaut.http.filter.FilterRunner +- io.quarkus.vertx.http.runtime.filters.Filter +- io.quarkus.vertx.http.runtime.filters.FilterBuildItem +- io.quarkus.resteasy.reactive.server.spi.ResteasyReactiveContainerRequestFilter + +## 8. Vert.x / Undertow / Netty +- io.vertx.ext.web.handler.HandlerInterceptor (基于 Handler) +- io.vertx.ext.web.RoutingContextHandler +- io.vertx.core.Handler +- io.undertow.server.HttpHandler +- io.undertow.server.handlers.PredicateHandler +- io.netty.channel.ChannelInboundHandler +- io.netty.handler.codec.http.HttpObjectAggregator + +## 9. Play Framework / Jersey / RESTEasy +- play.mvc.Action +- play.mvc.EssentialFilter +- play.mvc.Filter +- org.glassfish.jersey.server.ContainerRequest +- org.glassfish.jersey.server.ContainerResponse +- org.jboss.resteasy.spi.interception.PreProcessInterceptor +- org.jboss.resteasy.spi.interception.MessageBodyReaderInterceptor + +## 10. 老牌框架/自定义过滤 +- Apache Shiro: org.apache.shiro.web.servlet.AbstractShiroFilter +- Apache Shiro: org.apache.shiro.web.filter.PathMatchingFilter +- Apache Shiro: org.apache.shiro.web.filter.authc.FormAuthenticationFilter +- Apache Shiro: org.apache.shiro.web.filter.authc.BasicHttpAuthenticationFilter +- Apache Shiro: org.apache.shiro.web.filter.AccessControlFilter +- Apache Shiro: org.apache.shiro.web.filter.mgt.DefaultFilterChainManager +- SiteMesh: com.opensymphony.sitemesh.webapp.SiteMeshFilter +- SiteMesh: com.opensymphony.sitemesh.webapp.SiteMeshWebAppContext +- SiteMesh 3: org.sitemesh.config.ConfigurableSiteMeshFilter +- SiteMesh 3: org.sitemesh.webapp.SiteMeshFilter +- WebWork: com.opensymphony.webwork.dispatcher.FilterDispatcher +- WebWork: com.opensymphony.webwork.dispatcher.ServletDispatcher +- WebWork: com.opensymphony.webwork.interceptor.Interceptor +- 自定义 Filter: implements javax.servlet.Filter + +## 11. 安全漏洞相关过滤器与防护组件 +- org.springframework.web.filter.CorsFilter +- org.springframework.security.web.csrf.CsrfFilter +- org.springframework.security.web.header.HeaderWriterFilter +- org.springframework.security.web.firewall.StrictHttpFirewall +- com.alibaba.druid.wall.WallFilter +- com.alibaba.druid.wall.WallConfig +- com.alibaba.druid.filter.FilterEventAdapter +- org.owasp.esapi.waf.ESAPIWebApplicationFirewall + +## 12. 漏洞类型对应过滤器补充 +- SQL 注入: com.alibaba.druid.wall.WallFilter / WallConfig +- 命令执行: 自定义 CommandInjectionFilter +- 路径遍历: 自定义 PathTraversalFilter +- SSRF: 自定义 UrlAllowlistFilter +- XXE: 自定义 XmlSecurityFilter +- 模板注入: 自定义 TemplateInputFilter +- 代码注入: 自定义 ScriptInjectionFilter +- 反序列化: java.io.ObjectInputFilter +- 不安全重定向: 自定义 RedirectValidationFilter +- XSS 输出: org.owasp.esapi.waf.ESAPIWebApplicationFirewall +- 日志注入: 自定义 LogSanitizerFilter + +## 13. 漏洞类型对应常见框架/库 +- SQL 注入: MyBatis-Plus IllegalSQLInnerInterceptor / Druid WallFilter +- SQL 注入: Hibernate StatementInspector +- SQL 注入: JPA @Query(nativeQuery) 统一拦截器 +- SSRF: Spring Cloud Gateway GlobalFilter +- SSRF: Spring WebClient ExchangeFilterFunction +- SSRF: Apache HttpClient HttpRequestInterceptor +- XXE: Xerces SAXParser 安全特性过滤器 +- XSS: AntiSamy Filter / ESAPI WAF +- 路径遍历: Spring ResourceHandlerInterceptor +- 不安全重定向: Spring Security RedirectStrategy 包装 +- 反序列化: Jackson ObjectMapper 默认类型限制过滤 diff --git a/docs/filter/JavaScript.md b/docs/filter/JavaScript.md new file mode 100644 index 00000000..5fdcb0c5 --- /dev/null +++ b/docs/filter/JavaScript.md @@ -0,0 +1,99 @@ +# JavaScript Filter / 中间件 清单 + +## 1. Express 中间件 +- app.use(middleware) +- router.use(middleware) +- app.use((req, res, next) => ...) +- error-handling middleware (err, req, res, next) +- express.json() / express.urlencoded() + +## 2. Koa 中间件 +- app.use(async (ctx, next) => ...) +- koa-compose +- koa-body / koa-router middleware +- koa-session + +## 3. NestJS 过滤与拦截 +- NestMiddleware +- middleware consumer.apply(...) +- Guard (CanActivate) +- Interceptor (CallHandler) +- Pipe (transform) +- ExceptionFilter (catch) + +## 4. Fastify 中间件 +- addHook("onRequest"|"preParsing"|"preValidation"|"preHandler"|"preSerialization"|"onSend"|"onResponse"|"onError") +- addHook("onRoute") +- register(plugin, opts) + +## 5. Hapi / AdonisJS / LoopBack +- Hapi: server.ext("onRequest"|"onPreAuth"|"onPostAuth"|"onPreHandler"|"onPostHandler"|"onPreResponse") +- Hapi: server.route options.pre +- AdonisJS: Server.middleware.register(...) +- AdonisJS: Route.middleware(...) +- LoopBack 4: middleware sequence / interceptors +- LoopBack 3: middleware.json / middleware.urlencoded + +## 6. Next.js / Nuxt / Remix +- Next.js Middleware (middleware.ts) +- Next.js API Route handler wrapper +- Nuxt: server middleware (server/middleware) +- Nuxt: route rules / nitro plugins +- Remix: loader/action wrapper + +## 7. GraphQL / RPC +- Apollo Server: plugins / context function +- graphql-middleware +- Yoga: plugins +- gRPC Node: ServerInterceptor / ServerCredentials wrapper + +## 8. 老牌框架/自定义 +- Restify: server.pre / server.use +- Restify: server.on("after") +- Restify: server.on("restifyError") +- Restify: plugins.acceptParser / plugins.queryParser / plugins.bodyParser +- Sails: policies / hooks +- Hapi v16: server.ext(...) +- 自定义中间件: function(req, res, next) { ... } + +## 9. 安全漏洞相关中间件与防护 +- Express: helmet +- Express: csurf +- Express: express-rate-limit +- Express: hpp +- Express: express-validator +- Express: xss-clean +- Koa: koa-helmet +- Koa: koa-csrf +- Koa: koa-ratelimit +- Koa: @koa/cors +- Fastify: @fastify/helmet +- Fastify: @fastify/csrf-protection +- Fastify: @fastify/rate-limit +- Fastify: @fastify/cors +- Hapi: @hapi/crumb +- Hapi: hapi-rate-limit + +## 10. 漏洞类型对应中间件补充 +- SQL 注入: express-validator / Joi 校验中间件 +- 命令执行: 自定义 CommandGuard middleware +- 路径遍历: 自定义 PathGuard middleware +- SSRF: 自定义 UrlAllowlist middleware +- 反序列化: 自定义 DeserializeGuard middleware +- 模板注入: 自定义 TemplateGuard middleware +- 代码注入: 自定义 CodeGuard middleware +- 原型污染: hpp / qs 校验 / 自定义 merge-guard +- 不安全重定向: 自定义 RedirectGuard middleware +- XSS 输出: helmet / xss-clean / 自定义 output sanitizer +- 日志注入: 自定义 LogSanitizer middleware + +## 11. 漏洞类型对应常见框架/库 +- SQL 注入: Sequelize replacements / bind +- SQL 注入: Knex 参数化查询 +- SQL 注入: Prisma 参数化 query +- SSRF: undici Dispatcher/Agent 白名单包装 +- SSRF: axios 请求拦截器 +- XXE: fast-xml-parser 安全选项 +- XSS: DOMPurify / sanitize-html +- 原型污染: qs allowPrototypes 禁用 +- 不安全重定向: Next.js middleware 重定向白名单 diff --git a/docs/filter/Python.md b/docs/filter/Python.md new file mode 100644 index 00000000..96ef8cd9 --- /dev/null +++ b/docs/filter/Python.md @@ -0,0 +1,109 @@ +# Python Filter / 中间件 清单 + +## 1. Django 中间件 +- MIDDLEWARE 列表项 +- django.utils.deprecation.MiddlewareMixin +- process_request(request) +- process_view(request, view_func, view_args, view_kwargs) +- process_template_response(request, response) +- process_response(request, response) +- process_exception(request, exception) + +## 2. Django REST Framework +- DEFAULT_AUTHENTICATION_CLASSES +- DEFAULT_PERMISSION_CLASSES +- DEFAULT_THROTTLE_CLASSES +- DEFAULT_PARSER_CLASSES +- BaseAuthentication.authenticate(...) +- BasePermission.has_permission(...) +- BasePermission.has_object_permission(...) +- BaseThrottle.allow_request(...) + +## 3. Flask 中间件 +- @app.before_request +- @app.after_request +- @app.teardown_request +- @app.errorhandler +- werkzeug.middleware.* (ProxyFix, DispatcherMiddleware) +- wsgi_app / app.wsgi_app 包装 + +## 4. FastAPI / Starlette 中间件 +- app.add_middleware(...) +- BaseHTTPMiddleware +- @app.middleware("http") +- Starlette Middleware 接口 +- ExceptionMiddleware / CORSMiddleware / SessionMiddleware + +## 5. Sanic / Aiohttp / Tornado +- Sanic: @app.middleware("request"|"response") +- Sanic: @app.on_request / @app.on_response +- aiohttp: @web.middleware +- aiohttp: app.middlewares.append(...) +- Tornado: RequestHandler.prepare() +- Tornado: RequestHandler.on_finish() +- Tornado: Application.add_transform(...) + +## 6. Falcon / Pyramid / Bottle +- Falcon: middleware.process_request / process_response / process_resource +- Pyramid: tween_factory +- Pyramid: @subscriber(NewRequest) +- Bottle: app.add_hook("before_request"/"after_request") +- Bottle: @hook("before_request") + +## 7. GraphQL / RPC 过滤 +- Graphene: middleware (resolve) +- Ariadne: middleware / Extension +- Strawberry: extensions +- gRPC: ServerInterceptor +- Thrift: TProcessor / TServerEventHandler + +## 8. Celery / RQ / 任务拦截 +- Celery: task_prerun / task_postrun 信号 +- Celery: Task.before_start / after_return +- RQ: job hooks (on_success/on_failure) + +## 9. 老牌框架/自定义 +- web.py: web.application.add_processor(...) +- web.py: web.webapi.handle() +- web.py: web.application.processors +- Pylons: pylons.config["pylons.response_options"] +- Pylons: pylons.wsgiapp.PylonsApp +- Pylons: pylons.middleware +- TurboGears: @before_validate / @before_call +- TurboGears: app_globals / request hooks +- 自定义 WSGI 中间件 (callable(environ, start_response)) + +## 10. 安全漏洞相关中间件与防护 +- Django: django.middleware.security.SecurityMiddleware +- Django: django.middleware.csrf.CsrfViewMiddleware +- Django: django.middleware.clickjacking.XFrameOptionsMiddleware +- Flask: flask_wtf.csrf.CSRFProtect +- Flask: flask_talisman.Talisman +- Flask: flask_seasurf.SeaSurf +- Flask: flask_limiter.Limiter +- Starlette: TrustedHostMiddleware +- Starlette: HTTPSRedirectMiddleware +- Starlette: CORSMiddleware + +## 11. 漏洞类型对应中间件补充 +- SQL 注入: 自定义 SQLValidationMiddleware / connection.execute_wrapper +- 命令执行: 自定义 CommandValidationMiddleware +- 路径遍历: 自定义 PathSanitizerMiddleware +- SSRF: 自定义 UrlAllowlistMiddleware +- 反序列化: 自定义 DeserializationGuardMiddleware +- 模板注入: 自定义 TemplateSanitizerMiddleware +- 代码注入: 自定义 CodeExecutionGuardMiddleware +- 不安全重定向: 自定义 RedirectGuardMiddleware +- XSS 输出: 自定义 OutputSanitizerMiddleware +- 日志注入: 自定义 LogSanitizerMiddleware + +## 12. 漏洞类型对应常见框架/库 +- SQL 注入: Django ORM QuerySet 参数化 +- SQL 注入: SQLAlchemy Query / text 绑定参数 +- SQL 注入: Peewee Model.select().where(...) +- SSRF: requests.Session 请求白名单包装 +- SSRF: aiohttp.ClientSession 自定义 TCPConnector 限制 +- XXE: defusedxml 安全解析封装 +- XSS: Bleach HTML 清洗中间件 +- 反序列化: itsdangerous URLSafeSerializer +- 不安全重定向: Django allowed_hosts / RedirectFallbackMiddleware diff --git a/docs/sink/Java.md b/docs/sink/Java.md new file mode 100644 index 00000000..6d33baa8 --- /dev/null +++ b/docs/sink/Java.md @@ -0,0 +1,117 @@ +# Java 漏洞 Sink 点清单 + +## 1. 命令执行 +- Runtime.getRuntime().exec(...) +- new ProcessBuilder(...).start() +- ScriptEngine.eval(...) 当脚本拼接输入 +- new GroovyShell().evaluate(...) +- new GroovyShell().parse(...).run() +- new javax.tools.ToolProvider.getSystemJavaCompiler().run(...) +- org.codehaus.groovy.control.CompilationUnit.compile(...) + +## 2. SQL 注入 +- Statement.execute(...) +- Statement.executeQuery(...) +- Statement.executeUpdate(...) +- PreparedStatement 直接拼接 SQL 字符串后执行 +- CallableStatement.execute(...) +- CallableStatement.executeQuery(...) +- QueryRunner.query(...) 直接拼接 SQL +- EntityManager.createNativeQuery(sql) +- Session.createSQLQuery(sql) / createNativeQuery(sql) + +## 3. 路径遍历/任意文件读写 +- new FileInputStream(userInputPath) +- new FileOutputStream(userInputPath) +- Files.readAllBytes(Paths.get(userInputPath)) +- Files.newBufferedReader(Paths.get(userInputPath)) +- Files.write(Paths.get(userInputPath), ...) +- RandomAccessFile(userInputPath, "rw") +- Files.newInputStream(Paths.get(userInputPath)) +- Files.newOutputStream(Paths.get(userInputPath)) +- Files.copy(Paths.get(userInputPath), ...) +- Files.move(Paths.get(userInputPath), ...) +- new FileReader(userInputPath) +- new FileWriter(userInputPath) +- new BufferedReader(new FileReader(userInputPath)) +- new BufferedWriter(new FileWriter(userInputPath)) +- File.delete() / Files.delete(Paths.get(userInputPath)) +- File.listFiles() / Files.list(Paths.get(userInputPath)) + +## 4. 反序列化 +- ObjectInputStream.readObject() +- XMLDecoder.readObject() +- XStream.fromXML(...) +- Yaml.load(...) / SnakeYAML load(...) +- Yaml.loadAll(...) +- ObjectMapper.readValue(byte[], Object.class) +- Kryo.readClassAndObject(...) +- HessianInput.readObject() +- JSON.parseObject(...) 允许 autoType 时 +- java.io.ObjectInputStream.readUnshared() + +## 5. SSRF / 外部请求 +- new URL(userInput).openStream() +- URLConnection.connect() +- HttpURLConnection.getInputStream() +- Apache HttpClient.execute(...) +- OkHttpClient.newCall(request).execute() +- HttpClient.newHttpClient().send(...) +- WebClient.get().uri(userInput).retrieve() +- RestTemplate.getForObject(userInput, ...) +- RestTemplate.exchange(userInput, ...) +- Jsoup.connect(userInput).get() + +## 6. XXE / XML 解析 +- DocumentBuilder.parse(...) +- SAXParser.parse(...) +- JAXB.unmarshal(...) +- XMLInputFactory.createXMLStreamReader(...) +- SAXReader.read(...) (dom4j) +- DocumentBuilderFactory.newDocumentBuilder().parse(...) +- TransformerFactory.newTransformer().transform(...) +- XPathFactory.newInstance().newXPath().evaluate(...) + +## 7. 模板注入 +- VelocityEngine.evaluate(...) +- Freemarker Template.process(...) +- Thymeleaf 解析含用户输入的模板 +- Mustache.compile(userInput).execute(...) +- StringSubstitutor.replace(userInput) +- MessageFormat.format(userInput, ...) + +## 8. 代码注入 / 脚本执行 +- GroovyShell.evaluate(...) +- JavaCompiler.run(...) 处理用户输入源码 +- javax.script.ScriptEngine.eval(...) +- NashornScriptEngine.eval(...) +- ScriptEngineManager.getEngineByName(...).eval(...) + +## 9. 反射与类加载 +- Class.forName(userInput) +- ClassLoader.loadClass(userInput) +- Method.invoke(...) 目标或参数来自用户输入 +- Constructor.newInstance(...) +- Class.forName(userInput, true, classLoader) +- URLClassLoader.newInstance(urls) + +## 10. 不安全重定向/跳转 +- HttpServletResponse.sendRedirect(userInput) +- Response.temporaryRedirect(URI.create(userInput)) +- Response.seeOther(URI.create(userInput)) +- ModelAndView.setViewName(userInput) + +## 11. XSS 输出 +- response.getWriter().write(userInput) +- response.getOutputStream().print(userInput) +- JSP 表达式/脚本输出未转义内容 +- response.sendError(..., userInput) +- PrintWriter.print(userInput) +- out.write(userInput) (JSP) + +## 12. 日志注入 +- logger.info(userInput) +- logger.warn(userInput) +- logger.error(userInput) +- Logger.log(Level.INFO, userInput) +- Logger.log(Level.WARNING, userInput) diff --git a/docs/sink/JavaScript.md b/docs/sink/JavaScript.md new file mode 100644 index 00000000..e0b4399a --- /dev/null +++ b/docs/sink/JavaScript.md @@ -0,0 +1,114 @@ +# JavaScript 漏洞 Sink 点清单 + +## 1. 命令执行(Node.js) +- child_process.exec(userInput) +- child_process.execSync(userInput) +- child_process.spawn(command, args) command/args 来自用户输入 +- child_process.spawnSync(...) +- child_process.execFile(file, args) +- child_process.execFileSync(file, args) +- child_process.fork(modulePath, args) + +## 2. SQL 注入 +- database.query(sqlString) +- connection.query(sqlString) +- sequelize.query(sqlString) +- knex.raw(sqlString) +- pg.Client.query(sqlString) +- mysql.createConnection().query(sqlString) +- mssql.Request().query(sqlString) +- prisma.$queryRawUnsafe(sqlString) +- sqlite3.Database().all(sqlString) + +## 3. 路径遍历/任意文件读写 +- fs.readFile(userInputPath, ...) +- fs.readFileSync(userInputPath, ...) +- fs.writeFile(userInputPath, ...) +- fs.writeFileSync(userInputPath, ...) +- fs.createReadStream(userInputPath) +- fs.createWriteStream(userInputPath) +- path.join(base, userInputPath) 后直接读写 +- fs.promises.readFile(userInputPath) +- fs.promises.writeFile(userInputPath, ...) +- fs.promises.readdir(userInputPath) +- fs.promises.unlink(userInputPath) +- fs.promises.rename(userInputPath, ...) +- fs.open(userInputPath, ...) +- fs.rm(userInputPath, ...) +- fs.readdir(userInputPath, ...) +- tar.extract({ cwd: userInputPath }) +- unzipper.Extract({ path: userInputPath }) + +## 4. 反序列化 +- JSON.parse(userInput) 在信任边界外 +- yaml.load(userInput) / yaml.safeLoad(userInput) 配置不当 +- serialize/deserialize 库的 deserialize(userInput) +- safe-json-parse(userInput) 处理后续逻辑 +- bson.deserialize(userInput) +- node-serialize.unserialize(userInput) +- msgpack.decode(userInput) + +## 5. SSRF / 外部请求 +- fetch(userInputUrl) +- axios.get(userInputUrl) +- request(userInputUrl) +- got(userInputUrl) +- http.request(userInputUrl) +- https.request(userInputUrl) +- axios.post(userInputUrl) +- node-fetch(userInputUrl) +- superagent.get(userInputUrl) +- undici.request(userInputUrl) + +## 6. 模板注入 +- ejs.render(userInput, ...) +- handlebars.compile(userInput)(...) +- pug.render(userInput, ...) +- lodash.template(userInput)(...) +- nunjucks.renderString(userInput, ...) +- mustache.render(userInput, ...) +- twig.render(userInput, ...) + +## 7. 代码注入 +- eval(userInput) +- new Function(userInput)() +- setTimeout(userInput, ...) +- setInterval(userInput, ...) +- vm.runInThisContext(userInput) +- vm.runInNewContext(userInput) +- vm.Script(userInput).runInThisContext() +- vm.runInContext(userInput, sandbox) +- require(userInput) / import(userInput) + +## 8. 原型污染关键点 +- lodash.merge(target, userInput) +- Object.assign(target, userInput) +- deep merge 自定义实现接收用户输入对象 +- _.defaultsDeep(target, userInput) +- qs.parse(userInput) +- jQuery.extend(true, target, userInput) +- hoek.merge(target, userInput) + +## 9. 不安全重定向/跳转 +- res.redirect(userInput) +- window.location = userInput +- location.href = userInput +- res.writeHead(302, { Location: userInput }) +- document.location = userInput + +## 10. XSS 输出 +- innerHTML = userInput +- dangerouslySetInnerHTML +- document.write(userInput) +- res.send(userInput) / res.write(userInput) +- res.end(userInput) +- element.outerHTML = userInput +- insertAdjacentHTML("beforeend", userInput) +- jQuery.html(userInput) + +## 11. 日志注入 +- console.log(userInput) +- logger.info(userInput) +- logger.error(userInput) +- console.error(userInput) +- console.warn(userInput) diff --git a/docs/sink/Python.md b/docs/sink/Python.md new file mode 100644 index 00000000..f64346ba --- /dev/null +++ b/docs/sink/Python.md @@ -0,0 +1,105 @@ +# Python 漏洞 Sink 点清单 + +## 1. 命令执行 +- os.system(user_input) +- subprocess.run(user_input, shell=True) +- subprocess.Popen(user_input, shell=True) +- subprocess.call(user_input, shell=True) +- eval(user_input) / exec(user_input) +- os.popen(user_input) +- subprocess.check_output(user_input, shell=True) +- subprocess.check_call(user_input, shell=True) +- pexpect.run(user_input) +- os.spawnl(os.P_WAIT, cmd, ...) +- os.spawnlp(os.P_WAIT, cmd, ...) + +## 2. SQL 注入 +- cursor.execute(sql_string) +- cursor.executemany(sql_string) +- connection.execute(sql_string) +- ORM 中的 raw/extra 直接拼接 SQL +- sqlalchemy.text(sql_string) 直接拼接 +- Model.objects.raw(sql_string) +- connection.cursor().execute(sql_string) +- pandas.read_sql(sql_string, connection) + +## 3. 路径遍历/任意文件读写 +- open(user_input_path, "r|w|a") +- pathlib.Path(user_input_path).read_text() +- pathlib.Path(user_input_path).write_text(...) +- shutil.copy(user_input_path, ...) +- os.remove(user_input_path) +- os.listdir(user_input_path) +- pathlib.Path(user_input_path).read_bytes() +- pathlib.Path(user_input_path).write_bytes(...) +- shutil.copy2(user_input_path, ...) +- shutil.move(user_input_path, ...) +- os.path.exists(user_input_path) +- glob.glob(user_input_path) +- tempfile.NamedTemporaryFile(dir=user_input_path) +- zipfile.ZipFile.extract(member, path=user_input_path) +- tarfile.TarFile.extractall(path=user_input_path) + +## 4. 反序列化 +- pickle.loads(user_input) +- pickle.load(file) +- yaml.load(user_input) / yaml.load(file) +- marshal.loads(user_input) +- jsonpickle.decode(user_input) +- dill.loads(user_input) +- shelve.open(user_input) +- joblib.load(user_input) +- numpy.load(user_input, allow_pickle=True) + +## 5. SSRF / 外部请求 +- requests.get(user_input_url) +- requests.post(user_input_url) +- urllib.request.urlopen(user_input_url) +- httpx.get(user_input_url) +- aiohttp.ClientSession().get(user_input_url) +- requests.put(user_input_url) +- requests.delete(user_input_url) +- httpx.post(user_input_url) +- urllib.request.Request(user_input_url) +- urllib3.PoolManager().request("GET", user_input_url) + +## 6. 模板注入 +- jinja2.Template(user_input).render(...) +- jinja2.Environment().from_string(user_input).render(...) +- mako.template.Template(user_input).render(...) +- tornado.template.Template(user_input).generate(...) +- django.template.Template(user_input).render(...) + +## 7. 代码注入 / 动态导入 +- importlib.import_module(user_input) +- __import__(user_input) +- eval/exec 处理用户输入表达式 +- pkgutil.resolve_name(user_input) +- pydoc.locate(user_input) +- runpy.run_module(user_input) +- runpy.run_path(user_input) + +## 8. 命令/代码生成 +- ast.literal_eval(user_input) 在未限制上下文时 +- compile(user_input, ..., "exec") +- code.compile_command(user_input) +- ast.parse(user_input) + +## 9. 不安全重定向/跳转 +- Flask/Django redirect(user_input) +- Response(location=user_input) +- werkzeug.utils.redirect(user_input) +- HttpResponseRedirect(user_input) + +## 10. XSS 输出 +- 返回 HTML 内容时直接拼接 user_input +- 模板 autoescape 关闭或使用 |safe +- markupsafe.Markup(user_input) +- django.utils.safestring.mark_safe(user_input) + +## 11. 日志注入 +- logging.info(user_input) +- logging.error(user_input) +- print(user_input) 写入审计日志 +- logging.warning(user_input) +- logger.exception(user_input) diff --git a/docs/source/Java.md b/docs/source/Java.md new file mode 100644 index 00000000..963c8e78 --- /dev/null +++ b/docs/source/Java.md @@ -0,0 +1,162 @@ +# Java 漏洞 Source 点清单 + +## 1. Web 请求参数(Servlet/JAX-RS/Spring) +- HttpServletRequest.getParameter(...) +- HttpServletRequest.getParameterValues(...) +- HttpServletRequest.getParameterMap() +- HttpServletRequest.getHeader(...) +- HttpServletRequest.getHeaders(...) +- HttpServletRequest.getQueryString() +- HttpServletRequest.getCookies() +- HttpServletRequest.getInputStream() +- HttpServletRequest.getReader() +- @RequestParam / @RequestHeader / @CookieValue +- @PathVariable / @MatrixVariable +- JAX-RS: @QueryParam / @PathParam / @HeaderParam / @CookieParam / @FormParam +- Spring WebFlux: ServerRequest.queryParam(...) +- Spring WebFlux: ServerRequest.bodyToMono(...) +- Spring MVC: @RequestBody / @RequestParam / @PathVariable / @RequestHeader +- Spring MVC: WebRequest.getParameter(...) +- Spring MVC: NativeWebRequest.getParameter(...) +- Spring Cloud Gateway: ServerWebExchange.getRequest().getQueryParams() +- Spring Cloud Gateway: ServerWebExchange.getRequest().getHeaders() +- Micronaut: @QueryValue / @PathVariable / @Body / @Header +- Micronaut: HttpRequest.getParameters() +- Quarkus: @QueryParam / @PathParam / @HeaderParam / @FormParam +- Quarkus: RoutingContext.request().getParam(...) +- Play Framework: request.getQueryString(...) +- Play Framework: request.body().asJson() +- Vert.x: RoutingContext.request().getParam(...) +- Vert.x: RoutingContext.getBodyAsString() +- Jersey: ContainerRequestContext.getHeaders() + +## 2. 请求体/表单/JSON 反序列化 +- @RequestBody 绑定的 DTO/Map/JsonNode +- ObjectMapper.readValue(InputStream, ...) +- ObjectMapper.readTree(...) +- Gson.fromJson(...) +- Jackson 的 JsonNode.get(...) 读取客户端字段 +- MultipartFile.getInputStream() +- MultipartFile.getBytes() +- Jackson: ObjectReader.readValue(...) +- Jackson: ObjectMapper.convertValue(...) +- Fastjson: JSON.parseObject(...) +- Fastjson: JSONObject.getString(...) +- org.json: new JSONObject(body).get(...) +- Protobuf: request.getXxx() (protobuf message) +- Spring WebFlux: ServerRequest.bodyToFlux(...) + +## 3. 文件上传与多部件表单 +- MultipartFile.getOriginalFilename() +- MultipartFile.getInputStream() +- Part.getInputStream() +- Part.getSubmittedFileName() +- DiskFileItem.getInputStream() +- Commons FileUpload: FileItem.getName() +- Commons FileUpload: FileItem.getInputStream() +- Servlet 3.0: request.getParts() +- Spring: MultipartHttpServletRequest.getFile(...) + +## 4. URL/路径相关输入 +- HttpServletRequest.getRequestURI() +- HttpServletRequest.getRequestURL() +- HttpServletRequest.getPathInfo() +- HttpServletRequest.getServletPath() +- ServerHttpRequest.getPath() + +## 5. 认证/会话/Token 输入 +- HttpServletRequest.getRemoteUser() +- HttpServletRequest.getUserPrincipal().getName() +- HttpSession.getAttribute(...) +- SecurityContextHolder.getContext().getAuthentication().getName() +- OAuth2AuthenticationToken.getPrincipal().getAttributes() +- Jwt.getClaimAsString(...) + +## 6. RPC/微服务输入 +- gRPC: request.getXxx() +- Dubbo: 参数对象 getter +- Spring Cloud OpenFeign: 请求参数绑定对象 +- Spring Cloud: @RequestBody / @RequestParam 绑定 DTO +- Apache Thrift: request.getXxx() +- Hessian: 反序列化后的参数对象 +- RSocket: Payload.getDataUtf8() + +## 7. 模板/表达式输入 +- Spring EL: ${param.xxx} / ${header.xxx} +- Thymeleaf: ${param.xxx} / ${#request.getParameter(...)} +- Velocity: $request.getParameter(...) +- Freemarker: RequestParameters / request +- Spring EL: #request, #session, #params +- JSP EL: ${param.xxx} / ${header.xxx} / ${cookie.xxx} +- Pebble: {{ request.parameter("...") }} + +## 8. 消息队列/事件流输入 +- Kafka ConsumerRecord.value() +- RabbitMQ Message.getBody() +- RocketMQ Message.getBody() +- Pulsar Message.getData() +- Spring Kafka: @Payload String +- Spring AMQP: Message.getMessageProperties().getHeaders() +- Spring Cloud Stream: @StreamListener 参数 + +## 9. 数据库中不可信字段 +- ResultSet.getString(...) +- ResultSet.getObject(...) +- ORM 实体字段来自外部同步数据源 +- JpaRepository 查询结果字段 +- MyBatis ResultMap 字段 + +## 10. 外部系统/配置输入 +- System.getenv(...) +- System.getProperty(...) +- Properties.getProperty(...) +- JNDI: InitialContext.lookup(...) +- Spring Environment.getProperty(...) +- Spring Cloud Config: Environment.getProperty(...) +- Vault: Logical.read(...) +- Consul: KVClient.getValue(...) + +## 11. 框架细分(Spring Security / Spring Cloud Gateway) +- Spring Security: SecurityContextHolder.getContext().getAuthentication() +- Spring Security: Authentication.getPrincipal() / getCredentials() / getDetails() +- Spring Security: HttpSecurity.oauth2Login().userInfoEndpoint() 返回的用户属性 +- Spring Security: OAuth2AuthenticationToken.getPrincipal().getAttributes() +- Spring Security: JwtAuthenticationToken.getTokenAttributes() +- Spring Security: Jwt.getClaim(...) / getClaimAsString(...) +- Spring Security: ReactiveSecurityContextHolder.getContext() +- Spring Cloud Gateway: ServerWebExchange.getRequest().getQueryParams() +- Spring Cloud Gateway: ServerWebExchange.getRequest().getHeaders() +- Spring Cloud Gateway: ServerWebExchange.getRequest().getCookies() +- Spring Cloud Gateway: ServerWebExchange.getRequest().getBody() +- Spring Cloud Gateway: ServerHttpRequest.getURI().getRawQuery() + +## 12. 框架细分(Spring WebFlux / Struts2 / Micronaut / Quarkus) +- Spring WebFlux: ServerRequest.queryParam(...) +- Spring WebFlux: ServerRequest.pathVariable(...) +- Spring WebFlux: ServerRequest.headers().firstHeader(...) +- Spring WebFlux: ServerRequest.cookies() +- Spring WebFlux: ServerRequest.bodyToMono(...) +- Spring WebFlux: ServerWebExchange.getRequest().getQueryParams() +- Spring WebFlux: ServerWebExchange.getRequest().getHeaders() +- Spring WebFlux: ServerWebExchange.getRequest().getBody() +- Struts2: ActionContext.getContext().getParameters() +- Struts2: ServletActionContext.getRequest().getParameter(...) +- Struts2: ServletActionContext.getRequest().getHeader(...) +- Struts2: ActionContext.getContext().getName() +- Struts2: ValueStack.findValue(...) +- Micronaut: HttpRequest.getParameters() +- Micronaut: HttpRequest.getHeaders().get(...) +- Micronaut: HttpRequest.getBody() +- Micronaut: @Body / @QueryValue / @PathVariable +- Quarkus: RoutingContext.request().getParam(...) +- Quarkus: RoutingContext.request().getHeader(...) +- Quarkus: RoutingContext.getBodyAsString() +- Quarkus: @QueryParam / @PathParam / @HeaderParam + +## 13. 其他老牌 Java Web 框架 +- WebWork: ActionContext.getContext().getParameters() +- Tapestry: Request.getParameter(...) +- Wicket: RequestCycle.get().getRequest().getRequestParameters() +- JSF: FacesContext.getCurrentInstance().getExternalContext().getRequestParameterMap() +- Apache CXF: Message.getContextualProperty(...) +- Spring MVC 早期: HttpServletRequest.getParameter(...) diff --git a/docs/source/JavaScript.md b/docs/source/JavaScript.md new file mode 100644 index 00000000..62b89329 --- /dev/null +++ b/docs/source/JavaScript.md @@ -0,0 +1,134 @@ +# JavaScript 漏洞 Source 点清单 + +## 1. Web 请求参数(Express/Koa/Nest/Next) +- Express: req.query.xxx +- Express: req.params.xxx +- Express: req.body.xxx +- Express: req.headers["..."] +- Express: req.get("header") +- Express: req.cookies.xxx +- Koa: ctx.query.xxx +- Koa: ctx.params.xxx +- Koa: ctx.request.body +- Koa: ctx.headers["..."] +- NestJS: @Query() / @Param() / @Body() +- Next.js API: req.query / req.body +- Next.js Middleware: request.nextUrl.searchParams.get(...) +- Fastify: request.query / request.body / request.params +- Hapi: request.query / request.payload / request.params +- Sails: req.param(...) +- AdonisJS: request.input(...) +- LoopBack: req.params / ctx.args +- Remix: request.url / await request.json() +- Nuxt Server: event.node.req.url / await readBody(event) + +## 2. 文件上传/多部件表单 +- multer: req.file / req.files +- busboy: file stream / field value +- formidable: files / fields +- koa-body: ctx.request.files +- fastify-multipart: await req.file() +- hapi: request.payload.file + +## 3. URL/路径相关输入 +- req.originalUrl +- req.baseUrl +- req.path +- req.url +- new URL(req.url, base).searchParams.get(...) +- request.ip / req.ips +- req.hostname / req.protocol + +## 4. 认证/会话/Token 输入 +- req.session.xxx +- req.user / req.auth +- passport: req.user +- JWT: req.headers.authorization +- JWT: token payload claims +- express-session: req.session +- koa-session: ctx.session +- cookies: req.signedCookies / ctx.cookies.get(...) + +## 5. GraphQL/RPC 输入 +- GraphQL resolver args +- Apollo: context.req.headers +- gRPC: call.request. +- JSON-RPC: params +- tRPC: ctx.input + +## 6. 模板/表达式输入 +- ejs: req.query / req.body 传入模板变量 +- handlebars: 传入的模板上下文对象 +- pug: locals 对象字段 +- nunjucks: renderString 的 data 参数 +- liquidjs: engine.parseAndRender(source, data) +- eta: render(source, data) + +## 7. 浏览器端输入 +- location.href / location.search +- URLSearchParams.get(...) +- document.cookie +- window.name +- localStorage.getItem(...) +- sessionStorage.getItem(...) +- postMessage event.data +- form input.value +- history.state +- navigator.userAgent + +## 8. 消息队列/事件流输入 +- KafkaJS: message.value +- amqplib: msg.content +- redis: message +- NATS: msg.data +- BullMQ: job.data +- KafkaJS: eachMessage.message.value + +## 9. 数据库中不可信字段 +- ORM 查询结果字段 +- 外部同步数据源字段 +- mongoose 文档字段 +- sequelize 模型字段 + +## 10. 外部系统/配置输入 +- process.env +- dotenv: process.env["..."] +- config.get("...") +- JSON.parse(fs.readFileSync(...)) +- rc 配置文件读取 +- nconf.get("...") + +## 11. CLI/脚本输入 +- process.argv +- yargs argv +- commander.opts() +- minimist(argv) +- zx: argv + +## 12. 框架细分(NestJS / Fastify / Koa) +- NestJS: @Req() req / @Headers() headers / @Body() dto / @Query() query / @Param() param +- NestJS: ExecutionContext.switchToHttp().getRequest() +- NestJS: GraphQL @Args() / @Context() / @Req() +- Fastify: request.body / request.query / request.params / request.headers +- Fastify: request.cookies / request.ip / request.hostname +- Fastify: reply.request.body +- Koa: ctx.request.body / ctx.request.query / ctx.request.headers +- Koa: ctx.params / ctx.cookies.get(...) +- Koa: ctx.request.rawBody + +## 13. 框架细分(Hapi / AdonisJS / LoopBack) +- Hapi: request.payload +- Hapi: request.query / request.params +- Hapi: request.headers / request.state +- AdonisJS: request.input(...) +- AdonisJS: request.all() / request.only(...) +- AdonisJS: request.qs() / request.params() +- LoopBack: ctx.args / req.params +- LoopBack: req.body / req.query + +## 14. 其他老牌 Node/Web 框架 +- Sails: req.param(...) / req.allParams() +- Express 早期: req.param(...) +- Restify: req.params / req.query / req.body +- Restify: req.header(...) +- Hapi v16: request.params / request.payload diff --git a/docs/source/Python.md b/docs/source/Python.md new file mode 100644 index 00000000..3c4c8771 --- /dev/null +++ b/docs/source/Python.md @@ -0,0 +1,127 @@ +# Python 漏洞 Source 点清单 + +## 1. Web 请求参数(Django/Flask/FastAPI) +- Django: request.GET.get(...) +- Django: request.POST.get(...) +- Django: request.body +- Django: request.META.get("HTTP_*") +- Django: request.headers.get(...) +- Django: request.COOKIES.get(...) +- Flask: request.args.get(...) +- Flask: request.form.get(...) +- Flask: request.values.get(...) +- Flask: request.get_json(...) +- Flask: request.data / request.stream +- FastAPI: Query(...) / Path(...) / Header(...) +- FastAPI: Body(...) / Cookie(...) +- Starlette: request.query_params.get(...) +- Starlette: await request.json() +- Starlette: await request.body() +- Django REST Framework: request.data +- Django REST Framework: request.query_params.get(...) +- Django REST Framework: request.headers.get(...) +- Flask-RESTful: request.args / request.json +- Sanic: request.args / request.json / request.body +- Tornado: self.get_argument(...) +- Falcon: req.get_param(...) +- Pyramid: request.params.get(...) +- Bottle: request.query / request.forms / request.json +- aiohttp.web: request.query / await request.json() + +## 2. 文件上传/多部件表单 +- Django: request.FILES.get(...) +- Flask: request.files.get(...) +- FastAPI: UploadFile.filename +- FastAPI: await UploadFile.read() +- Django: request.FILES["..."] +- Flask: FileStorage.stream.read() +- Sanic: request.files.get(...) +- Tornado: self.request.files["..."] +- Starlette: await UploadFile.read() + +## 3. URL/路径相关输入 +- request.url / request.base_url +- request.path / request.full_path +- request.url_root +- request.host +- request.scheme +- request.headers.get("Host") + +## 4. 认证/会话/Token 输入 +- Django: request.user.get_username() +- Django: request.session.get(...) +- Flask: session.get(...) +- FastAPI: OAuth2PasswordBearer token +- JWT: jwt.get_unverified_header(...) +- JWT: jwt.decode(...) 读取 claim +- Django REST Framework: request.auth +- Flask-JWT-Extended: get_jwt_identity() +- FastAPI: OAuth2PasswordRequestForm.username / password + +## 5. RPC/微服务输入 +- gRPC: request. +- Celery: task args/kwargs +- GraphQL: resolver args / info.context +- Celery: task.request.kwargs +- nameko: ctx.data +- thriftpy2: request. +- xmlrpc.server: params + +## 6. 模板/表达式输入 +- Django template: {{ request.GET.xxx }} +- Jinja2: {{ request.args.xxx }} +- Jinja2: {{ request.headers.xxx }} +- Mako: ${request.params.get(...)} +- Django: {{ request.POST.xxx }} +- Jinja2: {{ request.form.xxx }} +- Tornado: self.get_argument(...) + +## 7. 消息队列/事件流输入 +- Kafka: msg.value() +- RabbitMQ: body +- Redis: pubsub message["data"] +- Pulsar: msg.data() +- Dramatiq: message.args / message.kwargs +- RQ: job.args / job.kwargs +- kombu: message.body + +## 8. 数据库中不可信字段 +- ORM 模型字段来自用户创建记录 +- cursor.fetchone()/fetchall() 结果字段 + +## 9. 外部系统/配置输入 +- os.environ.get(...) +- os.getenv(...) +- configparser.ConfigParser().get(...) +- yaml.safe_load(...) 加载配置 +- json.load(...) 读取外部配置 +- dynaconf: settings.get(...) +- hvac: client.read(...) + +## 10. CLI/脚本输入 +- sys.argv +- argparse.ArgumentParser().parse_args() +- click.get_current_context().params +- typer.Option / typer.Argument + +## 11. 框架细分(Django REST Framework 等) +- DRF: APIView.request.data +- DRF: APIView.request.query_params +- DRF: APIView.request.headers.get(...) +- DRF: APIView.request.auth / request.user +- DRF: Serializer.validated_data +- DRF: Serializer.initial_data +- DRF: request._full_data +- django-filter: request.query_params.get(...) +- django-rest-framework-simplejwt: token.payload / token["claim"] + +## 12. 框架细分(Falcon / Sanic / 旧框架) +- Falcon: req.media +- Falcon: req.params / req.get_param(...) +- Falcon: req.get_header(...) +- Sanic: request.json / request.args / request.form +- Sanic: request.headers / request.cookies +- Sanic: request.files / request.raw_body +- web.py: web.input() +- Pylons: request.params / request.POST +- TurboGears: request.params / request.json_body