Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,19 @@ outputs/

fastapi_app/invite_codes.txt
models/*
!models/.gitkeep
!models/README.md
!models/RMBG-2.0/
!models/sam3/
!models/sam3-official/
models/RMBG-2.0/*
!models/RMBG-2.0/.gitkeep
models/sam3/*
!models/sam3/.gitkeep
models/sam3-official/*
!models/sam3-official/sam3/
models/sam3-official/sam3/*
!models/sam3-official/sam3/.gitkeep
outputs/*
tmps/*
data/*
Expand Down
60 changes: 33 additions & 27 deletions dataflow_agent/toolkits/multimodaltool/bg_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
# - get_svg_render_desc: 返回 SVG 渲染工具的说明文本
# ================================================================
# BRIA-RMBG 2.0 高质量抠图工具
# - 模型:RMBG 2.0(ONNX
# - 依赖:onnxruntime, pillow, numpy
# - 模型:RMBG 2.0(Transformers / ModelScope 目录
# - 依赖:transformers, pillow, numpy
# ================================================================

from __future__ import annotations

import os
import subprocess
import sys
from pathlib import Path
import platform

import numpy as np
from PIL import Image, ImageFilter
Expand All @@ -31,57 +31,61 @@
from dataflow_agent.logger import get_logger

CURRENT_DIR = Path(__file__).resolve().parent
# Allow override via env var (e.g. in Docker: RMBG_MODEL_PATH=/app/models/RMBG-2.0)
PROJECT_ROOT = CURRENT_DIR.parents[2]
# Allow override via env var (e.g. RMBG_MODEL_PATH=/app/models/RMBG-2.0)
_env_model_path = os.environ.get("RMBG_MODEL_PATH")
MODEL_PATH = Path(_env_model_path) if _env_model_path else CURRENT_DIR / "onnx" / "model.onnx"
MODEL_PATH = Path(_env_model_path) if _env_model_path else PROJECT_ROOT / "models" / "RMBG-2.0"
OUTPUT_DIR = CURRENT_DIR

# 进程级抠图模型缓存:按 model_path 复用 BriaRMBG2Remover 实例
_BG_RMBG_MODEL_CACHE: dict[str, "BriaRMBG2Remover"] = {}
log = get_logger(__name__)


def _has_rmbg_model(model_path: Path) -> bool:
return (model_path / "config.json").exists() and (model_path / "model.safetensors").exists()


def ensure_model(model_path: Path) -> None:
"""
确保本地存在 RMBG-2.0 模型权重。

``model_path`` 对应的文件不存在,则通过 ModelScope 下载
``AI-ModelScope/RMBG-2.0`` 模型到该路径所在目录
当前实现期望 ``model_path`` 是一个 HuggingFace / ModelScope 风格的
模型目录;若目录缺失,则自动下载 ``AI-ModelScope/RMBG-2.0``。

参数
----
model_path:
本地模型文件路径(通常为 ONNX 或 transformers 权重)
本地模型目录路径

异常
----
FileNotFoundError
当下载结束后仍未在 ``model_path`` 处找到模型文件时抛出。
"""
if model_path.exists():
if model_path.is_file():
model_path = model_path.parent

if _has_rmbg_model(model_path):
log.info(f"模型已存在: {model_path}")
return

log.info("未检测到模型文件,正在下载 RMBG-2.0 权重...")

# 确保目录存在
model_path.parent.mkdir(parents=True, exist_ok=True)
# 判断当前系统是否为Windows
is_windows = platform.system().lower() == "windows"
# Windows用双引号包裹路径,Linux/macOS用单引号(保持原有逻辑)
quote = '"' if is_windows else "'"
# 直接下载到目标目录
cmd = (
f"modelscope download "
f"--model AI-ModelScope/RMBG-2.0 "
f"--local_dir {quote}{model_path}{quote} "
)
os.system(cmd)
log.info("未检测到 RMBG 模型目录,正在下载 RMBG-2.0 权重...")

model_path.mkdir(parents=True, exist_ok=True)
download_code = (
"from modelscope import snapshot_download\n"
"snapshot_download("
"'AI-ModelScope/RMBG-2.0', "
"local_dir=r'''%s''', "
"allow_patterns=['*.json', '*.py', '*.safetensors']"
")\n"
) % str(model_path)
subprocess.run([sys.executable, "-c", download_code], check=True)

# 检查下载是否成功
if not model_path.exists():
if not _has_rmbg_model(model_path):
raise FileNotFoundError(
f"模型下载失败:未找到 {model_path}。\n"
f"模型下载失败:目录 {model_path} 缺少 config.json 或 model.safetensors。\n"
"请检查 ModelScope 或手动下载。"
)

Expand Down Expand Up @@ -109,6 +113,8 @@ def __init__(self, model_path: Path | None = None, output_dir: Path | None = Non
输出目录。若为 None,则使用当前文件所在目录。
"""
self.model_path = Path(model_path) if model_path else MODEL_PATH
if self.model_path.is_file():
self.model_path = self.model_path.parent
self.output_dir = Path(output_dir) if output_dir else OUTPUT_DIR

ensure_model(self.model_path)
Expand Down
40 changes: 34 additions & 6 deletions dataflow_agent/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,37 @@
from .paper2video_subprocess import run_paper2video_via_subprocess
from .registry import RuntimeRegistry

# ---- 1. 自动发现并导入所有工作流定义模块 ---------------------------------
_pkg_path = Path(__file__).resolve().parent
for py in _pkg_path.glob("wf_*.py"):
mod_name = f"{__name__}.{py.stem}"
importlib.import_module(mod_name)
_imported_workflow_modules: set[str] = set()


def _workflow_module_names() -> list[str]:
return sorted(
f"{__name__}.{py.stem}"
for py in _pkg_path.glob("wf_*.py")
)


def _import_workflow_modules_until(name: str | None = None) -> None:
"""
Lazy-load workflow definition modules.

Importing every workflow during FastAPI startup pulls in heavyweight
dependencies such as torchvision/transformers and makes the backend look
dead for a long time. Only import enough modules to satisfy the requested
workflow registration, and fall back to importing everything only when the
caller explicitly asks for the full list.
"""
if name is not None and name in RuntimeRegistry._workflows:
return

for mod_name in _workflow_module_names():
if mod_name in _imported_workflow_modules:
continue
importlib.import_module(mod_name)
_imported_workflow_modules.add(mod_name)
if name is not None and name in RuntimeRegistry._workflows:
return

# ---- 2. 工作流的统一接口 ---------------------------------------------
def get_workflow(name: str):
Expand All @@ -23,6 +49,7 @@ def get_workflow(name: str):
Returns:
Callable: 用于构建该工作流图的工厂函数
"""
_import_workflow_modules_until(name)
return RuntimeRegistry.get(name)


Expand All @@ -37,8 +64,9 @@ async def run_workflow(name: str, state):


# ---- 3. 工作流注册信息公开接口 -------------------------------------------
list_workflows = RuntimeRegistry.all
def list_workflows():
_import_workflow_modules_until()
return RuntimeRegistry.all()

# ---- 3. 工作流注册信息公开接口 -------------------------------------------
# 提供所有已注册工作流的列表,便于外部查询与 introspection
list_workflows = RuntimeRegistry.all
22 changes: 19 additions & 3 deletions deploy/app_config.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,24 @@
# Shared FastAPI runtime config for deploy scripts.
# Environment variables can override these defaults.

DEPLOY_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(cd "$DEPLOY_DIR/.." && pwd)"

APP_HOST="${APP_HOST:-0.0.0.0}"
APP_PORT="${APP_PORT:-8000}"
APP_WORKERS="${APP_WORKERS:-2}"
APP_WORKERS="${APP_WORKERS:-1}"
APP_CONDA_ENV="${APP_CONDA_ENV:-}"
APP_PYTHON="${APP_PYTHON:-}"
CONDA_SH="${CONDA_SH:-/root/miniconda3/etc/profile.d/conda.sh}"
APP_PYTHON="${APP_PYTHON:-/opt/conda/bin/python}"
APP_FALLBACK_PYTHON="${APP_FALLBACK_PYTHON:-/opt/conda/bin/python}"
CONDA_SH="${CONDA_SH:-/opt/conda/etc/profile.d/conda.sh}"

# Keep the legacy external repo as a fallback only.
PAPER2ANY_ASSET_ROOT="${PAPER2ANY_ASSET_ROOT:-/mnt/paper2any/lz/github-proj/Paper2Any}"
MODEL_SERVER_ENV_FILE="${MODEL_SERVER_ENV_FILE:-logs/model_servers.env}"

SAM3_SERVER_URLS="${SAM3_SERVER_URLS:-http://127.0.0.1:8021}"
# Leave model paths empty by default so deploy/start.sh can prefer repo-local models first.
SAM3_HOME="${SAM3_HOME:-}"
SAM3_CHECKPOINT_PATH="${SAM3_CHECKPOINT_PATH:-}"
SAM3_BPE_PATH="${SAM3_BPE_PATH:-}"
RMBG_MODEL_PATH="${RMBG_MODEL_PATH:-}"
Loading