diff --git a/llm_web_kit/api/dependencies.py b/llm_web_kit/api/dependencies.py
index d949e305..3f4fb68d 100644
--- a/llm_web_kit/api/dependencies.py
+++ b/llm_web_kit/api/dependencies.py
@@ -4,13 +4,27 @@
"""
import logging
+import os
+from contextvars import ContextVar
from functools import lru_cache
+from logging.handlers import TimedRotatingFileHandler
from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
logger = logging.getLogger(__name__)
+# 创建一个 ContextVar 用于存储 request_id,提供默认值
+request_id_var: ContextVar[Optional[str]] = ContextVar("request_id", default=None)
+
+
+class RequestIdFilter(logging.Filter):
+ """日志过滤器,用于将 request_id 从 ContextVar 注入到日志记录中。"""
+
+ def filter(self, record):
+ record.request_id = request_id_var.get()
+ return True
+
class Settings(BaseSettings):
"""应用配置设置."""
@@ -27,6 +41,8 @@ class Settings(BaseSettings):
# 日志配置
log_level: str = "INFO"
+ log_dir: str = "logs"
+ log_filename: str = "api.log"
# 模型配置
model_path: Optional[str] = None
@@ -38,8 +54,8 @@ class Settings(BaseSettings):
# 数据库配置
database_url: Optional[str] = None # 从环境变量 DATABASE_URL 读取
- db_pool_size: int = 5
- db_max_overflow: int = 10
+ db_pool_size: int = 200
+ db_max_overflow: int = 100
# pydantic v2 配置写法
model_config = SettingsConfigDict(
@@ -57,14 +73,39 @@ def get_settings() -> Settings:
def get_logger(name: str = __name__) -> logging.Logger:
"""获取配置好的日志记录器."""
logger = logging.getLogger(name)
+ logger.setLevel(get_settings().log_level)
+ logger.addFilter(RequestIdFilter()) # 添加过滤器
+
if not logger.handlers:
- handler = logging.StreamHandler()
- formatter = logging.Formatter(
- '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ # 控制台处理器
+ stream_handler = logging.StreamHandler()
+ stream_formatter = logging.Formatter(
+ '%(asctime)s - %(request_id)s - %(name)s - %(levelname)s - %(message)s'
+ )
+ stream_handler.setFormatter(stream_formatter)
+ logger.addHandler(stream_handler)
+
+ # 文件处理器 (按天轮换)
+ settings = get_settings()
+ log_dir = settings.log_dir
+ if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+
+ log_file_path = os.path.join(log_dir, settings.log_filename)
+
+ file_handler = TimedRotatingFileHandler(
+ log_file_path,
+ when="midnight", # 每天午夜轮换
+ interval=1,
+ backupCount=30, # 保留30天的日志
+ encoding='utf-8'
)
- handler.setFormatter(formatter)
- logger.addHandler(handler)
- logger.setLevel(get_settings().log_level)
+ file_formatter = logging.Formatter(
+ '%(asctime)s - %(request_id)s - %(name)s - %(levelname)s - %(message)s'
+ )
+ file_handler.setFormatter(file_formatter)
+ logger.addHandler(file_handler)
+
return logger
diff --git a/llm_web_kit/api/main.py b/llm_web_kit/api/main.py
index 561e8df3..38932fc1 100644
--- a/llm_web_kit/api/main.py
+++ b/llm_web_kit/api/main.py
@@ -4,12 +4,14 @@
"""
import uvicorn
-from fastapi import FastAPI
+from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
-from .dependencies import get_inference_service, get_logger, get_settings
+from .dependencies import (get_inference_service, get_logger, get_settings,
+ request_id_var)
from .routers import htmls
+from .services.request_log_service import RequestLogService
settings = get_settings()
logger = get_logger(__name__)
@@ -33,6 +35,30 @@
allow_headers=["*"],
)
+
+@app.middleware("http")
+async def request_id_middleware(request: Request, call_next):
+ """中间件,用于生成 request_id 并通过 ContextVar 在整个请求周期中传递。"""
+ # 从请求头中获取 request_id,如果不存在则生成一个新的
+ request_id = request.headers.get("X-Request-ID")
+ if not request_id:
+ request_id = RequestLogService._generate_request_id()
+
+ # 使用 ContextVar 设置 request_id
+ token = request_id_var.set(request_id)
+
+ # 处理请求
+ response = await call_next(request)
+
+ # 在响应头中添加 request_id
+ response.headers["X-Request-ID"] = request_id
+
+ # 重置 ContextVar
+ request_id_var.reset(token)
+
+ return response
+
+
# 注册路由
app.include_router(htmls.router, prefix="/api/v1", tags=["HTML 处理"])
diff --git a/llm_web_kit/api/routers/htmls.py b/llm_web_kit/api/routers/htmls.py
index 2d3f60a2..b90c9a57 100644
--- a/llm_web_kit/api/routers/htmls.py
+++ b/llm_web_kit/api/routers/htmls.py
@@ -3,13 +3,17 @@
提供 HTML 解析、内容提取等功能的 API 端点。
"""
+import base64
+import html
+import time
from typing import Optional
-from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
+from fastapi import (APIRouter, BackgroundTasks, Body, Depends, File,
+ HTTPException, UploadFile)
from sqlalchemy.ext.asyncio import AsyncSession
from ..database import get_db_session
-from ..dependencies import get_logger, get_settings
+from ..dependencies import get_logger, get_settings, request_id_var
from ..models.request import HTMLParseRequest
from ..models.response import HTMLParseResponse
from ..services.html_service import HTMLService
@@ -23,16 +27,20 @@
@router.post('/html/parse', response_model=HTMLParseResponse)
async def parse_html(
- request: HTMLParseRequest,
- html_service: HTMLService = Depends(HTMLService),
- db_session: Optional[AsyncSession] = Depends(get_db_session)
+ background_tasks: BackgroundTasks,
+ request: HTMLParseRequest = Body(...),
+ html_service: HTMLService = Depends(HTMLService),
+ db_session: Optional[AsyncSession] = Depends(get_db_session)
):
"""解析 HTML 内容.
接收 HTML 字符串并返回解析后的结构化内容。
"""
- # 生成请求ID
- request_id = RequestLogService.generate_request_id()
+ # 从 context var 获取 request_id
+ request_id = request_id_var.get()
+ decoded_bytes = base64.b64decode(request.html_content)
+ decoded_str = decoded_bytes.decode('utf-8')
+ unescaped_html = html.unescape(decoded_str)
# 确定输入类型
if request.html_content:
@@ -43,35 +51,32 @@ async def parse_html(
input_type = 'unknown'
# 创建请求日志
- await RequestLogService.create_log(
+ start_time = time.time()
+ await RequestLogService.initial_log(
session=db_session,
request_id=request_id,
input_type=input_type,
- input_html=request.html_content,
+ input_html=unescaped_html,
url=request.url,
)
-
- # 立即提交,使 processing 状态在数据库中可见
- if db_session:
- try:
- await db_session.commit()
- except Exception as commit_error:
- logger.error(f'提交初始日志时出错: {commit_error}')
+ end_time = time.time()
+ logger.info(f'创建日志耗时: {end_time - start_time}秒')
try:
- logger.info(f'开始解析 HTML [request_id={request_id}],内容长度: {len(request.html_content) if request.html_content else 0}')
+ logger.info(f'开始解析 HTML,内容长度: {len(unescaped_html) if unescaped_html else 0}')
result = await html_service.parse_html(
- html_content=request.html_content,
+ html_content=unescaped_html,
url=request.url,
+ request_id=request_id,
options=request.options
)
- # 更新日志为成功
- await RequestLogService.update_log_success(
- session=db_session,
- request_id=request_id,
- output_markdown=result.get('markdown'),
+ # 将成功日志更新操作添加到后台任务
+ background_tasks.add_task(
+ RequestLogService.log_success_bg,
+ request_id,
+ result.get('markdown')
)
return HTMLParseResponse(
@@ -81,37 +86,32 @@ async def parse_html(
request_id=request_id
)
except Exception as e:
- logger.error(f'HTML 解析失败 [request_id={request_id}]: {str(e)}')
-
- # 更新日志为失败
- await RequestLogService.update_log_failure(
- session=db_session,
- request_id=request_id,
- error_message=str(e),
+ error_message = str(e)
+ logger.error(f'HTML 解析失败: {error_message}')
+
+ # 将失败日志更新操作添加到后台任务
+ background_tasks.add_task(
+ RequestLogService.log_failure_bg,
+ request_id,
+ error_message
)
- # 手动提交事务,确保失败日志被保存
- if db_session:
- try:
- await db_session.commit()
- except Exception as commit_error:
- logger.error(f'提交失败日志时出错: {commit_error}')
-
- raise HTTPException(status_code=500, detail=f'HTML 解析失败: {str(e)}')
+ raise HTTPException(status_code=500, detail=f'HTML 解析失败: {error_message}')
@router.post('/html/upload')
async def upload_html_file(
- file: UploadFile = File(...),
- html_service: HTMLService = Depends(HTMLService),
- db_session: Optional[AsyncSession] = Depends(get_db_session)
+ background_tasks: BackgroundTasks,
+ file: UploadFile = File(...),
+ html_service: HTMLService = Depends(HTMLService),
+ db_session: Optional[AsyncSession] = Depends(get_db_session)
):
"""上传 HTML 文件进行解析.
支持上传 HTML 文件,自动解析并返回结果。
"""
- # 生成请求ID
- request_id = RequestLogService.generate_request_id()
+ # 从 context var 获取 request_id
+ request_id = request_id_var.get()
try:
if not file.filename.endswith(('.html', '.htm')):
@@ -120,31 +120,26 @@ async def upload_html_file(
content = await file.read()
html_content = content.decode('utf-8')
- logger.info(f'上传 HTML 文件 [request_id={request_id}]: {file.filename}, 大小: {len(content)} bytes')
-
+ logger.info(f'上传 HTML 文件: {file.filename}, 大小: {len(content)} bytes')
+ start_time = time.time()
# 创建请求日志
- await RequestLogService.create_log(
+ await RequestLogService.initial_log(
session=db_session,
request_id=request_id,
input_type='file',
input_html=html_content,
url=None,
)
+ end_time = time.time()
+ logger.info(f'创建日志耗时: {end_time - start_time}秒')
- # 立即提交,使 processing 状态在数据库中可见
- if db_session:
- try:
- await db_session.commit()
- except Exception as commit_error:
- logger.error(f'提交初始日志时出错: {commit_error}')
-
- result = await html_service.parse_html(html_content=html_content, url="www.baidu.com")
+ result = await html_service.parse_html(html_content=html_content, url="www.baidu.com", request_id=request_id)
- # 更新日志为成功
- await RequestLogService.update_log_success(
- session=db_session,
- request_id=request_id,
- output_markdown=result.get('markdown'),
+ # 将成功日志更新操作添加到后台任务
+ background_tasks.add_task(
+ RequestLogService.log_success_bg,
+ request_id,
+ result.get('markdown')
)
return HTMLParseResponse(
@@ -154,23 +149,17 @@ async def upload_html_file(
request_id=request_id
)
except Exception as e:
- logger.error(f'HTML 文件解析失败 [request_id={request_id}]: {str(e)}')
-
- # 更新日志为失败
- await RequestLogService.update_log_failure(
- session=db_session,
- request_id=request_id,
- error_message=str(e),
+ error_message = str(e)
+ logger.error(f'HTML 文件解析失败: {error_message}')
+
+ # 将失败日志更新操作添加到后台任务
+ background_tasks.add_task(
+ RequestLogService.log_failure_bg,
+ request_id,
+ error_message
)
- # 手动提交事务,确保失败日志被保存
- if db_session:
- try:
- await db_session.commit()
- except Exception as commit_error:
- logger.error(f'提交失败日志时出错: {commit_error}')
-
- raise HTTPException(status_code=500, detail=f'HTML 文件解析失败: {str(e)}')
+ raise HTTPException(status_code=500, detail=f'HTML 文件解析失败: {error_message}')
@router.get('/html/status')
diff --git a/llm_web_kit/api/services/html_service.py b/llm_web_kit/api/services/html_service.py
index 0bb247a0..03a48789 100644
--- a/llm_web_kit/api/services/html_service.py
+++ b/llm_web_kit/api/services/html_service.py
@@ -3,12 +3,18 @@
桥接原有项目的 HTML 解析和内容提取功能,提供统一的 API 接口。
"""
+import time
from typing import Any, Dict, Optional
import httpx
from llm_web_kit.api.dependencies import (get_inference_service, get_logger,
get_settings)
+from llm_web_kit.input.pre_data_json import PreDataJson, PreDataJsonKey
+from llm_web_kit.main_html_parser.parser.tag_mapping import \
+ MapItemToHtmlTagsParser
+from llm_web_kit.main_html_parser.simplify_html.simplify_html import \
+ simplify_html
from llm_web_kit.simple import extract_content_from_main_html
logger = get_logger(__name__)
@@ -32,10 +38,11 @@ def _init_components(self):
return None
async def parse_html(
- self,
- html_content: Optional[str] = None,
- url: Optional[str] = None,
- options: Optional[Dict[str, Any]] = None
+ self,
+ html_content: Optional[str] = None,
+ url: Optional[str] = None,
+ request_id: str = None,
+ options: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""解析 HTML 内容."""
try:
@@ -60,37 +67,37 @@ async def parse_html(
if not html_content:
raise ValueError('必须提供 HTML 内容或有效的 URL')
- # 延迟导入,避免模块导入期异常导致服务类不可用
- try:
- from llm_web_kit.input.pre_data_json import (PreDataJson,
- PreDataJsonKey)
- from llm_web_kit.main_html_parser.parser.tag_mapping import \
- MapItemToHtmlTagsParser
- from llm_web_kit.main_html_parser.simplify_html.simplify_html import \
- simplify_html
- except Exception as import_err:
- logger.error(f'依赖导入失败: {import_err}')
- raise
-
+ # logger.info(f"html_content: {html_content}")
# 简化网页
try:
+ start_time = time.time()
simplified_html, typical_raw_tag_html = simplify_html(html_content)
+ end_time = time.time()
+ logger.info(f'简化完成, 耗时: {end_time - start_time}秒')
except Exception as e:
logger.error(f'简化网页失败: {e}')
raise
# 模型推理
+ start_time = time.time()
llm_response = await self._parse_with_model(simplified_html, options)
-
+ end_time = time.time()
+ logger.info(f'模型推理总耗时: {end_time - start_time}秒')
# 结果映射
+ start_time = time.time()
pre_data = PreDataJson({})
pre_data[PreDataJsonKey.TYPICAL_RAW_HTML] = html_content
pre_data[PreDataJsonKey.TYPICAL_RAW_TAG_HTML] = typical_raw_tag_html
pre_data[PreDataJsonKey.LLM_RESPONSE] = llm_response
parser = MapItemToHtmlTagsParser({})
pre_data = parser.parse_single(pre_data)
+ end_time = time.time()
+ logger.info(f'映射耗时: {end_time - start_time}秒')
main_html = pre_data[PreDataJsonKey.TYPICAL_MAIN_HTML]
+ start_time = time.time()
mm_nlp_md = extract_content_from_main_html(url, main_html, 'mm_md', use_raw_image_url=True)
+ end_time = time.time()
+ logger.info(f'抽取markdown耗时: {end_time - start_time}秒')
pre_data['markdown'] = mm_nlp_md
# 将 PreDataJson 转为标准 dict,避免响应模型校验错误
return dict(pre_data.items())
@@ -110,11 +117,14 @@ async def _parse_with_model(self, html_content: str, options: Optional[Dict[str,
# 重新导入以确保加载最新的代码,绕过缓存问题
from llm_web_kit.api.dependencies import get_settings
+
settings = get_settings()
async def main():
async with httpx.AsyncClient() as client:
- response = await client.post(settings.crawl_url, json={'url': 'https://aws.amazon.com/what-is/retrieval-augmented-generation/'}, timeout=60)
+ response = await client.post(settings.crawl_url,
+ json={'url': 'https://aws.amazon.com/what-is/retrieval-augmented-generation/'},
+ timeout=60)
response.raise_for_status()
data = response.json()
html_content = data.get('html')
diff --git a/llm_web_kit/api/services/inference_service.py b/llm_web_kit/api/services/inference_service.py
index 6401f0fd..fb28ff2c 100644
--- a/llm_web_kit/api/services/inference_service.py
+++ b/llm_web_kit/api/services/inference_service.py
@@ -293,6 +293,7 @@ def __init__(self):
"""初始化推理服务,延迟加载模型."""
self._llm = None
self._tokenizer = None
+ self._sampling_params = None # 新增采样参数成员
self._initialized = False
self._init_lock = None # 用于异步初始化锁
self._model_path = None
@@ -343,7 +344,23 @@ async def _init_model(self):
# max_model_len=config.max_tokens, # 减少序列长度避免内存不足
)
- logger.info(f'模型初始化成功: {self.model_path}')
+ # 在初始化时创建采样参数
+ if config.use_logits_processor:
+ token_state = Token_state(self.model_path)
+ self._sampling_params = SamplingParams(
+ temperature=config.temperature,
+ top_p=config.top_p,
+ max_tokens=config.max_output_tokens,
+ logits_processors=[token_state.process_logit]
+ )
+ else:
+ self._sampling_params = SamplingParams(
+ temperature=config.temperature,
+ top_p=config.top_p,
+ max_tokens=config.max_output_tokens
+ )
+
+ logger.info(f'模型和采样参数初始化成功: {self.model_path}')
except Exception as e:
logger.error(f'模型初始化失败: {e}')
@@ -374,31 +391,21 @@ async def _run_real_inference(self, simplified_html: str, options: dict | None =
prompt = create_prompt(simplified_html)
chat_prompt = add_template(prompt, self._tokenizer)
- # 设置采样参数
- if config.use_logits_processor:
- token_state = Token_state(self.model_path)
- sampling_params = SamplingParams(
- temperature=config.temperature,
- top_p=config.top_p,
- max_tokens=config.max_output_tokens,
- logits_processors=[token_state.process_logit]
- )
- else:
- sampling_params = SamplingParams(
- temperature=config.temperature,
- top_p=config.top_p,
- max_tokens=config.max_output_tokens
- )
+ # 直接使用初始化好的采样参数
+ if self._sampling_params is None:
+ logger.error("采样参数未初始化,返回占位结果")
+ return self._get_placeholder_result()
# 执行推理
start_time = time.time()
- output = self._llm.generate(chat_prompt, sampling_params)
+ output = self._llm.generate(chat_prompt, self._sampling_params)
end_time = time.time()
output_json = clean_output(output)
# 格式化结果
result = reformat_map(output_json)
- logger.info(f'推理完成,结果: {result}, 耗时: {end_time - start_time}秒')
+ logger.info(f'推理完成,结果:{result}')
+ logger.info(f'推理完成, 耗时: {end_time - start_time}秒')
return result
except Exception as e:
diff --git a/llm_web_kit/api/services/request_log_service.py b/llm_web_kit/api/services/request_log_service.py
index 1de0950d..f1196d54 100644
--- a/llm_web_kit/api/services/request_log_service.py
+++ b/llm_web_kit/api/services/request_log_service.py
@@ -10,6 +10,7 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
+from ..database import get_db_manager
from ..dependencies import get_logger
from ..models.db_models import RequestLog
@@ -18,18 +19,19 @@
class RequestLogService:
"""请求日志服务类."""
+
@staticmethod
- def generate_request_id() -> str:
+ def _generate_request_id() -> str:
"""生成唯一的请求ID."""
return str(uuid.uuid4())
@staticmethod
async def create_log(
- session: Optional[AsyncSession],
- request_id: str,
- input_type: str,
- input_html: Optional[str] = None,
- url: Optional[str] = None,
+ session: Optional[AsyncSession],
+ request_id: str,
+ input_type: str,
+ input_html: Optional[str] = None,
+ url: Optional[str] = None,
) -> Optional[RequestLog]:
"""创建请求日志记录.
@@ -63,11 +65,33 @@ async def create_log(
logger.error(f"创建请求日志失败: {e}")
return None
+ @staticmethod
+ async def initial_log(
+ session: Optional[AsyncSession],
+ request_id: str,
+ input_type: str,
+ input_html: Optional[str] = None,
+ url: Optional[str] = None,
+ ):
+ """创建并提交初始日志."""
+ if not session:
+ logger.debug("数据库会话为空,跳过初始日志记录")
+ return
+
+ await RequestLogService.create_log(
+ session, request_id, input_type, input_html, url
+ )
+ try:
+ await session.commit()
+ except Exception as e:
+ logger.error(f"提交初始日志时出错: {e}")
+ await session.rollback()
+
@staticmethod
async def update_log_success(
- session: Optional[AsyncSession],
- request_id: str,
- output_markdown: Optional[str] = None,
+ session: Optional[AsyncSession],
+ request_id: str,
+ output_markdown: Optional[str] = None,
) -> bool:
"""更新请求日志为成功状态.
@@ -99,11 +123,23 @@ async def update_log_success(
logger.error(f"更新请求日志失败: {e}")
return False
+ @staticmethod
+ async def log_success_bg(request_id: str, output_markdown: Optional[str] = None):
+ """作为后台任务更新日志为成功."""
+ async with get_db_manager().get_session() as bg_session:
+ updated = await RequestLogService.update_log_success(
+ session=bg_session,
+ request_id=request_id,
+ output_markdown=output_markdown,
+ )
+ if updated:
+ await bg_session.commit()
+
@staticmethod
async def update_log_failure(
- session: Optional[AsyncSession],
- request_id: str,
- error_message: str,
+ session: Optional[AsyncSession],
+ request_id: str,
+ error_message: str,
) -> bool:
"""更新请求日志为失败状态.
@@ -136,10 +172,22 @@ async def update_log_failure(
logger.error(f"更新请求日志失败: {e}")
return False
+ @staticmethod
+ async def log_failure_bg(request_id: str, error_message: str):
+ """作为后台任务更新日志为失败."""
+ async with get_db_manager().get_session() as bg_session:
+ updated = await RequestLogService.update_log_failure(
+ session=bg_session,
+ request_id=request_id,
+ error_message=error_message,
+ )
+ if updated:
+ await bg_session.commit()
+
@staticmethod
async def get_log_by_request_id(
- session: Optional[AsyncSession],
- request_id: str,
+ session: Optional[AsyncSession],
+ request_id: str,
) -> Optional[RequestLog]:
"""根据请求ID查询日志.
diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/not_html.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/not_html.html
index bd492e4a..f253b572 100644
--- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/not_html.html
+++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/not_html.html
@@ -654,7 +654,7 @@
-->