Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions dataflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,22 +330,47 @@ def eval_local_cmd():
def pdf2model_init(cache: Path = typer.Option(Path("."),
help = "Cache dir"),
qa: str = typer.Option("kbc", help="Which pipeline to init (vqa or kbc)"),
model: Optional[str] = typer.Option(None, help="Base model name or path")):
model: Optional[str] = typer.Option(None, help="Base model name or path"),
train_backend: str = typer.Option(
"base",
"--train-backend",
help="With --qa kbc: 'base' (LlamaFactory) or a registered dataflex-* backend (see cli_pdf.DATAFLEX_BACKEND_SPECS). vqa only allows 'base'.",
)):
if qa not in ["vqa", "kbc"]:
_echo(f"Invalid qa type: {qa}. Must be 'vqa' or 'kbc'.", "red")
raise typer.Exit(code=1)

if qa == "vqa":
if train_backend != "base":
_echo("vqa only supports --train-backend base.", "red")
raise typer.Exit(code=1)
else:
from dataflow.cli_funcs.cli_pdf import DATAFLEX_BACKEND_SPECS # type: ignore

allowed_kbc = {"base", *DATAFLEX_BACKEND_SPECS.keys()}
if train_backend not in allowed_kbc:
supported = ", ".join(sorted(allowed_kbc))
_echo(
f"Invalid --train-backend={train_backend!r} for --qa kbc. Supported: {supported}.",
"red",
)
raise typer.Exit(code=1)

try:
from dataflow.cli_funcs.cli_pdf import cli_pdf2model_init # type: ignore
cli_pdf2model_init(cache_path=str(cache), qa_type=qa, model_name=model)
cli_pdf2model_init(
cache_path=str(cache),
qa_type=qa,
model_name=model,
pdf2model_train_backend=train_backend,
)
except Exception as e:
_echo(f"pdf2model init error: {e}", "red")
raise typer.Exit(code=1)


@pdf_app.command("train")
def pdf2model_train(cache: Path = typer.Option(Path("."), help="Cache dir"),
lf_yaml: Optional[Path] = typer.Option(None, help="LlamaFactory yaml")):
lf_yaml: Optional[Path] = typer.Option(None, help="LlamaFactory yaml (base backend only)")):

try:
from dataflow.cli_funcs.cli_pdf import cli_pdf2model_train # type: ignore
Expand Down
Loading
Loading