diff --git a/dataflow/cli.py b/dataflow/cli.py index a27a9d9c..3f0b7b0f 100644 --- a/dataflow/cli.py +++ b/dataflow/cli.py @@ -330,14 +330,39 @@ def eval_local_cmd(): def pdf2model_init(cache: Path = typer.Option(Path("."), help = "Cache dir"), qa: str = typer.Option("kbc", help="Which pipeline to init (vqa or kbc)"), - model: Optional[str] = typer.Option(None, help="Base model name or path")): + model: Optional[str] = typer.Option(None, help="Base model name or path"), + train_backend: str = typer.Option( + "base", + "--train-backend", + help="With --qa kbc: 'base' (LlamaFactory) or a registered dataflex-* backend (see cli_pdf.DATAFLEX_BACKEND_SPECS). vqa only allows 'base'.", + )): if qa not in ["vqa", "kbc"]: _echo(f"Invalid qa type: {qa}. Must be 'vqa' or 'kbc'.", "red") raise typer.Exit(code=1) - + if qa == "vqa": + if train_backend != "base": + _echo("vqa only supports --train-backend base.", "red") + raise typer.Exit(code=1) + else: + from dataflow.cli_funcs.cli_pdf import DATAFLEX_BACKEND_SPECS # type: ignore + + allowed_kbc = {"base", *DATAFLEX_BACKEND_SPECS.keys()} + if train_backend not in allowed_kbc: + supported = ", ".join(sorted(allowed_kbc)) + _echo( + f"Invalid --train-backend={train_backend!r} for --qa kbc. Supported: {supported}.", + "red", + ) + raise typer.Exit(code=1) + try: from dataflow.cli_funcs.cli_pdf import cli_pdf2model_init # type: ignore - cli_pdf2model_init(cache_path=str(cache), qa_type=qa, model_name=model) + cli_pdf2model_init( + cache_path=str(cache), + qa_type=qa, + model_name=model, + pdf2model_train_backend=train_backend, + ) except Exception as e: _echo(f"pdf2model init error: {e}", "red") raise typer.Exit(code=1) @@ -345,7 +370,7 @@ def pdf2model_init(cache: Path = typer.Option(Path("."), @pdf_app.command("train") def pdf2model_train(cache: Path = typer.Option(Path("."), help="Cache dir"), - lf_yaml: Optional[Path] = typer.Option(None, help="LlamaFactory yaml")): + lf_yaml: Optional[Path] = typer.Option(None, help="LlamaFactory yaml (base backend only)")): try: from dataflow.cli_funcs.cli_pdf import cli_pdf2model_train # type: ignore diff --git a/dataflow/cli_funcs/cli_pdf.py b/dataflow/cli_funcs/cli_pdf.py index f688304f..1e9a2838 100644 --- a/dataflow/cli_funcs/cli_pdf.py +++ b/dataflow/cli_funcs/cli_pdf.py @@ -4,6 +4,7 @@ PDF to Model training pipeline with init/train/chat commands """ +import copy import subprocess import sys import yaml @@ -19,6 +20,126 @@ logger = get_logger() +PDF2MODEL_TRAIN_BACKEND_BASE = "base" + +# 可供后续可能的dataflex系列后端接入 +PDF2MODEL_TRAIN_BACKEND_DATAFLEX_LESS = "dataflex-less" + +DATAFLEX_BACKEND_SPECS = { + PDF2MODEL_TRAIN_BACKEND_DATAFLEX_LESS: { + "component_name": "less", + "suffix": "less", + }, +} + + +def _get_dataflex_backend_spec(pdf2model_train_backend: str): + return DATAFLEX_BACKEND_SPECS.get(pdf2model_train_backend) + + +def _assert_pdf2model_train_backend_allowed(qa_type: str, pdf2model_train_backend: str) -> None: + if pdf2model_train_backend.startswith("dataflex-"): + if qa_type != "kbc": + raise AssertionError( + f"pdf2model_train_backend={pdf2model_train_backend} is only valid when qa=kbc, got qa={qa_type!r}" + ) + if _get_dataflex_backend_spec(pdf2model_train_backend) is None: + supported_backends = ", ".join(sorted(DATAFLEX_BACKEND_SPECS.keys())) + raise AssertionError( + f"Unsupported pdf2model_train_backend={pdf2model_train_backend!r}. " + f"Supported dataflex backends: {supported_backends or '(none)'}" + ) + if qa_type == "vqa" and pdf2model_train_backend != PDF2MODEL_TRAIN_BACKEND_BASE: + raise AssertionError(f"vqa only supports pdf2model_train_backend={PDF2MODEL_TRAIN_BACKEND_BASE!r}, got {pdf2model_train_backend!r}") + + +def _path_for_yaml(cwd: Path, target: Path) -> str: + """Prefer path relative to cwd for yaml portability.""" + try: + return str(target.resolve().relative_to(cwd.resolve())) + except ValueError: + return str(target.resolve()) + + +def _build_dataflex_components_yaml_text(selector_cache_dir: str, component_name: str, backend_name: str) -> str: + return f"""# Auto-generated by dataflow pdf2model init ({backend_name}). +selectors: + {component_name}: + name: {component_name} + params: + cache_dir: {selector_cache_dir} + gradient_type: adam + proj_dim: 4096 + seed: 123 + save_interval: 16 +""" + + +def _build_pdf2model_dataflex_config_dict( + *, + cwd: Path, + model_name_or_path: str, + dataset_train: str, + dataset_eval: str, + dataset_dir: Path, + output_dir: Path, + components_cfg_file: Path, + train_config: dict, + component_name: str, +) -> dict: + """LlamaFactory-compatible fields + DataFlex dynamic_train block; ``template`` follows ``train_config`` (from LlamaFactoryTrainer defaults).""" + ds_dir_s = _path_for_yaml(cwd, dataset_dir) + out_s = _path_for_yaml(cwd, output_dir) + comp_s = _path_for_yaml(cwd, components_cfg_file) + # 参数配置,不直接调用dataflex的部件,而是直接构建LlamaFactory的配置,DataFlex的LESS Trainer会从中解析出自己的配置项 + return { + "model_name_or_path": model_name_or_path, + "trust_remote_code": train_config.get("trust_remote_code", True), + "stage": "sft", + "do_train": True, + "finetuning_type": train_config.get("finetuning_type", "lora"), + "lora_target": train_config.get("lora_target", "all"), + "lora_rank": train_config.get("lora_rank", 16), + "lora_alpha": train_config.get("lora_alpha", 32), + "dataset": dataset_train, + "eval_dataset": dataset_eval, + "template": train_config.get("template", "qwen"), + "dataset_dir": ds_dir_s, + "cutoff_len": train_config.get("cutoff_len", 1024), + "overwrite_cache": train_config.get("overwrite_cache", True), + "preprocessing_num_workers": train_config.get("preprocessing_num_workers", 4), + "dataloader_num_workers": train_config.get("dataloader_num_workers", 0), + "seed": train_config.get("seed", 42), + "output_dir": out_s, + "logging_steps": train_config.get("logging_steps", 10), + "save_steps": train_config.get("save_steps", 300), + "plot_loss": train_config.get("plot_loss", True), + "save_only_model": train_config.get("save_only_model", False), + "overwrite_output_dir": True, + "report_to": train_config.get("report_to", "none"), + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": train_config.get("learning_rate", 1.0e-4), + "num_train_epochs": train_config.get("num_train_epochs", 1.0), + "lr_scheduler_type": train_config.get("lr_scheduler_type", "cosine"), + "warmup_ratio": train_config.get("warmup_ratio", 0.1), + "bf16": train_config.get("bf16", True), + "fp16": train_config.get("fp16", False), + "ddp_timeout": train_config.get("ddp_timeout", 180000000), + "per_device_eval_batch_size": 1, + "metric_for_best_model": "eval_loss", + "greater_is_better": False, + "load_best_model_at_end": False, + "eval_strategy": "steps", + "eval_steps": 10, + "train_type": "dynamic_select", + "components_cfg_file": comp_s, + "component_name": component_name, + "warmup_step": 1, + "update_step": 1, + "update_times": 2, + } + def run_script_with_args(script_path: Path, description: str, args: list = None, cwd: str = None) -> bool: """Run a Python script with arguments and real-time output""" @@ -112,16 +233,30 @@ def copy_customizable_scripts(qa_type: str) -> bool: return False -def create_train_config_yaml(cache_path="./", model_name_or_path="Qwen/Qwen2.5-7B-Instruct", qa_type="kbc"): +def create_train_config_yaml( + cache_path="./", + model_name_or_path="Qwen/Qwen2.5-7B-Instruct", + qa_type="kbc", + pdf2model_train_backend: str = PDF2MODEL_TRAIN_BACKEND_BASE, +): """Create train_config.yaml file using built-in LlamaFactory configuration""" + _assert_pdf2model_train_backend_allowed(qa_type, pdf2model_train_backend) + cache_path_obj = Path(cache_path) if not cache_path_obj.is_absolute(): caller_cwd = Path(os.environ.get('PWD', os.getcwd())) cache_path_obj = caller_cwd / cache_path_obj + cwd = Path(os.getcwd()) + # 生成时间戳 timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") model_dir_name = f"pdf2model_cache_{timestamp}" # 改为pdf2model_cache前缀 + dataflex_spec = _get_dataflex_backend_spec(pdf2model_train_backend) + if dataflex_spec: + model_dir_name_dataflex = f"pdf2model_cache_{timestamp}_{dataflex_spec['suffix']}" + else: + model_dir_name_dataflex = model_dir_name cache_dir = cache_path_obj / ".cache" cache_dir.mkdir(parents=True, exist_ok=True) @@ -143,23 +278,68 @@ def create_train_config_yaml(cache_path="./", model_name_or_path="Qwen/Qwen2.5-7 trainer = llamafactory_module.LlamaFactoryTrainer(str(config_file), str(cache_path_obj)) config = trainer.get_default_config() # 只更新必要的动态参数 - config["model_name_or_path"] = model_name_or_path + if model_name_or_path: + config["model_name_or_path"] = model_name_or_path config["output_dir"] = str(cache_path_obj / ".cache" / "saves" / model_dir_name) - config["dataset_dir"] = str(cache_path_obj / ".cache" / "data") + config["dataset_dir"] = str(cache_path_obj / ".cache" / "data") if qa_type == "vqa": dataset_name = "pdf_vqa_dataset" else: dataset_name = "pdf_kbc_dataset" config["dataset"] = dataset_name + dataflex_yaml_path = None + components_yaml_path = cache_dir / "pdf2model_dataflex_components.yaml" + if dataflex_spec: + dataflex_yaml_path = cache_dir / f"pdf2model_dataflex_{dataflex_spec['suffix']}.yaml" + selector_cache_rel = _path_for_yaml( + cwd, cache_dir / f"dataflex_{dataflex_spec['suffix']}_selector_cache" + ) + + if dataflex_spec: + with open(components_yaml_path, "w", encoding="utf-8") as cf: + cf.write( + _build_dataflex_components_yaml_text( + selector_cache_rel, + component_name=dataflex_spec["component_name"], + backend_name=pdf2model_train_backend, + ) + ) + + dataset_eval_name = "pdf_kbc_eval_dataset" + dataflex_config = _build_pdf2model_dataflex_config_dict( + cwd=cwd, + model_name_or_path=config["model_name_or_path"], + dataset_train=dataset_name, + dataset_eval=dataset_eval_name, + dataset_dir=cache_path_obj / ".cache" / "data", + output_dir=cache_path_obj / ".cache" / "saves" / model_dir_name_dataflex, + components_cfg_file=components_yaml_path, + train_config=config, + component_name=dataflex_spec["component_name"], + ) + with open(dataflex_yaml_path, "w", encoding="utf-8") as lf: + yaml.dump(dataflex_config, lf, default_flow_style=False, allow_unicode=True, sort_keys=False, indent=2) + print(f"{dataflex_yaml_path.name} created ({pdf2model_train_backend}): {dataflex_yaml_path}") + print(f"pdf2model_dataflex_components.yaml created: {components_yaml_path}") + pdf2model_state = { "qa": qa_type, + "pdf2model_train_backend": pdf2model_train_backend, "train_config_file_dir": str(config_file), "output_dir": str(cache_path_obj / ".cache" / "saves" / model_dir_name), "dataset": dataset_name, - "dataset_dir": str(cache_path_obj / ".cache" / "data") , - "timestamp": timestamp + "dataset_dir": str(cache_path_obj / ".cache" / "data"), + "timestamp": timestamp, } + if dataflex_spec: + pdf2model_state["dataflex_backend_component"] = dataflex_spec["component_name"] + pdf2model_state["dataflex_train_yaml"] = str(dataflex_yaml_path) + pdf2model_state["dataflex_output_dir"] = str( + cache_path_obj / ".cache" / "saves" / model_dir_name_dataflex + ) + pdf2model_state["dataflex_components_yaml"] = str(components_yaml_path) + with open(cache_dir / "pdf2model_state.json", 'w') as f: json.dump(pdf2model_state, f, indent=2) @@ -179,8 +359,15 @@ def create_train_config_yaml(cache_path="./", model_name_or_path="Qwen/Qwen2.5-7 print(f"Failed to create train_config.yaml: {e}") return None -def generate_dataset_info(cache_path_obj, dataset_name, qa_type): +def generate_dataset_info( + cache_path_obj, + dataset_name, + qa_type, + pdf2model_train_backend: str = PDF2MODEL_TRAIN_BACKEND_BASE, +): """Create dataset_info.json configuration automatically""" + _assert_pdf2model_train_backend_allowed(qa_type, pdf2model_train_backend) + dataset_info_path = cache_path_obj / ".cache" / "data" / "dataset_info.json" dataset_info_path.parent.mkdir(parents=True, exist_ok=True) @@ -210,7 +397,18 @@ def generate_dataset_info(cache_path_obj, dataset_name, qa_type): } } - dataset_info = {dataset_name: config_entry} + # duplicate entry for eval when using DataFlex backend (same file; backend needs eval_dataset registered) + if ( + qa_type == "kbc" + and _get_dataflex_backend_spec(pdf2model_train_backend) is not None + and dataset_name == "pdf_kbc_dataset" + ): + dataset_info = { + dataset_name: copy.deepcopy(config_entry), + "pdf_kbc_eval_dataset": copy.deepcopy(config_entry), + } + else: + dataset_info = {dataset_name: config_entry} try: with open(dataset_info_path, 'w', encoding='utf-8') as f: @@ -281,15 +479,24 @@ def check_required_files(): return True -def cli_pdf2model_init(cache_path: str = "./", model_name: str = "Qwen/Qwen2.5-7B-Instruct", qa_type: str = "kbc") -> bool: +def cli_pdf2model_init( + cache_path: str = "./", + model_name: str = None, + qa_type: str = "kbc", + pdf2model_train_backend: str = PDF2MODEL_TRAIN_BACKEND_BASE, +) -> bool: """ PDF2Model initialization: 0. Copy only customizable scripts to current directory 1. Create train_config.yaml in .cache directory """ + _assert_pdf2model_train_backend_allowed(qa_type, pdf2model_train_backend) + + effective_model = model_name if model_name else "Qwen/Qwen2.5-7B-Instruct" print("Starting PDF2Model initialization...") print(f"Cache directory: {cache_path}") - print(f"Model: {model_name}") + print(f"Model: {effective_model}") + print(f"Train backend: {pdf2model_train_backend}") print(f"Output directory: pdf2model_cache_") # 更新输出目录显示 print("-" * 60) @@ -303,7 +510,12 @@ def cli_pdf2model_init(cache_path: str = "./", model_name: str = "Qwen/Qwen2.5-7 # Step 1: Create training configuration print("Step 1: Creating training configuration...") - config_file = create_train_config_yaml(cache_path, model_name, qa_type) + config_file = create_train_config_yaml( + cache_path, + effective_model, + qa_type, + pdf2model_train_backend=pdf2model_train_backend, + ) if config_file: print("PDF2Model initialization completed!") @@ -318,29 +530,21 @@ def cli_pdf2model_init(cache_path: str = "./", model_name: str = "Qwen/Qwen2.5-7 def get_latest_model_dir(cache_path_obj): - """获取最新的模型目录(基于时间戳)""" + """Latest adapter directory under saves (by mtime; supports pdf2model_cache_* suffixes e.g. _less).""" saves_dir = cache_path_obj / ".cache" / "saves" if not saves_dir.exists(): return None - # 查找所有 pdf2model_cache_ 开头的目录 - model_dirs = [] + candidates = [] for dir_path in saves_dir.iterdir(): - if dir_path.is_dir() and dir_path.name.startswith('pdf2model_cache_'): - # 检查是否包含正确的时间戳格式 (YYYYMMDD_HHMMSS) - timestamp_part = dir_path.name.replace('pdf2model_cache_', '') - if len(timestamp_part) == 15 and timestamp_part[8] == '_': - date_part = timestamp_part[:8] - time_part = timestamp_part[9:] - if date_part.isdigit() and time_part.isdigit() and len(time_part) == 6: - model_dirs.append(dir_path) - - if not model_dirs: + if dir_path.is_dir() and dir_path.name.startswith("pdf2model_cache_"): + candidates.append((dir_path.stat().st_mtime, dir_path)) + + if not candidates: return None - # 按名称排序(时间戳会自然排序) - model_dirs.sort(key=lambda x: x.name, reverse=True) - return model_dirs[0] + candidates.sort(key=lambda x: x[0], reverse=True) + return candidates[0][1] def cli_pdf2model_train(lf_yaml: str = ".cache/train_config.yaml", cache_path: str = "./") -> bool: @@ -355,23 +559,14 @@ def cli_pdf2model_train(lf_yaml: str = ".cache/train_config.yaml", cache_path: s if not cache_path_obj.is_absolute(): cache_path_obj = current_dir / cache_path_obj - config_path_obj = Path(lf_yaml) - if not config_path_obj.is_absolute(): - config_path_obj = current_dir / config_path_obj - - if not verify_environment(): - return False - - if not check_required_files(): - return False - - if not config_path_obj.exists(): - print(f"Training config file not found: {config_path_obj}") + train_cfg_path = cache_path_obj / ".cache" / "train_config.yaml" + if not train_cfg_path.exists(): + print(f"Training config file not found: {train_cfg_path}") print(f"{Style.BRIGHT}Run 'dataflow pdf2model init' first") return False try: - with open(config_path_obj, 'r', encoding='utf-8') as f: + with open(train_cfg_path, "r", encoding="utf-8") as f: train_config = yaml.safe_load(f) or {} except Exception as e: print(f"[ERROR] Could not read train_config.yaml: {e}") @@ -379,15 +574,43 @@ def cli_pdf2model_train(lf_yaml: str = ".cache/train_config.yaml", cache_path: s pdf2model_state_path = cache_path_obj / ".cache" / "pdf2model_state.json" try: - with open(pdf2model_state_path, 'r', encoding='utf-8') as f: + with open(pdf2model_state_path, "r", encoding="utf-8") as f: pdf2model_state = json.load(f) print(f"Loaded pipeline state from {pdf2model_state_path}") except Exception as e: - print(f"Warning: State file not found at {pdf2model_state_path}, using defaults.") - pdf2model_state = {} + print(f"Warning: State file not found at {pdf2model_state_path}, using defaults.") + pdf2model_state = {} qa_type = pdf2model_state.get("qa") - # pdf2model_model_path = pdf2model_state.get("model") + pdf2model_train_backend = pdf2model_state.get( + "pdf2model_train_backend", PDF2MODEL_TRAIN_BACKEND_BASE + ) + + try: + _assert_pdf2model_train_backend_allowed(qa_type or "kbc", pdf2model_train_backend) + except AssertionError as err: + print(f"{Fore.RED}Invalid pdf2model state: {err}{Style.RESET_ALL}") + return False + + if not verify_environment(): + return False + + if _get_dataflex_backend_spec(pdf2model_train_backend): + from dataflow.cli_funcs.pdf2model_pipeline.dataflex_pdf2model_launcher import verify_dataflex_available + + if not verify_dataflex_available(): + return False + + if not check_required_files(): + return False + + lf_path = Path(lf_yaml) + if not lf_path.is_absolute(): + lf_path = current_dir / lf_path + if pdf2model_train_backend == PDF2MODEL_TRAIN_BACKEND_BASE and not lf_path.exists(): + print(f"LlamaFactory training config not found: {lf_path}") + print(f"{Style.BRIGHT}Run 'dataflow pdf2model init' first or pass a valid --lf-yaml") + return False print("-" * 60) @@ -409,17 +632,22 @@ def cli_pdf2model_train(lf_yaml: str = ".cache/train_config.yaml", cache_path: s print(f"\n{Fore.BLUE}Step 2.5: Creating dataset_info.json{Style.RESET_ALL}") # 获取数据集名称 - dataset_name = train_config.get('dataset') + dataset_name = train_config.get("dataset") if isinstance(dataset_name, list): dataset_name = dataset_name[0] # 如果是列表,取第一个 - + if not dataset_name: print("Warning: No dataset name found in train_config.yaml, using default 'kb_qa'") - dataset_name = 'kb_qa' - + dataset_name = "kb_qa" + print(f"Dataset name from config: {dataset_name}") - if not generate_dataset_info(cache_path_obj, dataset_name, qa_type): + if not generate_dataset_info( + cache_path_obj, + dataset_name, + qa_type, + pdf2model_train_backend=pdf2model_train_backend, + ): return False print(f"{Fore.GREEN}✅ Step 2.5: Creating dataset_info.json completed{Style.RESET_ALL}") @@ -435,18 +663,33 @@ def cli_pdf2model_train(lf_yaml: str = ".cache/train_config.yaml", cache_path: s return False # Step 4: Training - script4_path = get_dataflow_script_path("llama_factory_trainer.py") - args4 = ["--config", str(config_path_obj), "--cache", cache_path] - if not run_script_with_args(script4_path, "Step 4: Training", args4, cwd=str(current_dir)): - return False - - # Show completion info - try: - with open(config_path_obj, 'r', encoding='utf-8') as f: - config = yaml.safe_load(f) - actual_output_dir = config.get('output_dir', 'unknown') - except Exception: - actual_output_dir = 'unknown' + if pdf2model_train_backend == PDF2MODEL_TRAIN_BACKEND_BASE: + script4_path = get_dataflow_script_path("llama_factory_trainer.py") + args4 = ["--config", str(lf_path), "--cache", cache_path] + if not run_script_with_args(script4_path, "Step 4: Training (LlamaFactory)", args4, cwd=str(current_dir)): + return False + try: + with open(lf_path, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + actual_output_dir = config.get("output_dir", "unknown") + except Exception: + actual_output_dir = "unknown" + else: + from dataflow.cli_funcs.pdf2model_pipeline.dataflex_pdf2model_launcher import run_dataflex_train + + dfx_yaml = pdf2model_state.get("dataflex_train_yaml") + if not dfx_yaml: + print(f"{Fore.RED}Missing dataflex_train_yaml in pdf2model_state.json; re-run pdf2model init.{Style.RESET_ALL}") + return False + dfx_path = Path(dfx_yaml) + if not dfx_path.is_absolute(): + dfx_path = current_dir / dfx_path + if not dfx_path.exists(): + print(f"{Fore.RED}DataFlex train yaml not found: {dfx_path}{Style.RESET_ALL}") + return False + if not run_dataflex_train(dfx_path, current_dir): + return False + actual_output_dir = pdf2model_state.get("dataflex_output_dir", "unknown") print("Training completed successfully!") print(f"Model saved to: {actual_output_dir}") diff --git a/dataflow/cli_funcs/pdf2model_pipeline/dataflex_pdf2model_launcher.py b/dataflow/cli_funcs/pdf2model_pipeline/dataflex_pdf2model_launcher.py new file mode 100644 index 00000000..cf730d95 --- /dev/null +++ b/dataflow/cli_funcs/pdf2model_pipeline/dataflex_pdf2model_launcher.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +"""Start DataFlex-backed training from the pdf2model CLI. +""" + +from __future__ import annotations + +import os +import random +import shutil +import subprocess +import sys +from pathlib import Path + + +def verify_dataflex_available() -> bool: + """Return True if ``dataflex.launcher`` imports in this Python environment.""" + try: + subprocess.run( + [sys.executable, "-c", "import dataflex.launcher"], + check=True, + capture_output=True, + text=True, + ) + return True + except subprocess.CalledProcessError: + print( + "❌ DataFlex not importable in this Python. Install with: pip install -e /path/to/DataFlex" + ) + return False + + +def _resolve_nproc_per_node() -> str: + explicit = os.environ.get("NPROC_PER_NODE") + if explicit: + return explicit + try: + import torch + + return str(max(torch.cuda.device_count(), 1)) + except Exception: + return "1" + + +def _want_torchrun() -> bool: + v = os.environ.get("FORCE_TORCHRUN", "1") + if str(v).lower() in ("1", "true", "yes"): + return True + try: + import torch + + return torch.cuda.device_count() > 1 + except Exception: + return False + + +def _torchrun_argv() -> list[str]: + exe = shutil.which("torchrun") + if exe: + return [exe] + return [sys.executable, "-m", "torch.distributed.run"] + + +def _launcher_cli_overrides() -> list[str]: + """OmegaConf-style args after yaml; merged in ``dataflex.launcher.read_args``.""" + allow_pin = os.environ.get("PDF2MODEL_DATAFLEX_ALLOW_PIN_MEMORY", "").lower() in ( + "1", + "true", + "yes", + ) or os.environ.get("DATAFLOW_LESS_ALLOW_PIN_MEMORY", "").lower() in ( + "1", + "true", + "yes", + ) + if allow_pin: + return [] + # VL batches (e.g. Qwen2.5-VL) can produce tensors with overlapping storage; pin_memory() crashes. + return ["dataloader_pin_memory=false"] + + +def _build_train_command(yaml_path: Path) -> list[str]: + y = str(yaml_path) + tail = [y, *_launcher_cli_overrides()] + if not _want_torchrun(): + return [sys.executable, "-m", "dataflex.launcher", *tail] + + master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") + master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) + nproc = _resolve_nproc_per_node() + return _torchrun_argv() + [ + f"--nnodes={os.environ.get('NNODES', '1')}", + f"--node_rank={os.environ.get('NODE_RANK', '0')}", + f"--nproc_per_node={nproc}", + f"--master_addr={master_addr}", + f"--master_port={master_port}", + "--module", + "dataflex.launcher", + *tail, + ] + + +def run_dataflex_train(yaml_path: Path, cwd: Path) -> bool: + """ + Run training via DataFlex (``dataflex.launcher``) directly. + + Default env: ``FORCE_TORCHRUN=1``, ``DISABLE_VERSION_CHECK=1``. + ``cwd`` is the pdf2model project root (paths inside yaml are relative to it). + """ + if not yaml_path.is_file(): + print(f"❌ DataFlex train yaml not found: {yaml_path}") + return False + if not verify_dataflex_available(): + return False + + env = os.environ.copy() + env.setdefault("FORCE_TORCHRUN", "1") + env.setdefault("DISABLE_VERSION_CHECK", "1") + + cmd = _build_train_command(yaml_path) + print(f"Running: {' '.join(cmd)}") + print(f"Working directory: {cwd}") + + try: + subprocess.run(cmd, cwd=str(cwd), env=env, check=True, stdout=sys.stdout, stderr=sys.stderr, text=True) + print("✅ DataFlex training completed") + return True + except subprocess.CalledProcessError: + print("❌ DataFlex training failed") + return False + except KeyboardInterrupt: + print("\nTraining interrupted by user") + return False diff --git a/pyproject.toml b/pyproject.toml index d1f125e3..e4033b1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,20 @@ pdf2model =[ "mineru-vl-utils", "flash-mineru" ] +# 为pdf2model接入dataflex所需的依赖配置 +pdf2model-dataflex =[ + "gradio>=4.38.0,<=5.13.0", + "llamafactory[torch,metrics]>=0.9.4", + "vllm>=0.8.0,<0.9.0", + "transformers>=4.48.1,<=4.52.0", + "tokenizers>=0.19.0,<=0.21.1", + "datasets>=2.16.0,<=3.2.0", + "fsspec<=2024.9.0", + "numpy>=1.24,<2.0.0", + "mineru[pipeline]", + "mineru-vl-utils", + "flash-mineru" +] eval =[ "vllm>=0.7.0,<0.9.2" ]