diff --git a/llm_web_kit/api/dependencies.py b/llm_web_kit/api/dependencies.py index d949e305..3f4fb68d 100644 --- a/llm_web_kit/api/dependencies.py +++ b/llm_web_kit/api/dependencies.py @@ -4,13 +4,27 @@ """ import logging +import os +from contextvars import ContextVar from functools import lru_cache +from logging.handlers import TimedRotatingFileHandler from typing import Optional from pydantic_settings import BaseSettings, SettingsConfigDict logger = logging.getLogger(__name__) +# 创建一个 ContextVar 用于存储 request_id,提供默认值 +request_id_var: ContextVar[Optional[str]] = ContextVar("request_id", default=None) + + +class RequestIdFilter(logging.Filter): + """日志过滤器,用于将 request_id 从 ContextVar 注入到日志记录中。""" + + def filter(self, record): + record.request_id = request_id_var.get() + return True + class Settings(BaseSettings): """应用配置设置.""" @@ -27,6 +41,8 @@ class Settings(BaseSettings): # 日志配置 log_level: str = "INFO" + log_dir: str = "logs" + log_filename: str = "api.log" # 模型配置 model_path: Optional[str] = None @@ -38,8 +54,8 @@ class Settings(BaseSettings): # 数据库配置 database_url: Optional[str] = None # 从环境变量 DATABASE_URL 读取 - db_pool_size: int = 5 - db_max_overflow: int = 10 + db_pool_size: int = 200 + db_max_overflow: int = 100 # pydantic v2 配置写法 model_config = SettingsConfigDict( @@ -57,14 +73,39 @@ def get_settings() -> Settings: def get_logger(name: str = __name__) -> logging.Logger: """获取配置好的日志记录器.""" logger = logging.getLogger(name) + logger.setLevel(get_settings().log_level) + logger.addFilter(RequestIdFilter()) # 添加过滤器 + if not logger.handlers: - handler = logging.StreamHandler() - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + # 控制台处理器 + stream_handler = logging.StreamHandler() + stream_formatter = logging.Formatter( + '%(asctime)s - %(request_id)s - %(name)s - %(levelname)s - %(message)s' + ) + stream_handler.setFormatter(stream_formatter) + logger.addHandler(stream_handler) + + # 文件处理器 (按天轮换) + settings = get_settings() + log_dir = settings.log_dir + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + log_file_path = os.path.join(log_dir, settings.log_filename) + + file_handler = TimedRotatingFileHandler( + log_file_path, + when="midnight", # 每天午夜轮换 + interval=1, + backupCount=30, # 保留30天的日志 + encoding='utf-8' ) - handler.setFormatter(formatter) - logger.addHandler(handler) - logger.setLevel(get_settings().log_level) + file_formatter = logging.Formatter( + '%(asctime)s - %(request_id)s - %(name)s - %(levelname)s - %(message)s' + ) + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) + return logger diff --git a/llm_web_kit/api/main.py b/llm_web_kit/api/main.py index 561e8df3..38932fc1 100644 --- a/llm_web_kit/api/main.py +++ b/llm_web_kit/api/main.py @@ -4,12 +4,14 @@ """ import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse -from .dependencies import get_inference_service, get_logger, get_settings +from .dependencies import (get_inference_service, get_logger, get_settings, + request_id_var) from .routers import htmls +from .services.request_log_service import RequestLogService settings = get_settings() logger = get_logger(__name__) @@ -33,6 +35,30 @@ allow_headers=["*"], ) + +@app.middleware("http") +async def request_id_middleware(request: Request, call_next): + """中间件,用于生成 request_id 并通过 ContextVar 在整个请求周期中传递。""" + # 从请求头中获取 request_id,如果不存在则生成一个新的 + request_id = request.headers.get("X-Request-ID") + if not request_id: + request_id = RequestLogService._generate_request_id() + + # 使用 ContextVar 设置 request_id + token = request_id_var.set(request_id) + + # 处理请求 + response = await call_next(request) + + # 在响应头中添加 request_id + response.headers["X-Request-ID"] = request_id + + # 重置 ContextVar + request_id_var.reset(token) + + return response + + # 注册路由 app.include_router(htmls.router, prefix="/api/v1", tags=["HTML 处理"]) diff --git a/llm_web_kit/api/routers/htmls.py b/llm_web_kit/api/routers/htmls.py index 2d3f60a2..b90c9a57 100644 --- a/llm_web_kit/api/routers/htmls.py +++ b/llm_web_kit/api/routers/htmls.py @@ -3,13 +3,17 @@ 提供 HTML 解析、内容提取等功能的 API 端点。 """ +import base64 +import html +import time from typing import Optional -from fastapi import APIRouter, Depends, File, HTTPException, UploadFile +from fastapi import (APIRouter, BackgroundTasks, Body, Depends, File, + HTTPException, UploadFile) from sqlalchemy.ext.asyncio import AsyncSession from ..database import get_db_session -from ..dependencies import get_logger, get_settings +from ..dependencies import get_logger, get_settings, request_id_var from ..models.request import HTMLParseRequest from ..models.response import HTMLParseResponse from ..services.html_service import HTMLService @@ -23,16 +27,20 @@ @router.post('/html/parse', response_model=HTMLParseResponse) async def parse_html( - request: HTMLParseRequest, - html_service: HTMLService = Depends(HTMLService), - db_session: Optional[AsyncSession] = Depends(get_db_session) + background_tasks: BackgroundTasks, + request: HTMLParseRequest = Body(...), + html_service: HTMLService = Depends(HTMLService), + db_session: Optional[AsyncSession] = Depends(get_db_session) ): """解析 HTML 内容. 接收 HTML 字符串并返回解析后的结构化内容。 """ - # 生成请求ID - request_id = RequestLogService.generate_request_id() + # 从 context var 获取 request_id + request_id = request_id_var.get() + decoded_bytes = base64.b64decode(request.html_content) + decoded_str = decoded_bytes.decode('utf-8') + unescaped_html = html.unescape(decoded_str) # 确定输入类型 if request.html_content: @@ -43,35 +51,32 @@ async def parse_html( input_type = 'unknown' # 创建请求日志 - await RequestLogService.create_log( + start_time = time.time() + await RequestLogService.initial_log( session=db_session, request_id=request_id, input_type=input_type, - input_html=request.html_content, + input_html=unescaped_html, url=request.url, ) - - # 立即提交,使 processing 状态在数据库中可见 - if db_session: - try: - await db_session.commit() - except Exception as commit_error: - logger.error(f'提交初始日志时出错: {commit_error}') + end_time = time.time() + logger.info(f'创建日志耗时: {end_time - start_time}秒') try: - logger.info(f'开始解析 HTML [request_id={request_id}],内容长度: {len(request.html_content) if request.html_content else 0}') + logger.info(f'开始解析 HTML,内容长度: {len(unescaped_html) if unescaped_html else 0}') result = await html_service.parse_html( - html_content=request.html_content, + html_content=unescaped_html, url=request.url, + request_id=request_id, options=request.options ) - # 更新日志为成功 - await RequestLogService.update_log_success( - session=db_session, - request_id=request_id, - output_markdown=result.get('markdown'), + # 将成功日志更新操作添加到后台任务 + background_tasks.add_task( + RequestLogService.log_success_bg, + request_id, + result.get('markdown') ) return HTMLParseResponse( @@ -81,37 +86,32 @@ async def parse_html( request_id=request_id ) except Exception as e: - logger.error(f'HTML 解析失败 [request_id={request_id}]: {str(e)}') - - # 更新日志为失败 - await RequestLogService.update_log_failure( - session=db_session, - request_id=request_id, - error_message=str(e), + error_message = str(e) + logger.error(f'HTML 解析失败: {error_message}') + + # 将失败日志更新操作添加到后台任务 + background_tasks.add_task( + RequestLogService.log_failure_bg, + request_id, + error_message ) - # 手动提交事务,确保失败日志被保存 - if db_session: - try: - await db_session.commit() - except Exception as commit_error: - logger.error(f'提交失败日志时出错: {commit_error}') - - raise HTTPException(status_code=500, detail=f'HTML 解析失败: {str(e)}') + raise HTTPException(status_code=500, detail=f'HTML 解析失败: {error_message}') @router.post('/html/upload') async def upload_html_file( - file: UploadFile = File(...), - html_service: HTMLService = Depends(HTMLService), - db_session: Optional[AsyncSession] = Depends(get_db_session) + background_tasks: BackgroundTasks, + file: UploadFile = File(...), + html_service: HTMLService = Depends(HTMLService), + db_session: Optional[AsyncSession] = Depends(get_db_session) ): """上传 HTML 文件进行解析. 支持上传 HTML 文件,自动解析并返回结果。 """ - # 生成请求ID - request_id = RequestLogService.generate_request_id() + # 从 context var 获取 request_id + request_id = request_id_var.get() try: if not file.filename.endswith(('.html', '.htm')): @@ -120,31 +120,26 @@ async def upload_html_file( content = await file.read() html_content = content.decode('utf-8') - logger.info(f'上传 HTML 文件 [request_id={request_id}]: {file.filename}, 大小: {len(content)} bytes') - + logger.info(f'上传 HTML 文件: {file.filename}, 大小: {len(content)} bytes') + start_time = time.time() # 创建请求日志 - await RequestLogService.create_log( + await RequestLogService.initial_log( session=db_session, request_id=request_id, input_type='file', input_html=html_content, url=None, ) + end_time = time.time() + logger.info(f'创建日志耗时: {end_time - start_time}秒') - # 立即提交,使 processing 状态在数据库中可见 - if db_session: - try: - await db_session.commit() - except Exception as commit_error: - logger.error(f'提交初始日志时出错: {commit_error}') - - result = await html_service.parse_html(html_content=html_content, url="www.baidu.com") + result = await html_service.parse_html(html_content=html_content, url="www.baidu.com", request_id=request_id) - # 更新日志为成功 - await RequestLogService.update_log_success( - session=db_session, - request_id=request_id, - output_markdown=result.get('markdown'), + # 将成功日志更新操作添加到后台任务 + background_tasks.add_task( + RequestLogService.log_success_bg, + request_id, + result.get('markdown') ) return HTMLParseResponse( @@ -154,23 +149,17 @@ async def upload_html_file( request_id=request_id ) except Exception as e: - logger.error(f'HTML 文件解析失败 [request_id={request_id}]: {str(e)}') - - # 更新日志为失败 - await RequestLogService.update_log_failure( - session=db_session, - request_id=request_id, - error_message=str(e), + error_message = str(e) + logger.error(f'HTML 文件解析失败: {error_message}') + + # 将失败日志更新操作添加到后台任务 + background_tasks.add_task( + RequestLogService.log_failure_bg, + request_id, + error_message ) - # 手动提交事务,确保失败日志被保存 - if db_session: - try: - await db_session.commit() - except Exception as commit_error: - logger.error(f'提交失败日志时出错: {commit_error}') - - raise HTTPException(status_code=500, detail=f'HTML 文件解析失败: {str(e)}') + raise HTTPException(status_code=500, detail=f'HTML 文件解析失败: {error_message}') @router.get('/html/status') diff --git a/llm_web_kit/api/services/html_service.py b/llm_web_kit/api/services/html_service.py index 0bb247a0..03a48789 100644 --- a/llm_web_kit/api/services/html_service.py +++ b/llm_web_kit/api/services/html_service.py @@ -3,12 +3,18 @@ 桥接原有项目的 HTML 解析和内容提取功能,提供统一的 API 接口。 """ +import time from typing import Any, Dict, Optional import httpx from llm_web_kit.api.dependencies import (get_inference_service, get_logger, get_settings) +from llm_web_kit.input.pre_data_json import PreDataJson, PreDataJsonKey +from llm_web_kit.main_html_parser.parser.tag_mapping import \ + MapItemToHtmlTagsParser +from llm_web_kit.main_html_parser.simplify_html.simplify_html import \ + simplify_html from llm_web_kit.simple import extract_content_from_main_html logger = get_logger(__name__) @@ -32,10 +38,11 @@ def _init_components(self): return None async def parse_html( - self, - html_content: Optional[str] = None, - url: Optional[str] = None, - options: Optional[Dict[str, Any]] = None + self, + html_content: Optional[str] = None, + url: Optional[str] = None, + request_id: str = None, + options: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """解析 HTML 内容.""" try: @@ -60,37 +67,37 @@ async def parse_html( if not html_content: raise ValueError('必须提供 HTML 内容或有效的 URL') - # 延迟导入,避免模块导入期异常导致服务类不可用 - try: - from llm_web_kit.input.pre_data_json import (PreDataJson, - PreDataJsonKey) - from llm_web_kit.main_html_parser.parser.tag_mapping import \ - MapItemToHtmlTagsParser - from llm_web_kit.main_html_parser.simplify_html.simplify_html import \ - simplify_html - except Exception as import_err: - logger.error(f'依赖导入失败: {import_err}') - raise - + # logger.info(f"html_content: {html_content}") # 简化网页 try: + start_time = time.time() simplified_html, typical_raw_tag_html = simplify_html(html_content) + end_time = time.time() + logger.info(f'简化完成, 耗时: {end_time - start_time}秒') except Exception as e: logger.error(f'简化网页失败: {e}') raise # 模型推理 + start_time = time.time() llm_response = await self._parse_with_model(simplified_html, options) - + end_time = time.time() + logger.info(f'模型推理总耗时: {end_time - start_time}秒') # 结果映射 + start_time = time.time() pre_data = PreDataJson({}) pre_data[PreDataJsonKey.TYPICAL_RAW_HTML] = html_content pre_data[PreDataJsonKey.TYPICAL_RAW_TAG_HTML] = typical_raw_tag_html pre_data[PreDataJsonKey.LLM_RESPONSE] = llm_response parser = MapItemToHtmlTagsParser({}) pre_data = parser.parse_single(pre_data) + end_time = time.time() + logger.info(f'映射耗时: {end_time - start_time}秒') main_html = pre_data[PreDataJsonKey.TYPICAL_MAIN_HTML] + start_time = time.time() mm_nlp_md = extract_content_from_main_html(url, main_html, 'mm_md', use_raw_image_url=True) + end_time = time.time() + logger.info(f'抽取markdown耗时: {end_time - start_time}秒') pre_data['markdown'] = mm_nlp_md # 将 PreDataJson 转为标准 dict,避免响应模型校验错误 return dict(pre_data.items()) @@ -110,11 +117,14 @@ async def _parse_with_model(self, html_content: str, options: Optional[Dict[str, # 重新导入以确保加载最新的代码,绕过缓存问题 from llm_web_kit.api.dependencies import get_settings + settings = get_settings() async def main(): async with httpx.AsyncClient() as client: - response = await client.post(settings.crawl_url, json={'url': 'https://aws.amazon.com/what-is/retrieval-augmented-generation/'}, timeout=60) + response = await client.post(settings.crawl_url, + json={'url': 'https://aws.amazon.com/what-is/retrieval-augmented-generation/'}, + timeout=60) response.raise_for_status() data = response.json() html_content = data.get('html') diff --git a/llm_web_kit/api/services/inference_service.py b/llm_web_kit/api/services/inference_service.py index 6401f0fd..fb28ff2c 100644 --- a/llm_web_kit/api/services/inference_service.py +++ b/llm_web_kit/api/services/inference_service.py @@ -293,6 +293,7 @@ def __init__(self): """初始化推理服务,延迟加载模型.""" self._llm = None self._tokenizer = None + self._sampling_params = None # 新增采样参数成员 self._initialized = False self._init_lock = None # 用于异步初始化锁 self._model_path = None @@ -343,7 +344,23 @@ async def _init_model(self): # max_model_len=config.max_tokens, # 减少序列长度避免内存不足 ) - logger.info(f'模型初始化成功: {self.model_path}') + # 在初始化时创建采样参数 + if config.use_logits_processor: + token_state = Token_state(self.model_path) + self._sampling_params = SamplingParams( + temperature=config.temperature, + top_p=config.top_p, + max_tokens=config.max_output_tokens, + logits_processors=[token_state.process_logit] + ) + else: + self._sampling_params = SamplingParams( + temperature=config.temperature, + top_p=config.top_p, + max_tokens=config.max_output_tokens + ) + + logger.info(f'模型和采样参数初始化成功: {self.model_path}') except Exception as e: logger.error(f'模型初始化失败: {e}') @@ -374,31 +391,21 @@ async def _run_real_inference(self, simplified_html: str, options: dict | None = prompt = create_prompt(simplified_html) chat_prompt = add_template(prompt, self._tokenizer) - # 设置采样参数 - if config.use_logits_processor: - token_state = Token_state(self.model_path) - sampling_params = SamplingParams( - temperature=config.temperature, - top_p=config.top_p, - max_tokens=config.max_output_tokens, - logits_processors=[token_state.process_logit] - ) - else: - sampling_params = SamplingParams( - temperature=config.temperature, - top_p=config.top_p, - max_tokens=config.max_output_tokens - ) + # 直接使用初始化好的采样参数 + if self._sampling_params is None: + logger.error("采样参数未初始化,返回占位结果") + return self._get_placeholder_result() # 执行推理 start_time = time.time() - output = self._llm.generate(chat_prompt, sampling_params) + output = self._llm.generate(chat_prompt, self._sampling_params) end_time = time.time() output_json = clean_output(output) # 格式化结果 result = reformat_map(output_json) - logger.info(f'推理完成,结果: {result}, 耗时: {end_time - start_time}秒') + logger.info(f'推理完成,结果:{result}') + logger.info(f'推理完成, 耗时: {end_time - start_time}秒') return result except Exception as e: diff --git a/llm_web_kit/api/services/request_log_service.py b/llm_web_kit/api/services/request_log_service.py index 1de0950d..f1196d54 100644 --- a/llm_web_kit/api/services/request_log_service.py +++ b/llm_web_kit/api/services/request_log_service.py @@ -10,6 +10,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from ..database import get_db_manager from ..dependencies import get_logger from ..models.db_models import RequestLog @@ -18,18 +19,19 @@ class RequestLogService: """请求日志服务类.""" + @staticmethod - def generate_request_id() -> str: + def _generate_request_id() -> str: """生成唯一的请求ID.""" return str(uuid.uuid4()) @staticmethod async def create_log( - session: Optional[AsyncSession], - request_id: str, - input_type: str, - input_html: Optional[str] = None, - url: Optional[str] = None, + session: Optional[AsyncSession], + request_id: str, + input_type: str, + input_html: Optional[str] = None, + url: Optional[str] = None, ) -> Optional[RequestLog]: """创建请求日志记录. @@ -63,11 +65,33 @@ async def create_log( logger.error(f"创建请求日志失败: {e}") return None + @staticmethod + async def initial_log( + session: Optional[AsyncSession], + request_id: str, + input_type: str, + input_html: Optional[str] = None, + url: Optional[str] = None, + ): + """创建并提交初始日志.""" + if not session: + logger.debug("数据库会话为空,跳过初始日志记录") + return + + await RequestLogService.create_log( + session, request_id, input_type, input_html, url + ) + try: + await session.commit() + except Exception as e: + logger.error(f"提交初始日志时出错: {e}") + await session.rollback() + @staticmethod async def update_log_success( - session: Optional[AsyncSession], - request_id: str, - output_markdown: Optional[str] = None, + session: Optional[AsyncSession], + request_id: str, + output_markdown: Optional[str] = None, ) -> bool: """更新请求日志为成功状态. @@ -99,11 +123,23 @@ async def update_log_success( logger.error(f"更新请求日志失败: {e}") return False + @staticmethod + async def log_success_bg(request_id: str, output_markdown: Optional[str] = None): + """作为后台任务更新日志为成功.""" + async with get_db_manager().get_session() as bg_session: + updated = await RequestLogService.update_log_success( + session=bg_session, + request_id=request_id, + output_markdown=output_markdown, + ) + if updated: + await bg_session.commit() + @staticmethod async def update_log_failure( - session: Optional[AsyncSession], - request_id: str, - error_message: str, + session: Optional[AsyncSession], + request_id: str, + error_message: str, ) -> bool: """更新请求日志为失败状态. @@ -136,10 +172,22 @@ async def update_log_failure( logger.error(f"更新请求日志失败: {e}") return False + @staticmethod + async def log_failure_bg(request_id: str, error_message: str): + """作为后台任务更新日志为失败.""" + async with get_db_manager().get_session() as bg_session: + updated = await RequestLogService.update_log_failure( + session=bg_session, + request_id=request_id, + error_message=error_message, + ) + if updated: + await bg_session.commit() + @staticmethod async def get_log_by_request_id( - session: Optional[AsyncSession], - request_id: str, + session: Optional[AsyncSession], + request_id: str, ) -> Optional[RequestLog]: """根据请求ID查询日志. diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/not_html.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/not_html.html index bd492e4a..f253b572 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/not_html.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/not_html.html @@ -654,7 +654,7 @@ -->