diff --git a/.agents/skills/biomodals-app-development/SKILL.md b/.agents/skills/biomodals-app-development/SKILL.md index 1cb8e60..9205c08 100644 --- a/.agents/skills/biomodals-app-development/SKILL.md +++ b/.agents/skills/biomodals-app-development/SKILL.md @@ -22,7 +22,7 @@ For new apps, ask the user which data-flow class applies before choosing archite ## Implementation Rules -Keep app code compatible with `biomodals help` and app discovery: +Keep app code compatible with `biomodals app help` and app discovery: - Name files `_app.py` under `src/biomodals/app//`. - Use a user-facing module docstring with upstream links, prerequisites, and output behavior. @@ -30,8 +30,18 @@ Keep app code compatible with `biomodals help` and app discovery: - Use module-level `CONF = AppConfig(...)` for new apps; pin `repo_commit_hash` or `version`. - Let `gpu` and `timeout` be overridden from `os.environ`. - Build runtime images through `patch_image_for_helper(...)`. -- Prefer helpers from `biomodals.helper` and `biomodals.helper.shell` instead of open-coded shell, archive, copy, download, hashing, or warmup logic. -- Name local entrypoints `submit__task(...)` and use Google-style `Args:` docstrings so `biomodals help ` renders flags. +- Before adding app or workflow helpers, check `biomodals.helper` first. + Reuse existing helper APIs for local output paths, shell, archive, copy, + download, hashing, warmup, and serialization behavior; only define local + helpers when the behavior is app-specific and no shared helper fits. +- Prefer `CONF.mounts(...)` for model and output volumes. Import shared volumes + from `biomodals.helper.constant` only when a function needs a nonstandard + mountpoint, a shared database/cache volume, or an explicit `commit()`. + When using `Volume.with_mount_options(...)` directly, combine read-only and + subpath options in one call. +- Avoid extracting trivial two- or three-line helpers that are used only once or + twice. Inline them and add a short comment when the intent is not obvious. +- Name local entrypoints `submit__task(...)` and use Google-style `Args:` docstrings so `biomodals app help ` renders flags. - Use `🧬` for local entrypoint status messages and `💊` for remote Modal-container status messages. - Keep Modal function return values primitive when practical: `int`, `str`, `float`, `bool`, `bytes`, `list`, `dict`, or `None`. Return complex objects @@ -47,14 +57,14 @@ When reviewing or finishing an app change, check: - Discovery: path, filename, app name, and local entrypoint name match CLI expectations. - Reproducibility: upstream version or commit is pinned. - Runtime boundaries: dependencies used only inside Modal images stay lazily imported. -- Volumes: model volumes are read-only for inference unless the tool writes caches there; writable volumes are committed after writes. -- Data flow: quick jobs return `.tar.zst` bytes via `package_outputs(...)`; persistent, resumable, or batch jobs use `CONF.get_out_volume()` or shared volumes. +- Volumes: model/cache mounts use app-specific subdirectories when practical; inference mounts are read-only unless the tool writes caches there; writable volumes are committed after writes; mounted volume paths are logged or returned as `VolumePath` when they cross app/workflow boundaries. +- Data flow: quick jobs return `.tar.zst` bytes via `package_outputs(...)`; persistent, resumable, or batch jobs use `CONF.output_volume`, `CONF.mounts(output_volume=True)`, or shared volumes. - Modal return payloads: prefer primitive, `cloudpickle`-serializable values; avoid returning `Path` objects directly or nested inside tuples, lists, dicts, or dataclasses. - Output safety: local output directories are created, existing tarballs are not overwritten accidentally, and final paths or Modal volume locations are printed. - CLI docs: local entrypoint docstrings use exact Google-style `Args:` formatting with continuation indentation. -- Verification: run `prek run --files ` when practical, plus `uv run biomodals list` and `uv run biomodals help ` for CLI or discovery changes. +- Verification: run `prek run --files ` when practical, plus `uv run biomodals app list` and `uv run biomodals app help ` for CLI or discovery changes. ## Reference diff --git a/.agents/skills/biomodals-app-development/references/app-development.md b/.agents/skills/biomodals-app-development/references/app-development.md index 410fa89..e543ff4 100644 --- a/.agents/skills/biomodals-app-development/references/app-development.md +++ b/.agents/skills/biomodals-app-development/references/app-development.md @@ -6,9 +6,9 @@ This reference is the maintained app-development standard for files under `src/b Biomodals apps are self-contained Modal applications wrapping bioinformatics tools. They live under `src/biomodals/app//`. -- Name app files `_app.py`; the `_app.py` suffix is how `cli.py` discovers apps with `APP_HOME.glob("*/*_app.py")`. +- Name app files `_app.py`; the `_app.py` suffix is how `biomodals.helper.catalog` discovers apps by scanning `src/biomodals/app//*_app.py`. - Place apps in an appropriate category such as `fold/`, `design/`, `score/`, or `bioinfo/`. -- The CLI app name is the filename stem with `_app` stripped, for example `protenix_app.py` becomes `protenix`. +- The catalog entry name is the filename stem with `_app` stripped, for example `protenix_app.py` becomes `protenix`. - Use section banners to keep modules scan-friendly: - module docstring - imports @@ -20,7 +20,7 @@ Biomodals apps are self-contained Modal applications wrapping bioinformatics too ## Module Docstring -The module docstring is rendered verbatim by `biomodals help ` as Markdown. Keep it user-facing and include the upstream source URL, important prerequisites, caveats, and output behavior. +The module docstring is rendered verbatim by `biomodals app help ` as Markdown. Keep it user-facing and include the upstream source URL, important prerequisites, caveats, and output behavior. Typical shape: @@ -48,6 +48,12 @@ Use optional configuration tables only when the local entrypoint docstring is in New apps should define module-level `CONF = AppConfig(...)`. +The pure Pydantic schema lives in `biomodals.schema.app.AppConfig`. Import the +Modal-compatible wrapper from `biomodals.app.config` when you need volume or +image helpers; otherwise import the schema directly. The wrapper adds +`output_volume`, `output_volume_name`, and `mounts(...)` helpers while keeping +the same schema fields and validators. + ```python from biomodals.app.config import AppConfig @@ -62,6 +68,7 @@ CONF = AppConfig( cuda_version="cu128", gpu=os.environ.get("GPU", "L40S"), timeout=int(os.environ.get("TIMEOUT", "3600")), + depends_on_apps=("gromacs",), # only for workflows that compose other apps ) ``` @@ -70,9 +77,11 @@ Rules: - Pin either `repo_commit_hash` or `version`, or both. - Let `gpu` and `timeout` be overridden by environment variables with sensible defaults. - Use `CONF.default_env` when setting image environment variables. It provides standard UV, Hugging Face, Torch, and torch backend environment. -- Use `CONF.model_dir`, `CONF.git_clone_dir`, `CONF.model_volume_mountpoint`, and related fields instead of hardcoded paths. +- Use `CONF.git_clone_dir`, `CONF.model_volume_mountpoint`, + `CONF.model_volume_subdir`, and related fields instead of hardcoded paths. +- Set `depends_on_apps` only for workflow apps that compose other Biomodals apps; standalone apps should leave it empty. -Use an `AppInfo` dataclass only when grouping several related app constants improves readability. For a few simple constants, module-level constants such as `OUT_VOLUME` or `OUTPUTS_DIR` are acceptable. +Use an `AppInfo` dataclass only when grouping several related app constants improves readability. Prefer `CONF.output_volume`, `CONF.output_volume_name`, and `CONF.mounts(...)` over module-level output-volume aliases in new code. ## Image Construction @@ -97,11 +106,41 @@ app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) ## Volumes -- Import shared volumes from `biomodals.app.constant`, such as `MODEL_VOLUME` or `MSA_CACHE_VOLUME`. -- Mount model weights read-only for inference when the function only reads model artifacts. -- Use `CONF.model_volume_mountpoint` for model volume mount paths. -- Commit volume changes explicitly after writes with `VOLUME.commit()`. -- Use `CONF.get_out_volume()` for app-specific persistent outputs. +- Import shared volumes from `biomodals.helper.constant`, such as `MODEL_VOLUME` or `MSA_CACHE_VOLUME`. +- Prefer `CONF.mounts(...)` for standard app model and output mounts. Use raw + `Volume.with_mount_options(...)` only for nonstandard mountpoints, shared + database/cache volumes, or code that must explicitly access a volume object. +- Mount only the subdirectory a function needs when a shared volume contains + app-specific data. For model weights under `MODEL_VOLUME`, the usual pattern + is `CONF.mounts(model_volume=True)`, which mounts `CONF.model_volume_subdir` + at `CONF.model_volume_mountpoint`. +- Mount model weights read-only for inference when the function only reads + model artifacts. If both read-only and subpath behavior are needed, use + `CONF.mounts(model_volume=True)`, or + `MODEL_VOLUME.with_mount_options(read_only=True, sub_path=...)` for custom + mountpoints. + Do not chain `MODEL_VOLUME.read_only().with_mount_options(...)`; Modal rejects + adding mount options after a read-only wrapper has already been created. +- Use `CONF.model_volume_mountpoint` for app-specific model directories. Use + `CONF.mounts(model_volume=True, is_huggingface=True)` only when the tool + stores Hugging Face-managed artifacts under `CONF.default_env` paths such as + `HF_HOME`. +- When upstream code expects a hardcoded cache path, mount the app-specific + shared-volume subdirectory at that path rather than changing unrelated app + logic. PaddleOCR, AbNatiV, and AntiFold are examples of this pattern. +- Shared cache volumes such as `MSA_CACHE_VOLUME` should also use subpath mounts + when an app only needs its own namespace. Shared database volumes that expose a + complete database root can stay mounted whole. +- For download/setup functions that populate model volumes, use + `CONF.mounts(model_volume=True, model_ro=False)` and commit the backing shared + volume after writes. +- Commit output or cache volume changes explicitly after writes with + `VOLUME.commit()`. +- Use `CONF.output_volume` and `CONF.mounts(output_volume=True)` for + app-specific persistent outputs; output volumes are normally mounted whole. +- Use `volume_path_from_mount_path(...)` when printing or returning remote + volume paths so logs show a validated `VolumePath` instead of an ambiguous + absolute container path. ## Remote Functions @@ -117,6 +156,9 @@ Always specify a timeout with `CONF.timeout` or `MAX_TIMEOUT`. Add resource hint `float`, `bool`, `bytes`, `list`, `dict`, or `None`. Return complex objects only when they provide much more benefit than a primitive representation, and the returned type must be serializable by `cloudpickle`. +- Workflow-compatible app functions are the main exception to the primitive + preference: return `AppRunResult` from `biomodals.schema` so workflows can + materialize `AppOutput` artifacts consistently. - Keep `Path` objects internal to the local process or Modal container. Return file paths, volume paths, and relative output paths as `str(path)`, including paths nested inside tuples, lists, dicts, or dataclasses. Convert back with @@ -131,7 +173,7 @@ Resource pattern: cpu=(0.125, 16.125), memory=(1024, 65536), timeout=MAX_TIMEOUT, - volumes={CONF.model_volume_mountpoint: MODEL_VOLUME.read_only()}, + volumes=CONF.mounts(model_volume=True), ) ``` @@ -152,6 +194,11 @@ Prefer existing helpers instead of reimplementing common behavior: - `hash_string(s)` from `biomodals.helper` for cache keys. - `patch_image_for_helper(image)` from `biomodals.helper` for Modal images. +Avoid extracting trivial two- or three-line helpers that are used once or twice. +Inline those operations with a short comment when that reads better. Add a +local helper only when the behavior is app-specific, repeated enough to clarify +the module, or absent from `biomodals.helper`. + ## Local Entrypoint The `@app.local_entrypoint()` function is the user-facing orchestration layer on the local machine. @@ -164,7 +211,7 @@ The `@app.local_entrypoint()` function is the user-facing orchestration layer on - Write returned tarball bytes locally. - Print final local path or Modal volume location. -Docstring rules for `biomodals help`: +Docstring rules for `biomodals app help`: - Use Google-style docstrings with an `Args:` section. - Put `Args:` on its own line. @@ -187,6 +234,9 @@ Choose architecture by job type: - Short-lived inference usually sends local input bytes to remote functions and returns tarball bytes directly. - Long-running apps should cache intermediate and final results in Modal volumes. - Parallel or interruptible runs should use queues, locks, stable run IDs, and resumable runners where possible. +- Workflow-compatible app functions should reuse existing remote app behavior + where practical, preserve standalone local entrypoints unchanged, and return + `AppRunResult` with `VolumePath` storage for durable outputs. Before choosing data flow for a new app, ask whether it is short-lived inference, long-running/cached, or parallel/resumable unless already clear from the request. @@ -203,10 +253,10 @@ Older apps can use raw constants such as `GPU`, `TIMEOUT`, and `APP_NAME`. When ## Examples And Verification -- When app development changes invocation or adds a new app, add or update an example bash script under `examples/app/` using `biomodals run`. +- When app development changes invocation or adds a new app, add or update an example bash script under `examples/app/` using `biomodals app run`. - Use small example inputs under `examples/data/` only when existing data is insufficient. - For Modal functions, verify returned payloads are primitive or otherwise intentionally complex and `cloudpickle`-serializable; convert returned paths to strings. - After edits, run `prek run --files ` when practical. -- For CLI or app discovery changes, smoke test `uv run biomodals list` and `uv run biomodals help ` when practical. +- For CLI or app discovery changes, smoke test `uv run biomodals app list` and `uv run biomodals app help ` when practical. diff --git a/.agents/skills/biomodals-workflow-development/SKILL.md b/.agents/skills/biomodals-workflow-development/SKILL.md new file mode 100644 index 0000000..b9ae8d2 --- /dev/null +++ b/.agents/skills/biomodals-workflow-development/SKILL.md @@ -0,0 +1,54 @@ +--- +name: biomodals-workflow-development +description: Use when creating, editing, or reviewing Biomodals workflow code under src/biomodals/workflow/, shared workflow schemas under src/biomodals/schema/, workflow-compatible app functions, or workflow CLI/tests, including ShortMD-style DAG construction, orchestrator composition, app dependency inclusion, workflow artifacts, and Modal volume handling. +--- + +# Biomodals Workflow Development + +Use this skill for Biomodals workflow scripts, the reusable workflow runtime, +workflow schemas, and workflow-compatible app integration points. + +## Core Workflow + +Before making non-trivial workflow changes, read +`references/workflow-development.md` for the maintained standards. + +Use `src/biomodals/workflow/shortmd_workflow.py` as the primary end-to-end +example for app-composed workflows. Ignore +`src/biomodals/workflow/ppiflow_workflow.py` as a reference pattern for now; it +is expected to be refactored. + +## Working Rules + +- Keep `biomodals.schema` pure Pydantic and free of Modal imports. +- Compose workflow apps with `from biomodals.workflow.core import orchestrator` + and `modal.App(...).include(orchestrator.app)`. +- Declare app dependencies on `AppConfig.depends_on_apps`, mirror them into + `CONF.tags["depends_on"]` for Modal UI metadata, and compose them with + `include_dependency_apps(app, CONF.depends_on_apps)`. +- Prefer included-app Modal handles over deployed-app lookup strings. Do not add + `modal.Function.from_name(...)` to new workflow code when the dependency app + can be included. +- Prefer `AppBackedNode` for nodes that primarily call app functions. + Add `WorkflowNativeNode` only for adapters, summaries, selectors, and + workflow-specific file-management glue. +- Store hydrated Modal functions/classes in a small `*ModalNamespace` dataclass + typed as `modal.Function` or `modal.Cls`, and exclude that namespace from DAG + hashing with `repr=False`, `compare=False`, and `metadata={"dag_hash": False}`. +- Define workflow-specific remote file-management functions as top-level + `@app.function`s in the workflow module and put their hydrated handles in the + workflow's `*ModalNamespace`. Do not make ordinary node methods Modal + functions. +- Import app-owned volume handles, volume names, and mountpoints from source app + modules. Avoid duplicating volume strings in workflow scripts. +- Use `volume_path_from_mount_path(...)` to convert mounted app paths into + `VolumePath` workflow storage references. +- Keep the core runtime slim. Add public orchestrator/runtime API only for clear + missing capabilities, not one-off workflow conveniences. + +## Verification + +For workflow changes, run focused pytest coverage first, then `prek run --files +` when practical. For CLI or discovery changes, also smoke test +`uv run biomodals workflow list` and the affected `biomodals workflow help/run` +path. diff --git a/.agents/skills/biomodals-workflow-development/agents/openai.yaml b/.agents/skills/biomodals-workflow-development/agents/openai.yaml new file mode 100644 index 0000000..53a5830 --- /dev/null +++ b/.agents/skills/biomodals-workflow-development/agents/openai.yaml @@ -0,0 +1,4 @@ +interface: + display_name: "Biomodals Workflow Development" + short_description: "Build ShortMD-style Biomodals workflows" + default_prompt: "Use $biomodals-workflow-development to update a Biomodals workflow while following the ShortMD reference pattern." diff --git a/.agents/skills/biomodals-workflow-development/references/workflow-development.md b/.agents/skills/biomodals-workflow-development/references/workflow-development.md new file mode 100644 index 0000000..58a767d --- /dev/null +++ b/.agents/skills/biomodals-workflow-development/references/workflow-development.md @@ -0,0 +1,429 @@ +# Biomodals Workflow Development + +Use this guide when creating or changing files under +`src/biomodals/workflow/` or shared workflow contracts under +`src/biomodals/schema/`. + +Use `src/biomodals/workflow/shortmd_workflow.py` as the primary end-to-end +workflow example. Ignore `src/biomodals/workflow/ppiflow_workflow.py` as a +reference pattern for now; it is expected to be refactored. + +## Vocabulary + +- **App**: a deployed Modal app that owns tool runtime and app functions. +- **App Function**: a callable remote Modal function exposed by an app. +- **Local Entrypoint**: a CLI-only `@app.local_entrypoint`. +- **Workflow-Compatible App Function**: a remote app function returning + `AppRunResult`. +- **Workflow Node**: one semantic DAG vertex. +- **App-Backed Node**: a workflow node that calls app functions. +- **Workflow-Native Node**: a workflow node implemented in workflow code. +- **Workflow Runtime**: validates and schedules workflow nodes. +- **Workflow Orchestrator**: the Modal-hosted coordinator class hosting the runtime for one workflow run. +- **Workflow Artifact**: durable data passed between workflow nodes. +- **Artifact Selector**: a named reference to upstream artifacts. + +Avoid the terms `app node`, `runner node`, `engine`, and `workflow +entrypoint`; they are ambiguous in this codebase. + +## ShortMD Reference Pattern + +ShortMD is the current reference for executable workflow apps. Its data flow is: + +1. The local entrypoint discovers local `.pdb` files, sanitizes the workflow + `run_id`, reads PDB bytes, builds a static `Workflow`, and submits that + object to the included `WorkflowOrchestrator`. +2. The workflow app composes the shared orchestrator and the GROMACS app with + `modal.App(...).include(orchestrator.app)` plus + `include_dependency_apps(app, CONF.depends_on_apps)`. +3. `ShortMDModalNamespace` carries the hydrated GROMACS functions and + workflow-native remote functions across the orchestrator boundary. +4. `ShortMDPrepNode` prepares one input PDB once through the GROMACS app. +5. `ShortMDCloneNode` clones prepared production inputs into per-replicate + directories. This file management is workflow-native because the standalone + GROMACS app does not need it. +6. `ShortMDReplicateNode` runs each production replicate through the GROMACS app + and collects trajectory stats. +7. `ShortMDSummaryNode` emits a Markdown report from completed production + artifacts. + +Follow this structure for new app-composed workflows: stage local inputs before +DAG construction, build a static fan-out DAG, keep app-specific runtime work in +included app functions, keep workflow-only adapters in the workflow module, and +return durable artifacts as `AppRunResult` outputs. + +## Schema Boundaries + +Shared contracts live in `biomodals.schema`. + +Schema modules must not import `modal`, `biomodals.app`, or +`biomodals.workflow`. They should contain Pydantic models and primitive fields +only. The shared `AppConfig` Pydantic schema lives in `biomodals.schema.app`. +Modal-specific helpers that construct volumes, images, or apps must stay in +`biomodals.app` or `biomodals.helper`, with compatibility imports allowed during +the transition from `biomodals.app.config`. + +Workflow-compatible app functions return `AppRunResult`. The workflow runtime +materializes each `AppOutput` into one or more `WorkflowArtifact` manifests. +Inline byte outputs are for UTF-8 text bytes only. They must be written into the +workflow run volume before they cross a node boundary. Binary outputs, archives, +and other non-text bytes must be written to deterministic volume paths and +returned as `VolumePath` storage. + +`AppRunResult.logs` are durable workflow artifacts too. The runtime writes log +outputs under `nodes//attempts//logs/` and records +artifact manifests for them so failed or partial attempts retain diagnostic +state. + +Volume path outputs may either be referenced in place or copied into the +workflow run volume when the source volume is mounted locally. Reference mode is +the default because many app outputs are already durable in their owning app +volume. Copy mode is for workflows that need a self-contained run directory. + +The first workflow runtime is Python-first. Pass a `Workflow` object across the +orchestrator boundary; serialized workflow dictionaries are intentionally +deferred until the node and app-function contracts stabilize. + +## Node Execution Policy + +Every workflow node checks durable SQLite run state before execution and skips work +when completed artifact manifests already exist. + +Incomplete nodes use one of two policies: + +- `RERUN`: discard incomplete attempt state and recompute. +- `RESUME`: use a durable node cache to resume or skip completed subwork. + +Long-running nodes must be idempotent against deterministic run, node, input, +and attempt identifiers. Store resumable state in volumes, not container-local +scratch paths. + +`AppRunStatus.PARTIAL` is terminal but not successful in the first runtime. The +runtime records the node as failed, records the run as failed, preserves logs, +and does not unblock downstream nodes. + +Forced workflow runs replace the existing run directory before creating a fresh +ledger. Use force only when discarding previous artifacts, node caches, and +attempt records is intentional. + +## Node Placement + +Use `ORCHESTRATOR` placement for lightweight workflow-native logic such as +filtering, ranking, reporting, and small manifest transforms. + +Use `REMOTE` placement for long-running work, app-backed work, and work that +benefits from failure isolation. + +The runtime routes `REMOTE` nodes through an injected remote-node runner when +one is available. The Modal orchestrator supplies a thin remote runner that +executes one node in a separate Modal function and commits workflow volume +writes after node code returns. Unit tests use fake runners and must not call +live Modal APIs. + +## Ledger Layout + +The first durable run layout is: + +```text +/// + ledger.sqlite3 + inputs/ + nodes/ + / + attempts/ + / + logs/ + raw_outputs/ + materialized_outputs/ + cache/ + artifacts/ + / + final/ +``` + +The workflow ledger is one SQLite database per run. The orchestrator is the only +ledger writer. Remote nodes write deterministic output files and logs, then the +orchestrator reloads the volume, reconciles those files, and updates the ledger. + +Ledger updates mutate SQLite rows directly. Do not preserve obsolete +Pydantic-status update patterns such as `model_copy(update=...)` for ledger +state. + +After orchestrator ledger writes inside Modal containers, call `commit()`. +Before reading data written by another container, call `reload()`. Resuming a +run with a different DAG hash fails unless the run is forced, because stale node +state cannot safely be reused across workflow definition changes. + +Record a Modal `FunctionCall.object_id` in `remote_calls` immediately after +submitting remote node work. On orchestrator startup or restart, reattach with +`modal.FunctionCall.from_id(call_id)` and poll before launching replacement +work. Reconcile existing pending, succeeded, failed, or expired calls and their +deterministic output files before applying `RERUN` or `RESUME`. Do not blindly +resubmit work while an older call may still be writing the same node outputs. + +Use these tables for the first ledger schema: + +```text +runs(run_id, workflow_name, dag_hash, status, created_at, updated_at, metadata_json) +nodes(node_id, status, execution_policy, placement, current_attempt_id, error, started_at, completed_at, updated_at) +attempts(attempt_id, node_id, status, started_at, completed_at, app_result_json, error, metadata_json) +remote_calls(call_id, node_id, attempt_id, function_name, call_kind, status, submitted_at, completed_at, error, metadata_json) +artifacts(artifact_id, producing_node_id, kind, volume_name, storage_path, source_app_output_name, created_at, metadata_json) +artifact_files(artifact_id, path, role, media_type, size_bytes, metadata_json) +node_inputs(node_id, input_name, artifact_id) +node_outputs(node_id, artifact_id) +``` + +Keep large payloads in files and store paths in SQLite. Store non-Pydantic +metadata JSON text with `orjson`. Store Pydantic payload snapshots with +`model_dump_json()` and load them with `model_validate_json(...)`. A human +should be able to debug a run with `sqlite3` by +checking `runs.status`, stalled rows in `nodes`, outstanding `remote_calls`, +and artifact paths in `artifacts` plus `artifact_files`. + +## Modal Preemption + +All Modal functions are subject to preemption. Treat remote functions as +restartable with the same inputs. + +Remote workflow code should: + +- split long work into smaller retryable tasks; +- expose enough attempt status, artifacts, and logs for the orchestrator to + record ledger state before and after work; +- write cache checkpoints for `RESUME` nodes; +- use deterministic output paths from run and node identifiers; +- leave enough artifacts and logs to reconcile after restart. + +## Fan-Out + +The first workflow runtime supports static DAG fan-out. Build one node per known +unit of work during DAG construction, as ShortMD does for per-PDB preparation +and per-replicate production runs. + +Use barriered fan-out first: a node starts only after all declared upstream +dependencies are complete. Streaming between nodes is deferred. + +Independent ready nodes may run in parallel when all dependencies for each node +are satisfied. + +## Orchestrator Submission + +The reusable workflow orchestrator lives under `biomodals.workflow.core` and is +not a user-facing workflow script. Workflow scripts should import the module and +compose its app into their own Modal app: + +```python +from biomodals.workflow.core import orchestrator + +app = modal.App(...).include(orchestrator.app) +``` + +All remote orchestration functions should live as methods on +`WorkflowOrchestrator`. Workflow apps may use the included +`WorkflowOrchestrator` methods for run submission, but the reusable orchestrator +must not perform deployed app lookups, import workflow app functions by name, or +handle hydration details for workflow-specific apps. Domain-specific input +staging and DAG construction belong in top-level workflow scripts. + +Keep the public orchestrator method surface minimal. The intended remote methods +are `WorkflowOrchestrator.run(...)` for a whole workflow run and +`WorkflowOrchestrator.run_node(...)` for isolated remote node execution. Do not +add convenience wrappers or alternate submission APIs unless they cover a large +missing capability or a clear ergonomics gap. + +The reusable orchestrator module should not expose a local entrypoint for generic +workflow submission. Each user-facing workflow script owns its own local +entrypoint, stages its own inputs, builds its `Workflow` object, and submits that +object to the included `WorkflowOrchestrator`. + +The orchestrator API accepts `Workflow` objects only. Workflow scripts build the +DAG locally and submit that object to the included `WorkflowOrchestrator` +method. The orchestrator should not accept serialized workflow dictionaries or +workflow factory import strings as its primary run contract. Workflow node +classes must therefore be importable in remote containers by their canonical +package-qualified module names. + +## CLI Namespace + +Use `biomodals app ...` for app commands and `biomodals workflow ...` for +workflow commands. App and workflow discovery should live behind catalog helper +APIs; `cli.py` should not import app or workflow home constants directly. + +The workflow namespace should expose `list` and `help` first. Other workflow +commands can exist as placeholders until the runtime execution interface is +stable. Existing top-level app commands may remain as deprecated aliases for one +transition period, but documentation and smoke tests should prefer the +namespaced commands. + +Workflows should be launched through the `biomodals workflow run` CLI rather +than by running workflow Python files directly. The run command is responsible +for importing workflow modules through the catalog/package path so workflow node +classes serialize with stable canonical module names before being submitted to +the included `WorkflowOrchestrator`. Its user-facing flags should mirror +`biomodals app run`, including Modal mode, detach, timeout, and pass-through +workflow flags after `--`. +The command may accept workflow paths only when they resolve to package-qualified +modules under the Biomodals workflow package. Reject ad hoc workflow files that +cannot be imported by a stable package module path. +Use Modal's module mode for workflow runs, for example +`python -m modal run -m biomodals.workflow.shortmd_workflow::submit_shortmd_workflow`, +so local and remote containers agree on workflow node class module names. + +## Workflow App Composition + +Workflow scripts should compose every Modal app they need at import time. Define +dependency app names once on `AppConfig.depends_on_apps`, mirror that list into +`CONF.tags["depends_on"]` for Modal UI visibility, and call +`include_dependency_apps(app, CONF.depends_on_apps)` after including the shared +orchestrator app. + +```python +DEPENDENCY_APPS = ("gromacs",) +CONF = AppConfig( + name="ShortMDWorkflow", + depends_on_apps=DEPENDENCY_APPS, + tags={"depends_on": ",".join(DEPENDENCY_APPS)}, +) + +app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags).include( + orchestrator.app, inherit_tags=True +) +app = include_dependency_apps(app, CONF.depends_on_apps) +``` + +`depends_on_apps` is a composition declaration, not a deployment command. Do not +auto-deploy dependency apps from workflow submission paths. Including dependency +apps gives the workflow access to hydrated Modal functions and classes while +letting Modal reuse normal image caching behavior. + +Import dependency app modules directly for app metadata, Modal function handles, +volume objects, volume names, and mountpoints. Do not duplicate volume names, +mount paths, or app function names as workflow-local string constants when the +source app exports them. + +## App Interfaces + +Local entrypoints stay CLI-only. They parse local paths, submit remote work, +download or report outputs, print user messages, and return `None`. + +Workflow reuse happens through workflow-compatible remote app functions. These +functions may reuse behavior from local entrypoints or existing remote +functions, but they return `AppRunResult` and avoid local filesystem UX. + +For new Biomodals workflows that depend on other Biomodals apps, prefer +included-app Modal handles over deployed-app lookup strings. Avoid +`modal.Function.from_name(...)` in workflow definitions when the dependency app +can be included; remote orchestrator containers can otherwise re-import the +workflow module and see unhydrated function globals. Use deployed-app lookup only +for legacy workflows or external apps that cannot be composed into the workflow +app, and document that reason near the node. + +When nodes need included Modal functions or classes, group those hydrated +objects in a small workflow-local dataclass named `*ModalNamespace`. Type the +fields as `modal.Function` or `modal.Cls`; avoid overly generic callable +protocols. Store the namespace on nodes as runtime-only state: + +```python +@dataclass(frozen=True) +class ShortMDModalNamespace: + prepare_gpu: modal.Function + production_gpu: modal.Function + + +@dataclass +class ShortMDPrepNode(AppBackedNode): + modal_namespace: ShortMDModalNamespace = field( + repr=False, + compare=False, + metadata={"dag_hash": False}, + ) +``` + +The namespace is allowed to cross the orchestrator boundary because it contains +Modal objects from apps included into the workflow app. Excluding it from the +DAG hash keeps retry and resume behavior tied to semantic workflow inputs rather +than runtime hydration objects. + +Prefer `AppBackedNode` for nodes whose primary job is to invoke app functions. +Workflow definitions should reuse existing app functions whenever possible. Add +`WorkflowNativeNode` implementations only when the source app lacks a needed +function or when workflow-specific adapters are required to transform artifacts +between apps. Use native nodes for lightweight transforms, selectors, summaries, +and file-management glue that is not part of the source app's standalone +contract. + +If a workflow-native adapter needs a remote Modal boundary, define a top-level +`@app.function` in the workflow module and put that hydrated function in the +workflow's `*ModalNamespace`. Do not try to make ordinary node methods remote +Modal methods; node methods are plain Python methods unless the node itself is a +Modal `@app.cls`, which is not the generic workflow-node model. + +Keep workflow-specific file cloning, cleanup, and adapter logic in workflow +scripts, not in app modules, when the standalone app does not require that +behavior. Conversely, if a function is useful to the app outside workflows, add +it to the app and preserve the app's existing standalone local entrypoints. + +Group repeated app arguments in a compact workflow settings dataclass when that +keeps node constructors readable. Avoid extracting trivial two- or three-line +helpers that are used once or twice; inline those operations with a comment if +the intent is not obvious. + +## Volumes And Artifacts + +Workflows that import multiple apps should treat each app's volume metadata as +owned by that app. Import volume handles, volume names, and mountpoints from the +source app module rather than hardcoding them in the workflow. + +When an app function returns an absolute path under its mounted volume, convert +that path to workflow storage with +`biomodals.helper.volume_run.volume_path_from_mount_path(...)`. The helper takes +`str` inputs and returns a single validated `VolumePath`; do not construct a +`VolumePath` only to extract `.path` and wrap it again. + +Workflow-native remote functions that mutate mounted volumes must call +`reload()` before reading data written by other containers and `commit()` after +writing, copying, or deleting files. Validate artifact storage paths with +`VolumePath` before joining them to mounted paths. + +## DAG Construction + +Build workflow DAGs locally from already-staged primitive data or Pydantic +models. Discover local inputs before DAG construction, sanitize user-derived +identifiers with `sanitize_filename`, and reject duplicate sanitized names. Use +stable node ids derived from sanitized names and deterministic indices so +resume, force, and ledger debugging stay predictable. + +Use static fan-out when the input cardinality is known at submission time. For +example, create one prep node per input, one clone node per replicate, one +production node per clone, and a final summary node that depends on all +production outputs. Keep per-run namespace prefixes explicit when the same input +filenames may appear across workflow runs. + +Summary/report nodes should usually be `WorkflowNativeNode` instances with +`ORCHESTRATOR` placement when they only aggregate manifests or emit text +reports. Return reports as UTF-8 `InlineBytes`; return binary files, +directories, and archives as durable `VolumePath` outputs. + +When adding a workflow-compatible app function, keep existing local entrypoint +behavior unchanged and add a focused pytest contract test that does not call +Modal live APIs. + +## Testing + +Keep tests under top-level `tests/`. + +Use pytest for non-Modal tests. Tests must not call `.remote()`, `.spawn()`, +`modal.Function.from_name(...)`, real `modal.Queue`, real `modal.Volume`, or +deployed Modal apps. Mock Modal boundaries with fake objects and deterministic +`AppRunResult` or `WorkflowArtifact` payloads. + +For included-app workflows, tests should assert that the workflow app declares +the expected `depends_on_apps`, composes dependency apps through +`include_dependency_apps`, and imports app-owned volume metadata instead of +hardcoding it. Patch `modal.Function.from_name` to fail in tests that exercise +new included-app nodes so accidental deployed-app lookup regressions are caught. + +Use fake Modal namespace objects at node boundaries. Cast those fakes to +`modal.Function` or `modal.Cls` in tests when needed to satisfy static typing; +the production node contract should remain Modal-object based. diff --git a/.gitignore b/.gitignore index d306fbf..d1af3a1 100644 --- a/.gitignore +++ b/.gitignore @@ -217,3 +217,10 @@ __marimo__/ # Streamlit .streamlit/secrets.toml + +# Zed editor +.zed/ + +# Agent-generated local files +PLAN.md +.claude/settings.local.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1e80cfa..c5e0669 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,34 +8,25 @@ repos: # list of supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace - id: end-of-file-fixer - - id: check-docstring-first - - id: check-yaml - id: debug-statements - id: detect-private-key - id: check-executables-have-shebangs - id: check-toml + - id: check-yaml - id: check-case-conflict + - id: check-merge-conflict - id: check-added-large-files args: ["--maxkb", "16384"] # python code formatting - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.15.9" + rev: "v0.15.13" hooks: - id: ruff-check - args: - [ - "--select", - "D,E,F,W,I,S,UP", - "--extend-ignore", - "E402,E501,F401,F841,S101", - "--exclude", - "logs/*,data/*", - "--fix", - "--exit-non-zero-on-fix", - ] - types_or: [python, pyi, jupyter] + args: ["--fix", "--exit-non-zero-on-fix"] + types_or: [python, pyi, jupyter, pyproject] - id: ruff-format + types_or: [python, pyi, jupyter] # yaml formatting # - repo: https://github.com/google/yamlfmt @@ -50,17 +41,12 @@ repos: - id: shellcheck # md formatting - - repo: https://github.com/executablebooks/mdformat - rev: 0.7.22 + - repo: https://github.com/rvben/rumdl-pre-commit + rev: v0.2.0 hooks: - - id: mdformat - args: ["--number"] - additional_dependencies: - - mdformat-gfm - - mdformat-tables - - mdformat_frontmatter - # - mdformat-toc - # - mdformat-black + - id: rumdl # Lint only (fails on issues) + args: ["-d", "MD013"] # comma-separated ignored rules + - id: rumdl-fmt # Auto-format and fail if issues remain # jupyter notebook cell output clearing - repo: https://github.com/kynan/nbstripout diff --git a/AGENTS.md b/AGENTS.md index 51387e0..3fe7229 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -3,12 +3,16 @@ ## Repository expectations - This is a Python 3.12+ project for running bioinformatics tools on Modal. -- Prefer `uv run ...` for project commands and `biomodals` CLI smoke tests. +- Prefer `uv run ...` for project commands and `biomodals` CLI smoke tests. Never run `python ...` directly. - Use `polars` for tabular data parsing and writing in Python code; avoid new `csv` or `pandas` parsing unless an upstream tool API specifically requires it. +- Use `orjson` for non-Pydantic JSON serialization and deserialization. For + Pydantic models, serialize with `model_dump_json()` and parse JSON bytes or + strings with `model_validate_json(...)`. - CI runs `prek` against `.pre-commit-config.yaml`; after code edits, run `prek run --files ` when practical. -- For CLI or app-discovery changes, smoke test with `uv run biomodals list` and `uv run biomodals help ` when practical. +- For CLI or app-discovery changes, smoke test with `uv run biomodals app list`, `uv run biomodals app help `, and `uv run biomodals workflow list` when practical. - Keep generated archives, large run outputs, Modal result directories, and local test data out of commits unless the user explicitly asks for them. +- Avoid extracting trivial helper functions that are only a couple of lines and used once or twice; inline the logic and add comments when that is clearer. ## Instruction maintenance @@ -37,6 +41,13 @@ This repo is built on Modal, a serverless cloud platform for running Python code When creating, editing, or reviewing files under `src/biomodals/app/**/*_app.py`, use the repo-local `biomodals-app-development` skill. See `docs/agents/app-development.md`. +### Workflow development + +When creating or editing reusable workflow runtime code under +`src/biomodals/workflow/` or shared workflow schemas under +`src/biomodals/schema/`, use the repo-local `biomodals-workflow-development` +skill. See `docs/agents/workflow-development.md`. + ## Biomodals app development The detailed app-development standards are consolidated in `.agents/skills/biomodals-app-development/`. @@ -48,3 +59,9 @@ Use these apps as current implementation references: - `src/biomodals/app/design/boltzgen_app.py` When developing new apps that must violate the skill's conventions for good reason, document the reason for the deviation in `docs/agents/` and link that note from `docs/agents/app-development.md`. + +## Biomodals workflow development + +The detailed workflow-development standards are consolidated in `.agents/skills/biomodals-workflow-development/`. + +Use `src/biomodals/workflow/shortmd_workflow.py` as the primary end-to-end workflow reference. Ignore `src/biomodals/workflow/ppiflow_workflow.py` as a reference pattern for now because it is expected to be refactored. diff --git a/CONTEXT.md b/CONTEXT.md new file mode 100644 index 0000000..c8dcc7c --- /dev/null +++ b/CONTEXT.md @@ -0,0 +1,176 @@ + + +# Biomodals + +Biomodals runs bioinformatics tools as Modal apps and composes them into reusable computational workflows. + +## Language + +**Workflow Artifact**: +A durable record of data produced or consumed by a workflow step, including its data category, storage location, and metadata needed by downstream steps. +_Avoid_: raw app output, untyped file path, loose tarball + +**Inline Byte Output**: +A workflow-compatible app output whose bytes are UTF-8 text small enough to serialize directly in a Pydantic JSON payload before materialization. +_Avoid_: binary archive, non-text bytes + +**Workflow Node**: +A semantic step in a workflow DAG that consumes workflow artifacts and produces workflow artifacts. +_Avoid_: Modal function, app function + +**App**: +A deployed Modal app that owns tool runtime, images, volumes, and exported app functions. +_Avoid_: workflow node, app node + +**App Function**: +A callable Modal remote function exposed by a Biomodals app or another Modal app and invoked by a workflow node. +_Avoid_: workflow node + +**Local Entrypoint**: +A CLI-facing Modal entrypoint that parses local user inputs, submits app functions, downloads or reports outputs, and returns no workflow contract. +_Avoid_: workflow entrypoint + +**CLI Namespace**: +A top-level `biomodals` command group that separates app commands from workflow commands. +_Avoid_: mixed app/workflow command collection + +**Workflow-Compatible App Function**: +An app function with standardized workflow input and output schemas suitable for app-backed workflow nodes. +_Avoid_: local entrypoint, submit function + +**Shared Schema**: +A stable Pydantic contract used across Biomodals packages without depending on app or workflow implementation modules. +_Avoid_: app config, internal model + +**App Configuration Schema**: +The pure Pydantic fields and validators that describe a Biomodals app's metadata and runtime settings. +_Avoid_: Modal volume factory, image helper + +**App-Backed Node**: +A workflow node implemented by calling one or more app functions. +_Avoid_: app node, runner node + +**Workflow-Native Node**: +A workflow node implemented directly in workflow code for orchestration, transformation, selection, ranking, packaging, or reporting. +_Avoid_: runtime node, orchestrator node + +**Workflow Builder**: +A Python interface for declaring workflow nodes, dependencies, artifact selectors, and execution settings before a workflow run. +_Avoid_: workflow YAML, scheduler config + +**Artifact Selector**: +A named input reference that selects upstream workflow artifacts by kind, file role, path pattern, metadata, or producing node. +_Avoid_: raw input path, wildcard-only dependency + +**Control Edge**: +A dependency between workflow nodes that enforces execution order without passing workflow artifacts. +_Avoid_: dummy artifact + +**Dynamic Task Fan-Out**: +A workflow node execution pattern where the DAG node is fixed but the number of per-input tasks is determined from upstream artifacts at runtime. +_Avoid_: dynamic DAG + +**Worker Pool**: +A fixed-size set of remote workers spawned by one workflow node to process that node's runtime task queue. +_Avoid_: server pool, runner server + +**Workflow Runtime**: +The reusable library that validates a workflow DAG, schedules workflow nodes, tracks durable run state, and materializes workflow artifacts. +_Avoid_: engine + +**Workflow Orchestrator**: +A Modal-hosted coordinator that owns one workflow run, hosts the workflow runtime, records durable run state, and uses Modal lifecycle hooks to reconcile interrupted work. +_Avoid_: workflow node, runner + +**Workflow Ledger**: +A per-run SQLite database written by the workflow orchestrator that records run, node, attempt, remote-call, fan-out task, and artifact state for recovery and manual debugging. +_Avoid_: scattered JSON state files, worker-owned database + +**Node Placement**: +The execution location for a workflow node, either inline in the workflow orchestrator or in a separate remote Modal function. +_Avoid_: runner location, execution site + +**Node Execution Policy**: +The restart and recovery contract for an incomplete workflow node when Modal interrupts or retries the node. +_Avoid_: runner tag, retry hint + +**Durable Node Cache**: +Volume-backed intermediate state that lets a long-running workflow node resume or safely skip completed work after restart. +_Avoid_: temporary scratch, local cache + +## Relationships + +- A **Workflow Builder** defines a workflow DAG in Python code. +- A **Workflow Builder** connects nodes through named **Artifact Selectors** and optional **Control Edges**. +- A **Workflow Node** declares a **Node Execution Policy**. +- A **Workflow Node** declares **Node Placement**. +- A **Workflow Node** may use **Dynamic Task Fan-Out** without changing the workflow DAG shape. +- A **Workflow Node** may use a **Worker Pool** to process dynamically fanned-out tasks. +- A **Workflow Runtime** schedules **Workflow Nodes** and does not contain tool-specific biological logic. +- A **Workflow Runtime** may run independent ready nodes in parallel when all of each node's dependencies are satisfied. +- A **Workflow Orchestrator** runs the **Workflow Runtime** remotely on Modal and is responsible for run-level lifecycle recovery. +- A **Workflow Orchestrator** is the only writer to the **Workflow Ledger**. +- Remote workflow nodes and workers write deterministic files and logs; the **Workflow Orchestrator** reconciles those files into the **Workflow Ledger**. +- A **Workflow Orchestrator** records Modal function call ids before waiting on remote work and reattaches to those calls during recovery before starting replacement work. +- A workflow step produces zero or more **Workflow Artifacts**. +- A workflow step consumes zero or more **Workflow Artifacts** from upstream steps. +- A **Workflow Artifact** references durable files stored in a remote Modal volume. +- An **Inline Byte Output** is normalized into a volume-backed **Workflow Artifact** before it crosses a workflow node boundary. +- Binary or non-text app outputs are written to volume paths and represented as volume-backed artifacts, not serialized as inline bytes. +- A **Workflow Node** may invoke one or more **App Functions** to fulfill one semantic step. +- An **App** may expose many **App Functions**. +- A **Local Entrypoint** remains CLI-only and should not be called by the workflow orchestrator. +- The **CLI Namespace** separates `biomodals app ...` commands from `biomodals workflow ...` commands. +- A **Workflow-Compatible App Function** may reuse behavior from a **Local Entrypoint**, but exposes a remote app function contract for workflows. +- A **Shared Schema** may be imported by app and workflow modules, but it must not import app or workflow modules. +- An **App Configuration Schema** lives in `biomodals.schema` when it is shared across apps. +- Modal-specific helpers that construct volumes, images, or app objects wrap the **App Configuration Schema** outside `biomodals.schema`. +- App-specific configuration models remain with their app until they become stable cross-module contracts. +- An **App-Backed Node** calls one or more **App Functions** and processes their outputs into workflow artifacts. +- A **Workflow-Native Node** performs lightweight workflow logic without calling a bioinformatics app. +- A long-running **Workflow Node** must use a **Durable Node Cache** so interruption and restart do not corrupt outputs or repeat unsafe work. +- A short-running **Workflow Node** may choose a rerun-on-restart policy when recomputation is cheaper than durable checkpointing. +- A lightweight **Workflow-Native Node** may run inline in the **Workflow Orchestrator** when remote execution overhead is not justified. +- A long-running or failure-isolated **Workflow Node** should run as a separate remote Modal function. +- Every **Workflow Node** checks durable run state before execution and skips work when completed artifacts already exist. +- A **Workflow** may compose **App Functions** from any Modal app when those functions can be described by node input and output contracts. + +## Example dialogue + +> **Dev:** "Should the LigandMPNN app pass its tarball directly to FlowPacker?" +> **Domain expert:** "No — the workflow should record a **Workflow Artifact** for the LigandMPNN result, materialize it into the run volume, and let FlowPacker consume the artifact's selected structure files." +> +> **Dev:** "Is AF3Score four workflow steps because it has lock, prepare, run, and postprocess functions?" +> **Domain expert:** "No — those are **App Functions** inside one **Workflow Node** when the workflow cares about one scoring step." +> +> **Dev:** "Should users write workflow YAML first?" +> **Domain expert:** "No — complex workflows should start with the **Workflow Builder** so node contracts and artifact selectors stay explicit in Python." +> +> **Dev:** "Can every interrupted node just run again?" +> **Domain expert:** "Only short-running nodes should default to rerun. Long-running nodes need a **Durable Node Cache** and a **Node Execution Policy** that makes restart behavior explicit." +> +> **Dev:** "Does the workflow orchestrator spawn app-backed nodes as runners?" +> **Domain expert:** "No — the **Workflow Orchestrator** runs the **Workflow Runtime**, the runtime schedules **Workflow Nodes**, and an **App-Backed Node** calls app functions as its implementation." +> +> **Dev:** "Should every node be its own Modal function?" +> **Domain expert:** "No — **Node Placement** determines whether the node runs inline for lightweight workflow logic or remotely for long-running and failure-isolated work." +> +> **Dev:** "Can the workflow call an app's local entrypoint?" +> **Domain expert:** "No — **Local Entrypoints** stay CLI-only. Workflows call **Workflow-Compatible App Functions** that may be derived from the same behavior." +> +> **Dev:** "How does one node consume only PDB files from an upstream design step?" +> **Domain expert:** "Use an **Artifact Selector** that names the upstream node, selects structure artifacts, and filters files by role or pattern." +> +> **Dev:** "Does one PPIFlow output structure create a new node in the DAG?" +> **Domain expert:** "No — the downstream **Workflow Node** uses **Dynamic Task Fan-Out** to create per-structure tasks while the DAG shape stays fixed." +> +> **Dev:** "If two scoring nodes depend on the same design node, should they wait for each other?" +> **Domain expert:** "No — once their shared upstream dependency is complete, the **Workflow Runtime** can schedule both nodes in parallel." + +## Flagged ambiguities + +- "artifact" can mean either inline app bytes or remote files. Resolved: an **Inline Byte Output** is a UTF-8 text app output before materialization; a **Workflow Artifact** is durable volume-backed state after materialization. +- "step" can mean either a semantic workflow operation or one callable remote function. Resolved: use **Workflow Node** for the semantic DAG unit and **App Function** for a Modal remote callable. +- "app node" can mean either a Modal deployment unit or a DAG vertex backed by that app. Resolved: use **App** for the deployment unit and **App-Backed Node** for the DAG vertex. +- "workflow entrypoint" can be confused with Modal's local entrypoint. Resolved: use **Workflow-Compatible App Function** for reusable remote app functions and **Local Entrypoint** for CLI wrappers. +- "dynamic workflow" can mean changing the DAG at runtime or changing only the task count. Resolved: first-version workflows use static DAGs with **Dynamic Task Fan-Out** only. diff --git a/biomodals b/biomodals index dd3bc1b..6ac3367 100755 --- a/biomodals +++ b/biomodals @@ -12,7 +12,7 @@ if [ $# -eq 0 ] || { [ $# -eq 1 ] && [[ "$1" == "${DEFAULT_BIN_NAME}" ]] ; }; th uv run --project "${SCRIPT_DIR}" "${DEFAULT_BIN_NAME}" --help # Run the subcommand if the first argument matches a known subcommand -elif [ $# -gt 0 ] && [[ "$1" =~ ^(list|ls|l|run|r|help|h|deploy|d)$ ]]; then +elif [ $# -gt 0 ] && [[ "$1" =~ ^(app|workflow|list|ls|l|run|r|help|h|deploy|d)$ ]]; then uv run --project "${SCRIPT_DIR}" "${DEFAULT_BIN_NAME}" "$@" # Otherwise, run the command as-is within the uv environment diff --git a/docs/agents/app-development.md b/docs/agents/app-development.md index 89ea394..a0b9f61 100644 --- a/docs/agents/app-development.md +++ b/docs/agents/app-development.md @@ -11,13 +11,23 @@ The previous `.github/instructions/app-development.instructions.md` file has bee - Invoke or read the `biomodals-app-development` skill before creating, editing, or reviewing Biomodals app files. - Treat the skill as the baseline for app discovery, `AppConfig`, Modal image construction, helper usage, volumes, data flow, local entrypoint docstrings, examples, and smoke tests. +- For app model/output volumes, prefer `CONF.mounts(...)`. For shared Modal + volumes with custom mountpoints, mount only the needed subdirectory with + `Volume.with_mount_options(sub_path=...)` and combine read-only and subpath + options in the same call when inference should not write to model artifacts. +- Treat `AppConfig` as a shared schema from `biomodals.schema.app`; keep + Modal-specific volume and image helpers outside `biomodals.schema`. - Compare non-trivial app changes against the current reference apps: - `src/biomodals/app/fold/alphafold3_app.py` - `src/biomodals/app/bioinfo/rosetta_app.py` - `src/biomodals/app/design/boltzgen_app.py` +- When adding workflow-compatible app functions, also follow + `docs/agents/workflow-development.md`. ## Maintenance - Update the skill when app-development standards change. - Keep this document as a pointer and coordination note, not a duplicate copy of the skill. - If an app needs to intentionally deviate from the skill, add a focused note under `docs/agents/` explaining why and link it from this document. +- Keep local entrypoints CLI-only. Workflow reuse should happen through remote + app functions that return shared schemas from `biomodals.schema`. diff --git a/docs/agents/workflow-development.md b/docs/agents/workflow-development.md new file mode 100644 index 0000000..3a72a1e --- /dev/null +++ b/docs/agents/workflow-development.md @@ -0,0 +1,25 @@ +# Biomodals Workflow Development + +Detailed workflow-development instructions for `src/biomodals/workflow/` and +shared workflow schemas under `src/biomodals/schema/` live in the repo-local +skill: + +- `.agents/skills/biomodals-workflow-development/SKILL.md` +- `.agents/skills/biomodals-workflow-development/references/workflow-development.md` + +## How Agents Should Use It + +- Invoke or read the `biomodals-workflow-development` skill before creating, + editing, or reviewing Biomodals workflow code. +- Treat `src/biomodals/workflow/shortmd_workflow.py` as the primary end-to-end + reference workflow. +- Ignore `src/biomodals/workflow/ppiflow_workflow.py` as a reference pattern + for now because it is expected to be refactored. +- When adding workflow-compatible app functions under `src/biomodals/app/`, also + follow `docs/agents/app-development.md`. + +## Maintenance + +- Update the workflow skill when workflow standards change. +- Keep this document as a pointer and coordination note, not a duplicate copy of + the skill. diff --git a/examples/app/abnativ.sh b/examples/app/abnativ.sh index ffa9c21..cf4639d 100755 --- a/examples/app/abnativ.sh +++ b/examples/app/abnativ.sh @@ -6,14 +6,14 @@ fi SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) BIOMODALS_ROOT=$(realpath "${SCRIPT_DIR}/../../") -ENTRY_BIN=$(realpath "${BIOMODALS_ROOT}/run.sh") +ENTRY_BIN=$(realpath "${BIOMODALS_ROOT}/biomodals") pembro_vh='QVQLVQSGVEVKKPGASVKVSCKASGYTFTNYYMYWVRQAPGQGLEWMGGINPSNGGTNFNEKFKNRVTLTTDSSTTTAYMELKSLQFDDTAVYYCARRDYRFDMGFDYWGQGTTVTVSS' pembro_vl='EIVLTQSPATLSLSPGERATLSCRASKGVSTSGYSYLHWYQQKPGQAPRLLIYLASYLESGVPARFSGSGSGTDFTLTISSLEPEDFAVYYCQHSRDLPLTFGGGTKVEIKTSENLYFQ' temp_dir=$(mktemp -d) -"${ENTRY_BIN}" r abnativ -- \ +"${ENTRY_BIN}" app r abnativ -- \ --run-name biomodals_abnativ_example \ --out-dir "${temp_dir}" \ --input-vh-seq "${pembro_vh}" \ diff --git a/examples/app/af3score.sh b/examples/app/af3score.sh new file mode 100755 index 0000000..ef512e3 --- /dev/null +++ b/examples/app/af3score.sh @@ -0,0 +1,20 @@ +#!/bin/bash +set -euo pipefail +if [ "${DEBUG:-0}" -eq 1 ]; then + set -x +fi + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +BIOMODALS_ROOT=$(realpath "${SCRIPT_DIR}/../../") +ENTRY_BIN=$(realpath "${BIOMODALS_ROOT}/biomodals") + +pembro_pdb="${SCRIPT_DIR}/../data/5B8C.pdb.gz" + +temp_dir=$(mktemp -d) +mkdir -p "${temp_dir}/inputs" +gunzip -c "${pembro_pdb}" > "${temp_dir}/inputs/5B8C.pdb" +"${ENTRY_BIN}" app r af3score -- \ + --input-dir "${temp_dir}/inputs" \ + --output-dir "${temp_dir}/outputs" \ + --run-name 'biomodals-af3score-test' \ + --max-batches 2 diff --git a/examples/app/antifold.sh b/examples/app/antifold.sh index c4482cb..d4dc1bc 100755 --- a/examples/app/antifold.sh +++ b/examples/app/antifold.sh @@ -6,14 +6,14 @@ fi SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) BIOMODALS_ROOT=$(realpath "${SCRIPT_DIR}/../../") -ENTRY_BIN=$(realpath "${BIOMODALS_ROOT}/run.sh") +ENTRY_BIN=$(realpath "${BIOMODALS_ROOT}/biomodals") pembro_pdb="${SCRIPT_DIR}/../data/5B8C.pdb.gz" temp_dir=$(mktemp -d) gunzip -c "${pembro_pdb}" > "${temp_dir}/5B8C.pdb" -"${ENTRY_BIN}" r antifold -- \ +"${ENTRY_BIN}" app r antifold -- \ --run-name biomodals_antifold_example \ --struct-file "${temp_dir}/5B8C.pdb" \ --out-dir "${temp_dir}" \ diff --git a/examples/app/flowpacker.sh b/examples/app/flowpacker.sh new file mode 100755 index 0000000..67b8be3 --- /dev/null +++ b/examples/app/flowpacker.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -euo pipefail +if [ "${DEBUG:-0}" -eq 1 ]; then + set -x +fi + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +BIOMODALS_ROOT=$(realpath "${SCRIPT_DIR}/../../") +ENTRY_BIN=$(realpath "${BIOMODALS_ROOT}/biomodals") + +pembro_pdb="${SCRIPT_DIR}/../data/5B8C.pdb.gz" + +temp_dir=$(mktemp -d) + +gunzip -c "${pembro_pdb}" > "${temp_dir}/5B8C.pdb" +"${ENTRY_BIN}" app r flowpacker -- \ + --input-path "${temp_dir}/5B8C.pdb" \ + --out-dir "${temp_dir}" \ + --use-confidence diff --git a/examples/app/ligandmpnn.sh b/examples/app/ligandmpnn.sh index 05f031c..9728461 100755 --- a/examples/app/ligandmpnn.sh +++ b/examples/app/ligandmpnn.sh @@ -6,21 +6,21 @@ fi SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) BIOMODALS_ROOT=$(realpath "${SCRIPT_DIR}/../../") -ENTRY_BIN=$(realpath "${BIOMODALS_ROOT}/run.sh") +ENTRY_BIN=$(realpath "${BIOMODALS_ROOT}/biomodals") pembro_pdb="${SCRIPT_DIR}/../data/5B8C.pdb.gz" temp_dir=$(mktemp -d) gunzip -c "${pembro_pdb}" > "${temp_dir}/5B8C.pdb" -"${ENTRY_BIN}" r ligandmpnn -- \ +"${ENTRY_BIN}" app r ligandmpnn -- \ --run-name biomodals_ligandmpnn_score_example \ --input-pdb "${temp_dir}/5B8C.pdb" \ --out-dir "${temp_dir}" \ --script-mode score \ --model-type ligand_mpnn -"${ENTRY_BIN}" r ligandmpnn -- \ +"${ENTRY_BIN}" app r ligandmpnn -- \ --run-name biomodals_ligandmpnn_design_example \ --input-pdb "${temp_dir}/5B8C.pdb" \ --out-dir "${temp_dir}" \ diff --git a/examples/app/ppiflow_app.sh b/examples/app/ppiflow_app.sh index e53c61e..2db1f27 100755 --- a/examples/app/ppiflow_app.sh +++ b/examples/app/ppiflow_app.sh @@ -6,7 +6,7 @@ fi SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) BIOMODALS_ROOT=$(realpath "${SCRIPT_DIR}/../../") -ENTRY_BIN=$(realpath "${BIOMODALS_ROOT}/run.sh") +ENTRY_BIN=$(realpath "${BIOMODALS_ROOT}/biomodals") pembro_pdb="${SCRIPT_DIR}/../data/5B8C.pdb.gz" vhh_pdb="${SCRIPT_DIR}/../data/7eow_nanobody_framework.pdb.gz" @@ -16,5 +16,5 @@ cd "${temp_dir}" || exit 1 gunzip -c "${pembro_pdb}" > "${temp_dir}/5B8C.pdb" gunzip -c "${vhh_pdb}" > "${temp_dir}/7eow_nanobody_framework.pdb" -"${ENTRY_BIN}" r ppiflow -- \ +"${ENTRY_BIN}" app r ppiflow -- \ --input-yaml "${SCRIPT_DIR}/../data/ppiflow_vhh.yaml" diff --git a/examples/app/rosetta.sh b/examples/app/rosetta.sh index a5213f0..7ffeb9d 100755 --- a/examples/app/rosetta.sh +++ b/examples/app/rosetta.sh @@ -6,7 +6,7 @@ fi SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) BIOMODALS_ROOT=$(realpath "${SCRIPT_DIR}/../../") -ENTRY_BIN=$(realpath "${BIOMODALS_ROOT}/run.sh") +ENTRY_BIN=$(realpath "${BIOMODALS_ROOT}/biomodals") pembro_pdb="${SCRIPT_DIR}/../data/5B8C.pdb.gz" @@ -15,6 +15,6 @@ cd "${temp_dir}" || exit 1 gunzip -c "${pembro_pdb}" > "${temp_dir}/5B8C.pdb" cp -an "${SCRIPT_DIR}/../data/rosetta_example.csv" "${temp_dir}/rosetta_example.csv" -"${ENTRY_BIN}" r rosetta -- \ +"${ENTRY_BIN}" app r rosetta -- \ --input-csv "${temp_dir}/rosetta_example.csv" \ --max-num-pods 2 diff --git a/pyproject.toml b/pyproject.toml index 6e928f0..acc9d50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,9 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ "aiofiles>=25.1.0", - "modal>=1.2.1", + "modal>=1.4.3", "niquests>=3.18.4", + "orjson>=3.11.8", "pydantic>=2.12.5", "pyyaml>=6.0.3", "rich>=14.2.0", @@ -26,4 +27,31 @@ build-backend = "uv_build" [dependency-groups] dev = [ "ipykernel>=7.2.0", + "pytest>=9.0.3", ] + +[tool.ruff] +extend-exclude = ["logs/*", "data/*"] +extend-include = ["*.ipynb"] +line-length = 88 + +[tool.ruff.format] +preview = true + +[tool.ruff.lint] +extend-ignore = ["E501"] +select = [ + "D", + "E", + "F", + "W", + "I", + "S", + "UP", + "B" +] +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = ["S101"] diff --git a/src/biomodals/app/bioinfo/gromacs/postprocess-traj.sh b/src/biomodals/app/bioinfo/gromacs/postprocess-traj.sh index 3535c53..3b8c05d 100755 --- a/src/biomodals/app/bioinfo/gromacs/postprocess-traj.sh +++ b/src/biomodals/app/bioinfo/gromacs/postprocess-traj.sh @@ -120,7 +120,7 @@ mkdir -p "${out_dir}" ########################################## # Extract protein trajectory, center, and remove PBC ########################################## -if [ ! -f "${output_stem}_centered.pdb" ]; then +if [ ! -f "${output_file}" ]; then # 1. make molecules whole printf "Protein\nProtein\n" | \ @@ -157,6 +157,9 @@ printf "Protein\nProtein\n" | \ -f "${output_file}.fit1st.xtc" -o "${output_file}" \ -center -ur compact -pbc mol -boxcenter tric && \ rm "${output_file}.fit1st.xtc" +fi + +if [ ! -f "${output_stem}_centered.pdb" ]; then echo 'Protein' | "${GROMACS}" trjconv -s "${tpr_file}" -f "${output_file}" -dump 0 -o "${output_stem}_centered.gro" echo 'Protein' | "${GROMACS}" trjconv -s "${tpr_file}" -f "${output_file}" -dump 0 -o "${output_stem}_centered.pdb" diff --git a/src/biomodals/app/bioinfo/gromacs_app.py b/src/biomodals/app/bioinfo/gromacs_app.py index 9d3de9c..c7da4cb 100644 --- a/src/biomodals/app/bioinfo/gromacs_app.py +++ b/src/biomodals/app/bioinfo/gromacs_app.py @@ -12,14 +12,15 @@ import os from dataclasses import dataclass -from pathlib import Path, PurePosixPath +from pathlib import Path import modal from biomodals.app.config import AppConfig -from biomodals.app.constant import MAX_TIMEOUT from biomodals.helper import patch_image_for_helper +from biomodals.helper.constant import MAX_TIMEOUT from biomodals.helper.shell import run_command +from biomodals.helper.volume_run import volume_path_from_mount_path ########################################## # Modal configs @@ -32,7 +33,7 @@ python_version="3.13", cuda_version="cu128", gpu=os.environ.get("GPU", "L40S"), - timeout=int(os.environ.get("TIMEOUT", str(MAX_TIMEOUT))), + timeout=int(os.environ.get("TIMEOUT", MAX_TIMEOUT)), ) @@ -53,10 +54,8 @@ class AppInfo: # Image and app definitions ########################################## APP_INFO = AppInfo() -OUTPUTS_VOLUME = CONF.get_out_volume() -OUTPUTS_VOLUME_NAME = OUTPUTS_VOLUME.name or f"{CONF.name}-outputs" -runtime_image = patch_image_for_helper( +runtime_image = ( modal.Image .from_registry( "nvidia/cuda:12.8.1-devel-ubuntu24.04", add_python=CONF.python_version @@ -173,13 +172,15 @@ class AppInfo: "echo 'source /usr/local/gromacs/bin/GMXRC' >> /etc/profile", ) .add_local_dir(Path(__file__).parent / "gromacs", APP_INFO.gmx_scripts, copy=True) + .pipe(patch_image_for_helper) ) -biotite_image = patch_image_for_helper( +biotite_image = ( modal.Image .debian_slim(python_version=CONF.python_version) .apt_install("git", "build-essential") .uv_pip_install("biotite", "numpy", "scipy", "seaborn", "matplotlib") + .pipe(patch_image_for_helper) ) app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) @@ -205,7 +206,7 @@ def file1_needs_update(file1: Path, file2: Path) -> bool: cpu=APP_INFO.gmx_threads + 0.125, memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=CONF.timeout, - volumes={CONF.output_volume_mountpoint: OUTPUTS_VOLUME}, + volumes=CONF.mounts(output_volume=True), ) def prepare_tpr_gpu( pdb_content: bytes, @@ -240,7 +241,7 @@ def prepare_tpr_gpu( input_pdb_path = work_path / f"{run_name}.pdb" input_pdb_path.write_bytes(pdb_content) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() script_path = Path(APP_INFO.gmx_scripts) / "prepare-tpr.sh" if not script_path.exists(): @@ -268,7 +269,7 @@ def prepare_tpr_gpu( cmd.append("--use-openmp-threads") # Modal adds this automatically but we want Gromacs to handle threading _ = run_command(cmd, cwd=str(work_path), env={"OMP_NUM_THREADS": None}) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() return str(work_path) @@ -276,7 +277,7 @@ def prepare_tpr_gpu( cpu=APP_INFO.gmx_threads + 0.125, memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=CONF.timeout, - volumes={CONF.output_volume_mountpoint: OUTPUTS_VOLUME}, + volumes=CONF.mounts(output_volume=True), ) def prepare_tpr_cpu( pdb_content: bytes, @@ -311,7 +312,7 @@ def prepare_tpr_cpu( input_pdb_path = work_path / f"{run_name}.pdb" input_pdb_path.write_bytes(pdb_content) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() script_path = Path(APP_INFO.gmx_scripts) / "prepare-tpr.sh" if not script_path.exists(): @@ -340,7 +341,7 @@ def prepare_tpr_cpu( # Modal adds this automatically but we want Gromacs to handle threading _ = run_command(cmd, cwd=str(work_path), env={"OMP_NUM_THREADS": None}) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() return str(work_path) @@ -348,10 +349,21 @@ def prepare_tpr_cpu( image=runtime_image, memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=CONF.timeout, - volumes={CONF.output_volume_mountpoint: OUTPUTS_VOLUME}, + volumes=CONF.mounts(output_volume=True), ) -def find_traj_last_step(traj_file: str) -> int: - """Calculated simulated steps from the simulation time (in ps) in a trajectory.""" +def find_traj_last_time_ns(traj_file: str) -> float: + """Calculate the last-readable simulation time (ns) in a trajectory. + + In our setup, dt=2fs=0.002ps; `gmx check` normally reports the simulation + time in ps, so we can convert it to #steps by dividing by `dt=0.002`. + + Because we setup the simulation by inputting the expected nanoseconds, + #steps = ns * 500000. + + When the simulation was interrupted, `gmx check` may only report the #frames + and timestep size, so we need to manually calculate the closest last step + that is within the trajectory bounds. + """ import shutil traj_path = Path(traj_file) @@ -365,13 +377,27 @@ def find_traj_last_step(traj_file: str) -> int: cmd = [gmx, "check", "-f", str(traj_path)] result = run_command(cmd, cwd=traj_path.parent, verbose=False) - last_time = None for line in result: + # Last frame 20000 time 200000.000 if line.startswith("Last frame"): - last_time = float(line.strip().split(" ")[-1]) - if last_time is None: - raise ValueError("Last frame time not found in trajectory") - return int(last_time / 0.002) # dt=2 fs, which is 0.002 ps + last_time_ps = float(line.strip().split(" ")[-1]) + return last_time_ps * 0.001 + + # Be robust in case the run was interrupted + # Item #frames Timestep (ps) + # Step 20001 10 + header_line_idx = -1 + header_cols = ["Item", "#frames", "Timestep", "(ps)"] + for i, line in enumerate(result): + if line.startswith("Item") and line.strip().split() == header_cols: + header_line_idx = i + break + if header_line_idx != -1: + readable_line = result[header_line_idx + 1].strip() + _, frames, timestep_ps = readable_line.split() + return float((int(frames) - 1) * float(timestep_ps)) * 0.001 + + raise ValueError("Last frame time not found in trajectory") @app.function( @@ -379,7 +405,7 @@ def find_traj_last_step(traj_file: str) -> int: cpu=APP_INFO.gmx_threads + 0.125, memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=CONF.timeout, - volumes={CONF.output_volume_mountpoint: OUTPUTS_VOLUME}, + volumes=CONF.mounts(output_volume=True), ) def production_run_gpu( run_name: str, @@ -399,11 +425,10 @@ def production_run_gpu( # Pick up exisiting trajectory and continue simulation when checkpoint exists traj_file_path = work_path / f"{deffnm}.xtc" checkpoint_file_path = work_path / f"{deffnm}.cpt" - nsteps = -2 # default: find nsteps from the mdp file + nsteps = -2 # default: use nsteps from the prepared TPR if traj_file_path.exists() and checkpoint_file_path.exists(): - simulated_steps = find_traj_last_step.remote(str(traj_file_path)) - total_steps = simulation_time_ns * 500000 # 2 fs timestep - nsteps = total_steps - simulated_steps + simulated_ns = find_traj_last_time_ns.remote(str(traj_file_path)) + nsteps = int((simulation_time_ns - simulated_ns) * 500000) # 2 fs timestep if nsteps <= 0: print("✅ Production run already completed, skipping.") return str(work_path) @@ -441,7 +466,7 @@ def production_run_gpu( # Modal adds this automatically but we want Gromacs to handle threading _ = run_command(cmd, cwd=str(work_path), env={"OMP_NUM_THREADS": None}) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() return str(work_path) @@ -449,7 +474,7 @@ def production_run_gpu( cpu=APP_INFO.gmx_threads + 0.125, memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=CONF.timeout, - volumes={CONF.output_volume_mountpoint: OUTPUTS_VOLUME}, + volumes=CONF.mounts(output_volume=True), ) def production_run_cpu( run_name: str, @@ -469,11 +494,10 @@ def production_run_cpu( # Pick up exisiting trajectory and continue simulation when checkpoint exists traj_file_path = work_path / f"{deffnm}.xtc" checkpoint_file_path = work_path / f"{deffnm}.cpt" - nsteps = -2 # default: find nsteps from the mdp file + nsteps = -2 # default: use nsteps from the prepared TPR if traj_file_path.exists() and checkpoint_file_path.exists(): - simulated_steps = find_traj_last_step.remote(str(traj_file_path)) - total_steps = simulation_time_ns * 500000 # 2 fs timestep - nsteps = total_steps - simulated_steps + simulated_ns = find_traj_last_time_ns.remote(str(traj_file_path)) + nsteps = int((simulation_time_ns - simulated_ns) * 500000) # 2 fs timestep if nsteps <= 0: print("✅ Production run already completed, skipping.") return str(work_path) @@ -511,7 +535,7 @@ def production_run_cpu( # Modal adds this automatically but we want Gromacs to handle threading _ = run_command(cmd, cwd=str(work_path), env={"OMP_NUM_THREADS": None}) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() return str(work_path) @@ -519,7 +543,7 @@ def production_run_cpu( image=runtime_image, memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=CONF.timeout, - volumes={CONF.output_volume_mountpoint: OUTPUTS_VOLUME}, + volumes=CONF.mounts(output_volume=True), ) def postprocess_traj( traj_file: str, @@ -552,7 +576,7 @@ def postprocess_traj( env={"OMP_NUM_THREADS": None}, verbose=False, ) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() @app.function( @@ -560,7 +584,7 @@ def postprocess_traj( cpu=1, memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=CONF.timeout, - volumes={CONF.output_volume_mountpoint: OUTPUTS_VOLUME}, + volumes=CONF.mounts(output_volume=True), ) def collect_traj_stats( traj_prefix: str, @@ -596,7 +620,9 @@ def collect_traj_stats( str(processed_traj_path), ref_struct_file=str(work_path / f"{run_name}.pdb"), ) - OUTPUTS_VOLUME.reload() + + out_vol = CONF.output_volume + out_vol.reload() traj_1st_frame_pdb_path = work_path / f"{traj_prefix}{run_name}_nopbc_centered.pdb" if not traj_1st_frame_pdb_path.exists(): raise RuntimeError( @@ -625,7 +651,7 @@ def collect_traj_stats( trajectory = xtc_file.get_structure(template) if not save_processed_traj: processed_traj_path.unlink() - OUTPUTS_VOLUME.commit() + out_vol.commit() # Get simulation time (ns) for plotting purposes time = xtc_file.get_time() / 1000.0 @@ -641,7 +667,7 @@ def collect_traj_stats( last_frame_path.unlink(missing_ok=True) # remove outdated last frame if not last_frame_path.exists(): strucio.save_structure(last_frame_path, trajectory[-1]) - OUTPUTS_VOLUME.commit() + out_vol.commit() # RMSD vs. the initial frame rmsd_fig_path = work_path / f"rmsd_{traj_prefix}{run_name}.png" @@ -659,7 +685,7 @@ def collect_traj_stats( header="time_ns,rmsd", comments="", ) - OUTPUTS_VOLUME.commit() + out_vol.commit() if not rmsd_fig_path.exists() and make_figures: figure, ax = plt.subplots(figsize=(6, 3), dpi=200, layout="constrained") @@ -671,7 +697,7 @@ def collect_traj_stats( figure.savefig(rmsd_fig_path) plt.close(figure) - OUTPUTS_VOLUME.commit() + out_vol.commit() # Radius of gyration rg_fig_path = work_path / f"rg_{traj_prefix}{run_name}.png" @@ -689,7 +715,7 @@ def collect_traj_stats( header="time_ns,rg", comments="", ) - OUTPUTS_VOLUME.commit() + out_vol.commit() if not rg_fig_path.exists() and make_figures: figure, ax = plt.subplots(figsize=(6, 3), dpi=200, layout="constrained") ax.plot(time, rg, color=biotite.colors["dimgreen"]) @@ -700,7 +726,7 @@ def collect_traj_stats( figure.savefig(rg_fig_path) plt.close(figure) - OUTPUTS_VOLUME.commit() + out_vol.commit() # RMSF of each residue rmsf_fig_path = work_path / f"rmsf_{traj_prefix}{run_name}.png" @@ -722,7 +748,7 @@ def collect_traj_stats( header="residue_index,rmsf", comments="", ) - OUTPUTS_VOLUME.commit() + out_vol.commit() if not rmsf_fig_path.exists() and make_figures: # Sidechain atoms fluctuate too much, so we only consider CA atoms figure, ax = plt.subplots(figsize=(6, 3), dpi=200, layout="constrained") @@ -734,7 +760,7 @@ def collect_traj_stats( figure.savefig(rmsf_fig_path) plt.close(figure) - OUTPUTS_VOLUME.commit() + out_vol.commit() return str(work_path) @@ -827,8 +853,7 @@ def submit_gromacs_task( _ = modal.FunctionCall.gather(*process_traj_tasks, prod_traj_task) - remote_volume_dir = PurePosixPath(remote_workdir).relative_to( - CONF.output_volume_mountpoint + remote_vol = volume_path_from_mount_path( + remote_workdir, CONF.output_volume_mountpoint, CONF.output_volume_name ) - print("🧬 Gromacs preparation complete! Check data with: \n") - print(f" modal volume ls {OUTPUTS_VOLUME_NAME} {remote_volume_dir}") + print(f"🧬 Gromacs preparation complete! Check data in {remote_vol}") diff --git a/src/biomodals/app/bioinfo/paddleocr_app.py b/src/biomodals/app/bioinfo/paddleocr_app.py index 4535e73..705c29d 100644 --- a/src/biomodals/app/bioinfo/paddleocr_app.py +++ b/src/biomodals/app/bioinfo/paddleocr_app.py @@ -3,15 +3,13 @@ # ruff: noqa: PLC0415 import os -from dataclasses import dataclass from pathlib import Path import modal from biomodals.app.config import AppConfig -from biomodals.app.constant import MODEL_VOLUME from biomodals.helper import patch_image_for_helper -from biomodals.helper.shell import package_outputs, softlink_dir +from biomodals.helper.shell import package_outputs ########################################## # Modal configs @@ -26,28 +24,18 @@ cuda_version="cu126", gpu=os.environ.get("GPU", "L40S"), timeout=int(os.environ.get("TIMEOUT", "86400")), + model_volume_mountpoint="/root/.paddlex", ) - -@dataclass -class AppInfo: - """Container for PaddleOCR-specific configuration and constants.""" - - model_weights_path: str = "/root/.paddlex" - - ########################################## # Image and app definitions ########################################## -APP_INFO = AppInfo() runtime_image = ( - patch_image_for_helper( - modal.Image - .debian_slim(python_version=CONF.python_version) - .apt_install("git", "build-essential", "libgl1-mesa-glx", "libglib2.0-0") - .env(CONF.default_env), - copy_patch_files=True, - ) + modal.Image + .debian_slim(python_version=CONF.python_version) + .apt_install("git", "build-essential", "libgl1-mesa-glx", "libglib2.0-0") + .env(CONF.default_env) + .pipe(patch_image_for_helper, copy_patch_files=True) .uv_pip_install( "paddlepaddle-gpu==3.2.1", index_url=f"https://www.paddlepaddle.org.cn/packages/stable/{CONF.cuda_version}/", @@ -67,7 +55,7 @@ class AppInfo: gpu=CONF.gpu, memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=CONF.timeout, - volumes={CONF.model_volume_mountpoint: MODEL_VOLUME}, + volumes=CONF.mounts(model_volume=True, model_ro=False), ) def run_paddleocr(input_content: bytes, input_name: str) -> bytes: """Run PaddleOCR on the input PDF content and return extracted markdown and images.""" @@ -75,11 +63,6 @@ def run_paddleocr(input_content: bytes, input_name: str) -> bytes: from paddleocr import PaddleOCRVL # type: ignore[ty:unresolved-import] - # PaddleOCR hardcodes the model cache directory - model_cache_dir = Path(CONF.model_volume_mountpoint) / CONF.name - model_cache_dir.mkdir(parents=True, exist_ok=True) - softlink_dir(model_cache_dir, Path(APP_INFO.model_weights_path)) - run_dir = ".".join(input_name.split(".")[:-1]) workdir = Path(mkdtemp(prefix=f"{CONF.name}_")) / run_dir workdir.mkdir() diff --git a/src/biomodals/app/bioinfo/rosetta_app.py b/src/biomodals/app/bioinfo/rosetta_app.py index 7f2205d..5788f9f 100644 --- a/src/biomodals/app/bioinfo/rosetta_app.py +++ b/src/biomodals/app/bioinfo/rosetta_app.py @@ -12,7 +12,6 @@ import os from collections.abc import Iterable -from dataclasses import dataclass from io import BytesIO from pathlib import Path from uuid import uuid4 @@ -22,7 +21,8 @@ from biomodals.app.config import AppConfig from biomodals.helper import hash_string, patch_image_for_helper -from biomodals.helper.shell import package_outputs +from biomodals.helper.shell import package_outputs, warmup_directory +from biomodals.helper.volume_run import volume_path_from_mount_path ########################################## # Modal configs @@ -36,83 +36,17 @@ python_version="3.12", timeout=int(os.environ.get("TIMEOUT", "14400")), ) -OUT_VOLUME = CONF.get_out_volume() ROSETTA_DIR = Path(__file__).parent / "rosetta" -@dataclass(frozen=True) -class _RosettaCommandJob: - index: str - binary: str - pdb_path: Path - out_dir: Path - rosetta_script_path: Path | None = None - flags_path: Path | None = None - - -def _build_rosetta_command( - *, - binary: str, - pdb_path: Path, - out_dir: Path, - rosetta_script_path: Path | None = None, - flags_path: Path | None = None, -) -> list[str]: - cmd = [binary] - if rosetta_script_path is not None: - cmd.extend(["-parser:protocol", str(rosetta_script_path)]) - if flags_path is not None: - cmd.append(f"@{flags_path}") - cmd.extend(["-s", str(pdb_path), "-out:path:all", str(out_dir)]) - return cmd - - -def _command_for_rosetta_job(job: _RosettaCommandJob) -> list[str]: - return _build_rosetta_command( - binary=job.binary, - pdb_path=job.pdb_path, - out_dir=job.out_dir, - rosetta_script_path=job.rosetta_script_path, - flags_path=job.flags_path, - ) - - -def _required_rosetta_job_value(job_spec: dict[str, object], key: str) -> object: - value = job_spec[key] - if value is None: - raise ValueError(f"Rosetta job is missing {key!r}") - return value - - -def _optional_mounted_path(mount_dir: Path, path: object) -> Path | None: - if path is None: - return None - return mount_dir / str(path) - - -def _normalize_volume_rosetta_job( - job_spec: dict[str, object], *, mount_dir: Path, workdir: Path -) -> _RosettaCommandJob: - task_idx = str(_required_rosetta_job_value(job_spec, "index")) - return _RosettaCommandJob( - index=task_idx, - binary=str(_required_rosetta_job_value(job_spec, "binary")), - pdb_path=mount_dir / str(_required_rosetta_job_value(job_spec, "pdb")), - out_dir=workdir / task_idx, - rosetta_script_path=_optional_mounted_path( - mount_dir, job_spec.get("rosetta_script") - ), - flags_path=_optional_mounted_path(mount_dir, job_spec.get("flags_file")), - ) - - ########################################## # Image and app definitions ########################################## -runtime_image = patch_image_for_helper( - modal.Image.from_registry( - "rosettacommons/rosetta:serial-420", add_python=CONF.python_version - ) +runtime_image = ( + modal.Image + .from_registry("rosettacommons/rosetta:serial-420", add_python=CONF.python_version) + .env(CONF.default_env) + .pipe(patch_image_for_helper) ) app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) @@ -124,7 +58,7 @@ def _normalize_volume_rosetta_job( cpu=(0.125, 30.125), # Each pod can run 1-30 jobs memory=(1024, 43008), # reserve 1GB, OOM at 64GB timeout=CONF.timeout, - volumes={CONF.output_volume_mountpoint: OUT_VOLUME}, + volumes=CONF.mounts(output_volume=True), ) def run_rosetta(run_name: str, run_id: str, num_cpu_per_pod: int): """Run Rosetta scripts.""" @@ -141,15 +75,29 @@ def _worker(worker_idx: int) -> None: print(f"💊 No more jobs in queue for worker {worker_idx}") return - job = _normalize_volume_rosetta_job( - job_spec, - mount_dir=mount_dir, - workdir=workdir, - ) - run_command_with_log( - _command_for_rosetta_job(job), log_file=job.out_dir / "rosetta.log" - ) - OUT_VOLUME.commit() + task_idx = str(job_spec["index"]) + binary = job_spec["binary"] + pdb_path = job_spec["pdb"] + if binary is None or pdb_path is None: + raise ValueError(f"Rosetta job is missing required values: {job_spec}") + + out_dir = workdir / task_idx + cmd = [str(binary)] + if job_spec.get("rosetta_script") is not None: + cmd.extend([ + "-parser:protocol", + str(mount_dir / str(job_spec["rosetta_script"])), + ]) + if job_spec.get("flags_file") is not None: + cmd.append(f"@{mount_dir / str(job_spec['flags_file'])}") + cmd.extend([ + "-s", + str(mount_dir / str(pdb_path)), + "-out:path:all", + str(out_dir), + ]) + run_command_with_log(cmd, log_file=out_dir / "rosetta.log") + CONF.output_volume.commit() # Run workers in parallel within the pod from concurrent.futures import ThreadPoolExecutor @@ -164,7 +112,7 @@ def _worker(worker_idx: int) -> None: cpu=(1.125, 16.125), # burst for tar compression memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=CONF.timeout, - volumes={CONF.output_volume_mountpoint: OUT_VOLUME}, + volumes=CONF.mounts(output_volume=True), ) def package_outputs_helper( root: str | Path, @@ -173,6 +121,7 @@ def package_outputs_helper( num_threads: int = 16, ) -> bytes: """Modal runner to package directories into a tar.zst archive and return as bytes.""" + warmup_directory(root) return package_outputs( root, paths_to_bundle=paths_to_bundle, @@ -392,7 +341,7 @@ def submit_rosetta_task( print(f"🧬 Preparing queue for {run_name} tasks...") queue = modal.Queue.from_name(f"{CONF.name}-queue-{run_id}", create_if_missing=True) uploaded_files = set() - with OUT_VOLUME.batch_upload() as batch: + with CONF.output_volume.batch_upload() as batch: for r in tasks_df.iter_rows(named=True): # Structure file should always be present local_pdb = Path(r["pdb"]).expanduser().resolve() @@ -446,14 +395,13 @@ def submit_rosetta_task( modal.Queue.objects.delete(f"{CONF.name}-queue-{run_id}") # Save results locally - out_vol_name = OUT_VOLUME.name or f"{CONF.name}-outputs" - remote_data_dir = f"/{run_name}-{run_id}" - + out_vol = volume_path_from_mount_path( + f"{CONF.output_volume_mountpoint}/{run_name}-{run_id}", + CONF.output_volume_mountpoint, + CONF.output_volume_name, + ) if out_dir is None: - print( - f"🧬 {CONF.name} run complete!\n" - f"Results saved to Modal volume '{out_vol_name}' at '{remote_data_dir}'" - ) + print(f"🧬 {CONF.name} run complete!\nResults saved to {out_vol}") return local_out_dir = Path(out_dir).expanduser().resolve() diff --git a/src/biomodals/app/config.py b/src/biomodals/app/config.py index 5588a2d..219a31d 100644 --- a/src/biomodals/app/config.py +++ b/src/biomodals/app/config.py @@ -1,150 +1,75 @@ -"""Common configurations for Biomodals apps.""" +"""Compatibility helpers for Biomodals app configuration.""" + +from __future__ import annotations -import os from functools import cached_property -from pathlib import Path - -from modal import Volume -from pydantic import BaseModel, computed_field, model_validator - -from biomodals.app.constant import MAX_TIMEOUT - - -class AppConfig(BaseModel): - """Base configuration model for Biomodals apps.""" - - # Metadata - name: str - repo_url: str | None = None - repo_commit_hash: str | None = None - package_name: str | None = None - version: str | None = None - python_version: str | None = None - tags: dict[str, str] | None = None - - # Runtime configs - # Model GPU (https://modal.com/docs/guide/gpu) - # 16GB: T4 - # 24GB: L4, A10G - # 40GB: A100-40G, A100 (using A100 may cause Modal to auto-upgrade to A100-80G) - # 48GB: L40S - # 80GB: A100-80G, H100 (may auto-upgrade to H200, use H100! to avoid) - # 96GB: RTX-PRO-6000 - # 141GB: H200 - # 180GB: B200 (B200+ may auto-upgrade to B300, which requires CUDA13.0+) - gpu: str = "A10G" - # https://modal.com/docs/guide/cuda - cuda_version: str = "cu128" - # Default execution timeout in seconds (https://modal.com/docs/guide/timeouts) - timeout: int = int(os.environ.get("TIMEOUT", "1800")) - # Location to cache model weights and other large artifacts - model_volume_mountpoint: str = "/biomodals-store" - # Location to mount output volume (if in use) - output_volume_mountpoint: str = "/biomodals-outputs" - - def get_out_volume(self) -> Volume: - """Volume for storing outputs.""" - vol_name = f"{self.name}-outputs" - return Volume.from_name(vol_name, create_if_missing=True, version=2) +from pathlib import PurePosixPath - @computed_field - @cached_property - def default_env(self) -> dict[str, str]: - """Environment variables to set in the runtime image.""" - model_cache_dir = Path(self.model_volume_mountpoint).resolve() - return { - "UV_COMPILE_BYTECODE": "1", # slower image build, faster runtime - "HF_XET_HIGH_PERFORMANCE": "1", - "HF_HOME": str(model_cache_dir / "huggingface"), - "TORCH_HOME": str(model_cache_dir / "torch"), - "UV_TORCH_BACKEND": self.cuda_version, - } +from modal import CloudBucketMount, Volume +from pydantic import ConfigDict, computed_field - @computed_field - @cached_property - def model_dir(self) -> Path: - """Directory to store model weights.""" - return Path(self.model_volume_mountpoint) / self.name +from biomodals.schema.app import AppConfig as SchemaAppConfig + + +class AppConfig(SchemaAppConfig): + """App configuration with Modal-specific compatibility helpers.""" + + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) @computed_field @cached_property - def git_clone_dir(self) -> Path: - """Directory to store cloned Git repositories.""" - return Path(f"/opt/{self.name}") + def output_volume_name(self) -> str: + """Name of the output volume.""" + return f"{self.name}-outputs" @computed_field @cached_property - def cuda_version_numeric(self) -> str: - """Numeric CUDA version, e.g., '128' for 'cu128'. - - https://github.com/astral-sh/uv/blob/main/crates/uv-torch/src/backend.rs + def output_volume(self) -> Volume: + """Volume for storing outputs.""" + return Volume.from_name( + self.output_volume_name, create_if_missing=True, version=2 + ) + + def mounts( + self, + output_volume: bool = False, + model_volume: bool = False, + *, + model_ro: bool = True, + model_mount_subdir: bool = True, + is_huggingface: bool = False, + ) -> dict[str | PurePosixPath, Volume | CloudBucketMount]: + """Generate the volume mountpoints for modal.Function definitions. + + Args: + output_volume: Whether to mount the output volume. + model_volume: Whether to mount the model volume for storing checkpoints. + `self.model_volume_mountpoint` will be used as the mount point. + model_ro: Whether to mount the model volume as read-only. + model_mount_subdir: If True, only mount a subdirectory of the volume + for isolation from other apps. Otherwise, mount the full volume. + is_huggingface: Whether the model is managed by HuggingFace. + + Returns: + A dictionary mapping volume mount points to volumes. """ - if not self.cuda_version.startswith("cu"): - return "" - - available_uv_backends = { - "130", - "129", - "128", - "126", - "125", - "124", - "123", - "122", - "121", - "120", - "118", - "117", - "116", - "115", - "114", - "113", - "112", - "111", - "110", - "102", - "101", - "100", - "92", - "91", - "90", - } - - if (cuda_ver := self.cuda_version[2:]) not in available_uv_backends: - raise ValueError( - f"CUDA version {self.cuda_version} is not supported by UV. " - f"Available versions: {available_uv_backends}" - ) - return f"{cuda_ver[:-1]}.{cuda_ver[-1]}.0" - - @model_validator(mode="after") - def ensure_package_info(self): - """Ensure that the package information is complete.""" - if self.repo_url is None and self.package_name is None: - raise ValueError( - "At least one of 'repo_url' or 'package_name' must be provided." + volumes = {} + if output_volume: + volumes[self.output_volume_mountpoint] = self.output_volume + if model_volume: + from biomodals.helper.constant import MODEL_VOLUME + + if model_mount_subdir: + sub_path = ( + "/huggingface" if is_huggingface else self.model_volume_subdir + ) + else: + sub_path = None + + volumes[self.model_volume_mountpoint] = MODEL_VOLUME.with_mount_options( + read_only=model_ro, sub_path=sub_path ) - if self.repo_commit_hash is None and self.version is None: - raise ValueError( - "Provide 'repo_commit_hash' or 'version' for reproducibility." - ) - return self - - @model_validator(mode="after") - def ensure_cuda_gpu_compatibility(self): - """Ensure that the specified CUDA version is compatible with the GPU.""" - if not self.cuda_version.startswith("cu"): - raise ValueError("CUDA version must start with 'cu', e.g., 'cu128'.") - - is_cu12 = self.cuda_version.startswith("cu12") - if is_cu12 and self.gpu.startswith("B200+"): - raise ValueError("CUDA 12.x is not compatible with 'B200+ / B300' GPU.") - - return self - - @model_validator(mode="after") - def ensure_timeout_within_range(self): - """Ensure that the specified timeout is within a reasonable range.""" - # between 1 second and 24 hours - self.timeout = max(1, min(self.timeout, MAX_TIMEOUT)) - return self + return volumes + + +__all__ = ["AppConfig"] diff --git a/src/biomodals/app/constant.py b/src/biomodals/app/constant.py deleted file mode 100644 index b0e7f4b..0000000 --- a/src/biomodals/app/constant.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Constants used across Biomodals apps.""" - -from modal import Volume - -# Volume for caching all model weights -MODEL_VOLUME_NAME = "biomodals-store" -MODEL_VOLUME = Volume.from_name(MODEL_VOLUME_NAME, create_if_missing=True) - -# Volume for caching MSA databases, which are large and shared across apps -AF3_MSA_DB_VOLUME = Volume.from_name( - "AlphaFold3-msa-db", create_if_missing=True, version=2 -) -PROTENIX_MSA_DB_VOLUME = Volume.from_name( - "Protenix-msa-db", create_if_missing=True, version=2 -) - -# Volume for caching MSA search results -MSA_CACHE_VOLUME = Volume.from_name( - "biomodals-msa-cache", create_if_missing=True, version=2 -) - -# Max timeout for any function, in seconds (24 hours) -MAX_TIMEOUT = 86400 diff --git a/src/biomodals/app/design/antifold_app.py b/src/biomodals/app/design/antifold_app.py index 1c6300e..4705719 100644 --- a/src/biomodals/app/design/antifold_app.py +++ b/src/biomodals/app/design/antifold_app.py @@ -20,7 +20,6 @@ import modal from biomodals.app.config import AppConfig -from biomodals.app.constant import MODEL_VOLUME from biomodals.helper import patch_image_for_helper from biomodals.helper.shell import package_outputs, run_command @@ -36,13 +35,16 @@ python_version="3.10", cuda_version="cu121", gpu=os.environ.get("GPU", "A10G"), + # AntiFold hard-coded the download logic to look for models in + # ./models/model.pt + # sys.exec_prefix points to /usr/local + model_volume_mountpoint="/usr/local/lib/python3.10/site-packages/models", ) -MODEL_DIR = CONF.model_dir ########################################## # Image and app definitions ########################################## -runtime_image = patch_image_for_helper( +runtime_image = ( modal.Image .debian_slim(python_version=CONF.python_version) .apt_install("git", "build-essential", "wget") @@ -54,6 +56,8 @@ find_links=f"https://data.pyg.org/whl/torch-2.2.0+{CONF.cuda_version}.html", extra_options="--no-build-isolation", # https://github.com/astral-sh/uv/issues/5040 ) + .uv_pip_install("backports.strenum") # <3.11 + .pipe(patch_image_for_helper, skip_deps=["uniaf3"]) ) app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) @@ -68,7 +72,7 @@ memory=(1024, 65536), # reserve 1GB, OOM at 64GB image=runtime_image, timeout=CONF.timeout, - volumes={CONF.model_volume_mountpoint: MODEL_VOLUME}, + volumes=CONF.mounts(model_volume=True, model_ro=False), ) def antifold_inference( struct_bytes: bytes, @@ -89,20 +93,6 @@ def antifold_inference( """Manage AntiFold runs and return all inference results.""" from tempfile import TemporaryDirectory - # AntiFold hard-coded the download logic to look for models in ./models/model.pt - model_path = ( - Path(sys.exec_prefix) - / "lib" - / f"python{CONF.python_version}" - / "site-packages" - / "models" - / "model.pt" - ) - cache_model_path = MODEL_DIR / "model.pt" - if cache_model_path.exists(): - model_path.parent.mkdir(parents=True, exist_ok=True) - model_path.symlink_to(cache_model_path) - with TemporaryDirectory() as tmpdir: work_path = Path(tmpdir) / f"{output_id}_antifold" work_path.mkdir() @@ -148,13 +138,6 @@ def antifold_inference( print("💊 Packaging results...") tarball_bytes = package_outputs(work_path) - if not cache_model_path.exists(): - # Cache the model for future runs - import shutil - - shutil.copyfile(model_path, cache_model_path) - MODEL_VOLUME.commit() - return tarball_bytes diff --git a/src/biomodals/app/design/boltzgen_app.py b/src/biomodals/app/design/boltzgen_app.py index e24c225..43dff5b 100644 --- a/src/biomodals/app/design/boltzgen_app.py +++ b/src/biomodals/app/design/boltzgen_app.py @@ -17,8 +17,8 @@ import orjson from biomodals.app.config import AppConfig -from biomodals.app.constant import MAX_TIMEOUT, MODEL_VOLUME from biomodals.helper import patch_image_for_helper +from biomodals.helper.constant import MAX_TIMEOUT, MODEL_VOLUME from biomodals.helper.shell import ( package_outputs, run_command, @@ -26,6 +26,7 @@ sanitize_filename, warmup_directory, ) +from biomodals.helper.volume_run import volume_path_from_mount_path ########################################## # Modal configs @@ -42,15 +43,10 @@ gpu=os.environ.get("GPU", "L40S"), ) -# Volumes to be mounted -OUTPUTS_VOLUME = CONF.get_out_volume() -OUTPUTS_VOLUME_NAME = OUTPUTS_VOLUME.name or f"{CONF.name}-outputs" -OUTPUTS_DIR = CONF.output_volume_mountpoint - ########################################## # Image and app definitions ########################################## -runtime_image = patch_image_for_helper( +runtime_image = ( modal.Image .debian_slim(python_version=CONF.python_version) .apt_install("git", "build-essential", "zstd", "fd-find") @@ -58,6 +54,7 @@ .uv_pip_install("polars[pandas,numpy,calamine,xlsxwriter]", "tqdm") .uv_pip_install(f"git+{CONF.repo_url}@{CONF.repo_commit_hash}") .workdir(str(CONF.git_clone_dir)) + .pipe(patch_image_for_helper) ) app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) @@ -70,7 +67,7 @@ cpu=(1.125, 16.125), # burst for tar compression memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=MAX_TIMEOUT, - volumes={OUTPUTS_DIR: OUTPUTS_VOLUME}, + volumes=CONF.mounts(output_volume=True), image=runtime_image, ) def package_outputs_helper( @@ -193,7 +190,7 @@ def find_paths_in_list( # Fetch model weights ########################################## @app.function( - volumes={CONF.model_volume_mountpoint: MODEL_VOLUME}, + volumes=CONF.mounts(model_volume=True, model_ro=False, is_huggingface=True), secrets=[modal.Secret.from_name("huggingface")], timeout=MAX_TIMEOUT, ) @@ -213,12 +210,12 @@ def boltzgen_download(force: bool = False) -> None: ########################################## # Inference functions ########################################## -@app.function(timeout=CONF.timeout, volumes={OUTPUTS_DIR: OUTPUTS_VOLUME}) +@app.function(timeout=CONF.timeout, volumes=CONF.mounts(output_volume=True)) def prepare_boltzgen_run( yaml_content: bytes, run_name: str, additional_files: dict[str, bytes] ) -> None: """Prepare BoltzGen input and output directories.""" - workdir = Path(OUTPUTS_DIR) / run_name + workdir = Path(CONF.output_volume_mountpoint) / run_name for d in ("inputs", "outputs"): (workdir / d).mkdir(parents=True, exist_ok=True) @@ -233,10 +230,10 @@ def prepare_boltzgen_run( file_path.parent.mkdir(parents=True, exist_ok=True) file_path.write_bytes(content) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() -@app.function(timeout=CONF.timeout, volumes={OUTPUTS_DIR: OUTPUTS_VOLUME}) +@app.function(timeout=CONF.timeout, volumes=CONF.mounts(output_volume=True)) def get_run_ids( run_name: str, num_parallel_runs: int, @@ -249,8 +246,8 @@ def get_run_ids( from datetime import UTC, datetime from uuid import uuid4 - OUTPUTS_VOLUME.reload() - outdir = Path(OUTPUTS_DIR) / run_name / "outputs" + CONF.output_volume.reload() + outdir = Path(CONF.output_volume_mountpoint) / run_name / "outputs" if not salvage_mode: today: str = datetime.now(UTC).strftime("%Y%m%d") @@ -282,7 +279,7 @@ def get_run_ids( @app.function( memory=(128, 65536), # reserve 128MB, OOM at 64GB timeout=MAX_TIMEOUT, - volumes={OUTPUTS_DIR: OUTPUTS_VOLUME}, + volumes=CONF.mounts(output_volume=True), ) def collect_boltzgen_data( run_name: str, @@ -296,8 +293,9 @@ def collect_boltzgen_data( filter_rmsd_threshold: float = 4.0, ) -> bytes | list[str]: """Collect BoltzGen output data from multiple runs.""" - OUTPUTS_VOLUME.reload() - outdir = Path(OUTPUTS_DIR) / run_name / "outputs" + out_vol = CONF.output_volume + out_vol.reload() + outdir = Path(CONF.output_volume_mountpoint) / run_name / "outputs" config_dir = outdir.parent / "inputs" / "config" config_dir.mkdir(parents=True, exist_ok=True) @@ -327,15 +325,18 @@ def collect_boltzgen_data( else: print("💊 All planned BoltzGen runs are already complete; skipping relaunch.") - OUTPUTS_VOLUME.reload() + out_vol.reload() + vol_path = volume_path_from_mount_path( + str(outdir), CONF.output_volume_mountpoint, CONF.output_volume_name + ) if filter_results: # Rerun BoltzGen filters on all run IDs, and only download the designs # that passed all filters (also limited by the `budget`) - print("💊 Collecting BoltzGen outputs...") + print(f"💊 Collecting BoltzGen outputs in {vol_path}...") combine_multiple_runs.remote(run_name, run_ids) print("💊 Filtering combined BoltzGen designs...") refilter_designs.remote(run_name, budget, filter_rmsd_threshold) - OUTPUTS_VOLUME.reload() + out_vol.reload() print("💊 Packaging filtered BoltzGen outputs...") tarball_bytes = package_outputs_helper.remote( @@ -348,12 +349,10 @@ def collect_boltzgen_data( ], ) return tarball_bytes - else: - print("💊 Skipping refiltering of BoltzGen outputs.") - print( - f"💊 Results are available at: '{outdir.relative_to(OUTPUTS_DIR)}' in volume '{OUTPUTS_VOLUME_NAME}'." - ) - return run_ids + + print("💊 Skipping refiltering of BoltzGen outputs.") + print(f"💊 Results are available at: {vol_path}.") + return run_ids @app.cls( @@ -361,10 +360,7 @@ def collect_boltzgen_data( cpu=1.125, memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=MAX_TIMEOUT, - volumes={ - OUTPUTS_DIR: OUTPUTS_VOLUME, - CONF.model_volume_mountpoint: MODEL_VOLUME.read_only(), - }, + volumes=CONF.mounts(output_volume=True, model_volume=True, is_huggingface=True), ) class BoltzGenRunner: """Class to run BoltzGen on a YAML specification.""" @@ -409,10 +405,10 @@ def boltzgen_run( if lock_dir.exists() and (lock_dir.stat().st_mtime < (time.time() - 24 * 3600)): print(f"💊 Removing stale lock for {out_dir}.") lock_dir.rmdir() - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() try: lock_dir.mkdir(exist_ok=False) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() except FileExistsError: print( f"💊 Another worker is already running BoltzGen for {out_dir}; skipping." @@ -452,13 +448,13 @@ def clean_locks(self): if self.lock_dir.exists(): print(f"💊 Cleaning up lock directory {self.lock_dir}") self.lock_dir.rmdir() - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() @app.function( memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=MAX_TIMEOUT, - volumes={OUTPUTS_DIR: OUTPUTS_VOLUME}, + volumes=CONF.mounts(output_volume=True), ) def combine_multiple_runs(run_name: str, run_ids: list[str]): """Combine outputs from multiple BoltzGen runs into a single table.""" @@ -468,10 +464,10 @@ def combine_multiple_runs(run_name: str, run_ids: list[str]): import polars as pl from tqdm import tqdm - workdir = Path(OUTPUTS_DIR) / run_name / "outputs" - out_dir = Path(OUTPUTS_DIR) / run_name / "combined-outputs" + workdir = Path(CONF.output_volume_mountpoint) / run_name / "outputs" + out_dir = Path(CONF.output_volume_mountpoint) / run_name / "combined-outputs" (out_dir / "refold_cif").mkdir(parents=True, exist_ok=True) - OUTPUTS_VOLUME.reload() + CONF.output_volume.reload() metrics_dfs: list[pl.DataFrame] = [] ca_coords_seqs_dfs: list[pl.DataFrame] = [] @@ -534,7 +530,7 @@ def combine_multiple_runs(run_name: str, run_ids: list[str]): @app.function( memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=CONF.timeout, - volumes={OUTPUTS_DIR: OUTPUTS_VOLUME}, + volumes=CONF.mounts(output_volume=True), ) def refilter_designs( run_name: str, @@ -546,7 +542,7 @@ def refilter_designs( import polars as pl from boltzgen.task.filter.filter import Filter # type: ignore[ty:unresolved-import] - workdir = Path(OUTPUTS_DIR) / run_name + workdir = Path(CONF.output_volume_mountpoint) / run_name warmup_directory(workdir / "combined-outputs") filter_task = Filter( @@ -705,7 +701,7 @@ def submit_boltzgen_task( f"🧬 Including additional referenced files: {list(yml_parser.additional_files.keys())}" ) - # TODO: use OUTPUTS_VOLUME.batch_upload to avoid spinning up container + # TODO: use CONF.output_volume.batch_upload to avoid spinning up container print(f"🧬 Submitting BoltzGen run for yaml: {input_yaml}") yaml_str = yaml_path.read_bytes() @@ -771,7 +767,7 @@ def submit_boltzgen_task( "modal", "volume", "get", - OUTPUTS_VOLUME_NAME, + CONF.output_volume_name, f"{remote_root_dir}/{subdir}", ], cwd=run_out_dir, diff --git a/src/biomodals/app/design/ligandmpnn_app.py b/src/biomodals/app/design/ligandmpnn_app.py index 54eec84..897aa30 100644 --- a/src/biomodals/app/design/ligandmpnn_app.py +++ b/src/biomodals/app/design/ligandmpnn_app.py @@ -19,8 +19,8 @@ import modal from biomodals.app.config import AppConfig -from biomodals.app.constant import MAX_TIMEOUT, MODEL_VOLUME from biomodals.helper import patch_image_for_helper +from biomodals.helper.constant import MAX_TIMEOUT, MODEL_VOLUME from biomodals.helper.shell import ( find_with_fd, package_outputs, @@ -43,7 +43,6 @@ cuda_version="cu121", gpu=os.environ.get("GPU", "A10G"), ) -REPO_DIR = CONF.git_clone_dir AVAILABLE_MODELS = { # ProteinMPNN @@ -80,22 +79,13 @@ ########################################## # Image and app definitions ########################################## -runtime_image = patch_image_for_helper( +runtime_image = ( modal.Image .debian_slim(python_version=CONF.python_version) .apt_install("git", "build-essential", "wget") .env(CONF.default_env) - # .run_commands( - # " && ".join( - # ( - # f"git clone {CONF.repo_url} {REPO_DIR}", - # f"cd {REPO_DIR}", - # f"git checkout {CONF.repo_commit_hash}", - # "uv pip install --system -r requirements.txt", - # ) - # ) - # ) .uv_pip_install(f"{CONF.package_name}=={CONF.version}") + .pipe(patch_image_for_helper) ) app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) @@ -131,8 +121,7 @@ def torch_to_numpy(pt_file: str | Path) -> dict[str, Any]: # Fetch model weights ########################################## @app.function( - volumes={CONF.model_volume_mountpoint: MODEL_VOLUME}, - timeout=MAX_TIMEOUT, + volumes=CONF.mounts(model_volume=True, model_ro=False), timeout=MAX_TIMEOUT ) def download_weights() -> None: """Download ProteinMPNN models into the mounted volume. @@ -141,12 +130,13 @@ def download_weights() -> None: AbMPNN ref: https://zenodo.org/records/8164693 """ base_url = "https://files.ipd.uw.edu/pub/ligandmpnn" + model_dir = Path(CONF.model_volume_mountpoint) ligandmpnn_weights = { - f"{base_url}/{model_name}": CONF.model_dir / "model_params" / model_name + f"{base_url}/{model_name}": model_dir / "model_params" / model_name for model_name in AVAILABLE_MODELS } abmpnn_dict = { - "https://zenodo.org/records/8164693/files/abmpnn.pt?download=1": CONF.model_dir + "https://zenodo.org/records/8164693/files/abmpnn.pt?download=1": model_dir / "model_params" / "abmpnn.pt" } @@ -214,7 +204,7 @@ def build_base_command( gpu=CONF.gpu, memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=86400, - volumes={CONF.model_volume_mountpoint: MODEL_VOLUME.read_only()}, + volumes=CONF.mounts(model_volume=True), ) def ligandmpnn_run( run_name: str, @@ -242,6 +232,7 @@ def ligandmpnn_run( omit_aa_per_residue_bytes, ) + model_dir = Path(CONF.model_volume_mountpoint) log_path = workdir / "ligandmpnn-run.log" print(f"💊 Running LigandMPNN, saving logs to {log_path}") for seed in tqdm(seeds, desc="Inference seeds"): @@ -251,7 +242,7 @@ def ligandmpnn_run( "--out_folder", str(workdir / "outputs" / f"seed-{seed}"), ] - run_command_with_log(cmd, log_file=log_path, cwd=CONF.model_dir) + run_command_with_log(cmd, log_file=log_path, cwd=model_dir) # Convert .pt outputs to numpy print("💊 Converting .pt outputs to numpy...") @@ -439,7 +430,7 @@ def submit_ligandmpnn_task( cli_args[f"--checkpoint_{model_type}"] = checkpoint elif model_type == "abmpnn": cli_args["--checkpoint_protein_mpnn"] = str( - CONF.model_dir / "model_params" / "abmpnn.pt" + Path(CONF.model_volume_mountpoint) / "model_params" / "abmpnn.pt" ) if fixed_residues is not None: cli_args["--fixed_residues"] = fixed_residues @@ -516,8 +507,4 @@ def submit_ligandmpnn_task( print(f"🧬 Downloading results for {run_name}...") (local_out_dir / f"{run_name}-{script_mode}.tar.zst").write_bytes(res_bytes) - # run_command( - # ["modal", "volume", "get", OUTPUTS_VOLUME_NAME, str(remote_results_dir)], - # cwd=local_out_dir, - # ) print(f"🧬 Results saved to: {local_out_dir.resolve()}") diff --git a/src/biomodals/app/design/ppiflow_app.py b/src/biomodals/app/design/ppiflow_app.py index bc8ab96..7e9e0d5 100644 --- a/src/biomodals/app/design/ppiflow_app.py +++ b/src/biomodals/app/design/ppiflow_app.py @@ -11,9 +11,11 @@ from pydantic import BaseModel, computed_field, model_validator from biomodals.app.config import AppConfig -from biomodals.app.constant import MAX_TIMEOUT, MODEL_VOLUME, MODEL_VOLUME_NAME from biomodals.helper import patch_image_for_helper +from biomodals.helper.constant import MAX_TIMEOUT, MODEL_VOLUME, MODEL_VOLUME_NAME from biomodals.helper.shell import run_command_with_log, sanitize_filename +from biomodals.helper.volume_run import volume_path_from_mount_path +from biomodals.schema import AppOutput, AppRunResult, AppRunStatus, ArtifactKind ########################################## # Modal configs @@ -29,15 +31,12 @@ gpu=os.environ.get("GPU", "L40S"), ) -# Volumes to be mounted -OUTPUTS_VOLUME = CONF.get_out_volume() -OUTPUTS_VOLUME_NAME = OUTPUTS_VOLUME.name or f"{CONF.name}-outputs" SCRIPTS_DIR = CONF.git_clone_dir / "tool" / "PPIFlow" ########################################## # Image and app definitions ########################################## -runtime_image = patch_image_for_helper( +runtime_image = ( modal.Image .micromamba(python_version=CONF.python_version) .apt_install("git", "build-essential") @@ -119,6 +118,7 @@ find_links=f"https://data.pyg.org/whl/torch-2.10.0+{CONF.cuda_version}.html", ) .workdir(str(SCRIPTS_DIR)) + .pipe(patch_image_for_helper) ) app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) @@ -262,10 +262,12 @@ def model_weights_name(self) -> str: ########################################## # Fetch model weights ########################################## -@app.function(volumes={CONF.model_volume_mountpoint: MODEL_VOLUME}, timeout=MAX_TIMEOUT) +@app.function( + volumes=CONF.mounts(model_volume=True, model_ro=False), timeout=MAX_TIMEOUT +) def fetch_model_weights(force: bool = False) -> None: """Download PPIFlow models into the mounted volume.""" - model_dir = CONF.model_dir + model_dir = Path(CONF.model_volume_mountpoint) base_url = "https://drive.google.com/uc?export=download&confirm=t&id=" tasks = { f"{base_url}1WBSjCTEtia9S1hJ54mYH1PZdDqpLVsgw": model_dir / "antibody.ckpt", @@ -295,10 +297,7 @@ def fetch_model_weights(force: bool = False) -> None: cpu=(0.125, 16.125), memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=MAX_TIMEOUT, - volumes={ - CONF.output_volume_mountpoint: OUTPUTS_VOLUME, - CONF.model_volume_mountpoint: MODEL_VOLUME.read_only(), - }, + volumes=CONF.mounts(output_volume=True, model_volume=True), ) def ppiflow_run(args: PPIFlowArgs, run_name: str) -> str: """Actual remote runner of PPIFlow.""" @@ -311,7 +310,7 @@ def ppiflow_run(args: PPIFlowArgs, run_name: str) -> str: return str(workdir) # Build command - model_weights_path = CONF.model_dir / args.model_weights_name + model_weights_path = Path(CONF.model_volume_mountpoint) / args.model_weights_name arg_fields = args.args.model_dump(exclude_none=True) cmd = [ sys.executable, @@ -326,10 +325,42 @@ def ppiflow_run(args: PPIFlowArgs, run_name: str) -> str: print(f"💊 Running {CONF.name}, saving logs to {log_path}") run_command_with_log(cmd, log_file=log_path) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() return str(workdir) +@app.function( + gpu=CONF.gpu, + cpu=(0.125, 16.125), + memory=(1024, 65536), + timeout=MAX_TIMEOUT, + volumes=CONF.mounts(output_volume=True, model_volume=True), +) +def ppiflow_run_workflow(args: PPIFlowArgs, run_name: str) -> AppRunResult: + """Run PPIFlow and return a workflow-compatible app result.""" + safe_run_name = sanitize_filename(run_name) + remote_workdir = ppiflow_run.get_raw_f()(args=args, run_name=safe_run_name) + return AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="ppiflow_outputs", + kind=ArtifactKind.DIRECTORY, + storage=volume_path_from_mount_path( + str(remote_workdir), + CONF.output_volume_mountpoint, + CONF.output_volume_name, + ), + metadata={ + "run_name": safe_run_name, + "script_name": args.script_name, + "model_weights_name": args.model_weights_name, + }, + ) + ], + ) + + ########################################## # Entrypoint for ephemeral usage ########################################## @@ -394,10 +425,12 @@ def submit_ppiflow_task( raise ValueError(f"Unsupported design_mode: {design_mode}") # NOTE: make sure names are unique for different inputs - - with OUTPUTS_VOLUME.batch_upload() as batch: + remote_dir = volume_path_from_mount_path( + str(remote_workdir), CONF.output_volume_mountpoint, CONF.output_volume_name + ) + with CONF.output_volume.batch_upload() as batch: for file in files_to_upload: - print(f"🧬 Uploading '{file}' to volume {OUTPUTS_VOLUME_NAME}...") + print(f"🧬 Uploading '{file}' to {remote_dir}...") batch.put_file(file, f"{run_name}/{Path(file).name}") print(f"🧬 Submitting PPIFlow task with run name: {run_name}") diff --git a/src/biomodals/app/design/rfdiffusion_app.py b/src/biomodals/app/design/rfdiffusion_app.py index 6b5ce88..38abcd6 100644 --- a/src/biomodals/app/design/rfdiffusion_app.py +++ b/src/biomodals/app/design/rfdiffusion_app.py @@ -71,6 +71,7 @@ run_command, warmup_directory, ) +from biomodals.helper.volume_run import volume_path_from_mount_path # ------------------------- # Modal configs @@ -80,7 +81,8 @@ APP_NAME = os.environ.get("MODAL_APP", "RFdiffusion") RFD_VOLUME = Volume.from_name("rfdiffusion-models", create_if_missing=True) -RFD_OUT_VOLUME = Volume.from_name("rfdiffusion-outputs", create_if_missing=True) +RFD_OUT_VOLUME_NAME = "rfdiffusion-outputs" +RFD_OUT_VOLUME = Volume.from_name(RFD_OUT_VOLUME_NAME, create_if_missing=True) RFD_REPO_DIR = "/root/RFdiffusion" RFD_MODELS_DIR = f"{RFD_REPO_DIR}/models" @@ -107,7 +109,7 @@ # For an example of a newer, modern CUDA/PyTorch-style Docker environment, see: # https://github.com/JMB-Scripts/RFdiffusion-dockerfile-nvidia-RTX5090/blob/main/RTX-5090.dockerfile # The runtime image below is defined directly with Modal and is not built from that Dockerfile. -runtime_image = patch_image_for_helper( +runtime_image = ( Image .debian_slim(python_version="3.10") .apt_install( @@ -162,6 +164,7 @@ "python -m pip install --no-cache-dir -r requirements.txt && " "python setup.py install" ) + .pipe(patch_image_for_helper) ) @@ -259,7 +262,7 @@ async def download_rfdiffusion_models(force: bool = False) -> None: timeout=TIMEOUT, image=runtime_image, volumes={ - RFD_MODELS_DIR: RFD_VOLUME.read_only(), + RFD_MODELS_DIR: RFD_VOLUME.with_mount_options(read_only=True), # output cache volume. RFD_OUT_DIR: RFD_OUT_VOLUME, }, @@ -323,6 +326,12 @@ def rfdiffusion_infer( # ---- commit cached outputs ---- RFD_OUT_VOLUME.commit() + remote_run_dir = volume_path_from_mount_path( + str(cached_run_dir), + RFD_OUT_DIR, + RFD_OUT_VOLUME_NAME, + ) + print(f"RFdiffusion cached outputs: {remote_run_dir}", flush=True) # ---- bundle outputs for return ---- warmup_directory(run_dir) @@ -348,39 +357,29 @@ def submit_rfdiffusion_task( ): """Submit an RFdiffusion inference job to Modal. - Parameters - ---------- - run_name : str - Unique name for this run. Used as the output-volume cache key and as part - of the returned output archive filename. - input_pdb : str - Path to the input PDB file on the local machine. The file will be uploaded - to the Modal worker before inference starts. - contigs : str | None - Convenience wrapper for `contigmap.contigs` (Hydra override). This argument - simplifies common RFdiffusion use cases such as binder or scaffold design. - num_designs : int - Convenience wrapper for `inference.num_designs` (Hydra override). - hotspot_res : str | None - Convenience wrapper for `ppi.hotspot_res` (Hydra override), typically used - for binder design. - rfd_args : str - Raw RFdiffusion Hydra overrides passed directly to the inference script. - This acts as an escape hatch for advanced or unsupported options. - download_models : bool - If set, download RFdiffusion checkpoint weights into the persistent models - volume and exit without running inference. - force_redownload : bool - Force re-download checkpoint weights even if they already exist in the - models volume. - out_dir : str | None - Optional local directory where the output `.tar.zst` archive will be written. - Defaults to the current working directory. + Args: + run_name: Unique name for this run. Used as the output-volume cache key + and as part of the returned output archive filename. + input_pdb: Path to the input PDB file on the local machine. The file + will be uploaded to the Modal worker before inference starts. + contigs: Convenience wrapper for `contigmap.contigs` (Hydra override). + This simplifies common RFdiffusion use cases such as binder or + scaffold design. + num_designs: Convenience wrapper for `inference.num_designs`. + hotspot_res: Convenience wrapper for `ppi.hotspot_res`, typically used + for binder design. + rfd_args: Raw RFdiffusion Hydra overrides passed directly to the + inference script. This is an escape hatch for advanced options. + download_models: If set, download RFdiffusion checkpoint weights into + the persistent models volume and exit without running inference. + force_redownload: Force re-download checkpoint weights even if they + already exist in the models volume. + out_dir: Optional local directory where the output `.tar.zst` archive + will be written. Defaults to the current working directory. Notes: - ----- - - For longer jobs, increase TIMEOUT via environment variable: - TIMEOUT=360000 modal run rfdiffusion_app.py ... + For longer jobs, increase `TIMEOUT` via environment variable: + `TIMEOUT=360000 modal run rfdiffusion_app.py ...`. """ if download_models: download_rfdiffusion_models.remote(force=force_redownload) @@ -389,7 +388,7 @@ def submit_rfdiffusion_task( if run_name is None: raise ValueError("Missing required --run-name") - run_name = validate_run_name(run_name) + run_name = validate_run_name(run_name) if input_pdb is None: raise ValueError("Missing required --input-pdb (path to local .pdb)") diff --git a/src/biomodals/app/fold/abcfold2_app.py b/src/biomodals/app/fold/abcfold2_app.py index 7fe81d8..4c8cf4e 100644 --- a/src/biomodals/app/fold/abcfold2_app.py +++ b/src/biomodals/app/fold/abcfold2_app.py @@ -23,8 +23,8 @@ import modal from biomodals.app.config import AppConfig -from biomodals.app.constant import MODEL_VOLUME from biomodals.helper import patch_image_for_helper +from biomodals.helper.constant import MODEL_VOLUME from biomodals.helper.shell import package_outputs from biomodals.helper.web import download_files @@ -73,33 +73,29 @@ class AppInfo: ########################################## APP_INFO = AppInfo() -# Volumes -OUTPUTS_VOLUME = CONF.get_out_volume() -OUTPUTS_VOLUME_NAME = OUTPUTS_VOLUME.name -OUTPUTS_DIR = CONF.output_volume_mountpoint - -download_image = patch_image_for_helper( +download_image = ( modal.Image .debian_slim() .uv_pip_install("huggingface_hub>=1.10") .env( CONF.default_env | { - "CHAI_DOWNLOADS_DIR": str(ChaiConf.model_dir), - "BOLTZ_CACHE": str(BoltzConf.model_dir), + "CHAI_DOWNLOADS_DIR": ChaiConf.model_volume_mountpoint, + "BOLTZ_CACHE": BoltzConf.model_volume_mountpoint, } ) + .pipe(patch_image_for_helper) ) -runtime_image = patch_image_for_helper( +runtime_image = ( modal.Image .debian_slim() .apt_install("git", "build-essential") .env( CONF.default_env | { - "CHAI_DOWNLOADS_DIR": str(ChaiConf.model_dir), - "BOLTZ_CACHE": str(BoltzConf.model_dir), + "CHAI_DOWNLOADS_DIR": ChaiConf.model_volume_mountpoint, + "BOLTZ_CACHE": BoltzConf.model_volume_mountpoint, } ) .run_commands( @@ -125,6 +121,7 @@ class AppInfo: .env({"PATH": f"{APP_INFO.abcfold_dir}/.venv/bin:$PATH"}) .apt_install("kalign") # for Chai templates .workdir(APP_INFO.abcfold_dir) + .pipe(patch_image_for_helper) ) app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) @@ -134,7 +131,7 @@ class AppInfo: # Fetch model weights ########################################## @app.function( - volumes={str(BoltzConf.model_volume_mountpoint): MODEL_VOLUME}, + volumes=BoltzConf.mounts(model_volume=True, model_ro=False, is_huggingface=True), timeout=CONF.timeout, image=download_image, ) @@ -147,23 +144,24 @@ def download_boltz_models(force: bool = False) -> None: from huggingface_hub import snapshot_download # type: ignore[ty:unresolved-import] + boltz_download_dir = Path(BoltzConf.model_volume_mountpoint) snapshot_download( repo_id="boltz-community/boltz-2", revision=APP_INFO.boltz_model_hash, - local_dir=BoltzConf.model_dir, + local_dir=boltz_download_dir, force_download=force, ) - boltz_download_dir = BoltzConf.model_dir + MODEL_VOLUME.commit() + tar_mols = boltz_download_dir / "mols.tar" if not (boltz_download_dir / "mols").exists(): with tarfile.open(str(tar_mols), "r") as tar: tar.extractall(boltz_download_dir) # noqa: S202 - MODEL_VOLUME.commit() @app.function( - volumes={str(ChaiConf.model_volume_mountpoint): MODEL_VOLUME}, + volumes=ChaiConf.mounts(model_volume=True, model_ro=False), timeout=CONF.timeout, image=download_image, ) @@ -182,11 +180,12 @@ async def download_chai_models(force=False): ] # launch downloads concurrently - chai_model_dir = ChaiConf.model_dir + chai_model_dir = Path(ChaiConf.model_volume_mountpoint) download_tasks = { f"{base_url}{dep}": chai_model_dir / dep for dep in inference_dependencies } download_files(download_tasks, progress_bar_desc="Downloading Chai models") + MODEL_VOLUME.commit() # Special treatment for ESM esm2_path = chai_model_dir / "esm2" / "traced_sdpa_esm2_t36_3B_UR50D_fp16.pt" @@ -237,10 +236,7 @@ def get_run_id(yaml_str: bytes) -> str: @app.function( image=runtime_image, timeout=CONF.timeout, - volumes={ - OUTPUTS_DIR: OUTPUTS_VOLUME, - BoltzConf.model_volume_mountpoint: MODEL_VOLUME, - }, + volumes=CONF.mounts(output_volume=True) | BoltzConf.mounts(model_volume=True), ) def prepare_abcfold2( yaml_str: bytes, search_templates: bool, msa_chains: str | None = None @@ -258,7 +254,8 @@ def prepare_abcfold2( run_id: str = get_run_id.local(yaml_str=yaml_str) if not search_templates: run_id = f"{run_id}-no-tmpl" - out_dir_full: Path = Path(OUTPUTS_DIR) / run_id[:2] / run_id + out_root = Path(CONF.output_volume_mountpoint) + out_dir_full: Path = out_root / run_id[:2] / run_id out_dir_full.mkdir(parents=True, exist_ok=True) # Check if MSA and templates were already generated for a previous run with same ID @@ -277,21 +274,21 @@ def prepare_abcfold2( force=True, chains=msa_chains, search_templates=search_templates, - template_cache_dir=Path(OUTPUTS_DIR) / ".cache" / "rcsb", + template_cache_dir=out_root / ".cache" / "rcsb", ) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() # Generate inputs for Boltz and Chai if not (out_dir_full / "boltz_models" / f"{run_id}.yaml").exists(): _ = prepare_boltz(conf_file=yaml_path, out_dir=out_dir_full) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() if not (out_dir_full / "chai_models" / f"{run_id}.yaml").exists(): _ = prepare_chai( conf_file=yaml_path, out_dir=out_dir_full, - ccd_lib_dir=BoltzConf.model_dir / "mols", + ccd_lib_dir=Path(BoltzConf.model_volume_mountpoint) / "mols", ) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() # Pull run parameters from YAML conf = load_params_from_run_yaml(yaml_path) @@ -306,10 +303,7 @@ def prepare_abcfold2( memory=(1024, 65536), # reserve 1GB, OOM at 64GB image=runtime_image, timeout=CONF.timeout, - volumes={ - OUTPUTS_DIR: OUTPUTS_VOLUME, - BoltzConf.model_volume_mountpoint: MODEL_VOLUME, - }, + volumes=CONF.mounts(output_volume=True) | BoltzConf.mounts(model_volume=True), ) def collect_abcfold2_boltz_data( run_conf: dict[str, str | list[int] | int | list[str] | None], @@ -321,7 +315,7 @@ def collect_abcfold2_boltz_data( run_id = run_conf["run_id"] work_path = work_path / "boltz_models" boltz_conf_path = work_path / f"{run_id}.yaml" - OUTPUTS_VOLUME.reload() + CONF.output_volume.reload() if not boltz_conf_path.exists(): raise FileNotFoundError(f"Boltz config file not found: {boltz_conf_path}") @@ -343,7 +337,7 @@ def collect_abcfold2_boltz_data( for boltz_run_dir in run_abcfold2_boltz.map(seeds_to_run, kwargs=run_conf): print(f"Boltz run complete: {boltz_run_dir}") - OUTPUTS_VOLUME.reload() + CONF.output_volume.reload() print("💊 Packaging Boltz results...") boltz_tarball_bytes = package_outputs( work_path, @@ -366,10 +360,7 @@ def collect_abcfold2_boltz_data( memory=(1024, 65536), # reserve 1GB, OOM at 64GB image=runtime_image, timeout=CONF.timeout, - volumes={ - OUTPUTS_DIR: OUTPUTS_VOLUME, - BoltzConf.model_volume_mountpoint: MODEL_VOLUME, - }, + volumes=CONF.mounts(output_volume=True) | BoltzConf.mounts(model_volume=True), ) def run_abcfold2_boltz( seed: int, @@ -386,7 +377,7 @@ def run_abcfold2_boltz( run_boltz, ) - OUTPUTS_VOLUME.reload() + CONF.output_volume.reload() work_path = Path(workdir).expanduser().resolve() work_path = work_path / "boltz_models" boltz_conf_path = work_path / f"{run_id}.yaml" @@ -402,7 +393,7 @@ def run_abcfold2_boltz( num_diffn_samples=num_diffn_samples, boltz_additional_cli_args=boltz_additional_cli_args, ) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() return str(boltz_run_dir) @@ -411,10 +402,7 @@ def run_abcfold2_boltz( memory=(1024, 65536), # reserve 1GB, OOM at 64GB image=runtime_image, timeout=CONF.timeout, - volumes={ - OUTPUTS_DIR: OUTPUTS_VOLUME, - ChaiConf.model_volume_mountpoint: MODEL_VOLUME, - }, + volumes=CONF.mounts(output_volume=True) | ChaiConf.mounts(model_volume=True), ) def collect_abcfold2_chai_data( run_conf: dict[str, str | list[int] | int | list[str] | None], @@ -426,7 +414,7 @@ def collect_abcfold2_chai_data( run_id = run_conf["run_id"] work_path = work_path / "chai_models" chai_conf_path = work_path / f"{run_id}.yaml" - OUTPUTS_VOLUME.reload() + CONF.output_volume.reload() if not chai_conf_path.exists(): raise FileNotFoundError(f"Chai config file not found: {chai_conf_path}") @@ -448,7 +436,7 @@ def collect_abcfold2_chai_data( for chai_run_dir in run_abcfold2_chai.map(seeds_to_run, kwargs=run_conf): print(f"Chai run complete: {chai_run_dir}") - OUTPUTS_VOLUME.reload() + CONF.output_volume.reload() print("💊 Packaging Chai results...") chai_tarball_bytes = package_outputs(work_path) return chai_tarball_bytes @@ -459,10 +447,7 @@ def collect_abcfold2_chai_data( memory=(1024, 65536), # reserve 1GB, OOM at 64GB image=runtime_image, timeout=CONF.timeout, - volumes={ - OUTPUTS_DIR: OUTPUTS_VOLUME, - ChaiConf.model_volume_mountpoint: MODEL_VOLUME, - }, + volumes=CONF.mounts(output_volume=True) | ChaiConf.mounts(model_volume=True), ) def run_abcfold2_chai( seed: int, @@ -479,7 +464,7 @@ def run_abcfold2_chai( run_chai, ) - OUTPUTS_VOLUME.reload() + CONF.output_volume.reload() work_path = Path(workdir).expanduser().resolve() chai_work_path = work_path / "chai_models" chai_conf_path = chai_work_path / f"{run_id}.yaml" @@ -500,7 +485,7 @@ def run_abcfold2_chai( num_diffn_samples=num_diffn_samples, num_trunk_samples=num_trunk_samples, ) - OUTPUTS_VOLUME.commit() + CONF.output_volume.commit() return str(chai_run_dir) diff --git a/src/biomodals/app/fold/alphafold3_app.py b/src/biomodals/app/fold/alphafold3_app.py index 778edab..59cf564 100644 --- a/src/biomodals/app/fold/alphafold3_app.py +++ b/src/biomodals/app/fold/alphafold3_app.py @@ -37,13 +37,13 @@ ) from biomodals.app.config import AppConfig -from biomodals.app.constant import ( +from biomodals.helper import hash_string, patch_image_for_helper +from biomodals.helper.constant import ( AF3_MSA_DB_VOLUME, MAX_TIMEOUT, - MODEL_VOLUME, MSA_CACHE_VOLUME, + MSA_CACHE_VOLUME_NAME, ) -from biomodals.helper import hash_string, patch_image_for_helper from biomodals.helper.io import ( build_local_output_path, resolve_local_output_dir, @@ -62,10 +62,10 @@ CONF = AppConfig( tags={"group": Path(__file__).parent.name}, name="AlphaFold3", - repo_url="https://github.com/google-deepmind/alphafold3", - repo_commit_hash="87bd9e678d9acacc4aa9baa05e820f32b80e1b49", + repo_url="https://github.com/y1zhou/alphafold3", + repo_commit_hash="987ad1cb7d7028b6d35908cf63fe7d951d98d6b6", package_name="alphafold3", - version="3.0.1", + version="3.0.2", python_version="3.12", cuda_version="cu130", gpu=os.environ.get("GPU", "L40S"), @@ -80,7 +80,8 @@ class AppInfo: # Volume mount path for genetic search databases msa_db_dir: str = f"/{CONF.name}-msa-db" # Volume mount path for MSA output cache - msa_cache_dir: str = "/biomodals-msa-cache" + msa_cache_dir: str = f"/{MSA_CACHE_VOLUME_NAME}" + msa_cache_volume_subdir: str = f"/{CONF.name}" ########################################## @@ -89,7 +90,7 @@ class AppInfo: APP_INFO = AppInfo() # Ref: https://github.com/google-deepmind/alphafold3/blob/main/docker/Dockerfile -runtime_image = patch_image_for_helper( +runtime_image = ( modal.Image .debian_slim(python_version=CONF.python_version) .apt_install("git", "build-essential", "zstd", "zlib1g-dev", "wget") @@ -132,6 +133,7 @@ class AppInfo: .uv_pip_install(str(CONF.git_clone_dir)) .run_commands("build_data") # installed in the previous step .env({"PATH": "/hmmer/bin:$PATH"}) + .pipe(patch_image_for_helper) ) app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) @@ -224,6 +226,15 @@ def _cache_conf_unpaired_msa(conf: AF3Config, msa_cache_dir: Path) -> AF3Config: return conf +def _af3_sanitised_name(name: str) -> str: + """Return sanitised version of the name that can be used as a filename.""" + import string + + spaceless_name = name.replace(" ", "_") + allowed_chars = set(string.ascii_letters + string.digits + "_-.") + return "".join(x for x in spaceless_name if x in allowed_chars) + + ########################################## # Inference functions ########################################## @@ -235,10 +246,12 @@ def _cache_conf_unpaired_msa(conf: AF3Config, msa_cache_dir: Path) -> AF3Config: # mmCIF templates .tar.zst: 57.6GiB # ephemeral_disk=1024 * round(304.8 + 5), # MiB, billed by memory at 20:1 ratio timeout=CONF.timeout, - volumes={ - CONF.model_volume_mountpoint: MODEL_VOLUME, + volumes=CONF.mounts(model_volume=True) + | { APP_INFO.msa_db_dir: AF3_MSA_DB_VOLUME, - APP_INFO.msa_cache_dir: MSA_CACHE_VOLUME, + APP_INFO.msa_cache_dir: MSA_CACHE_VOLUME.with_mount_options( + sub_path=APP_INFO.msa_cache_volume_subdir + ), }, ) def run_data_pipeline(json_bytes: bytes, copy_msa_to_ssd: bool = True) -> bytes: @@ -247,7 +260,7 @@ def run_data_pipeline(json_bytes: bytes, copy_msa_to_ssd: bool = True) -> bytes: from tempfile import mkdtemp # Try to fill config with cached MSA - msa_cache_dir = Path(APP_INFO.msa_cache_dir) / CONF.name + msa_cache_dir = Path(APP_INFO.msa_cache_dir) conf = _load_conf_from_bytes(json_bytes) MSA_CACHE_VOLUME.reload() conf = _cache_conf_unpaired_msa(conf, msa_cache_dir) @@ -255,7 +268,7 @@ def run_data_pipeline(json_bytes: bytes, copy_msa_to_ssd: bool = True) -> bytes: # Check if all protein/RNA sequences have MSA results temp_dir: Path = Path(mkdtemp(prefix="alphafold3_data_")) - run_name = conf.name + run_name = _af3_sanitised_name(conf.name) input_json_path = temp_dir / f"{run_name}.json" conf.to_files(temp_dir, run_name) all_protein_msa_filled = all( @@ -269,6 +282,7 @@ def run_data_pipeline(json_bytes: bytes, copy_msa_to_ssd: bool = True) -> bytes: if (rna_seq := seq.rna) is not None ) if all_protein_msa_filled and all_rna_msa_filled: + print("💊 MSA cache hit, returning results...") return input_json_path.read_bytes() # TODO: test sharded DB @@ -302,25 +316,26 @@ def run_data_pipeline(json_bytes: bytes, copy_msa_to_ssd: bool = True) -> bytes: # if p.returncode != 0: # raise RuntimeError("Failed to extract mmCIF template files") msa_db_dir.append(str(db_dir)) - msa_db_dir.append(str(msa_db_path)) # fallback for mmCIF templates and RNA + msa_db_dir.append(APP_INFO.msa_db_dir) # fallback for mmCIF templates and RNA - work_dir = temp_dir / run_name - work_dir.mkdir(exist_ok=True) cmd = [ sys.executable, str(CONF.git_clone_dir / "run_alphafold.py"), "--run_inference=false", f"--json_path={input_json_path}", - f"--output_dir={work_dir}", - f"--model_dir={CONF.model_dir}", + f"--output_dir={temp_dir}", + f"--model_dir={CONF.model_volume_mountpoint}", *(f"--db_dir={d}" for d in msa_db_dir), "--jackhmmer_n_cpu=8", "--nhmmer_n_cpu=8", ] - run_command_with_log(cmd, log_file=work_dir / f"{run_name}.log", verbose=True) + run_command(cmd, verbose=True) # Cache unpaired MSA files in separate directories for future use - msa_json_path = work_dir / f"{run_name}_data.json" + msa_json_path = temp_dir / run_name / f"{run_name}_data.json" + if not msa_json_path.exists(): + print([x.relative_to(temp_dir) for x in temp_dir.rglob("*")]) + raise FileNotFoundError(f"MSA JSON file not found: {msa_json_path}") _ = _cache_conf_unpaired_msa(AF3Config.from_file(msa_json_path), msa_cache_dir) MSA_CACHE_VOLUME.commit() return msa_json_path.read_bytes() @@ -349,8 +364,7 @@ def search_msa_and_templates( tmp_path = Path(tmp_dir) data_pipeline_futures = [] for i, msa_chain in msa_chains: - input_conf = conf.model_copy(deep=True) - input_conf.sequences = [msa_chain] + input_conf = conf.model_copy(update={"sequences": [msa_chain]}) input_conf.to_files(tmp_path, str(i)) data_pipeline_futures.append( run_data_pipeline.spawn( @@ -376,9 +390,12 @@ def search_msa_and_templates( cpu=(0.125, 16.125), # burst for tar compression memory=(1024, 131072), # reserve 1GB, OOM at 128GB timeout=MAX_TIMEOUT, - volumes={ - CONF.model_volume_mountpoint: MODEL_VOLUME, # JAX cache - APP_INFO.msa_cache_dir: MSA_CACHE_VOLUME.read_only(), + # Writable model dir because AlphaFold3 writes its JAX cache next to weights + volumes=CONF.mounts(model_volume=True, model_ro=False) + | { + APP_INFO.msa_cache_dir: MSA_CACHE_VOLUME.with_mount_options( + read_only=True, sub_path=APP_INFO.msa_cache_volume_subdir + ) }, ) def run_inference_pipeline( @@ -404,6 +421,7 @@ def run_inference_pipeline( print(f"💊 Running inference for {run_name} with seeds {model_seeds}") out_dir = temp_path / run_name + model_dir = Path(CONF.model_volume_mountpoint) cmd = [ sys.executable, str(CONF.git_clone_dir / "run_alphafold.py"), @@ -411,8 +429,8 @@ def run_inference_pipeline( "--run_data_pipeline=false", f"--json_path={input_json_path}", f"--output_dir={out_dir}", - f"--model_dir={CONF.model_dir}", - f"--jax_compilation_cache_dir={CONF.model_dir / 'jax_cache'}", + f"--model_dir={model_dir}", + f"--jax_compilation_cache_dir={model_dir / 'jax_cache'}", f"--num_recycles={recycle}", f"--num_diffusion_samples={sample}", ] @@ -422,6 +440,127 @@ def run_inference_pipeline( return package_outputs(out_dir / run_name) +def predict_structures( + conf: AF3Config, + local_out_dir: Path, + recycle: int, + sample: int, + num_containers: int, + *, + poll_timeout: int = 5, +) -> Path: + """Run AF3 inference pipeline and save outputs to .tar.zst file.""" + run_name = conf.name + out_file = build_local_output_path(local_out_dir, run_name=run_name) + if out_file.exists(): + print(f"🧬 File already exists, skipping inference: {out_file}") + return out_file + + # Directly run inference pipeline if only one container is specified + json_bytes = conf.to_json().encode() + model_seeds = conf.modelSeeds + if num_containers == 1: + tarball_content = run_inference_pipeline.remote( + json_bytes, recycle=recycle, sample=sample, model_seeds=model_seeds + ) + write_local_tarball(out_file, tarball_content) + return out_file + + tar_binary = shutil.which("tar") or None + if tar_binary is None: + raise RuntimeError("🧬 tar command not found") + tar_cmd = [tar_binary, "-I", "zstd"] + + def _part_file(i: int) -> Path: + return local_out_dir / f"{run_name}_part{i}.tar.zst" + + def _is_good_tarball(tarball_file: Path) -> bool: + """Return whether an existing tarball is good enough to skip.""" + if not tarball_file.exists() or tarball_file.stat().st_size == 0: + return False + try: + run_command([*tar_cmd, "-tf", str(tarball_file)], verbose=False) + except Exception as exc: + print( + f"🧬 Existing part tarball is not readable; rerunning {tarball_file}: {exc}" + ) + return False + return True + + # Run inference in parallel for parts that are missing + inference_func_calls: dict[int, modal.FunctionCall] = {} + good_part_indices: set[int] = set() + for i in range(num_containers): + tarball_file = _part_file(i) + if _is_good_tarball(tarball_file): + good_part_indices.add(i) + continue + fc = run_inference_pipeline.spawn( + json_bytes, + recycle=recycle, + sample=sample, + model_seeds=model_seeds[i::num_containers], + ) + inference_func_calls[i] = fc + + # Collect results as they become available + failures: list[tuple[int, Exception]] = [] + while inference_func_calls: + for i, fc in inference_func_calls.copy().items(): + try: + tarball_content = fc.get(timeout=poll_timeout) + except TimeoutError: + print(f"🧬 Task {i} still running...") + continue + except Exception as exc: + failures.append((i, exc)) + del inference_func_calls[i] + print(f"🧬 Task {i} failed: {exc}") + continue + + tarball_file = _part_file(i) + tmp_file = tarball_file.with_suffix(".tmp") + write_local_tarball(tmp_file, tarball_content, overwrite=True) + tmp_file.replace(tarball_file) + del inference_func_calls[i] + + # Go through all expected tarball part files + tarball_part_files = [_part_file(i) for i in range(num_containers)] + for i, tarball_part_file in enumerate(tarball_part_files): + if i not in good_part_indices and _is_good_tarball(tarball_part_file): + good_part_indices.add(i) + unusable_part_files = [ + p for i, p in enumerate(tarball_part_files) if i not in good_part_indices + ] + if unusable_part_files: + saved = ( + ", ".join(str(tarball_part_files[i]) for i in sorted(good_part_indices)) + or "none" + ) + failed = "; ".join(f"part {i}: {exc}" for i, exc in failures) or "unknown" + raise RuntimeError( + "Some AlphaFold3 inference parts failed or did not produce readable " + "tarballs. " + f"Saved part tarballs: {saved}. Failed parts: {failed}. " + "Rerun the command to resume only missing parts." + ) + + # Run local extraction after everything is saved to avoid errors + with TemporaryDirectory() as tmp_dir: + for tar_filename in tarball_part_files: + run_command( + [*tar_cmd, "-xf", str(tar_filename)], verbose=False, cwd=tmp_dir + ) + + # Combine the parts into a single .tar.zst file + tarball_content = package_outputs(Path(tmp_dir) / run_name) + write_local_tarball(out_file, tarball_content) + print( + f"🧬 Note that top-level {run_name}_*.{{cif,json,csv}} may not be correct since they are from parallel workers" + ) + return out_file + + ########################################## # Entrypoint for ephemeral usage ########################################## @@ -452,6 +591,7 @@ def submit_alphafold3_task( on the number of model seeds in the JSON config. recycle: Number of Pairformer recycles to use during inference. sample: Number of diffusion samples to generate per seed. + """ # Validate and read input input_path = Path(input_json).expanduser().resolve() @@ -461,6 +601,7 @@ def submit_alphafold3_task( conf = AF3Config.from_file(input_path) if run_name is None: run_name = conf.name + conf.name = run_name # Run inference if search_msa: @@ -470,48 +611,14 @@ def submit_alphafold3_task( json_bytes = input_path.read_bytes() local_out_dir = resolve_local_output_dir(out_dir) - out_file = build_local_output_path(local_out_dir, run_name=run_name) - num_seeds = len(conf.modelSeeds) + + new_conf = _load_conf_from_bytes(json_bytes) + new_conf.name = run_name + new_conf.modelSeeds = conf.modelSeeds + num_seeds = len(new_conf.modelSeeds) num_containers = max(1, min(max_num_gpus, num_seeds)) print(f"🧬 Running {CONF.name} inference pipeline with {num_containers=}...") - if num_containers == 1: - tarball_bytes = run_inference_pipeline.remote( - json_bytes, recycle=recycle, sample=sample, model_seeds=conf.modelSeeds - ) - else: - inference_futures = [ - run_inference_pipeline.spawn( - json_bytes, - recycle=recycle, - sample=sample, - model_seeds=conf.modelSeeds[i::num_containers], - ) - for i in range(num_containers) - ] - tarballs = modal.FunctionCall.gather(*inference_futures) - - with TemporaryDirectory() as tmp_dir: - for i, tarball_bytes in enumerate(tarballs): - tar_filename = out_file.with_name(f"{run_name}_part{i}.tar.zst") - write_local_tarball(tar_filename, tarball_bytes, overwrite=True) - run_command( - [ - shutil.which("tar") or "tar", - "-I", - "zstd", - "-xf", - str(tar_filename), - ], - verbose=False, - cwd=tmp_dir, - ) - - # Combine the parts into a single .tar.zst file - tarball_bytes = package_outputs(Path(tmp_dir) / run_name) - print( - f"🧬 Note that top-level {run_name}_*.{{cif,json,csv}} may not be correct since they are from parallel workers" - ) - - # Save results locally - write_local_tarball(out_file, tarball_bytes) + out_file = predict_structures( + new_conf, local_out_dir, recycle, sample, num_containers + ) print(f"🧬 {CONF.name} run complete! Results saved to {out_file}") diff --git a/src/biomodals/app/fold/flowpacker_app.py b/src/biomodals/app/fold/flowpacker_app.py index 9c667d3..4f6d1a9 100644 --- a/src/biomodals/app/fold/flowpacker_app.py +++ b/src/biomodals/app/fold/flowpacker_app.py @@ -39,14 +39,27 @@ import modal from biomodals.app.config import AppConfig -from biomodals.app.constant import MODEL_VOLUME from biomodals.helper import patch_image_for_helper +from biomodals.helper.constant import MODEL_VOLUME from biomodals.helper.io import ( build_local_output_path, resolve_local_output_dir, write_local_tarball, ) -from biomodals.helper.shell import package_outputs, run_command_with_log +from biomodals.helper.shell import ( + copy_files, + package_outputs, + run_command_with_log, + sanitize_filename, + softlink_dir, +) +from biomodals.helper.volume_run import volume_path_from_mount_path +from biomodals.schema import ( + AppOutput, + AppRunResult, + AppRunStatus, + ArtifactKind, +) ########################################## # Modal configs @@ -96,7 +109,7 @@ class AppInfo: ########################################## APP_INFO = AppInfo() -runtime_image = patch_image_for_helper( +runtime_image = ( modal.Image .debian_slim(python_version=CONF.python_version) .apt_install("git", "git-lfs", "build-essential") @@ -114,6 +127,7 @@ class AppInfo: find_links=f"https://data.pyg.org/whl/torch-2.3.0+{CONF.cuda_version}.html", extra_options=f"--exclude-newer {APP_INFO.dependency_cutoff}", ) + .pipe(patch_image_for_helper) ) app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) @@ -121,65 +135,55 @@ class AppInfo: ########################################## # Fetch model weights ########################################## -def _checkpoint_path(name: str) -> Path: - """Return the shared-volume path for a FlowPacker checkpoint.""" - return CONF.model_dir / "checkpoints" / f"{name}.pth" - - @app.function( cpu=(0.125, 8.125), timeout=CONF.timeout, - volumes={CONF.model_volume_mountpoint: MODEL_VOLUME}, + volumes=CONF.mounts(model_volume=True, model_ro=False), ) def download_flowpacker_checkpoints(force: bool = False) -> None: """Download FlowPacker Git LFS checkpoints into the model volume.""" from biomodals.helper.shell import run_command - checkpoint_dir = CONF.model_dir / "checkpoints" - checkpoint_dir.mkdir(parents=True, exist_ok=True) - - wanted = [_checkpoint_path(name) for name in APP_INFO.checkpoint_names] - if not force and all(path.exists() for path in wanted): + cache_dir = Path(CONF.model_volume_mountpoint) + all_checkpoints = [cache_dir / f"{name}.pth" for name in APP_INFO.checkpoint_names] + missing_checkpoints = [ + f for f in all_checkpoints if not (f.exists() and f.stat().st_size > 1024) + ] + if not force and len(missing_checkpoints) == 0: print("💊 FlowPacker checkpoints already exist in the model volume") return if force: - _ = [path.unlink(missing_ok=True) for path in wanted] + _ = [path.unlink(missing_ok=True) for path in all_checkpoints] + missing_checkpoints = all_checkpoints - include_paths = ",".join( - f"checkpoints/{name}.pth" for name in APP_INFO.checkpoint_names - ) + include_paths = ",".join(f"checkpoints/{f.name}" for f in missing_checkpoints) run_command(["git", "lfs", "install", "--skip-repo"], cwd=CONF.git_clone_dir) run_command( - ["git", "lfs", "pull", "--include", include_paths, "--exclude", ""], + ["git", "lfs", "pull", "--include", include_paths], cwd=CONF.git_clone_dir, env={"GIT_LFS_SKIP_SMUDGE": "0"}, ) - - for name in APP_INFO.checkpoint_names: - src = CONF.git_clone_dir / "checkpoints" / f"{name}.pth" - dst = _checkpoint_path(name) - shutil.copy2(src, dst) - MODEL_VOLUME.commit() + checkpoint_dir = CONF.git_clone_dir / "checkpoints" + files_to_copy: dict[str | Path, str | Path] = {} + for f in missing_checkpoints: + source_path = checkpoint_dir / f.name + if not source_path.exists(): + raise FileNotFoundError( + f"FlowPacker Git LFS checkpoint was not downloaded: {source_path}" + ) + files_to_copy[source_path] = f + copy_files(files_to_copy) + MODEL_VOLUME.commit() print("💊 FlowPacker checkpoint download complete") ########################################## # Inference functions ########################################## -def _ensure_checkpoint_symlink() -> None: - """Link the repo checkpoint directory to the mounted model volume.""" - checkpoint_src = CONF.model_dir / "checkpoints" - checkpoint_link = CONF.git_clone_dir / "checkpoints" - if checkpoint_link.is_symlink() and checkpoint_link.resolve() == checkpoint_src: - return - - if checkpoint_link.is_dir() and not checkpoint_link.is_symlink(): - shutil.rmtree(checkpoint_link) - elif checkpoint_link.exists() or checkpoint_link.is_symlink(): - checkpoint_link.unlink() - - checkpoint_link.symlink_to(checkpoint_src, target_is_directory=True) +def _checkpoint_path(checkpoint_name: str) -> Path: + """Get the path to the checkpoint file in the Git LFS directory.""" + return CONF.git_clone_dir / "checkpoints" / f"{checkpoint_name}.pth" def _write_flowpacker_config( @@ -195,7 +199,7 @@ def _write_flowpacker_config( """Write the upstream FlowPacker inference config for one Modal run.""" import yaml - conf_ckpt = "./checkpoints/confidence.pth" if use_confidence else None + conf_ckpt = str(_checkpoint_path("confidence")) if use_confidence else None config = { "mode": "vf", "data": { @@ -209,7 +213,7 @@ def _write_flowpacker_config( "max_radius": 16.0, "max_neighbors": 30, }, - "ckpt": f"./checkpoints/{model_name}.pth", + "ckpt": str(_checkpoint_path(model_name)), "conf_ckpt": conf_ckpt, "sample": { "batch_size": 1, @@ -230,8 +234,8 @@ def _write_flowpacker_config( cpu=(0.125, 16.125), memory=(1024, 65536), timeout=CONF.timeout, - # Cannot mount as ro as FlowPacker mkdirs there for whatever reason - volumes={CONF.model_volume_mountpoint: MODEL_VOLUME}, + # Cannot mount as read-only because FlowPacker runs mkdir for some reason + volumes=CONF.mounts(model_volume=True, model_ro=False), ) def run_flowpacker( input_files: list[tuple[str, bytes]], @@ -254,6 +258,10 @@ def run_flowpacker( raise ValueError( f"Unsupported model '{model_name}'. Choose one of: {APP_INFO.supported_models}" ) + ckpt_dir = CONF.git_clone_dir / "checkpoints" + if ckpt_dir.exists() and not ckpt_dir.is_symlink(): + shutil.rmtree(ckpt_dir) + softlink_dir(CONF.model_volume_mountpoint, ckpt_dir) ckpt_path = _checkpoint_path(model_name) if not ckpt_path.exists(): raise FileNotFoundError(f"FlowPacker checkpoint is missing: {ckpt_path}") @@ -264,8 +272,7 @@ def run_flowpacker( f"FlowPacker confidence checkpoint is missing: {confidence_ckpt_path}" ) - _ensure_checkpoint_symlink() - + run_name = sanitize_filename(run_name) input_dir = Path(mkdtemp(prefix="flowpacker_inputs_")) sample_dir = CONF.git_clone_dir / "samples" / run_name if sample_dir.exists(): @@ -286,14 +293,7 @@ def run_flowpacker( sample_coeff=sample_coeff, ) - cmd = [ - sys.executable, - "sampler_pdb.py", - "biomodals", - run_name, - "--seed", - str(seed), - ] + cmd = [sys.executable, "sampler_pdb.py", "biomodals", run_name, "--seed", str(seed)] if save_traj: cmd.extend(["--save_traj", "True"]) if use_gt_masks: @@ -315,6 +315,68 @@ def run_flowpacker( return package_outputs(sample_dir) +@app.function( + gpu=CONF.gpu, + cpu=(0.125, 16.125), + memory=(1024, 65536), + timeout=CONF.timeout, + volumes=CONF.mounts(output_volume=True, model_volume=True, model_ro=False), +) +def run_flowpacker_workflow( + input_files: list[tuple[str, bytes]], + run_name: str, + model_name: str = "cluster", + use_confidence: bool = False, + n_samples: int = 1, + num_steps: int = APP_INFO.default_num_steps, + sample_coeff: float = APP_INFO.default_sample_coeff, + use_gt_masks: bool = False, + inpaint: str | None = None, + save_traj: bool = False, + seed: int = 42, +) -> AppRunResult: + """Run FlowPacker and return a workflow-compatible app result.""" + safe_run_name = sanitize_filename(run_name) + tarball_bytes = run_flowpacker.get_raw_f()( + input_files=input_files, + run_name=safe_run_name, + model_name=model_name, + use_confidence=use_confidence, + n_samples=n_samples, + num_steps=num_steps, + sample_coeff=sample_coeff, + use_gt_masks=use_gt_masks, + inpaint=inpaint, + save_traj=save_traj, + seed=seed, + ) + archive_filename = f"{safe_run_name}.tar.zst" + volume_root = Path(CONF.output_volume_mountpoint) + archive_path = volume_root / "workflow" / safe_run_name / archive_filename + archive_path.parent.mkdir(parents=True, exist_ok=True) + archive_path.write_bytes(tarball_bytes) + CONF.output_volume.commit() + return AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="flowpacker_outputs", + kind=ArtifactKind.ARCHIVE, + storage=volume_path_from_mount_path( + remote_path=str(archive_path), + mount_root=str(volume_root), + volume_name=CONF.output_volume_name, + media_type="application/zstd", + ), + metadata={ + "archive_format": "tar.zst", + "filename": archive_filename, + }, + ) + ], + ) + + ########################################## # Entrypoint for ephemeral usage ########################################## @@ -381,6 +443,7 @@ def submit_flowpacker_task( seed: Random seed for FlowPacker inference. download_models: Download FlowPacker checkpoints and exit without inference. force_redownload: Force checkpoint redownload even when cached files exist. + """ if model_name not in APP_INFO.supported_models: raise ValueError( @@ -405,14 +468,18 @@ def submit_flowpacker_task( run_name = ( resolved_input.stem if resolved_input.is_file() else resolved_input.name ) + safe_run_name = sanitize_filename(run_name) local_out_dir = resolve_local_output_dir(out_dir) - out_file = build_local_output_path(local_out_dir, run_name=run_name) + out_file = build_local_output_path(local_out_dir, run_name=safe_run_name) - print(f"🧬 Submitting FlowPacker run '{run_name}' with {len(input_files)} input(s)") + print( + f"🧬 Submitting FlowPacker run '{safe_run_name}' " + f"with {len(input_files)} input(s)" + ) tarball_bytes = run_flowpacker.remote( input_files=input_files, - run_name=run_name, + run_name=safe_run_name, model_name=model_name, use_confidence=use_confidence, n_samples=n_samples, diff --git a/src/biomodals/app/fold/protenix_app.py b/src/biomodals/app/fold/protenix_app.py index e9e399e..9487624 100644 --- a/src/biomodals/app/fold/protenix_app.py +++ b/src/biomodals/app/fold/protenix_app.py @@ -36,8 +36,13 @@ import modal from biomodals.app.config import AppConfig -from biomodals.app.constant import MAX_TIMEOUT, MODEL_VOLUME, MSA_CACHE_VOLUME from biomodals.helper import hash_string, patch_image_for_helper +from biomodals.helper.constant import ( + MAX_TIMEOUT, + MODEL_VOLUME, + MSA_CACHE_VOLUME, + MSA_CACHE_VOLUME_NAME, +) from biomodals.helper.io import ( build_local_output_path, resolve_local_output_dir, @@ -79,9 +84,7 @@ class AppInfo: cuda_tag = f"{CONF.cuda_version_numeric}-devel-ubuntu24.04" # Volume for preprocessed MSA/template intermediates (MSA_CACHE_VOLUME) - msa_cache_dir: str = "/protenix-msa" - # Volume for prediction outputs (enables skip/resume across interrupted runs) - model_dir: Path = CONF.model_dir + msa_cache_volume_subdir: str = f"/{CONF.name}" # Base URL for downloading checkpoints and data caches # https://github.com/bytedance/Protenix/blob/main/protenix/web_service/dependency_url.py @@ -111,7 +114,7 @@ class AppInfo: # Image and app definitions ########################################## APP_INFO = AppInfo() -runtime_image = patch_image_for_helper( +runtime_image = ( modal.Image .from_registry(f"nvidia/cuda:{APP_INFO.cuda_tag}", add_python=CONF.python_version) .entrypoint([]) # remove verbose logging in the base image @@ -120,8 +123,10 @@ class AppInfo: CONF.default_env | { "PYTHONUNBUFFERED": "1", - "PROTENIX_ROOT_DIR": str(CONF.model_dir), - "PROTENIX_CHECKPOINT_DIR": str(CONF.model_dir / "checkpoint"), + "PROTENIX_ROOT_DIR": CONF.model_volume_mountpoint, + "PROTENIX_CHECKPOINT_DIR": str( + Path(CONF.model_volume_mountpoint) / "checkpoint" + ), } ) .uv_pip_install( @@ -134,6 +139,7 @@ class AppInfo: gpu=CONF.gpu, env={"LAYERNORM_TYPE": "fast_layernorm"}, # default, but just in case ) + .pipe(patch_image_for_helper) ) app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) @@ -142,7 +148,7 @@ class AppInfo: # Fetch model weights and data caches ########################################## @app.function( - volumes={CONF.model_volume_mountpoint: MODEL_VOLUME}, timeout=CONF.timeout + volumes=CONF.mounts(model_volume=True, model_ro=False), timeout=CONF.timeout ) def download_protenix_data( model_name: str = "protenix_base_default_v1.0.0", @@ -157,7 +163,7 @@ def download_protenix_data( include_templates: Also download template-related data files. """ - data_root = CONF.model_dir + data_root = Path(CONF.model_volume_mountpoint) files_to_download: dict[str, str | Path] = {} # Download common data caches @@ -192,7 +198,14 @@ def download_protenix_data( ########################################## # Inference functions ########################################## -@app.function(timeout=CONF.timeout, volumes={APP_INFO.msa_cache_dir: MSA_CACHE_VOLUME}) +@app.function( + timeout=CONF.timeout, + volumes={ + MSA_CACHE_VOLUME_NAME: MSA_CACHE_VOLUME.with_mount_options( + sub_path=APP_INFO.msa_cache_volume_subdir + ) + }, +) def query_protenix_msa_server( query_command: str, input_json_path: str, output_dir: str, msa_server_mode: str ) -> None: @@ -247,7 +260,14 @@ def _get_new_location(old_path: str | None) -> str | None: MSA_CACHE_VOLUME.commit() -@app.function(timeout=CONF.timeout, volumes={APP_INFO.msa_cache_dir: MSA_CACHE_VOLUME}) +@app.function( + timeout=CONF.timeout, + volumes={ + MSA_CACHE_VOLUME_NAME: MSA_CACHE_VOLUME.with_mount_options( + sub_path=APP_INFO.msa_cache_volume_subdir + ) + }, +) def prepare_protenix_inputs( input_bytes: bytes, msa_server_mode: str = "protenix", @@ -300,11 +320,7 @@ def prepare_protenix_inputs( else hash_string(":".join(protein_seqs)) ) cache_dir = ( - Path(APP_INFO.msa_cache_dir) - / CONF.name - / msa_server_mode - / hash_key[:2] - / hash_key + Path(MSA_CACHE_VOLUME_NAME) / msa_server_mode / hash_key[:2] / hash_key ) cache_dir.mkdir(parents=True, exist_ok=True) output_dirs.append(str(cache_dir)) @@ -350,9 +366,11 @@ def prepare_protenix_inputs( cpu=(1.125, 16.125), # burst for tar compression memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=MAX_TIMEOUT, - volumes={ - CONF.model_volume_mountpoint: MODEL_VOLUME.read_only(), - APP_INFO.msa_cache_dir: MSA_CACHE_VOLUME, + volumes=CONF.mounts(model_volume=True) + | { + MSA_CACHE_VOLUME_NAME: MSA_CACHE_VOLUME.with_mount_options( + sub_path=APP_INFO.msa_cache_volume_subdir + ) }, ) def run_protenix( @@ -443,8 +461,8 @@ def run_protenix( input_seqs = struct2seq(input_file) cache_key = hash_string(":".join(x[1] for x in input_seqs)) score_msa_cache_dir = ( - Path(APP_INFO.msa_cache_dir) - / f"{CONF.name}_score" + Path(MSA_CACHE_VOLUME_NAME) + / "score" / msa_server_mode / cache_key[:2] / cache_key diff --git a/src/biomodals/app/score/abnativ_app.py b/src/biomodals/app/score/abnativ_app.py index c66b799..0d44458 100644 --- a/src/biomodals/app/score/abnativ_app.py +++ b/src/biomodals/app/score/abnativ_app.py @@ -15,9 +15,9 @@ import modal from biomodals.app.config import AppConfig -from biomodals.app.constant import MAX_TIMEOUT, MODEL_VOLUME from biomodals.helper import patch_image_for_helper -from biomodals.helper.shell import package_outputs, run_command, softlink_dir +from biomodals.helper.constant import MAX_TIMEOUT, MODEL_VOLUME +from biomodals.helper.shell import package_outputs, run_command ########################################## # Modal configs @@ -31,15 +31,14 @@ python_version="3.12", cuda_version="cu128", gpu=os.environ.get("GPU", "A10G"), + model_volume_mountpoint="/root/.abnativ/models/pretrained_models", ) -# AbNatiV hard-coded cache directory for model weights -ABNATIV_MODEL_DIR = "/root/.abnativ/models/pretrained_models" ########################################## # Image and app definitions ########################################## -runtime_image = patch_image_for_helper( +runtime_image = ( modal.Image .micromamba(python_version=CONF.python_version) .apt_install("git", "build-essential", "wget", "zstd") @@ -47,6 +46,7 @@ .micromamba_install(["openmm", "pdbfixer", "biopython"], channels=["conda-forge"]) .micromamba_install(["anarci"], channels=["bioconda"]) .uv_pip_install(f"{CONF.package_name}=={CONF.version}") + .pipe(patch_image_for_helper) ) app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) @@ -56,14 +56,11 @@ ########################################## @app.function( cpu=(1.125, 16.125), - volumes={CONF.model_volume_mountpoint: MODEL_VOLUME}, + volumes=CONF.mounts(model_volume=True, model_ro=False), timeout=MAX_TIMEOUT, ) def download_abnativ_models(force: bool = False) -> None: """Download AbNatiV models into the mounted volume.""" - # Make soft link from AbNatiV's expected model directory to the mounted volume - softlink_dir(CONF.model_dir, ABNATIV_MODEL_DIR) - # Download all artifacts print(f"💊 Downloading {CONF.name} models...") cmd = ["abnativ", "init"] @@ -83,7 +80,7 @@ def download_abnativ_models(force: bool = False) -> None: cpu=(1.125, 16.125), # burst for tar compression memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=CONF.timeout, - volumes={CONF.model_volume_mountpoint: MODEL_VOLUME.read_only()}, + volumes=CONF.mounts(model_volume=True), ) def abnativ_score_unpaired( fasta_bytes: bytes, @@ -98,8 +95,6 @@ def abnativ_score_unpaired( """Manage AbNatiV runs and return all score results.""" from tempfile import TemporaryDirectory - softlink_dir(CONF.model_dir, ABNATIV_MODEL_DIR) - with TemporaryDirectory() as tmpdir: work_path = Path(tmpdir) / f"{output_id}_abnativ_{nativeness_type}" work_path.mkdir() @@ -141,7 +136,7 @@ def abnativ_score_unpaired( cpu=(1.125, 16.125), # burst for tar compression memory=(1024, 65536), # reserve 1GB, OOM at 64GB timeout=CONF.timeout, - volumes={CONF.model_volume_mountpoint: MODEL_VOLUME.read_only()}, + volumes=CONF.mounts(model_volume=True), ) def abnativ_score_paired( csv_bytes: bytes, @@ -154,8 +149,6 @@ def abnativ_score_paired( """Manage AbNatiV runs and return all score results.""" from tempfile import TemporaryDirectory - softlink_dir(CONF.model_dir, ABNATIV_MODEL_DIR) - with TemporaryDirectory() as tmpdir: work_path = Path(tmpdir) / f"{output_id}_abnativ_paired" work_path.mkdir() diff --git a/src/biomodals/app/score/af3score_app.py b/src/biomodals/app/score/af3score_app.py index fe81583..c912334 100644 --- a/src/biomodals/app/score/af3score_app.py +++ b/src/biomodals/app/score/af3score_app.py @@ -15,6 +15,7 @@ import os import shutil +import string import sys from dataclasses import dataclass from pathlib import Path @@ -23,7 +24,6 @@ import modal from biomodals.app.config import AppConfig -from biomodals.app.constant import MODEL_VOLUME from biomodals.helper import patch_image_for_helper from biomodals.helper.shell import ( copy_files, @@ -34,6 +34,7 @@ from biomodals.helper.volume_run import ( build_volume_run_paths, has_completed_output_files, + volume_path_from_mount_path, ) ########################################## @@ -56,14 +57,19 @@ class AppInfo: """Container for AF3Score-specific configuration and constants.""" af3_weights: str = "AlphaFold3/af3.bin" - out_volume: modal.Volume = CONF.get_out_volume() + metrics_filename: str = "af3score_metrics.csv" + completion_sample_subdir: str = "seed-10_sample-0" + completion_required_files: tuple[str, ...] = ( + "summary_confidences.json", + "confidences.json", + ) ########################################## # Image and app definitions ########################################## APP_INFO = AppInfo() -runtime_image = patch_image_for_helper( +runtime_image = ( modal.Image .debian_slim(python_version=CONF.python_version) .apt_install( @@ -89,41 +95,14 @@ class AppInfo: .workdir(str(CONF.git_clone_dir)) .uv_pip_install(str(CONF.git_clone_dir), "biopython", "h5py", "pandas") .run_commands("build_data") + .pipe(patch_image_for_helper) ) app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) ########################################## -# Helper functions +# Local input collection ########################################## -def _get_af3_sanitized_name(name: str) -> str: - """Sanitize a name to be compatible with AF3Score's internal naming.""" - import string - - lower_spaceless_name = name.lower().replace(" ", "_") - allowed_chars = set(string.ascii_lowercase + string.digits + "_-.") - return "".join(c for c in lower_spaceless_name if c in allowed_chars) - - -def _run_paths(run_name: str) -> dict[str, Path]: - """Return the standard run-level paths for one AF3Score output directory.""" - return build_volume_run_paths( - Path(CONF.output_volume_mountpoint), - run_name, - metrics_filename="af3score_metrics.csv", - ) - - -def _has_completed_outputs(output_dir: Path, input_id: str) -> bool: - """Check whether AF3Score wrote the required output JSON files.""" - return has_completed_output_files( - output_dir, - input_id, - sample_subdir="seed-10_sample-0", - required_files=("summary_confidences.json", "confidences.json"), - ) - - def _collect_input_files(input_root: Path, stage_dir: Path) -> list[Path]: """Collect supported AF3Score input files from a file or directory.""" if not input_root.exists(): @@ -138,8 +117,14 @@ def _collect_input_files(input_root: Path, stage_dir: Path) -> list[Path]: raise ValueError(f"No .pdb files were found in '{input_root}'.") symlinks: list[Path] = [] + allowed_chars = set(string.ascii_lowercase + string.digits + "_-.") for f in all_files: - symlink_path = stage_dir / _get_af3_sanitized_name(f.name) + safe_name = "".join( + c for c in f.name.lower().replace(" ", "_") if c in allowed_chars + ) + if not safe_name: + raise ValueError(f"Input file name has no AF3Score-safe characters: {f}") + symlink_path = stage_dir / safe_name if symlink_path.exists(): raise ValueError(f"Duplicated sanitized file name: {symlink_path.name}") symlink_path.symlink_to(f) @@ -147,18 +132,6 @@ def _collect_input_files(input_root: Path, stage_dir: Path) -> list[Path]: return symlinks -def _adjust_num_cpu_gpu( - total_num_files: int, max_num_batches: int, max_num_workers: int -) -> tuple[int, int]: - """Adjust the number of CPU workers and GPU batches based on the total number of files.""" - max_num_batches = min(max(1, max_num_batches), total_num_files) - num_jobs_per_batch = max( - 1, (total_num_files + max_num_batches - 1) // max_num_batches - ) - adjusted_max_num_workers = min(max(1, max_num_workers), num_jobs_per_batch) - return max_num_batches, adjusted_max_num_workers - - @dataclass class ChunkSpec: """Container for AF3Score batch chunk specifications.""" @@ -188,13 +161,17 @@ class TaskSpec: cpu=(0.125, 1.125), memory=(512, 2048), timeout=CONF.timeout, - volumes={CONF.output_volume_mountpoint: AppInfo.out_volume}, + volumes=CONF.mounts(output_volume=True), ) def af3score_manage_lock(run_name: str, acquire: bool = True) -> None: """Internal-only remote helper for acquiring or releasing one run-level lock.""" # TODO: replace with a task queue; mkdir in Volumes may not be atomic - AppInfo.out_volume.reload() - paths = _run_paths(run_name) + CONF.output_volume.reload() + paths = build_volume_run_paths( + CONF.output_volume_mountpoint, + run_name, + metrics_filename=APP_INFO.metrics_filename, + ) root_dir = paths["run_root"] lock_dir = root_dir / ".run.lock" if acquire: @@ -205,25 +182,30 @@ def af3score_manage_lock(run_name: str, acquire: bool = True) -> None: raise RuntimeError( f"`{run_name=}` is already in use by another active AF3Score run." ) from exc - AppInfo.out_volume.commit() + CONF.output_volume.commit() return if lock_dir.exists(): lock_dir.rmdir() - AppInfo.out_volume.commit() + CONF.output_volume.commit() @app.function( cpu=(1.125, 16.125), memory=(1024, 32768), timeout=CONF.timeout, - volumes={CONF.output_volume_mountpoint: AppInfo.out_volume}, + volumes=CONF.mounts(output_volume=True), ) def af3score_prepare( - paths: dict[str, Path], input_files: list[str], num_jobs: int, prepare_workers: int + run_name: str, input_files: list[str], num_jobs: int, prepare_workers: int ) -> TaskSpec: """Prepare AF3Score batches from staged inputs.""" - AppInfo.out_volume.reload() + CONF.output_volume.reload() + paths = build_volume_run_paths( + CONF.output_volume_mountpoint, + run_name, + metrics_filename=APP_INFO.metrics_filename, + ) staged_dir = paths["inputs_dir"].resolve() if not staged_dir.exists(): raise FileNotFoundError(f"Staged input directory not found: {staged_dir}") @@ -240,7 +222,12 @@ def af3score_prepare( skipped = 0 out_dir = paths["output_dir"] for pdb_file in all_files: - if _has_completed_outputs(out_dir, pdb_file.stem): + if has_completed_output_files( + out_dir, + pdb_file.stem, + sample_subdir=APP_INFO.completion_sample_subdir, + required_files=APP_INFO.completion_required_files, + ): skipped += 1 continue pending_files.append(pdb_file) @@ -268,9 +255,9 @@ def af3score_prepare( for source_path in pending_files }) # Adjust CPU and GPU resources - n_batches, n_cpu = _adjust_num_cpu_gpu( - len(pending_files), num_jobs, prepare_workers - ) + n_batches = min(max(1, num_jobs), len(pending_files)) + num_jobs_per_batch = max(1, (len(pending_files) + n_batches - 1) // n_batches) + n_cpu = min(max(1, prepare_workers), num_jobs_per_batch) run_command([ sys.executable, str(CONF.git_clone_dir / "01_prepare_get_json.py"), @@ -314,16 +301,20 @@ def af3score_prepare( cpu=(0.125, 16.125), memory=(1024, 65536), timeout=CONF.timeout, - volumes={ - CONF.output_volume_mountpoint: AppInfo.out_volume, - CONF.model_volume_mountpoint: MODEL_VOLUME.read_only(), - }, + volumes=CONF.mounts( + output_volume=True, model_volume=True, model_mount_subdir=False + ), ) def af3score_run( - paths: dict[str, Path], batch_name: str, batch_json_dir: str, batch_pdb_dir: str -): + run_name: str, batch_name: str, batch_json_dir: str, batch_pdb_dir: str +) -> None: """Run one AF3Score batch.""" - AppInfo.out_volume.reload() + CONF.output_volume.reload() + paths = build_volume_run_paths( + CONF.output_volume_mountpoint, + run_name, + metrics_filename=APP_INFO.metrics_filename, + ) af3_weights = Path(CONF.model_volume_mountpoint) / APP_INFO.af3_weights if not af3_weights.exists(): raise FileNotFoundError(f"AlphaFold3 model weights not found: {af3_weights}") @@ -371,20 +362,23 @@ def af3score_run( ], log_file=out_dir / f"{batch_name}.log", ) - AppInfo.out_volume.commit() + CONF.output_volume.commit() @app.function( cpu=(0.125, 16.125), memory=(1024, 16384), timeout=CONF.timeout, - volumes={CONF.output_volume_mountpoint: AppInfo.out_volume}, + volumes=CONF.mounts(output_volume=True), ) -def af3score_postprocess( - input_files: list[str], paths: dict[str, Path] -) -> dict[str, int | str]: +def af3score_postprocess(run_name: str, input_files: list[str]) -> dict[str, int | str]: """Validate records and collect metrics for all inputs.""" - AppInfo.out_volume.reload() + CONF.output_volume.reload() + paths = build_volume_run_paths( + CONF.output_volume_mountpoint, + run_name, + metrics_filename=APP_INFO.metrics_filename, + ) for path in (paths["output_dir"], paths["failed_dir"]): path.mkdir(parents=True, exist_ok=True) @@ -395,7 +389,12 @@ def af3score_postprocess( for input_name in input_files: input_id = Path(input_name).stem failed_record = paths["failed_dir"] / f"{input_id}.err" - if _has_completed_outputs(out_dir, input_id): + if has_completed_output_files( + out_dir, + input_id, + sample_subdir=APP_INFO.completion_sample_subdir, + required_files=APP_INFO.completion_required_files, + ): if failed_record.exists(): failed_record.unlink() processed += 1 @@ -437,7 +436,7 @@ def af3score_postprocess( if paths["prep_dir"].exists(): shutil.rmtree(paths["prep_dir"]) - AppInfo.out_volume.commit() + CONF.output_volume.commit() return { "output_dir": str(out_dir), "failed_dir": str(paths["failed_dir"]), @@ -485,18 +484,27 @@ def submit_af3score_task( print(f"🧬 Total files: {num_files} found in '{input_root}'") run_name = sanitize_filename(run_name) - run_paths = _run_paths(run_name) + run_paths = build_volume_run_paths( + CONF.output_volume_mountpoint, + run_name, + metrics_filename=APP_INFO.metrics_filename, + ) if not force: - for x in AppInfo.out_volume.iterdir("/"): + for x in CONF.output_volume.iterdir("/"): if x.path == run_name: raise ValueError( f"Run name '{run_name}' already exists in Modal volume." ) + remote_run_dir = volume_path_from_mount_path( + str(run_paths["run_root"]), + CONF.output_volume_mountpoint, + CONF.output_volume_name, + ) af3score_manage_lock.remote(run_name=run_name, acquire=True) try: - print(f"🧬 Uploading '{input_root}' to Modal") + print(f"🧬 Uploading '{input_root}' to {remote_run_dir}") stage_root = run_paths["inputs_dir"].relative_to(run_paths["mount_root"]) - with AppInfo.out_volume.batch_upload(force=force) as batch: + with CONF.output_volume.batch_upload(force=force) as batch: if num_files == 1: f = all_files[0] batch.put_file(f, f"/{stage_root}/{f.name}") @@ -504,7 +512,7 @@ def submit_af3score_task( batch.put_directory(all_files[0].parent, f"/{stage_root}/") prepare_result = af3score_prepare.remote( - paths=run_paths, + run_name=run_name, input_files=[path.name for path in all_files], num_jobs=max_batches, prepare_workers=prepare_workers, @@ -517,15 +525,6 @@ def submit_af3score_task( chunk_specs = prepare_result.chunk_specs total_chunks = len(chunk_specs) - def _af3score_run(spec: ChunkSpec) -> None: - """Submit one AF3Score batch run as a remote function call.""" - af3score_run.remote( - paths=run_paths, - batch_name=spec.batch_name, - batch_json_dir=spec.batch_json_dir, - batch_pdb_dir=spec.batch_pdb_dir, - ) - if total_chunks: max_batches = min(max_batches, total_chunks) print(f"🧬 Running {total_chunks} batches with a max of {max_batches} GPUs") @@ -533,13 +532,22 @@ def _af3score_run(spec: ChunkSpec) -> None: from concurrent.futures import ThreadPoolExecutor with ThreadPoolExecutor(max_workers=max_batches) as executor: - futures = [executor.submit(_af3score_run, spec) for spec in chunk_specs] + futures = [ + executor.submit( + af3score_run.remote, + run_name=run_name, + batch_name=spec.batch_name, + batch_json_dir=spec.batch_json_dir, + batch_pdb_dir=spec.batch_pdb_dir, + ) + for spec in chunk_specs + ] for future in futures: future.result() # wait for all workers to finish postprocess_result = af3score_postprocess.remote( + run_name=run_name, input_files=prepare_result.input_files, - paths=run_paths, ) for key, value in postprocess_result.items(): prefix = "[METRICS]" if str(key).startswith("metrics_") else "[POSTPROCESS]" @@ -558,7 +566,7 @@ def _af3score_run(spec: ChunkSpec) -> None: local_metrics_csv = local_out_dir / f"{run_name}_af3score_metrics.csv" print("🧬 Downloading metrics CSV...") with open(local_metrics_csv, "wb") as f: - for chunk in AppInfo.out_volume.read_file( + for chunk in CONF.output_volume.read_file( str(run_paths["metrics_csv"].relative_to(run_paths["mount_root"])) ): f.write(chunk) diff --git a/src/biomodals/app/score/dockq_app.py b/src/biomodals/app/score/dockq_app.py new file mode 100644 index 0000000..9f6b885 --- /dev/null +++ b/src/biomodals/app/score/dockq_app.py @@ -0,0 +1,314 @@ +"""DockQ source repo: . + +DockQ compares a predicted/model complex against a reference/native complex and +reports continuous docking-quality metrics. This wrapper accepts a CSV of +model/reference file pairs for standalone use and exposes a deployed batch +function for workflows that already have structure bytes in memory. + +The upstream DockQ documentation is available at . + +## Outputs + +Results are saved locally as `.tar.zst`. The archive contains +`dockq_results.csv` and one raw `.log` file per scored pair. +""" + +# Ignore ruff warnings about import location +# ruff: noqa: PLC0415 + +from __future__ import annotations + +import os +import re +import shlex +import subprocess +from collections.abc import Iterable +from pathlib import Path +from tempfile import TemporaryDirectory + +import modal +import polars as pl + +from biomodals.app.config import AppConfig +from biomodals.helper import patch_image_for_helper +from biomodals.helper.io import ( + build_local_output_path, + resolve_local_output_dir, + write_local_tarball, +) +from biomodals.helper.shell import package_outputs, sanitize_filename + +########################################## +# Modal configs +########################################## +CONF = AppConfig( + tags={"group": Path(__file__).parent.name}, + name="DockQ", + repo_url="https://github.com/y1zhou/DockQ", + package_name="DockQ", + version="2.1.3", + python_version="3.12", + timeout=int(os.environ.get("TIMEOUT", "3600")), +) + + +########################################## +# Image and app definitions +########################################## +runtime_image = ( + modal.Image + .debian_slim(python_version=CONF.python_version) + .apt_install("zstd") + .env(CONF.default_env) + .uv_pip_install(f"{CONF.package_name}=={CONF.version}", "pandas") + .pipe(patch_image_for_helper) +) +app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) + + +########################################## +# Helper functions +########################################## +METRIC_KEYS = { + "dockq": "dockq", + "irmsd": "irmsd", + "lrmsd": "lrmsd", + "fnat": "fnat", + "fnonnat": "fnonnat", + "f1": "f1", + "clashes": "clashes", +} + + +def _parse_short_metrics(output: str) -> dict[str, str]: + """Parse DockQ short-output key/value metrics.""" + metrics: dict[str, str] = {} + tokens = output.replace("\t", " ").split() + for idx, token in enumerate(tokens[:-1]): + normalized = token.rstrip(":").lower() + if normalized in METRIC_KEYS: + value = tokens[idx + 1] + if re.fullmatch(r"-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?", value): + metrics[METRIC_KEYS[normalized]] = value + + mapping_match = re.search(r"\bmapping\s+(\S+)", output) + if mapping_match: + metrics["mapping"] = mapping_match.group(1) + return metrics + + +def _safe_structure_name(name: str, fallback: str) -> str: + """Return a safe structure filename preserving common structure suffixes.""" + raw = Path(name).name or fallback + safe = sanitize_filename(raw) + suffix = Path(safe).suffix.lower() + if suffix not in {".pdb", ".cif", ".gz"}: + safe = f"{safe}.pdb" + return safe + + +def _row_from_pair( + pair: dict[str, object], + *, + pair_idx: int, + workdir: Path, + dockq_args: list[str], +) -> dict[str, str]: + """Run DockQ for one model/reference pair and return a CSV row.""" + pair_id = str(pair.get("id") or f"pair_{pair_idx}") + pair_dir = workdir / sanitize_filename(pair_id) + pair_dir.mkdir(parents=True, exist_ok=True) + + model_name = _safe_structure_name(str(pair.get("model_name") or ""), "model.pdb") + reference_name = _safe_structure_name( + str(pair.get("reference_name") or ""), "reference.pdb" + ) + model_path = pair_dir / model_name + reference_path = pair_dir / reference_name + + model_bytes = pair.get("model_bytes") + reference_bytes = pair.get("reference_bytes") + if not isinstance(model_bytes, bytes): + raise TypeError(f"DockQ pair {pair_id!r} is missing model_bytes") + if not isinstance(reference_bytes, bytes): + raise TypeError(f"DockQ pair {pair_id!r} is missing reference_bytes") + + model_path.write_bytes(model_bytes) + reference_path.write_bytes(reference_bytes) + + cmd = ["DockQ", str(model_path), str(reference_path), *dockq_args] + if mapping := pair.get("mapping"): + cmd.extend(["--mapping", str(mapping)]) + + print(f"💊 Running DockQ for {pair_id}") + completed = subprocess.run( # noqa: S603 + cmd, + check=False, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + raw_output = completed.stdout or "" + log_path = pair_dir / "dockq.log" + log_path.write_text(raw_output) + + metrics = _parse_short_metrics(raw_output) + row = { + "id": pair_id, + "model": model_name, + "reference": reference_name, + "dockq": metrics.get("dockq", ""), + "irmsd": metrics.get("irmsd", ""), + "lrmsd": metrics.get("lrmsd", ""), + "fnat": metrics.get("fnat", ""), + "fnonnat": metrics.get("fnonnat", ""), + "f1": metrics.get("f1", ""), + "clashes": metrics.get("clashes", ""), + "mapping": metrics.get("mapping", str(pair.get("mapping") or "")), + "returncode": str(completed.returncode), + "log": str(log_path.relative_to(workdir)), + } + if completed.returncode != 0: + row["error"] = raw_output.strip().splitlines()[-1] if raw_output.strip() else "" + else: + row["error"] = "" + return row + + +def _write_results_csv(rows: Iterable[dict[str, str]], csv_path: Path) -> None: + """Write DockQ result rows to CSV.""" + fieldnames = [ + "id", + "model", + "reference", + "dockq", + "irmsd", + "lrmsd", + "fnat", + "fnonnat", + "f1", + "clashes", + "mapping", + "returncode", + "error", + "log", + ] + normalized = [ + {field: str(row.get(field, "")) for field in fieldnames} for row in rows + ] + pl.DataFrame( + normalized, + schema={field: pl.String for field in fieldnames}, + ).write_csv(csv_path) + + +########################################## +# Inference functions +########################################## +@app.function( + cpu=(0.125, 16.125), + memory=(512, 16384), + timeout=CONF.timeout, +) +def run_dockq_batch( + pairs: list[dict[str, object]], + run_name: str, + dockq_args: list[str] | None = None, +) -> bytes: + """Run DockQ on model/reference pairs and return packaged outputs.""" + dockq_args = dockq_args or ["--short"] + if not pairs: + raise ValueError("At least one DockQ pair is required") + + safe_run_name = sanitize_filename(run_name) + with TemporaryDirectory(prefix=f"dockq_{safe_run_name}_") as tmpdir: + out_dir = Path(tmpdir) / safe_run_name + out_dir.mkdir(parents=True, exist_ok=True) + + rows = [ + _row_from_pair( + pair, + pair_idx=idx, + workdir=out_dir, + dockq_args=dockq_args, + ) + for idx, pair in enumerate(pairs, start=1) + ] + _write_results_csv(rows, out_dir / "dockq_results.csv") + return package_outputs(out_dir) + + +########################################## +# Entrypoint for ephemeral usage +########################################## +def _pairs_from_csv(input_csv: Path) -> list[dict[str, object]]: + """Read standalone DockQ pair specs from a local CSV.""" + df = pl.read_csv(input_csv, infer_schema_length=0) + required = {"model", "reference"} + if not required.issubset(set(df.columns)): + raise ValueError( + f"DockQ input CSV must contain columns {sorted(required)}, got {df.columns}" + ) + + pairs: list[dict[str, object]] = [] + for idx, row in enumerate(df.iter_rows(named=True), start=1): + model = Path(row["model"]).expanduser() + reference = Path(row["reference"]).expanduser() + if not model.is_absolute(): + model = input_csv.parent / model + if not reference.is_absolute(): + reference = input_csv.parent / reference + if not model.exists(): + raise FileNotFoundError(f"Model structure not found: {model}") + if not reference.exists(): + raise FileNotFoundError(f"Reference structure not found: {reference}") + pairs.append({ + "id": row.get("id") or f"pair_{idx}", + "model_name": model.name, + "model_bytes": model.read_bytes(), + "reference_name": reference.name, + "reference_bytes": reference.read_bytes(), + "mapping": row.get("mapping") or None, + }) + return pairs + + +@app.local_entrypoint() +def submit_dockq_task( + input_csv: str, + out_dir: str | None = None, + run_name: str | None = None, + dockq_args: str = "--short", +) -> None: + """Run DockQ scoring on model/reference pairs from a CSV. + + Args: + input_csv: CSV with `model` and `reference` columns, plus optional + `id` and `mapping` columns. Relative paths resolve from the CSV + parent directory. + out_dir: Optional local output directory. If not specified, outputs + will be saved in the current working directory. + run_name: Optional run name for output files. Defaults to the input + CSV filename stem. + dockq_args: Extra DockQ CLI arguments, shell-split before execution. + Defaults to `--short`. + """ + input_path = Path(input_csv).expanduser().resolve() + if not input_path.exists(): + raise FileNotFoundError(f"Input CSV not found: {input_path}") + if run_name is None: + run_name = input_path.stem + + local_out_dir = resolve_local_output_dir(out_dir) + out_file = build_local_output_path(local_out_dir, run_name=run_name) + + pairs = _pairs_from_csv(input_path) + print(f"🧬 Submitting DockQ run '{run_name}' with {len(pairs)} pair(s)") + tarball_bytes = run_dockq_batch.remote( + pairs=pairs, + run_name=run_name, + dockq_args=shlex.split(dockq_args), + ) + + write_local_tarball(out_file, tarball_bytes) + print(f"🧬 DockQ run complete! Results saved to {out_file}") diff --git a/src/biomodals/cli.py b/src/biomodals/cli.py index e2f3a06..7289dd0 100644 --- a/src/biomodals/cli.py +++ b/src/biomodals/cli.py @@ -9,12 +9,20 @@ from rich.markdown import Markdown from rich.table import Table -from biomodals.app.catalog import AppNotFoundError, BiomodalsApp, get_all_apps +from biomodals.helper.catalog import ( + WORKFLOW_HOME, + AppNotFoundError, + BiomodalsApp, + CatalogType, + get_catalog, +) from biomodals.helper.shell import run_command # ruff: noqa: S603 app = typer.Typer() +app_commands = typer.Typer(no_args_is_help=True) +workflow_commands = typer.Typer(no_args_is_help=True) console = Console() @@ -27,15 +35,40 @@ def callback(): ... +app.add_typer(app_commands, name="app", help="Discover and run Biomodals apps.") +app.add_typer( + workflow_commands, name="workflow", help="Discover Biomodals workflow entrypoints." +) + + ########################################## -# CLI Commands +# Helper functions ########################################## -def _load_app(name: str) -> BiomodalsApp: - """Load a biomodals app by name or path.""" +def _load_entry(entry_type: CatalogType, name: str) -> BiomodalsApp: + """Load a biomodals app or workflow by name or path.""" + all_entries = get_catalog(entry_type, use_absolute_paths=True) + name_or_path = name.partition("::")[0] + if entry_type == "workflow" and name_or_path not in all_entries: + workflow_path = Path(name_or_path).expanduser() + if workflow_path.exists() and not workflow_path.resolve().is_relative_to( + WORKFLOW_HOME + ): + console.print( + "[bold red]Error[/bold red] Workflow paths must be under " + f"'[green]{WORKFLOW_HOME}[/green]' so they import through " + "'[green]biomodals.workflow[/green]'." + ) + raise typer.Exit(code=1) + try: - return BiomodalsApp(name) + return BiomodalsApp( + name, + all_apps=all_entries, + ) except AppNotFoundError as e: - console.print(f"[bold red]Error[/bold red] failed to find app '{name}': {e}") + console.print( + f"[bold red]Error[/bold red] failed to find {entry_type} '{name}': {e}" + ) raise typer.Exit(code=1) from e except ImportError as e: console.print(f"[bold red]Error[/bold red] Failed to import '{name}': {e}") @@ -54,12 +87,83 @@ def _print_title(title: str) -> None: ########################################## # CLI Commands ########################################## -@app.command( +def _list_available_entries( + list_type: CatalogType, + *, + use_absolute_paths: bool, + sort_by: Literal["name", "category", "group", "path"], + reverse: bool, + short: bool, +) -> dict[str, Path]: + """Show a list of available biomodals apps or workflows.""" + title = list_type.capitalize() + table_headers = [f"{title} name", "Category", f"{title} path"] + available_apps = get_catalog(list_type, use_absolute_paths=use_absolute_paths) + table_rows: list[tuple[str, str, str]] = [] + for app_name, app_path in available_apps.items(): + app_category = app_path.parent.name + table_rows.append((f"[green]{app_name}[/green]", app_category, str(app_path))) + match sort_by: + case "name": + sort_by_idx = 0 + case "category" | "group": + sort_by_idx = 1 + case "path": + sort_by_idx = 2 + case _: + raise ValueError(f"Invalid sort key: {sort_by}") + table_rows.sort(key=lambda x: x[sort_by_idx], reverse=reverse) + if short: + for r in table_rows: + console.print(r[0]) + return available_apps + + table = Table(*table_headers) + for r in table_rows: + table.add_row(*r) + + if list_type == "app": + console.print( + "\n:dna: To see help for an application, use:\n" + " [bold]biomodals app help <[green]app-name-or-path[/green]>[/bold]" + ) + console.print( + "\n:dna: To run an application on [link=https://modal.com]modal.com[/link], use:\n" + r" [bold]biomodals app run <[green]app-name-or-path[/green]>[/bold] -- [gray]\[OPTIONS][/gray]" + ) + console.print( + "\n:dna: If an app contains multiple local entrypoints, use it as:\n" + " [bold]<[green]app-name-or-path[/green]>::<[green]function-name[/green]>[/bold]\n" + ) + else: + console.print( + "\n:dna: To see help for a workflow, use:\n" + " [bold]biomodals workflow help <[green]workflow-name-or-path[/green]>[/bold]" + ) + console.print( + "\n:dna: To run a workflow on [link=https://modal.com]modal.com[/link], use:\n" + r" [bold]biomodals workflow run <[green]workflow-name-or-path[/green]>[/bold] -- [gray]\[OPTIONS][/gray]" + ) + console.print( + "\n:dna: If a workflow contains multiple local entrypoints, use it as:\n" + " [bold]<[green]workflow-name-or-path[/green]>::<[green]function-name[/green]>[/bold]\n" + ) + console.print(f"\n:dna: [bold]Available biomodals {list_type}s:[/bold]") + console.print(table) + return available_apps + + +@app_commands.command( name="list", help="Show a list of all available biomodals applications (aliases: ls, l).", ) -@app.command(name="ls", hidden=True) -@app.command(name="l", hidden=True) +@app_commands.command(name="ls", hidden=True) +@app_commands.command(name="l", hidden=True) +@app.command( + name="list", help="Deprecated alias for 'biomodals app list'.", deprecated=True +) +@app.command(name="ls", hidden=True, deprecated=True) +@app.command(name="l", hidden=True, deprecated=True) def list_available_apps( use_absolute_paths: Annotated[ bool, @@ -83,88 +187,76 @@ def list_available_apps( short: Annotated[ bool, typer.Option( - "--short", - help="Only show app names without paths or additional info.", - is_flag=True, + "--short", help="Only show app names without paths or additional info." ), ] = False, ) -> dict[str, Path]: """Show a list of all available biomodals applications.""" - table_headers = ["App name", "Category", "App path"] - available_apps = get_all_apps(use_absolute_paths) - table_rows: list[tuple[str, str, str]] = [] - for app_name, app_path in available_apps.items(): - app_category = app_path.parent.name - table_rows.append(( - f"[green]{app_name}[/green]", - app_category, - str(app_path), - )) - match sort_by: - case "name": - sort_by_idx = table_headers.index("App name") - case "category" | "group": - sort_by_idx = table_headers.index("Category") - case "path": - sort_by_idx = table_headers.index("App path") - case _: - raise ValueError(f"Invalid sort key: {sort_by}") - table_rows.sort(key=lambda x: x[sort_by_idx], reverse=reverse) - if short: - for r in table_rows: - console.print(r[0]) - return available_apps - - table = Table(*table_headers) - for r in table_rows: - table.add_row(*r) - - console.print( - "\n:dna: To see help for an application, use:\n" - " [bold]biomodals help <[green]app-name-or-path[/green]>[/bold]" - ) - console.print( - "\n:dna: To run an application on [link=https://modal.com]modal.com[/link], use:\n" - r" [bold]biomodals run <[green]app-name-or-path[/green]>[/bold] -- [gray]\[OPTIONS][/gray]" + return _list_available_entries( + "app", + use_absolute_paths=use_absolute_paths, + sort_by=sort_by, + reverse=reverse, + short=short, ) - console.print( - "\n:dna: If an app contains multiple local entrypoints, use it as:\n" - " [bold]<[green]app-name-or-path[/green]>::<[green]function-name[/green]>[/bold]\n" - ) - console.print("\n:dna: [bold]Available biomodals applications:[/bold]") - console.print(table) - return available_apps -@app.command( - name="help", - no_args_is_help=True, - help="Show help for a specific biomodals application (alias: h).", +@workflow_commands.command( + name="list", + help="Show a list of all available biomodals workflows (aliases: ls, l).", ) -@app.command(name="h", no_args_is_help=True, hidden=True) -def show_app_help( - app_name: Annotated[ - str, typer.Argument(help="Name or path of the app to show help for.") - ], - verbose: Annotated[ +@workflow_commands.command(name="ls", hidden=True) +@workflow_commands.command(name="l", hidden=True) +def list_available_workflows( + use_absolute_paths: Annotated[ bool, - typer.Option("--verbose", "-v", help="Show detailed help for all functions."), + typer.Option( + "--absolute", "-a", help="Use absolute paths for workflow locations." + ), ] = False, -): - """Show help for a specific biomodals application. + sort_by: Annotated[ + Literal["name", "category", "group", "path"], + typer.Option( + "--sort-by", + "-s", + help="Key to sort the workflows by in the table display.", + case_sensitive=False, + ), + ] = "path", + reverse: Annotated[ + bool, + typer.Option( + "--reverse", "-r", help="Reverse the sorting order in the table display." + ), + ] = False, + short: Annotated[ + bool, + typer.Option( + "--short", help="Only show workflow names without paths or additional info." + ), + ] = False, +) -> dict[str, Path]: + """Show a list of all available biomodals workflows.""" + return _list_available_entries( + "workflow", + use_absolute_paths=use_absolute_paths, + sort_by=sort_by, + reverse=reverse, + short=short, + ) - If unsure which app to use, run `biomodals list` to see available apps. - If you would like to see help for a local entrypoint or Modal function, - add `::` to the app name to show help for that specific function. - """ - app = _load_app(app_name) - if app._entrypoint is not None: + +def _show_entry_help(list_type: CatalogType, entry_name: str, *, verbose: bool) -> None: + """Show help for a specific biomodals app or workflow.""" + catalog_entry = _load_entry(list_type, entry_name) + if catalog_entry._entrypoint is not None: # When an entrypoint name is specified, show only its docstring - f = app[app._entrypoint] + f = catalog_entry[catalog_entry._entrypoint] console.print( f"[bold]Help for {f.func_type} function" f"'[green]{f.name}[/green]'" - f" in app '[green]{app.name}[/green]' ({app.category}):[/bold]\n" + f" in {list_type} '[green]{catalog_entry.name}[/green]'" + f" ({catalog_entry.category}):[/bold]\n" ) console.print(f.docstring or "No documentation available.") if table_rows := f.args_table: @@ -174,16 +266,19 @@ def show_app_help( # When no entrypoint is specified, show the app help console.print( - "[bold]Help for application" - f" '[green]{app.name}[/green]' ({app.category}):[/bold]" + f"[bold]Help for {list_type}" + f" '[green]{catalog_entry.name}[/green]'" + f" ({catalog_entry.category}):[/bold]" ) - if app.module_doc: + if catalog_entry.module_doc: _print_title("Module documentation") - console.print(Markdown(app.module_doc)) - if app._remote_modal_func_idx: - remote_modal_functions = [app[x] for x in app._remote_modal_func_idx] + console.print(Markdown(catalog_entry.module_doc)) + if catalog_entry._remote_modal_func_idx: + remote_modal_functions = [ + catalog_entry[x] for x in catalog_entry._remote_modal_func_idx + ] - _print_title("Remote Modal functions in this app") + _print_title(f"Remote Modal functions in this {list_type}") remote_func_names = ", ".join([x.name for x in remote_modal_functions]) console.print(f"[green]{remote_func_names}[/green]\n") if verbose: @@ -192,10 +287,10 @@ def show_app_help( console.print(f"\n[bold green]{f.name}[/bold green]") console.print(Markdown(f.docstring)) - if f_indices := app._local_entrypoint_idx: - _print_title("Local entrypoint(s) in this app") + if f_indices := catalog_entry._local_entrypoint_idx: + _print_title(f"Local entrypoint(s) in this {list_type}") for f_idx in f_indices: - f = app[f_idx] + f = catalog_entry[f_idx] if f.args_table: console.print(f"[bold green]{f.name}[/bold green] CLI flags:\n") @@ -205,12 +300,69 @@ def show_app_help( console.print(Markdown(f.docstring)) +@app_commands.command( + name="help", + no_args_is_help=True, + help="Show help for a specific biomodals application (alias: h).", +) +@app_commands.command(name="h", no_args_is_help=True, hidden=True) @app.command( + name="help", + no_args_is_help=True, + help="Deprecated alias for 'biomodals app help'.", + deprecated=True, +) +@app.command(name="h", no_args_is_help=True, hidden=True, deprecated=True) +def show_app_help( + app_name: Annotated[ + str, typer.Argument(help="Name or path of the app to show help for.") + ], + verbose: Annotated[ + bool, + typer.Option("--verbose", "-v", help="Show detailed help for all functions."), + ] = False, +) -> None: + """Show help for a specific biomodals application. + + If unsure which app to use, run `biomodals app list` to see available apps. + If you would like to see help for a local entrypoint or Modal function, + add `::` to the app name to show help for that specific function. + """ + _show_entry_help("app", app_name, verbose=verbose) + + +@workflow_commands.command( + name="help", + no_args_is_help=True, + help="Show help for a specific biomodals workflow (alias: h).", +) +@workflow_commands.command(name="h", no_args_is_help=True, hidden=True) +def show_workflow_help( + workflow_name: Annotated[ + str, typer.Argument(help="Name or path of the workflow to show help for.") + ], + verbose: Annotated[ + bool, + typer.Option("--verbose", "-v", help="Show detailed help for all functions."), + ] = False, +) -> None: + """Show help for a specific biomodals workflow.""" + _show_entry_help("workflow", workflow_name, verbose=verbose) + + +@app_commands.command( name="run", no_args_is_help=True, help="Run a biomodals application on Modal (alias: r).", ) -@app.command(name="r", no_args_is_help=True, hidden=True) +@app_commands.command(name="r", no_args_is_help=True, hidden=True) +@app.command( + name="run", + no_args_is_help=True, + help="Deprecated alias for 'biomodals app run'.", + deprecated=True, +) +@app.command(name="r", no_args_is_help=True, hidden=True, deprecated=True) def run_modal_app( app_name_or_path: Annotated[ str, typer.Argument(help="Name or path of the app to run.") @@ -244,9 +396,12 @@ def run_modal_app( Use with: `biomodals run [OPTIONS] -- [app-options]`, where `[app-options]` are additional flags to pass to the `modal run ` command. """ + # TODO(workflows): add workflow run semantics separately from Modal app runs + # so workflow-* names can stage workflow inputs before invoking orchestrators. + import os import sys - app = _load_app(app_name_or_path) + app = _load_entry("app", app_name_or_path) full_app = ( str(app.path) if app._entrypoint is None else f"{app.path}::{app._entrypoint}" @@ -261,31 +416,136 @@ def run_modal_app( "To start an interactive shell for the app, run:\n" f"[bold green]{shlex.join(cmd)}[/bold green]" ) - elif flags: - # TODO: figure out a way to tag run names into the app. - # Previously we used the MODAL_APP environment variable for ephemeral - # apps run with the --run-name flag, but with the new AppConfig API - # this is no longer read. - import os - - env = os.environ.copy() - if gpu is not None: - env["GPU"] = gpu - if timeout is not None: - env["TIMEOUT"] = str(timeout) + return + + # TODO: figure out a way to tag run names into the app. + # Previously we used the MODAL_APP environment variable for ephemeral + # apps run with the --run-name flag, but with the new AppConfig API + # this is no longer read. + env = os.environ.copy() + if gpu is not None: + env["GPU"] = gpu + if timeout is not None: + env["TIMEOUT"] = str(timeout) + + if flags: run_command([*cmd, *flags], env=env) elif app._entrypoint is not None: - run_command(["biomodals", "help", str(full_app)], try_rich_print=True) + run_command(cmd, env=env) else: - run_command(["biomodals", "help", str(app.path)], try_rich_print=True) + run_command(["biomodals", "app", "help", str(app.path)], try_rich_print=True) -@app.command( +def _resolve_workflow_entrypoint(workflow: BiomodalsApp) -> str: + """Return the explicit or only local workflow entrypoint.""" + if workflow._entrypoint is not None: + return workflow._entrypoint + + local_entrypoints = [ + workflow[entrypoint_idx] for entrypoint_idx in workflow._local_entrypoint_idx + ] + if len(local_entrypoints) == 1: + return local_entrypoints[0].name + + if len(local_entrypoints) > 1: + entrypoint_names = ", ".join( + f"[green]{workflow.name}::{entrypoint.name}[/green]" + for entrypoint in local_entrypoints + ) + console.print( + "[bold red]Error[/bold red] Workflow " + f"'[green]{workflow.name}[/green]' contains multiple local entrypoints; " + f"choose one explicitly: {entrypoint_names}" + ) + raise typer.Exit(code=1) + + console.print( + "[bold red]Error[/bold red] Workflow " + f"'[green]{workflow.name}[/green]' does not define a local entrypoint." + ) + raise typer.Exit(code=1) + + +@workflow_commands.command( + name="run", + no_args_is_help=True, + help="Run a biomodals workflow on Modal (alias: r).", +) +@workflow_commands.command(name="r", no_args_is_help=True, hidden=True) +def run_workflow( + workflow_name_or_path: Annotated[ + str, typer.Argument(help="Name or path of the workflow to run.") + ], + modal_mode: Annotated[ + str, + typer.Option("--mode", "-m", help="Modal command to use ('run' or 'shell')."), + ] = "run", + detach: Annotated[ + bool, + typer.Option("--detach", "-d", help="Run the modal command in detached mode."), + ] = False, + gpu: Annotated[ + str | None, + typer.Option("--gpu", help="GPU type to use for the modal run (e.g. 'L40S'). "), + ] = None, + timeout: Annotated[ + int | None, + typer.Option( + "--timeout", + help="Timeout in seconds for the modal run. If not specified, use the workflow default.", + ), + ] = None, + flags: Annotated[ + list[str] | None, + typer.Argument(help="Additional flags to pass to the workflow entrypoint."), + ] = None, +): + """Run a biomodals workflow on Modal. + + Use with: `biomodals workflow run [OPTIONS] -- [workflow-options]`, + where `[workflow-options]` are passed to the workflow local entrypoint. + """ + import os + import sys + + workflow = _load_entry("workflow", workflow_name_or_path) + entrypoint = _resolve_workflow_entrypoint(workflow) + full_workflow = f"{workflow.module}::{entrypoint}" + + cmd = [sys.executable, "-m", "modal", modal_mode] + if detach: + cmd.append("-d") + cmd.extend(["-m", full_workflow]) + + if modal_mode == "shell": + console.print( + "To start an interactive shell for the workflow, run:\n" + f"[bold green]{shlex.join(cmd)}[/bold green]" + ) + return + + env = os.environ.copy() + if gpu is not None: + env["GPU"] = gpu + if timeout is not None: + env["TIMEOUT"] = str(timeout) + + run_command([*cmd, *(flags or [])], env=env) + + +@app_commands.command( name="deploy", no_args_is_help=True, help="Deploy a biomodals application to Modal (alias: d).", ) -@app.command(name="d", no_args_is_help=True, hidden=True) +@app_commands.command(name="d", no_args_is_help=True, hidden=True) +@app.command( + name="deploy", + no_args_is_help=True, + help="Deprecated alias for 'biomodals app deploy'.", + deprecated=True, +) +@app.command(name="d", no_args_is_help=True, hidden=True, deprecated=True) def deploy_app( app_name_or_path: Annotated[ str, typer.Argument(help="Name or path of the app to deploy.") @@ -299,7 +559,7 @@ def deploy_app( ] = None, ): """Deploy a biomodals application to Modal.""" - app = _load_app(app_name_or_path) + app = _load_entry("app", app_name_or_path) cmd = ["modal", "deploy"] if name: cmd.extend(["--name", name]) diff --git a/src/biomodals/helper/__init__.py b/src/biomodals/helper/__init__.py index ae2dc7e..20be466 100644 --- a/src/biomodals/helper/__init__.py +++ b/src/biomodals/helper/__init__.py @@ -1,9 +1,17 @@ """Helper utility scripts.""" +from collections.abc import Iterable + from modal import Image -def patch_image_for_helper(image: Image, copy_patch_files: bool = False) -> Image: +def patch_image_for_helper( + image: Image, + *, + copy_patch_files: bool = False, + include_workflow_modules: bool = False, + skip_deps: Iterable[str] | None = None, +) -> Image: """Patch a Modal Image to include helper dependencies. Args: @@ -15,6 +23,11 @@ def patch_image_for_helper(image: Image, copy_patch_files: bool = False) -> Imag This can slow down iteration since it requires a rebuild of the Image and any subsequent build steps whenever the included files change, but it is required if you want to run additional build steps after this one. + include_workflow_modules: Whether to include workflow modules in the patch. + By default, only helper dependencies are included. + skip_deps: A list of package names to skip when installing + `biomodals` dependencies. By default, all dependencies are included. + This is to help with older project apps on Python <3.12. """ # This is a bit hacky, but because Modal's .add_local_python_source() # does not install the package, the metadata.requires call would not work @@ -26,17 +39,25 @@ def patch_image_for_helper(image: Image, copy_patch_files: bool = False) -> Imag except metadata.PackageNotFoundError: helper_deps = [] - return ( - image - .apt_install("zstd", "fd-find") - .uv_pip_install(helper_deps) - .add_local_python_source( - "biomodals.helper", - "biomodals.app.constant", - "biomodals.app.config", - copy=copy_patch_files, - ) - ) + mods = ["biomodals.helper", "biomodals.app.config", "biomodals.schema"] + if include_workflow_modules: + mods.append("biomodals.workflow") + + new_image = image.apt_install("zstd", "fd-find") + if skip_deps is not None: + import re + + skip_deps_set = set(skip_deps) + package_name_pattern = re.compile(r"^[\w_\-.]+") + helper_deps = [ + dep + for dep in helper_deps + if next(package_name_pattern.finditer(dep)).group(0) not in skip_deps_set + ] + if helper_deps: + new_image = new_image.uv_pip_install(helper_deps) + + return new_image.add_local_python_source(*mods, copy=copy_patch_files) def hash_string(s: str) -> str: diff --git a/src/biomodals/app/catalog.py b/src/biomodals/helper/catalog.py similarity index 70% rename from src/biomodals/app/catalog.py rename to src/biomodals/helper/catalog.py index 41b4919..2252fd9 100644 --- a/src/biomodals/app/catalog.py +++ b/src/biomodals/helper/catalog.py @@ -2,14 +2,17 @@ import importlib import inspect -from collections.abc import Callable +from collections.abc import Callable, Iterable from dataclasses import dataclass from pathlib import Path from typing import Literal import modal -APP_HOME = Path(__file__).parent.resolve() +BIOMODALS_HOME = Path(__file__).parent.parent.resolve() +APP_HOME = BIOMODALS_HOME / "app" +WORKFLOW_HOME = BIOMODALS_HOME / "workflow" +CatalogType = Literal["app", "workflow"] class AppNotFoundError(ValueError): @@ -21,26 +24,82 @@ def __init__(self, app_name: str) -> None: super().__init__(f"Application '{app_name}' not found.") -def get_all_apps( - use_absolute_paths: bool = False, +def get_all_scripts( + root_dir: Path, + glob_prefix: str, + glob_suffix: str, *, - app_home: Path = APP_HOME, + use_absolute_paths: bool = False, cwd: Path | None = None, ) -> dict[str, Path]: """Retrieve all available biomodals applications.""" available_apps: dict[str, Path] = {} base_cwd = Path.cwd() if cwd is None else cwd - for app_file in app_home.glob("*/*_app.py"): + glob_pattern = f"{glob_prefix}*{glob_suffix}.py" + for app_file in root_dir.glob(glob_pattern): app_path = ( app_file.resolve() if use_absolute_paths else app_file.relative_to(base_cwd, walk_up=True) ) - app_name = app_file.stem.replace("_app", "") + app_name = app_file.stem.removesuffix(glob_suffix) available_apps[app_name] = app_path return available_apps +def get_catalog( + catalog_type: CatalogType, + *, + use_absolute_paths: bool = False, + cwd: Path | None = None, +) -> dict[str, Path]: + """Retrieve app or workflow catalog entries.""" + match catalog_type: + case "app": + return get_all_scripts( + APP_HOME, "*/", "_app", use_absolute_paths=use_absolute_paths, cwd=cwd + ) + case "workflow": + return get_all_scripts( + WORKFLOW_HOME, + "", + "_workflow", + use_absolute_paths=use_absolute_paths, + cwd=cwd, + ) + case _: + raise ValueError(f"Unknown catalog type: {catalog_type}") + + +def include_dependency_apps(app: modal.App, dependencies: Iterable[str]) -> modal.App: + """Include catalog app definitions into an existing Modal app.""" + all_apps = get_catalog("app", use_absolute_paths=True) + for dependency in dependencies: + dependency_metadata = BiomodalsApp(dependency, all_apps=all_apps) + dependency_module = importlib.import_module(dependency_metadata.module) + dependency_app = getattr(dependency_module, "app", None) + if not isinstance(dependency_app, modal.App): + raise TypeError( + f"Dependency app '{dependency}' does not expose a modal.App named app" + ) + + function_collisions = set(app._local_state.functions) & set( + dependency_app._local_state.functions + ) + class_collisions = set(app._local_state.classes) & set( + dependency_app._local_state.classes + ) + duplicate_tags = sorted(function_collisions | class_collisions) + if duplicate_tags: + duplicate_list = ", ".join(duplicate_tags) + raise ValueError( + f"Dependency app '{dependency}' has Modal tag collisions: " + f"{duplicate_list}" + ) + app.include(dependency_app, inherit_tags=False) + return app + + @dataclass(frozen=True) class AppFunction: """Information about a Modal or local entrypoint function.""" @@ -66,6 +125,7 @@ class BiomodalsApp: _func_idx (dict[str, int]): A mapping of function names to their index in the functions list. _local_entrypoint_idx (list[int]): A list of indices of local entrypoint functions in the functions list. _remote_modal_func_idx (list[int]): A list of indices of remote Modal functions in the functions list. + """ def __init__( @@ -79,9 +139,7 @@ def __init__( self._entrypoint = entrypoint_name # Normalize app name & path - self._all_apps = all_apps or get_all_apps( - use_absolute_paths=True, app_home=APP_HOME - ) + self._all_apps = all_apps or get_catalog("app", use_absolute_paths=True) self.name, self.path = self.resolve_app_path(name_or_path) self.category = self.path.parent.name self.module = self.app_path_to_module_path(self.path) @@ -121,19 +179,29 @@ def resolve_app_path(self, app_name_or_path: str) -> tuple[str, Path]: app_path = Path(app_name_or_path).expanduser() if not app_path.exists(): raise AppNotFoundError(app_name_or_path) - return app_path.stem.removesuffix("_app"), app_path + return app_path.stem.removesuffix("_app").removesuffix("_workflow"), app_path @staticmethod def app_path_to_module_path(app_path: Path) -> str: """Convert an app path to a module path.""" + resolved_path = app_path.resolve() + if resolved_path.is_relative_to(APP_HOME): + module_path = ( + str(resolved_path.relative_to(APP_HOME)) + .replace("/", ".") + .replace("\\", ".") + .replace(".py", "") + .replace("-", "_") + ) + return f"biomodals.app.{module_path}" module_path = ( - str(app_path.resolve().relative_to(APP_HOME)) + str(resolved_path.relative_to(BIOMODALS_HOME)) .replace("/", ".") .replace("\\", ".") .replace(".py", "") .replace("-", "_") ) - return f"biomodals.app.{module_path}" + return f"biomodals.{module_path}" def populate_functions(self): """Collect all functions within the app.""" diff --git a/src/biomodals/helper/constant.py b/src/biomodals/helper/constant.py new file mode 100644 index 0000000..5e127ac --- /dev/null +++ b/src/biomodals/helper/constant.py @@ -0,0 +1,30 @@ +"""Shared constants used across Biomodals apps and workflows.""" + +from modal import Volume + +# Volume for caching all model weights. +MODEL_VOLUME_NAME = "biomodals-store" +MODEL_VOLUME = Volume.from_name(MODEL_VOLUME_NAME, create_if_missing=True) + +# Volume for caching MSA databases, which are large and shared across apps. +AF3_MSA_DB_VOLUME = Volume.from_name( + "AlphaFold3-msa-db", create_if_missing=True, version=2 +) +PROTENIX_MSA_DB_VOLUME = Volume.from_name( + "Protenix-msa-db", create_if_missing=True, version=2 +) + +# Volume for caching MSA search results. +MSA_CACHE_VOLUME_NAME = "biomodals-msa-cache" +MSA_CACHE_VOLUME = Volume.from_name( + MSA_CACHE_VOLUME_NAME, create_if_missing=True, version=2 +) + +# Durable workflow-orchestrator output ledger/artifact volume. +WORKFLOW_ORCHESTRATOR_VOLUME_NAME = "biomodals-workflow-orchestrator" +WORKFLOW_ORCHESTRATOR_VOLUME = Volume.from_name( + WORKFLOW_ORCHESTRATOR_VOLUME_NAME, create_if_missing=True, version=2 +) + +# Max timeout for any function, in seconds (24 hours). +MAX_TIMEOUT = 86_400 diff --git a/src/biomodals/helper/io.py b/src/biomodals/helper/io.py index 5d97964..f1c192f 100644 --- a/src/biomodals/helper/io.py +++ b/src/biomodals/helper/io.py @@ -17,7 +17,10 @@ def _clean_filename_part(value: str | Path | None) -> str: """Return one clean filename component.""" if value is None: return "" - cleaned = sanitize_filename(str(value)) + raw_value = str(value) + if not raw_value.strip(): + return "" + cleaned = sanitize_filename(raw_value) cleaned = re.sub(r"[^A-Za-z0-9._-]+", "_", cleaned) cleaned = re.sub(r"_+", "_", cleaned).strip("._-") return cleaned @@ -45,13 +48,7 @@ def build_local_output_path( ) -> Path: """Build a clean local output path and raise if it would overwrite a file.""" parts = [ - part - for part in ( - _clean_filename_part(prefix), - _clean_filename_part(run_name), - _clean_filename_part(suffix), - ) - if part + p for part in (prefix, run_name, suffix) if (p := _clean_filename_part(part)) ] if not parts: raise ValueError( diff --git a/src/biomodals/helper/shell.py b/src/biomodals/helper/shell.py index a5dc955..1ae078d 100644 --- a/src/biomodals/helper/shell.py +++ b/src/biomodals/helper/shell.py @@ -105,9 +105,16 @@ def run_command_with_log( """Run a shell command and log output to a file.""" import shlex import subprocess as sp - from datetime import UTC, datetime, timedelta + from datetime import datetime, timedelta from time import time + try: + from datetime import UTC + except ImportError: + from datetime import timezone + + UTC = timezone.utc # noqa: UP017 + if isinstance(cmd, str): cmd = shlex.split(cmd) @@ -317,7 +324,7 @@ def copy_files(src_dst_mapping: dict[str | Path, str | Path]) -> None: for p in subprocesses: _, p_stderr = p.communicate() if p.returncode != 0: - p_cmd = shlex.join(p.args) + p_cmd = shlex.join(p.args) # type: ignore[ty:invalid-argument-type] p_err_msg = p_stderr.decode().strip() err_msgs.append( f"'{p_cmd}' failed with return code {p.returncode}: {p_err_msg}" @@ -345,4 +352,7 @@ def sanitize_filename(filename: str, separator: str = "_") -> str: root_dir = Path(os.sep) f = (root_dir / filename.strip()).resolve().relative_to(root_dir) - return separator.join(f.parts) + sanitized = separator.join(f.parts) + if not sanitized: + raise ValueError("Value must contain at least one safe filename component") + return sanitized diff --git a/src/biomodals/helper/volume_run.py b/src/biomodals/helper/volume_run.py index 987ab73..2e3655b 100644 --- a/src/biomodals/helper/volume_run.py +++ b/src/biomodals/helper/volume_run.py @@ -5,7 +5,35 @@ Modal-supported atomic primitive. """ -from pathlib import Path +from pathlib import Path, PurePosixPath + +from biomodals.schema import VolumePath + + +def volume_path_from_mount_path( + remote_path: str, + mount_root: str, + volume_name: str, + media_type: str | None = None, +) -> VolumePath: + """Convert an app mount path into a volume-relative workflow storage path.""" + resolved_remote_path = PurePosixPath(remote_path) + resolved_mount_root = PurePosixPath(mount_root) + try: + relative_path = resolved_remote_path.relative_to(resolved_mount_root) + except ValueError as exc: + raise ValueError( + f"Remote path is outside mounted volume root {mount_root}: {remote_path}" + ) from exc + if str(relative_path) == ".": + raise ValueError( + f"Remote path must be below mounted volume root {mount_root}: {remote_path}" + ) + return VolumePath( + volume_name=volume_name, + path=str(relative_path), + media_type=media_type, + ) def build_volume_run_paths( diff --git a/src/biomodals/schema/__init__.py b/src/biomodals/schema/__init__.py new file mode 100644 index 0000000..14628fb --- /dev/null +++ b/src/biomodals/schema/__init__.py @@ -0,0 +1,40 @@ +"""Shared Pydantic contracts for Biomodals apps and workflows.""" + +from biomodals.schema.app import AppConfig, AppOutput, AppRunResult, AppRunStatus +from biomodals.schema.storage import InlineBytes, StorageKind, VolumePath +from biomodals.schema.workflow import ( + ArtifactFile, + ArtifactKind, + ArtifactSelector, + AttemptRecord, + ControlEdge, + NodeExecutionPolicy, + NodePlacement, + NodeStatus, + NodeStatusRecord, + RunStatus, + WorkflowArtifact, + WorkflowRun, +) + +__all__ = [ + "AppConfig", + "AppOutput", + "AppRunResult", + "AppRunStatus", + "ArtifactFile", + "ArtifactKind", + "ArtifactSelector", + "AttemptRecord", + "ControlEdge", + "InlineBytes", + "NodeExecutionPolicy", + "NodePlacement", + "NodeStatus", + "NodeStatusRecord", + "RunStatus", + "StorageKind", + "VolumePath", + "WorkflowArtifact", + "WorkflowRun", +] diff --git a/src/biomodals/schema/app.py b/src/biomodals/schema/app.py new file mode 100644 index 0000000..6627b53 --- /dev/null +++ b/src/biomodals/schema/app.py @@ -0,0 +1,191 @@ +"""Schemas for Biomodals app configuration and function results.""" + +from __future__ import annotations + +from functools import cached_property +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator + +from biomodals.schema.storage import InlineBytes, VolumePath +from biomodals.schema.workflow import ArtifactKind + +try: + from enum import StrEnum +except ImportError: + from backports.strenum import StrEnum # type: ignore[ty:unresolved-import] # noqa: UP035,I001 + +APP_CONFIG_MAX_TIMEOUT = 86_400 + + +class AppConfig(BaseModel): + """Base configuration model for Biomodals apps.""" + + model_config = ConfigDict(frozen=True) + + # Metadata + name: str + repo_url: str | None = None + repo_commit_hash: str | None = None + package_name: str | None = None + version: str | None = None + python_version: str | None = None + tags: dict[str, str] | None = None + depends_on_apps: tuple[str, ...] = () + + # Runtime configs + # Model GPU (https://modal.com/docs/guide/gpu) + # 16GB: T4 + # 24GB: L4, A10G + # 40GB: A100-40G, A100 (using A100 may cause Modal to auto-upgrade to A100-80G) + # 48GB: L40S + # 80GB: A100-80G, H100 (may auto-upgrade to H200, use H100! to avoid) + # 96GB: RTX-PRO-6000 + # 141GB: H200 + # 180GB: B200 (B200+ may auto-upgrade to B300, which requires CUDA13.0+) + gpu: str = "A10G" + # https://modal.com/docs/guide/cuda + cuda_version: str = "cu128" + # Default execution timeout in seconds (https://modal.com/docs/guide/timeouts) + timeout: int = 1800 + # Location to cache model weights and other large artifacts + model_volume_mountpoint: str = "/biomodals-store" + # Location to mount output volume (if in use) + output_volume_mountpoint: str = "/biomodals-outputs" + + @computed_field + @cached_property + def default_env(self) -> dict[str, str]: + """Environment variables to set in the runtime image.""" + model_cache_dir = Path(self.model_volume_mountpoint).resolve() + return { + "UV_COMPILE_BYTECODE": "1", # slower image build, faster runtime + "HF_XET_HIGH_PERFORMANCE": "1", + "HF_HOME": str(model_cache_dir / "huggingface"), + "TORCH_HOME": str(model_cache_dir / "torch"), + "UV_TORCH_BACKEND": self.cuda_version, + } + + @computed_field + @cached_property + def model_volume_subdir(self) -> str: + """Subdirectory in the Modal volume to store model weights. + + Note: do not use this if the weights are managed using HuggingFace. + Instead, use "/huggingface" directly. + """ + return f"/{self.name}" + + @computed_field + @cached_property + def git_clone_dir(self) -> Path: + """Directory to store cloned Git repositories.""" + return Path(f"/opt/{self.name}") + + @computed_field + @cached_property + def cuda_version_numeric(self) -> str: + """Numeric CUDA version, e.g., '128' for 'cu128'. + + https://github.com/astral-sh/uv/blob/main/crates/uv-torch/src/backend.rs + """ + if not self.cuda_version.startswith("cu"): + return "" + + available_uv_backends = { + "130", + "129", + "128", + "126", + "125", + "124", + "123", + "122", + "121", + "120", + "118", + "117", + "116", + "115", + "114", + "113", + "112", + "111", + "110", + "102", + "101", + "100", + "92", + "91", + "90", + } + + if (cuda_ver := self.cuda_version[2:]) not in available_uv_backends: + raise ValueError( + f"CUDA version {self.cuda_version} is not supported by UV. " + f"Available versions: {available_uv_backends}" + ) + return f"{cuda_ver[:-1]}.{cuda_ver[-1]}.0" + + @model_validator(mode="after") + def ensure_package_info(self): + """Ensure that the package information is complete.""" + if self.repo_url is None and self.package_name is None: + raise ValueError( + "At least one of 'repo_url' or 'package_name' must be provided." + ) + if self.repo_commit_hash is None and self.version is None: + raise ValueError( + "Provide 'repo_commit_hash' or 'version' for reproducibility." + ) + return self + + @model_validator(mode="after") + def ensure_cuda_gpu_compatibility(self): + """Ensure that the specified CUDA version is compatible with the GPU.""" + if not self.cuda_version.startswith("cu"): + raise ValueError("CUDA version must start with 'cu', e.g., 'cu128'.") + + is_cu12 = self.cuda_version.startswith("cu12") + if is_cu12 and self.gpu.startswith("B200+"): + raise ValueError("CUDA 12.x is not compatible with 'B200+ / B300' GPU.") + + return self + + @model_validator(mode="after") + def ensure_timeout_within_range(self): + """Ensure that the specified timeout is within a reasonable range.""" + # between 1 second and 24 hours + if self.timeout != max(1, min(self.timeout, APP_CONFIG_MAX_TIMEOUT)): + raise ValueError( + f"Timeout must be between 1 and {APP_CONFIG_MAX_TIMEOUT} seconds." + ) + return self + + +class AppRunStatus(StrEnum): + """Common completion states returned by workflow-compatible app functions.""" + + SUCCEEDED = "succeeded" + FAILED = "failed" + PARTIAL = "partial" + + +class AppOutput(BaseModel): + """One output produced by a workflow-compatible app function.""" + + name: str + kind: ArtifactKind + storage: InlineBytes | VolumePath = Field(discriminator="kind") + metadata: dict[str, Any] = Field(default_factory=dict) + + +class AppRunResult(BaseModel): + """Standard result returned by workflow-compatible app functions.""" + + status: AppRunStatus + outputs: list[AppOutput] = Field(default_factory=list) + metrics: dict[str, str | int | float | bool] = Field(default_factory=dict) + warnings: list[str] = Field(default_factory=list) + logs: list[AppOutput] = Field(default_factory=list) diff --git a/src/biomodals/schema/storage.py b/src/biomodals/schema/storage.py new file mode 100644 index 0000000..fb13023 --- /dev/null +++ b/src/biomodals/schema/storage.py @@ -0,0 +1,75 @@ +"""Storage schemas shared by app results and workflow artifacts.""" + +from __future__ import annotations + +from pathlib import Path, PurePosixPath +from typing import Literal + +from pydantic import BaseModel, ConfigDict, field_validator + +try: + from enum import StrEnum +except ImportError: + from backports.strenum import StrEnum # type: ignore[ty:unresolved-import] # noqa: UP035,I001 + + +class StorageKind(StrEnum): + """Supported storage forms for app outputs and workflow artifacts.""" + + INLINE_BYTES = "inline_bytes" + VOLUME_PATH = "volume_path" + + +class InlineBytes(BaseModel): + """UTF-8 text returned directly before workflow materialization.""" + + model_config = ConfigDict(extra="forbid") + + kind: Literal[StorageKind.INLINE_BYTES] = StorageKind.INLINE_BYTES + data: bytes + filename: str + media_type: str | None = None + + @field_validator("data") + @classmethod + def ensure_utf8_text(cls, value: bytes) -> bytes: + """Reject non-text inline payloads.""" + try: + value.decode("utf-8") + except UnicodeDecodeError as exc: + raise ValueError( + "InlineBytes.data must be UTF-8 text; use VolumePath for binary data." + ) from exc + return value + + +class VolumePath(BaseModel): + """Path to data stored in a Modal volume.""" + + model_config = ConfigDict(extra="forbid") + + kind: Literal[StorageKind.VOLUME_PATH] = StorageKind.VOLUME_PATH + volume_name: str + path: str + media_type: str | None = None + + @field_validator("path") + @classmethod + def ensure_relative_volume_path(cls, value: str) -> str: + """Reject paths that can escape the declared volume root.""" + path = PurePosixPath(value) + if value == "" or value == ".": + raise ValueError("VolumePath.path must be a non-empty relative path") + if path.is_absolute() or any(part in {"", ".", ".."} for part in path.parts): + raise ValueError("VolumePath.path must be relative and must not traverse") + if "\\" in value: + raise ValueError("VolumePath.path must use POSIX separators") + return value + + def at_mountpoint(self, mountpoint: str | Path) -> Path: + """Return the path of this volume on the given mountpoint.""" + return Path(mountpoint) / self.path + + def __str__(self) -> str: + """Return a human-readable string representation of this volume path.""" + return f"'{self.path}' from volume '{self.volume_name}'" diff --git a/src/biomodals/schema/workflow.py b/src/biomodals/schema/workflow.py new file mode 100644 index 0000000..58d8cd6 --- /dev/null +++ b/src/biomodals/schema/workflow.py @@ -0,0 +1,139 @@ +"""Schemas for workflow artifacts, selectors, and durable run status.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, Field + +from biomodals.schema.storage import VolumePath + +# < Python 3.11 guards +try: + from datetime import UTC + from enum import StrEnum +except ImportError: + from datetime import timezone # noqa: I001 + + from backports.strenum import StrEnum # type: ignore[ty:unresolved-import] # noqa: UP035,I001 + + UTC = timezone.utc # noqa: UP017 + + +class ArtifactKind(StrEnum): + """Common artifact categories passed between workflow nodes.""" + + STRUCTURES = "structures" + SCORES = "scores" + REPORT = "report" + ARCHIVE = "archive" + DIRECTORY = "directory" + TABLE = "table" + LOGS = "logs" + + +class NodeExecutionPolicy(StrEnum): + """Restart behavior for an incomplete workflow node.""" + + RERUN = "rerun" + RESUME = "resume" + + +class NodePlacement(StrEnum): + """Execution location for a workflow node.""" + + ORCHESTRATOR = "orchestrator" + REMOTE = "remote" + + +class NodeStatus(StrEnum): + """Durable lifecycle states for one workflow node.""" + + PENDING = "pending" + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + SKIPPED = "skipped" + + +class RunStatus(StrEnum): + """Durable lifecycle states for one workflow run.""" + + PENDING = "pending" + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + + +class ArtifactFile(BaseModel): + """One file recorded inside a workflow artifact.""" + + path: str + role: str | None = None + media_type: str | None = None + size_bytes: int | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class WorkflowArtifact(BaseModel): + """Durable manifest for data produced by a workflow node.""" + + artifact_id: str + producing_node_id: str + kind: ArtifactKind + storage: VolumePath + files: list[ArtifactFile] = Field(default_factory=list) + source_app_output_name: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ArtifactSelector(BaseModel): + """Reference to upstream workflow artifacts consumed by a node input.""" + + producing_node_id: str + kind: ArtifactKind | None = None + pattern: str | None = None + role: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ControlEdge(BaseModel): + """Ordering dependency between workflow nodes without artifact passage.""" + + upstream_node_id: str + downstream_node_id: str + + +class WorkflowRun(BaseModel): + """Durable status record for one workflow run.""" + + workflow_name: str + run_id: str + status: RunStatus = RunStatus.PENDING + dag_hash: str | None = None + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class AttemptRecord(BaseModel): + """Durable record for one node execution attempt.""" + + node_id: str + attempt_id: str + status: NodeStatus = NodeStatus.RUNNING + metadata: dict[str, Any] = Field(default_factory=dict) + + +class NodeStatusRecord(BaseModel): + """Durable status record for one workflow node.""" + + node_id: str + status: NodeStatus + execution_policy: NodeExecutionPolicy = NodeExecutionPolicy.RERUN + placement: NodePlacement = NodePlacement.ORCHESTRATOR + input_artifact_ids: list[str] = Field(default_factory=list) + output_artifact_ids: list[str] = Field(default_factory=list) + attempts: list[str] = Field(default_factory=list) + error: str | None = None diff --git a/src/biomodals/workflow/__init__.py b/src/biomodals/workflow/__init__.py index e778128..15ce079 100644 --- a/src/biomodals/workflow/__init__.py +++ b/src/biomodals/workflow/__init__.py @@ -1 +1,23 @@ -"""Multi-app computational pipelines.""" +"""Executable workflow scripts and public workflow runtime types.""" + +from biomodals.workflow.core import ( + AppBackedNode, + NodeHandle, + NodeRunContext, + Workflow, + WorkflowDefinition, + WorkflowNativeNode, + WorkflowNode, + WorkflowNodeSpec, +) + +__all__ = [ + "AppBackedNode", + "NodeHandle", + "NodeRunContext", + "Workflow", + "WorkflowDefinition", + "WorkflowNativeNode", + "WorkflowNode", + "WorkflowNodeSpec", +] diff --git a/src/biomodals/workflow/core/__init__.py b/src/biomodals/workflow/core/__init__.py new file mode 100644 index 0000000..2f724e9 --- /dev/null +++ b/src/biomodals/workflow/core/__init__.py @@ -0,0 +1,25 @@ +"""Reusable workflow runtime internals.""" + +from biomodals.workflow.core.builder import ( + NodeHandle, + Workflow, + WorkflowDefinition, + WorkflowNodeSpec, +) +from biomodals.workflow.core.nodes import ( + AppBackedNode, + NodeRunContext, + WorkflowNativeNode, + WorkflowNode, +) + +__all__ = [ + "AppBackedNode", + "NodeHandle", + "NodeRunContext", + "Workflow", + "WorkflowDefinition", + "WorkflowNativeNode", + "WorkflowNode", + "WorkflowNodeSpec", +] diff --git a/src/biomodals/workflow/core/artifacts.py b/src/biomodals/workflow/core/artifacts.py new file mode 100644 index 0000000..7dff9a9 --- /dev/null +++ b/src/biomodals/workflow/core/artifacts.py @@ -0,0 +1,321 @@ +"""Local helpers for materializing app outputs into workflow artifacts.""" + +from __future__ import annotations + +import shutil +from collections.abc import Mapping +from pathlib import Path, PurePosixPath +from typing import Any, Literal + +import orjson +from pydantic import BaseModel + +from biomodals.helper.shell import sanitize_filename +from biomodals.schema import ( + AppRunResult, + ArtifactFile, + ArtifactKind, + InlineBytes, + VolumePath, + WorkflowArtifact, +) + + +def _artifact_id(producing_node_id: str, output_name: str) -> str: + return sanitize_filename(f"{producing_node_id}-{output_name}") + + +def _write_json(path: Path, payload: object) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_name(f".{path.name}.tmp") + if isinstance(payload, BaseModel): + tmp_path.write_text(payload.model_dump_json(indent=2), encoding="utf-8") + else: + tmp_path.write_bytes( + orjson.dumps( + payload, + option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS, + ) + ) + tmp_path.replace(path) + + +def _artifact_files(root: Path) -> list[ArtifactFile]: + if root.is_file(): + return [ + ArtifactFile( + path=root.name, + size_bytes=root.stat().st_size, + ) + ] + return [ + ArtifactFile( + path=str(path.relative_to(root)), + size_bytes=path.stat().st_size, + ) + for path in sorted(root.rglob("*")) + if path.is_file() + ] + + +def _validate_inline_text_bytes( + storage: InlineBytes, output_kind: ArtifactKind +) -> None: + if output_kind == ArtifactKind.ARCHIVE or getattr(storage, "archive_format", None): + raise ValueError( + "InlineBytes outputs are UTF-8 text only; archive outputs must use " + "VolumePath storage" + ) + try: + storage.data.decode("utf-8") + except UnicodeDecodeError as exc: + raise ValueError("InlineBytes outputs must contain UTF-8 text bytes") from exc + + +def _materialize_inline_bytes( + *, + storage: InlineBytes, + output_name: str, + output_kind: ArtifactKind, + workflow_volume_name: str, + attempt_dir: Path, + volume_root: Path | None, + producing_node_id: str, + metadata: dict[str, Any] | None = None, + artifact_output_name: str | None = None, + source_app_output_name: str | None = None, + raw_dir: Path | None = None, + materialized_parent: Path | None = None, +) -> WorkflowArtifact: + artifact_id = _artifact_id( + producing_node_id, + artifact_output_name or output_name, + ) + _validate_inline_text_bytes(storage, output_kind) + safe_filename = sanitize_filename(storage.filename) + raw_dir = raw_dir or attempt_dir / "raw_outputs" + raw_dir.mkdir(parents=True, exist_ok=True) + raw_path = raw_dir / safe_filename + raw_path.write_bytes(storage.data) + + materialized_parent = materialized_parent or attempt_dir / "materialized_outputs" + materialized_dir = materialized_parent / artifact_id + materialized_dir.mkdir(parents=True, exist_ok=True) + materialized_dir.joinpath(safe_filename).write_bytes(storage.data) + + return WorkflowArtifact( + artifact_id=artifact_id, + producing_node_id=producing_node_id, + kind=output_kind, + storage=VolumePath( + volume_name=workflow_volume_name, + path=_volume_path(materialized_dir, volume_root), + ), + files=_artifact_files(materialized_dir), + source_app_output_name=source_app_output_name or output_name, + metadata=metadata or {}, + ) + + +def _volume_path(path: Path, volume_root: Path | None) -> str: + if volume_root is None: + return str(path) + return path.relative_to(volume_root).as_posix() + + +def _resolve_volume_child(root: Path, path: str) -> Path: + relative = PurePosixPath(path) + if path == "" or path == ".": + raise ValueError("VolumePath.path must be a non-empty relative path") + if relative.is_absolute() or any( + part in {"", ".", ".."} for part in relative.parts + ): + raise ValueError("VolumePath.path must be relative and must not traverse") + if "\\" in path: + raise ValueError("VolumePath.path must use POSIX separators") + + resolved_root = root.resolve() + raw_path = resolved_root / Path(*relative.parts) + current = resolved_root + for part in relative.parts: + current /= part + if current.is_symlink(): + raise ValueError("VolumePath.path must not contain symlinks") + + resolved_path = raw_path.resolve() + try: + # Validate-only: reject paths that resolve outside the mounted volume. + resolved_path.relative_to(resolved_root) + except ValueError as exc: + raise ValueError("VolumePath.path escapes the mounted volume root") from exc + return resolved_path + + +def _copy_volume_path_tree( + *, + source_path: Path, + materialized_dir: Path, + source_root: Path, +) -> None: + resolved_source_root = source_root.resolve() + if source_path.is_symlink(): + raise ValueError("VolumePath copy source must not be a symlink") + if source_path.is_file(): + materialized_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2( + source_path, materialized_dir / source_path.name, follow_symlinks=False + ) + return + + materialized_dir.mkdir(parents=True, exist_ok=True) + for child in sorted(source_path.rglob("*")): + if child.is_symlink(): + raise ValueError("VolumePath copy source tree must not contain symlinks") + try: + child.resolve().relative_to(resolved_source_root) + except ValueError as exc: + raise ValueError( + "VolumePath copy source tree escapes the mounted volume root" + ) from exc + + destination = materialized_dir / child.relative_to(source_path) + if child.is_dir(): + destination.mkdir(parents=True, exist_ok=True) + elif child.is_file(): + destination.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(child, destination, follow_symlinks=False) + + +def _materialize_volume_path_copy( + *, + storage: VolumePath, + output_name: str, + output_kind: ArtifactKind, + workflow_volume_name: str, + attempt_dir: Path, + volume_root: Path | None, + producing_node_id: str, + metadata: dict[str, Any], + volume_roots: Mapping[str, Path], + artifact_output_name: str | None = None, + source_app_output_name: str | None = None, + materialized_parent: Path | None = None, +) -> WorkflowArtifact: + artifact_id = _artifact_id( + producing_node_id, + artifact_output_name or output_name, + ) + source_root = volume_roots.get(storage.volume_name) + if source_root is None: + raise ValueError( + f"Missing mounted volume root for output volume {storage.volume_name!r}" + ) + source_path = _resolve_volume_child(source_root, storage.path) + if not source_path.exists(): + raise FileNotFoundError(f"Volume output path not found: {source_path}") + + materialized_parent = materialized_parent or attempt_dir / "materialized_outputs" + materialized_dir = materialized_parent / artifact_id + _copy_volume_path_tree( + source_path=source_path, + materialized_dir=materialized_dir, + source_root=source_root, + ) + + return WorkflowArtifact( + artifact_id=artifact_id, + producing_node_id=producing_node_id, + kind=output_kind, + storage=VolumePath( + volume_name=workflow_volume_name, + path=_volume_path(materialized_dir, volume_root), + media_type=storage.media_type, + ), + files=_artifact_files(materialized_dir), + source_app_output_name=source_app_output_name or output_name, + metadata=metadata, + ) + + +def materialize_app_run_result( + *, + result: AppRunResult, + workflow_volume_name: str, + attempt_dir: Path, + artifact_dir: Path, + producing_node_id: str, + volume_root: Path | None = None, + volume_path_mode: Literal["reference", "copy"] = "reference", + volume_roots: Mapping[str, Path] | None = None, +) -> list[WorkflowArtifact]: + """Write app outputs into local workflow volume paths and return manifests.""" + artifacts: list[WorkflowArtifact] = [] + + def materialize_output( + output, + *, + artifact_output_name: str | None = None, + source_app_output_name: str | None = None, + raw_dir: Path | None = None, + materialized_parent: Path | None = None, + ) -> WorkflowArtifact: + artifact_id = _artifact_id( + producing_node_id, + artifact_output_name or output.name, + ) + if isinstance(output.storage, InlineBytes): + return _materialize_inline_bytes( + storage=output.storage, + output_name=output.name, + output_kind=output.kind, + workflow_volume_name=workflow_volume_name, + attempt_dir=attempt_dir, + volume_root=volume_root, + producing_node_id=producing_node_id, + metadata=output.metadata, + artifact_output_name=artifact_output_name, + source_app_output_name=source_app_output_name, + raw_dir=raw_dir, + materialized_parent=materialized_parent, + ) + + if volume_path_mode == "copy": + return _materialize_volume_path_copy( + storage=output.storage, + output_name=output.name, + output_kind=output.kind, + workflow_volume_name=workflow_volume_name, + attempt_dir=attempt_dir, + volume_root=volume_root, + producing_node_id=producing_node_id, + metadata=output.metadata, + volume_roots=volume_roots or {}, + artifact_output_name=artifact_output_name, + source_app_output_name=source_app_output_name, + materialized_parent=materialized_parent, + ) + return WorkflowArtifact( + artifact_id=artifact_id, + producing_node_id=producing_node_id, + kind=output.kind, + storage=output.storage, + source_app_output_name=source_app_output_name or output.name, + metadata=output.metadata, + ) + + for output in result.outputs: + artifact = materialize_output(output) + _write_json(artifact_dir / f"{artifact.artifact_id}.json", artifact) + artifacts.append(artifact) + + for log_output in result.logs: + artifact = materialize_output( + log_output, + artifact_output_name=f"logs-{log_output.name}", + source_app_output_name=log_output.name, + raw_dir=attempt_dir / "logs" / "raw_outputs", + materialized_parent=attempt_dir / "logs", + ) + _write_json(artifact_dir / f"{artifact.artifact_id}.json", artifact) + artifacts.append(artifact) + return artifacts diff --git a/src/biomodals/workflow/core/builder.py b/src/biomodals/workflow/core/builder.py new file mode 100644 index 0000000..55a9fab --- /dev/null +++ b/src/biomodals/workflow/core/builder.py @@ -0,0 +1,145 @@ +"""Python-first workflow DAG builder.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from biomodals.helper.shell import sanitize_filename +from biomodals.schema import ArtifactKind, ArtifactSelector +from biomodals.workflow.core.nodes import WorkflowNode + + +@dataclass(frozen=True) +class NodeHandle: + """Stable handle returned after adding a node to a workflow.""" + + node_id: str + + def outputs( + self, + kind: ArtifactKind | None = None, + pattern: str | None = None, + role: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> ArtifactSelector: + """Select artifacts produced by this node.""" + return ArtifactSelector( + producing_node_id=self.node_id, + kind=kind, + pattern=pattern, + role=role, + metadata=metadata or {}, + ) + + +@dataclass +class WorkflowNodeSpec: + """Builder-time node metadata.""" + + node_id: str + node: WorkflowNode + inputs: dict[str, ArtifactSelector] = field(default_factory=dict) + control_dependencies: set[str] = field(default_factory=set) + + +@dataclass(frozen=True) +class WorkflowDefinition: + """Validated workflow DAG definition.""" + + name: str + nodes: dict[str, WorkflowNodeSpec] + dependencies: dict[str, set[str]] + + +class Workflow: + """Python-first workflow DAG builder.""" + + def __init__(self, name: str): + """Initialize an empty workflow definition.""" + self.name = sanitize_filename(name) + self._nodes: dict[str, WorkflowNodeSpec] = {} + + def add_node( + self, + node: WorkflowNode, + *, + id: str, + inputs: dict[str, ArtifactSelector] | None = None, + depends_on: list[NodeHandle | str] | None = None, + ) -> NodeHandle: + """Add one node to the workflow and return its handle.""" + node_id = sanitize_filename(id) + if node_id in self._nodes: + raise ValueError(f"Duplicate workflow node id: {node_id}") + + control_dependencies = { + dependency.node_id if isinstance(dependency, NodeHandle) else dependency + for dependency in depends_on or [] + } + self._nodes[node_id] = WorkflowNodeSpec( + node_id=node_id, + node=node, + inputs=inputs or {}, + control_dependencies=control_dependencies, + ) + return NodeHandle(node_id=node_id) + + def add_control_edge( + self, + upstream: NodeHandle | str, + downstream: NodeHandle | str, + ) -> None: + """Add an ordering-only dependency between two existing nodes.""" + upstream_id = upstream.node_id if isinstance(upstream, NodeHandle) else upstream + downstream_id = ( + downstream.node_id if isinstance(downstream, NodeHandle) else downstream + ) + self._nodes[downstream_id].control_dependencies.add(upstream_id) + + def validate(self) -> WorkflowDefinition: + """Validate the workflow DAG and return an immutable definition.""" + dependencies = self._dependencies() + missing = { + dependency + for node_dependencies in dependencies.values() + for dependency in node_dependencies + if dependency not in self._nodes + } + if missing: + raise ValueError(f"Unknown workflow node dependencies: {sorted(missing)}") + self._raise_for_cycles(dependencies) + return WorkflowDefinition( + name=self.name, + nodes=dict(self._nodes), + dependencies=dependencies, + ) + + def _dependencies(self) -> dict[str, set[str]]: + dependencies: dict[str, set[str]] = {} + for node_id, spec in self._nodes.items(): + input_dependencies = { + selector.producing_node_id for selector in spec.inputs.values() + } + dependencies[node_id] = input_dependencies | spec.control_dependencies + return dependencies + + @staticmethod + def _raise_for_cycles(dependencies: dict[str, set[str]]) -> None: + temporary: set[str] = set() + permanent: set[str] = set() + + def visit(node_id: str) -> None: + if node_id in permanent: + return + if node_id in temporary: + raise ValueError("Workflow DAG contains a cycle") + temporary.add(node_id) + for dependency in dependencies.get(node_id, set()): + if dependency in dependencies: + visit(dependency) + temporary.remove(node_id) + permanent.add(node_id) + + for node_id in dependencies: + visit(node_id) diff --git a/src/biomodals/workflow/core/ledger.py b/src/biomodals/workflow/core/ledger.py new file mode 100644 index 0000000..03ff223 --- /dev/null +++ b/src/biomodals/workflow/core/ledger.py @@ -0,0 +1,1155 @@ +"""SQLite-backed durable ledger for Biomodals workflow runs.""" + +from __future__ import annotations + +import shutil +import sqlite3 +from collections.abc import Iterable, Iterator +from contextlib import closing, contextmanager +from datetime import datetime +from fnmatch import fnmatch +from pathlib import Path +from threading import RLock +from typing import Any + +import orjson + +from biomodals.schema import ( + AppRunResult, + ArtifactSelector, + AttemptRecord, + NodeExecutionPolicy, + NodePlacement, + NodeStatus, + NodeStatusRecord, + RunStatus, + VolumePath, + WorkflowArtifact, + WorkflowRun, +) + +try: + from datetime import UTC +except ImportError: + from datetime import timezone + + UTC = timezone.utc # noqa: UP017 + +LEDGER_FILENAME = "ledger.sqlite3" +LEDGER_TABLES = ( + "runs", + "nodes", + "attempts", + "remote_calls", + "artifacts", + "artifact_files", + "node_inputs", + "node_outputs", +) + + +class WorkflowLedger: + """SQLite-backed durable state for one workflow run.""" + + def __init__(self, volume_root: str | Path): + """Initialize a ledger rooted at a mounted workflow volume path.""" + self.volume_root = Path(volume_root) + self.workflow_name: str | None = None + self.run_id: str | None = None + self._connection: sqlite3.Connection | None = None + self._lock = RLock() + + @property + def run_root(self) -> Path: + """Return the root directory for the active workflow run.""" + if self.workflow_name is None or self.run_id is None: + raise RuntimeError("Workflow run has not been initialized") + return self.volume_root / self.workflow_name / self.run_id + + @property + def ledger_path(self) -> Path: + """Return the SQLite database path for the active workflow run.""" + return self.run_root / LEDGER_FILENAME + + def create_run(self, run: WorkflowRun) -> WorkflowRun: + """Create a run ledger and initialize its SQLite schema.""" + self._activate(run.workflow_name, run.run_id) + self._create_run_layout() + with self._transaction() as conn: + conn.execute( + """ + INSERT INTO runs ( + run_id, + workflow_name, + dag_hash, + status, + created_at, + updated_at, + metadata_json + ) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + run.run_id, + run.workflow_name, + run.dag_hash, + run.status.value, + _datetime_json(run.created_at), + _datetime_json(run.updated_at), + _json_dumps(run.metadata), + ), + ) + return run + + def load_run(self, workflow_name: str, run_id: str) -> WorkflowRun: + """Load an existing run ledger.""" + self._activate(workflow_name, run_id) + if not self.ledger_path.exists(): + raise FileNotFoundError(self.ledger_path) + row = self._fetch_one("SELECT * FROM runs WHERE run_id = ?", (run_id,)) + if row is None: + raise FileNotFoundError(f"Workflow run not found: {workflow_name}/{run_id}") + return _run_from_row(row) + + def run_exists(self, workflow_name: str, run_id: str) -> bool: + """Return whether a SQLite ledger exists for one workflow run.""" + ledger_path = self.volume_root / workflow_name / run_id / LEDGER_FILENAME + if not ledger_path.exists(): + return False + try: + with closing(sqlite3.connect(ledger_path)) as conn: + row = conn.execute( + "SELECT 1 FROM runs WHERE run_id = ? LIMIT 1", + (run_id,), + ).fetchone() + except sqlite3.Error: + return False + return row is not None + + def mark_run_status(self, status: RunStatus) -> WorkflowRun: + """Update the durable status for the active workflow run.""" + now = _now_json() + with self._transaction() as conn: + conn.execute( + "UPDATE runs SET status = ?, updated_at = ? WHERE run_id = ?", + (status.value, now, self._require_run_id()), + ) + return self.load_run(self._require_workflow_name(), self._require_run_id()) + + def reset_run(self, workflow_name: str, run_id: str) -> None: + """Remove all durable state for one workflow run.""" + if self.workflow_name == workflow_name and self.run_id == run_id: + self._close() + run_root = self.volume_root / workflow_name / run_id + if run_root.exists(): + shutil.rmtree(run_root) + + def node_has_state(self, node_id: str) -> bool: + """Return whether a node has any durable state in this run.""" + row = self._fetch_one( + """ + SELECT 1 FROM nodes WHERE node_id = ? + UNION ALL + SELECT 1 FROM attempts WHERE node_id = ? + UNION ALL + SELECT 1 FROM artifacts WHERE producing_node_id = ? + LIMIT 1 + """, + (node_id, node_id, node_id), + ) + return row is not None or (self.run_root / "nodes" / node_id).exists() + + def node_is_running(self, node_id: str) -> bool: + """Return whether a node is currently marked as running.""" + row = self._fetch_one( + "SELECT status FROM nodes WHERE node_id = ?", + (node_id,), + ) + return row is not None and row["status"] == NodeStatus.RUNNING.value + + def load_node_status(self, node_id: str) -> NodeStatusRecord: + """Load a node status row or return the default pending state.""" + return self._load_node_status_or_default(node_id) + + def reset_node(self, node_id: str) -> None: + """Remove durable state for one workflow node.""" + artifact_ids = [ + row["artifact_id"] + for row in self._fetch_all( + "SELECT artifact_id FROM artifacts WHERE producing_node_id = ?", + (node_id,), + ) + ] + node_dir = self.run_root / "nodes" / node_id + if node_dir.exists(): + shutil.rmtree(node_dir) + for artifact_id in artifact_ids: + manifest_path = self.run_root / "artifacts" / f"{artifact_id}.json" + if manifest_path.exists(): + manifest_path.unlink() + artifact_dir = self.run_root / "artifacts" / artifact_id + if artifact_dir.exists(): + shutil.rmtree(artifact_dir) + + with self._transaction() as conn: + for artifact_id in artifact_ids: + conn.execute( + "DELETE FROM artifact_files WHERE artifact_id = ?", + (artifact_id,), + ) + conn.execute( + "DELETE FROM node_outputs WHERE artifact_id = ?", + (artifact_id,), + ) + conn.execute( + "DELETE FROM artifacts WHERE artifact_id = ?", + (artifact_id,), + ) + conn.execute("DELETE FROM node_inputs WHERE node_id = ?", (node_id,)) + conn.execute("DELETE FROM node_outputs WHERE node_id = ?", (node_id,)) + conn.execute("DELETE FROM remote_calls WHERE node_id = ?", (node_id,)) + conn.execute("DELETE FROM attempts WHERE node_id = ?", (node_id,)) + conn.execute("DELETE FROM nodes WHERE node_id = ?", (node_id,)) + + def mark_node_pending(self, node_id: str) -> NodeStatusRecord: + """Mark a node as pending.""" + self._upsert_node_status(node_id, NodeStatus.PENDING) + return self._load_node_status_or_default(node_id) + + def mark_node_running( + self, + node_id: str, + attempt_id: str, + *, + input_artifact_ids: list[str] | None = None, + execution_policy: NodeExecutionPolicy | None = None, + placement: NodePlacement | None = None, + ) -> NodeStatusRecord: + """Mark a node as running and record its attempt id.""" + now = _now_json() + execution_policy = execution_policy or NodeExecutionPolicy.RERUN + placement = placement or NodePlacement.ORCHESTRATOR + with self._transaction() as conn: + conn.execute( + """ + INSERT INTO nodes ( + node_id, + status, + execution_policy, + placement, + current_attempt_id, + error, + started_at, + completed_at, + updated_at + ) + VALUES (?, ?, ?, ?, ?, NULL, ?, NULL, ?) + ON CONFLICT(node_id) DO UPDATE SET + status = excluded.status, + execution_policy = excluded.execution_policy, + placement = excluded.placement, + current_attempt_id = excluded.current_attempt_id, + error = NULL, + started_at = COALESCE(nodes.started_at, excluded.started_at), + completed_at = NULL, + updated_at = excluded.updated_at + """, + ( + node_id, + NodeStatus.RUNNING.value, + execution_policy.value, + placement.value, + attempt_id, + now, + now, + ), + ) + if input_artifact_ids is not None: + conn.execute( + "DELETE FROM node_inputs WHERE node_id = ? AND input_name = ''", + (node_id,), + ) + conn.executemany( + """ + INSERT OR IGNORE INTO node_inputs ( + node_id, + input_name, + artifact_id + ) + VALUES (?, '', ?) + """, + [(node_id, artifact_id) for artifact_id in input_artifact_ids], + ) + return self._load_node_status_or_default(node_id) + + def record_node_inputs( + self, + node_id: str, + inputs: dict[str, list[WorkflowArtifact]], + ) -> None: + """Record the named artifact inputs resolved for a node attempt.""" + with self._transaction() as conn: + conn.execute("DELETE FROM node_inputs WHERE node_id = ?", (node_id,)) + conn.executemany( + """ + INSERT OR IGNORE INTO node_inputs ( + node_id, + input_name, + artifact_id + ) + VALUES (?, ?, ?) + """, + [ + (node_id, input_name, artifact.artifact_id) + for input_name, artifacts in inputs.items() + for artifact in artifacts + ], + ) + + def mark_node_succeeded( + self, node_id: str, artifact_ids: list[str] + ) -> NodeStatusRecord: + """Mark a node as succeeded with output artifact ids.""" + now = _now_json() + with self._transaction() as conn: + conn.execute( + """ + INSERT INTO nodes ( + node_id, + status, + execution_policy, + placement, + current_attempt_id, + error, + started_at, + completed_at, + updated_at + ) + VALUES (?, ?, ?, ?, NULL, NULL, ?, ?, ?) + ON CONFLICT(node_id) DO UPDATE SET + status = excluded.status, + error = NULL, + completed_at = excluded.completed_at, + updated_at = excluded.updated_at + """, + ( + node_id, + NodeStatus.SUCCEEDED.value, + NodeExecutionPolicy.RERUN.value, + NodePlacement.ORCHESTRATOR.value, + now, + now, + now, + ), + ) + conn.execute("DELETE FROM node_outputs WHERE node_id = ?", (node_id,)) + conn.executemany( + """ + INSERT OR IGNORE INTO node_outputs (node_id, artifact_id) + VALUES (?, ?) + """, + [(node_id, artifact_id) for artifact_id in artifact_ids], + ) + return self._load_node_status_or_default(node_id) + + def mark_node_failed(self, node_id: str, error: str) -> NodeStatusRecord: + """Mark a node as failed with an error message.""" + now = _now_json() + with self._transaction() as conn: + conn.execute( + """ + INSERT INTO nodes ( + node_id, + status, + execution_policy, + placement, + current_attempt_id, + error, + started_at, + completed_at, + updated_at + ) + VALUES (?, ?, ?, ?, NULL, ?, ?, ?, ?) + ON CONFLICT(node_id) DO UPDATE SET + status = excluded.status, + error = excluded.error, + completed_at = excluded.completed_at, + updated_at = excluded.updated_at + """, + ( + node_id, + NodeStatus.FAILED.value, + NodeExecutionPolicy.RERUN.value, + NodePlacement.ORCHESTRATOR.value, + error, + now, + now, + now, + ), + ) + row = conn.execute( + "SELECT current_attempt_id FROM nodes WHERE node_id = ?", + (node_id,), + ).fetchone() + if row is not None and row["current_attempt_id"] is not None: + conn.execute( + """ + UPDATE attempts + SET status = ?, + completed_at = COALESCE(completed_at, ?), + error = COALESCE(error, ?) + WHERE node_id = ? AND attempt_id = ? + """, + ( + NodeStatus.FAILED.value, + now, + error, + node_id, + row["current_attempt_id"], + ), + ) + return self._load_node_status_or_default(node_id) + + def record_attempt_started(self, node_id: str, attempt_id: str) -> AttemptRecord: + """Record that a node attempt started.""" + now = _now_json() + with self._transaction() as conn: + conn.execute( + """ + INSERT INTO attempts ( + attempt_id, + node_id, + status, + started_at, + completed_at, + app_result_json, + error, + metadata_json + ) + VALUES (?, ?, ?, ?, NULL, NULL, NULL, ?) + ON CONFLICT(attempt_id, node_id) DO UPDATE SET + status = excluded.status, + started_at = COALESCE(attempts.started_at, excluded.started_at), + completed_at = NULL, + error = NULL + """, + ( + attempt_id, + node_id, + NodeStatus.RUNNING.value, + now, + _json_dumps({}), + ), + ) + return AttemptRecord(node_id=node_id, attempt_id=attempt_id) + + def record_app_result( + self, node_id: str, attempt_id: str, result: AppRunResult + ) -> Path: + """Record the raw app result for one node attempt in SQLite.""" + with self._transaction() as conn: + conn.execute( + """ + UPDATE attempts + SET app_result_json = ? + WHERE node_id = ? AND attempt_id = ? + """, + (result.model_dump_json(), node_id, attempt_id), + ) + return self.ledger_path + + def load_attempt_app_result( + self, node_id: str, attempt_id: str + ) -> AppRunResult | None: + """Load the recorded app result for one attempt, if present.""" + row = self._fetch_one( + """ + SELECT app_result_json + FROM attempts + WHERE node_id = ? AND attempt_id = ? + """, + (node_id, attempt_id), + ) + if row is None or row["app_result_json"] is None: + return None + return AppRunResult.model_validate_json(row["app_result_json"]) + + def record_attempt_completed( + self, + node_id: str, + attempt_id: str, + status: NodeStatus, + *, + result: AppRunResult | None = None, + error: str | None = None, + ) -> AttemptRecord: + """Record terminal status for one node attempt.""" + now = _now_json() + app_result_json = result.model_dump_json() if result is not None else None + with self._transaction() as conn: + conn.execute( + """ + INSERT INTO attempts ( + attempt_id, + node_id, + status, + started_at, + completed_at, + app_result_json, + error, + metadata_json + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(attempt_id, node_id) DO UPDATE SET + status = excluded.status, + completed_at = excluded.completed_at, + app_result_json = COALESCE( + excluded.app_result_json, + attempts.app_result_json + ), + error = excluded.error + """, + ( + attempt_id, + node_id, + status.value, + now, + now, + app_result_json, + error, + _json_dumps({}), + ), + ) + return AttemptRecord(node_id=node_id, attempt_id=attempt_id, status=status) + + def record_remote_call( + self, + *, + call_id: str, + node_id: str, + attempt_id: str, + function_name: str, + call_kind: str, + status: str = "submitted", + metadata: dict[str, Any] | None = None, + ) -> None: + """Record a submitted Modal function call before waiting for it.""" + now = _now_json() + with self._transaction() as conn: + conn.execute( + """ + INSERT INTO remote_calls ( + call_id, + node_id, + attempt_id, + function_name, + call_kind, + status, + submitted_at, + completed_at, + error, + metadata_json + ) + VALUES (?, ?, ?, ?, ?, ?, ?, NULL, NULL, ?) + ON CONFLICT(call_id) DO UPDATE SET + node_id = excluded.node_id, + attempt_id = excluded.attempt_id, + function_name = excluded.function_name, + call_kind = excluded.call_kind, + status = excluded.status, + metadata_json = excluded.metadata_json + """, + ( + call_id, + node_id, + attempt_id, + function_name, + call_kind, + status, + now, + _json_dumps(metadata or {}), + ), + ) + + def mark_remote_call_status( + self, + call_id: str, + status: str, + *, + error: str | None = None, + completed: bool = False, + metadata: dict[str, Any] | None = None, + ) -> None: + """Update status for a recorded Modal function call.""" + completed_at = _now_json() if completed else None + with self._transaction() as conn: + if completed and metadata is not None: + conn.execute( + """ + UPDATE remote_calls + SET status = ?, + error = ?, + completed_at = ?, + metadata_json = ? + WHERE call_id = ? + """, + (status, error, completed_at, _json_dumps(metadata), call_id), + ) + elif completed: + conn.execute( + """ + UPDATE remote_calls + SET status = ?, + error = ?, + completed_at = ? + WHERE call_id = ? + """, + (status, error, completed_at, call_id), + ) + elif metadata is not None: + conn.execute( + """ + UPDATE remote_calls + SET status = ?, + error = ?, + metadata_json = ? + WHERE call_id = ? + """, + (status, error, _json_dumps(metadata), call_id), + ) + else: + conn.execute( + """ + UPDATE remote_calls + SET status = ?, + error = ? + WHERE call_id = ? + """, + (status, error, call_id), + ) + + def latest_remote_call( + self, + node_id: str, + *, + statuses: Iterable[str] | None = None, + ) -> dict[str, Any] | None: + """Return the latest remote call row for a node.""" + status_values: set[str] | None = None + if statuses is not None: + status_values = set(statuses) + if not status_values: + return None + rows = self._fetch_all( + """ + SELECT * + FROM remote_calls + WHERE node_id = ? + ORDER BY submitted_at DESC, call_id DESC + """, + (node_id,), + ) + for row in rows: + if status_values is None or row["status"] in status_values: + return _row_to_dict(row) + return None + + def load_remote_call(self, call_id: str) -> dict[str, Any] | None: + """Return one remote call row by id.""" + row = self._fetch_one( + "SELECT * FROM remote_calls WHERE call_id = ?", + (call_id,), + ) + if row is None: + return None + return _row_to_dict(row) + + def record_artifacts(self, artifacts: list[WorkflowArtifact]) -> list[Path]: + """Record workflow artifact manifests in SQLite rows.""" + paths: list[Path] = [] + with self._transaction() as conn: + for artifact in artifacts: + conn.execute( + """ + INSERT INTO artifacts ( + artifact_id, + producing_node_id, + kind, + volume_name, + storage_path, + source_app_output_name, + created_at, + metadata_json + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(artifact_id) DO UPDATE SET + producing_node_id = excluded.producing_node_id, + kind = excluded.kind, + volume_name = excluded.volume_name, + storage_path = excluded.storage_path, + source_app_output_name = excluded.source_app_output_name, + metadata_json = excluded.metadata_json + """, + ( + artifact.artifact_id, + artifact.producing_node_id, + artifact.kind.value, + artifact.storage.volume_name, + artifact.storage.path, + artifact.source_app_output_name, + _now_json(), + _json_dumps(artifact.metadata), + ), + ) + conn.execute( + "DELETE FROM artifact_files WHERE artifact_id = ?", + (artifact.artifact_id,), + ) + conn.executemany( + """ + INSERT OR REPLACE INTO artifact_files ( + artifact_id, + path, + role, + media_type, + size_bytes, + metadata_json + ) + VALUES (?, ?, ?, ?, ?, ?) + """, + [ + ( + artifact.artifact_id, + file.path, + file.role, + file.media_type, + file.size_bytes, + _json_dumps(file.metadata), + ) + for file in artifact.files + ], + ) + paths.append( + self.run_root / "artifacts" / f"{artifact.artifact_id}.json" + ) + return paths + + def load_artifact(self, artifact_id: str) -> WorkflowArtifact: + """Load one artifact manifest by id.""" + row = self._fetch_one( + "SELECT * FROM artifacts WHERE artifact_id = ?", + (artifact_id,), + ) + if row is None: + raise FileNotFoundError(f"Workflow artifact not found: {artifact_id}") + return self._artifact_from_row(row) + + def select_artifacts(self, selector: ArtifactSelector) -> list[WorkflowArtifact]: + """Return artifacts matching one upstream selector.""" + return [ + artifact + for artifact in self._load_artifacts_if_any() + if self._artifact_matches_selector(artifact, selector) + ] + + def node_is_complete(self, node_id: str) -> bool: + """Return whether a node has succeeded and all artifacts are recorded.""" + row = self._fetch_one( + "SELECT status FROM nodes WHERE node_id = ?", + (node_id,), + ) + if row is None or row["status"] != NodeStatus.SUCCEEDED.value: + return False + output_ids = [ + output_row["artifact_id"] + for output_row in self._fetch_all( + "SELECT artifact_id FROM node_outputs WHERE node_id = ?", + (node_id,), + ) + ] + if not output_ids: + return True + return all( + self._fetch_one( + "SELECT 1 FROM artifacts WHERE artifact_id = ?", + (artifact_id,), + ) + is not None + for artifact_id in output_ids + ) + + def next_attempt_id(self, node_id: str) -> str: + """Return the next deterministic attempt id for one node.""" + rows = self._fetch_all( + "SELECT attempt_id FROM attempts WHERE node_id = ?", + (node_id,), + ) + max_suffix = 0 + prefix = "attempt-" + for row in rows: + attempt_id = str(row["attempt_id"]) + suffix = attempt_id.removeprefix(prefix) + if suffix != attempt_id and suffix.isdecimal(): + max_suffix = max(max_suffix, int(suffix)) + return f"attempt-{max_suffix + 1}" + + def close(self) -> None: + """Close the active SQLite connection, if one is open.""" + with self._lock: + self._close() + + @contextmanager + def closed_for_volume_sync(self) -> Iterator[None]: + """Close the SQLite connection while synchronizing the backing volume.""" + with self._lock: + # Modal Volume sync fails if SQLite keeps ledger files open. Hold the + # ledger lock so another scheduler worker cannot reopen them mid-sync. + self._close() + yield + + @staticmethod + def _artifact_matches_selector( + artifact: WorkflowArtifact, + selector: ArtifactSelector, + ) -> bool: + if artifact.producing_node_id != selector.producing_node_id: + return False + if selector.kind is not None and artifact.kind != selector.kind: + return False + for key, expected in selector.metadata.items(): + if artifact.metadata.get(key) != expected: + return False + if selector.pattern is None and selector.role is None: + return True + return any( + (selector.pattern is None or fnmatch(file.path, selector.pattern)) + and (selector.role is None or file.role == selector.role) + for file in artifact.files + ) + + def _load_node_status_or_default(self, node_id: str) -> NodeStatusRecord: + row = self._fetch_one("SELECT * FROM nodes WHERE node_id = ?", (node_id,)) + if row is None: + return NodeStatusRecord(node_id=node_id, status=NodeStatus.PENDING) + attempts = [ + attempt_row["attempt_id"] + for attempt_row in self._fetch_all( + """ + SELECT attempt_id + FROM attempts + WHERE node_id = ? + ORDER BY started_at, attempt_id + """, + (node_id,), + ) + ] + input_artifact_ids = [ + input_row["artifact_id"] + for input_row in self._fetch_all( + """ + SELECT artifact_id + FROM node_inputs + WHERE node_id = ? + ORDER BY input_name, artifact_id + """, + (node_id,), + ) + ] + output_artifact_ids = [ + output_row["artifact_id"] + for output_row in self._fetch_all( + """ + SELECT artifact_id + FROM node_outputs + WHERE node_id = ? + ORDER BY artifact_id + """, + (node_id,), + ) + ] + return NodeStatusRecord( + node_id=node_id, + status=NodeStatus(row["status"]), + execution_policy=NodeExecutionPolicy(row["execution_policy"]), + placement=NodePlacement(row["placement"]), + input_artifact_ids=input_artifact_ids, + output_artifact_ids=output_artifact_ids, + attempts=attempts, + error=row["error"], + ) + + def _load_artifacts_if_any(self) -> list[WorkflowArtifact]: + return [ + self._artifact_from_row(row) + for row in self._fetch_all( + "SELECT * FROM artifacts ORDER BY artifact_id", + ) + ] + + def _artifact_from_row(self, row: sqlite3.Row) -> WorkflowArtifact: + file_rows = self._fetch_all( + """ + SELECT * + FROM artifact_files + WHERE artifact_id = ? + ORDER BY path + """, + (row["artifact_id"],), + ) + return WorkflowArtifact.model_validate({ + "artifact_id": row["artifact_id"], + "producing_node_id": row["producing_node_id"], + "kind": row["kind"], + "storage": VolumePath( + volume_name=row["volume_name"], + path=row["storage_path"], + ), + "files": [ + { + "path": file_row["path"], + "role": file_row["role"], + "media_type": file_row["media_type"], + "size_bytes": file_row["size_bytes"], + "metadata": _json_loads(file_row["metadata_json"]), + } + for file_row in file_rows + ], + "source_app_output_name": row["source_app_output_name"], + "metadata": _json_loads(row["metadata_json"]), + }) + + def _upsert_node_status(self, node_id: str, status: NodeStatus) -> None: + now = _now_json() + with self._transaction() as conn: + conn.execute( + """ + INSERT INTO nodes ( + node_id, + status, + execution_policy, + placement, + current_attempt_id, + error, + started_at, + completed_at, + updated_at + ) + VALUES (?, ?, ?, ?, NULL, NULL, NULL, NULL, ?) + ON CONFLICT(node_id) DO UPDATE SET + status = excluded.status, + updated_at = excluded.updated_at + """, + ( + node_id, + status.value, + NodeExecutionPolicy.RERUN.value, + NodePlacement.ORCHESTRATOR.value, + now, + ), + ) + + def _create_run_layout(self) -> None: + self.run_root.mkdir(parents=True, exist_ok=True) + for name in ("inputs", "nodes", "artifacts", "final"): + self.run_root.joinpath(name).mkdir(exist_ok=True) + + def _activate(self, workflow_name: str, run_id: str) -> None: + with self._lock: + if self.workflow_name == workflow_name and self.run_id == run_id: + return + self._close() + self.workflow_name = workflow_name + self.run_id = run_id + + @contextmanager + def _transaction(self): + with self._lock: + conn = self._connect() + try: + yield conn + except Exception: + conn.rollback() + raise + else: + conn.commit() + + def _fetch_one( + self, + sql: str, + params: Iterable[Any] = (), + ) -> sqlite3.Row | None: + with self._lock: + return self._connect().execute(sql, tuple(params)).fetchone() + + def _fetch_all( + self, + sql: str, + params: Iterable[Any] = (), + ) -> list[sqlite3.Row]: + with self._lock: + return list(self._connect().execute(sql, tuple(params)).fetchall()) + + def _connect(self) -> sqlite3.Connection: + if self._connection is None: + self.run_root.mkdir(parents=True, exist_ok=True) + self._connection = sqlite3.connect( + self.ledger_path, + check_same_thread=False, + ) + self._connection.row_factory = sqlite3.Row + self._connection.execute("PRAGMA foreign_keys = ON") + self._connection.execute("PRAGMA journal_mode = WAL") + self._ensure_schema() + return self._connection + + def _ensure_schema(self) -> None: + if self._connection is None: + raise RuntimeError("SQLite connection has not been opened") + self._connection.executescript( + """ + CREATE TABLE IF NOT EXISTS runs ( + run_id TEXT PRIMARY KEY, + workflow_name TEXT NOT NULL, + dag_hash TEXT, + status TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + metadata_json TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS nodes ( + node_id TEXT PRIMARY KEY, + status TEXT NOT NULL, + execution_policy TEXT NOT NULL, + placement TEXT NOT NULL, + current_attempt_id TEXT, + error TEXT, + started_at TEXT, + completed_at TEXT, + updated_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS attempts ( + attempt_id TEXT NOT NULL, + node_id TEXT NOT NULL, + status TEXT NOT NULL, + started_at TEXT NOT NULL, + completed_at TEXT, + app_result_json TEXT, + error TEXT, + metadata_json TEXT NOT NULL, + PRIMARY KEY (attempt_id, node_id) + ); + + CREATE TABLE IF NOT EXISTS remote_calls ( + call_id TEXT PRIMARY KEY, + node_id TEXT NOT NULL, + attempt_id TEXT NOT NULL, + function_name TEXT NOT NULL, + call_kind TEXT NOT NULL, + status TEXT NOT NULL, + submitted_at TEXT NOT NULL, + completed_at TEXT, + error TEXT, + metadata_json TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS artifacts ( + artifact_id TEXT PRIMARY KEY, + producing_node_id TEXT NOT NULL, + kind TEXT NOT NULL, + volume_name TEXT NOT NULL, + storage_path TEXT NOT NULL, + source_app_output_name TEXT, + created_at TEXT NOT NULL, + metadata_json TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS artifact_files ( + artifact_id TEXT NOT NULL, + path TEXT NOT NULL, + role TEXT, + media_type TEXT, + size_bytes INTEGER, + metadata_json TEXT NOT NULL, + PRIMARY KEY (artifact_id, path) + ); + + CREATE TABLE IF NOT EXISTS node_inputs ( + node_id TEXT NOT NULL, + input_name TEXT NOT NULL, + artifact_id TEXT NOT NULL, + PRIMARY KEY (node_id, input_name, artifact_id) + ); + + CREATE TABLE IF NOT EXISTS node_outputs ( + node_id TEXT NOT NULL, + artifact_id TEXT NOT NULL, + PRIMARY KEY (node_id, artifact_id) + ); + + CREATE INDEX IF NOT EXISTS idx_nodes_status + ON nodes(status); + CREATE INDEX IF NOT EXISTS idx_attempts_node + ON attempts(node_id); + CREATE INDEX IF NOT EXISTS idx_remote_calls_node_status + ON remote_calls(node_id, status); + CREATE INDEX IF NOT EXISTS idx_artifacts_producing_node + ON artifacts(producing_node_id); + """ + ) + self._connection.commit() + + def _close(self) -> None: + if self._connection is not None: + self._connection.close() + self._connection = None + + def _require_workflow_name(self) -> str: + if self.workflow_name is None: + raise RuntimeError("Workflow run has not been initialized") + return self.workflow_name + + def _require_run_id(self) -> str: + if self.run_id is None: + raise RuntimeError("Workflow run has not been initialized") + return self.run_id + + @staticmethod + def _read_json(path: Path) -> object: + """Read JSON files written by artifact materialization helpers.""" + return orjson.loads(path.read_bytes()) + + +def _run_from_row(row: sqlite3.Row) -> WorkflowRun: + return WorkflowRun( + workflow_name=row["workflow_name"], + run_id=row["run_id"], + status=RunStatus(row["status"]), + dag_hash=row["dag_hash"], + created_at=row["created_at"], + updated_at=row["updated_at"], + metadata=_json_loads(row["metadata_json"]), + ) + + +def _row_to_dict(row: sqlite3.Row) -> dict[str, Any]: + return {key: row[key] for key in row.keys()} + + +def _json_dumps(value: object) -> str: + return orjson.dumps(value, option=orjson.OPT_SORT_KEYS).decode("utf-8") + + +def _json_loads(value: str | bytes | None) -> dict[str, Any]: + if not value: + return {} + return orjson.loads(value) + + +def _now_json() -> str: + return _datetime_json(datetime.now(UTC)) + + +def _datetime_json(value: datetime) -> str: + return value.astimezone(UTC).isoformat() diff --git a/src/biomodals/workflow/core/nodes.py b/src/biomodals/workflow/core/nodes.py new file mode 100644 index 0000000..ced80b8 --- /dev/null +++ b/src/biomodals/workflow/core/nodes.py @@ -0,0 +1,57 @@ +"""Base node contracts for Biomodals workflows.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Protocol + +from biomodals.schema import ( + AppRunResult, + NodeExecutionPolicy, + NodePlacement, + WorkflowArtifact, +) + + +@dataclass(frozen=True) +class NodeRunContext: + """Runtime context passed to workflow node implementations.""" + + run_id: str + node_id: str + attempt_id: str + cache_dir: Path + inputs: dict[str, list[WorkflowArtifact]] + + +class WorkflowNode(Protocol): + """Protocol for one semantic workflow DAG vertex.""" + + execution_policy: NodeExecutionPolicy + placement: NodePlacement + + def run(self, context: NodeRunContext) -> AppRunResult: + """Execute node implementation and return app-compatible outputs.""" + + +class WorkflowNativeNode: + """Base class for workflow nodes implemented directly in workflow code.""" + + execution_policy = NodeExecutionPolicy.RERUN + placement = NodePlacement.ORCHESTRATOR + + def run(self, context: NodeRunContext) -> AppRunResult: + """Execute workflow-native logic.""" + raise NotImplementedError + + +class AppBackedNode: + """Base class for workflow nodes implemented by calling app functions.""" + + execution_policy = NodeExecutionPolicy.RERUN + placement = NodePlacement.REMOTE + + def run(self, context: NodeRunContext) -> AppRunResult: + """Execute the app-backed node implementation.""" + raise NotImplementedError diff --git a/src/biomodals/workflow/core/orchestrator.py b/src/biomodals/workflow/core/orchestrator.py new file mode 100644 index 0000000..c1e6169 --- /dev/null +++ b/src/biomodals/workflow/core/orchestrator.py @@ -0,0 +1,144 @@ +"""Workflow orchestrator helpers and Modal class boundary.""" + +from __future__ import annotations + +import os +from pathlib import Path + +import modal + +from biomodals.app.config import AppConfig +from biomodals.helper import patch_image_for_helper +from biomodals.helper.constant import ( + MAX_TIMEOUT, + WORKFLOW_ORCHESTRATOR_VOLUME, + WORKFLOW_ORCHESTRATOR_VOLUME_NAME, +) +from biomodals.schema import AppRunResult +from biomodals.workflow.core.builder import Workflow +from biomodals.workflow.core.nodes import NodeRunContext, WorkflowNode +from biomodals.workflow.core.runtime import RemoteFunctionCall, WorkflowRuntime + +CONF = AppConfig( + tags={"group": "workflow"}, + name="WorkflowOrchestrator", + package_name="biomodals-workflow-orchestrator", + version="0.1.0", + python_version="3.13", + timeout=int(os.environ.get("TIMEOUT", str(MAX_TIMEOUT))), +) +OUT_VOLUME = WORKFLOW_ORCHESTRATOR_VOLUME +OUT_VOLUME_NAME = WORKFLOW_ORCHESTRATOR_VOLUME_NAME +REMOTE_NODE_FUNCTION_NAME = "run_node" + +runtime_image = ( + modal.Image + .debian_slim(python_version=CONF.python_version) + .env(CONF.default_env) + .pipe(patch_image_for_helper, include_workflow_modules=True) +) +app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags) + + +@app.cls( + cpu=(1.125, 16.125), + memory=(1024, 65536), + timeout=MAX_TIMEOUT, + volumes={CONF.output_volume_mountpoint: OUT_VOLUME}, +) +class WorkflowOrchestrator: + """Modal-hosted coordinator for one workflow run.""" + + @modal.enter() + def enter(self) -> None: + """Refresh the workflow volume before serving orchestrator methods.""" + self._close_runtime() + self._exit_cleanup_done = False + OUT_VOLUME.reload() + + @modal.method() + def run( + self, + workflow: Workflow, + run_id: str, + force: bool = False, + max_ready_workers: int = 32, + ) -> AppRunResult: + """Run one workflow definition through the workflow runtime.""" + if not isinstance(workflow, Workflow): + raise TypeError("workflow must be a Workflow object") + + def remote_node_runner( + node: WorkflowNode, + context: NodeRunContext, + ) -> RemoteFunctionCall: + function_call = self.run_node.spawn(node, context) + if not hasattr(function_call, "object_id") or not hasattr( + function_call, "get" + ): + raise TypeError( + "Remote workflow node spawn did not return a FunctionCall" + ) + return function_call + + OUT_VOLUME.reload() + self._runtime = WorkflowRuntime( + workflow=workflow, + volume_root=Path(CONF.output_volume_mountpoint), + workflow_volume_name=OUT_VOLUME_NAME, + workflow_volume=OUT_VOLUME, + remote_node_runner=remote_node_runner, + remote_node_function_name=REMOTE_NODE_FUNCTION_NAME, + function_call_resolver=modal.FunctionCall.from_id, + max_ready_workers=max_ready_workers, + ) + try: + return self._runtime.run(run_id=run_id, force=force) + finally: + self._close_runtime() + OUT_VOLUME.commit() + + @modal.method() + def run_node( + self, + node: WorkflowNode, + context: NodeRunContext, + ) -> AppRunResult: + """Run one failure-isolated workflow node in a separate Modal method call.""" + OUT_VOLUME.reload() + try: + return node.run(context) + finally: + OUT_VOLUME.commit() + + @modal.exit() + def exit(self) -> None: + """Persist any pending workflow volume writes on container shutdown.""" + if not getattr(self, "_exit_cleanup_done", False): + self._exit_cleanup_done = True + runtime = getattr(self, "_runtime", None) + if runtime is not None: + cancel_active_remote_calls = getattr( + runtime, + "cancel_active_remote_calls", + None, + ) + if cancel_active_remote_calls is not None: + try: + cancel_active_remote_calls(terminate_containers=True) + except Exception as exc: # noqa: BLE001 + print( + "[workflow] Remote call cleanup failed during " + f"orchestrator exit: {exc}", + flush=True, + ) + self._close_runtime() + OUT_VOLUME.commit() + + def _close_runtime(self) -> None: + runtime = getattr(self, "_runtime", None) + if runtime is not None: + close = getattr(runtime, "close", None) + if close is not None: + close() + self._runtime = None diff --git a/src/biomodals/workflow/core/runtime.py b/src/biomodals/workflow/core/runtime.py new file mode 100644 index 0000000..bdf1341 --- /dev/null +++ b/src/biomodals/workflow/core/runtime.py @@ -0,0 +1,700 @@ +"""Local workflow runtime scheduler.""" + +from __future__ import annotations + +import hashlib +import traceback +from collections.abc import Callable, Mapping +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import fields, is_dataclass +from enum import Enum +from pathlib import Path +from threading import RLock +from typing import Protocol + +import orjson +from pydantic import BaseModel + +from biomodals.schema import ( + AppRunResult, + AppRunStatus, + ArtifactSelector, + AttemptRecord, + NodeExecutionPolicy, + NodePlacement, + NodeStatus, + RunStatus, + WorkflowArtifact, + WorkflowRun, +) +from biomodals.workflow.core.artifacts import materialize_app_run_result +from biomodals.workflow.core.builder import Workflow, WorkflowDefinition +from biomodals.workflow.core.ledger import WorkflowLedger +from biomodals.workflow.core.nodes import NodeRunContext, WorkflowNode + + +class RemoteFunctionCall(Protocol): + """Minimal Modal FunctionCall boundary used by remote workflow nodes.""" + + object_id: str + + def get(self, timeout: float | int | None = None) -> object: + """Return the remote function result or raise TimeoutError.""" + + +RemoteNodeRunner = Callable[ + [WorkflowNode, NodeRunContext], AppRunResult | RemoteFunctionCall +] +FunctionCallResolver = Callable[[str], RemoteFunctionCall] + + +class WorkflowVolume(Protocol): + """Minimal Modal Volume boundary used by the workflow runtime.""" + + def commit(self) -> object: + """Persist pending writes to the mounted volume.""" + + def reload(self) -> object: + """Refresh local view of writes made by other containers.""" + + +class WorkflowRuntime: + """Local runtime core for scheduling workflow nodes against a ledger.""" + + def __init__( + self, + *, + workflow: Workflow, + volume_root: str | Path, + workflow_volume_name: str, + workflow_volume: WorkflowVolume | None = None, + remote_node_runner: RemoteNodeRunner | None = None, + remote_node_function_name: str | None = None, + function_call_resolver: FunctionCallResolver | None = None, + remote_call_poll_timeout: float | int = 0, + max_ready_workers: int = 32, + ): + """Initialize a runtime for one workflow and ledger root.""" + self.workflow = workflow + self.volume_root = Path(volume_root) + self.workflow_volume_name = workflow_volume_name + self.workflow_volume = workflow_volume + self.remote_node_runner = remote_node_runner + self.remote_node_function_name = remote_node_function_name + self.function_call_resolver = function_call_resolver + self.remote_call_poll_timeout = remote_call_poll_timeout + self.max_ready_workers = max_ready_workers + self.ledger = WorkflowLedger(self.volume_root) + self.executed_waves: list[list[str]] = [] + self._active_remote_calls: dict[str, RemoteFunctionCall] = {} + self._active_remote_calls_lock = RLock() + + def run(self, *, run_id: str, force: bool = False) -> AppRunResult: + """Run the workflow until every node succeeds or no progress is possible.""" + definition = self.workflow.validate() + dag_hash = self._dag_hash(definition) + print( + f"[workflow] Starting workflow '{definition.name}' run '{run_id}' " + f"with {len(definition.nodes)} node(s)", + flush=True, + ) + print( + "[workflow] DAG graph: node_id [placement; class] <- dependency", flush=True + ) + for node_id, spec in definition.nodes.items(): + dependencies = sorted(definition.dependencies[node_id]) + dependency_text = ", ".join(dependencies) if dependencies else "-" + node_class = ( + f"{spec.node.__class__.__module__}.{spec.node.__class__.__qualname__}" + ) + print( + f"[workflow] {node_id} " + f"[{spec.node.placement.value}; {node_class}] <- {dependency_text}", + flush=True, + ) + self._reload_volume() + run_exists = self.ledger.run_exists(definition.name, run_id) + if run_exists and force: + self.ledger.reset_run(definition.name, run_id) + self._commit_volume() + self.ledger.create_run( + WorkflowRun( + workflow_name=definition.name, + run_id=run_id, + dag_hash=dag_hash, + ) + ) + self._commit_volume() + elif run_exists: + existing_run = self.ledger.load_run(definition.name, run_id) + if existing_run.dag_hash is not None and existing_run.dag_hash != dag_hash: + raise ValueError( + "DAG hash does not match existing workflow run; rerun with force" + ) + else: + self.ledger.create_run( + WorkflowRun( + workflow_name=definition.name, + run_id=run_id, + dag_hash=dag_hash, + ) + ) + self._commit_volume() + self.ledger.mark_run_status(RunStatus.RUNNING) + self._commit_volume() + + while True: + completed = self._completed_nodes(definition.nodes.keys()) + if len(completed) == len(definition.nodes): + self.ledger.mark_run_status(RunStatus.SUCCEEDED) + self._commit_volume() + return AppRunResult(status=AppRunStatus.SUCCEEDED) + + ready = [ + node_id + for node_id, dependencies in definition.dependencies.items() + if node_id not in completed + and dependencies.issubset(completed) + and self._node_can_make_progress(node_id) + ] + if not ready: + running = [ + node_id + for node_id, dependencies in definition.dependencies.items() + if node_id not in completed + and dependencies.issubset(completed) + and self.ledger.node_is_running(node_id) + ] + if running: + self._commit_volume() + return AppRunResult( + status=AppRunStatus.PARTIAL, + warnings=[ + "Workflow has in-flight nodes without a recoverable " + f"remote call: {', '.join(sorted(running))}" + ], + ) + self.ledger.mark_run_status(RunStatus.FAILED) + self._commit_volume() + return AppRunResult( + status=AppRunStatus.FAILED, + warnings=["No runnable workflow nodes remain"], + ) + + self.executed_waves.append(ready) + for node_id, node_result in self._run_ready_nodes(ready): + if node_result.status in {AppRunStatus.FAILED, AppRunStatus.PARTIAL}: + error = self._node_error_message(node_result) + node_status = self.ledger.load_node_status(node_id) + if node_status.status != NodeStatus.FAILED or not node_status.error: + self.ledger.mark_node_failed(node_id, error) + self.ledger.mark_run_status(RunStatus.FAILED) + self._commit_volume() + return AppRunResult( + status=node_result.status, + warnings=node_result.warnings or [error], + ) + + def _completed_nodes(self, node_ids) -> set[str]: + return { + node_id for node_id in node_ids if self.ledger.node_is_complete(node_id) + } + + def _node_can_make_progress(self, node_id: str) -> bool: + if self.ledger.node_is_complete(node_id): + return False + if not self.ledger.node_is_running(node_id): + return True + return ( + self.ledger.latest_remote_call( + node_id, + statuses=("submitted", "running", "succeeded"), + ) + is not None + ) + + def _run_ready_nodes(self, node_ids: list[str]) -> list[tuple[str, AppRunResult]]: + results: list[tuple[str, AppRunResult]] = [] + max_workers = min(len(node_ids), max(1, self.max_ready_workers)) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(self._run_node, node_id): node_id + for node_id in node_ids + } + for future in as_completed(futures): + node_id = futures[future] + try: + results.append((node_id, future.result())) + except Exception as exc: # noqa: BLE001 + error = "".join( + traceback.format_exception(type(exc), exc, exc.__traceback__) + ) + print( + f"[workflow] Node failed: {node_id}: {exc}", + flush=True, + ) + self.ledger.mark_node_failed(node_id, error) + self._commit_volume() + results.append(( + node_id, + AppRunResult( + status=AppRunStatus.FAILED, + warnings=[str(exc)], + ), + )) + return results + + def _run_node(self, node_id: str) -> AppRunResult: + definition = self.workflow.validate() + spec = definition.nodes[node_id] + recovered = self._recover_remote_node_if_possible(node_id) + if recovered is not None: + return recovered + if ( + spec.node.execution_policy == NodeExecutionPolicy.RERUN + and self.ledger.node_has_state(node_id) + and not self.ledger.node_is_complete(node_id) + ): + self.ledger.reset_node(node_id) + self._commit_volume() + attempt_id = self._next_attempt_id(node_id) + inputs = self._resolve_inputs(spec.inputs) + input_artifact_ids = [ + artifact.artifact_id + for artifacts in inputs.values() + for artifact in artifacts + ] + self.ledger.mark_node_running( + node_id, + attempt_id, + input_artifact_ids=input_artifact_ids, + execution_policy=spec.node.execution_policy, + placement=spec.node.placement, + ) + print( + f"[workflow] Node started: {node_id} attempt={attempt_id} " + f"placement={spec.node.placement.value}", + flush=True, + ) + self.ledger.record_node_inputs(node_id, inputs) + attempt = self.ledger.record_attempt_started(node_id, attempt_id) + attempt_dir = self._attempt_dir(attempt) + cache_dir = self.ledger.run_root / "nodes" / node_id / "cache" + cache_dir.mkdir(parents=True, exist_ok=True) + self._commit_volume() + + result = self._dispatch_node( + spec.node, + NodeRunContext( + run_id=self.ledger.run_id or "", + node_id=node_id, + attempt_id=attempt_id, + cache_dir=cache_dir, + inputs=inputs, + ), + ) + return self._finalize_node_result( + node_id=node_id, + attempt_id=attempt_id, + attempt_dir=attempt_dir, + result=result, + ) + + def _finalize_node_result( + self, + *, + node_id: str, + attempt_id: str, + attempt_dir: Path, + result: AppRunResult, + ) -> AppRunResult: + self.ledger.record_app_result(node_id, attempt_id, result) + if result.status in {AppRunStatus.FAILED, AppRunStatus.PARTIAL}: + print( + f"[workflow] Node failed: {node_id} attempt={attempt_id}: " + f"{self._node_error_message(result)}", + flush=True, + ) + if result.logs: + log_artifacts = materialize_app_run_result( + result=AppRunResult(status=result.status, logs=result.logs), + workflow_volume_name=self.workflow_volume_name, + attempt_dir=attempt_dir, + artifact_dir=self.ledger.run_root / "artifacts", + producing_node_id=node_id, + volume_root=self.volume_root, + ) + self.ledger.record_artifacts(log_artifacts) + self.ledger.record_attempt_completed( + node_id, + attempt_id, + NodeStatus.FAILED, + result=result, + error=self._node_error_message(result), + ) + self._commit_volume() + return result + + artifacts = materialize_app_run_result( + result=result, + workflow_volume_name=self.workflow_volume_name, + attempt_dir=attempt_dir, + artifact_dir=self.ledger.run_root / "artifacts", + producing_node_id=node_id, + volume_root=self.volume_root, + ) + self.ledger.record_artifacts(artifacts) + self.ledger.mark_node_succeeded( + node_id, + [artifact.artifact_id for artifact in artifacts], + ) + self.ledger.record_attempt_completed( + node_id, + attempt_id, + NodeStatus.SUCCEEDED, + result=result, + ) + self._commit_volume() + print( + f"[workflow] Node succeeded: {node_id} attempt={attempt_id} " + f"artifacts={len(artifacts)}", + flush=True, + ) + return result + + def _resolve_inputs( + self, + selectors: dict[str, ArtifactSelector], + ) -> dict[str, list[WorkflowArtifact]]: + self._reload_volume() + return { + input_name: self.ledger.select_artifacts(selector) + for input_name, selector in selectors.items() + } + + def _dispatch_node( + self, node: WorkflowNode, context: NodeRunContext + ) -> AppRunResult: + if ( + node.placement == NodePlacement.REMOTE + and self.remote_node_runner is not None + ): + return self._run_remote_node(node, context) + return node.run(context) + + def _run_remote_node( + self, + node: WorkflowNode, + context: NodeRunContext, + ) -> AppRunResult: + if self.remote_node_runner is None: + return node.run(context) + + remote_result = self.remote_node_runner(node, context) + if isinstance(remote_result, AppRunResult): + self._reload_volume() + return remote_result + + call_id = str(remote_result.object_id) + with self._active_remote_calls_lock: + self._active_remote_calls[call_id] = remote_result + try: + self.ledger.record_remote_call( + call_id=call_id, + node_id=context.node_id, + attempt_id=context.attempt_id, + function_name=self._remote_function_name(node), + call_kind="node", + ) + self._commit_volume() + result = self._collect_remote_call(call_id, remote_result) + self._reload_volume() + return result + finally: + with self._active_remote_calls_lock: + self._active_remote_calls.pop(call_id, None) + + def cancel_active_remote_calls(self, *, terminate_containers: bool = True) -> None: + """Cancel Modal function calls spawned by this runtime instance.""" + with self._active_remote_calls_lock: + active_remote_calls = dict(self._active_remote_calls) + if not active_remote_calls: + return + + print( + "[workflow] Cancelling " + f"{len(active_remote_calls)} in-flight remote call(s)", + flush=True, + ) + for call_id, function_call in active_remote_calls.items(): + cancel = getattr(function_call, "cancel", None) + if cancel is None: + print( + f"[workflow] Remote call cannot be cancelled: {call_id}", + flush=True, + ) + continue + try: + cancel(terminate_containers=terminate_containers) + except Exception as exc: # noqa: BLE001 + print( + f"[workflow] Remote call cancellation failed: {call_id}: {exc}", + flush=True, + ) + continue + + print(f"[workflow] Remote call cancelled: {call_id}", flush=True) + try: + self.ledger.mark_remote_call_status( + call_id, + "cancelled", + completed=True, + ) + self._commit_volume() + except Exception as exc: # noqa: BLE001 + print( + "[workflow] Remote call cancellation status could not be " + f"recorded: {call_id}: {exc}", + flush=True, + ) + + def _recover_remote_node_if_possible(self, node_id: str) -> AppRunResult | None: + succeeded_call = self.ledger.latest_remote_call( + node_id, + statuses=("succeeded",), + ) + if succeeded_call is not None: + result = self.ledger.load_attempt_app_result( + node_id, + str(succeeded_call["attempt_id"]), + ) + if result is not None: + self._reload_volume() + return self._finalize_node_result( + node_id=node_id, + attempt_id=str(succeeded_call["attempt_id"]), + attempt_dir=self.ledger.run_root + / "nodes" + / node_id + / "attempts" + / str(succeeded_call["attempt_id"]), + result=result, + ) + + remote_call = self.ledger.latest_remote_call( + node_id, + statuses=("submitted", "running"), + ) + if remote_call is None: + return None + call_id = str(remote_call["call_id"]) + try: + function_call = self._resolve_function_call(call_id) + result = self._collect_remote_call(call_id, function_call) + except _RemoteCallExpired: + return None + self._reload_volume() + return self._finalize_node_result( + node_id=node_id, + attempt_id=str(remote_call["attempt_id"]), + attempt_dir=self.ledger.run_root + / "nodes" + / node_id + / "attempts" + / str(remote_call["attempt_id"]), + result=result, + ) + + def _collect_remote_call( + self, + call_id: str, + function_call: RemoteFunctionCall, + ) -> AppRunResult: + try: + try: + raw_result = function_call.get(timeout=self.remote_call_poll_timeout) + except TimeoutError: + self.ledger.mark_remote_call_status(call_id, "running") + self._commit_volume() + raw_result = function_call.get() + except Exception as exc: + self._record_remote_call_exception(call_id, exc) + raise + + try: + result = AppRunResult.model_validate(raw_result) + except Exception as exc: + self.ledger.mark_remote_call_status( + call_id, + "failed", + error=str(exc), + completed=True, + ) + self._commit_volume() + raise + + remote_call = self.ledger.load_remote_call(call_id) + if remote_call is not None: + self.ledger.record_app_result( + str(remote_call["node_id"]), + str(remote_call["attempt_id"]), + result, + ) + self._commit_volume() + self.ledger.mark_remote_call_status( + call_id, + "succeeded", + completed=True, + metadata={"result_status": result.status.value}, + ) + self._commit_volume() + return result + + def _record_remote_call_exception(self, call_id: str, exc: Exception) -> None: + if exc.__class__.__name__ == "OutputExpiredError": + self.ledger.mark_remote_call_status( + call_id, + "expired", + error=str(exc), + completed=True, + ) + self._commit_volume() + raise _RemoteCallExpired(str(exc)) from exc + self.ledger.mark_remote_call_status( + call_id, + "failed", + error=str(exc), + completed=True, + ) + self._commit_volume() + + def _resolve_function_call(self, call_id: str) -> RemoteFunctionCall: + if self.function_call_resolver is not None: + return self.function_call_resolver(call_id) + + import modal + + return modal.FunctionCall.from_id(call_id) + + def _remote_function_name(self, node: WorkflowNode) -> str: + if self.remote_node_function_name is not None: + return self.remote_node_function_name + if self.remote_node_runner is not None: + function_name = getattr(self.remote_node_runner, "function_name", None) + if function_name is not None: + return str(function_name) + node_name = f"{node.__class__.__module__}.{node.__class__.__qualname__}" + return node_name + + @staticmethod + def _node_error_message(result: AppRunResult) -> str: + if result.warnings: + return result.warnings[0] + if result.status == AppRunStatus.PARTIAL: + return "Node returned partial status" + return "Node returned failed status" + + @staticmethod + def _dag_hash(definition: WorkflowDefinition) -> str: + payload = { + "name": definition.name, + "nodes": { + node_id: WorkflowRuntime._node_hash_payload(spec.node) + | { + "inputs": { + input_name: selector.model_dump(mode="json") + for input_name, selector in sorted(spec.inputs.items()) + }, + "control_dependencies": sorted(spec.control_dependencies), + "dependencies": sorted(definition.dependencies[node_id]), + } + for node_id, spec in sorted(definition.nodes.items()) + }, + } + encoded = orjson.dumps( + payload, + option=orjson.OPT_SORT_KEYS, + ) + return hashlib.sha256(encoded).hexdigest() + + @staticmethod + def _node_hash_payload(node: WorkflowNode) -> dict[str, object]: + payload: dict[str, object] = { + "class": f"{node.__class__.__module__}.{node.__class__.__qualname__}", + "execution_policy": node.execution_policy.value, + "placement": node.placement.value, + } + if is_dataclass(node): + payload["dataclass"] = WorkflowRuntime._stable_json_value(node) + return payload + + @staticmethod + def _stable_json_value(value: object) -> object: + if isinstance(value, BaseModel): + return value.model_dump(mode="json", round_trip=True) + if isinstance(value, Enum): + return value.value + if isinstance(value, bytes): + return { + "bytes_sha256": hashlib.sha256(value).hexdigest(), + "size_bytes": len(value), + } + if isinstance(value, Path): + return value.as_posix() + if is_dataclass(value) and not isinstance(value, type): + return { + field.name: WorkflowRuntime._stable_json_value( + getattr(value, field.name) + ) + for field in fields(value) + if field.metadata.get("dag_hash") is not False + } + if isinstance(value, Mapping): + return { + str(key): WorkflowRuntime._stable_json_value(item) + for key, item in sorted(value.items(), key=lambda pair: str(pair[0])) + } + if isinstance(value, (list, tuple)): + return [WorkflowRuntime._stable_json_value(item) for item in value] + if isinstance(value, (set, frozenset)): + stable_items = [WorkflowRuntime._stable_json_value(item) for item in value] + return sorted( + stable_items, + key=lambda item: orjson.dumps(item, option=orjson.OPT_SORT_KEYS), + ) + if value is None or isinstance(value, (str, int, float, bool)): + return value + raise TypeError( + f"Unsupported DAG hash value type: {type(value).__module__}." + f"{type(value).__qualname__}" + ) + + def _next_attempt_id(self, node_id: str) -> str: + return self.ledger.next_attempt_id(node_id) + + def _attempt_dir(self, attempt: AttemptRecord) -> Path: + return ( + self.ledger.run_root + / "nodes" + / attempt.node_id + / "attempts" + / attempt.attempt_id + ) + + def _commit_volume(self) -> None: + if self.workflow_volume is not None: + with self.ledger.closed_for_volume_sync(): + self.workflow_volume.commit() + + def _reload_volume(self) -> None: + if self.workflow_volume is not None: + with self.ledger.closed_for_volume_sync(): + self.workflow_volume.reload() + + def close(self) -> None: + """Close durable local resources owned by the runtime.""" + self.ledger.close() + + +class _RemoteCallExpired(RuntimeError): + """Raised when Modal no longer has a recoverable function result.""" diff --git a/src/biomodals/workflow/ppiflow_workflow.py b/src/biomodals/workflow/ppiflow_workflow.py index 9db9b0e..40c72f8 100644 --- a/src/biomodals/workflow/ppiflow_workflow.py +++ b/src/biomodals/workflow/ppiflow_workflow.py @@ -1,1625 +1,555 @@ -"""PPIFlow source repo: . - -This file (`ppiflow_app.py`) is a **single Modal entrypoint** that routes to multiple upstream -PPIFlow sampling scripts (binder / antibody / nanobody / monomer / partial-flow variants), -while enforcing a **stable output layout** and **inference-safe config override**. - -## What this wrapper guarantees - -- **One CLI** (`--task`) for multiple PPIFlow scripts in `/ppiflow/sample_*.py`. -- **Role-based uploads**: local input files are uploaded with stable filenames (e.g. `binder_input.pdb`, - `antigen.pdb`, `framework.pdb`, `complex.pdb`, `motif.csv`) so the remote worker never guesses ordering. -- **Forced outputs** (remote side): - - `--output_dir` is forced to `/runs///outputs` - - `--name` is forced to `` -- **Effective config** (remote side): - - If a `--config` is provided, an `effective_config.yaml` is written under the run directory with: - `model.use_deepspeed_evo_attention = False` - - This makes inference **portable** (does not require deepspeed/nvcc kernels). - -## Configuration - -### Primary flags (local entrypoint) - -| Flag | Default | Description | -|------|---------|-------------| -| `--task` | `binder` | Task router. One of: `binder`, `antibody`, `nanobody`, `monomer`, `scaffolding`, `ab_partial_flow`, `nb_partial_flow`, `binder_partial_flow`, `mpnn_stage1`, `mpnn_stage2`. | -| `--run-name` | `test1` | Unique run identifier. Controls output directory name and tarball name (`..tar.gz`). | -| `--out-dir` | `./ppiflow_outputs` | Local directory to write the returned run bundle (`.tar.gz`). | -| `--model-weights` | **Required** | Local/remote path to a checkpoint. Remote resolves to `/models/` unless already under `/models/`. | -| `--config` | `None` | YAML config path (absolute in container or repo-relative). If provided, it will be rewritten to `effective_config.yaml` with deepspeed evo attention disabled. | - -### Input file flags (local -> uploaded to Modal) - -| Task | Required local file flags | Uploaded filename on worker | -|------|---------------------------|-----------------------------| -| `binder` | `--binder-input-pdb` | `binder_input.pdb` | -| `binder_partial_flow` | `--binder-input-pdb` | `binder_input.pdb` | -| `antibody` / `nanobody` | `--ab-antigen-pdb`, `--ab-framework-pdb` | `antigen.pdb`, `framework.pdb` | -| `ab_partial_flow` / `nb_partial_flow` | `--pf-complex-pdb` | `complex.pdb` | -| `scaffolding` | `--scaffold-motif-csv` | `motif.csv` | -| `monomer` | *(no file upload required)* | *(none)* | -| `mpnn_stage1` | *(no local upload; uses existing run)* | *(none; reads `/runs///outputs/*.pdb`)* | -| `mpnn_stage2` | *(no local upload; uses existing run)* | *(none; reads `/runs///outputs/*.pdb`)* | - -### Binder args (sample_binder.py) - -| Flag | Default | Description | -|------|---------|-------------| -| `--binder-target-chain` | `B` | Target chain ID passed to `--target_chain`. | -| `--binder-binder-chain` | `A` | Binder chain ID passed to `--binder_chain`. | -| `--binder-specified-hotspots` | `None` | Hotspots string, e.g. `"B119,B141,B200"`. | -| `--binder-samples-min-length` | `75` | Minimum binder length. | -| `--binder-samples-max-length` | `76` | Maximum binder length. | -| `--binder-samples-per-target` | `5` | Number of samples per target. | - -### Antibody / Nanobody args (sample_antibody_nanobody.py) - -| Flag | Default | Description | -|------|---------|-------------| -| `--ab-antigen-chain` | `None` | **Required** for `antibody/nanobody`. Passed to `--antigen_chain`. | -| `--ab-heavy-chain` | `None` | **Required** for `antibody/nanobody`. Passed to `--heavy_chain`. | -| `--ab-light-chain` | `None` | Optional light chain. Passed to `--light_chain` when provided. | -| `--ab-specified-hotspots` | `None` | Optional hotspot residues on antigen, e.g. `"A56,A58"`. | -| `--ab-cdr-length` | `None` | Optional CDR length override (string format per upstream script). | -| `--ab-samples-per-target` | `5` | Samples per target for antibody/nanobody. | - -### Monomer unconditional args (sample_monomer.py) - -| Flag | Default | Description | -|------|---------|-------------| -| `--mono-length-subset` | `None` | **Required** for `monomer`. String list, e.g. `"[60, 80, 100]"`. | -| `--mono-samples-num` | `5` | Number of unconditional samples. | - -### Scaffolding args (sample_monomer.py motif mode) - -| Flag | Default | Description | -|------|---------|-------------| -| `--scaffold-motif-names` | `None` | Optional motif name filter passed as `--motif_names`. | -| `--scaffold-samples-num` | `5` | Number of scaffolding samples. | - -### Partial flow antibody / nanobody args (sample_antibody_nanobody_partial.py) - -| Flag | Default | Description | -|------|---------|-------------| -| `--pf-fixed-positions` | `None` | **Required**. Fixed positions string, e.g. `"H26,H27,H28,L50-63"`. | -| `--pf-cdr-position` | `None` | **Required**. CDR ranges string, e.g. `"H26-32,H45-56,H97-113"`. | -| `--pf-start-t` | `None` | **Required**. Partial flow start time (float). | -| `--pf-samples-per-target` | `None` | **Required**. Samples per target. | -| `--pf-retry-limit` | `10` | Passed as `--retry_Limit` (upstream spelling). | -| `--pf-specified-hotspots` | `None` | Optional hotspots for partial flow. | -| `--pf-antigen-chain` | `None` | **Required**. Passed to `--antigen_chain`. | -| `--pf-heavy-chain` | `None` | **Required**. Passed to `--heavy_chain`. | -| `--pf-light-chain` | `None` | Optional. Passed to `--light_chain` when provided. | - -### Partial flow binder args (sample_binder_partial.py) - -| Flag | Default | Description | -|------|---------|-------------| -| `--bpf-target-chain` | `B` | Target chain passed to `--target_chain`. | -| `--bpf-binder-chain` | `A` | Binder chain passed to `--binder_chain`. | -| `--bpf-start-t` | `0.7` | Partial flow start time passed to `--start_t`. | - -### MPNN / ABMPNN tasks -For a complete set of protein_mpnn CLI options that can be used, see . - - -`mpnn_stage1` and `mpnn_stage2` both run sequence design on an existing backbone run: - -- Required source run flags: `--mpnn-source-task`, `--mpnn-source-run` -- Input backbones come from: `/runs///outputs/*.pdb` -- Model choice: `--mpnn-model-name` (use `abmpnn` to run ABMPNN weights from `/models/abmpnn.pt`) -- Runtime controls: `--mpnn-batch-size` (default `1`), `--mpnn-seed` (default `0`) -- Optional design constraints: - - `--mpnn-chain-list` (chains to design) - - `--mpnn-position-list` (CSV path for fixed positions; if omitted, design is unconstrained/full-chain) - - `--mpnn-omit-aas` (optional amino acids to omit, e.g. `C`) - - `--mpnn-use-soluble-model` (optional flag for ProteinMPNN soluble weights mode) - -Stage semantics in this wrapper: - -- `mpnn_stage1` (exploration): default `--mpnn-num-seq-per-target-stage1 8`, `--mpnn-temp-stage1 0.5` -- `mpnn_stage2` (conservative refinement): default `--mpnn-num-seq-per-target-stage2 4`, `--mpnn-temp-stage2 0.1` - -Current behavior note: - -- `mpnn_stage2` currently does not consume AF3-filtered outputs automatically; it is a second MPNN pass with different sampling settings. - -## Environment variables (Modal) - -| Environment variable | Default | Description | -|----------------------|---------|-------------| -| `MODAL_APP` | `ppiflow` | Name of the Modal app. | -| `GPU` | `L40S` | GPU type for the worker (e.g. `A10G`, `A100`, `L40S`). | -| `TIMEOUT` | `36000` | Modal function timeout (seconds). | - -## Persistent volumes & paths - -- Models volume: `ppiflow-models` mounted at `/models` -- Runs volume: `ppiflow-runs` mounted at `/ppiflow-runs` - -Expected checkpoint layout (one-time upload examples): - 1) ppiflow - modal volume put ppiflow-models antibody.ckpt /antibody.ckpt - modal volume put ppiflow-models binder.ckpt /binder.ckpt - modal volume put ppiflow-models monomer.ckpt /monomer.ckpt - modal volume put ppiflow-models nanobody.ckpt /nanobody.ckpt 2) proteinmpnn - 2) Upload ProteinMPNN weights to persistent Volume (one-time) - modal volume put ppiflow-models v_48_002.pt /proteinmpnn_v_48_002.pt - modal volume put ppiflow-models v_48_010.pt /proteinmpnn_v_48_010.pt - modal volume put ppiflow-models v_48_020.pt /proteinmpnn_v_48_020.pt - 3) abmpnn weights - modal volume put ppiflow-models abmpnn.pt /abmpnn.pt - -## Outputs - -- Each run is stored under the runs volume at: - `/runs///` - with: - - `inputs/` (uploaded inputs) - - `outputs/` (upstream script outputs; forced `--output_dir`) - - `effective_config.yaml` (if `--config` provided) - - `cmd.txt` (exact executed command) - - `stdout.log` (combined stdout/stderr) - - `artifacts/` (best-effort collected: metrics/config + any `.csv`) - -- The local CLI saves a `.tar.gz` bundle to: - `/..tar.gz` - -## Typical usage - - # Binder (de novo) - modal run ppiflow_app.py --task binder -- \ - --binder-input-pdb ~/target.pdb \ - --binder-target-chain B \ - --binder-binder-chain A \ - --binder-specified-hotspots "B119,B141,B200" \ - --binder-samples-min-length 75 \ - --binder-samples-max-length 76 \ - --binder-samples-per-target 5 \ - --config /ppiflow/configs/inference_binder.yaml \ - --model-weights /models/binder.ckpt \ - --run-name test1 \ - --out-dir ./ppiflow_outputs - - # Antibody partial flow - modal run ppiflow_app.py --task ab_partial_flow -- \ - --pf-complex-pdb ~/complex.pdb \ - --pf-fixed-positions "H26,H27,H28,L50-63" \ - --pf-cdr-position "H26-32,H45-56,H97-113" \ - --pf-start-t 0.8 \ - --pf-samples-per-target 5 \ - --pf-antigen-chain A \ - --pf-heavy-chain H \ - --pf-light-chain L \ - --model-weights /models/antibody.ckpt \ - --run-name abp1 - - # MPNN stage1 (exploration) on an existing binder run - modal run ppiflow_app.py --task mpnn_stage1 -- \ - --mpnn-source-task binder \ - --mpnn-source-run test1 \ - --mpnn-model-name v_48_020 \ - --mpnn-num-seq-per-target-stage1 8 \ - --mpnn-temp-stage1 0.5 \ - --run-name mpnn_test1\ - --mpnn-batch-size 1 \ - --mpnn-seed 0 - - - # MPNN stage2 (conservative refinement) on the same source run - modal run ppiflow_app.py --task mpnn_stage2 -- \ - --mpnn-source-task binder \ - --mpnn-source-run test1 \ - --mpnn-model-name v_48_020 \ - --mpnn-num-seq-per-target-stage2 4 \ - --mpnn-temp-stage2 0.1 \ - --run-name test1 \ - --mpnn-batch-size 1 \ - --mpnn-seed 0 - - # ABMPNN with fixed framework positions (design only selected residues/chains) - modal run ppiflow_app.py --task mpnn_stage1 -- \ - --mpnn-source-task nanobody \ - --mpnn-source-run nb1 \ - --mpnn-model-name abmpnn \ - --mpnn-chain-list A \ - --mpnn-position-list ./fixed_positions.csv \ - --mpnn-num-seq-per-target-stage1 8 \ - --mpnn-temp-stage1 0.5 - - # ABMPNN full design (no fixed positions file provided) - modal run ppiflow_app.py --task mpnn_stage1 -- \ - --mpnn-source-task nanobody \ - --mpnn-source-run nb1 \ - --mpnn-model-name abmpnn \ - --mpnn-num-seq-per-target-stage1 8 \ - --mpnn-temp-stage1 0.5 - -""" - -# TODO: reuse *_app modules for constructing the workflow +"""PPIFlow workflow definition built on the reusable workflow runtime.""" from __future__ import annotations -import csv -import json import os -import subprocess -import sys -import tarfile -import tempfile -from collections.abc import Iterable -from io import StringIO +from copy import deepcopy +from dataclasses import dataclass, field from pathlib import Path from typing import Any -from modal import App, Image, Volume - -# ------------------------- -# Modal configs -# ------------------------- -APP_NAME = os.environ.get("MODAL_APP", "ppiflow") -GPU = os.environ.get("GPU", "L40S") # e.g. A10G, A100, L40S -TIMEOUT = int(os.environ.get("TIMEOUT", "36000")) - -# Persistent Volumes -MODELS_VOL = Volume.from_name("ppiflow-models", create_if_missing=True) -RUNS_VOL = Volume.from_name("ppiflow-runs", create_if_missing=True) - -MODELS_DIR = Path("/models") -RUNS_DIR = Path("/ppiflow-runs") +import modal +import yaml + +from biomodals.app.design import ppiflow_app +from biomodals.helper import patch_image_for_helper +from biomodals.helper.catalog import include_dependency_apps +from biomodals.helper.constant import MAX_TIMEOUT +from biomodals.helper.shell import sanitize_filename +from biomodals.helper.volume_run import volume_path_from_mount_path +from biomodals.schema import ( + AppConfig, + AppRunResult, + ArtifactKind, + NodeExecutionPolicy, + NodePlacement, +) +from biomodals.workflow.core import ( + AppBackedNode, + NodeRunContext, + Workflow, + WorkflowNativeNode, + orchestrator, +) -# ------------------------- -# Image definition -# ------------------------- -PPIFLOW_REPO = ( - "https://github.com/zhuqianhui2-hash/PPIFlow.git" # updated at 2026-02-10-18:00 +PPI_FLOW_OUTPUT_LAYOUT = ( + "stage1/", + "stage2/", + "design_output/", + "design_output/ranked_designs.csv", + "design_output/design_report.md", +) +PPI_FLOW_APP_STEPS = ("PPIFlowStep", "PartialStep") + +DEPENDENCY_APPS = ("ppiflow",) +CONF = AppConfig( + tags={"depends_on": ",".join(DEPENDENCY_APPS)}, + depends_on_apps=DEPENDENCY_APPS, + name="PPIFlowWorkflow", + package_name="biomodals-ppiflow-workflow", + version="0.1.0", + python_version="3.13", + timeout=int(os.environ.get("TIMEOUT", str(MAX_TIMEOUT))), ) -PPIFLOW_DIR = "/ppiflow" - -PYTORCH_CU121_INDEX = "https://download.pytorch.org/whl/cu121" -PYG_WHL = "https://data.pyg.org/whl/torch-2.3.0+cu121.html" - -TORCH_PKGS = [ - "torch==2.3.1+cu121", - "torchvision==0.18.1+cu121", - "torchaudio==2.3.1+cu121", -] - -PYG_PKGS = [ - "pyg-lib==0.4.0+pt23cu121", - "torch-scatter==2.1.2+pt23cu121", - "torch-sparse==0.6.18+pt23cu121", - "torch-cluster==1.6.3+pt23cu121", - "torch-spline-conv==1.2.2+pt23cu121", - "torch-geometric==2.6.1", -] - -INFER_PKGS = [ - "numpy==1.26.3", - "scipy==1.15.2", - "pandas==2.2.3", - "scikit-learn==1.2.2", - "pyyaml==6.0.2", - "omegaconf==2.3.0", - "hydra-core==1.3.2", - "hydra-submitit-launcher==1.2.0", - "submitit==1.5.3", - "tqdm==4.67.1", - "lightning==2.5.0.post0", - "pytorch-lightning==2.5.0.post0", - "torchmetrics==1.6.2", - "lightning-utilities==0.14.0", - "einops==0.8.1", - "dm-tree==0.1.6", - "optree==0.14.1", - "opt-einsum==3.4.0", - "opt-einsum-fx==0.1.4", - "e3nn==0.5.6", - "fair-esm==2.0.0", - "biopython==1.83", - "biotite==1.0.1", - "biotraj==1.2.2", - "gemmi==0.6.5", - "ihm==2.2", - "modelcif==0.7", - "tmtools==0.2.0", - "freesasa==2.2.1", - "mdtraj==1.10.3", - "requests==2.32.3", - "packaging==24.2", - "typing-extensions==4.12.2", - "protobuf==3.20.2", - "tensorboard==2.19.0", - "tensorboard-data-server==0.7.2", - "grpcio==1.72.1", - "gputil==1.4.0", - "gpustat==1.1.1", - "hjson==3.1.0", - "ninja==1.11.1.3", -] runtime_image = ( - Image.debian_slim(python_version="3.10") - .apt_install( - "git", - "curl", - "ca-certificates", - "build-essential", - "python3-dev", - "pkg-config", - "gfortran", - "libopenblas-dev", - "liblapack-dev", - "libhdf5-dev", - "libnetcdf-dev", - "zlib1g-dev", - "libbz2-dev", - "liblzma-dev", - ) - .env({"PYTHONUNBUFFERED": "1", "PYTHONPATH": PPIFLOW_DIR}) - .run_commands( - f"rm -rf {PPIFLOW_DIR} && git clone --depth 1 {PPIFLOW_REPO} {PPIFLOW_DIR}" - ) - .pip_install(*TORCH_PKGS, extra_index_url=PYTORCH_CU121_INDEX) - .uv_pip_install(*PYG_PKGS, find_links=PYG_WHL) - .uv_pip_install(*INFER_PKGS) + modal.Image + .debian_slim(python_version=CONF.python_version) + .env(CONF.default_env) + .pipe(patch_image_for_helper, include_workflow_modules=True) +) +app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags).include( + orchestrator.app, inherit_tags=True ) +app = include_dependency_apps(app, CONF.depends_on_apps) -app = App(APP_NAME) - -# ------------------------- -# Task routing (PPIFlow sampling scripts only) -# ------------------------- -TASK_TO_SCRIPT: dict[str, str] = { - "binder": f"{PPIFLOW_DIR}/sample_binder.py", - "antibody": f"{PPIFLOW_DIR}/sample_antibody_nanobody.py", - "nanobody": f"{PPIFLOW_DIR}/sample_antibody_nanobody.py", - "monomer": f"{PPIFLOW_DIR}/sample_monomer.py", - "scaffolding": f"{PPIFLOW_DIR}/sample_monomer.py", - "ab_partial_flow": f"{PPIFLOW_DIR}/sample_antibody_nanobody_partial.py", - "nb_partial_flow": f"{PPIFLOW_DIR}/sample_antibody_nanobody_partial.py", - "binder_partial_flow": f"{PPIFLOW_DIR}/sample_binder_partial.py", -} - -MPNN_TASKS = {"mpnn_stage1", "mpnn_stage2"} - - -# ------------------------- -# Helpers -# ------------------------- -def _tar_dir(src_dir: Path, out_tar_gz: Path) -> None: - with tarfile.open(out_tar_gz, "w:gz") as tf: - tf.add(src_dir, arcname=src_dir.name) - - -def _iter_files(root: Path) -> Iterable[Path]: - for p in root.rglob("*"): - if p.is_file(): - yield p - - -def _write_effective_config(src_config: Path, dst_config: Path) -> None: - import yaml - - cfg = yaml.safe_load(src_config.read_text()) or {} - cfg.setdefault("model", {}) - cfg["model"]["use_deepspeed_evo_attention"] = False - dst_config.write_text(yaml.safe_dump(cfg, sort_keys=False)) - - -def _collect_artifacts(run_dir: Path) -> None: - artifacts = run_dir / "artifacts" - artifacts.mkdir(parents=True, exist_ok=True) - - want_exact = {"metrics.csv", "config.yml", "config.yaml"} - for f in _iter_files(run_dir): - if f.name in want_exact: - dst = artifacts / f.name - if not dst.exists(): - dst.write_bytes(f.read_bytes()) - - for f in _iter_files(run_dir): - if f.suffix.lower() == ".csv": - dst = artifacts / f.name - if not dst.exists(): - dst.write_bytes(f.read_bytes()) - - -def _script_help_text(script: Path) -> str: - p = subprocess.run( # noqa: S603 - [sys.executable, str(script), "--help"], - check=False, - text=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) - return p.stdout or "" +@dataclass(frozen=True) +class PPIFlowModalNamespace: + """Hydrated Modal objects carried across the orchestrator boundary.""" -def _script_supports_flag( - script: Path, flag: str, help_text: str | None = None -) -> bool: - txt = help_text if help_text is not None else _script_help_text(script) - return flag in txt + ppiflow_run: modal.Function -def _find_abmpnn_run_script() -> Path: - """Find ABMPNN/ProteinMPNN runnable script under /ppiflow/ProteinMPNN. +@dataclass +class PPIFlowWorkflowNode(AppBackedNode): + """Base class for PPIFlow v2 app-backed workflow nodes.""" - We try common names, then fallback to a glob search. - """ - base = Path(PPIFLOW_DIR) / "ProteinMPNN" - candidates = [ - base / "protein_mpnn_run.py", - base / "run.py", - base / "protein_mpnn_run_main.py", - Path(PPIFLOW_DIR) / "protein_mpnn_run.py", - ] - for p in candidates: - if p.exists(): - return p - - hits = sorted(base.rglob("*mpnn*run*.py")) - if hits: - return hits[0] - - raise FileNotFoundError( - f"ABMPNN run script not found under {base}. " - "Expected something like ProteinMPNN/protein_mpnn_run.py" + step_name: str + modal_namespace: PPIFlowModalNamespace = field( + repr=False, + compare=False, + metadata={"dag_hash": False}, ) + config: dict[str, Any] = field(default_factory=dict) + execution_policy: NodeExecutionPolicy = NodeExecutionPolicy.RERUN + placement: NodePlacement = NodePlacement.REMOTE + + def run(self, context: NodeRunContext) -> AppRunResult: + """Run a workflow-compatible PPIFlow app step.""" + if self.step_name not in PPI_FLOW_APP_STEPS: + raise NotImplementedError( + f"PPIFlow workflow step {self.step_name!r} does not yet have a " + "workflow-compatible app adapter." + ) + raw_args = self.config.get("args", self.config) + if not isinstance(raw_args, dict): + raise ValueError(f"PPIFlow step {self.step_name!r} args must be a mapping") -def _resolve_abmpnn_ckpt(user_path: str | None) -> Path: - if user_path: - p = Path(user_path) - else: - p = Path(PPIFLOW_DIR) / "ProteinMPNN" / "model_weights" / "abmpnn.pt" - if not p.exists(): - raise FileNotFoundError(f"ABMPNN checkpoint not found: {p}") - return p - - -def _write_csv(path: Path, rows: list[dict[str, Any]], fieldnames: list[str]) -> None: - buf = StringIO() - w = csv.DictWriter(buf, fieldnames=fieldnames) - w.writeheader() - for r in rows: - w.writerow(r) - path.write_text(buf.getvalue()) - - -def _expand_position_token(token: str) -> list[int]: - token = token.strip() - if not token: - return [] - if "-" in token: - a, b = token.split("-", 1) - start = int(a.strip()) - end = int(b.strip()) - if end < start: - raise ValueError(f"Invalid range: {token}") - return list(range(start, end + 1)) - return [int(token)] - - -def _normalize_fixed_positions_csv_bytes(raw_csv: bytes) -> bytes: - """Accept user-friendly fixed-position syntax and convert to ProteinMPNN style. - - Supported input for the second column (e.g. motif_index / fixed_positions): - - "1 2 3 10 11" - - "1-25,34-50,59-96,118-127" - - Output second column is normalized to: - - "1 2 3 ... 127" - """ - text = raw_csv.decode("utf-8-sig") - reader = csv.DictReader(StringIO(text)) - if not reader.fieldnames: - raise ValueError("mpnn position CSV has no header") - - fns = list(reader.fieldnames) - lower_map = {f.lower(): f for f in fns} - name_col = lower_map.get("pdb_name") or lower_map.get("pdb_id") - pos_col = lower_map.get("motif_index") or lower_map.get("fixed_positions") - if not name_col or not pos_col: - raise ValueError( - "mpnn position CSV must contain (pdb_name or pdb_id) and " - "(motif_index or fixed_positions) columns" + run_name = sanitize_filename( + str(self.config.get("run_name") or f"{context.run_id}-{self.step_name}") + ) + app_args = ppiflow_app.PPIFlowArgs.model_validate({"args": raw_args}) + return AppRunResult.model_validate( + self.modal_namespace.ppiflow_run.remote( + args=app_args, + run_name=run_name, + ) ) - rows = list(reader) - if not rows: - raise ValueError("mpnn position CSV has no data rows") - - out_rows: list[dict[str, str]] = [] - for r in rows: - pdb_name = (r.get(name_col) or "").strip() - if not pdb_name: - raise ValueError(f"Empty pdb name in row: {r}") - raw = (r.get(pos_col) or "").strip() - if not raw: - out = "" - else: - pieces = [p.strip() for p in raw.split(",") if p.strip()] - values: list[int] = [] - for p in pieces: - values.extend(_expand_position_token(p)) - vals = sorted(set(values)) - out = " ".join(str(v) for v in vals) - out_rows.append({"pdb_name": pdb_name, "motif_index": out}) - - buf = StringIO() - writer = csv.DictWriter(buf, fieldnames=["pdb_name", "motif_index"]) - writer.writeheader() - writer.writerows(out_rows) - return buf.getvalue().encode("utf-8") - - -def _chain_order_from_pdb(pdb_path: Path) -> list[str]: - order: list[str] = [] - seen: set[str] = set() - with pdb_path.open("r", encoding="utf-8", errors="ignore") as f: - for line in f: - if not (line.startswith("ATOM") or line.startswith("HETATM")): - continue - if len(line) < 22: - continue - chain_id = line[21].strip() - if not chain_id or chain_id in seen: - continue - seen.add(chain_id) - order.append(chain_id) - return order +@dataclass +class FilterStructuresNode(WorkflowNativeNode): + """Filter structures using score artifacts.""" -def _rewrite_fixed_positions_for_proteinmpnn( - csv_path: Path, pdb_folder: Path, chain_list: str -) -> None: - """Generate fixed-position CSV for ProteinMPNN. + step_name: str + config: dict[str, Any] = field(default_factory=dict) - Convert normalized CSV (single position list per pdb) into upstream legacy format - expected by ProteinMPNN helper script: - second column uses '-' to separate per-chain residue lists in pdb chain order. - """ - designed_chains = [c for c in chain_list.split() if c] - if not designed_chains: - raise ValueError( - "--mpnn-chain-list is required when --mpnn-position-list is provided" - ) - if len(designed_chains) != 1: - raise ValueError( - "Current fixed-position CSV format supports one designed chain. " - f"Got --mpnn-chain-list={chain_list!r}" - ) + def run(self, context: NodeRunContext) -> AppRunResult: + """Execute filtering logic.""" + raise NotImplementedError - text = csv_path.read_text(encoding="utf-8-sig") - reader = csv.DictReader(StringIO(text)) - if not reader.fieldnames: - raise ValueError("mpnn_fixed_positions.csv has no header") - rows = list(reader) - if not rows: - raise ValueError("mpnn_fixed_positions.csv has no data rows") - - out_rows: list[dict[str, str]] = [] - target_chain = designed_chains[0] - for row in rows: - pdb_name = (row.get("pdb_name") or row.get("pdb_id") or "").strip() - if not pdb_name: - raise ValueError(f"Empty pdb_name/pdb_id in row: {row}") - pos_raw = (row.get("motif_index") or row.get("fixed_positions") or "").strip() - - pdb_file = pdb_folder / f"{pdb_name}.pdb" - if not pdb_file.exists(): - raise FileNotFoundError( - f"PDB not found for fixed positions row: {pdb_file}" - ) - chains = _chain_order_from_pdb(pdb_file) - if not chains: - raise ValueError(f"Could not detect chain order from: {pdb_file}") - if target_chain not in chains: - raise ValueError( - f"Chain {target_chain!r} not found in {pdb_file.name}; chains={chains}" - ) - segments: list[str] = [] - for ch in chains: - segments.append(pos_raw if ch == target_chain else "") - out_rows.append({"pdb_name": pdb_name, "motif_index": "-".join(segments)}) +@dataclass +class RankAndReportNode(WorkflowNativeNode): + """Rank final designs and write report artifacts.""" - buf = StringIO() - writer = csv.DictWriter(buf, fieldnames=["pdb_name", "motif_index"]) - writer.writeheader() - writer.writerows(out_rows) - csv_path.write_text(buf.getvalue(), encoding="utf-8") + step_name: str + config: dict[str, Any] = field(default_factory=dict) + def run(self, context: NodeRunContext) -> AppRunResult: + """Execute ranking and report logic.""" + raise NotImplementedError -def _write_empty_fixed_positions_csv(csv_path: Path, pdb_folder: Path) -> None: - """Write a no-op fixed-positions CSV for upstream protein_mpnn_run.py compatibility. - This avoids its UnboundLocalError when --position_list is omitted. - """ - rows: list[dict[str, str]] = [] - for pdb in sorted(pdb_folder.glob("*.pdb")): - rows.append({"pdb_name": pdb.stem, "motif_index": ""}) - if not rows: - raise FileNotFoundError(f"No .pdb found under: {pdb_folder}") - - buf = StringIO() - writer = csv.DictWriter(buf, fieldnames=["pdb_name", "motif_index"]) - writer.writeheader() - writer.writerows(rows) - csv_path.write_text(buf.getvalue(), encoding="utf-8") - - -def _collect_pdbs_for_mpnn(src_outputs: Path) -> list[Path]: - pdbs = sorted(src_outputs.glob("*.pdb")) - if not pdbs: - raise FileNotFoundError(f"No .pdb found under: {src_outputs}") - return pdbs - - -def _gather_mpnn_fastas(mpnn_out_root: Path) -> list[Path]: - # accept .fa / .fasta / .faa etc. under seqs/ - hits: list[Path] = [] - for pat in ( - "seqs/*.fa", - "seqs/*.fasta", - "seqs/*.faa", - "seqs/*.fa.gz", - "seqs/*.fasta.gz", - ): - hits.extend(mpnn_out_root.rglob(pat)) - return sorted(set(hits)) +def build_ppiflow_workflow( + *, + task_yaml_bytes: bytes, + steps_yaml_bytes: bytes, + stage: int | None = None, + modal_namespace: PPIFlowModalNamespace | None = None, +) -> Workflow: + """Build a PPIFlow workflow DAG from upstream-style YAML files.""" + if stage not in {None, 1, 2}: + raise ValueError("stage must be omitted, 1, or 2") + if modal_namespace is None: + modal_namespace = PPIFlowModalNamespace( + ppiflow_run=ppiflow_app.ppiflow_run_workflow, + ) + task_doc = _load_yaml_bytes(task_yaml_bytes) + steps_doc = _load_yaml_bytes(steps_yaml_bytes) + task = _task_section(task_doc) + enabled = _enabled_section(task_doc) + gentype = str(task.get("gentype") or task.get("design_mode") or "binder") + workflow = Workflow("ppiflow-v2") + + stage1_tail = None + if stage in {None, 1}: + stage1_tail = _add_stage1_nodes( + workflow=workflow, + enabled=enabled, + steps=steps_doc, + gentype=gentype, + modal_namespace=modal_namespace, + ) -def _detect_abmpnn_cli_flags(mpnn_script: Path) -> dict[str, str]: - ht = _script_help_text(mpnn_script) + if stage in {None, 2}: + _add_stage2_nodes( + workflow=workflow, + enabled=enabled, + steps=steps_doc, + gentype=gentype, + upstream=stage1_tail if stage is None else None, + modal_namespace=modal_namespace, + ) - # folder input (this is required for ProteinMPNN/protein_mpnn_run.py) - folder_flag = ( - "--folder_with_pdbs_path" - if _script_supports_flag(mpnn_script, "--folder_with_pdbs_path", ht) - else "" - ) - if not folder_flag and _script_supports_flag(mpnn_script, "--pdb_dir", ht): - folder_flag = "--pdb_dir" - if not folder_flag and _script_supports_flag(mpnn_script, "--input_folder", ht): - folder_flag = "--input_folder" - - # single pdb input (optional, many scripts do NOT support it) - pdb_flag = ( - "--pdb_path" if _script_supports_flag(mpnn_script, "--pdb_path", ht) else "" - ) - if not pdb_flag and _script_supports_flag(mpnn_script, "--input_pdb", ht): - pdb_flag = "--input_pdb" + return workflow - # out dir - out_flag = ( - "--out_folder" if _script_supports_flag(mpnn_script, "--out_folder", ht) else "" - ) - if not out_flag and _script_supports_flag(mpnn_script, "--output_dir", ht): - out_flag = "--output_dir" - if not out_flag and _script_supports_flag(mpnn_script, "--out_dir", ht): - out_flag = "--out_dir" - - # checkpoint (ABMPNN single file OR ProteinMPNN model dir system) - ckpt_candidates = [ - "--checkpoint_path", - "--checkpoint", - "--model_path", - "--weights", - "--ckpt", - "--checkpoint_file", - ] - ckpt_flag = "" - for f in ckpt_candidates: - if _script_supports_flag(mpnn_script, f, ht): - ckpt_flag = f - break - - # ProteinMPNN classic interface (dir + model_name) - weights_dir_flag = ( - "--path_to_model_weights" - if _script_supports_flag(mpnn_script, "--path_to_model_weights", ht) - else "" - ) - model_name_flag = ( - "--model_name" if _script_supports_flag(mpnn_script, "--model_name", ht) else "" - ) - nseq_flag = ( - "--num_seq_per_target" - if _script_supports_flag(mpnn_script, "--num_seq_per_target", ht) - else "" - ) - if not nseq_flag and _script_supports_flag(mpnn_script, "--num_seqs", ht): - nseq_flag = "--num_seqs" +def _add_stage1_nodes( + *, + workflow: Workflow, + enabled: dict[str, bool], + steps: dict[str, Any], + gentype: str, + modal_namespace: PPIFlowModalNamespace, +): + tail = None + if _step_enabled(enabled, "PPIFlowStep"): + tail = workflow.add_node( + _app_step_node(steps, "PPIFlowStep", modal_namespace), + id="stage1-ppiflow-design", + ) - temp_flag = ( - "--sampling_temp" - if _script_supports_flag(mpnn_script, "--sampling_temp", ht) - else "" - ) - seed_flag = "--seed" if _script_supports_flag(mpnn_script, "--seed", ht) else "" - batch_flag = ( - "--batch_size" if _script_supports_flag(mpnn_script, "--batch_size", ht) else "" - ) + mpnn_step = None + if gentype == "binder" and _step_enabled(enabled, "MPNNStep_stage1"): + mpnn_step = ("stage1-ligandmpnn", "MPNNStep_stage1") + elif gentype in {"antibody", "nanobody"} and _step_enabled( + enabled, "AbMPNNStep_stage1" + ): + mpnn_step = ("stage1-abmpnn", "AbMPNNStep_stage1") + if mpnn_step is not None: + node_id, step_name = mpnn_step + tail = workflow.add_node( + _app_step_node(steps, step_name, modal_namespace), + id=node_id, + inputs=_structure_inputs(tail), + ) - if not out_flag: - raise RuntimeError( - f"Cannot detect output flag for {mpnn_script}. " - "Expected one of: --out_folder / --output_dir / --out_dir" + if _step_enabled(enabled, "FlowpackerStep_stage1"): + tail = workflow.add_node( + _app_step_node(steps, "FlowpackerStep_stage1", modal_namespace), + id="stage1-flowpacker", + inputs=_structure_inputs(tail), ) - # For protein_mpnn_run.py, folder_flag is required - if "protein_mpnn_run.py" in str(mpnn_script) and not folder_flag: - raise RuntimeError( - f"{mpnn_script} appears to require folder input, but folder flag not detected. " - "Expected --folder_with_pdbs_path." + score = None + if _step_enabled(enabled, "AF3scoreStep_stage1"): + score = workflow.add_node( + _app_step_node(steps, "AF3scoreStep_stage1", modal_namespace), + id="stage1-af3score", + inputs=_structure_inputs(tail), ) - return { - "folder_flag": folder_flag, - "pdb_flag": pdb_flag, - "out_flag": out_flag, - "ckpt_flag": ckpt_flag, - "weights_dir_flag": weights_dir_flag, - "model_name_flag": model_name_flag, - "nseq_flag": nseq_flag, - "temp_flag": temp_flag, - "seed_flag": seed_flag, - "batch_flag": batch_flag, - } + if _step_enabled(enabled, "FilterStep_stage1"): + inputs = _structure_inputs(tail) + if score is not None: + inputs["scores"] = score.outputs(kind=ArtifactKind.SCORES) + tail = workflow.add_node( + FilterStructuresNode( + "FilterStep_stage1", + _step_cfg(steps, "FilterStep_stage1"), + ), + id="stage1-filter", + inputs=inputs, + ) + return tail -def _run_abmpnn_on_folder( +def _add_stage2_nodes( *, - mpnn_script: Path, - flags: dict[str, str], - pdb_folder: Path, - out_folder: Path, - ckpt_path: Path | None, - num_seqs: int, - sampling_temp: float, - seed: int, - batch_size: int, - model_name: str | None, - chain_list: str | None, - position_list_csv: Path | None, - omit_aas: str | None, - use_soluble_model: bool, - log_path: Path, + workflow: Workflow, + enabled: dict[str, bool], + steps: dict[str, Any], + gentype: str, + upstream, + modal_namespace: PPIFlowModalNamespace, ) -> None: - """Run ProteinMPNN / ABMPNN script in folder mode. + tail = upstream + if _step_enabled(enabled, "RosettaFixStep"): + tail = workflow.add_node( + _app_step_node(steps, "RosettaFixStep", modal_namespace), + id="stage2-rosetta-fix", + inputs=_structure_inputs(tail), + ) - This is intentionally "no patch mode": execute upstream script as-is. - """ - out_folder.mkdir(parents=True, exist_ok=True) - - run_script = mpnn_script - patch_dir = mpnn_script.parent # e.g. /ppiflow/ProteinMPNN - - # ------------------------- - # Step 1: build argv - # ------------------------- - argv: list[str] = ["python", str(run_script)] - - folder_flag = flags.get("folder_flag") or "" - pdb_flag = flags.get("pdb_flag") or "" - if folder_flag: - argv += [folder_flag, str(pdb_folder)] - elif pdb_flag: - raise RuntimeError( - "MPNN runner is configured for folder mode, but folder_flag is missing and pdb_flag exists. " - "Implement a single-PDB runner or ensure folder_flag is detected correctly." + if _step_enabled(enabled, "PartialStep"): + tail = workflow.add_node( + _app_step_node(steps, "PartialStep", modal_namespace), + id="stage2-partial-ppiflow", + inputs=_structure_inputs(tail), ) - else: - raise RuntimeError("Neither folder_flag nor pdb_flag detected for MPNN script.") - - out_flag = flags.get("out_flag") or "" - if not out_flag: - raise RuntimeError("No out_flag detected for MPNN script.") - argv += [out_flag, str(out_folder)] - - ckpt_flag = flags.get("ckpt_flag") or "" - if ckpt_flag: - if ckpt_path is None: - raise ValueError( - f"MPNN script expects {ckpt_flag}, but no checkpoint path was resolved." - ) - argv += [ckpt_flag, str(ckpt_path)] - else: - weights_dir_flag = flags.get("weights_dir_flag") or "" - model_name_flag = flags.get("model_name_flag") or "" - if not weights_dir_flag or not model_name_flag: - raise RuntimeError( - "MPNN script does not expose a checkpoint flag and does not expose " - "--path_to_model_weights/--model_name. Cannot pass weights." - ) - weights_dir = Path("/models") - argv += [weights_dir_flag, str(weights_dir)] - if not model_name: - raise ValueError( - "mpnn_model_name is required for this ProteinMPNN script (--model_name). " - "You passed None/empty." - ) - argv += [model_name_flag, str(model_name)] - - # --- sanity check for classic ProteinMPNN --- - expected_pt = weights_dir / f"{model_name}.pt" - if not expected_pt.exists(): - available = sorted(p.name for p in weights_dir.glob("*.pt")) - raise FileNotFoundError( - f"ProteinMPNN weight not found: {expected_pt}\n" - f"Available under /models: {available}" - ) - # --- auto-detect CA-only checkpoint and toggle --ca_only --- - try: - import torch - - ckpt = torch.load(str(expected_pt), map_location="cpu") - sd = ckpt.get("model_state_dict", ckpt) - w = sd.get("features.edge_embedding.weight", None) - if w is not None and hasattr(w, "shape") and len(w.shape) == 2: - if int(w.shape[1]) == 167: - argv += ["--ca_only"] - except Exception as e: - print( - f"Warning: failed to auto-detect CA-only checkpoint for {expected_pt}: {e}" - ) + mpnn_step = None + if gentype == "binder" and _step_enabled(enabled, "MPNNStep_stage2"): + mpnn_step = ("stage2-ligandmpnn", "MPNNStep_stage2") + elif gentype in {"antibody", "nanobody"} and _step_enabled( + enabled, "AbMPNNStep_stage2" + ): + mpnn_step = ("stage2-abmpnn", "AbMPNNStep_stage2") + if mpnn_step is not None: + node_id, step_name = mpnn_step + tail = workflow.add_node( + _app_step_node(steps, step_name, modal_namespace), + id=node_id, + inputs=_structure_inputs(tail), + ) - nseq_flag = flags.get("nseq_flag") or "" - temp_flag = flags.get("temp_flag") or "" - seed_flag = flags.get("seed_flag") or "" - batch_flag = flags.get("batch_flag") or "" - - if nseq_flag: - argv += [nseq_flag, str(int(num_seqs))] - if temp_flag: - argv += [temp_flag, str(float(sampling_temp))] - if seed_flag: - argv += [seed_flag, str(int(seed))] - if batch_flag: - argv += [batch_flag, str(int(batch_size))] - - if chain_list: - argv += ["--chain_list", chain_list] - if position_list_csv: - argv += ["--position_list", str(position_list_csv)] - if omit_aas: - argv += ["--omit_AAs", omit_aas] - if use_soluble_model: - argv += ["--use_soluble_model"] - - # ------------------------- - # Step 3: run with cwd = original ProteinMPNN dir to satisfy relative imports - # ------------------------- - p = subprocess.run( # noqa: S603 - argv, - check=False, - text=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - cwd=str(patch_dir), # critical: make `protein_mpnn_utils` importable - ) + if _step_enabled(enabled, "FlowpackerStep_stage2"): + tail = workflow.add_node( + _app_step_node(steps, "FlowpackerStep_stage2", modal_namespace), + id="stage2-flowpacker", + inputs=_structure_inputs(tail), + ) - dbg: list[str] = [] - dbg.append("=== runner ===") - dbg.append(f"original_script: {mpnn_script}") - dbg.append(f"executed_script: {run_script}") - dbg.append(f"cwd: {patch_dir}") - dbg.append("") - dbg.append("=== argv ===") - dbg.append(" ".join(argv)) - dbg.append("") - dbg.append("=== output ===") - dbg.append(p.stdout or "") - log_path.write_text("\n".join(dbg)) - - if p.returncode != 0: - raise RuntimeError( - f"ABMPNN/ProteinMPNN failed (exit {p.returncode}). See {log_path}" + score = None + if _step_enabled(enabled, "AF3scoreStep_stage2"): + score = workflow.add_node( + _app_step_node(steps, "AF3scoreStep_stage2", modal_namespace), + id="stage2-af3score", + inputs=_structure_inputs(tail), ) + filtered = tail + if _step_enabled(enabled, "FilterStep_stage2"): + inputs = _structure_inputs(tail) + if score is not None: + inputs["scores"] = score.outputs(kind=ArtifactKind.SCORES) + filtered = workflow.add_node( + FilterStructuresNode( + "FilterStep_stage2", + _step_cfg(steps, "FilterStep_stage2"), + ), + id="stage2-filter", + inputs=inputs, + ) -def _stage_pdb_folder(pdbs: list[Path], dst_dir: Path) -> None: - dst_dir.mkdir(parents=True, exist_ok=True) - for p in pdbs: - # keep stable filenames - (dst_dir / p.name).write_bytes(p.read_bytes()) + refold = None + if _step_enabled(enabled, "ReFoldStep"): + refold = workflow.add_node( + _app_step_node(steps, "ReFoldStep", modal_namespace), + id="stage2-alphafold3-refold", + inputs=_structure_inputs(filtered), + ) + if _step_enabled(enabled, "RosettaRelaxStep"): + workflow.add_node( + _app_step_node(steps, "RosettaRelaxStep", modal_namespace), + id="stage2-rosetta-relax", + inputs=_structure_inputs(filtered), + ) -# ------------------------- -# Remote GPU job: unified runner -# ------------------------- -@app.function( - gpu=GPU, - cpu=(2, 8), - timeout=TIMEOUT, - image=runtime_image, - volumes={str(MODELS_DIR): MODELS_VOL, str(RUNS_DIR): RUNS_VOL}, -) -def run_ppiflow_structured( - # ---------- common ---------- - task: str, - run_name: str, - input_files: list[tuple[str, bytes]], - model_weights: str | None, - config: str | None, - # ---------- binder (sample_binder.py) ---------- - binder_target_chain: str, - binder_binder_chain: str, - binder_specified_hotspots: str | None, - binder_samples_min_length: int, - binder_samples_max_length: int, - binder_samples_per_target: int, - # ---------- antibody/nanobody (sample_antibody_nanobody.py) ---------- - ab_antigen_chain: str | None, - ab_heavy_chain: str | None, - ab_light_chain: str | None, - ab_specified_hotspots: str | None, - ab_cdr_length: str | None, - ab_samples_per_target: int, - # ---------- monomer (sample_monomer.py unconditional) ---------- - mono_length_subset: str | None, - mono_samples_num: int, - # ---------- scaffolding (sample_monomer.py motif) ---------- - scaffold_motif_names: str | None, - scaffold_samples_num: int, - # ---------- partial flow antibody/nanobody ---------- - pf_fixed_positions: str | None, - pf_cdr_position: str | None, - pf_specified_hotspots: str | None, - pf_start_t: float | None, - pf_samples_per_target: int | None, - pf_retry_limit: int, - pf_antigen_chain: str | None, - pf_heavy_chain: str | None, - pf_light_chain: str | None, - # ---------- partial flow binder ---------- - bpf_target_chain: str, - bpf_binder_chain: str, - bpf_start_t: float, - # ---------- mpnn ---------- - mpnn_source_task: str | None, - mpnn_source_run: str | None, - mpnn_num_seq_per_target_stage1: int, - mpnn_num_seq_per_target_stage2: int, - mpnn_temp_stage1: float, - mpnn_temp_stage2: float, - mpnn_model_name: str | None, - mpnn_batch_size: int, - mpnn_seed: int, - mpnn_ckpt_path: str | None, - mpnn_chain_list: str | None, - mpnn_omit_aas: str | None, - mpnn_use_soluble_model: bool, -) -> bytes: - """Unified runner for PPIFlow structured sampling tasks.""" - # ------------------------- - # Branch 1: MPNN tasks (operate on existing run dir) - # ------------------------- - if task in MPNN_TASKS: - if not mpnn_source_task or not mpnn_source_run: - raise ValueError( - "mpnn_stage1/2 requires --mpnn-source-task and --mpnn-source-run" - ) + dockq = None + if _step_enabled(enabled, "DockQStep"): + inputs = _structure_inputs(filtered) + if refold is not None: + inputs["models"] = refold.outputs(kind=ArtifactKind.STRUCTURES) + dockq = workflow.add_node( + _app_step_node(steps, "DockQStep", modal_namespace), + id="stage2-dockq", + inputs=inputs, + ) - src_run_dir = RUNS_DIR / mpnn_source_task / mpnn_source_run - src_outputs = src_run_dir / "outputs" - if not src_outputs.exists(): - raise FileNotFoundError(f"Source outputs missing: {src_outputs}") + if _step_enabled(enabled, "RankStep") or _step_enabled(enabled, "ReportStep"): + inputs = _structure_inputs(filtered) + if dockq is not None: + inputs["dockq"] = dockq.outputs(kind=ArtifactKind.SCORES) + workflow.add_node( + RankAndReportNode("RankAndReportStep", _rank_report_cfg(steps)), + id="stage2-rank-report", + inputs=inputs, + ) - pdbs = _collect_pdbs_for_mpnn(src_outputs) - mpnn_script = _find_abmpnn_run_script() +def _structure_inputs(upstream) -> dict[str, Any]: + if upstream is None: + return {} + return {"structures": upstream.outputs(kind=ArtifactKind.STRUCTURES)} - # detect flags first, then resolve checkpoint only when needed - flags = _detect_abmpnn_cli_flags(mpnn_script) - ckpt_path: Path | None = None - if flags.get("ckpt_flag"): - ckpt_path = _resolve_abmpnn_ckpt(mpnn_ckpt_path) - (src_run_dir / f"{task}.mpnn_help.txt").write_text( - _script_help_text(mpnn_script) - ) - mpnn_dir = src_run_dir / task - mpnn_out = mpnn_dir / "out" - mpnn_dir.mkdir(parents=True, exist_ok=True) - - if task == "mpnn_stage1": - num_seq = int(mpnn_num_seq_per_target_stage1) - temp = float(mpnn_temp_stage1) - manifest_name = "candidates_stage1.csv" - else: - num_seq = int(mpnn_num_seq_per_target_stage2) - temp = float(mpnn_temp_stage2) - manifest_name = "candidates_stage2.csv" - - mpnn_out = mpnn_dir / "out" - pdb_folder = mpnn_dir / "pdbs" - mpnn_inputs = mpnn_dir / "inputs" - mpnn_dir.mkdir(parents=True, exist_ok=True) - mpnn_inputs.mkdir(parents=True, exist_ok=True) - - _stage_pdb_folder(pdbs, pdb_folder) - - position_list_csv: Path | None = None - for fname, content in input_files: - dst = mpnn_inputs / Path(fname).name - dst.write_bytes(content) - if Path(fname).name == "mpnn_fixed_positions.csv": - position_list_csv = dst - if position_list_csv and not mpnn_chain_list: - raise ValueError( - f"{task} requires --mpnn-chain-list when --mpnn-position-list is provided" - ) - if position_list_csv: - _rewrite_fixed_positions_for_proteinmpnn( - csv_path=position_list_csv, - pdb_folder=pdb_folder, - chain_list=mpnn_chain_list or "", - ) - else: - # Keep "full design" behavior while working around upstream UnboundLocalError. - position_list_csv = mpnn_inputs / "mpnn_fixed_positions.auto.csv" - _write_empty_fixed_positions_csv(position_list_csv, pdb_folder) - - log_path = mpnn_dir / "mpnn_folder.log" - _run_abmpnn_on_folder( - mpnn_script=mpnn_script, - flags=flags, - pdb_folder=pdb_folder, - out_folder=mpnn_out, - ckpt_path=ckpt_path, - num_seqs=num_seq, - sampling_temp=temp, - seed=int(mpnn_seed), - batch_size=int(mpnn_batch_size), - model_name=mpnn_model_name, - chain_list=mpnn_chain_list, - position_list_csv=position_list_csv, - omit_aas=mpnn_omit_aas, - use_soluble_model=mpnn_use_soluble_model, - log_path=log_path, - ) +def _app_step_node( + steps: dict[str, Any], + step_name: str, + modal_namespace: PPIFlowModalNamespace, +) -> PPIFlowWorkflowNode: + return PPIFlowWorkflowNode( + step_name=step_name, + modal_namespace=modal_namespace, + config=_step_cfg(steps, step_name), + ) - fastas = _gather_mpnn_fastas(mpnn_out) - if not fastas: - raise RuntimeError( - f"No fasta outputs found under {mpnn_out} (expected */seqs/*.fa*)" - ) - rows: list[dict[str, Any]] = [] - for fa in fastas: - # .../out//seqs/.fa - try: - pdb_stem = fa.parent.parent.name - except Exception: - pdb_stem = fa.stem - rows.append( - { - "pdb_path": str((src_outputs / f"{pdb_stem}.pdb").resolve()), - "seq_fasta": str(fa.resolve()), - "iptm": "", - "ptm": "", - "passed": "", - } - ) +def _load_yaml_bytes(data: bytes) -> dict[str, Any]: + loaded = yaml.safe_load(data.decode("utf-8")) or {} + if not isinstance(loaded, dict): + raise ValueError("YAML root must be a mapping") + return loaded - _write_csv( - src_run_dir / manifest_name, - rows, - ["pdb_path", "seq_fasta", "iptm", "ptm", "passed"], - ) - (src_run_dir / f"{task}.meta.json").write_text( - json.dumps( - { - "mpnn_script": str(mpnn_script), - "ckpt_path": str(ckpt_path) if ckpt_path else None, - "num_seq_per_target": num_seq, - "sampling_temp": temp, - "batch_size": int(mpnn_batch_size), - "seed": int(mpnn_seed), - "chain_list": mpnn_chain_list, - "position_list_provided": bool(position_list_csv), - "omit_aas": mpnn_omit_aas, - "use_soluble_model": bool(mpnn_use_soluble_model), - "detected_flags": flags, - }, - indent=2, - ) - ) +def _task_section(task_doc: dict[str, Any]) -> dict[str, Any]: + section = task_doc.get("task", task_doc) + if not isinstance(section, dict): + raise ValueError("task.yaml must contain a mapping under 'task'") + return section - RUNS_VOL.commit() - # Return a tar of the SOURCE run dir (so you get both backbone outputs and mpnn outputs) - with tempfile.TemporaryDirectory() as td: - tar_path = Path(td) / f"{task}.{run_name}.tar.gz" - _tar_dir(src_run_dir, tar_path) - return tar_path.read_bytes() +def _enabled_section(task_doc: dict[str, Any]) -> dict[str, bool]: + enabled = task_doc.get("steps", {}) + if not isinstance(enabled, dict): + raise ValueError("task.yaml 'steps' section must be a mapping") + return {str(key): bool(value) for key, value in enabled.items()} - # ------------------------- - # Branch 2: PPIFlow sampling tasks - # ------------------------- - if task not in TASK_TO_SCRIPT: - raise ValueError( - f"Unknown task={task}. Choose from {sorted(TASK_TO_SCRIPT) + sorted(MPNN_TASKS)}" - ) - script = Path(TASK_TO_SCRIPT[task]) - if not script.exists(): - raise FileNotFoundError(f"Script not found in image: {script}") - - run_dir = RUNS_DIR / task / run_name - inputs_dir = run_dir / "inputs" - outputs_dir = run_dir / "outputs" - run_dir.mkdir(parents=True, exist_ok=True) - inputs_dir.mkdir(parents=True, exist_ok=True) - outputs_dir.mkdir(parents=True, exist_ok=True) - - # ---- write inputs (role-based filenames) ---- - for fname, content in input_files: - (inputs_dir / Path(fname).name).write_bytes(content) - - # ---- resolve model checkpoint (required for sampling tasks) ---- - if not model_weights: - raise ValueError("--model-weights is required for PPIFlow sampling tasks") - mw = Path(model_weights) - model_ckpt = mw if str(mw).startswith(str(MODELS_DIR)) else (MODELS_DIR / mw.name) - if not model_ckpt.exists(): - raise FileNotFoundError(f"Model checkpoint not found: {model_ckpt}") - - # ---- resolve config + write effective config (optional) ---- - effective_config: Path | None = None - if config: - cfg_path = Path(config) - if not cfg_path.is_absolute(): - cfg_guess = Path(PPIFLOW_DIR) / config - cfg_path = cfg_guess if cfg_guess.exists() else cfg_path - if not cfg_path.exists(): - raise FileNotFoundError( - f"Config not found: {config} (resolved: {cfg_path})" - ) - effective_config = run_dir / "effective_config.yaml" - _write_effective_config(cfg_path, effective_config) - - def p_in(name: str) -> Path: - p = inputs_dir / name - if not p.exists(): - raise FileNotFoundError(f"Required input file missing: {p}") - return p - - argv: list[str] = ["python", str(script)] - - if task == "binder": - input_pdb = p_in("binder_input.pdb") - argv += ["--input_pdb", str(input_pdb)] - argv += ["--target_chain", binder_target_chain] - argv += ["--binder_chain", binder_binder_chain] - if effective_config: - argv += ["--config", str(effective_config)] - if binder_specified_hotspots: - argv += ["--specified_hotspots", binder_specified_hotspots] - argv += [ - "--samples_min_length", - str(binder_samples_min_length), - "--samples_max_length", - str(binder_samples_max_length), - "--samples_per_target", - str(binder_samples_per_target), - "--model_weights", - str(model_ckpt), - "--output_dir", - str(outputs_dir), - "--name", - run_name, - ] - - elif task in {"antibody", "nanobody"}: - antigen_pdb = p_in("antigen.pdb") - framework_pdb = p_in("framework.pdb") - argv += [ - "--antigen_pdb", - str(antigen_pdb), - "--framework_pdb", - str(framework_pdb), - ] - - if not ab_antigen_chain: - raise ValueError("antibody/nanobody requires --ab-antigen-chain") - if not ab_heavy_chain: - raise ValueError("antibody/nanobody requires --ab-heavy-chain") - - argv += ["--antigen_chain", ab_antigen_chain, "--heavy_chain", ab_heavy_chain] - if ab_light_chain: - argv += ["--light_chain", ab_light_chain] - if ab_specified_hotspots: - argv += ["--specified_hotspots", ab_specified_hotspots] - if ab_cdr_length: - argv += ["--cdr_length", ab_cdr_length] - if effective_config: - argv += ["--config", str(effective_config)] - argv += [ - "--samples_per_target", - str(ab_samples_per_target), - "--model_weights", - str(model_ckpt), - "--output_dir", - str(outputs_dir), - "--name", - run_name, - ] - - elif task == "monomer": - if mono_length_subset is None: - raise ValueError("monomer requires --mono-length-subset") - if effective_config: - argv += ["--config", str(effective_config)] - argv += [ - "--model_weights", - str(model_ckpt), - "--output_dir", - str(outputs_dir), - "--length_subset", - mono_length_subset, - "--samples_num", - str(mono_samples_num), - ] - - elif task == "scaffolding": - motif_csv = p_in("motif.csv") - if effective_config: - argv += ["--config", str(effective_config)] - argv += [ - "--model_weights", - str(model_ckpt), - "--output_dir", - str(outputs_dir), - "--motif_csv", - str(motif_csv), - ] - if scaffold_motif_names: - argv += ["--motif_names", scaffold_motif_names] - argv += ["--samples_num", str(scaffold_samples_num)] - - elif task in {"ab_partial_flow", "nb_partial_flow"}: - complex_pdb = p_in("complex.pdb") - - if not pf_fixed_positions: - raise ValueError(f"{task} requires --pf-fixed-positions") - if not pf_cdr_position: - raise ValueError(f"{task} requires --pf-cdr-position") - if pf_start_t is None: - raise ValueError(f"{task} requires --pf-start-t") - if pf_samples_per_target is None: - raise ValueError(f"{task} requires --pf-samples-per-target") - if not pf_antigen_chain: - raise ValueError(f"{task} requires --pf-antigen-chain") - if not pf_heavy_chain: - raise ValueError(f"{task} requires --pf-heavy-chain") - - argv += [ - "--complex_pdb", - str(complex_pdb), - "--fixed_positions", - pf_fixed_positions, - "--cdr_position", - pf_cdr_position, - "--start_t", - str(pf_start_t), - "--samples_per_target", - str(pf_samples_per_target), - "--output_dir", - str(outputs_dir), - "--retry_Limit", - str(pf_retry_limit), - ] - if pf_specified_hotspots: - argv += ["--specified_hotspots", pf_specified_hotspots] - if effective_config: - argv += ["--config", str(effective_config)] - - argv += [ - "--model_weights", - str(model_ckpt), - "--antigen_chain", - pf_antigen_chain, - "--heavy_chain", - pf_heavy_chain, - ] - if pf_light_chain: - argv += ["--light_chain", pf_light_chain] - argv += ["--name", run_name] - - elif task == "binder_partial_flow": - input_pdb = p_in("binder_input.pdb") - argv += ["--input_pdb", str(input_pdb)] - if effective_config: - argv += ["--config", str(effective_config)] - argv += [ - "--model_weights", - str(model_ckpt), - "--target_chain", - bpf_target_chain, - "--binder_chain", - bpf_binder_chain, - "--start_t", - str(bpf_start_t), - "--output_dir", - str(outputs_dir), - ] +def _step_enabled(enabled: dict[str, bool], step_name: str) -> bool: + return bool(enabled.get(step_name, False)) - else: - raise ValueError(f"Task routed but not implemented: {task}") - - # ---- run ---- - (run_dir / "cmd.txt").write_text(" ".join(argv) + "\n") - run_cwd = inputs_dir if task == "scaffolding" else None - - p = subprocess.run( # noqa: S603 - argv, - check=False, - text=True, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - cwd=str(run_cwd) if run_cwd else None, - ) - (run_dir / "stdout.log").write_text(p.stdout or "") - # if fail, but outputs exist, keep; else raise - pdbs = sorted((run_dir / "outputs").glob("*.pdb")) - if p.returncode != 0 and not pdbs: - raise RuntimeError( - f"PPIFlow failed (exit {p.returncode}). See {run_dir}/stdout.log" - ) +def _step_cfg(steps: dict[str, Any], step_name: str) -> dict[str, Any]: + cfg = steps.get(step_name, {}) + if cfg is None: + return {} + if not isinstance(cfg, dict): + raise ValueError(f"steps.yaml entry {step_name!r} must be a mapping") + return cfg - _collect_artifacts(run_dir) - RUNS_VOL.commit() - with tempfile.TemporaryDirectory() as td: - tar_path = Path(td) / f"{task}.{run_name}.tar.gz" - _tar_dir(run_dir, tar_path) - return tar_path.read_bytes() +def _rank_report_cfg(steps: dict[str, Any]) -> dict[str, Any]: + return { + "RankStep": _step_cfg(steps, "RankStep"), + "ReportStep": _step_cfg(steps, "ReportStep"), + } -# ------------------------- -# Local entrypoint -# ------------------------- -@app.local_entrypoint() -def submit_ppiflow( - # ---------- common ---------- - task: str = "binder", - run_name: str = "test1", - out_dir: str = "./ppiflow_outputs", - # ---------- local input files ---------- - binder_input_pdb: str | None = None, - ab_antigen_pdb: str | None = None, - ab_framework_pdb: str | None = None, - pf_complex_pdb: str | None = None, - scaffold_motif_csv: str | None = None, - # ---------- model weights (sampling tasks only) ---------- - model_weights: str | None = None, - # ---------- config ---------- - config: str | None = None, - # ---------- binder args ---------- - binder_target_chain: str = "B", - binder_binder_chain: str = "A", - binder_specified_hotspots: str | None = None, - binder_samples_min_length: int = 75, - binder_samples_max_length: int = 76, - binder_samples_per_target: int = 5, - # ---------- antibody/nanobody args ---------- - ab_antigen_chain: str | None = None, - ab_heavy_chain: str | None = None, - ab_light_chain: str | None = None, - ab_specified_hotspots: str | None = None, - ab_cdr_length: str | None = None, - ab_samples_per_target: int = 5, - # ---------- monomer unconditional ---------- - mono_length_subset: str | None = None, - mono_samples_num: int = 5, - # ---------- scaffolding ---------- - scaffold_motif_names: str | None = None, - scaffold_samples_num: int = 5, - # ---------- partial flow antibody/nanobody ---------- - pf_fixed_positions: str | None = None, - pf_cdr_position: str | None = None, - pf_specified_hotspots: str | None = None, - pf_start_t: float | None = None, - pf_samples_per_target: int | None = None, - pf_retry_limit: int = 10, - pf_antigen_chain: str | None = None, - pf_heavy_chain: str | None = None, - pf_light_chain: str | None = None, - # ---------- partial flow binder ---------- - bpf_target_chain: str = "B", - bpf_binder_chain: str = "A", - bpf_start_t: float = 0.7, - # ---------- mpnn ---------- - mpnn_source_task: str | None = None, - mpnn_source_run: str | None = None, - mpnn_num_seq_per_target_stage1: int = 8, - mpnn_num_seq_per_target_stage2: int = 4, - mpnn_temp_stage1: float = 0.5, - mpnn_temp_stage2: float = 0.1, - mpnn_model_name: str = "v_48_020", - mpnn_batch_size: int = 1, - mpnn_seed: int = 0, - mpnn_ckpt_path: str | None = None, - mpnn_chain_list: str | None = None, - mpnn_position_list: str | None = None, - mpnn_omit_aas: str | None = None, - mpnn_use_soluble_model: bool = False, -) -> None: - """Unified Modal CLI. +def _ppiflow_input_fields(args: object) -> tuple[str, ...]: + if isinstance(args, ppiflow_app.SampleAntibodyNanobodyConfig): + return ("antigen_pdb", "framework_pdb") + if isinstance(args, ppiflow_app.SampleAntibodyNanobodyPartialConfig): + return ("complex_pdb",) + if isinstance( + args, + (ppiflow_app.SampleBinderConfig, ppiflow_app.SampleBinderPartialConfig), + ): + return ("input_pdb",) + raise TypeError(f"Unsupported PPIFlow args type: {type(args).__name__}") + + +def _active_ppiflow_app_steps( + task_doc: dict[str, Any], stage: int | None +) -> tuple[str, ...]: + """Return PPIFlow app steps that should be staged for the selected run.""" + if stage not in {None, 1, 2}: + raise ValueError("stage must be omitted, 1, or 2") + enabled = _enabled_section(task_doc) + active_steps: list[str] = [] + if stage in {None, 1} and _step_enabled(enabled, "PPIFlowStep"): + active_steps.append("PPIFlowStep") + if stage in {None, 2} and _step_enabled(enabled, "PartialStep"): + active_steps.append("PartialStep") + return tuple(active_steps) + + +def _stage_ppiflow_app_inputs( + *, + steps_doc: dict[str, Any], + run_id: str, + app_steps: tuple[str, ...], +) -> dict[str, Any]: + """Upload local PPIFlow app inputs and rewrite step args to mounted paths.""" + staged_steps = deepcopy(steps_doc) + uploads: list[tuple[Path, str]] = [] + volume_root = Path(ppiflow_app.CONF.output_volume_mountpoint) + + for step_name in app_steps: + if step_name not in staged_steps: + continue + cfg = _step_cfg(staged_steps, step_name) + raw_args = cfg.get("args", cfg) + if not isinstance(raw_args, dict): + continue + + app_args = ppiflow_app.PPIFlowArgs.model_validate({"args": raw_args}) + for field_name in _ppiflow_input_fields(app_args.args): + current_value = getattr(app_args.args, field_name) + current_path = Path(current_value) + if current_path.is_absolute() and current_path.is_relative_to(volume_root): + continue - - Sampling tasks upload inputs (role-based). - - MPNN tasks do NOT upload inputs; they operate on an existing run: - --mpnn-source-task --mpnn-source-run - """ - allowed = set(TASK_TO_SCRIPT) | set(MPNN_TASKS) - if task not in allowed: - raise ValueError(f"--task must be one of {sorted(allowed)}") - - def _read_file_as(role_name: str, path: str | None) -> tuple[str, bytes] | None: - if not path: - return None - pp = Path(path).expanduser() - if not pp.exists(): - raise FileNotFoundError(f"Local file not found: {pp}") - return (role_name, pp.read_bytes()) - - input_files: list[tuple[str, bytes]] = [] - - # Build uploads only for sampling tasks - if task in TASK_TO_SCRIPT: - if task in {"binder", "binder_partial_flow"}: - if not binder_input_pdb: - raise ValueError(f"{task} requires --binder-input-pdb") - item = _read_file_as("binder_input.pdb", binder_input_pdb) - if item: - input_files.append(item) - - if task in {"antibody", "nanobody"}: - if not ab_antigen_pdb or not ab_framework_pdb: - raise ValueError( - f"{task} requires --ab-antigen-pdb and --ab-framework-pdb" - ) - input_files.append(_read_file_as("antigen.pdb", ab_antigen_pdb)) - input_files.append(_read_file_as("framework.pdb", ab_framework_pdb)) - - if task in {"ab_partial_flow", "nb_partial_flow"}: - if not pf_complex_pdb: - raise ValueError(f"{task} requires --pf-complex-pdb") - input_files.append(_read_file_as("complex.pdb", pf_complex_pdb)) - - if task == "scaffolding": - if not scaffold_motif_csv: - raise ValueError("scaffolding requires --scaffold-motif-csv") - - csv_path = Path(scaffold_motif_csv).expanduser() - if not csv_path.exists(): - raise FileNotFoundError(f"Local file not found: {csv_path}") - - csv_text = csv_path.read_text(encoding="utf-8-sig") - reader = csv.DictReader(StringIO(csv_text)) - required_cols = {"target", "length", "contig", "motif_path"} - if not reader.fieldnames or not required_cols.issubset( - set(reader.fieldnames) - ): - raise ValueError( - f"motif.csv must have columns {sorted(required_cols)}, got {reader.fieldnames}" + local_path = current_path.expanduser().resolve() + if not local_path.exists(): + raise FileNotFoundError( + f"PPIFlow {step_name} input {field_name!r} was not found " + f"locally or in the mounted output volume: {current_value}" ) - rows = list(reader) - if not rows: - raise ValueError("motif.csv has no data rows") - - csv_dir = csv_path.parent - motif_files: dict[str, Path] = {} - for r in rows: - mp = (r.get("motif_path") or "").strip() - if not mp: - raise ValueError(f"motif_path is empty in row: {r}") - - mp_path = Path(mp).expanduser() if mp.startswith("~") else Path(mp) - if not mp_path.is_absolute(): - mp_path = (csv_dir / mp_path).resolve() - if not mp_path.exists(): - raise FileNotFoundError( - f"motif_path file not found: {mp_path} (from motif_path={mp!r})" - ) - - stable_name = mp_path.name - if stable_name in motif_files and motif_files[stable_name] != mp_path: - stable_name = ( - f"{mp_path.stem}.{len(motif_files) + 1}{mp_path.suffix}" - ) - motif_files[stable_name] = mp_path - r["motif_path"] = stable_name - - out_buf = StringIO() - writer = csv.DictWriter(out_buf, fieldnames=reader.fieldnames) - writer.writeheader() - writer.writerows(rows) - input_files.append(("motif.csv", out_buf.getvalue().encode("utf-8"))) - - for stable_name, p in motif_files.items(): - input_files.append((stable_name, p.read_bytes())) - - if task == "scaffolding" and scaffold_motif_names: - s = scaffold_motif_names.strip() - if not s.startswith("["): - scaffold_motif_names = json.dumps([s]) - elif task in MPNN_TASKS: - if mpnn_position_list and not mpnn_chain_list: - raise ValueError( - f"{task} requires --mpnn-chain-list when --mpnn-position-list is provided" + remote_rel = ( + Path(run_id) + / sanitize_filename(step_name) + / sanitize_filename(field_name) + / sanitize_filename(local_path.name) ) - if mpnn_position_list: - pos_csv = Path(mpnn_position_list).expanduser() - if not pos_csv.exists(): - raise FileNotFoundError(f"Local file not found: {pos_csv}") - normalized = _normalize_fixed_positions_csv_bytes(pos_csv.read_bytes()) - input_files.append(("mpnn_fixed_positions.csv", normalized)) - - # dispatch - tar_bytes = run_ppiflow_structured.remote( - task=task, - run_name=run_name, - input_files=input_files, - model_weights=model_weights, - config=config, - binder_target_chain=binder_target_chain, - binder_binder_chain=binder_binder_chain, - binder_specified_hotspots=binder_specified_hotspots, - binder_samples_min_length=binder_samples_min_length, - binder_samples_max_length=binder_samples_max_length, - binder_samples_per_target=binder_samples_per_target, - ab_antigen_chain=ab_antigen_chain, - ab_heavy_chain=ab_heavy_chain, - ab_light_chain=ab_light_chain, - ab_specified_hotspots=ab_specified_hotspots, - ab_cdr_length=ab_cdr_length, - ab_samples_per_target=ab_samples_per_target, - mono_length_subset=mono_length_subset, - mono_samples_num=mono_samples_num, - scaffold_motif_names=scaffold_motif_names, - scaffold_samples_num=scaffold_samples_num, - pf_fixed_positions=pf_fixed_positions, - pf_cdr_position=pf_cdr_position, - pf_specified_hotspots=pf_specified_hotspots, - pf_start_t=pf_start_t, - pf_samples_per_target=pf_samples_per_target, - pf_retry_limit=pf_retry_limit, - pf_antigen_chain=pf_antigen_chain, - pf_heavy_chain=pf_heavy_chain, - pf_light_chain=pf_light_chain, - bpf_target_chain=bpf_target_chain, - bpf_binder_chain=bpf_binder_chain, - bpf_start_t=bpf_start_t, - mpnn_source_task=mpnn_source_task, - mpnn_source_run=mpnn_source_run, - mpnn_num_seq_per_target_stage1=mpnn_num_seq_per_target_stage1, - mpnn_num_seq_per_target_stage2=mpnn_num_seq_per_target_stage2, - mpnn_temp_stage1=mpnn_temp_stage1, - mpnn_temp_stage2=mpnn_temp_stage2, - mpnn_batch_size=mpnn_batch_size, - mpnn_seed=mpnn_seed, - mpnn_ckpt_path=mpnn_ckpt_path, - mpnn_model_name=mpnn_model_name, - mpnn_chain_list=mpnn_chain_list, - mpnn_omit_aas=mpnn_omit_aas, - mpnn_use_soluble_model=mpnn_use_soluble_model, + raw_args[field_name] = str(volume_root / remote_rel) + uploads.append((local_path, remote_rel.as_posix())) + + if uploads: + with ppiflow_app.CONF.output_volume.batch_upload() as batch: + for local_path, remote_rel in uploads: + remote_storage = volume_path_from_mount_path( + str(volume_root / remote_rel), + str(volume_root), + ppiflow_app.CONF.output_volume_name, + ) + print( + f"Uploading PPIFlow input '{local_path}' to {remote_storage}", + flush=True, + ) + batch.put_file(local_path, f"/{remote_storage.path}") + return staged_steps + + +@app.local_entrypoint() +def submit_ppiflow_workflow( + task_yaml: str, + steps_yaml: str, + run_id: str | None = None, + stage: int | None = None, + force: bool = False, + wait: bool = True, + max_parallel: int = 16, +) -> None: + """Build and submit a PPIFlow workflow from task and step YAML files. + + Args: + task_yaml: Path to the PPIFlow task YAML declaring enabled workflow + steps and design mode. + steps_yaml: Path to the YAML file containing per-step app arguments. + run_id: Stable workflow run id for durable ledger state. Defaults to + the task YAML filename stem. + stage: Optional stage selector. Use 1 for stage 1 only, 2 for stage 2 + only, or omit to build both stages. + force: Replace an existing workflow run ledger before running. + wait: Wait locally for the remote workflow result. Disable to print the + Modal function call id for asynchronous collection. + max_parallel: Maximum number of ready workflow nodes to execute + concurrently in one scheduler wave. + """ + task_yaml_path = Path(task_yaml).expanduser().resolve() + steps_yaml_path = Path(steps_yaml).expanduser().resolve() + resolved_run_id = sanitize_filename(run_id or task_yaml_path.stem) + task_yaml_bytes = task_yaml_path.read_bytes() + task_doc = _load_yaml_bytes(task_yaml_bytes) + steps_doc = _stage_ppiflow_app_inputs( + steps_doc=_load_yaml_bytes(steps_yaml_path.read_bytes()), + run_id=resolved_run_id, + app_steps=_active_ppiflow_app_steps(task_doc, stage), + ) + workflow = build_ppiflow_workflow( + task_yaml_bytes=task_yaml_bytes, + steps_yaml_bytes=yaml.safe_dump(steps_doc).encode("utf-8"), + stage=stage, ) - out_dir_p = Path(out_dir).expanduser() - out_dir_p.mkdir(parents=True, exist_ok=True) - out_tar = out_dir_p / f"{task}.{run_name}.tar.gz" - out_tar.write_bytes(tar_bytes) - print(f"[ok] saved: {out_tar}") + orchestrator_handle = orchestrator.WorkflowOrchestrator() + orchestrator_kwargs = { + "workflow": workflow, + "run_id": resolved_run_id, + "force": force, + "max_ready_workers": max_parallel, + } + print( + f"Submitting PPIFlow workflow '{resolved_run_id}' with " + f"{len(workflow.validate().nodes)} node(s)", + flush=True, + ) + if wait: + result: AppRunResult | str = AppRunResult.model_validate( + orchestrator_handle.run.remote(**orchestrator_kwargs) + ) + else: + function_call = orchestrator_handle.run.spawn(**orchestrator_kwargs) + result = str(getattr(function_call, "object_id", function_call)) + if isinstance(result, AppRunResult): + print(f"PPIFlow workflow run finished with status: {result.status}", flush=True) + else: + print(f"PPIFlow workflow run submitted. FunctionCall id: {result}", flush=True) diff --git a/src/biomodals/workflow/shortmd_workflow.py b/src/biomodals/workflow/shortmd_workflow.py new file mode 100644 index 0000000..203db7e --- /dev/null +++ b/src/biomodals/workflow/shortmd_workflow.py @@ -0,0 +1,625 @@ +"""ShortMD workflow for parallel short GROMACS production replicates. + +This proof-of-concept accepts a directory of PDB files, prepares each structure +once with the GROMACS app, clones the prepared production inputs into replicate +run directories, and runs many short production trajectories in parallel through +the reusable Biomodals workflow runtime. +""" + +# Ignore ruff warnings about import location +# ruff: noqa: PLC0415 + +from __future__ import annotations + +import os +import shutil +from dataclasses import dataclass, field +from pathlib import Path + +import modal + +from biomodals.app.bioinfo import gromacs_app +from biomodals.helper import patch_image_for_helper +from biomodals.helper.catalog import include_dependency_apps +from biomodals.helper.constant import MAX_TIMEOUT +from biomodals.helper.shell import sanitize_filename +from biomodals.helper.volume_run import volume_path_from_mount_path +from biomodals.schema import ( + AppConfig, + AppOutput, + AppRunResult, + AppRunStatus, + ArtifactKind, + InlineBytes, + NodeExecutionPolicy, + NodePlacement, + VolumePath, +) +from biomodals.workflow.core import ( + AppBackedNode, + NodeRunContext, + Workflow, + WorkflowNativeNode, + orchestrator, +) + +DEPENDENCY_APPS = ("gromacs",) +CONF = AppConfig( + tags={"depends_on": ",".join(DEPENDENCY_APPS)}, + depends_on_apps=DEPENDENCY_APPS, + name="ShortMDWorkflow", + package_name="biomodals-shortmd-workflow", + version="0.1.0", + python_version="3.13", + timeout=int(os.environ.get("TIMEOUT", str(MAX_TIMEOUT))), +) + +runtime_image = ( + modal.Image + .debian_slim(python_version=CONF.python_version) + .env(CONF.default_env) + .pipe(patch_image_for_helper, include_workflow_modules=True) +) +app = modal.App(CONF.name, image=runtime_image, tags=CONF.tags).include( + orchestrator.app, inherit_tags=True +) +app = include_dependency_apps(app, CONF.depends_on_apps) +GROMACS_OUTPUT_VOLUME = gromacs_app.CONF.output_volume +GROMACS_OUTPUT_VOLUME_NAME = gromacs_app.CONF.output_volume_name +GROMACS_OUTPUT_MOUNTPOINT = gromacs_app.CONF.output_volume_mountpoint + + +@dataclass(frozen=True) +class ShortMDGromacsSettings: + """Shared GROMACS arguments for ShortMD prep and production nodes.""" + + simulation_time_ns: int = 2 + run_pdbfixer: bool = False + cpu_only: bool = False + num_threads: int = 16 + use_openmp_threads: bool = False + ld_seed: int = -1 + gen_seed: int = -1 + genion_seed: int = 0 + save_processed_traj: bool = True + make_figures: bool = True + + +@dataclass(frozen=True) +class ShortMDModalNamespace: + """Hydrated Modal objects carried across the orchestrator boundary.""" + + clear: modal.Function + clone: modal.Function + prepare_cpu: modal.Function + prepare_gpu: modal.Function + production_cpu: modal.Function + production_gpu: modal.Function + collect_stats: modal.Function + + +@app.function( + image=runtime_image, + cpu=0.125, + memory=(512, 4096), + timeout=CONF.timeout, + volumes={GROMACS_OUTPUT_MOUNTPOINT: GROMACS_OUTPUT_VOLUME}, +) +def clear_shortmd_gromacs_run(run_name: str) -> None: + """Remove one ShortMD-managed GROMACS run directory from the app volume.""" + safe_run_name = sanitize_filename(run_name) + GROMACS_OUTPUT_VOLUME.reload() + run_dir = Path(GROMACS_OUTPUT_MOUNTPOINT) / safe_run_name + if run_dir.is_dir(): + shutil.rmtree(run_dir) + elif run_dir.exists(): + run_dir.unlink() + GROMACS_OUTPUT_VOLUME.commit() + + +@app.function( + image=runtime_image, + cpu=0.125, + memory=(512, 4096), + timeout=CONF.timeout, + volumes={GROMACS_OUTPUT_MOUNTPOINT: GROMACS_OUTPUT_VOLUME}, +) +def clone_prepared_shortmd_run( + source_storage_path: str, + source_run_name: str, + replicate_run_name: str, + overwrite: bool = False, +) -> str: + """Clone prepared GROMACS inputs into a ShortMD replicate directory.""" + source_storage_path = VolumePath( + volume_name=GROMACS_OUTPUT_VOLUME_NAME, + path=source_storage_path, + ).path + source_run_name = sanitize_filename(source_run_name) + replicate_run_name = sanitize_filename(replicate_run_name) + GROMACS_OUTPUT_VOLUME.reload() + + volume_root = Path(GROMACS_OUTPUT_MOUNTPOINT) + source_dir = volume_root / source_storage_path + replicate_dir = volume_root / replicate_run_name + if not source_dir.is_dir(): + raise FileNotFoundError(f"Prepared GROMACS run not found: {source_dir}") + + if overwrite and replicate_dir.exists(): + shutil.rmtree(replicate_dir) + + created_clone = False + if not replicate_dir.exists(): + shutil.copytree(source_dir, replicate_dir) + created_clone = True + else: + replicate_dir.mkdir(parents=True, exist_ok=True) + + source_pdb = replicate_dir / f"{source_run_name}.pdb" + if not source_pdb.exists(): + source_pdb = source_dir / f"{source_run_name}.pdb" + if not source_pdb.exists(): + raise FileNotFoundError(f"Prepared PDB not found: {source_pdb}") + shutil.copy2(source_pdb, replicate_dir / f"{replicate_run_name}.pdb") + + source_tpr = replicate_dir / f"production_{source_run_name}.tpr" + if not source_tpr.exists(): + source_tpr = source_dir / f"production_{source_run_name}.tpr" + if not source_tpr.exists(): + raise FileNotFoundError(f"Prepared production TPR not found: {source_tpr}") + shutil.copy2(source_tpr, replicate_dir / f"production_{replicate_run_name}.tpr") + + if created_clone: + keep_tpr = f"production_{replicate_run_name}.tpr" + for path in replicate_dir.glob("production_*"): + if path.name != keep_tpr and (path.is_file() or path.is_symlink()): + path.unlink() + for pattern in ( + "rmsd_production_*", + "rg_production_*", + "rmsf_production_*", + "production_*_nopbc*", + "production_*_last_frame.pdb", + ): + for path in replicate_dir.glob(pattern): + if path.is_file() or path.is_symlink(): + path.unlink() + + GROMACS_OUTPUT_VOLUME.commit() + return str(replicate_dir) + + +@dataclass +class ShortMDPrepNode(AppBackedNode): + """Workflow node that prepares one PDB for GROMACS production replicates.""" + + pdb_content: bytes + run_name: str + modal_namespace: ShortMDModalNamespace = field( + repr=False, + compare=False, + metadata={"dag_hash": False}, + ) + overwrite_existing: bool = False + gromacs: ShortMDGromacsSettings = field(default_factory=ShortMDGromacsSettings) + execution_policy: NodeExecutionPolicy = NodeExecutionPolicy.RESUME + placement: NodePlacement = NodePlacement.REMOTE + + def run(self, context: NodeRunContext) -> AppRunResult: + """Run GROMACS prep and return a workflow artifact for the run directory.""" + safe_run_name = sanitize_filename(self.run_name) + if self.overwrite_existing: + self.modal_namespace.clear.remote(run_name=safe_run_name) + app_function = ( + self.modal_namespace.prepare_cpu + if self.gromacs.cpu_only + else self.modal_namespace.prepare_gpu + ) + remote_workdir = app_function.remote( + pdb_content=self.pdb_content, + run_name=safe_run_name, + simulation_time_ns=self.gromacs.simulation_time_ns, + run_pdbfixer=self.gromacs.run_pdbfixer, + num_threads=self.gromacs.num_threads, + use_openmp_threads=self.gromacs.use_openmp_threads, + ld_seed=self.gromacs.ld_seed, + gen_seed=self.gromacs.gen_seed, + genion_seed=self.gromacs.genion_seed, + ) + return AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="prepared_gromacs_run", + kind=ArtifactKind.DIRECTORY, + storage=volume_path_from_mount_path( + remote_path=str(remote_workdir), + mount_root=GROMACS_OUTPUT_MOUNTPOINT, + volume_name=GROMACS_OUTPUT_VOLUME_NAME, + ), + metadata={"stage": "prep", "run_name": safe_run_name}, + ) + ], + ) + + +@dataclass +class ShortMDCloneNode(WorkflowNativeNode): + """Workflow-native adapter that clones prepared inputs for one replicate.""" + + source_run_name: str + replicate_run_name: str + modal_namespace: ShortMDModalNamespace = field( + repr=False, + compare=False, + metadata={"dag_hash": False}, + ) + overwrite_clone: bool = False + execution_policy: NodeExecutionPolicy = NodeExecutionPolicy.RESUME + placement: NodePlacement = NodePlacement.REMOTE + + def run(self, context: NodeRunContext) -> AppRunResult: + """Clone prepared inputs into a replicate run directory.""" + prepared_artifacts = context.inputs.get("prepared") or [] + if len(prepared_artifacts) != 1: + raise ValueError( + "ShortMD clone node requires exactly one prepared input artifact" + ) + prepared_artifact = prepared_artifacts[0] + if prepared_artifact.storage.volume_name != GROMACS_OUTPUT_VOLUME_NAME: + raise ValueError( + "ShortMD prepared artifact volume does not match the GROMACS " + f"output volume: {prepared_artifact.storage.volume_name}" + ) + safe_source_run_name = sanitize_filename( + str(prepared_artifact.metadata.get("run_name") or self.source_run_name) + ) + safe_replicate_run_name = sanitize_filename(self.replicate_run_name) + remote_workdir = self.modal_namespace.clone.remote( + source_storage_path=prepared_artifact.storage.path, + source_run_name=safe_source_run_name, + replicate_run_name=safe_replicate_run_name, + overwrite=self.overwrite_clone, + ) + return AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="cloned_gromacs_run", + kind=ArtifactKind.DIRECTORY, + storage=volume_path_from_mount_path( + remote_path=str(remote_workdir), + mount_root=GROMACS_OUTPUT_MOUNTPOINT, + volume_name=GROMACS_OUTPUT_VOLUME_NAME, + ), + metadata={ + "stage": "clone", + "run_name": safe_replicate_run_name, + "source_run_name": safe_source_run_name, + }, + ) + ], + ) + + +@dataclass +class ShortMDReplicateNode(AppBackedNode): + """Workflow node that runs one short production replicate through GROMACS.""" + + source_run_name: str + replicate_run_name: str + modal_namespace: ShortMDModalNamespace = field( + repr=False, + compare=False, + metadata={"dag_hash": False}, + ) + gromacs: ShortMDGromacsSettings = field(default_factory=ShortMDGromacsSettings) + execution_policy: NodeExecutionPolicy = NodeExecutionPolicy.RESUME + placement: NodePlacement = NodePlacement.REMOTE + + def run(self, context: NodeRunContext) -> AppRunResult: + """Launch one GROMACS production run, then collect trajectory stats.""" + cloned_artifacts = context.inputs.get("cloned") or [] + if len(cloned_artifacts) != 1: + raise ValueError( + "ShortMD replicate node requires exactly one cloned input artifact" + ) + cloned_artifact = cloned_artifacts[0] + if cloned_artifact.storage.volume_name != GROMACS_OUTPUT_VOLUME_NAME: + raise ValueError( + "ShortMD cloned artifact volume does not match the GROMACS " + f"output volume: {cloned_artifact.storage.volume_name}" + ) + safe_source_run_name = sanitize_filename( + str(cloned_artifact.metadata.get("source_run_name") or self.source_run_name) + ) + safe_replicate_run_name = sanitize_filename( + str(cloned_artifact.metadata.get("run_name") or self.replicate_run_name) + ) + app_function = ( + self.modal_namespace.production_cpu + if self.gromacs.cpu_only + else self.modal_namespace.production_gpu + ) + _ = app_function.remote( + run_name=safe_replicate_run_name, + simulation_time_ns=self.gromacs.simulation_time_ns, + num_threads=self.gromacs.num_threads, + use_openmp_threads=self.gromacs.use_openmp_threads, + ) + remote_workdir = self.modal_namespace.collect_stats.remote( + "production_", + run_name=safe_replicate_run_name, + save_processed_traj=self.gromacs.save_processed_traj, + make_figures=self.gromacs.make_figures, + ) + return AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="gromacs_production", + kind=ArtifactKind.DIRECTORY, + storage=volume_path_from_mount_path( + remote_path=str(remote_workdir), + mount_root=GROMACS_OUTPUT_MOUNTPOINT, + volume_name=GROMACS_OUTPUT_VOLUME_NAME, + ), + metadata={ + "stage": "production", + "run_name": safe_replicate_run_name, + "source_run_name": safe_source_run_name, + }, + ) + ], + ) + + +@dataclass +class ShortMDSummaryNode(WorkflowNativeNode): + """Workflow-native node that emits a manifest of production replicates.""" + + replicates: int + max_parallel: int + + def run(self, context: NodeRunContext) -> AppRunResult: + """Write a Markdown summary of all replicate output artifacts.""" + artifacts = [ + artifact + for artifacts in context.inputs.values() + for artifact in artifacts + if artifact.kind == ArtifactKind.DIRECTORY + ] + artifacts.sort(key=lambda artifact: artifact.metadata.get("run_name", "")) + lines = [ + "# ShortMD Workflow Summary", + "", + f"- Replicates per input: {self.replicates}", + f"- Max parallel workflow nodes: {self.max_parallel}", + "", + "| Source run | Replicate run | Volume | Path |", + "| --- | --- | --- | --- |", + ] + for artifact in artifacts: + source_run_name = str(artifact.metadata.get("source_run_name") or "") + run_name = str(artifact.metadata.get("run_name") or artifact.artifact_id) + lines.append( + "| " + f"{source_run_name} | " + f"{run_name} | " + f"{artifact.storage.volume_name} | " + f"{artifact.storage.path} |" + ) + summary = "\n".join(lines) + "\n" + return AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="shortmd_summary", + kind=ArtifactKind.REPORT, + storage=InlineBytes( + data=summary.encode("utf-8"), + filename="shortmd-summary.md", + media_type="text/markdown", + ), + metadata={ + "replicates": str(self.replicates), + "max_parallel": str(self.max_parallel), + }, + ) + ], + ) + + +def discover_pdb_inputs(input_dir: str | Path) -> list[tuple[str, bytes]]: + """Return ``(filename, bytes)`` pairs for PDB files in a directory.""" + input_path = Path(input_dir).expanduser().resolve() + if not input_path.is_dir(): + raise NotADirectoryError(input_path) + pdb_paths = list(input_path.glob("*.pdb")) + if not pdb_paths: + raise ValueError(f"No PDB files found in {input_path}") + return [(path.name, path.read_bytes()) for path in pdb_paths] + + +def build_shortmd_workflow( + *, + input_pdbs: list[tuple[str, bytes]], + run_namespace: str | None = None, + replicates: int = 50, + simulation_time_ns: int = 2, + run_pdbfixer: bool = False, + cpu_only: bool = False, + num_threads: int = 16, + use_openmp_threads: bool = False, + ld_seed: int = -1, + gen_seed: int = -1, + genion_seed: int = 0, + max_parallel: int = 16, + overwrite_existing: bool = False, +) -> Workflow: + """Build a ShortMD workflow DAG from local PDB payloads.""" + if replicates < 1: + raise ValueError("replicates must be at least 1") + workflow = Workflow("shortmd") + safe_run_namespace = ( + sanitize_filename(run_namespace) if run_namespace is not None else None + ) + gromacs = ShortMDGromacsSettings( + simulation_time_ns=simulation_time_ns, + run_pdbfixer=run_pdbfixer, + cpu_only=cpu_only, + num_threads=num_threads, + use_openmp_threads=use_openmp_threads, + ld_seed=ld_seed, + gen_seed=gen_seed, + genion_seed=genion_seed, + ) + modal_namespace = ShortMDModalNamespace( + clear=clear_shortmd_gromacs_run, + clone=clone_prepared_shortmd_run, + prepare_cpu=gromacs_app.prepare_tpr_cpu, + prepare_gpu=gromacs_app.prepare_tpr_gpu, + production_cpu=gromacs_app.production_run_cpu, + production_gpu=gromacs_app.production_run_gpu, + collect_stats=gromacs_app.collect_traj_stats, + ) + used_run_names: set[str] = set() + replicate_handles = {} + + for file_name, pdb_content in input_pdbs: + pdb_run_name = sanitize_filename(Path(file_name).stem) + run_name = ( + f"{safe_run_namespace}-{pdb_run_name}" + if safe_run_namespace is not None + else pdb_run_name + ) + if run_name in used_run_names: + raise ValueError(f"Duplicate sanitized PDB run name: {run_name}") + used_run_names.add(run_name) + prep = workflow.add_node( + ShortMDPrepNode( + pdb_content=pdb_content, + run_name=run_name, + modal_namespace=modal_namespace, + overwrite_existing=overwrite_existing, + gromacs=gromacs, + ), + id=f"prep-{run_name}", + ) + for replicate_idx in range(1, replicates + 1): + replicate_run_name = f"{run_name}-r{replicate_idx:03d}" + clone = workflow.add_node( + ShortMDCloneNode( + source_run_name=run_name, + replicate_run_name=replicate_run_name, + modal_namespace=modal_namespace, + overwrite_clone=overwrite_existing, + ), + id=f"clone-{replicate_run_name}", + inputs={"prepared": prep.outputs(kind=ArtifactKind.DIRECTORY)}, + ) + replicate = workflow.add_node( + ShortMDReplicateNode( + source_run_name=run_name, + replicate_run_name=replicate_run_name, + modal_namespace=modal_namespace, + gromacs=gromacs, + ), + id=f"replicate-{replicate_run_name}", + inputs={"cloned": clone.outputs(kind=ArtifactKind.DIRECTORY)}, + ) + replicate_handles[replicate_run_name] = replicate + + workflow.add_node( + ShortMDSummaryNode(replicates=replicates, max_parallel=max_parallel), + id="summary", + inputs={ + replicate_run_name: handle.outputs(kind=ArtifactKind.DIRECTORY) + for replicate_run_name, handle in replicate_handles.items() + }, + ) + return workflow + + +@app.local_entrypoint() +def submit_shortmd_workflow( + input_dir: str, + run_id: str | None = None, + replicates: int = 50, + simulation_time_ns: int = 2, + run_pdbfixer: bool = False, + cpu_only: bool = False, + num_threads: int = 16, + use_openmp_threads: bool = False, + ld_seed: int = -1, + gen_seed: int = -1, + genion_seed: int = 0, + force: bool = False, + wait: bool = True, + max_parallel: int = 16, +) -> None: + """Run ShortMD production replicate workflow for a directory of PDB files. + + Args: + input_dir: Directory containing `.pdb` files. Each filename stem becomes + the prepared GROMACS run name. + run_id: Stable workflow run id for durable ledger state. Defaults to + the input directory name. + replicates: Number of short production replicates per input PDB. + simulation_time_ns: Production simulation length in nanoseconds for + each replicate and the prepared production TPR. + run_pdbfixer: Whether to run PDBFixer during preparation. + cpu_only: Whether to run GROMACS preparation and production on CPU only. + num_threads: Number of CPU threads to pass to GROMACS. + use_openmp_threads: Whether to use OpenMP threading in GROMACS. + ld_seed: Random seed for Langevin dynamics during preparation. + gen_seed: Random seed for initial velocity generation during preparation. + genion_seed: Random seed for ion placement during preparation. + force: Replace an existing workflow run ledger before running. + wait: Wait locally for the remote workflow result. Disable to print the + Modal function call id for asynchronous collection. + max_parallel: Maximum number of ready workflow nodes to execute + concurrently in one scheduler wave. + """ + input_path = Path(input_dir).expanduser().resolve() + input_pdbs = discover_pdb_inputs(input_path) + resolved_run_id = sanitize_filename(run_id or input_path.name) + workflow = build_shortmd_workflow( + input_pdbs=input_pdbs, + run_namespace=resolved_run_id, + replicates=replicates, + simulation_time_ns=simulation_time_ns, + run_pdbfixer=run_pdbfixer, + cpu_only=cpu_only, + num_threads=num_threads, + use_openmp_threads=use_openmp_threads, + ld_seed=ld_seed, + gen_seed=gen_seed, + genion_seed=genion_seed, + max_parallel=max_parallel, + overwrite_existing=force, + ) + + orchestrator_handle = orchestrator.WorkflowOrchestrator() + orchestrator_kwargs = { + "workflow": workflow, + "run_id": resolved_run_id, + "force": force, + "max_ready_workers": max_parallel, + } + print( + f"Submitting ShortMD workflow '{resolved_run_id}' with " + f"{len(input_pdbs)} input PDB(s), {replicates} replicate(s) each", + flush=True, + ) + if wait: + result: AppRunResult | str = AppRunResult.model_validate( + orchestrator_handle.run.remote(**orchestrator_kwargs) + ) + else: + function_call = orchestrator_handle.run.spawn(**orchestrator_kwargs) + result = str(getattr(function_call, "object_id", function_call)) + if isinstance(result, AppRunResult): + print(f"ShortMD workflow run finished with status: {result.status}", flush=True) + else: + print(f"ShortMD workflow run submitted. FunctionCall id: {result}", flush=True) diff --git a/tests/app/test_af3score_standalone_contract.py b/tests/app/test_af3score_standalone_contract.py new file mode 100644 index 0000000..fcf8d91 --- /dev/null +++ b/tests/app/test_af3score_standalone_contract.py @@ -0,0 +1,18 @@ +"""Tests for AF3Score standalone app contracts.""" + +# ruff: noqa: D103 + +import inspect + +from biomodals.app.score import af3score_app + + +def test_af3score_remote_functions_do_not_accept_path_payloads() -> None: + for function_name in ( + "af3score_prepare", + "af3score_run", + "af3score_postprocess", + ): + signature = inspect.signature(getattr(af3score_app, function_name).get_raw_f()) + assert "paths" not in signature.parameters + assert "Path" not in str(signature) diff --git a/tests/app/test_alphafold3_standalone_contract.py b/tests/app/test_alphafold3_standalone_contract.py new file mode 100644 index 0000000..3e5aa40 --- /dev/null +++ b/tests/app/test_alphafold3_standalone_contract.py @@ -0,0 +1,61 @@ +"""Tests for standalone AlphaFold3 app behavior.""" + +# ruff: noqa: D103 + +from pathlib import Path + +from uniaf3.schema.alphafold3 import AF3Config, AF3Protein, AF3SequenceEntry + +from biomodals.app.fold import alphafold3_app + + +def test_submit_alphafold3_task_applies_run_name_to_prediction_config( + tmp_path: Path, + monkeypatch, +) -> None: + input_json = tmp_path / "input.json" + conf = AF3Config( + name="original", + modelSeeds=[11, 12], + sequences=[ + AF3SequenceEntry(protein=AF3Protein(id="A", sequence="ACDE")), + ], + ) + input_json.write_text(conf.model_dump_json(), encoding="utf-8") + captured = {} + + def fake_predict_structures( + prediction_conf, + local_out_dir: Path, + recycle: int, + sample: int, + num_containers: int, + ) -> Path: + captured["name"] = prediction_conf.name + captured["model_seeds"] = list(prediction_conf.modelSeeds) + captured["local_out_dir"] = local_out_dir + captured["recycle"] = recycle + captured["sample"] = sample + captured["num_containers"] = num_containers + return local_out_dir / f"{prediction_conf.name}.tar.zst" + + monkeypatch.setattr(alphafold3_app, "predict_structures", fake_predict_structures) + + alphafold3_app.submit_alphafold3_task.info.raw_f( + input_json=str(input_json), + out_dir=str(tmp_path), + run_name="renamed", + search_msa=False, + max_num_gpus=4, + recycle=3, + sample=2, + ) + + assert captured == { + "name": "renamed", + "model_seeds": [11, 12], + "local_out_dir": tmp_path, + "recycle": 3, + "sample": 2, + "num_containers": 2, + } diff --git a/tests/app/test_catalog_workflow_apps.py b/tests/app/test_catalog_workflow_apps.py new file mode 100644 index 0000000..d57a924 --- /dev/null +++ b/tests/app/test_catalog_workflow_apps.py @@ -0,0 +1,115 @@ +"""Tests for Biomodals catalog workflow discovery.""" + +# ruff: noqa: D103 + +from pathlib import Path +from types import SimpleNamespace + +import modal +import pytest + +from biomodals.helper import catalog +from biomodals.helper.catalog import BiomodalsApp, get_catalog, include_dependency_apps + + +def test_default_catalog_does_not_collect_workflows() -> None: + apps = get_catalog("app", use_absolute_paths=True) + + assert "workflow-ppiflow" not in apps + assert apps["ppiflow"].name == "ppiflow_app.py" + assert "ppiflow_workflow" not in apps + assert "workflow-orchestrator" not in apps + + +def test_app_catalog_does_not_collect_workflow_scripts() -> None: + apps = get_catalog("app", use_absolute_paths=True) + + assert "workflow-ppiflow" not in apps + assert "ppiflow_workflow" not in apps + + +def test_workflow_catalog_discovers_natural_workflow_names() -> None: + workflows = get_catalog("workflow", use_absolute_paths=True) + + assert "ppiflow" in workflows + assert "shortmd" in workflows + assert "workflow-ppiflow" not in workflows + assert "orchestrator" not in workflows + assert workflows["ppiflow"].name == "ppiflow_workflow.py" + assert workflows["shortmd"].name == "shortmd_workflow.py" + + +def test_workflow_file_resolves_to_workflow_module_with_natural_name() -> None: + workflows = get_catalog("workflow", use_absolute_paths=True) + app = BiomodalsApp("ppiflow", all_apps=workflows) + + assert app.module == "biomodals.workflow.ppiflow_workflow" + assert app.category == "workflow" + + +def test_include_dependency_apps_resolves_catalog_app_and_includes_modal_app( + monkeypatch, +) -> None: + workflow_app = modal.App("workflow") + dependency_app = modal.App("dependency") + + @dependency_app.function(name="dependency_function", serialized=True) + def dependency_function() -> None: + return None + + class FakeBiomodalsApp: + def __init__(self, app_name_or_path: str, all_apps: dict[str, Path]) -> None: + assert app_name_or_path == "dependency" + assert all_apps == {"dependency": Path("/apps/dependency_app.py")} + self.module = "fake.dependency_app" + + monkeypatch.setattr( + catalog, + "get_catalog", + lambda catalog_type, *, use_absolute_paths=False, cwd=None: { + "dependency": Path("/apps/dependency_app.py") + }, + ) + monkeypatch.setattr(catalog, "BiomodalsApp", FakeBiomodalsApp) + monkeypatch.setattr( + catalog.importlib, + "import_module", + lambda module_name: SimpleNamespace(app=dependency_app), + ) + + assert include_dependency_apps(workflow_app, ("dependency",)) is workflow_app + assert "dependency_function" in workflow_app._local_state.functions + + +def test_include_dependency_apps_rejects_duplicate_modal_tags(monkeypatch) -> None: + workflow_app = modal.App("workflow") + dependency_app = modal.App("dependency") + + @workflow_app.function(name="duplicate_function", serialized=True) + def workflow_duplicate_function() -> None: + return None + + @dependency_app.function(name="duplicate_function", serialized=True) + def dependency_duplicate_function() -> None: + return None + + class FakeBiomodalsApp: + def __init__(self, app_name_or_path: str, all_apps: dict[str, Path]) -> None: + self.module = "fake.dependency_app" + + monkeypatch.setattr( + catalog, + "get_catalog", + lambda catalog_type, *, use_absolute_paths=False, cwd=None: { + "dependency": Path("/apps/dependency_app.py") + }, + ) + monkeypatch.setattr(catalog, "BiomodalsApp", FakeBiomodalsApp) + monkeypatch.setattr( + catalog.importlib, + "import_module", + lambda module_name: SimpleNamespace(app=dependency_app), + ) + + with pytest.raises(ValueError, match="duplicate_function"): + include_dependency_apps(workflow_app, ("dependency",)) diff --git a/tests/app/test_cli_workflow_catalog.py b/tests/app/test_cli_workflow_catalog.py new file mode 100644 index 0000000..f95791a --- /dev/null +++ b/tests/app/test_cli_workflow_catalog.py @@ -0,0 +1,105 @@ +"""Tests for workflow-aware CLI catalog loading.""" + +# ruff: noqa: D103 + +from dataclasses import dataclass +from pathlib import Path + +import pytest +from typer.testing import CliRunner + +from biomodals.cli import _load_entry, app +from biomodals.helper.catalog import AppFunction + +runner = CliRunner() + + +def test_cli_loads_workflow_namespace_names() -> None: + workflow = _load_entry("workflow", "ppiflow") + + assert workflow.module == "biomodals.workflow.ppiflow_workflow" + assert workflow.category == "workflow" + + +def test_workflow_list_command_shows_workflow_names_without_legacy_prefix() -> None: + result = runner.invoke(app, ["workflow", "list", "--short"]) + + assert result.exit_code == 0 + assert "ppiflow" in result.output + assert "workflow-ppiflow" not in result.output + assert "orchestrator" not in result.output + + +def test_app_list_command_is_namespaced() -> None: + result = runner.invoke(app, ["app", "list", "--short"]) + + assert result.exit_code == 0 + assert "rosetta" in result.output + + +def test_top_level_list_remains_app_compatibility_alias() -> None: + result = runner.invoke(app, ["list", "--short"]) + + assert result.exit_code == 0 + assert "rosetta" in result.output + + +def test_app_deploy_command_is_namespaced() -> None: + result = runner.invoke(app, ["app", "deploy", "--help"]) + + assert result.exit_code == 0 + assert "Name or path of the app to deploy" in result.output + + +def test_top_level_deploy_remains_app_compatibility_alias() -> None: + result = runner.invoke(app, ["deploy", "--help"]) + + assert result.exit_code == 0 + assert "Name or path of the app to deploy" in result.output + + +def test_workflow_run_rejects_files_outside_workflow_package(tmp_path: Path) -> None: + ad_hoc_workflow = tmp_path / "ad_hoc_workflow.py" + ad_hoc_workflow.write_text('"""Not a packaged Biomodals workflow."""\n') + + result = runner.invoke(app, ["workflow", "run", str(ad_hoc_workflow)]) + + assert result.exit_code == 1 + assert "Workflow paths must be under" in result.output + assert "biomodals.workflow" in result.output + + +@dataclass +class _FakeWorkflow: + name: str = "ambiguous" + module: str = "biomodals.workflow.ambiguous_workflow" + path: Path = Path("src/biomodals/workflow/ambiguous_workflow.py") + _entrypoint: str | None = None + + def __post_init__(self) -> None: + self._local_entrypoint_idx = [0, 1] + self.functions = [ + AppFunction("first", "local_entrypoint", None, []), + AppFunction("second", "local_entrypoint", None, []), + ] + + def __getitem__(self, name: str | int) -> AppFunction: + if isinstance(name, str): + for function in self.functions: + if function.name == name: + return function + raise KeyError(name) + return self.functions[name] + + +def test_workflow_run_requires_entrypoint_for_multiple_local_entrypoints( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr("biomodals.cli._load_entry", lambda *_args: _FakeWorkflow()) + + result = runner.invoke(app, ["workflow", "run", "ambiguous"]) + + assert result.exit_code == 1 + assert "contains multiple local entrypoints" in result.output + assert "::first" in result.output + assert "::second" in result.output diff --git a/tests/app/test_dockq_standalone_contract.py b/tests/app/test_dockq_standalone_contract.py new file mode 100644 index 0000000..d7c3a02 --- /dev/null +++ b/tests/app/test_dockq_standalone_contract.py @@ -0,0 +1,21 @@ +"""Tests for standalone DockQ app helper behavior.""" + +# ruff: noqa: D103 + +from pathlib import Path + +import pytest + +from biomodals.app.score import dockq_app +from biomodals.helper import io as helper_io + + +def test_dockq_reuses_shared_local_output_helpers() -> None: + assert dockq_app.build_local_output_path is helper_io.build_local_output_path + assert dockq_app.resolve_local_output_dir is helper_io.resolve_local_output_dir + assert dockq_app.write_local_tarball is helper_io.write_local_tarball + + +def test_build_local_output_path_reports_blank_run_name(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="must be non-empty"): + dockq_app.build_local_output_path(tmp_path, run_name=" ") diff --git a/tests/app/test_flowpacker_workflow_contract.py b/tests/app/test_flowpacker_workflow_contract.py new file mode 100644 index 0000000..d4e9325 --- /dev/null +++ b/tests/app/test_flowpacker_workflow_contract.py @@ -0,0 +1,140 @@ +"""Tests for FlowPacker workflow-compatible result contracts.""" + +# ruff: noqa: D103 + +from pathlib import Path +from types import SimpleNamespace + +import yaml + +from biomodals.app.fold import flowpacker_app +from biomodals.schema import AppRunStatus, ArtifactKind, VolumePath + + +def test_flowpacker_workflow_result_stores_archive_in_volume( + tmp_path, + monkeypatch, +) -> None: + seen_kwargs = {} + + class FakeRunFlowPacker: + def get_raw_f(self): + def fake_run_flowpacker(**kwargs): + seen_kwargs.update(kwargs) + return b"tarball" + + return fake_run_flowpacker + + class FakeOutputVolume: + def __init__(self): + self.commit_count = 0 + + def commit(self): + self.commit_count += 1 + + output_volume = FakeOutputVolume() + output_volume_name = flowpacker_app.CONF.output_volume_name + monkeypatch.setattr(flowpacker_app, "run_flowpacker", FakeRunFlowPacker()) + monkeypatch.setattr( + flowpacker_app, + "CONF", + SimpleNamespace( + output_volume=output_volume, + output_volume_mountpoint=str(tmp_path), + output_volume_name=output_volume_name, + ), + ) + + result = flowpacker_app.run_flowpacker_workflow.get_raw_f()( + input_files=[("input.pdb", b"ATOM\n")], + run_name="../packed", + ) + + assert seen_kwargs["run_name"] == "packed" + assert result.status == AppRunStatus.SUCCEEDED + assert len(result.outputs) == 1 + output = result.outputs[0] + assert output.name == "flowpacker_outputs" + assert output.kind == ArtifactKind.ARCHIVE + assert output.storage == VolumePath( + volume_name=output_volume_name, + path="workflow/packed/packed.tar.zst", + media_type="application/zstd", + ) + assert output.metadata == { + "archive_format": "tar.zst", + "filename": "packed.tar.zst", + } + assert ( + Path(tmp_path) / "workflow" / "packed" / "packed.tar.zst" + ).read_bytes() == b"tarball" + assert output_volume.commit_count == 1 + + +def test_flowpacker_config_uses_volume_checkpoint_paths(tmp_path) -> None: + config_path = tmp_path / "biomodals.yaml" + input_dir = tmp_path / "inputs" + input_dir.mkdir() + + flowpacker_app._write_flowpacker_config( + config_path, + input_dir=input_dir, + model_name="cluster", + use_confidence=True, + n_samples=1, + num_steps=10, + sample_coeff=5.0, + ) + + config = yaml.safe_load(config_path.read_text()) + assert Path(config["ckpt"]) == flowpacker_app._checkpoint_path("cluster") + assert Path(config["ckpt"]).is_absolute() + assert ( + Path(config["ckpt"]).parent == flowpacker_app.CONF.git_clone_dir / "checkpoints" + ) + assert Path(config["conf_ckpt"]) == flowpacker_app._checkpoint_path("confidence") + + +def test_flowpacker_checkpoint_download_copies_git_lfs_files_to_volume( + tmp_path, + monkeypatch, +) -> None: + class FakeModelVolume: + def __init__(self): + self.commit_count = 0 + + def commit(self): + self.commit_count += 1 + + git_clone_dir = tmp_path / "FlowPacker" + checkpoint_dir = git_clone_dir / "checkpoints" + cache_dir = tmp_path / "model-cache" + checkpoint_dir.mkdir(parents=True) + cache_dir.mkdir() + + fake_conf = SimpleNamespace( + git_clone_dir=git_clone_dir, + model_volume_mountpoint=str(cache_dir), + ) + fake_model_volume = FakeModelVolume() + + def fake_run_command(cmd, *, cwd=None, env=None): + if cmd[:3] == ["git", "lfs", "pull"]: + assert cwd == git_clone_dir + assert env == {"GIT_LFS_SKIP_SMUDGE": "0"} + for checkpoint_name in flowpacker_app.APP_INFO.checkpoint_names: + (checkpoint_dir / f"{checkpoint_name}.pth").write_bytes( + f"{checkpoint_name}-weights".encode() + ) + + monkeypatch.setattr(flowpacker_app, "CONF", fake_conf) + monkeypatch.setattr(flowpacker_app, "MODEL_VOLUME", fake_model_volume) + monkeypatch.setattr("biomodals.helper.shell.run_command", fake_run_command) + + flowpacker_app.download_flowpacker_checkpoints.get_raw_f()(force=False) + + for checkpoint_name in flowpacker_app.APP_INFO.checkpoint_names: + assert (cache_dir / f"{checkpoint_name}.pth").read_bytes() == ( + f"{checkpoint_name}-weights".encode() + ) + assert fake_model_volume.commit_count == 1 diff --git a/tests/app/test_gromacs_standalone_contract.py b/tests/app/test_gromacs_standalone_contract.py new file mode 100644 index 0000000..56b81e6 --- /dev/null +++ b/tests/app/test_gromacs_standalone_contract.py @@ -0,0 +1,111 @@ +"""Tests for standalone GROMACS app behavior used by workflows.""" + +# ruff: noqa: D101,D102,D103,D107 + +import shutil +from pathlib import Path +from types import SimpleNamespace + +from biomodals.app.bioinfo import gromacs_app + + +def test_submit_gromacs_task_keeps_single_run_standalone_flow( + tmp_path: Path, + monkeypatch, +) -> None: + pdb_path = tmp_path / "input.pdb" + pdb_path.write_text("ATOM\n", encoding="utf-8") + prepare_kwargs = {} + production_kwargs = {} + spawned_stats = [] + + class FakePrepare: + def remote(self, **kwargs): + prepare_kwargs.update(kwargs) + return f"{gromacs_app.CONF.output_volume_mountpoint}/single" + + class FakeProduction: + def remote(self, **kwargs): + production_kwargs.update(kwargs) + return f"{gromacs_app.CONF.output_volume_mountpoint}/single" + + class FakeStats: + def spawn(self, traj_prefix, **kwargs): + spawned_stats.append((traj_prefix, kwargs)) + return f"stats-{traj_prefix}" + + class FakeFunctionCall: + @staticmethod + def gather(*tasks): + return list(tasks) + + monkeypatch.setattr(gromacs_app, "prepare_tpr_cpu", FakePrepare()) + monkeypatch.setattr(gromacs_app, "production_run_cpu", FakeProduction()) + monkeypatch.setattr(gromacs_app, "collect_traj_stats", FakeStats()) + monkeypatch.setattr(gromacs_app.modal, "FunctionCall", FakeFunctionCall) + + gromacs_app.submit_gromacs_task.info.raw_f( + input_pdb=str(pdb_path), + run_name="single", + simulation_time_ns=3, + cpu_only=True, + num_threads=2, + ) + + assert prepare_kwargs["run_name"] == "single" + assert prepare_kwargs["pdb_content"] == b"ATOM\n" + assert production_kwargs == { + "run_name": "single", + "simulation_time_ns": 3, + "num_threads": 2, + "use_openmp_threads": False, + } + assert spawned_stats == [ + ("nvt_", {"run_name": "single"}), + ("npt_", {"run_name": "single"}), + ( + "production_", + {"run_name": "single", "save_processed_traj": True}, + ), + ] + + +def test_fresh_production_run_uses_mdp_nsteps(tmp_path: Path, monkeypatch) -> None: + work_path = tmp_path / "fresh" + work_path.mkdir() + work_path.joinpath("production_fresh.tpr").write_text("tpr\n", encoding="utf-8") + captured = {} + + class FakeVolume: + def __init__(self) -> None: + self.commit_count = 0 + + def commit(self) -> None: + self.commit_count += 1 + + volume = FakeVolume() + monkeypatch.setattr( + gromacs_app, + "CONF", + SimpleNamespace(output_volume_mountpoint=str(tmp_path), output_volume=volume), + ) + monkeypatch.setattr(shutil, "which", lambda name: "/usr/bin/gmx") + + def fake_run_command(cmd, *, cwd, env): + captured["cmd"] = cmd + captured["cwd"] = cwd + captured["env"] = env + return [] + + monkeypatch.setattr(gromacs_app, "run_command", fake_run_command) + + result = gromacs_app.production_run_cpu.get_raw_f()( + run_name="fresh", + simulation_time_ns=2, + ) + + nsteps_index = captured["cmd"].index("-nsteps") + assert captured["cmd"][nsteps_index + 1] == "-2" + assert captured["cwd"] == str(work_path) + assert result == str(work_path) + assert volume.commit_count == 1 diff --git a/tests/app/test_rosetta_standalone_contract.py b/tests/app/test_rosetta_standalone_contract.py new file mode 100644 index 0000000..e52a3db --- /dev/null +++ b/tests/app/test_rosetta_standalone_contract.py @@ -0,0 +1,87 @@ +"""Standalone contract tests for the Rosetta app.""" + +# ruff: noqa: D103 + +from pathlib import Path +from types import SimpleNamespace + +from biomodals.app.bioinfo import rosetta_app + + +def test_rosetta_no_local_output_reports_volume_path( + tmp_path: Path, + monkeypatch, + capsys, +) -> None: + input_pdb = tmp_path / "demo.pdb" + input_pdb.write_text("ATOM\n", encoding="utf-8") + uploaded = [] + queued = [] + deleted = [] + + class FakeBatch: + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + def put_file(self, local_path, remote_path): + uploaded.append((local_path, remote_path)) + + class FakeVolume: + def batch_upload(self): + return FakeBatch() + + class FakeQueue: + def put(self, item): + queued.append(item) + + monkeypatch.setattr( + rosetta_app, + "CONF", + SimpleNamespace( + name="Rosetta", + output_volume=FakeVolume(), + output_volume_mountpoint="/biomodals-outputs", + output_volume_name="Rosetta-outputs", + ), + ) + monkeypatch.setattr( + rosetta_app, + "uuid4", + lambda: SimpleNamespace(hex="abc123"), + ) + monkeypatch.setattr( + rosetta_app.modal, + "Queue", + SimpleNamespace( + from_name=lambda *args, **kwargs: FakeQueue(), + objects=SimpleNamespace(delete=lambda name: deleted.append(name)), + ), + ) + monkeypatch.setattr( + rosetta_app.modal, + "FunctionCall", + SimpleNamespace(gather=lambda *tasks: None), + ) + monkeypatch.setattr( + rosetta_app, + "run_rosetta", + SimpleNamespace(spawn=lambda *args: SimpleNamespace(object_id="call-1")), + ) + + rosetta_app.submit_rosetta_task( + rosetta_binary="relax", + input_pdb=str(input_pdb), + out_dir=None, + ) + + assert uploaded[0] == (input_pdb.resolve(), "/demo-abc123/1/demo.pdb") + assert uploaded[1][1] == "/demo-abc123/tasks.parquet" + assert queued[0]["pdb"] == "demo-abc123/1/demo.pdb" + assert deleted == ["Rosetta-queue-abc123"] + assert ( + "Results saved to 'demo-abc123' from volume 'Rosetta-outputs'" + in capsys.readouterr().out + ) diff --git a/tests/helper/test_volume_run.py b/tests/helper/test_volume_run.py new file mode 100644 index 0000000..afc2451 --- /dev/null +++ b/tests/helper/test_volume_run.py @@ -0,0 +1,51 @@ +"""Tests for reusable volume run helpers.""" + +from __future__ import annotations + +import pytest + +from biomodals.helper.volume_run import volume_path_from_mount_path +from biomodals.schema import VolumePath + + +def test_volume_path_from_mount_path_returns_relative_volume_path() -> None: + """Mount paths are converted to volume-relative storage paths.""" + assert volume_path_from_mount_path( + remote_path="/outputs/run-1/production", + mount_root="/outputs", + volume_name="Gromacs-outputs", + ) == VolumePath(volume_name="Gromacs-outputs", path="run-1/production") + + +def test_volume_path_from_mount_path_preserves_media_type() -> None: + """Optional media type is preserved on the returned storage object.""" + assert volume_path_from_mount_path( + remote_path="/outputs/run-1/archive.tar.zst", + mount_root="/outputs", + volume_name="FlowPacker-outputs", + media_type="application/zstd", + ) == VolumePath( + volume_name="FlowPacker-outputs", + path="run-1/archive.tar.zst", + media_type="application/zstd", + ) + + +def test_volume_path_from_mount_path_rejects_paths_outside_mount_root() -> None: + """Paths outside the mounted volume root are rejected.""" + with pytest.raises(ValueError, match="outside mounted volume root"): + volume_path_from_mount_path( + remote_path="/other/run-1", + mount_root="/outputs", + volume_name="Gromacs-outputs", + ) + + +def test_volume_path_from_mount_path_rejects_mount_root_itself() -> None: + """The mount root itself is not a valid artifact storage path.""" + with pytest.raises(ValueError, match="below mounted volume root"): + volume_path_from_mount_path( + remote_path="/outputs", + mount_root="/outputs", + volume_name="Gromacs-outputs", + ) diff --git a/tests/schema/test_workflow_schemas.py b/tests/schema/test_workflow_schemas.py new file mode 100644 index 0000000..f52383f --- /dev/null +++ b/tests/schema/test_workflow_schemas.py @@ -0,0 +1,160 @@ +"""Tests for shared workflow schema contracts.""" + +# ruff: noqa: D103 + +import ast +from pathlib import Path + +import pytest +from pydantic import ValidationError + +from biomodals.schema import ( + AppConfig, + AppOutput, + AppRunResult, + AppRunStatus, + ArtifactKind, + InlineBytes, + StorageKind, + VolumePath, + WorkflowArtifact, +) + + +def _valid_app_config(**overrides: object) -> AppConfig: + values = { + "name": "demo", + "package_name": "demo-package", + "version": "1.0.0", + } + values.update(overrides) + return AppConfig(**values) + + +def test_app_config_is_exported_from_schema_and_app_compatibility_module() -> None: + from biomodals.app.config import AppConfig as CompatAppConfig + + schema_config = _valid_app_config() + compat_config = CompatAppConfig( + name="demo", + package_name="demo-package", + version="1.0.0", + ) + + assert issubclass(CompatAppConfig, AppConfig) + assert ( + compat_config.model_dump(exclude={"output_volume", "output_volume_name"}) + == schema_config.model_dump() + ) + assert compat_config.model_volume_subdir == "/demo" + assert compat_config.git_clone_dir == Path("/opt/demo") + assert compat_config.cuda_version_numeric == "12.8.0" + assert compat_config.default_env["UV_TORCH_BACKEND"] == "cu128" + assert compat_config.output_volume_name == "demo-outputs" + assert hasattr(compat_config, "mounts") + + +def test_app_config_validates_source_reproducibility_and_runtime_bounds() -> None: + with pytest.raises(ValidationError, match="repo_url"): + AppConfig(name="missing-source", version="1.0.0") + + with pytest.raises(ValidationError, match="repo_commit_hash"): + AppConfig(name="missing-version", package_name="demo-package") + + with pytest.raises(ValidationError, match="CUDA version must start"): + _valid_app_config(cuda_version="12.8") + + with pytest.raises(ValidationError, match="CUDA 12.x"): + _valid_app_config(gpu="B200+", cuda_version="cu128") + + with pytest.raises(ValidationError, match="Timeout must be between"): + _valid_app_config(timeout=0) + + with pytest.raises(ValidationError, match="Timeout must be between"): + _valid_app_config(timeout=999_999) + + +def test_app_config_records_dependency_apps_without_modal_imports() -> None: + assert _valid_app_config().depends_on_apps == () + assert _valid_app_config(depends_on_apps=["gromacs"]).depends_on_apps == ( + "gromacs", + ) + + +def test_schema_modules_do_not_import_modal_app_or_workflow_packages() -> None: + schema_dir = Path(__file__).parents[2] / "src" / "biomodals" / "schema" + banned_import_roots = {"modal"} + banned_import_prefixes = ("biomodals.app", "biomodals.workflow") + + for path in sorted(schema_dir.glob("*.py")): + tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) + imported_modules: list[str] = [] + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + imported_modules.extend(alias.name for alias in node.names) + elif isinstance(node, ast.ImportFrom) and node.module is not None: + imported_modules.append(node.module) + + banned = [ + module + for module in imported_modules + if module.split(".", 1)[0] in banned_import_roots + or module.startswith(banned_import_prefixes) + ] + assert banned == [], f"{path.name} imports runtime-only modules: {banned}" + + +def test_inline_bytes_round_trip() -> None: + result = AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="packed", + kind=ArtifactKind.REPORT, + storage=InlineBytes( + data=b"hello\n", + filename="report.txt", + media_type="text/plain", + ), + ) + ], + ) + + dumped = result.model_dump_json() + loaded = AppRunResult.model_validate_json(dumped) + + assert loaded.outputs[0].storage.kind == StorageKind.INLINE_BYTES + assert isinstance(loaded.outputs[0].storage, InlineBytes) + assert loaded.outputs[0].storage.data == b"hello\n" + assert loaded.outputs[0].storage.filename == "report.txt" + assert "aGVsbG8K" not in dumped + assert "archive_format" not in InlineBytes.model_fields + + +def test_inline_bytes_rejects_binary_data_and_archive_metadata() -> None: + with pytest.raises(ValidationError, match="UTF-8"): + InlineBytes(data=b"\xff\x00", filename="binary.bin") + + with pytest.raises(ValidationError, match="archive_format"): + InlineBytes(data=b"text", filename="archive.zip", archive_format="zip") + + +def test_volume_path_rejects_absolute_and_traversal_paths() -> None: + for unsafe_path in ("/absolute/out", "../out", "a/../out", r"a\b"): + with pytest.raises(ValidationError, match="VolumePath.path"): + VolumePath(volume_name="Workflow-outputs", path=unsafe_path) + + +def test_workflow_artifact_is_volume_backed() -> None: + artifact = WorkflowArtifact( + artifact_id="art-packed", + producing_node_id="packed", + kind=ArtifactKind.STRUCTURES, + storage=VolumePath( + volume_name="Workflow-outputs", + path="ppiflow/run-1/artifacts/art-packed", + ), + ) + + assert artifact.storage.path == "ppiflow/run-1/artifacts/art-packed" diff --git a/tests/workflow/test_artifacts.py b/tests/workflow/test_artifacts.py new file mode 100644 index 0000000..b29916f --- /dev/null +++ b/tests/workflow/test_artifacts.py @@ -0,0 +1,443 @@ +"""Tests for local workflow artifact materialization.""" + +# ruff: noqa: D103 + +from pathlib import Path + +import pytest + +from biomodals.schema import ( + AppOutput, + AppRunResult, + AppRunStatus, + ArtifactKind, + InlineBytes, + VolumePath, +) +from biomodals.workflow.core.artifacts import materialize_app_run_result + + +def test_materialize_inline_bytes_writes_raw_and_volume_artifact( + tmp_path: Path, +) -> None: + result = AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="summary", + kind=ArtifactKind.REPORT, + storage=InlineBytes(data=b"ok\n", filename="summary.txt"), + ) + ], + ) + + artifacts = materialize_app_run_result( + result=result, + workflow_volume_name="Workflow-outputs", + attempt_dir=tmp_path / "nodes" / "summary" / "attempts" / "1", + artifact_dir=tmp_path / "artifacts", + producing_node_id="summary", + volume_root=tmp_path, + ) + + raw_path = tmp_path / "nodes" / "summary" / "attempts" / "1" / "raw_outputs" + materialized_path = ( + tmp_path + / "nodes" + / "summary" + / "attempts" + / "1" + / "materialized_outputs" + / "summary-summary" + / "summary.txt" + ) + assert raw_path.joinpath("summary.txt").read_bytes() == b"ok\n" + assert materialized_path.read_bytes() == b"ok\n" + assert artifacts[0].storage == VolumePath( + volume_name="Workflow-outputs", + path="nodes/summary/attempts/1/materialized_outputs/summary-summary", + ) + assert artifacts[0].files[0].path == "summary.txt" + assert (tmp_path / "artifacts" / "summary-summary.json").exists() + + +def test_materialized_inline_artifact_path_is_volume_relative( + tmp_path: Path, +) -> None: + run_root = tmp_path / "demo" / "run-1" + result = AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="summary", + kind=ArtifactKind.REPORT, + storage=InlineBytes(data=b"ok\n", filename="summary.txt"), + ) + ], + ) + + artifacts = materialize_app_run_result( + result=result, + workflow_volume_name="Workflow-outputs", + attempt_dir=run_root / "nodes" / "summary" / "attempts" / "attempt-1", + artifact_dir=run_root / "artifacts", + producing_node_id="summary", + volume_root=tmp_path, + ) + + assert artifacts[0].storage == VolumePath( + volume_name="Workflow-outputs", + path=( + "demo/run-1/nodes/summary/attempts/attempt-1/" + "materialized_outputs/summary-summary" + ), + ) + + +def test_materialize_inline_bytes_preserves_output_metadata( + tmp_path: Path, +) -> None: + result = AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="summary", + kind=ArtifactKind.REPORT, + storage=InlineBytes(data=b"ok\n", filename="summary.txt"), + metadata={"stage": "stage1"}, + ) + ], + ) + + artifacts = materialize_app_run_result( + result=result, + workflow_volume_name="Workflow-outputs", + attempt_dir=tmp_path / "attempt", + artifact_dir=tmp_path / "artifacts", + producing_node_id="summary", + volume_root=tmp_path, + ) + + assert artifacts[0].metadata == {"stage": "stage1"} + + +def test_materialize_app_run_result_persists_log_outputs_under_attempt_logs( + tmp_path: Path, +) -> None: + result = AppRunResult( + status=AppRunStatus.SUCCEEDED, + logs=[ + AppOutput( + name="stderr", + kind=ArtifactKind.LOGS, + storage=InlineBytes(data=b"warning\n", filename="stderr.log"), + metadata={"stream": "stderr"}, + ) + ], + ) + + artifacts = materialize_app_run_result( + result=result, + workflow_volume_name="Workflow-outputs", + attempt_dir=tmp_path / "attempt", + artifact_dir=tmp_path / "artifacts", + producing_node_id="node", + volume_root=tmp_path, + ) + + log_path = tmp_path / "attempt" / "logs" / "node-logs-stderr" / "stderr.log" + assert log_path.read_bytes() == b"warning\n" + assert artifacts[0].kind == ArtifactKind.LOGS + assert artifacts[0].source_app_output_name == "stderr" + assert artifacts[0].metadata == {"stream": "stderr"} + assert artifacts[0].storage == VolumePath( + volume_name="Workflow-outputs", + path="attempt/logs/node-logs-stderr", + ) + assert (tmp_path / "artifacts" / "node-logs-stderr.json").exists() + + +def test_materialize_volume_path_references_existing_remote_output( + tmp_path: Path, +) -> None: + result = AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="scores", + kind=ArtifactKind.SCORES, + storage=VolumePath( + volume_name="AF3Score-outputs", + path="run-1/af3score_metrics.csv", + ), + ) + ], + ) + + artifacts = materialize_app_run_result( + result=result, + workflow_volume_name="Workflow-outputs", + attempt_dir=tmp_path / "attempt", + artifact_dir=tmp_path / "artifacts", + producing_node_id="score", + ) + + assert artifacts[0].storage == VolumePath( + volume_name="AF3Score-outputs", + path="run-1/af3score_metrics.csv", + ) + assert (tmp_path / "artifacts" / "score-scores.json").exists() + + +def test_materialize_volume_path_can_copy_from_mounted_volume( + tmp_path: Path, +) -> None: + source_root = tmp_path / "source-volume" + source_dir = source_root / "runs" / "run-1" + source_dir.mkdir(parents=True) + source_dir.joinpath("scores.csv").write_text("score\n1\n", encoding="utf-8") + result = AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="scores", + kind=ArtifactKind.SCORES, + storage=VolumePath( + volume_name="AF3Score-outputs", + path="runs/run-1", + ), + ) + ], + ) + + artifacts = materialize_app_run_result( + result=result, + workflow_volume_name="Workflow-outputs", + attempt_dir=tmp_path / "workflow" / "attempt", + artifact_dir=tmp_path / "workflow" / "artifacts", + producing_node_id="score", + volume_root=tmp_path / "workflow", + volume_path_mode="copy", + volume_roots={"AF3Score-outputs": source_root}, + ) + + copied_file = ( + tmp_path + / "workflow" + / "attempt" + / "materialized_outputs" + / "score-scores" + / "scores.csv" + ) + assert copied_file.read_text(encoding="utf-8") == "score\n1\n" + assert artifacts[0].storage == VolumePath( + volume_name="Workflow-outputs", + path="attempt/materialized_outputs/score-scores", + ) + assert artifacts[0].files[0].path == "scores.csv" + + +def test_materialize_volume_path_copy_preserves_empty_directories( + tmp_path: Path, +) -> None: + source_root = tmp_path / "source-volume" + source_dir = source_root / "runs" / "run-1" + source_dir.mkdir(parents=True) + result = AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="scores", + kind=ArtifactKind.SCORES, + storage=VolumePath( + volume_name="AF3Score-outputs", + path="runs/run-1", + ), + ) + ], + ) + + artifacts = materialize_app_run_result( + result=result, + workflow_volume_name="Workflow-outputs", + attempt_dir=tmp_path / "workflow" / "attempt", + artifact_dir=tmp_path / "workflow" / "artifacts", + producing_node_id="score", + volume_root=tmp_path / "workflow", + volume_path_mode="copy", + volume_roots={"AF3Score-outputs": source_root}, + ) + + materialized_dir = ( + tmp_path / "workflow" / "attempt" / "materialized_outputs" / "score-scores" + ) + assert materialized_dir.is_dir() + assert artifacts[0].storage == VolumePath( + volume_name="Workflow-outputs", + path="attempt/materialized_outputs/score-scores", + ) + + +def test_materialize_volume_path_copy_rejects_traversal( + tmp_path: Path, +) -> None: + source_root = tmp_path / "source-volume" + source_root.mkdir() + result = AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="scores", + kind=ArtifactKind.SCORES, + storage=VolumePath.model_construct( + kind="volume_path", + volume_name="AF3Score-outputs", + path="../secret.csv", + ), + ) + ], + ) + + with pytest.raises(ValueError, match="relative"): + materialize_app_run_result( + result=result, + workflow_volume_name="Workflow-outputs", + attempt_dir=tmp_path / "workflow" / "attempt", + artifact_dir=tmp_path / "workflow" / "artifacts", + producing_node_id="score", + volume_root=tmp_path / "workflow", + volume_path_mode="copy", + volume_roots={"AF3Score-outputs": source_root}, + ) + + +def test_materialize_volume_path_copy_rejects_symlinked_children( + tmp_path: Path, +) -> None: + source_root = tmp_path / "source-volume" + source_dir = source_root / "runs" / "run-1" + source_dir.mkdir(parents=True) + source_dir.joinpath("scores.csv").write_text("score\n1\n", encoding="utf-8") + secret = tmp_path / "secret.csv" + secret.write_text("secret\n", encoding="utf-8") + source_dir.joinpath("secret-link.csv").symlink_to(secret) + result = AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="scores", + kind=ArtifactKind.SCORES, + storage=VolumePath( + volume_name="AF3Score-outputs", + path="runs/run-1", + ), + ) + ], + ) + + with pytest.raises(ValueError, match="symlink"): + materialize_app_run_result( + result=result, + workflow_volume_name="Workflow-outputs", + attempt_dir=tmp_path / "workflow" / "attempt", + artifact_dir=tmp_path / "workflow" / "artifacts", + producing_node_id="score", + volume_root=tmp_path / "workflow", + volume_path_mode="copy", + volume_roots={"AF3Score-outputs": source_root}, + ) + + +def test_materialize_volume_path_copy_rejects_symlink_path_component( + tmp_path: Path, +) -> None: + source_root = tmp_path / "source-volume" + real_dir = source_root / "real-run" + real_dir.mkdir(parents=True) + real_dir.joinpath("scores.csv").write_text("score\n1\n", encoding="utf-8") + source_root.joinpath("linked-run").symlink_to(real_dir, target_is_directory=True) + result = AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="scores", + kind=ArtifactKind.SCORES, + storage=VolumePath( + volume_name="AF3Score-outputs", + path="linked-run", + ), + ) + ], + ) + + with pytest.raises(ValueError, match="symlinks"): + materialize_app_run_result( + result=result, + workflow_volume_name="Workflow-outputs", + attempt_dir=tmp_path / "workflow" / "attempt", + artifact_dir=tmp_path / "workflow" / "artifacts", + producing_node_id="score", + volume_root=tmp_path / "workflow", + volume_path_mode="copy", + volume_roots={"AF3Score-outputs": source_root}, + ) + + +def test_materialize_inline_bytes_rejects_non_utf8_bytes( + tmp_path: Path, +) -> None: + result = AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="archive", + kind=ArtifactKind.REPORT, + storage=InlineBytes.model_construct( + data=b"\xff\x00", + filename="archive.tar.zst", + ), + ) + ], + ) + + with pytest.raises(ValueError, match="UTF-8 text"): + materialize_app_run_result( + result=result, + workflow_volume_name="Workflow-outputs", + attempt_dir=tmp_path / "attempt", + artifact_dir=tmp_path / "artifacts", + producing_node_id="pack", + ) + + +def test_archive_outputs_use_volume_path_metadata(tmp_path: Path) -> None: + result = AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="models", + kind=ArtifactKind.ARCHIVE, + storage=VolumePath( + volume_name="FlowPacker-outputs", + path="workflow/packed/packed.tar.zst", + media_type="application/zstd", + ), + metadata={"archive_format": "tar.zst"}, + ) + ], + ) + + artifacts = materialize_app_run_result( + result=result, + workflow_volume_name="Workflow-outputs", + attempt_dir=tmp_path / "attempt", + artifact_dir=tmp_path / "artifacts", + producing_node_id="pack", + ) + + assert artifacts[0].storage == VolumePath( + volume_name="FlowPacker-outputs", + path="workflow/packed/packed.tar.zst", + media_type="application/zstd", + ) + assert artifacts[0].metadata == {"archive_format": "tar.zst"} diff --git a/tests/workflow/test_builder.py b/tests/workflow/test_builder.py new file mode 100644 index 0000000..74bee5e --- /dev/null +++ b/tests/workflow/test_builder.py @@ -0,0 +1,79 @@ +"""Tests for the Python workflow builder.""" + +# ruff: noqa: D101,D102,D103 + +from dataclasses import fields + +import pytest + +import biomodals.workflow as workflow_api +from biomodals.schema import ArtifactKind +from biomodals.workflow import Workflow +from biomodals.workflow.core.builder import NodeHandle +from biomodals.workflow.core.nodes import WorkflowNativeNode + + +class DummyNode(WorkflowNativeNode): + def run(self, context): # pragma: no cover - builder tests do not execute nodes + raise NotImplementedError + + +def test_selector_input_creates_data_dependency() -> None: + workflow = Workflow("demo") + upstream = workflow.add_node(DummyNode(), id="design") + downstream = workflow.add_node( + DummyNode(), + id="score", + inputs={ + "structures": upstream.outputs( + kind=ArtifactKind.STRUCTURES, + pattern="**/*.pdb", + ) + }, + ) + + definition = workflow.validate() + + assert definition.dependencies["score"] == {"design"} + assert definition.nodes["score"].inputs["structures"].producing_node_id == "design" + assert downstream.node_id == "score" + + +def test_node_handle_exposes_only_node_id_and_selector_api() -> None: + assert [field.name for field in fields(NodeHandle)] == ["node_id"] + assert not hasattr(workflow_api, "NodeOutputRef") + + +def test_depends_on_creates_control_edge() -> None: + workflow = Workflow("demo") + ranked = workflow.add_node(DummyNode(), id="ranked") + packaged = workflow.add_node(DummyNode(), id="package", depends_on=[ranked]) + + definition = workflow.validate() + + assert definition.dependencies["package"] == {"ranked"} + assert definition.nodes["package"].control_dependencies == {"ranked"} + assert packaged.node_id == "package" + + +def test_duplicate_node_ids_raise_value_error() -> None: + workflow = Workflow("demo") + workflow.add_node(DummyNode(), id="design") + + with pytest.raises(ValueError, match="Duplicate workflow node id"): + workflow.add_node(DummyNode(), id="design") + + +def test_empty_sanitized_workflow_name_raises_value_error() -> None: + with pytest.raises(ValueError, match="safe filename"): + Workflow("///") + + +def test_cycles_raise_value_error() -> None: + workflow = Workflow("demo") + first = workflow.add_node(DummyNode(), id="first") + second = workflow.add_node(DummyNode(), id="second", depends_on=[first]) + workflow.add_control_edge(second, first) + + with pytest.raises(ValueError, match="cycle"): + workflow.validate() diff --git a/tests/workflow/test_ledger.py b/tests/workflow/test_ledger.py new file mode 100644 index 0000000..a4e0760 --- /dev/null +++ b/tests/workflow/test_ledger.py @@ -0,0 +1,251 @@ +"""Tests for the SQLite-backed workflow ledger.""" + +# ruff: noqa: D103 + +import sqlite3 +from pathlib import Path + +import orjson + +from biomodals.schema import ( + AppOutput, + AppRunResult, + AppRunStatus, + ArtifactFile, + ArtifactKind, + InlineBytes, + NodeStatus, + RunStatus, + VolumePath, + WorkflowArtifact, + WorkflowRun, +) +from biomodals.workflow.core import ledger as ledger_module +from biomodals.workflow.core.ledger import LEDGER_TABLES, WorkflowLedger + + +def _connect(tmp_path: Path) -> sqlite3.Connection: + conn = sqlite3.connect(tmp_path / "ppiflow" / "run-1" / "ledger.sqlite3") + conn.row_factory = sqlite3.Row + return conn + + +def test_create_run_initializes_sqlite_ledger_and_documented_tables( + tmp_path: Path, +) -> None: + ledger = WorkflowLedger(tmp_path) + run = WorkflowRun( + workflow_name="ppiflow", + run_id="run-1", + dag_hash="abc123", + ) + + created = ledger.create_run(run) + loaded = ledger.load_run("ppiflow", "run-1") + + assert created == run + assert loaded == run + assert tmp_path.joinpath("ppiflow", "run-1", "ledger.sqlite3").exists() + + with _connect(tmp_path) as conn: + tables = { + row["name"] + for row in conn.execute( + "SELECT name FROM sqlite_master WHERE type = 'table'" + ) + } + run_row = conn.execute("SELECT * FROM runs WHERE run_id = ?", ("run-1",)) + row = run_row.fetchone() + + assert set(LEDGER_TABLES).issubset(tables) + assert row["workflow_name"] == "ppiflow" + assert row["dag_hash"] == "abc123" + assert row["status"] == RunStatus.PENDING + + +def test_mark_run_status_updates_run_row(tmp_path: Path) -> None: + ledger = WorkflowLedger(tmp_path) + ledger.create_run(WorkflowRun(workflow_name="ppiflow", run_id="run-1")) + + updated = ledger.mark_run_status(RunStatus.RUNNING) + loaded = ledger.load_run("ppiflow", "run-1") + + assert updated.status == RunStatus.RUNNING + assert loaded.status == RunStatus.RUNNING + with _connect(tmp_path) as conn: + row = conn.execute("SELECT status FROM runs WHERE run_id = 'run-1'").fetchone() + assert row["status"] == RunStatus.RUNNING + + +def test_run_exists_closes_probe_connection(tmp_path: Path, monkeypatch) -> None: + ledger_path = tmp_path / "ppiflow" / "run-1" / "ledger.sqlite3" + ledger_path.parent.mkdir(parents=True) + ledger_path.touch() + + class FakeCursor: + def fetchone(self): + return (1,) + + class FakeConnection: + def __init__(self) -> None: + self.closed = False + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, traceback): + return None + + def execute(self, sql, params): + return FakeCursor() + + def close(self) -> None: + self.closed = True + + connection = FakeConnection() + monkeypatch.setattr( + ledger_module.sqlite3, + "connect", + lambda path: connection, + ) + + assert WorkflowLedger(tmp_path).run_exists("ppiflow", "run-1") + assert connection.closed is True + + +def test_node_attempt_app_result_and_artifacts_are_debuggable_with_sql( + tmp_path: Path, +) -> None: + ledger = WorkflowLedger(tmp_path) + ledger.create_run(WorkflowRun(workflow_name="ppiflow", run_id="run-1")) + + status = ledger.mark_node_running("design", "attempt-1") + attempt = ledger.record_attempt_started("design", "attempt-1") + ledger.record_app_result( + "design", + "attempt-1", + AppRunResult(status=AppRunStatus.SUCCEEDED), + ) + artifact = WorkflowArtifact( + artifact_id="artifact-1", + producing_node_id="design", + kind=ArtifactKind.STRUCTURES, + storage=VolumePath( + volume_name="Workflow-outputs", + path="ppiflow/run-1/artifacts/artifact-1", + ), + files=[ + ArtifactFile( + path="model.pdb", + role="structure", + media_type="chemical/x-pdb", + size_bytes=12, + ) + ], + ) + ledger.record_artifacts([artifact]) + succeeded = ledger.mark_node_succeeded("design", ["artifact-1"]) + + assert status.status == NodeStatus.RUNNING + assert attempt.attempt_id == "attempt-1" + assert succeeded.status == NodeStatus.SUCCEEDED + assert ledger.node_is_complete("design") + + with _connect(tmp_path) as conn: + node = conn.execute("SELECT * FROM nodes WHERE node_id = 'design'").fetchone() + attempt_row = conn.execute( + "SELECT * FROM attempts WHERE node_id = 'design'" + ).fetchone() + artifact_row = conn.execute( + "SELECT * FROM artifacts WHERE artifact_id = 'artifact-1'" + ).fetchone() + file_row = conn.execute( + "SELECT * FROM artifact_files WHERE artifact_id = 'artifact-1'" + ).fetchone() + output_row = conn.execute( + "SELECT * FROM node_outputs WHERE node_id = 'design'" + ).fetchone() + + assert node["status"] == NodeStatus.SUCCEEDED + assert node["current_attempt_id"] == "attempt-1" + assert ( + attempt_row["app_result_json"] + == AppRunResult(status=AppRunStatus.SUCCEEDED).model_dump_json() + ) + assert artifact_row["storage_path"] == "ppiflow/run-1/artifacts/artifact-1" + assert file_row["path"] == "model.pdb" + assert output_row["artifact_id"] == "artifact-1" + + +def test_node_is_not_complete_when_artifact_row_is_missing(tmp_path: Path) -> None: + ledger = WorkflowLedger(tmp_path) + ledger.create_run(WorkflowRun(workflow_name="ppiflow", run_id="run-1")) + ledger.mark_node_running("design", "attempt-1") + ledger.mark_node_succeeded("design", ["missing-artifact"]) + + assert not ledger.node_is_complete("design") + + +def test_record_app_result_stores_pydantic_inline_bytes_json(tmp_path: Path) -> None: + ledger = WorkflowLedger(tmp_path) + ledger.create_run(WorkflowRun(workflow_name="ppiflow", run_id="run-1")) + ledger.mark_node_running("node-1", "attempt-1") + ledger.record_attempt_started("node-1", "attempt-1") + result = AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="archive", + kind=ArtifactKind.ARCHIVE, + storage=InlineBytes(data=b"ok", filename="archive.txt"), + ) + ], + ) + + ledger.record_app_result("node-1", "attempt-1", result) + + with _connect(tmp_path) as conn: + row = conn.execute( + """ + SELECT app_result_json + FROM attempts + WHERE node_id = 'node-1' AND attempt_id = 'attempt-1' + """ + ).fetchone() + + data = orjson.loads(row["app_result_json"]) + assert isinstance(data["outputs"][0]["storage"]["data"], str) + + +def test_next_attempt_id_uses_highest_numeric_suffix(tmp_path: Path) -> None: + ledger = WorkflowLedger(tmp_path) + ledger.create_run(WorkflowRun(workflow_name="ppiflow", run_id="run-1")) + ledger.record_attempt_started("design", "attempt-2") + ledger.record_attempt_started("design", "attempt-uuid") + + assert ledger.next_attempt_id("design") == "attempt-3" + + +def test_remote_call_rows_are_human_debuggable(tmp_path: Path) -> None: + ledger = WorkflowLedger(tmp_path) + ledger.create_run(WorkflowRun(workflow_name="ppiflow", run_id="run-1")) + ledger.mark_node_running("remote", "attempt-1") + ledger.record_attempt_started("remote", "attempt-1") + + ledger.record_remote_call( + call_id="fc-123", + node_id="remote", + attempt_id="attempt-1", + function_name="run_node", + call_kind="node", + ) + ledger.mark_remote_call_status("fc-123", "running") + + with _connect(tmp_path) as conn: + row = conn.execute("SELECT * FROM remote_calls WHERE call_id = 'fc-123'") + remote_call = row.fetchone() + + assert remote_call["node_id"] == "remote" + assert remote_call["attempt_id"] == "attempt-1" + assert remote_call["status"] == "running" + assert remote_call["function_name"] == "run_node" diff --git a/tests/workflow/test_nodes.py b/tests/workflow/test_nodes.py new file mode 100644 index 0000000..d07e8d5 --- /dev/null +++ b/tests/workflow/test_nodes.py @@ -0,0 +1,20 @@ +"""Tests for reusable workflow node helpers.""" + +# ruff: noqa: D101,D102,D103,D107 + +from biomodals.schema import NodeExecutionPolicy, NodePlacement +from biomodals.workflow.core.nodes import AppBackedNode + + +def test_app_backed_node_is_marker_base_for_remote_app_work() -> None: + node = AppBackedNode() + + assert node.execution_policy == NodeExecutionPolicy.RERUN + assert node.placement == NodePlacement.REMOTE + + +def test_app_backed_node_no_longer_owns_modal_lookup_api() -> None: + assert not hasattr(AppBackedNode, "app_name") + assert not hasattr(AppBackedNode, "function_name") + assert not hasattr(AppBackedNode, "load_app_function") + assert not hasattr(AppBackedNode, "invoke_app_function") diff --git a/tests/workflow/test_orchestrator.py b/tests/workflow/test_orchestrator.py new file mode 100644 index 0000000..8ebca42 --- /dev/null +++ b/tests/workflow/test_orchestrator.py @@ -0,0 +1,375 @@ +"""Tests for the mocked workflow orchestrator boundary.""" + +# ruff: noqa: D101,D102,D103,D107 + +from pathlib import Path +from typing import Any, cast + +import pytest + +from biomodals.helper.constant import WORKFLOW_ORCHESTRATOR_VOLUME_NAME +from biomodals.schema import AppRunResult, AppRunStatus +from biomodals.workflow import Workflow +from biomodals.workflow.core import orchestrator +from biomodals.workflow.core.nodes import NodeRunContext, WorkflowNativeNode + + +class SucceedNode(WorkflowNativeNode): + def run(self, context): + return AppRunResult(status=AppRunStatus.SUCCEEDED) + + +class RaisingNode(WorkflowNativeNode): + def run(self, context): + raise RuntimeError("remote failed") + + +class FakeVolume: + def __init__(self) -> None: + self.commit_count = 0 + self.reload_count = 0 + + def commit(self) -> None: + self.commit_count += 1 + + def reload(self) -> None: + self.reload_count += 1 + + +def _raw_orchestrator() -> tuple[Any, Any]: + raw_cls = cast(Any, orchestrator.WorkflowOrchestrator)._get_user_cls() + return raw_cls, raw_cls() + + +def test_orchestrator_run_constructs_runtime(monkeypatch) -> None: + calls: dict[str, object] = {} + volume = FakeVolume() + monkeypatch.setattr(orchestrator, "OUT_VOLUME", volume) + + class FakeRuntime: + def __init__( + self, + *, + workflow: Workflow, + volume_root: Path, + workflow_volume_name: str | None = None, + workflow_volume=None, + remote_node_runner=None, + remote_node_function_name=None, + function_call_resolver=None, + max_ready_workers: int = 32, + ): + calls["workflow"] = workflow + calls["volume_root"] = volume_root + calls["workflow_volume_name"] = workflow_volume_name + calls["workflow_volume"] = workflow_volume + calls["remote_node_runner"] = remote_node_runner + calls["remote_node_function_name"] = remote_node_function_name + calls["function_call_resolver"] = function_call_resolver + calls["max_ready_workers"] = max_ready_workers + + def run(self, *, run_id: str, force: bool = False) -> AppRunResult: + calls["run_id"] = run_id + calls["force"] = force + return AppRunResult(status=AppRunStatus.SUCCEEDED) + + def close(self) -> None: + calls["closed"] = True + + monkeypatch.setattr(orchestrator, "WorkflowRuntime", FakeRuntime) + + raw_cls, instance = _raw_orchestrator() + workflow = Workflow("demo") + result = raw_cls.run._get_raw_f()( + instance, + workflow=workflow, + run_id="run-1", + force=True, + max_ready_workers=7, + ) + + assert result.status == AppRunStatus.SUCCEEDED + assert calls == { + "workflow": workflow, + "volume_root": Path(orchestrator.CONF.output_volume_mountpoint), + "workflow_volume_name": WORKFLOW_ORCHESTRATOR_VOLUME_NAME, + "workflow_volume": volume, + "remote_node_runner": calls["remote_node_runner"], + "remote_node_function_name": orchestrator.REMOTE_NODE_FUNCTION_NAME, + "function_call_resolver": calls["function_call_resolver"], + "max_ready_workers": 7, + "run_id": "run-1", + "force": True, + "closed": True, + } + assert calls["remote_node_runner"] is not None + assert callable(calls["function_call_resolver"]) + assert volume.reload_count == 1 + assert volume.commit_count == 1 + + +def test_orchestrator_run_passes_remote_node_runner_and_resolver( + monkeypatch, + tmp_path: Path, +) -> None: + calls: dict[str, object] = {} + volume = FakeVolume() + monkeypatch.setattr(orchestrator, "OUT_VOLUME", volume) + + class FakeRuntime: + def __init__( + self, + *, + workflow: Workflow, + volume_root: Path, + workflow_volume_name: str | None = None, + workflow_volume=None, + remote_node_runner=None, + remote_node_function_name=None, + function_call_resolver=None, + max_ready_workers: int = 32, + ) -> None: + self.remote_node_runner = remote_node_runner + calls["remote_node_runner"] = remote_node_runner + calls["remote_node_function_name"] = remote_node_function_name + calls["function_call_resolver"] = function_call_resolver + calls["max_ready_workers"] = max_ready_workers + + def run(self, *, run_id: str, force: bool = False) -> AppRunResult: + assert self.remote_node_runner is not None + context = NodeRunContext( + run_id=run_id, + node_id="remote", + attempt_id="attempt-1", + cache_dir=tmp_path / "cache", + inputs={}, + ) + calls["remote_call"] = self.remote_node_runner(SucceedNode(), context) + return AppRunResult(status=AppRunStatus.SUCCEEDED) + + def close(self) -> None: + calls["closed"] = True + + monkeypatch.setattr(orchestrator, "WorkflowRuntime", FakeRuntime) + + class FakeFunctionCall: + object_id = "fc-remote" + + def get(self): + return AppRunResult(status=AppRunStatus.SUCCEEDED) + + raw_cls, instance = _raw_orchestrator() + instance.run_node = cast( + Any, + type( + "FakeRemoteMethod", + (), + { + "__name__": "run_node", + "spawn": lambda self, node, context: ( + calls.update({ + "remote_method": self, + "remote_node": node, + "remote_context": context, + }) + or FakeFunctionCall() + ), + }, + )(), + ) + result = raw_cls.run._get_raw_f()( + instance, + workflow=Workflow("demo"), + run_id="run-1", + max_ready_workers=9, + ) + + assert result.status == AppRunStatus.SUCCEEDED + assert calls["remote_node_runner"] is not None + assert calls["remote_node_function_name"] == orchestrator.REMOTE_NODE_FUNCTION_NAME + assert callable(calls["function_call_resolver"]) + assert calls["max_ready_workers"] == 9 + assert getattr(calls["remote_method"], "__name__", None) == "run_node" + assert isinstance(calls["remote_node"], SucceedNode) + remote_context = cast(NodeRunContext, calls["remote_context"]) + remote_call = cast(Any, calls["remote_call"]) + assert remote_context.run_id == "run-1" + assert remote_call.object_id == "fc-remote" + assert calls["closed"] is True + assert volume.reload_count == 1 + assert volume.commit_count == 1 + + +def test_orchestrator_enter_closes_stale_runtime_before_reload(monkeypatch) -> None: + volume = FakeVolume() + monkeypatch.setattr(orchestrator, "OUT_VOLUME", volume) + + class FakeRuntime: + def __init__(self) -> None: + self.close_count = 0 + + def close(self) -> None: + self.close_count += 1 + + raw_cls, instance = _raw_orchestrator() + runtime = FakeRuntime() + instance._runtime = runtime + + raw_cls.enter._get_raw_f()(instance) + + assert runtime.close_count == 1 + assert instance._runtime is None + assert volume.reload_count == 1 + + +def test_orchestrator_exit_cancels_active_remote_calls_once(monkeypatch) -> None: + volume = FakeVolume() + monkeypatch.setattr(orchestrator, "OUT_VOLUME", volume) + + class FakeRuntime: + def __init__(self) -> None: + self.cancel_count = 0 + self.close_count = 0 + + def cancel_active_remote_calls(self, *, terminate_containers: bool) -> None: + assert terminate_containers is True + self.cancel_count += 1 + + def close(self) -> None: + self.close_count += 1 + + raw_cls, instance = _raw_orchestrator() + runtime = FakeRuntime() + instance._runtime = runtime + + raw_cls.exit._get_raw_f()(instance) + raw_cls.exit._get_raw_f()(instance) + + assert runtime.cancel_count == 1 + assert runtime.close_count == 1 + assert instance._runtime is None + assert volume.commit_count == 2 + + +def test_orchestrator_exit_still_closes_and_commits_when_cancel_fails( + monkeypatch, +) -> None: + volume = FakeVolume() + monkeypatch.setattr(orchestrator, "OUT_VOLUME", volume) + + class FakeRuntime: + def __init__(self) -> None: + self.close_count = 0 + + def cancel_active_remote_calls(self, *, terminate_containers: bool) -> None: + raise RuntimeError("cancel failed") + + def close(self) -> None: + self.close_count += 1 + + raw_cls, instance = _raw_orchestrator() + runtime = FakeRuntime() + instance._runtime = runtime + + raw_cls.exit._get_raw_f()(instance) + + assert runtime.close_count == 1 + assert instance._runtime is None + assert volume.commit_count == 1 + + +def test_orchestrator_remote_node_runner_rejects_non_function_call( + tmp_path: Path, + monkeypatch, +) -> None: + calls: dict[str, object] = {} + monkeypatch.setattr(orchestrator, "OUT_VOLUME", FakeVolume()) + + class FakeRuntime: + def __init__(self, *, remote_node_runner=None, **kwargs) -> None: + self.remote_node_runner = remote_node_runner + + def run(self, *, run_id: str, force: bool = False) -> AppRunResult: + assert self.remote_node_runner is not None + context = NodeRunContext( + run_id=run_id, + node_id="remote", + attempt_id="attempt-1", + cache_dir=tmp_path / "cache", + inputs={}, + ) + calls["remote_call"] = self.remote_node_runner(SucceedNode(), context) + return AppRunResult(status=AppRunStatus.SUCCEEDED) + + def close(self) -> None: + calls["closed"] = True + + monkeypatch.setattr(orchestrator, "WorkflowRuntime", FakeRuntime) + + raw_cls, instance = _raw_orchestrator() + instance.run_node = cast( + Any, + type( + "BadRemoteMethod", + (), + {"spawn": lambda self, *args: object()}, + )(), + ) + with pytest.raises(TypeError, match="FunctionCall"): + raw_cls.run._get_raw_f()( + instance, + workflow=Workflow("demo"), + run_id="run-1", + ) + + +def test_orchestrator_remote_node_method_commits_after_exception( + tmp_path: Path, + monkeypatch, +) -> None: + volume = FakeVolume() + monkeypatch.setattr(orchestrator, "OUT_VOLUME", volume) + context = NodeRunContext( + run_id="run-1", + node_id="remote", + attempt_id="attempt-1", + cache_dir=tmp_path / "cache", + inputs={}, + ) + + with pytest.raises(RuntimeError, match="remote failed"): + raw_cls, instance = _raw_orchestrator() + raw_cls.run_node._get_raw_f()( + instance, + RaisingNode(), + context, + ) + + assert volume.reload_count == 1 + assert volume.commit_count == 1 + + +def test_orchestrator_rejects_serialized_workflow_dict() -> None: + raw_cls, instance = _raw_orchestrator() + + with pytest.raises(TypeError, match="Workflow object"): + raw_cls.run._get_raw_f()( + instance, + workflow={"nodes": []}, + run_id="run-1", + ) + + +def test_orchestrator_modal_app_uses_python_313_runtime() -> None: + assert orchestrator.CONF.python_version == "3.13" + assert orchestrator.WorkflowOrchestrator is not None + assert orchestrator.OUT_VOLUME_NAME == WORKFLOW_ORCHESTRATOR_VOLUME_NAME + + +def test_orchestrator_app_exposes_only_class_remote_surface() -> None: + functions = orchestrator.app._local_state.functions + + assert "WorkflowOrchestrator.*" in functions + assert "run_workflow_orchestrator" not in functions + assert "run_remote_workflow_node" not in functions + assert "submit_workflow_orchestrator_task" not in functions diff --git a/tests/workflow/test_ppiflow_workflow.py b/tests/workflow/test_ppiflow_workflow.py new file mode 100644 index 0000000..d46140a --- /dev/null +++ b/tests/workflow/test_ppiflow_workflow.py @@ -0,0 +1,322 @@ +"""Tests for the PPIFlow workflow definition.""" + +# ruff: noqa: D103 + +from pathlib import Path +from types import SimpleNamespace +from typing import cast + +import modal + +from biomodals.app.design import ppiflow_app +from biomodals.schema import ( + AppOutput, + AppRunResult, + AppRunStatus, + ArtifactKind, + VolumePath, +) +from biomodals.workflow.core import NodeRunContext +from biomodals.workflow.ppiflow_workflow import ( + CONF, + PPIFlowModalNamespace, + _active_ppiflow_app_steps, + _stage_ppiflow_app_inputs, + build_ppiflow_workflow, +) + + +class _FakePPIFlowFunction: + def __init__(self) -> None: + self.kwargs = {} + + def remote(self, **kwargs): + self.kwargs = kwargs + return AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="ppiflow_outputs", + kind=ArtifactKind.DIRECTORY, + storage=VolumePath( + volume_name=ppiflow_app.CONF.output_volume_name, + path="demo-run", + ), + ) + ], + ) + + +def _task_yaml(*, enabled_steps: str) -> bytes: + return f""" +task: + gentype: binder +steps: +{enabled_steps} +""".encode() + + +def test_ppiflow_workflow_declares_app_dependency() -> None: + assert CONF.depends_on_apps == ("ppiflow",) + assert CONF.tags == {"depends_on": "ppiflow"} + + +def test_ppiflow_app_step_uses_included_modal_namespace(tmp_path: Path) -> None: + fake_function = _FakePPIFlowFunction() + namespace = PPIFlowModalNamespace( + ppiflow_run=cast(modal.Function, fake_function), + ) + workflow = build_ppiflow_workflow( + task_yaml_bytes=_task_yaml(enabled_steps=" PPIFlowStep: true\n"), + steps_yaml_bytes=b""" +PPIFlowStep: + run_name: demo-run + args: + name: demo + specified_hotspots: A1 + input_pdb: /inputs/demo.pdb + binder_chain: B +""", + modal_namespace=namespace, + ) + + definition = workflow.validate() + spec = definition.nodes["stage1-ppiflow-design"] + result = spec.node.run( + NodeRunContext( + run_id="run-1", + node_id=spec.node_id, + attempt_id="attempt-1", + cache_dir=tmp_path, + inputs={}, + ) + ) + + assert result.status == AppRunStatus.SUCCEEDED + assert fake_function.kwargs["run_name"] == "demo-run" + assert isinstance(fake_function.kwargs["args"], ppiflow_app.PPIFlowArgs) + assert result.outputs[0].storage == VolumePath( + volume_name=ppiflow_app.CONF.output_volume_name, + path="demo-run", + ) + + +def test_ppiflow_unsupported_steps_fail_with_clear_adapter_error( + tmp_path: Path, +) -> None: + fake_function = _FakePPIFlowFunction() + namespace = PPIFlowModalNamespace( + ppiflow_run=cast(modal.Function, fake_function), + ) + workflow = build_ppiflow_workflow( + task_yaml_bytes=_task_yaml(enabled_steps=" FlowpackerStep_stage1: true\n"), + steps_yaml_bytes=b"FlowpackerStep_stage1: {}\n", + modal_namespace=namespace, + ) + + spec = workflow.validate().nodes["stage1-flowpacker"] + try: + spec.node.run( + NodeRunContext( + run_id="run-1", + node_id=spec.node_id, + attempt_id="attempt-1", + cache_dir=tmp_path, + inputs={}, + ) + ) + except NotImplementedError as exc: + assert "workflow-compatible app adapter" in str(exc) + else: + raise AssertionError("unsupported PPIFlow step should fail clearly") + + +def test_ppiflow_entrypoint_stages_local_app_inputs( + tmp_path: Path, + monkeypatch, +) -> None: + input_pdb = tmp_path / "input.pdb" + input_pdb.write_text("ATOM\n", encoding="utf-8") + uploaded = [] + + class FakeBatch: + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + def put_file(self, local_path, remote_path): + uploaded.append((Path(local_path), remote_path)) + + class FakeVolume: + def batch_upload(self): + return FakeBatch() + + monkeypatch.setattr( + ppiflow_app, + "CONF", + SimpleNamespace( + output_volume=FakeVolume(), + output_volume_mountpoint="/biomodals-outputs", + output_volume_name="PPIFlow-outputs", + ), + ) + + steps_doc = { + "PPIFlowStep": { + "args": { + "name": "demo", + "specified_hotspots": "A1", + "input_pdb": str(input_pdb), + "binder_chain": "B", + } + } + } + + staged = _stage_ppiflow_app_inputs( + steps_doc=steps_doc, + run_id="run-1", + app_steps=("PPIFlowStep",), + ) + + assert staged["PPIFlowStep"]["args"]["input_pdb"] == ( + "/biomodals-outputs/run-1/PPIFlowStep/input_pdb/input.pdb" + ) + assert uploaded == [(input_pdb, "/run-1/PPIFlowStep/input_pdb/input.pdb")] + + +def test_ppiflow_staging_uses_active_stage_steps( + tmp_path: Path, + monkeypatch, +) -> None: + input_pdb = tmp_path / "input.pdb" + input_pdb.write_text("ATOM\n", encoding="utf-8") + uploaded = [] + + class FakeBatch: + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + def put_file(self, local_path, remote_path): + uploaded.append((Path(local_path), remote_path)) + + class FakeVolume: + def batch_upload(self): + return FakeBatch() + + monkeypatch.setattr( + ppiflow_app, + "CONF", + SimpleNamespace( + output_volume=FakeVolume(), + output_volume_mountpoint="/biomodals-outputs", + output_volume_name="PPIFlow-outputs", + ), + ) + task_doc = { + "steps": { + "PPIFlowStep": True, + "PartialStep": True, + } + } + steps_doc = { + "PPIFlowStep": { + "args": { + "name": "demo", + "specified_hotspots": "A1", + "input_pdb": str(input_pdb), + "binder_chain": "B", + } + }, + "PartialStep": { + "args": { + "name": "demo-partial", + "specified_hotspots": "A1", + "input_pdb": str(tmp_path / "stage2-not-local.pdb"), + "fixed_positions": "B1", + "start_t": 0.5, + } + }, + } + + staged = _stage_ppiflow_app_inputs( + steps_doc=steps_doc, + run_id="run-1", + app_steps=_active_ppiflow_app_steps(task_doc, stage=1), + ) + + assert staged["PPIFlowStep"]["args"]["input_pdb"].endswith( + "/PPIFlowStep/input_pdb/input.pdb" + ) + assert staged["PartialStep"]["args"]["input_pdb"].endswith("stage2-not-local.pdb") + assert uploaded == [(input_pdb, "/run-1/PPIFlowStep/input_pdb/input.pdb")] + + +def test_ppiflow_staging_keeps_same_basename_inputs_distinct( + tmp_path: Path, + monkeypatch, +) -> None: + antigen_pdb = tmp_path / "antigen" / "input.pdb" + framework_pdb = tmp_path / "framework" / "input.pdb" + antigen_pdb.parent.mkdir() + framework_pdb.parent.mkdir() + antigen_pdb.write_text("ATOM antigen\n", encoding="utf-8") + framework_pdb.write_text("ATOM framework\n", encoding="utf-8") + uploaded = [] + + class FakeBatch: + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + def put_file(self, local_path, remote_path): + uploaded.append((Path(local_path), remote_path)) + + class FakeVolume: + def batch_upload(self): + return FakeBatch() + + monkeypatch.setattr( + ppiflow_app, + "CONF", + SimpleNamespace( + output_volume=FakeVolume(), + output_volume_mountpoint="/biomodals-outputs", + output_volume_name="PPIFlow-outputs", + ), + ) + steps_doc = { + "PPIFlowStep": { + "args": { + "name": "demo", + "specified_hotspots": "A1", + "antigen_pdb": str(antigen_pdb), + "antigen_chain": "A", + "framework_pdb": str(framework_pdb), + "heavy_chain": "H", + } + } + } + + staged = _stage_ppiflow_app_inputs( + steps_doc=steps_doc, + run_id="run-1", + app_steps=("PPIFlowStep",), + ) + + assert staged["PPIFlowStep"]["args"]["antigen_pdb"] == ( + "/biomodals-outputs/run-1/PPIFlowStep/antigen_pdb/input.pdb" + ) + assert staged["PPIFlowStep"]["args"]["framework_pdb"] == ( + "/biomodals-outputs/run-1/PPIFlowStep/framework_pdb/input.pdb" + ) + assert uploaded == [ + (antigen_pdb, "/run-1/PPIFlowStep/antigen_pdb/input.pdb"), + (framework_pdb, "/run-1/PPIFlowStep/framework_pdb/input.pdb"), + ] diff --git a/tests/workflow/test_runtime.py b/tests/workflow/test_runtime.py new file mode 100644 index 0000000..52c4dbf --- /dev/null +++ b/tests/workflow/test_runtime.py @@ -0,0 +1,995 @@ +"""Tests for the local workflow runtime scheduler.""" + +# ruff: noqa: D101,D102,D103,D107 + +import sqlite3 +from dataclasses import dataclass, field +from pathlib import Path +from threading import Barrier, BrokenBarrierError, Event, Thread + +import pytest +from pydantic import BaseModel, Field + +from biomodals.schema import ( + AppOutput, + AppRunResult, + AppRunStatus, + ArtifactKind, + InlineBytes, + NodeExecutionPolicy, + NodePlacement, + NodeStatus, + RunStatus, + VolumePath, + WorkflowArtifact, + WorkflowRun, +) +from biomodals.workflow import Workflow +from biomodals.workflow.core.ledger import WorkflowLedger +from biomodals.workflow.core.nodes import WorkflowNativeNode +from biomodals.workflow.core.runtime import WorkflowRuntime + + +class FakeNode(WorkflowNativeNode): + def __init__( + self, + *, + result: AppRunResult | None = None, + calls: list[str] | None = None, + policy: NodeExecutionPolicy = NodeExecutionPolicy.RERUN, + ): + self.result = result or AppRunResult( + status=AppRunStatus.SUCCEEDED, + outputs=[ + AppOutput( + name="output", + kind=ArtifactKind.REPORT, + storage=InlineBytes(data=b"ok", filename="output.txt"), + ) + ], + ) + self.calls = calls + self.execution_policy = policy + self.seen_cache_dir: Path | None = None + self.seen_inputs = None + + def run(self, context): + self.seen_cache_dir = context.cache_dir + self.seen_inputs = context.inputs + if self.calls is not None: + self.calls.append(context.node_id) + return self.result + + +class ExplodingNode(WorkflowNativeNode): + def run(self, context): + raise AssertionError(f"{context.node_id} should not run") + + +class RuntimeErrorNode(WorkflowNativeNode): + def run(self, context): + raise RuntimeError(f"{context.node_id} exploded") + + +class CommitObservedNode(WorkflowNativeNode): + def __init__(self, volume: "FakeVolume"): + self.volume = volume + self.commit_count_at_run = -1 + + def run(self, context): + self.commit_count_at_run = self.volume.commit_count + return AppRunResult(status=AppRunStatus.SUCCEEDED) + + +class RemoteOnlyNode(WorkflowNativeNode): + placement = NodePlacement.REMOTE + + def run(self, context): + raise AssertionError("remote placement should use remote_node_runner") + + +class FakeRemoteCall: + def __init__( + self, + *, + object_id: str, + result: AppRunResult, + on_get=None, + effects: list[object] | None = None, + ): + self.object_id = object_id + self.result = result + self.on_get = on_get + self.effects = effects or [] + self.get_timeouts: list[float | int | None] = [] + self.cancel_calls: list[bool] = [] + + def get(self, timeout=None): + self.get_timeouts.append(timeout) + if self.on_get is not None: + self.on_get(timeout) + if self.effects: + effect = self.effects.pop(0) + if isinstance(effect, BaseException): + raise effect + return effect + return self.result + + def cancel(self, terminate_containers: bool = False) -> None: + self.cancel_calls.append(terminate_containers) + + +class FakeVolume: + def __init__(self) -> None: + self.commit_count = 0 + self.reload_count = 0 + self.on_commit = None + self.on_reload = None + + def commit(self) -> None: + if self.on_commit is not None: + self.on_commit() + self.commit_count += 1 + + def reload(self) -> None: + if self.on_reload is not None: + self.on_reload() + self.reload_count += 1 + + +class HashSettings(BaseModel): + visible: str + hidden: str = Field(repr=False) + + +@dataclass +class ConfiguredNode(WorkflowNativeNode): + settings: HashSettings + output_path: Path + + def run(self, context): + return AppRunResult(status=AppRunStatus.SUCCEEDED) + + +@dataclass +class BytesConfiguredNode(WorkflowNativeNode): + payload: bytes + + def run(self, context): + return AppRunResult(status=AppRunStatus.SUCCEEDED) + + +@dataclass +class RuntimeHandleNode(WorkflowNativeNode): + handle: object = field(metadata={"dag_hash": False}) + + def run(self, context): + return AppRunResult(status=AppRunStatus.SUCCEEDED) + + +def test_volume_sync_closes_open_ledger_connection(tmp_path: Path) -> None: + workflow = Workflow("demo") + volume = FakeVolume() + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + workflow_volume=volume, + ) + runtime.ledger.create_run(WorkflowRun(workflow_name="demo", run_id="run-1")) + assert runtime.ledger._connection is not None + + def assert_ledger_closed() -> None: + assert runtime.ledger._connection is None + + volume.on_reload = assert_ledger_closed + volume.on_commit = assert_ledger_closed + + runtime._reload_volume() + runtime.ledger.load_run("demo", "run-1") + assert runtime.ledger._connection is not None + + runtime._commit_volume() + + assert volume.reload_count == 1 + assert volume.commit_count == 1 + assert runtime.ledger._connection is None + + +def test_completed_nodes_are_skipped(tmp_path: Path) -> None: + workflow = Workflow("demo") + workflow.add_node(ExplodingNode(), id="done") + ledger = WorkflowLedger(tmp_path) + ledger.create_run(WorkflowRun(workflow_name="demo", run_id="run-1")) + ledger.record_artifacts([ + WorkflowArtifact( + artifact_id="artifact-1", + producing_node_id="done", + kind=ArtifactKind.REPORT, + storage=VolumePath(volume_name="Workflow-outputs", path="done"), + ) + ]) + ledger.mark_node_succeeded("done", ["artifact-1"]) + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.SUCCEEDED + assert runtime.executed_waves == [] + + +def test_force_replaces_existing_run_ledger_and_reruns_completed_nodes( + tmp_path: Path, +) -> None: + workflow = Workflow("demo") + calls: list[str] = [] + workflow.add_node(FakeNode(calls=calls), id="one") + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + + first = runtime.run(run_id="run-1") + stale_file = tmp_path / "demo" / "run-1" / "nodes" / "one" / "cache" / "stale" + stale_file.write_text("old", encoding="utf-8") + second = runtime.run(run_id="run-1", force=True) + + assert first.status == AppRunStatus.SUCCEEDED + assert second.status == AppRunStatus.SUCCEEDED + assert calls == ["one", "one"] + assert not stale_file.exists() + + +def test_independent_ready_nodes_run_in_same_scheduler_wave( + tmp_path: Path, +) -> None: + workflow = Workflow("demo") + upstream = workflow.add_node(FakeNode(), id="design") + calls: list[str] = [] + workflow.add_node( + FakeNode(calls=calls), + id="score-a", + inputs={"structures": upstream.outputs(kind=ArtifactKind.STRUCTURES)}, + ) + workflow.add_node( + FakeNode(calls=calls), + id="score-b", + inputs={"structures": upstream.outputs(kind=ArtifactKind.STRUCTURES)}, + ) + ledger = WorkflowLedger(tmp_path) + ledger.create_run(WorkflowRun(workflow_name="demo", run_id="run-1")) + ledger.record_artifacts([ + WorkflowArtifact( + artifact_id="design-artifact", + producing_node_id="design", + kind=ArtifactKind.STRUCTURES, + storage=VolumePath(volume_name="Workflow-outputs", path="design"), + ) + ]) + ledger.mark_node_succeeded("design", ["design-artifact"]) + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + runtime.run(run_id="run-1") + + assert set(calls) == {"score-a", "score-b"} + assert runtime.executed_waves == [["score-a", "score-b"]] + + +def test_independent_ready_nodes_execute_concurrently(tmp_path: Path) -> None: + barrier = Barrier(2, timeout=0.5) + workflow = Workflow("demo") + + class BarrierNode(WorkflowNativeNode): + def run(self, context): + try: + barrier.wait() + except BrokenBarrierError: + return AppRunResult(status=AppRunStatus.FAILED) + return AppRunResult(status=AppRunStatus.SUCCEEDED) + + workflow.add_node(BarrierNode(), id="one") + workflow.add_node(BarrierNode(), id="two") + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.SUCCEEDED + assert runtime.executed_waves == [["one", "two"]] + + +def test_failed_node_prevents_downstream_nodes_from_running(tmp_path: Path) -> None: + workflow = Workflow("demo") + failed = workflow.add_node( + FakeNode(result=AppRunResult(status=AppRunStatus.FAILED)), + id="fail", + ) + workflow.add_node(ExplodingNode(), id="downstream", depends_on=[failed]) + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.FAILED + assert runtime.executed_waves == [["fail"]] + + +def test_partial_node_marks_run_failed_and_blocks_downstream(tmp_path: Path) -> None: + workflow = Workflow("demo") + partial = workflow.add_node( + FakeNode(result=AppRunResult(status=AppRunStatus.PARTIAL)), + id="partial", + ) + workflow.add_node(ExplodingNode(), id="downstream", depends_on=[partial]) + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.PARTIAL + assert runtime.ledger.load_run("demo", "run-1").status == RunStatus.FAILED + status = runtime.ledger._load_node_status_or_default("partial") + assert status.status == NodeStatus.FAILED + assert status.error == "Node returned partial status" + assert runtime.executed_waves == [["partial"]] + + +def test_single_node_exception_marks_node_and_run_failed(tmp_path: Path) -> None: + workflow = Workflow("demo") + workflow.add_node(RuntimeErrorNode(), id="fail") + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.FAILED + assert result.warnings == ["fail exploded"] + assert runtime.ledger.load_run("demo", "run-1").status == RunStatus.FAILED + status = runtime.ledger._load_node_status_or_default("fail") + assert status.status == NodeStatus.FAILED + assert status.error is not None + assert "Traceback" in status.error + assert "RuntimeError: fail exploded" in status.error + + +def test_runtime_logs_dag_and_node_state_transitions( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + workflow = Workflow("demo") + first = workflow.add_node(FakeNode(), id="prepare") + workflow.add_node(FakeNode(), id="produce", depends_on=[first]) + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.SUCCEEDED + stdout = capsys.readouterr().out + assert "[workflow] Starting workflow 'demo' run 'run-1'" in stdout + assert "[workflow] DAG graph:" in stdout + assert "[workflow] prepare" in stdout + assert "[workflow] produce" in stdout + assert "<- prepare" in stdout + assert "[workflow] Node started: prepare attempt=attempt-1" in stdout + assert "[workflow] Node succeeded: prepare attempt=attempt-1" in stdout + assert "[workflow] Node started: produce attempt=attempt-1" in stdout + assert "[workflow] Node succeeded: produce attempt=attempt-1" in stdout + + +def test_runtime_logs_failed_node_transition( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + workflow = Workflow("demo") + workflow.add_node( + FakeNode(result=AppRunResult(status=AppRunStatus.FAILED)), + id="fail", + ) + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.FAILED + stdout = capsys.readouterr().out + assert "[workflow] Node started: fail attempt=attempt-1" in stdout + assert "[workflow] Node failed: fail attempt=attempt-1" in stdout + + +def test_runtime_dag_hash_uses_stable_json_for_dataclass_node_config() -> None: + first_workflow = Workflow("demo") + first_workflow.add_node( + ConfiguredNode( + settings=HashSettings(visible="same", hidden="one"), + output_path=Path("outputs/report.txt"), + ), + id="configured", + ) + second_workflow = Workflow("demo") + second_workflow.add_node( + ConfiguredNode( + settings=HashSettings(visible="same", hidden="two"), + output_path=Path("outputs/report.txt"), + ), + id="configured", + ) + repeated_workflow = Workflow("demo") + repeated_workflow.add_node( + ConfiguredNode( + settings=HashSettings(visible="same", hidden="one"), + output_path=Path("outputs/report.txt"), + ), + id="configured", + ) + + first_hash = WorkflowRuntime._dag_hash(first_workflow.validate()) + second_hash = WorkflowRuntime._dag_hash(second_workflow.validate()) + repeated_hash = WorkflowRuntime._dag_hash(repeated_workflow.validate()) + + assert first_hash != second_hash + assert first_hash == repeated_hash + + +def test_runtime_dag_hash_supports_bytes_in_dataclass_node_config() -> None: + first_workflow = Workflow("demo") + first_workflow.add_node(BytesConfiguredNode(payload=b"ATOM 1\n"), id="configured") + second_workflow = Workflow("demo") + second_workflow.add_node(BytesConfiguredNode(payload=b"ATOM 2\n"), id="configured") + repeated_workflow = Workflow("demo") + repeated_workflow.add_node( + BytesConfiguredNode(payload=b"ATOM 1\n"), id="configured" + ) + + first_hash = WorkflowRuntime._dag_hash(first_workflow.validate()) + second_hash = WorkflowRuntime._dag_hash(second_workflow.validate()) + repeated_hash = WorkflowRuntime._dag_hash(repeated_workflow.validate()) + + assert first_hash != second_hash + assert first_hash == repeated_hash + + +def test_runtime_dag_hash_skips_dataclass_fields_marked_excluded() -> None: + first_workflow = Workflow("demo") + first_workflow.add_node(RuntimeHandleNode(handle=object()), id="node") + second_workflow = Workflow("demo") + second_workflow.add_node(RuntimeHandleNode(handle=object()), id="node") + + assert WorkflowRuntime._dag_hash(first_workflow.validate()) == ( + WorkflowRuntime._dag_hash(second_workflow.validate()) + ) + + +def test_running_node_without_recoverable_call_is_not_duplicated( + tmp_path: Path, +) -> None: + workflow = Workflow("demo") + calls: list[str] = [] + workflow.add_node(FakeNode(calls=calls), id="incomplete") + ledger = WorkflowLedger(tmp_path) + ledger.create_run(WorkflowRun(workflow_name="demo", run_id="run-1")) + ledger.mark_node_running("incomplete", "attempt-old") + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.PARTIAL + assert calls == [] + assert runtime.ledger.load_run("demo", "run-1").status == RunStatus.RUNNING + + +def test_rerun_policy_discards_failed_attempt_state(tmp_path: Path) -> None: + workflow = Workflow("demo") + calls: list[str] = [] + workflow.add_node(FakeNode(calls=calls), id="incomplete") + ledger = WorkflowLedger(tmp_path) + ledger.create_run(WorkflowRun(workflow_name="demo", run_id="run-1")) + ledger.mark_node_running("incomplete", "attempt-old") + ledger.record_attempt_started("incomplete", "attempt-old") + ledger.mark_node_failed("incomplete", "old failure") + old_cache = tmp_path / "demo" / "run-1" / "nodes" / "incomplete" / "cache" / "old" + old_cache.parent.mkdir(parents=True) + old_cache.write_text("old", encoding="utf-8") + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.SUCCEEDED + assert calls == ["incomplete"] + assert not old_cache.exists() + status = runtime.ledger._load_node_status_or_default("incomplete") + assert status.attempts == ["attempt-1"] + + +def test_resume_policy_receives_durable_cache_path(tmp_path: Path) -> None: + workflow = Workflow("demo") + node = FakeNode(policy=NodeExecutionPolicy.RESUME) + workflow.add_node(node, id="long") + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + runtime.run(run_id="run-1") + + assert node.seen_cache_dir == tmp_path / "demo/run-1/nodes/long/cache" + assert node.seen_cache_dir is not None + assert node.seen_cache_dir.exists() + + +def test_runtime_commits_node_start_before_node_execution(tmp_path: Path) -> None: + workflow = Workflow("demo") + volume = FakeVolume() + node = CommitObservedNode(volume) + workflow.add_node(node, id="long") + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + workflow_volume=volume, + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.SUCCEEDED + assert node.commit_count_at_run >= 3 + + +def test_remote_placement_uses_injected_remote_runner(tmp_path: Path) -> None: + workflow = Workflow("demo") + workflow.add_node(RemoteOnlyNode(), id="remote") + calls = [] + + def remote_runner(node, context): + calls.append((node, context.node_id)) + return AppRunResult(status=AppRunStatus.SUCCEEDED) + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + remote_node_runner=remote_runner, + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.SUCCEEDED + assert calls == [(workflow.validate().nodes["remote"].node, "remote")] + + +def test_remote_placement_records_function_call_before_waiting( + tmp_path: Path, +) -> None: + workflow = Workflow("demo") + workflow.add_node(RemoteOnlyNode(), id="remote") + observed_statuses: list[str] = [] + + def observe_remote_call(timeout): + with sqlite3.connect(tmp_path / "demo" / "run-1" / "ledger.sqlite3") as conn: + row = conn.execute( + "SELECT status FROM remote_calls WHERE call_id = 'fc-new'" + ).fetchone() + observed_statuses.append(row[0]) + + def remote_runner(node, context): + return FakeRemoteCall( + object_id="fc-new", + result=AppRunResult(status=AppRunStatus.SUCCEEDED), + on_get=observe_remote_call, + ) + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + remote_node_runner=remote_runner, + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.SUCCEEDED + assert observed_statuses == ["submitted"] + with sqlite3.connect(tmp_path / "demo" / "run-1" / "ledger.sqlite3") as conn: + final_status = conn.execute( + "SELECT status FROM remote_calls WHERE call_id = 'fc-new'" + ).fetchone()[0] + assert final_status == "succeeded" + + +def test_remote_placement_records_configured_function_name( + tmp_path: Path, +) -> None: + workflow = Workflow("demo") + workflow.add_node(RemoteOnlyNode(), id="remote") + + def remote_runner(node, context): + return FakeRemoteCall( + object_id="fc-named", + result=AppRunResult(status=AppRunStatus.SUCCEEDED), + ) + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + remote_node_runner=remote_runner, + remote_node_function_name="run_node", + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.SUCCEEDED + with sqlite3.connect(tmp_path / "demo" / "run-1" / "ledger.sqlite3") as conn: + function_name = conn.execute( + "SELECT function_name FROM remote_calls WHERE call_id = 'fc-named'" + ).fetchone()[0] + assert function_name == "run_node" + + +def test_remote_call_failure_after_timeout_is_recorded(tmp_path: Path) -> None: + workflow = Workflow("demo") + workflow.add_node(RemoteOnlyNode(), id="remote") + + def remote_runner(node, context): + return FakeRemoteCall( + object_id="fc-fail", + result=AppRunResult(status=AppRunStatus.SUCCEEDED), + effects=[TimeoutError(), RuntimeError("remote exploded")], + ) + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + remote_node_runner=remote_runner, + ) + + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.FAILED + with sqlite3.connect(tmp_path / "demo" / "run-1" / "ledger.sqlite3") as conn: + row = conn.execute( + "SELECT status, error FROM remote_calls WHERE call_id = 'fc-fail'" + ).fetchone() + assert row == ("failed", "remote exploded") + + +def test_runtime_cleanup_cancels_in_flight_remote_function_call( + tmp_path: Path, +) -> None: + workflow = Workflow("demo") + workflow.add_node(RemoteOnlyNode(), id="remote") + waiting_for_result = Event() + release_result = Event() + + class BlockingRemoteCall(FakeRemoteCall): + def get(self, timeout=None): + self.get_timeouts.append(timeout) + if timeout == 0: + raise TimeoutError() + waiting_for_result.set() + release_result.wait(timeout=2) + raise RuntimeError("cancelled after cleanup") + + def cancel(self, terminate_containers: bool = False) -> None: + super().cancel(terminate_containers=terminate_containers) + release_result.set() + + remote_call = BlockingRemoteCall( + object_id="fc-blocking", + result=AppRunResult(status=AppRunStatus.SUCCEEDED), + ) + + def remote_runner(node, context): + return remote_call + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + remote_node_runner=remote_runner, + ) + results: list[AppRunResult] = [] + thread = Thread( + target=lambda: results.append(runtime.run(run_id="run-1")), + daemon=True, + ) + thread.start() + + assert waiting_for_result.wait(timeout=2) + try: + runtime.cancel_active_remote_calls(terminate_containers=True) + finally: + release_result.set() + thread.join(timeout=2) + + assert not thread.is_alive() + assert remote_call.cancel_calls == [True] + assert results[0].status == AppRunStatus.FAILED + + +def test_remote_success_records_app_result_before_succeeded_status( + tmp_path: Path, +) -> None: + workflow = Workflow("demo") + workflow.add_node(RemoteOnlyNode(), id="remote") + + def remote_runner(node, context): + return FakeRemoteCall( + object_id="fc-success", + result=AppRunResult(status=AppRunStatus.SUCCEEDED), + ) + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + remote_node_runner=remote_runner, + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.SUCCEEDED + with sqlite3.connect(tmp_path / "demo" / "run-1" / "ledger.sqlite3") as conn: + row = conn.execute( + """ + SELECT app_result_json + FROM attempts + WHERE node_id = 'remote' AND attempt_id = 'attempt-1' + """ + ).fetchone() + assert row[0] == AppRunResult(status=AppRunStatus.SUCCEEDED).model_dump_json() + + +def test_remote_success_reloads_volume_before_materializing_outputs( + tmp_path: Path, +) -> None: + workflow = Workflow("demo") + workflow.add_node(RemoteOnlyNode(), id="remote") + volume = FakeVolume() + + def remote_runner(node, context): + return FakeRemoteCall( + object_id="fc-reload", + result=AppRunResult(status=AppRunStatus.SUCCEEDED), + ) + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + workflow_volume=volume, + remote_node_runner=remote_runner, + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.SUCCEEDED + assert volume.reload_count >= 2 + + +def test_remote_recovery_reattaches_existing_call_before_rerun( + tmp_path: Path, +) -> None: + workflow = Workflow("demo") + workflow.add_node(RemoteOnlyNode(), id="remote") + ledger = WorkflowLedger(tmp_path) + ledger.create_run(WorkflowRun(workflow_name="demo", run_id="run-1")) + ledger.mark_node_running( + "remote", + "attempt-old", + placement=NodePlacement.REMOTE, + ) + ledger.record_attempt_started("remote", "attempt-old") + ledger.record_remote_call( + call_id="fc-old", + node_id="remote", + attempt_id="attempt-old", + function_name="run_node", + call_kind="node", + ) + resolved: list[str] = [] + remote_runner_calls: list[str] = [] + existing_call = FakeRemoteCall( + object_id="fc-old", + result=AppRunResult(status=AppRunStatus.SUCCEEDED), + ) + + def resolve_call(call_id: str): + resolved.append(call_id) + return existing_call + + def remote_runner(node, context): + remote_runner_calls.append(context.node_id) + raise AssertionError("existing remote call should be reattached") + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + remote_node_runner=remote_runner, + function_call_resolver=resolve_call, + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.SUCCEEDED + assert resolved == ["fc-old"] + assert remote_runner_calls == [] + assert existing_call.get_timeouts == [0] + status = runtime.ledger._load_node_status_or_default("remote") + assert status.status == NodeStatus.SUCCEEDED + + +def test_runtime_passes_selected_upstream_artifacts_to_node_context( + tmp_path: Path, +) -> None: + workflow = Workflow("demo") + upstream = workflow.add_node(ExplodingNode(), id="design") + downstream = FakeNode() + workflow.add_node( + downstream, + id="score", + inputs={"structures": upstream.outputs(kind=ArtifactKind.STRUCTURES)}, + ) + ledger = WorkflowLedger(tmp_path) + ledger.create_run(WorkflowRun(workflow_name="demo", run_id="run-1")) + ledger.record_artifacts([ + WorkflowArtifact( + artifact_id="design-structures", + producing_node_id="design", + kind=ArtifactKind.STRUCTURES, + storage=VolumePath( + volume_name="Workflow-outputs", + path="demo/run-1/nodes/design/outputs", + ), + ) + ]) + ledger.mark_node_succeeded("design", ["design-structures"]) + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.SUCCEEDED + assert downstream.seen_inputs == { + "structures": [ + WorkflowArtifact( + artifact_id="design-structures", + producing_node_id="design", + kind=ArtifactKind.STRUCTURES, + storage=VolumePath( + volume_name="Workflow-outputs", + path="demo/run-1/nodes/design/outputs", + ), + ) + ] + } + status = runtime.ledger._load_node_status_or_default("score") + assert status.input_artifact_ids == ["design-structures"] + + +def test_runtime_records_succeeded_run_status(tmp_path: Path) -> None: + workflow = Workflow("demo") + workflow.add_node(FakeNode(), id="one") + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.SUCCEEDED + assert runtime.ledger.load_run("demo", "run-1").status == RunStatus.SUCCEEDED + + +def test_runtime_records_dag_hash_and_timestamps(tmp_path: Path) -> None: + workflow = Workflow("demo") + workflow.add_node(FakeNode(), id="one") + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + result = runtime.run(run_id="run-1") + run = runtime.ledger.load_run("demo", "run-1") + + assert result.status == AppRunStatus.SUCCEEDED + assert run.dag_hash + assert run.created_at <= run.updated_at + + +def test_runtime_rejects_resume_when_dag_hash_changed(tmp_path: Path) -> None: + first_workflow = Workflow("demo") + first_workflow.add_node(FakeNode(), id="one") + first_runtime = WorkflowRuntime( + workflow=first_workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + first_runtime.run(run_id="run-1") + + second_workflow = Workflow("demo") + second_workflow.add_node(FakeNode(), id="renamed") + second_runtime = WorkflowRuntime( + workflow=second_workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + + with pytest.raises(ValueError, match="DAG hash"): + second_runtime.run(run_id="run-1") + + +def test_runtime_records_failed_run_status(tmp_path: Path) -> None: + workflow = Workflow("demo") + workflow.add_node( + FakeNode(result=AppRunResult(status=AppRunStatus.FAILED)), + id="fail", + ) + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.FAILED + assert runtime.ledger.load_run("demo", "run-1").status == RunStatus.FAILED + + +def test_runtime_reloads_and_commits_workflow_volume(tmp_path: Path) -> None: + workflow = Workflow("demo") + workflow.add_node(FakeNode(), id="one") + volume = FakeVolume() + + runtime = WorkflowRuntime( + workflow=workflow, + volume_root=tmp_path, + workflow_volume_name="Workflow-outputs", + workflow_volume=volume, + ) + result = runtime.run(run_id="run-1") + + assert result.status == AppRunStatus.SUCCEEDED + assert volume.reload_count >= 1 + assert volume.commit_count >= 1 diff --git a/tests/workflow/test_shortmd_workflow.py b/tests/workflow/test_shortmd_workflow.py new file mode 100644 index 0000000..4488ade --- /dev/null +++ b/tests/workflow/test_shortmd_workflow.py @@ -0,0 +1,709 @@ +"""Tests for the ShortMD workflow definition.""" + +# ruff: noqa: D103 + +from pathlib import Path +from typing import cast + +import modal +import pytest + +from biomodals.app.bioinfo import gromacs_app +from biomodals.schema import ( + AppRunResult, + AppRunStatus, + ArtifactKind, + InlineBytes, + NodePlacement, + VolumePath, + WorkflowArtifact, +) +from biomodals.workflow import shortmd_workflow +from biomodals.workflow.core.nodes import NodeRunContext +from biomodals.workflow.shortmd_workflow import ( + ShortMDCloneNode, + ShortMDGromacsSettings, + ShortMDModalNamespace, + ShortMDPrepNode, + ShortMDReplicateNode, + ShortMDSummaryNode, + build_shortmd_workflow, + clone_prepared_shortmd_run, + discover_pdb_inputs, +) + + +class UnexpectedRemoteFunction: + """Sentinel remote object for paths a test must not call.""" + + def remote(self, *args: object, **kwargs: object) -> object: + """Fail if the sentinel is invoked.""" + pytest.fail(f"Unexpected remote call: args={args}, kwargs={kwargs}") + + +UNEXPECTED_REMOTE = cast(modal.Function, UnexpectedRemoteFunction()) + + +def test_shortmd_uses_gromacs_app_volume_metadata() -> None: + assert shortmd_workflow.CONF.depends_on_apps == ("gromacs",) + assert shortmd_workflow.CONF.tags == {"depends_on": "gromacs"} + assert ( + shortmd_workflow.GROMACS_OUTPUT_MOUNTPOINT + == gromacs_app.CONF.output_volume_mountpoint + ) + assert shortmd_workflow.GROMACS_OUTPUT_VOLUME is gromacs_app.CONF.output_volume + assert ( + shortmd_workflow.GROMACS_OUTPUT_VOLUME_NAME + == gromacs_app.CONF.output_volume_name + ) + + +def test_discover_pdb_inputs_globs_pdb_files(tmp_path: Path) -> None: + tmp_path.joinpath("b.pdb").write_text("B\n", encoding="utf-8") + tmp_path.joinpath("a.pdb").write_text("A\n", encoding="utf-8") + tmp_path.joinpath("ignore.txt").write_text("x\n", encoding="utf-8") + + discovered = discover_pdb_inputs(tmp_path) + + assert set(discovered) == {("a.pdb", b"A\n"), ("b.pdb", b"B\n")} + + +def test_discover_pdb_inputs_rejects_empty_directory(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="No PDB files"): + discover_pdb_inputs(tmp_path) + + +def test_build_shortmd_workflow_models_prep_replicate_summary_dependencies() -> None: + workflow = build_shortmd_workflow( + input_pdbs=[("alpha.pdb", b"ATOM\n"), ("beta.pdb", b"ATOM\n")], + replicates=2, + simulation_time_ns=2, + cpu_only=True, + max_parallel=8, + ) + + definition = workflow.validate() + + assert workflow.name == "shortmd" + assert set(definition.nodes) == { + "prep-alpha", + "clone-alpha-r001", + "clone-alpha-r002", + "replicate-alpha-r001", + "replicate-alpha-r002", + "prep-beta", + "clone-beta-r001", + "clone-beta-r002", + "replicate-beta-r001", + "replicate-beta-r002", + "summary", + } + assert definition.dependencies["clone-alpha-r001"] == {"prep-alpha"} + assert definition.dependencies["clone-alpha-r002"] == {"prep-alpha"} + assert definition.dependencies["replicate-alpha-r001"] == {"clone-alpha-r001"} + assert definition.dependencies["replicate-alpha-r002"] == {"clone-alpha-r002"} + assert definition.dependencies["summary"] == { + "replicate-alpha-r001", + "replicate-alpha-r002", + "replicate-beta-r001", + "replicate-beta-r002", + } + + prep_node = definition.nodes["prep-alpha"].node + clone_node = definition.nodes["clone-alpha-r001"].node + replicate_node = definition.nodes["replicate-alpha-r001"].node + summary_node = definition.nodes["summary"].node + + assert isinstance(prep_node, ShortMDPrepNode) + assert prep_node.placement == NodePlacement.REMOTE + assert prep_node.run_name == "alpha" + assert prep_node.pdb_content == b"ATOM\n" + assert { + "app_name", + "prep_cpu_function", + "prep_gpu_function", + "prep_cpu_function_name", + "prep_gpu_function_name", + }.isdisjoint(prep_node.__dict__) + assert isinstance(prep_node.modal_namespace, ShortMDModalNamespace) + + assert isinstance(clone_node, ShortMDCloneNode) + assert clone_node.placement == NodePlacement.REMOTE + assert clone_node.source_run_name == "alpha" + assert clone_node.replicate_run_name == "alpha-r001" + assert "clone_function" not in clone_node.__dict__ + assert clone_node.modal_namespace is prep_node.modal_namespace + + assert isinstance(replicate_node, ShortMDReplicateNode) + assert replicate_node.placement == NodePlacement.REMOTE + assert replicate_node.source_run_name == "alpha" + assert replicate_node.replicate_run_name == "alpha-r001" + assert replicate_node.gromacs.simulation_time_ns == 2 + assert replicate_node.gromacs.cpu_only is True + assert { + "app_name", + "production_cpu_function", + "production_gpu_function", + "stats_function", + "production_cpu_function_name", + "production_gpu_function_name", + "stats_function_name", + }.isdisjoint(replicate_node.__dict__) + assert replicate_node.modal_namespace is prep_node.modal_namespace + + assert isinstance(summary_node, ShortMDSummaryNode) + assert summary_node.max_parallel == 8 + + +def test_build_shortmd_workflow_rejects_duplicate_sanitized_stems() -> None: + with pytest.raises(ValueError, match="Duplicate"): + build_shortmd_workflow( + input_pdbs=[("../a.pdb", b"A\n"), ("a.pdb", b"B\n")], + replicates=1, + ) + + +def test_shortmd_prep_node_runs_gromacs_prepare_and_returns_artifact( + tmp_path: Path, + monkeypatch, +) -> None: + prepare_kwargs = {} + clear_kwargs = {} + events = [] + + class FakePrepareFunction: + def remote(self, **kwargs: object) -> object: + events.append("prepare") + prepare_kwargs.update(kwargs) + return f"{gromacs_app.CONF.output_volume_mountpoint}/prepared/source" + + class FakeClearFunction: + def remote(self, **kwargs: object) -> object: + events.append("clear") + clear_kwargs.update(kwargs) + return None + + modal_namespace = ShortMDModalNamespace( + clear=cast(modal.Function, FakeClearFunction()), + clone=UNEXPECTED_REMOTE, + prepare_cpu=cast(modal.Function, FakePrepareFunction()), + prepare_gpu=UNEXPECTED_REMOTE, + production_cpu=UNEXPECTED_REMOTE, + production_gpu=UNEXPECTED_REMOTE, + collect_stats=UNEXPECTED_REMOTE, + ) + monkeypatch.setattr(shortmd_workflow.modal.Function, "from_name", pytest.fail) + monkeypatch.setattr( + shortmd_workflow, + "clear_shortmd_gromacs_run", + pytest.fail, + ) + monkeypatch.setattr( + shortmd_workflow.gromacs_app, + "prepare_tpr_cpu", + pytest.fail, + ) + + node = ShortMDPrepNode( + pdb_content=b"ATOM\n", + run_name="../source", + modal_namespace=modal_namespace, + overwrite_existing=True, + gromacs=ShortMDGromacsSettings( + simulation_time_ns=2, + run_pdbfixer=True, + cpu_only=True, + num_threads=8, + use_openmp_threads=True, + ld_seed=11, + gen_seed=12, + genion_seed=13, + ), + ) + result = node.run( + NodeRunContext( + run_id="run-1", + node_id="prep-source", + attempt_id="attempt-1", + cache_dir=tmp_path / "cache", + inputs={}, + ) + ) + + assert prepare_kwargs == { + "pdb_content": b"ATOM\n", + "run_name": "source", + "simulation_time_ns": 2, + "run_pdbfixer": True, + "num_threads": 8, + "use_openmp_threads": True, + "ld_seed": 11, + "gen_seed": 12, + "genion_seed": 13, + } + assert clear_kwargs == {"run_name": "source"} + assert events == ["clear", "prepare"] + assert result.status == AppRunStatus.SUCCEEDED + assert result.outputs[0].name == "prepared_gromacs_run" + assert result.outputs[0].kind == ArtifactKind.DIRECTORY + assert result.outputs[0].storage == VolumePath( + volume_name=gromacs_app.CONF.output_volume_name, + path="prepared/source", + ) + assert result.outputs[0].metadata == {"stage": "prep", "run_name": "source"} + + +def test_shortmd_prep_node_rejects_workdir_outside_gromacs_mount( + tmp_path: Path, + monkeypatch, +) -> None: + class FakePrepareFunction: + def remote(self, **kwargs: object) -> object: + return "/outside-gromacs-output" + + modal_namespace = ShortMDModalNamespace( + clear=UNEXPECTED_REMOTE, + clone=UNEXPECTED_REMOTE, + prepare_cpu=UNEXPECTED_REMOTE, + prepare_gpu=cast(modal.Function, FakePrepareFunction()), + production_cpu=UNEXPECTED_REMOTE, + production_gpu=UNEXPECTED_REMOTE, + collect_stats=UNEXPECTED_REMOTE, + ) + monkeypatch.setattr( + shortmd_workflow.modal.Function, + "from_name", + pytest.fail, + ) + monkeypatch.setattr( + shortmd_workflow.gromacs_app, + "prepare_tpr_gpu", + FakePrepareFunction(), + ) + + node = ShortMDPrepNode( + pdb_content=b"ATOM\n", + run_name="source", + modal_namespace=modal_namespace, + ) + with pytest.raises(ValueError, match="outside"): + node.run( + NodeRunContext( + run_id="run-1", + node_id="prep-source", + attempt_id="attempt-1", + cache_dir=tmp_path / "cache", + inputs={}, + ) + ) + + +def test_clone_prepared_shortmd_run_copies_prepared_inputs_into_replicate( + tmp_path: Path, + monkeypatch, +) -> None: + class FakeOutputVolume: + def __init__(self) -> None: + self.commit_count = 0 + self.reload_count = 0 + + def commit(self) -> None: + self.commit_count += 1 + + def reload(self) -> None: + self.reload_count += 1 + + output_volume = FakeOutputVolume() + source_dir = tmp_path / "prepared" / "source" + source_dir.mkdir(parents=True) + source_dir.joinpath("source.pdb").write_text("ATOM\n", encoding="utf-8") + source_dir.joinpath("production_source.tpr").write_text("tpr\n", encoding="utf-8") + source_dir.joinpath("production_source.xtc").write_text("stale\n", encoding="utf-8") + source_dir.joinpath("npt_source.gro").write_text("npt\n", encoding="utf-8") + + monkeypatch.setattr(shortmd_workflow, "GROMACS_OUTPUT_MOUNTPOINT", str(tmp_path)) + monkeypatch.setattr(shortmd_workflow, "GROMACS_OUTPUT_VOLUME", output_volume) + + result = clone_prepared_shortmd_run.get_raw_f()( + source_storage_path="prepared/source", + source_run_name="source", + replicate_run_name="source-r001", + ) + + replicate_dir = tmp_path / "source-r001" + assert result == str(replicate_dir) + assert replicate_dir.joinpath("source-r001.pdb").read_text(encoding="utf-8") == ( + "ATOM\n" + ) + assert ( + replicate_dir.joinpath("production_source-r001.tpr").read_text(encoding="utf-8") + == "tpr\n" + ) + assert not replicate_dir.joinpath("production_source.xtc").exists() + assert output_volume.reload_count == 1 + assert output_volume.commit_count == 1 + + +def test_shortmd_clone_node_clones_prepared_run_and_returns_artifact( + tmp_path: Path, + monkeypatch, +) -> None: + clone_kwargs = {} + + class FakeCloneFunction: + def remote(self, **kwargs: object) -> object: + clone_kwargs.update(kwargs) + return f"{gromacs_app.CONF.output_volume_mountpoint}/source-r001" + + modal_namespace = ShortMDModalNamespace( + clear=UNEXPECTED_REMOTE, + clone=cast(modal.Function, FakeCloneFunction()), + prepare_cpu=UNEXPECTED_REMOTE, + prepare_gpu=UNEXPECTED_REMOTE, + production_cpu=UNEXPECTED_REMOTE, + production_gpu=UNEXPECTED_REMOTE, + collect_stats=UNEXPECTED_REMOTE, + ) + monkeypatch.setattr(shortmd_workflow.modal.Function, "from_name", pytest.fail) + monkeypatch.setattr( + shortmd_workflow, + "clone_prepared_shortmd_run", + pytest.fail, + ) + + node = ShortMDCloneNode( + source_run_name="source", + replicate_run_name="source-r001", + modal_namespace=modal_namespace, + overwrite_clone=True, + ) + result = node.run( + NodeRunContext( + run_id="run-1", + node_id="clone-source-r001", + attempt_id="attempt-1", + cache_dir=tmp_path / "cache", + inputs={ + "prepared": [ + WorkflowArtifact( + artifact_id="source", + producing_node_id="prep-source", + kind=ArtifactKind.DIRECTORY, + storage=VolumePath( + volume_name=gromacs_app.CONF.output_volume_name, + path="prepared/source", + ), + metadata={"stage": "prep", "run_name": "source"}, + ) + ] + }, + ) + ) + + assert clone_kwargs == { + "source_storage_path": "prepared/source", + "source_run_name": "source", + "replicate_run_name": "source-r001", + "overwrite": True, + } + assert result.status == AppRunStatus.SUCCEEDED + assert result.outputs[0].name == "cloned_gromacs_run" + assert result.outputs[0].kind == ArtifactKind.DIRECTORY + assert result.outputs[0].storage == VolumePath( + volume_name=gromacs_app.CONF.output_volume_name, + path="source-r001", + ) + assert result.outputs[0].metadata == { + "stage": "clone", + "run_name": "source-r001", + "source_run_name": "source", + } + + +def test_shortmd_replicate_node_runs_gromacs_production( + tmp_path: Path, + monkeypatch, +) -> None: + production_kwargs = {} + stats_kwargs = {} + + class FakeProductionFunction: + def remote(self, **kwargs: object) -> object: + production_kwargs.update(kwargs) + return f"{gromacs_app.CONF.output_volume_mountpoint}/source-r001" + + class FakeStatsFunction: + def remote(self, traj_prefix: str, **kwargs: object) -> object: + stats_kwargs["traj_prefix"] = traj_prefix + stats_kwargs.update(kwargs) + return f"{gromacs_app.CONF.output_volume_mountpoint}/production/source-r001" + + modal_namespace = ShortMDModalNamespace( + clear=UNEXPECTED_REMOTE, + clone=UNEXPECTED_REMOTE, + prepare_cpu=UNEXPECTED_REMOTE, + prepare_gpu=UNEXPECTED_REMOTE, + production_cpu=UNEXPECTED_REMOTE, + production_gpu=cast(modal.Function, FakeProductionFunction()), + collect_stats=cast(modal.Function, FakeStatsFunction()), + ) + monkeypatch.setattr(shortmd_workflow.modal.Function, "from_name", pytest.fail) + monkeypatch.setattr( + shortmd_workflow.gromacs_app, + "production_run_gpu", + pytest.fail, + ) + monkeypatch.setattr( + shortmd_workflow.gromacs_app, + "collect_traj_stats", + pytest.fail, + ) + + node = ShortMDReplicateNode( + source_run_name="source", + replicate_run_name="source-r001", + modal_namespace=modal_namespace, + ) + result = node.run( + NodeRunContext( + run_id="run-1", + node_id="replicate-source-r001", + attempt_id="attempt-1", + cache_dir=tmp_path / "cache", + inputs={ + "cloned": [ + WorkflowArtifact( + artifact_id="source-r001", + producing_node_id="clone-source-r001", + kind=ArtifactKind.DIRECTORY, + storage=VolumePath( + volume_name=gromacs_app.CONF.output_volume_name, + path="source-r001", + ), + metadata={ + "stage": "clone", + "run_name": "source-r001", + "source_run_name": "source", + }, + ) + ] + }, + ) + ) + + assert production_kwargs == { + "run_name": "source-r001", + "simulation_time_ns": 2, + "num_threads": 16, + "use_openmp_threads": False, + } + assert stats_kwargs == { + "traj_prefix": "production_", + "run_name": "source-r001", + "save_processed_traj": True, + "make_figures": True, + } + assert result.status == AppRunStatus.SUCCEEDED + assert result.outputs[0].name == "gromacs_production" + assert result.outputs[0].kind == ArtifactKind.DIRECTORY + assert result.outputs[0].storage == VolumePath( + volume_name=gromacs_app.CONF.output_volume_name, + path="production/source-r001", + ) + assert result.outputs[0].metadata["run_name"] == "source-r001" + assert result.outputs[0].metadata["source_run_name"] == "source" + + +def test_shortmd_summary_node_emits_markdown_manifest(tmp_path: Path) -> None: + node = ShortMDSummaryNode(replicates=2, max_parallel=4) + context = NodeRunContext( + run_id="run-1", + node_id="summary", + attempt_id="attempt-1", + cache_dir=tmp_path / "cache", + inputs={ + "alpha-r001": [ + WorkflowArtifact( + artifact_id="alpha-r001", + producing_node_id="replicate-alpha-r001", + kind=ArtifactKind.DIRECTORY, + storage=VolumePath( + volume_name=gromacs_app.CONF.output_volume_name, + path="alpha-r001", + ), + metadata={ + "source_run_name": "alpha", + "run_name": "alpha-r001", + }, + ) + ], + "alpha-r002": [ + WorkflowArtifact( + artifact_id="alpha-r002", + producing_node_id="replicate-alpha-r002", + kind=ArtifactKind.DIRECTORY, + storage=VolumePath( + volume_name=gromacs_app.CONF.output_volume_name, + path="alpha-r002", + ), + metadata={ + "source_run_name": "alpha", + "run_name": "alpha-r002", + }, + ) + ], + }, + ) + + result = node.run(context) + + assert len(result.outputs) == 1 + output = result.outputs[0] + assert output.name == "shortmd_summary" + assert output.kind == ArtifactKind.REPORT + assert isinstance(output.storage, InlineBytes) + assert output.storage.filename == "shortmd-summary.md" + report = output.storage.data.decode("utf-8") + assert "# ShortMD Workflow Summary" in report + assert ( + f"| alpha | alpha-r001 | {gromacs_app.CONF.output_volume_name} | alpha-r001 |" + in report + ) + assert ( + f"| alpha | alpha-r002 | {gromacs_app.CONF.output_volume_name} | alpha-r002 |" + in report + ) + + +def test_shortmd_app_includes_orchestrator_class() -> None: + functions = shortmd_workflow.app._local_state.functions + + assert "WorkflowOrchestrator.*" in functions + assert "prepare_tpr_cpu" in functions + assert "prepare_tpr_gpu" in functions + assert "production_run_cpu" in functions + assert "production_run_gpu" in functions + assert "collect_traj_stats" in functions + + +def test_submit_shortmd_workflow_uses_included_orchestrator_class_boundary( + tmp_path: Path, + monkeypatch, + capsys: pytest.CaptureFixture[str], +) -> None: + input_dir = tmp_path / "pdbs" + input_dir.mkdir() + input_dir.joinpath("alpha.pdb").write_text("ATOM\n", encoding="utf-8") + calls = {} + + class FakeOrchestratorMethod: + def remote(self, **kwargs): + calls["remote"] = kwargs + return AppRunResult(status=AppRunStatus.SUCCEEDED) + + def spawn(self, **kwargs): + calls["spawn"] = kwargs + return "call-1" + + class FakeWorkflowOrchestrator: + def __init__(self) -> None: + self.run = FakeOrchestratorMethod() + + monkeypatch.setattr( + shortmd_workflow.orchestrator, + "WorkflowOrchestrator", + FakeWorkflowOrchestrator, + ) + + raw_f = shortmd_workflow.submit_shortmd_workflow.info.raw_f + assert raw_f is not None + raw_f( + input_dir=str(input_dir), + run_id="shortmd-run", + replicates=1, + wait=False, + max_parallel=3, + ) + + assert "remote" not in calls + assert calls["spawn"]["workflow"].name == "shortmd" + definition = calls["spawn"]["workflow"].validate() + prep_node = definition.nodes["prep-shortmd-run-alpha"].node + replicate_node = definition.nodes["replicate-shortmd-run-alpha-r001"].node + + assert prep_node.run_name == "shortmd-run-alpha" + assert replicate_node.source_run_name == "shortmd-run-alpha" + assert replicate_node.replicate_run_name == "shortmd-run-alpha-r001" + assert {"prep_cpu_function", "prep_gpu_function"}.isdisjoint(prep_node.__dict__) + assert { + "production_cpu_function", + "production_gpu_function", + "stats_function", + }.isdisjoint(replicate_node.__dict__) + assert prep_node.modal_namespace.clear is shortmd_workflow.clear_shortmd_gromacs_run + assert prep_node.modal_namespace.prepare_cpu is gromacs_app.prepare_tpr_cpu + assert prep_node.modal_namespace.prepare_gpu is gromacs_app.prepare_tpr_gpu + assert ( + replicate_node.modal_namespace.production_cpu is gromacs_app.production_run_cpu + ) + assert ( + replicate_node.modal_namespace.production_gpu is gromacs_app.production_run_gpu + ) + assert ( + replicate_node.modal_namespace.collect_stats is gromacs_app.collect_traj_stats + ) + assert calls["spawn"]["run_id"] == "shortmd-run" + assert calls["spawn"]["force"] is False + assert calls["spawn"]["max_ready_workers"] == 3 + stdout = capsys.readouterr().out + assert "Submitting ShortMD workflow 'shortmd-run'" in stdout + assert "1 input PDB(s)" in stdout + assert "1 replicate(s)" in stdout + + +def test_submit_shortmd_workflow_propagates_force_to_gromacs_overwrite( + tmp_path: Path, + monkeypatch, +) -> None: + input_dir = tmp_path / "pdbs" + input_dir.mkdir() + input_dir.joinpath("alpha.pdb").write_text("ATOM\n", encoding="utf-8") + calls = {} + + class FakeOrchestratorMethod: + def spawn(self, **kwargs): + calls["spawn"] = kwargs + return "call-1" + + class FakeWorkflowOrchestrator: + def __init__(self) -> None: + self.run = FakeOrchestratorMethod() + + monkeypatch.setattr( + shortmd_workflow.orchestrator, + "WorkflowOrchestrator", + FakeWorkflowOrchestrator, + ) + + raw_f = shortmd_workflow.submit_shortmd_workflow.info.raw_f + assert raw_f is not None + raw_f( + input_dir=str(input_dir), + run_id="shortmd-run", + replicates=1, + force=True, + wait=False, + ) + + definition = calls["spawn"]["workflow"].validate() + prep_node = definition.nodes["prep-shortmd-run-alpha"].node + clone_node = definition.nodes["clone-shortmd-run-alpha-r001"].node + + assert prep_node.overwrite_existing is True + assert clone_node.overwrite_clone is True + assert "clone_function" not in clone_node.__dict__ + assert prep_node.modal_namespace.clear is shortmd_workflow.clear_shortmd_gromacs_run + assert ( + clone_node.modal_namespace.clone is shortmd_workflow.clone_prepared_shortmd_run + ) + assert calls["spawn"]["force"] is True