diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..8d1accee1 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,48 @@ +# Git +.git +.gitignore +.github + +# Python +__pycache__ +*.pyc +*.pyo +*.pyd +.Python +env/ +venv/ +.venv/ +pip-log.txt +pip-delete-this-directory.txt +.tox +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.log +.idea +.mypy_cache + +# Documentation +*.md +docs/ + +# Development files +.vscode/ +*.swp +*.swo +*~ + +# Build artifacts (keep minimal) +predictions/ +submission/ +inputs/ +.tmp_spec.json + +# Keep these for the container +!requirements.txt +!examples/ +!src/ +!scripts/ diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..f0fc9ffe0 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,30 @@ +name: CI - Build and Test + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test-scaffold: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pyyaml + # Install any additional requirements from participants + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + + - name: Test script help and argument parsing + run: | + python predict_hackathon.py --help diff --git a/.gitignore b/.gitignore index 3d20fc11a..f98708824 100644 --- a/.gitignore +++ b/.gitignore @@ -161,6 +161,14 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +# Data files +hackathon_data/ + # Boltz prediction outputs # All result files generated from a boltz prediction call -boltz_results_*/ \ No newline at end of file +boltz_results_*/ +my_predictions/ +tmp/ +my_results/ +asos_public_evaluation/ +abag_public_evaluation/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..2257ff8bb --- /dev/null +++ b/Dockerfile @@ -0,0 +1,71 @@ +# Based on Liana64's contribution https://github.com/jwohlwend/boltz/blob/5ee0e6b9740b85ff24349aacc4d69615f499490b/Dockerfile +ARG MINIFORGE_NAME=Miniforge3 +ARG MINIFORGE_VERSION=23.3.1-0 +ARG BASE_IMAGE=nvidia/cuda:12.3.0-runtime-ubuntu22.04 + +FROM ${BASE_IMAGE} AS builder +ARG MINIFORGE_NAME +ARG MINIFORGE_VERSION + +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + build-essential \ + python3 \ + python3-venv \ + python3-dev \ + wget \ + && wget --no-check-certificate --no-hsts https://github.com/conda-forge/miniforge/releases/download/${MINIFORGE_VERSION}/${MINIFORGE_NAME}-${MINIFORGE_VERSION}-Linux-$(uname -m).sh -O miniforge.sh \ + && bash miniforge.sh -b -p /opt/conda \ + && rm miniforge.sh \ + && /opt/conda/bin/mamba init bash + +WORKDIR /app +COPY environment.yml /app/ +COPY src /app/src +COPY pyproject.toml /app/pyproject.toml + +RUN /opt/conda/bin/mamba env create -f environment.yml --name boltz && \ + /opt/conda/bin/mamba init bash && \ + /opt/conda/bin/conda run -n boltz pip install --no-cache-dir --upgrade pip && \ + /opt/conda/bin/conda run -n boltz pip install --no-cache-dir .[cuda] && \ + apt-get purge -y git build-essential wget && \ + apt-get autoremove -y && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* + +FROM ${BASE_IMAGE} +COPY --from=builder /opt/conda /opt/conda + +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 \ + build-essential \ + python3-dev \ + && apt-get autoremove -y \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* + +ENV PATH="/opt/conda/bin:$PATH" \ + LANG=C.UTF-8 \ + PYTHONUNBUFFERED=1 + +ARG USERNAME=boltz +ARG UID=900 +ARG GID=900 +RUN groupadd --gid $GID $USERNAME && \ + useradd --uid $UID --gid $GID --create-home --shell /bin/bash $USERNAME + +WORKDIR /app + +# Copy everything +COPY . /app/ + +RUN chown -R $USERNAME:$USERNAME /app + +USER $USERNAME + +# Initialize mamba and activate the boltz environment +SHELL ["/bin/bash", "-c"] +RUN mamba init bash && \ + echo "mamba activate boltz" >> ~/.bashrc + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/README.md b/README.md index 8bb05bf73..8a3acdb1c 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,11 @@ ![](docs/boltz1_pred_figure.png) +## Hackathon Instructions + +Please refer to the [hackathon README](hackathon/README_Hackathon.md) for instructions on how to participate in the Boltz hackathon. + +**Note:** The hackathon requires at least one CUDA-enabled GPU for running inference. ## Introduction diff --git a/environment.yml b/environment.yml new file mode 100644 index 000000000..53bb204db --- /dev/null +++ b/environment.yml @@ -0,0 +1,7 @@ +name: boltzplatz +channels: + - conda-forge +dependencies: + - python=3.11 + - pip + - pymol-open-source \ No newline at end of file diff --git a/examples/specs/example_protein_complex.json b/examples/specs/example_protein_complex.json new file mode 100644 index 000000000..e235d3c86 --- /dev/null +++ b/examples/specs/example_protein_complex.json @@ -0,0 +1,19 @@ +{ + "datapoint_id": "complex", + "task_type": "protein_complex", + "proteins": [ + { + "id": "A", + "sequence": "MVTPEGNVSLVDESLLVGV", + "msa_path": "examples/msa/seq1.a3m" + }, + { + "id": "B", + "sequence": "MVTPEGNVSLVDESLLVGK", + "msa_path": "examples/msa/seq2.a3m" + } + ], + "ground_truth": { + "structure": "examples/ground_truth/complex.pdb" + } +} diff --git a/examples/specs/example_protein_ligand.json b/examples/specs/example_protein_ligand.json new file mode 100644 index 000000000..1af859039 --- /dev/null +++ b/examples/specs/example_protein_ligand.json @@ -0,0 +1,20 @@ +{ + "datapoint_id": "affinity", + "task_type": "protein_ligand", + "proteins": [ + { + "id": "A", + "sequence": "MVTPEGNVSLVDESLLVGV", + "msa_path": "examples/msa/seq1.a3m" + } + ], + "ligands": [ + { + "id": "B", + "smiles": "N[C@@H](Cc1ccc(O)cc1)C(=O)O" + } + ], + "ground_truth": { + "structure": "examples/ground_truth/affinity.pdb" + } +} diff --git a/examples/test_dataset.jsonl b/examples/test_dataset.jsonl new file mode 100644 index 000000000..63a401e88 --- /dev/null +++ b/examples/test_dataset.jsonl @@ -0,0 +1,2 @@ +{"datapoint_id": "affinity", "task_type": "protein_ligand", "proteins": [{"id": "A", "sequence": "MVTPEGNVSLVDESLLVGV", "msa_path": "examples/msa/seq1.a3m"}], "ligands": [{"id": "B", "smiles": "N[C@@H](Cc1ccc(O)cc1)C(=O)O"}], "ground_truth": {"structure": "examples/ground_truth/affinity.pdb"}} +{"datapoint_id": "complex", "task_type": "protein_complex", "proteins": [{"id": "A", "sequence": "MVTPEGNVSLVDESLLVGV", "msa_path": "examples/msa/seq1.a3m"}, {"id": "B", "sequence": "MVTPEGNVSLVDESLLVGK", "msa_path": "examples/msa/seq2.a3m"}], "ground_truth": {"structure": "examples/ground_truth/complex.pdb"}} diff --git a/hackathon/README_Hackathon.md b/hackathon/README_Hackathon.md new file mode 100644 index 000000000..737386a1d --- /dev/null +++ b/hackathon/README_Hackathon.md @@ -0,0 +1,323 @@ +# M-Boltz Hackathon Template 🧬 + +Welcome to the M-Boltz hackathon! +It is great to have you here! + +This repository is a fork of the [Boltz](https://github.com/jwohlwend/boltz) repository and has been modified for the M-Boltz hackathon to allow a straightforward evaluation of your contributions for the antibody-antigen complex prediction challenge and the allosteric-orthosteric ligand challenge. + +Please read these instructions carefully before you start. + +## Setup βš™οΈ + +First, create a fork of the template repository! + +Different from the original installation instructions, please set up your environment by first using `conda` or `mamba` to create the environment and then use `pip` to install the `boltz` package. + +``` +git clone YOUR_FORKED_REPO_URL +cd +conda env create -f environment.yml --name boltz +conda activate boltz +pip install -e ".[cuda]" +``` + +**_NOTE:_** The hackathon requires at least one CUDA-enabled GPU for running inference. CPU-only installations are not supported for this hackathon, use at your own risk. + +## Download the datasets πŸ“₯ + +Please download the data for both challenges from the links below and place them in the `hackathon_data` folder. +The data includes the actual datasets, pre-computed MSAs, example predictions and evaluations. + +``` +wget https://d2v9mdonbgo0hk.cloudfront.net/hackathon_data.tar.gz +mkdir hackathon_data +tar -xvf hackathon_data.tar.gz -C hackathon_data +``` + +## Quick Start ⚑️ + +To participate in the hackathon: + +1. **Modify the code**: Edit the functions in `hackathon/predict_hackathon.py`: + - `prepare_protein_complex()` or `prepare_protein_ligand()` - Customize input configurations and CLI arguments + - `post_process_protein_complex()` or `post_process_protein_ligand()` - Re-rank or post-process predictions + - You can also modify any Boltz source code in `src/boltz/` as needed + + These functions already contain some minimal code such that the next scripts can be successfully executed. + We explain the functions in more detail below. + +2. **Run predictions**: Execute the prediction script on a validation dataset: + ```bash + python hackathon/predict_hackathon.py \ + --input-jsonl hackathon_data/datasets/abag_public/abag_public.jsonl \ + --msa-dir hackathon_data/datasets/abag_public/msa/ \ + --submission-dir ./my_predictions \ + --intermediate-dir ./tmp/ \ + --result-folder ./my_results + ``` + + - `--input-jsonl` provides information about task type, input molecules, and ground truth for evaluation + - `--msa-dir` contains the pre-computed MSA + - `--submission-dir` is the output directory for the predicted structures + - `--intermediate-dir` is directory for temporary files + - `--result-folder` is the output directory for the evaluation results (metrics) + + **_NOTE:_** If this is your first time using `boltz`, some files (model weights, CCD library) will get downloaded to your machine first. This can take a while and should *not* be interrupted to not corrupt the files. So take the chance, grab a coffee, and talk to some other participants! + +3. **Evaluate**: Results will be automatically computed and saved to the `--result-folder` directory. +Review the metrics to assess your improvements. + +4. **Iterate**: Refine your approach based on evaluation results and repeat! + +5. **Submit**: Before the deadline, push your final code to your forked repository and fill out the [submission form](TBD). + +## Entrypoints for Participants πŸ’» + +### `hackathon/predict_hackathon.py` + +We will evaluate your contributions by calling `hackathon/predict_hackathon.py` (see example above). This script performs the following main steps for each data point (e.g., protein complex) of a dataset (defined in`--input-jsonl`): + +1. Generate one or multiple combinations of Boltz input YAML file (protein and molecule input data) and Boltz CLI arguments + +2. Call Boltz with each specified combination of YAML file and CLI arguments + +3. Post-process and rank the predictions from all combinations and store top 5 predictions in the submission directory + +You can modify steps 1 and 3 by editing the functions in `hackathon/predict_hackathon.py`. + +![Overview of the hackathon workflow](img/overview.svg) + +#### Modifying step 1: Generating input YAML files and CLI arguments + +To adapt step 1 modify the following function for the antibody-antigen complex prediction challenge (the allosteric-orthosteric ligand prediction challenge is similar): + +`def prepare_protein_complex(datapoint_id: str, proteins: list[Protein], input_dict: dict, msa_dir: Optional[Path] = None) -> list[tuple[dict, list[str]]]:` + +This function enables modification of [Boltz inputs](https://github.com/jwohlwend/boltz/tree/main?tab=readme-ov-file#inference) - the YAML file with molecular information (e.g., proteins, ligands, constraints, etc.) and CLI arguments (e.g., the number of diffusion samples or recycling steps). + +This function gets as input: + +- `datapoint_id: str`: The ID of the current datapoint +- `proteins: list[Protein]`: A list of `Protein` objects to be processed (defined in `hackathon_api.Protein`) +- `input_dict: dict`: A pre-filled dictionary containing the YAML definition for that data point +- `msa_dir: Path`: The directory with the precomputed MSA files. MSA files are always provided. + +For example input information see `hackathon_data/datasets/abag_public/abag_public.jsonl`. This information will be automatically converted to the above-specified objects. + +Each protein has attributes + +- `id: str`: The chain ID of the protein +- `sequence: str`: The amino acid sequence of the protein +- `msa: str`: The name of the MSA file within `msa_dir`. We always provide a precomputed MSA. + +Each data point contains three proteins with IDs: `"H"` (heavy chain segment), `"L"` (light chain segment), and `"A"` (antigen). + +The function should return a **list of tuples**, where each tuple contains: + +- A modified `input_dict` with any changes made during preparation, which will be reflected in the Boltz input YAML. +- A list of CLI arguments that should be passed to Boltz for this configuration. + +By returning multiple tuples, you can run Boltz with different configurations for the same datapoint (e.g., different sampling strategies, different constraints, different hyperparameters). Each configuration will be run separately with its own YAML file and CLI argument combination. + +Note that we have already precomputed MSA for all proteins and it will be input alongside the protein sequences. Thus, you can not change the MSA calculation. However, you can post-process the input MSA within the `prepare_protein_complex` function before it is passed to the Boltz model, e.g. save a sub-sampled MSA to a *new* CSV file and ajust the MSA path of the protein in `input_dict` accordingly. You can find example MSA in the provided data. + +#### Step 2: Running Boltz + +With the provided information, the script will then call Boltz once for each configuration. + +You are also welcome to make modifications to the Boltz code as needed. + +#### Modifying step 3: Post-processing and ranking predictions + +Afterwards, the following function gets called: + +`def post_process_protein_complex(datapoint: Datapoint, input_dicts: list[dict[str, Any]], cli_args_list: list[list[str]], prediction_dirs: list[Path]) -> list[Path]:` + +This function enables modification, combining, or re-rank of predicted structures from multiple configurations. It outputs paths for multiple structure candidates for each data point. + +This function receives: +- `datapoint: Datapoint`: The original datapoint object (defined in `hackathon_api.Datapoint`) +- `input_dicts: list[dict[str, Any]]`: A list of input dictionaries used (one per configuration) +- `cli_args_list: list[list[str]]`: A list of CLI arguments used (one per configuration) +- `prediction_dirs: list[Path]`: A list of directories containing prediction results (one per configuration) + +The function should return a list of **Path objects** pointing to the PDB files of the final structure candidates. +The order is important! +The first path will be your top 1 prediction, and we will evaluate up to 5 predictions for each data point. + +#### Allosteric-orthosteric ligand prediction challenge + +For the allosteric-orthosteric ligand challenge, there are similar functions as for antibody-antigen complex challenge explained above. Here are summarized only parts of code that differ between the two challenges, so please first read the above explanations. + +`def prepare_protein_ligand(datapoint_id: str, protein: Protein, ligands: list[SmallMolecule], input_dict: dict, msa_dir: Optional[Path] = None) -> list[tuple[dict, list[str]]]:` + +Here, `protein` is a single protein object and `ligands` is a list containing a single small molecule object (defined in `hackathon_api.SmallMolecule`). +We initially thought of allowing multiple ligands, but for this challenge we will only have a single ligand per data point. + +For example input information see `hackathon_data/datasets/asos_public/asos_public.jsonl`. + +The small molecule has attributes: +- `id`: The ID of the small molecule +- `smiles`: The SMILES string of the small molecule + +This function also returns a **list of tuples** to support multiple configurations per datapoint. + +For post-processing and re-ranking, use the function + +`def post_process_protein_ligand(datapoint: Datapoint, input_dicts: list[dict[str, Any]], cli_args_list: list[list[str]], prediction_dirs: list[Path]) -> list[Path]:` + +This function receives lists of configurations and returns a list of **Path objects** pointing to the ranked PDB files. + +### Dependencies + +Add any additional Python packages to `pyproject.toml` under the `[project.dependencies]` section. +If you need non-Python dependencies, you can add those in `environment.yml`. +We strongly advice against adding non-Python dependencies that are not available through any public conda channel. +If you still want to add them, please install them directly in your machine and make sure that you modify `Dockerfile` accordingly. + +## Evaluation Limits ⏱️ + +When evaluating your contributions your code will run in an environment with the following hardware specs: + +- 1x NVIDIA L40 GPU (48GB) +- 32 CPU cores +- 300 GB RAM + +On this machine the full end-to-end prediction for a single datapoint, including pre-processing, Boltz prediction, post-processing, should complete within 15 minutes on average. +As a reference, one typical antibody-antigen complex with 5 diffusion samples and default settings takes around 80-90 seconds end-to-end on that kind of hardware. + +To protect our proprietary data and ensure a fair competition, the evaluation environment will have **no internet access**. + +## Validation Sets πŸ§ͺ + +For both challenges we provide a validation data set that you can use to test your contributions and track your progress. + +### Antibody-Antigen Complex Prediction Challenge + +The validation set for the antibody-antigen complex challenge comprises of 10 public PDB structures, all released after the cut-off date for Boltz training data. + +To run the prediction and evaluation, use: + +```bash +python hackathon/predict_hackathon.py \ + --input-jsonl hackathon_data/datasets/abag_public/abag_public.jsonl \ + --msa-dir hackathon_data/datasets/abag_public/msa/ \ + --submission-dir \ + --intermediate-dir ./tmp/ \ + --result-folder +``` + +Replace `` with the path to a directory where you want to store your structure predictions and `` with the path to a directory where you want to store the evaluation results. +If you do not provide `--result-folder`, the script will only run the predictions and not the evaluation. + +If you just want to run the evaluation on already existing predictions: + +```bash +python hackathon/evaluate_abag.py \ + --dataset-file hackathon_data/datasets/abag_public/abag_public.jsonl \ + --submission-folder SUBMISSION_DIR \ + --result-folder ./abag_public_evaluation/ +``` + +The evaluation script will compute the Capri-Q docking assessment classification scores (high, medium, acceptable, incorrect, error) for each of your top 5 predictions per data point. +Error means that the prediction or evaluation did not finish due to a programmatic error. +It will then print the distribution of classifications for the top 1 predictions across all data points. +Additionally, it will compute the number of "successful" predictions, i.e., the number of data points for which the top 1 prediction is classified as "acceptable" or better. +You will find more stats in a file `combined_results.csv` in the result folder. + +On the validation set, Boltz-2 with default settings should give you the following distribution of classifications for the top 1 predictions: + +- High: 2/10 +- Medium: 0/10 +- Acceptable: 0/10 +- Incorrect: 8/10 + +The winner of this challenge will be the team with the highest number of successful top 1 predictions on our *internal* test set. +Ties are broken by looking at the number of predictions with β€œhigh” classification, then with β€œmedium” classification and finally with β€œacceptable” classification. + +### Allosteric-Orthosteric Ligand Prediction Challenge + +The validation set for the allosteric-orthosteric ligand challenge comprises of 40 structures that were also used in the recent paper of Nittinger et. al [1]. + +To run the prediction and evaluation, use: + +```bash +python hackathon/predict_hackathon.py \ + --input-jsonl hackathon_data/datasets/asos_public/asos_public.jsonl \ + --msa-dir hackathon_data/datasets/asos_public/msa/ \ + --submission-dir \ + --intermediate-dir ./tmp/ \ + --result-folder +``` + +Replace `` with the path to a directory where you want to store your predictions and `` with the path to a directory where you want to store the evaluation results. +If you do not provide `--result-folder`, the script will only run the predictions and not the evaluation. + +If you just want to run the evaluation on already existing predictions: + +```bash +python hackathon/evaluate_asos.py \ + --dataset-file hackathon_data/datasets/asos_public/asos_public.jsonl \ + --submission-folder SUBMISSION_DIR \ + --result-folder ./asos_public_evaluation/ +``` + +The evaluation script will compute the ligand RMSD for each of your top 5 predictions per data point and print the mean of the top 1 RMSDs across all data points, just the allosteric data points, and just the orthosteric data points. +Additionally, it will compute the mean of the minimum RMSDs in the top 5 predictions and the number of data points with minimum RMSD < 2Γ… in the top 5 predictions. +You will find more stats in a file `combined_results.csv` in the result folder. + +Below you see examples of per-structure RMSD plots that Boltz-2 with default settings should give you, with a mean top-1 RMSD of ~6.26Γ… on this validation set. + +![Example per-structure results for the allosteric-orthosteric ligand challenge](img/allosteric_rmsd.png) +![Example per-structure results for the allosteric-orthosteric ligand challenge](img/orthosteric_rmsd.png) + +The winner of this challenge will be the team with the lowest mean RMSD of the top 1 predictions on our *internal* test set. + +## Submission Format πŸ“¦ + +If you make deeper changes to the provided code, make sure your final predictions are organized in the following structure: +``` +{submission_dir}/ +β”œβ”€β”€ {datapoint_id_1}/ +β”‚ β”œβ”€β”€ model_0.pdb +β”‚ β”œβ”€β”€ model_1.pdb +β”‚ β”œβ”€β”€ model_2.pdb +β”‚ β”œβ”€β”€ model_3.pdb +β”‚ └── model_4.pdb +└── {datapoint_id_2}/ + β”œβ”€β”€ model_0.pdb + └── ... +``` + +## Handing In Your Final Submission πŸŽ‰ + +Before the deadline on **21st October 2025, 17:30 CEST / 11:30 EDT**, please submit your final code by pushing to your forked repository on GitHub. +Then fill out the [submission form](TBD) and enter + +- your group name +- the link to your repository +- the commit SHA you want us to evaluate (if not provided, we will evaluate the latest commit on the `main` branch) +- the challenge you are submitting for (antibody-antigen complex prediction, allosteric-orthosteric ligand prediction) +- link to a short description of your method (e.g., a README file in your repository or a separate document) + +If you want to submit for both challenges, please fill out the form twice. +You can use different commit SHAs for each challenge if you want. + +Before submitting, we advise you to make sure that the following steps work in your repository: + +- check out a fresh clone of your repository +- create the `conda` environment and install the dependencies +- run the prediction and evaluation on the validation set +- check that the submission format is correct +- check that the Docker image builds successfully (run `docker build -t boltz-hackathon .` in the root of your repository) + +## Need Help? πŸ†˜ + +If you are on-site feel free to ask any of the organizers or your fellow participants for help. +If you are joining virtually, please reach out on Slack in the `#m-boltz-hackathon` channel. + +## References + +1. Nittinger, Eva, et al. "Co-folding, the future of docking – prediction of allosteric and orthosteric ligands." Artificial Intelligence in the Life Sciences, vol. 8, 2025, p. 100136. Elsevier, + +Good luck, have fun! πŸš€ diff --git a/hackathon/evaluate_abag.py b/hackathon/evaluate_abag.py new file mode 100644 index 000000000..96f0c8274 --- /dev/null +++ b/hackathon/evaluate_abag.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +import argparse +import os +import sys +import subprocess +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Optional +import pandas as pd + +from hackathon_api import Datapoint + +def parse_args(): + parser = argparse.ArgumentParser(description="Parallel CAPRI-Q evaluation runner (Python port)") + parser.add_argument('--dataset-file', type=str, default=str(Path.cwd() / 'inputs'), help='Path to input JSONL file') + parser.add_argument('--result-folder', type=str, default=str(Path.cwd() / 'outputs'), help='Directory to store result files') + parser.add_argument('--submission-folder', type=str, default=str(Path.cwd() / 'predictions'), help='Directory containing prediction files') + parser.add_argument('--njobs', type=int, default=50, help='Number of parallel jobs to run') + parser.add_argument('--nsamples', type=int, default=5, help='Number of samples to evaluate per structure') + return parser.parse_args() + + +def run_evaluation(gt_dir, gt_structures: dict[str, Any], structure_name: str, i: int, args) -> Optional[pd.DataFrame]: + output_subdir = Path(args.result_folder) / f"{structure_name}_{i}" + output_subdir.mkdir(parents=True, exist_ok=True) + prediction_file = Path(args.submission_folder) / structure_name / f"model_{i}.pdb" + if not prediction_file.exists(): + print(f"No prediction file {prediction_file} found. Skipping.") + return None + + + capriq_cmd = [ + "/capri-q/bin/capriq", + "-a", "--dontwrite", + "-t", f"/app/ground_truth/{gt_structures['structure_complex']}", + "-u", f"/app/ground_truth/{gt_structures['structure_ab']}", + "-u", f"/app/ground_truth/{gt_structures['structure_ligand']}", + "-z", "/app/outputs/", + "-p", "65", + "-o", f"/app/outputs/{structure_name}_{i}_results.txt", + "-l", f"/app/outputs/{structure_name}_{i}_errors.txt", + f"/app/predictions/prediction.pdb", + "&&", + "chown", "-R", f"{os.getuid()}:{os.getgid()}", "/app/outputs" + ] + + docker_cmd = [ + "docker", "run", "--group-add", str(os.getgid()), "--rm", "--network", "none", + "-v", f"{gt_dir}:/app/ground_truth/", + "-v", f"{output_subdir.absolute()}:/app/outputs", + "-v", f"{prediction_file.absolute()}:/app/predictions/prediction.pdb", + "gitlab-registry.in2p3.fr/cmsb-public/capri-q", + "/bin/bash", "-c", + f"{' '.join(capriq_cmd)}" + ] + print(f"Evaluating {structure_name} model {i}... Prediction file: {prediction_file}") + # print(f"Docker command: {' '.join(docker_cmd)}") + try: + subprocess.run(docker_cmd, check=True) + except subprocess.CalledProcessError as e: + print(f"Docker run failed for {structure_name} model {i}. Error: {e}", file=sys.stderr) + return pd.DataFrame({ + 'structure_name': [structure_name], + 'structure_index': [i], + 'nclash': [None], + 'clash_fraction': [None], + 'classification': ['error'], + 'error': [str(e)] + }) + + # load result file + result_file = output_subdir / f"{structure_name}_{i}_results.txt" + if not result_file.exists(): + print(f"No result file {result_file} found. Skipping.") + return pd.DataFrame({ + 'structure_name': [structure_name], + 'structure_index': [i], + 'nclash': [None], + 'clash_fraction': [None], + 'classification': ['error'], + 'error': ['Result file not found'] + }) + + df = pd.read_csv(result_file, sep='\\s+') + df['clash_fraction'] = df['model'].str.replace("/", "").astype(float) / df['nclash'] + df['nclash'] = df['model'].str.replace("/", "").astype(int) + df.drop(columns=['model'], inplace=True) + df['structure_name'] = structure_name + df['structure_index'] = i + return df + +def load_dataset(input_jsonl: str) -> list[Datapoint]: + with open(input_jsonl, 'r') as f: + data = [Datapoint.from_json(line) for line in f] + return data + +def main(): + args = parse_args() + input_jsonl = args.dataset_file + dataset = load_dataset(input_jsonl) + gt_dir = Path(args.dataset_file).parent / 'ground_truth' + result_dfs = [] + with ThreadPoolExecutor(max_workers=args.njobs) as executor: + futures = [] + for datapoint in dataset: + structure_name = datapoint.datapoint_id + gt_structures = datapoint.ground_truth + for i in range(args.nsamples): + futures.append(executor.submit(run_evaluation, gt_dir, gt_structures, structure_name, i, args)) + for future in as_completed(futures): + result = future.result() + if result is not None: + result_dfs.append(result) + combined_results = pd.concat(result_dfs, ignore_index=True) + combined_results.to_csv(Path(args.result_folder) / 'combined_results.csv', index=False) + + # select structure 0 and count "classification" + nsuccessful = 0 + good_classes = ['high', 'medium', 'acceptable'] + bad_classes = ['incorrect', 'error'] + for classification in good_classes + bad_classes: + n = len(combined_results[(combined_results['structure_index'] == 0) & (combined_results['classification'].str.contains(classification))]) + if classification in good_classes: + nsuccessful += n + print(f"Number of {classification} classifications in top 1: {n}") + + # print number of successful top 1 predictions + print(f"Number of successful top 1 predictions: {nsuccessful} out of {len(dataset)}") + + print("All evaluations completed.") + +if __name__ == "__main__": + main() diff --git a/hackathon/evaluate_asos.py b/hackathon/evaluate_asos.py new file mode 100644 index 000000000..cddddb556 --- /dev/null +++ b/hackathon/evaluate_asos.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 +""" +Evaluate ASOS predictions by calculating ligand RMSD values. + +This script aligns predicted structures with experimental structures using PyMOL, +then calculates RMSD values for ligands using the Hungarian algorithm for optimal assignment. +""" + +import os +import argparse +import numpy as np +import pandas as pd +from scipy.optimize import linear_sum_assignment +from Bio import PDB +import json +import matplotlib.pyplot as plt +import tempfile +import shutil + + +def best_fit_rmsd(exp, pred): + """ + Calculate RMSD using the Hungarian algorithm for optimal atom assignment. + + Args: + exp: Experimental coordinates (numpy array) + pred: Predicted coordinates (numpy array) + + Returns: + RMSD value + """ + # Calculate pairwise distances + distances = np.linalg.norm(exp[:, np.newaxis] - pred[np.newaxis, :], axis=2) + + # Use the Hungarian algorithm to find the optimal assignment + row_indices, col_indices = linear_sum_assignment(distances) + + # Calculate the RMSD based on the optimal assignment + rmsd = np.sqrt(np.mean(np.sum((exp[row_indices] - pred[col_indices]) ** 2, axis=1))) + + return rmsd + + +def get_coordinates(structure, ligand_name): + """ + Extract coordinates for a specific ligand from a PDB structure. + + Args: + structure: BioPython structure object + ligand_name: Residue name of the ligand + + Returns: + List of coordinates [[x, y, z], ...] + """ + coordinates = [] # Initialize an empty list to store coordinates + + for model in structure: + for chain in model: + for residue in chain: + if residue.get_resname() == f"{ligand_name}": + for atom in residue.get_atoms(): + vector = atom.get_vector() # Get the vector for the atom + coordinates.append(vector) # Append the vector to the list + + # Now, if you want to extract just the x, y, z values: + coordinates_list = [[vector[0], vector[1], vector[2]] for vector in coordinates] + + return coordinates_list + + +def get_ligand_rmsd(exp_file, pred_file, ligand_name_exp, ligand_name_pred): + """ + Calculate ligand RMSD between experimental and predicted structures. + + Args: + exp_file: Path to experimental structure file + pred_file: Path to predicted structure file + ligand_name_exp: Ligand residue name in experimental structure + ligand_name_pred: Ligand residue name in predicted structure + + Returns: + RMSD value + """ + parser = PDB.PDBParser() + structure_exp = parser.get_structure('experiment', exp_file) + + parser = PDB.PDBParser() + structure_pred = parser.get_structure('model', pred_file) + + coordinates_exp = get_coordinates(structure_exp, f"{ligand_name_exp}") + coordinates_pred = get_coordinates(structure_pred, f"{ligand_name_pred}") + + exp = np.array(coordinates_exp) + pred = np.array(coordinates_pred) + + # Calculate best-fit RMSD + rmsd = best_fit_rmsd(exp, pred) + + return rmsd + + +def load_dataset(dataset_folder, dataset_file): + """ + Load the ASOS dataset and extract ligand information. + + Args: + dataset_folder: Path to dataset folder + dataset_file: Dataset filename + + Returns: + Tuple of (dataset, ligand_info) + """ + dataset = [json.loads(line) for line in open(os.path.join(dataset_folder, dataset_file))] + print(f"Loaded {len(dataset)} samples from the dataset.") + + ligand_info = {} + for datapoint in dataset: + for ligand in datapoint['ground_truth']['ligand_types']: + ligand_id = datapoint['datapoint_id'] + if ligand_id not in ligand_info: + ligand_info[ligand_id] = [] # Initialize a list for each ligand ID + ligand_info[ligand_id].append({ # Append the type and CCD information + 'type': ligand['type'], + 'ccd': ligand['ccd'], + 'chain_prot': ligand['chain'], + 'chain_lig': ligand.get("ligand_chain") or ligand["ligand_id"], # Default to "L" if not provided + }) + + return dataset, ligand_info + + +def align_structures(dataset, ligand_info, dataset_folder, submission_folder, tempfolder): + """ + Align predicted structures with experimental structures using PyMOL. + + Args: + dataset: Dataset list + ligand_info: Dictionary of ligand information + dataset_folder: Path to dataset folder + submission_folder: Path to submission folder with predictions + tempfolder: Temporary folder for aligned structures + """ + # Clean and create temp folder + shutil.rmtree(tempfolder, ignore_errors=True) + os.makedirs(tempfolder, exist_ok=True) + + # Generate PyMOL alignment script + for datapoint in dataset: + datapoint_id = datapoint['datapoint_id'] + gt_structure = datapoint['ground_truth']["structure"] + keep = [] + for chain in ligand_info[datapoint_id]: + keep.append(chain["chain_prot"]) + keep.append(chain["chain_lig"]) + + keep_selection = "+".join(keep) + with open(os.path.join(tempfolder, "align.pml"), "a") as a: + a.write(f"load {dataset_folder}/ground_truth/{gt_structure}, exp\n") + a.write(f"sele not chain {keep_selection}\n") + a.write(f"remove sele\n") + for model in range(5): + a.write(f"load {submission_folder}/{datapoint_id}/model_{model}.pdb, pred_{model}\n") + a.write(f"align pred_{model}, exp\n") + a.write(f"save {tempfolder}/{datapoint_id}_model_{model}.pdb, pred_{model}\n") + a.write(f"save {tempfolder}/{datapoint_id}_exp.pdb, exp\n") + a.write("delete all\n") + + # Run PyMOL alignment + sh_file = os.path.join(tempfolder, "pymol_align.sh") + with open(sh_file, "w") as p: + p.write("module load pymol\n") + p.write(f"pymol -c {os.path.join(tempfolder, 'align.pml')}\n") + + os.system("chmod u+x " + sh_file) + print("Running PyMOL alignment...") + os.system(sh_file) + print("PyMOL alignment complete.") + + +def calculate_rmsds(dataset, ligand_info, tempfolder): + """ + Calculate RMSD values for all ligands across all models. + + Args: + dataset: Dataset list + ligand_info: Dictionary of ligand information + tempfolder: Temporary folder with aligned structures + + Returns: + Dictionary of RMSD values + """ + rmsds = {} + for datapoint in dataset: + exp_file = os.path.join(tempfolder, f"{datapoint['datapoint_id']}_exp.pdb") + try: + for ligand_i in range(len(ligand_info[datapoint['datapoint_id']])): + ligand = ligand_info[datapoint['datapoint_id']][ligand_i] + rmsds_ligand = [] + for model in range(5): + pred_file = os.path.join(tempfolder, f"{datapoint['datapoint_id']}_model_{model}.pdb") + rmsds_ligand.append(get_ligand_rmsd(exp_file, pred_file, ligand["ccd"], "LIG")) + + rmsds[f"{datapoint['datapoint_id']}"] = { + "rmsd": rmsds_ligand, + "type": ligand['type'] + } + + except Exception as e: + print(f"Unsuccessful for {datapoint['datapoint_id']}") + import traceback + traceback.print_exc() + + return rmsds + + +def plot_results(rmsds, result_folder): + """ + Create boxplots for orthosteric and allosteric ligand RMSD values. + + Args: + rmsds: Dictionary of RMSD values + result_folder: Folder to save plots + """ + os.makedirs(result_folder, exist_ok=True) + + # Create a horizontal boxplot for orthosteric ligands + rmsd_subset = {key: value["rmsd"] for key, value in rmsds.items() if value["type"] == "orthosteric"} + if rmsd_subset: + data = [values for values in rmsd_subset.values()] + labels = list(rmsd_subset.keys()) + plt.figure(figsize=(10, 6)) + plt.boxplot(data, vert=False) # vert=False makes the boxplot horizontal + plt.yticks(range(1, len(labels) + 1), labels) # Set the y-ticks to the model names + plt.xlabel('RMSD Values') + plt.title('Orthosteric Ligand RMSD Values') + plt.tight_layout() + plt.savefig(os.path.join(result_folder, 'orthosteric_rmsd.png'), dpi=300, bbox_inches='tight') + plt.close() + print(f"Saved orthosteric RMSD plot to {result_folder}/orthosteric_rmsd.png") + + # Create a horizontal boxplot for allosteric ligands + rmsd_subset = {key: value["rmsd"] for key, value in rmsds.items() if value["type"] == "allosteric"} + if rmsd_subset: + data = [values for values in rmsd_subset.values()] + labels = list(rmsd_subset.keys()) + plt.figure(figsize=(10, 6)) + plt.boxplot(data, vert=False) # vert=False makes the boxplot horizontal + plt.yticks(range(1, len(labels) + 1), labels) # Set the y-ticks to the model names + plt.xlabel('RMSD Values') + plt.title('Allosteric Ligand RMSD Values') + plt.tight_layout() + plt.savefig(os.path.join(result_folder, 'allosteric_rmsd.png'), dpi=300, bbox_inches='tight') + plt.close() + print(f"Saved allosteric RMSD plot to {result_folder}/allosteric_rmsd.png") + + +def save_results(rmsds, result_folder): + """ + Save RMSD results to a JSON file. + + Args: + rmsds: Dictionary of RMSD values + result_folder: Folder to save results + """ + os.makedirs(result_folder, exist_ok=True) + + # CSV with combined results + output_file = os.path.join(result_folder, 'combined_results.csv') + df_metrics = pd.DataFrame([ + { + "datapoint_id": key, + "type": value["type"], + "top1_rmsd": value["rmsd"][0], + "top5_mean_rmsd": np.mean(value["rmsd"][:5]), + "top5_min_rmsd": np.min(value["rmsd"][:5]), + "rmsd_under_2A": any(rmsd < 2.0 for rmsd in value["rmsd"][:5]), + "rmsd_model_0": value["rmsd"][0], + "rmsd_model_1": value["rmsd"][1], + "rmsd_model_2": value["rmsd"][2], + "rmsd_model_3": value["rmsd"][3], + "rmsd_model_4": value["rmsd"][4], + } + for key, value in rmsds.items() + ]) + df_metrics.to_csv(output_file, index=False) + print(f"Saved metrics results to {output_file}") + + return df_metrics + +def main(): + parser = argparse.ArgumentParser( + description='Evaluate ASOS predictions by calculating ligand RMSD values.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + '--dataset-folder', + required=False, + help='Path to the dataset folder containing the JSONL file and ground_truth subdirectory' + ) + parser.add_argument( + '--dataset-file', + required=True, + help='Name of the dataset JSONL file' + ) + parser.add_argument( + '--submission-folder', + required=True, + help='Path to the submission folder containing predicted structures' + ) + parser.add_argument( + '--result-folder', + default='./evaluation_results', + help='Path to save evaluation results and plots' + ) + parser.add_argument( + '--temp-folder', + default='./tmp', + help='Path to temporary folder for aligned structures' + ) + + args = parser.parse_args() + + print("=" * 80) + print("ASOS Ligand RMSD Evaluation") + print("=" * 80) + print(f"Dataset folder: {args.dataset_folder}") + print(f"Dataset file: {args.dataset_file}") + print(f"Submission folder: {args.submission_folder}") + print(f"Result folder: {args.result_folder}") + print(f"Temp folder: {args.temp_folder}") + print("=" * 80) + + # Load dataset + print("\n1. Loading dataset...") + if not args.dataset_folder: + args.dataset_folder = os.path.dirname(os.path.abspath(args.dataset_file)) + + dataset, ligand_info = load_dataset(args.dataset_folder, args.dataset_file) + print(f"Found ligand information for {len(ligand_info)} datapoints") + + # Align structures + print("\n2. Aligning structures with PyMOL...") + align_structures(dataset, ligand_info, args.dataset_folder, args.submission_folder, args.temp_folder) + + # Calculate RMSDs + print("\n3. Calculating ligand RMSD values...") + rmsds = calculate_rmsds(dataset, ligand_info, args.temp_folder) + print(f"Calculated RMSD for {len(rmsds)} ligands") + + # Save results + print("\n4. Saving results...") + df_metrics = save_results(rmsds, args.result_folder) + + # Plot results + print("\n5. Generating plots...") + plot_results(rmsds, args.result_folder) + + print("\n" + "=" * 80) + print("Evaluation complete!") + print(f"Top 1 model RMSD summary:") + # all, orthosteric, allosteric + for ligand_type in [None, "orthosteric", "allosteric"]: + if ligand_type: + df_filtered = df_metrics[df_metrics['type'] == ligand_type] + label = ligand_type.capitalize() + else: + df_filtered = df_metrics + label = "All" + + if not df_filtered.empty: + mean_top1_rmsd = df_filtered['top1_rmsd'].mean() + mean_top5_min_rmsd = df_filtered['top5_min_rmsd'].mean() + num_below_2A = df_filtered['rmsd_under_2A'].sum() + total = len(df_filtered) + print(f"{label} ligands - Mean Top 1 RMSD: {mean_top1_rmsd:.2f}, Mean Top 5 Min RMSD: {mean_top5_min_rmsd:.2f}, " + f"RMSD < 2Γ… in Top 5: {num_below_2A}/{total} ({(num_below_2A/total)*100:.1f}%)") + else: + print(f"No data for {label} ligands.") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/hackathon/hackathon_api.py b/hackathon/hackathon_api.py new file mode 100644 index 000000000..20260012c --- /dev/null +++ b/hackathon/hackathon_api.py @@ -0,0 +1,46 @@ + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional + +from dataclasses_json import dataclass_json + + +# Enum for task_type +class TaskType(str, Enum): + """Enum for valid hackathon task types.""" + + PROTEIN_COMPLEX = "protein_complex" + PROTEIN_LIGAND = "protein_ligand" + + +@dataclass_json +@dataclass +class Protein: + """Represents a protein sequence for Boltz prediction.""" + + id: str + sequence: str + msa: Optional[str] = None # A3M path (always provided for hackathon) + + +@dataclass_json +@dataclass +class SmallMolecule: + """Represents a small molecule/ligand for Boltz prediction.""" + + id: str + smiles: Optional[str] = None # SMILES string for the ligand + + +@dataclass_json +@dataclass +class Datapoint: + """Represents a single hackathon datapoint for Boltz prediction.""" + + datapoint_id: str + task_type: TaskType + proteins: list[Protein] + ligands: Optional[list[SmallMolecule]] = None # We will only have a SINGLE ligand for the allosteric/orthosteric binding challenge + ground_truth: Optional[dict[str, Any]] = None + diff --git a/hackathon/img/allosteric_rmsd.png b/hackathon/img/allosteric_rmsd.png new file mode 100644 index 000000000..9ed065d68 Binary files /dev/null and b/hackathon/img/allosteric_rmsd.png differ diff --git a/hackathon/img/orthosteric_rmsd.png b/hackathon/img/orthosteric_rmsd.png new file mode 100644 index 000000000..d9ba63f10 Binary files /dev/null and b/hackathon/img/orthosteric_rmsd.png differ diff --git a/hackathon/img/overview.drawio b/hackathon/img/overview.drawio new file mode 100644 index 000000000..b41d1486c --- /dev/null +++ b/hackathon/img/overview.drawio @@ -0,0 +1 @@ +5Vpdc5s4FP01nkkf7BEIsP0YO91up+1sZtOZbp52FJCxJhhRIRK7v34lEN+ixjHka5+MrqRrcXSudK7EBK53+08MRdtv1MPBxATefgKvJqZpmAsofqTlkFkWtp0ZfEY81ag03JBfWBmBsibEw3GtIac04CSqG10ahtjlNRtijD7Wm21oUP/XCPm4ZbhxUdC2/iAe3+ZvAUr7n5j42/yfDaBqdihvrAzxFnn0sWKCHydwzSjl2dNuv8aBBC/HJev3R0dtMTCGQ96nQ3D79+4n+OtwFdq2/SOa0i/oyxSqwT2gIFFvrEbLDzkEjCahh6UXYwJXj1vC8U2EXFn7KCZd2LZ8F6hq5Q4zjvedAzWK1xe8wXSHOTuIJnmHHFxFGTsvP5YTAC1l21bAXyobUnPuF65LWMSDQuYElEz7VJQqiOgAC9AdDlbIvffTbmsaUCZqPbxBSSDeZrUhQZBbJyb0oLtxN+PAay16wuuMBq/z+klogZ4ozQdAydBFpZMRg4rxm44vn68QR/KlKEltab1wXW2Smz3y0DRVm02Fm89XR3wIs8ZNaqrMlPMzkSvaKkKeR0J/yqmYnEvRBET7tBvIa1gGWrUu73xsrN9RfC93gkOE8yZ3TDPeTihOGXDXoLT/qEGoae1kNjjObBQQPxTPAd5IZ5LiRGxWl8qcjn3VwXVNRHTS3wR1+hfBX6G/oaO/YQ7Bf6OT/89Ct7MpshIK5ZeopwmPEh6LJ/P/wBpoNhbNZZs1Sw1p4BCcmbcwwp6QcqoY0hDXQRGvyQ7/SABndl68VXimhat9rXRQpU6cYpowV/212sA4Yj7OW6ldTo7qt2AyHCBOHuoC9BxozHcXTcb7i50iMHLVqxEcY8WO8/TQMQzjlOBJS9eYETFGzE6JKNiOKPulAgr2kmfXjHJMwvOU2Vfio9B7ujLr9vzt5vJEtwOGkfy7dtYzRBgZ9TAqtqRKGDmaMLIHCCNN2tIzjoyZEAjVMJpZlj1aKFntUMpRGi6W0q6XjKFDpUGaqsQVz9dZ7pLPnt2QndCyqxNwtL2jjpvKCctGUE5f8Sq9ZtTqFeif0kiQiLJEcKwrqAYMnyECZXE8UOyRAkV3nFKFVSMW3GyxkDIhkIOZeojdXzD/7kKMZz1J85X050P2K2tMYNXqDGB++KBVGBHDEkzRxqvl1CBde0HEqJe4spolYj2Xh40b4idMRAINpR7ZiLGZQOmUftMvpomfpDR2xPNkZy1V6mQaYR21jDY9irOOKj+GyP/mz8QPeRYtC9WHDoZM7NUFCYX8LDliZH3WXz/L12V+KkyV+2ZbU9PWVG1ns9nE7j586c0bV8y23AhemDnGvMEcq80c7cnZEEp2ceIO3Pm+A2+ErZ3LtBooQavuItuwVa/fbIEtR80FOtvmW46esBsabd0rufvmGAobqZbuBFynEQc5ptBJiiZlQ+9SXl1JxAIUx8RtyMY94Vn2BS1TlW8rdaVglIWaXsy6AVhTm1NhWVrH5Kbw1VSbT5egaoWvSdCegddTn5wZnxboWMVOjU+rcZEAm7dUA8Znj2uq/tyy57DCLTBbOPO+/LIto8kv8OL8Mvqu7M9DMHMO6wRrXhz1JVjz0LXIiUYg2PKtbK8wzzFyUJpR1x/duiPHGA9d3TXoy6ZFNOYi93FxHBe5EMNThsL7NCvCHnE59tKRscTlCcNx+56D01oKJU9MTWCXDrIcqvct3ltTG82rNNuxWmpDu9AMkUkVa+PZEVtdXY2FZvse/ASpt6RaPFPcnJYuivqd/BzqXzCLvLsy1UsfshorqwHvKv1z6nR3QJvuEI6krnMSVqiA5XN6UCN9+4iEsaSE3zizU4n6W0PbnJuzxgY1byczBrBHwrudEWoWnCAgUdwFSRXbOMo+4duQvQRpDDpauUiu0nEsNrZzPY0gf1XoLJ8RnXa2Al83OrbuS7mR0NF/KLd8pp2ut0LMzsOLC3pZwMiV62ntjPw17W5nfRbYFHO6gNEdHT1By4li+XFwlk6Un1jDj/8B \ No newline at end of file diff --git a/hackathon/img/overview.svg b/hackathon/img/overview.svg new file mode 100644 index 000000000..715f0bc87 --- /dev/null +++ b/hackathon/img/overview.svg @@ -0,0 +1,4 @@ + + + +
Data point
- ID
- Task type

Data point...
Boltz outputs 2
Boltz outputs 2
Boltz outputs 1
Boltz outputs 1
Protein
Ligand
MSA
Protein...
Ground truth
Ground t...
prepare data point and produce run configurations for Boltz
prepare data point and produce run configurat...
[(input data 1, CLI args 1), (input data 2, CLI args 2), ...]
[(input data 1, CLI args 1), (input data 2, CLI args 2),...
...
...
postprocess and re-rank predicted structures
to produce top 5 predictions
postprocess and re-rank predicted structu...
[ model_0.pdb, ..., model_4.pdb ]
[ model_0.pdb, ..., model_4.pdb ]
evaluation against ground truth data
evaluation agains...
1
1
2
2
3
3
run Boltz on each configuration
run Boltz...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/hackathon/predict_hackathon.py b/hackathon/predict_hackathon.py new file mode 100644 index 000000000..bb66ed319 --- /dev/null +++ b/hackathon/predict_hackathon.py @@ -0,0 +1,396 @@ +# predict_hackathon.py +import argparse +import json +import os +import shutil +import subprocess +from collections.abc import Iterable +from pathlib import Path +from typing import Any, List, Optional + +import yaml +from hackathon_api import Datapoint, Protein, SmallMolecule + +# --------------------------------------------------------------------------- +# ---- Participants should modify these four functions ---------------------- +# --------------------------------------------------------------------------- + +def prepare_protein_complex(datapoint_id: str, proteins: List[Protein], input_dict: dict, msa_dir: Optional[Path] = None) -> List[tuple[dict, List[str]]]: + """ + Prepare input dict and CLI args for a protein complex prediction. + You can return multiple configurations to run by returning a list of (input_dict, cli_args) tuples. + Args: + datapoint_id: The unique identifier for this datapoint + proteins: List of protein sequences to predict as a complex + input_dict: Prefilled input dict + msa_dir: Directory containing MSA files (for computing relative paths) + Returns: + List of tuples of (final input dict that will get exported as YAML, list of CLI args). Each tuple represents a separate configuration to run. + """ + # Please note: + # `proteins`` will contain 3 chains + # H,L: heavy and light chain of the Fv or Fab region + # A: the antigen + # + # you can modify input_dict to change the input yaml file going into the prediction, e.g. + # ``` + # input_dict["constraints"] = [{ + # "contact": { + # "token1" : [CHAIN_ID, RES_IDX/ATOM_NAME], + # "token1" : [CHAIN_ID, RES_IDX/ATOM_NAME] + # } + # }] + # ``` + # + # will add contact constraints to the input_dict + + # Example: predict 5 structures + cli_args = ["--diffusion_samples", "5"] + return [(input_dict, cli_args)] + +def prepare_protein_ligand(datapoint_id: str, protein: Protein, ligands: list[SmallMolecule], input_dict: dict, msa_dir: Optional[Path] = None) -> List[tuple[dict, List[str]]]: + """ + Prepare input dict and CLI args for a protein-ligand prediction. + You can return multiple configurations to run by returning a list of (input_dict, cli_args) tuples. + Args: + datapoint_id: The unique identifier for this datapoint + protein: The protein sequence + ligands: A list of a single small molecule ligand object + input_dict: Prefilled input dict + msa_dir: Directory containing MSA files (for computing relative paths) + Returns: + List of tuples of (final input dict that will get exported as YAML, list of CLI args). Each tuple represents a separate configuration to run. + """ + # Please note: + # `protein` is a single-chain target protein sequence with id A + # `ligands` contains a single small molecule ligand object with unknown binding sites + # you can modify input_dict to change the input yaml file going into the prediction, e.g. + # ``` + # input_dict["constraints"] = [{ + # "contact": { + # "token1" : [CHAIN_ID, RES_IDX/ATOM_NAME], + # "token1" : [CHAIN_ID, RES_IDX/ATOM_NAME] + # } + # }] + # ``` + # + # will add contact constraints to the input_dict + + # Example: predict 5 structures + cli_args = ["--diffusion_samples", "5"] + return [(input_dict, cli_args)] + +def post_process_protein_complex(datapoint: Datapoint, input_dicts: List[dict[str, Any]], cli_args_list: List[list[str]], prediction_dirs: List[Path]) -> List[Path]: + """ + Return ranked model files for protein complex submission. + Args: + datapoint: The original datapoint object + input_dicts: List of input dictionaries used for predictions (one per config) + cli_args_list: List of command line arguments used for predictions (one per config) + prediction_dirs: List of directories containing prediction results (one per config) + Returns: + Sorted pdb file paths that should be used as your submission. + """ + # Collect all PDBs from all configurations + all_pdbs = [] + for prediction_dir in prediction_dirs: + config_pdbs = sorted(prediction_dir.glob(f"{datapoint.datapoint_id}_config_*_model_*.pdb")) + all_pdbs.extend(config_pdbs) + + # Sort all PDBs and return their paths + all_pdbs = sorted(all_pdbs) + return all_pdbs + +def post_process_protein_ligand(datapoint: Datapoint, input_dicts: List[dict[str, Any]], cli_args_list: List[list[str]], prediction_dirs: List[Path]) -> List[Path]: + """ + Return ranked model files for protein-ligand submission. + Args: + datapoint: The original datapoint object + input_dicts: List of input dictionaries used for predictions (one per config) + cli_args_list: List of command line arguments used for predictions (one per config) + prediction_dirs: List of directories containing prediction results (one per config) + Returns: + Sorted pdb file paths that should be used as your submission. + """ + # Collect all PDBs from all configurations + all_pdbs = [] + for prediction_dir in prediction_dirs: + config_pdbs = sorted(prediction_dir.glob(f"{datapoint.datapoint_id}_config_*_model_*.pdb")) + all_pdbs.extend(config_pdbs) + + # Sort all PDBs and return their paths + all_pdbs = sorted(all_pdbs) + return all_pdbs + +# ----------------------------------------------------------------------------- +# ---- End of participant section --------------------------------------------- +# ----------------------------------------------------------------------------- + + +DEFAULT_OUT_DIR = Path("predictions") +DEFAULT_SUBMISSION_DIR = Path("submission") +DEFAULT_INPUTS_DIR = Path("inputs") + +ap = argparse.ArgumentParser( + description="Hackathon scaffold for Boltz predictions", + epilog="Examples:\n" + " Single datapoint: python predict_hackathon.py --input-json examples/specs/example_protein_ligand.json --msa-dir ./msa --submission-dir submission --intermediate-dir intermediate\n" + " Multiple datapoints: python predict_hackathon.py --input-jsonl examples/test_dataset.jsonl --msa-dir ./msa --submission-dir submission --intermediate-dir intermediate", + formatter_class=argparse.RawDescriptionHelpFormatter +) + +input_group = ap.add_mutually_exclusive_group(required=True) +input_group.add_argument("--input-json", type=str, + help="Path to JSON datapoint for a single datapoint") +input_group.add_argument("--input-jsonl", type=str, + help="Path to JSONL file with multiple datapoint definitions") + +ap.add_argument("--msa-dir", type=Path, + help="Directory containing MSA files (for computing relative paths in YAML)") +ap.add_argument("--submission-dir", type=Path, required=False, default=DEFAULT_SUBMISSION_DIR, + help="Directory to place final submissions") +ap.add_argument("--intermediate-dir", type=Path, required=False, default=Path("hackathon_intermediate"), + help="Directory to place generated input YAML files and predictions") +ap.add_argument("--group-id", type=str, required=False, default=None, + help="Group ID to set for submission directory (sets group rw access if specified)") +ap.add_argument("--result-folder", type=Path, required=False, default=None, + help="Directory to save evaluation results. If set, will automatically run evaluation after predictions.") + +args = ap.parse_args() + +def _prefill_input_dict(datapoint_id: str, proteins: Iterable[Protein], ligands: Optional[list[SmallMolecule]] = None, msa_dir: Optional[Path] = None) -> dict: + """ + Prepare input dict for Boltz YAML. + """ + seqs = [] + for p in proteins: + if msa_dir and p.msa: + if Path(p.msa).is_absolute(): + msa_full_path = Path(p.msa) + else: + msa_full_path = msa_dir / p.msa + try: + msa_relative_path = os.path.relpath(msa_full_path, Path.cwd()) + except ValueError: + msa_relative_path = str(msa_full_path) + else: + msa_relative_path = p.msa + entry = { + "protein": { + "id": p.id, + "sequence": p.sequence, + "msa": msa_relative_path + } + } + seqs.append(entry) + if ligands: + def _format_ligand(ligand: SmallMolecule) -> dict: + output = { + "ligand": { + "id": ligand.id, + "smiles": ligand.smiles + } + } + return output + + for ligand in ligands: + seqs.append(_format_ligand(ligand)) + doc = { + "version": 1, + "sequences": seqs, + } + return doc + +def _run_boltz_and_collect(datapoint) -> None: + """ + New flow: prepare input dict, write yaml, run boltz, post-process, copy submissions. + """ + out_dir = args.intermediate_dir / "predictions" + out_dir.mkdir(parents=True, exist_ok=True) + subdir = args.submission_dir / datapoint.datapoint_id + subdir.mkdir(parents=True, exist_ok=True) + + # Prepare input dict and CLI args + base_input_dict = _prefill_input_dict(datapoint.datapoint_id, datapoint.proteins, datapoint.ligands, args.msa_dir) + + if datapoint.task_type == "protein_complex": + configs = prepare_protein_complex(datapoint.datapoint_id, datapoint.proteins, base_input_dict, args.msa_dir) + elif datapoint.task_type == "protein_ligand": + configs = prepare_protein_ligand(datapoint.datapoint_id, datapoint.proteins[0], datapoint.ligands, base_input_dict, args.msa_dir) + else: + raise ValueError(f"Unknown task_type: {datapoint.task_type}") + + # Run boltz for each configuration + all_input_dicts = [] + all_cli_args = [] + all_pred_subfolders = [] + + input_dir = args.intermediate_dir / "input" + input_dir.mkdir(parents=True, exist_ok=True) + + for config_idx, (input_dict, cli_args) in enumerate(configs): + # Write input YAML with config index suffix + yaml_path = input_dir / f"{datapoint.datapoint_id}_config_{config_idx}.yaml" + with open(yaml_path, "w") as f: + yaml.safe_dump(input_dict, f, sort_keys=False) + + # Run boltz + cache = os.environ.get("BOLTZ_CACHE", str(Path.home() / ".boltz")) + fixed = [ + "boltz", "predict", str(yaml_path), + "--devices", "1", + "--out_dir", str(out_dir), + "--cache", cache, + "--no_kernels", + "--output_format", "pdb", + ] + cmd = fixed + cli_args + print(f"Running config {config_idx}:", " ".join(cmd), flush=True) + subprocess.run(cmd, check=True) + + # Compute prediction subfolder for this config + pred_subfolder = out_dir / f"boltz_results_{datapoint.datapoint_id}_config_{config_idx}" / "predictions" / f"{datapoint.datapoint_id}_config_{config_idx}" + + all_input_dicts.append(input_dict) + all_cli_args.append(cli_args) + all_pred_subfolders.append(pred_subfolder) + + # Post-process and copy submissions + if datapoint.task_type == "protein_complex": + ranked_files = post_process_protein_complex(datapoint, all_input_dicts, all_cli_args, all_pred_subfolders) + elif datapoint.task_type == "protein_ligand": + ranked_files = post_process_protein_ligand(datapoint, all_input_dicts, all_cli_args, all_pred_subfolders) + else: + raise ValueError(f"Unknown task_type: {datapoint.task_type}") + + if not ranked_files: + raise FileNotFoundError(f"No model files found for {datapoint.datapoint_id}") + + for i, file_path in enumerate(ranked_files): + target = subdir / (f"model_{i}.pdb" if file_path.suffix == ".pdb" else f"model_{i}{file_path.suffix}") + shutil.copy2(file_path, target) + print(f"Saved: {target}") + + if args.group_id: + try: + subprocess.run(["chgrp", "-R", args.group_id, str(subdir)], check=True) + subprocess.run(["chmod", "-R", "g+rw", str(subdir)], check=True) + except Exception as e: + print(f"WARNING: Failed to set group ownership or permissions: {e}") + +def _load_datapoint(path: Path): + """Load JSON datapoint file.""" + with open(path) as f: + return Datapoint.from_json(f.read()) + +def _run_evaluation(input_file: str, task_type: str, submission_dir: Path, result_folder: Path): + """ + Run the appropriate evaluation script based on task type. + + Args: + input_file: Path to the input JSON or JSONL file + task_type: Either "protein_complex" or "protein_ligand" + submission_dir: Directory containing prediction submissions + result_folder: Directory to save evaluation results + """ + script_dir = Path(__file__).parent + + if task_type == "protein_complex": + eval_script = script_dir / "evaluate_abag.py" + cmd = [ + "python", str(eval_script), + "--dataset-file", input_file, + "--submission-folder", str(submission_dir), + "--result-folder", str(result_folder) + ] + elif task_type == "protein_ligand": + eval_script = script_dir / "evaluate_asos.py" + cmd = [ + "python", str(eval_script), + "--dataset-file", input_file, + "--submission-folder", str(submission_dir), + "--result-folder", str(result_folder) + ] + else: + raise ValueError(f"Unknown task_type: {task_type}") + + print(f"\n{'=' * 80}") + print(f"Running evaluation for {task_type}...") + print(f"Command: {' '.join(cmd)}") + print(f"{'=' * 80}\n") + + subprocess.run(cmd, check=True) + print(f"\nEvaluation complete. Results saved to {result_folder}") + +def _process_jsonl(jsonl_path: str, msa_dir: Optional[Path] = None): + """Process multiple datapoints from a JSONL file.""" + print(f"Processing JSONL file: {jsonl_path}") + + for line_num, line in enumerate(Path(jsonl_path).read_text().splitlines(), 1): + if not line.strip(): + continue + + print(f"\n--- Processing line {line_num} ---") + + try: + datapoint = Datapoint.from_json(line) + _run_boltz_and_collect(datapoint) + + except json.JSONDecodeError as e: + print(f"ERROR: Invalid JSON on line {line_num}: {e}") + continue + except Exception as e: + print(f"ERROR: Failed to process datapoint on line {line_num}: {e}") + raise e + continue + +def _process_json(json_path: str, msa_dir: Optional[Path] = None): + """Process a single datapoint from a JSON file.""" + print(f"Processing JSON file: {json_path}") + + try: + datapoint = _load_datapoint(Path(json_path)) + _run_boltz_and_collect(datapoint) + except Exception as e: + print(f"ERROR: Failed to process datapoint: {e}") + raise + +def main(): + """Main entry point for the hackathon scaffold.""" + # Determine task type from first datapoint for evaluation + task_type = None + input_file = None + + if args.input_json: + input_file = args.input_json + _process_json(args.input_json, args.msa_dir) + # Get task type from the single datapoint + try: + datapoint = _load_datapoint(Path(args.input_json)) + task_type = datapoint.task_type + except Exception as e: + print(f"WARNING: Could not determine task type: {e}") + elif args.input_jsonl: + input_file = args.input_jsonl + _process_jsonl(args.input_jsonl, args.msa_dir) + # Get task type from first datapoint in JSONL + try: + with open(args.input_jsonl) as f: + first_line = f.readline().strip() + if first_line: + first_datapoint = Datapoint.from_json(first_line) + task_type = first_datapoint.task_type + except Exception as e: + print(f"WARNING: Could not determine task type: {e}") + + # Run evaluation if result folder is specified and task type was determined + if args.result_folder and task_type and input_file: + try: + _run_evaluation(input_file, task_type, args.submission_dir, args.result_folder) + except Exception as e: + print(f"WARNING: Evaluation failed: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() diff --git a/hackathon/predict_in_docker.sh b/hackathon/predict_in_docker.sh new file mode 100755 index 000000000..51b67538e --- /dev/null +++ b/hackathon/predict_in_docker.sh @@ -0,0 +1,73 @@ + +#!/bin/bash + +# Usage: +# ./predict_in_docker.sh --image --dataset --submission-dir --msa-dir --boltz-cache-dir + +set -e + +show_help() { + echo "Usage: $0 --image --dataset --submission-dir --msa-dir --boltz-cache-dir " + exit 1 +} + +# Parse named parameters +while [[ $# -gt 0 ]]; do + key="$1" + case $key in + --image) + DOCKER_IMAGE_TAG="$2" + shift; shift + ;; + --dataset) + DATASET_PATH="$2" + shift; shift + ;; + --submission-dir) + SUBMISSION_DIR="$2" + shift; shift + ;; + --msa-dir) + MSA_DIR="$2" + shift; shift + ;; + --boltz-cache-dir) + BOLTZ_CACHE_DIR="$2" + shift; shift + ;; + --help) + show_help + ;; + *) + echo "Unknown parameter: $1" + show_help + ;; + esac +done + +# Check required parameters +if [[ -z "$DOCKER_IMAGE_TAG" || -z "$DATASET_PATH" || -z "$SUBMISSION_DIR" || -z "$MSA_DIR" || -z "$BOLTZ_CACHE_DIR" ]]; then + show_help +fi + +# Make sure submission directory exists +mkdir -p "$SUBMISSION_DIR" + +# Run the docker container with the required mounts and arguments +# set -x +docker run --rm \ + --gpus ${CUDA_VISIBLE_DEVICES:+device=$CUDA_VISIBLE_DEVICES} ${CUDA_VISIBLE_DEVICES:-all} \ + --network none \ + --shm-size=16G \ + --mount type=bind,source=$HOME/.ssh/cacert.pem,target=/etc/ssl/certs/cacert.pem,readonly \ + -e REQUESTS_CA_BUNDLE=/etc/ssl/certs/cacert.pem \ + -e CURL_CA_BUNDLE=/etc/ssl/certs/cacert.pem \ + -e SSL_CERT_FILE=/etc/ssl/certs/cacert.pem \ + -e BOLTZ_CACHE=/db/boltz \ + -v "${DATASET_PATH}:/app/dataset.jsonl:ro" \ + -v "${SUBMISSION_DIR}:/app/submission:rw" \ + -v "${MSA_DIR}:/app/msa:ro" \ + -v "${BOLTZ_CACHE_DIR}:/db/boltz:rw" \ + -it \ + "${DOCKER_IMAGE_TAG}" \ + conda run -n boltz python hackathon/predict_hackathon.py --input-jsonl "/app/dataset.jsonl" --msa-dir "/app/msa" diff --git a/pyproject.toml b/pyproject.toml index 9e22f29ef..0e1e302e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,9 @@ dependencies = [ "gemmi==0.6.5", "scikit-learn==1.6.1", "chembl_structure_pipeline==1.2.2", + "dataclasses-json", + "jupyterlab", + "matplotlib" ] [project.scripts] diff --git a/scripts/generate_local_msa.py b/scripts/generate_local_msa.py new file mode 100644 index 000000000..dbd2cb500 --- /dev/null +++ b/scripts/generate_local_msa.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import csv +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +# Constants +CHAIN_INFO_LENGTH = 2 +DEFAULT_CHAIN_NAME = "101" + + +@dataclass +class LocalColabFoldConfig: + """Configuration for ColabFold search.""" + + colabsearch: str + query_fpath: str + db_dir: str + results_dir: str + mmseqs_path: Optional[str] = None + db1: str = "uniref30_2302_db" + db2: Optional[str] = None + db3: Optional[str] = "colabfold_envdb_202108_db" + use_env: int = 1 + filter: int = 1 + db_load_mode: int = 0 + + +class A3MProcessor: + """Processor for A3M file format.""" + + def __init__(self, a3m_file: str, out_dir: str) -> None: + self.out_dir = out_dir + self.a3m_file = Path(a3m_file) + self.a3m_content = self._read_a3m_file() + self.chain_info = self._parse_header() + + def _read_a3m_file(self) -> str: + """Read A3M file content.""" + return self.a3m_file.read_text() + + def _parse_header(self) -> tuple[list[str], dict[str, tuple[int, int]]]: + """Parse A3M header to get chain information.""" + first_line = self.a3m_content.split("\n")[0] + if first_line[0] == "#": + lengths, oligomeric_state = first_line.split("\t") + + chain_lengths = [int(x) for x in lengths[1:].split(",")] + chain_names = [ + f"10{x + 1}" for x in range(len(oligomeric_state.split(","))) + ] + + # Calculate sequence ranges for each chain + seq_ranges = {} + for i, name in enumerate(chain_names): + start = sum(chain_lengths[:i]) + end = sum(chain_lengths[: i + 1]) + seq_ranges[name] = (start, end) + else: + chain_names = [DEFAULT_CHAIN_NAME] + seq_ranges = {DEFAULT_CHAIN_NAME: (0, len(self.a3m_content.split("\n")[1]))} + + return chain_names, seq_ranges + + def _extract_sequence(self, line: str, range_tuple: tuple[int, int]) -> str: + """Extract sequence for specific range.""" + seq = [] + no_insert_count = 0 + start, end = range_tuple + + for char in line: + if char.isupper() or char == "-": + no_insert_count += 1 + # we keep insertions + if start < no_insert_count <= end: + seq.append(char) + elif no_insert_count > end: + break + + return "".join(seq) + + def _process_sequence_lines( # noqa: C901 + self, + lines: list[str], + seq_ranges: dict[str, tuple[int, int]], + chain_names: list[str], + ) -> tuple[dict[str, list[str]], dict[str, list[str]]]: + """Process sequence lines to separate pairing and non-pairing sequences.""" + pairing_a3ms = {name: [] for name in chain_names} + nonpairing_a3ms = {name: [] for name in chain_names} + + current_query = None + for line in lines: + if line.startswith("#"): + continue + + if line.startswith(">"): + name = line[1:] + if name in chain_names: + current_query = chain_names[chain_names.index(name)] + elif name == "\t".join(chain_names): + current_query = None + + # Add header line to appropriate dictionary + if current_query: + nonpairing_a3ms[current_query].append(line) + else: + for chain_name in chain_names: + pairing_a3ms[chain_name].append(line) + continue + + # Process sequence line + if not line: + continue + + if current_query: + seq = self._extract_sequence(line, seq_ranges[current_query]) + nonpairing_a3ms[current_query].append(seq) + else: + for chain_name in chain_names: + seq = self._extract_sequence(line, seq_ranges[chain_name]) + pairing_a3ms[chain_name].append(seq) + + return nonpairing_a3ms, pairing_a3ms + + def _get_query_sequences(self, + chain_names: list[str], + pairing_a3ms: dict[str, list[str]], + nonpairing_a3ms: dict[str, list[str]]) -> dict[str, str]: + + query_sequences = {} + for chain_name in chain_names: + # Try to get query from pairing first, then non-pairing + pairing_lines = pairing_a3ms.get(chain_name, []) + nonpairing_lines = nonpairing_a3ms.get(chain_name, []) + + if len(pairing_lines) > 1: + query_sequences[chain_name] = pairing_lines[1] + elif len(nonpairing_lines) > 1: + query_sequences[chain_name] = nonpairing_lines[1] + else: + query_sequences[chain_name] = "" + + return query_sequences + + def split_sequences(self) -> None: + """Split A3M file into pairing and non-pairing sequences.""" + out_dir = Path(self.out_dir) + chain_names, seq_ranges = self.chain_info + + nonpairing_a3ms, pairing_a3ms = self._process_sequence_lines( + self.a3m_content.split("\n"), seq_ranges, chain_names + ) + + # Extract query sequences for each chain + query_sequences = self._get_query_sequences(chain_names, pairing_a3ms, nonpairing_a3ms) + + self._write_output_files(out_dir, nonpairing_a3ms, pairing_a3ms, query_sequences) + + def _write_msa_to_csv( + self, + csv_file_name: Path, + query_sequence: str, + pairing_sequences: list[str], + nonpairing_sequences: list[str], + ) -> None: + """ + Write MSA sequences to a CSV file with query sequence always first. + + Args: + csv_file_name: Path to the output CSV file + query_sequence: The query sequence (always written with key=0 as first row) + pairing_sequences: List of pairing MSA sequences (written with keys starting from 1) + nonpairing_sequences: List of non-pairing MSA sequences (written with key=-1) + """ + with csv_file_name.open(mode="w", newline="") as csv_file: + writer = csv.writer(csv_file) + writer.writerow(["key", "sequence"]) # Write header + + # ALWAYS write query sequence first with key=0 + writer.writerow([0, query_sequence]) + + # Write pairing sequences with positive keys starting from 1 + for i, seq in enumerate(pairing_sequences, start=1): + if seq and not seq.startswith(">"): + writer.writerow([i, seq]) + + # Write non-pairing sequences with key=-1 + for seq in nonpairing_sequences: + if seq and not seq.startswith(">"): + writer.writerow([-1, seq]) + + def _write_output_files( + self, + out_dir: Path, + nonpairing_a3ms: dict[str, list[str]], + pairing_a3ms: dict[str, list[str]], + query_sequences: dict[str, str], + ) -> None: + """ + Write split sequences to output files. + + This method combines pairing and non-pairing MSAs into a single CSV file per chain, + ensuring the query sequence is always written first with key=0. + + Args: + out_dir: Output directory for CSV files + nonpairing_a3ms: Dictionary of non-pairing MSA sequences by chain + pairing_a3ms: Dictionary of pairing MSA sequences by chain + query_sequences: Dictionary of query sequences by chain + """ + out_dir.mkdir(exist_ok=True) + + # Get all unique chain names from both dictionaries + all_chain_names = list(pairing_a3ms.keys()) + + # Process each chain and write combined MSA to CSV + for i, chain_name in enumerate(all_chain_names): + csv_file_name = out_dir / f"msa_{i}.csv" + + # Get sequences from both sources + pairing_lines = pairing_a3ms.get(chain_name, []) + nonpairing_lines = nonpairing_a3ms.get(chain_name, []) + + # Get the query sequence for this chain + query_seq = query_sequences.get(chain_name, "") + + # Validate that we have a query sequence + if not query_seq: + print(f"Warning: No query sequence found for chain {chain_name}") + continue + + # Extract sequences, skipping header at index 0 + # Skip index 1 if it's identical to the query sequence + pairing_sequences = [] + nonpairing_sequences = [] + + # Process pairing sequences + for idx, line in enumerate(pairing_lines): + if idx == 0: # Skip header + continue + if idx == 1 and line == query_seq: # Skip if identical to query + continue + if line and not line.startswith(">"): + pairing_sequences.append(line) + + # Process non-pairing sequences + for idx, line in enumerate(nonpairing_lines): + if idx == 0: # Skip header + continue + if idx == 1 and line == query_seq: # Skip if identical to query + continue + if line and not line.startswith(">"): + nonpairing_sequences.append(line) + + # Write combined MSA to CSV with query sequence always first + self._write_msa_to_csv( + csv_file_name, + query_seq, + pairing_sequences, + nonpairing_sequences, + ) + + +def run_colabfold_search(config: LocalColabFoldConfig) -> str: + """Run ColabFold search with given configuration.""" + cmd = [config.colabsearch, config.query_fpath, config.db_dir, config.results_dir] + + # Add optional parameters + if config.db1: + cmd.extend(["--db1", config.db1]) + if config.db2: + cmd.extend(["--db2", config.db2]) + if config.db3: + cmd.extend(["--db3", config.db3]) + if config.mmseqs_path: + cmd.extend(["--mmseqs", config.mmseqs_path]) + else: + cmd.extend(["--mmseqs", "mmseqs"]) + if config.use_env: + cmd.extend(["--use-env", str(config.use_env)]) + if config.filter: + cmd.extend(["--filter", str(config.filter)]) + if config.db_load_mode: + cmd.extend(["--db-load-mode", str(config.db_load_mode)]) + + # Use subprocess instead of os.system for security + subprocess.run(cmd, check=True) # noqa: S603 + + # Return the first .a3m file found in results directory + result_files = list(Path(config.results_dir).glob("*.a3m")) + if not result_files: + error_msg = f"No .a3m files found in {config.results_dir}" + raise FileNotFoundError(error_msg) + return str(result_files[0]) + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="ColabFold search and A3M processing tool", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Required arguments + parser.add_argument("query_fpath", help="Path to the query FASTA file") + parser.add_argument("db_dir", help="Directory containing the databases") + parser.add_argument("results_dir", help="Directory for storing results") + + # Optional arguments + parser.add_argument( + "--colabsearch", help="Path to colabfold_search", default="colabfold_search" + ) + parser.add_argument( + "--mmseqs_path", help="Path to MMseqs2 binary", default="mmseqs" + ) + parser.add_argument("--db1", help="First database name", default="uniref30_2302_db") + parser.add_argument("--db2", help="Templates database") + parser.add_argument( + "--db3", help="Environmental database (default: colabfold_envdb_202108_db)" + ) + parser.add_argument( + "--use_env", help="Use environment settings", type=int, default=1 + ) + parser.add_argument("--filter", help="Apply filtering", type=int, default=1) + parser.add_argument( + "--db_load_mode", help="Database load mode", type=int, default=0 + ) + parser.add_argument( + "--output_split", help="Directory for split A3M files", default=None + ) + return parser.parse_args() + + +def main(args: argparse.Namespace): + # Create configuration from arguments + config = LocalColabFoldConfig( + colabsearch=args.colabsearch, + query_fpath=args.query_fpath, + db_dir=args.db_dir, + results_dir=args.results_dir, + mmseqs_path=args.mmseqs_path, + db1=args.db1, + db2=args.db2, + db3=args.db3, + use_env=args.use_env, + filter=args.filter, + db_load_mode=args.db_load_mode, + ) + + # Run search + results_a3m = run_colabfold_search(config) + + processor = A3MProcessor(results_a3m, args.results_dir) + if len(processor.chain_info) == CHAIN_INFO_LENGTH: + processor.split_sequences() + + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/hackathon_compute_msa.py b/scripts/hackathon_compute_msa.py new file mode 100755 index 000000000..a59f6dca0 --- /dev/null +++ b/scripts/hackathon_compute_msa.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 +""" +Script to compute MSAs for protein sequences in a JSONL file using ColabFold search. + +This script processes a JSONL file containing protein sequences, generates MSAs for each +unique protein sequence using ColabFold search, and outputs a new JSONL file with +updated MSA paths. +""" + +from __future__ import annotations + +import argparse +import json +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional +import hashlib +import shutil + +from generate_local_msa import LocalColabFoldConfig, run_colabfold_search, A3MProcessor + + +def create_fasta_from_sequences(sequences: List[str], seq_ids: List[str], output_path: Path) -> str: + """Create a FASTA file from protein sequences, concatenating multiple sequences with ':'.""" + if len(sequences) == 1: + # Single sequence + fasta_content = f">{seq_ids[0]}\n{sequences[0]}\n" + else: + # Multiple sequences - concatenate with ':' + concatenated_sequence = ":".join(sequences) + concatenated_id = "_".join(seq_ids) + fasta_content = f">{concatenated_id}\n{concatenated_sequence}\n" + + output_path.write_text(fasta_content) + return str(output_path) + + +def get_entry_hash(entry: Dict[str, Any]) -> str: + """Generate a hash for a JSONL entry based on its protein sequences.""" + return entry["datapoint_id"] + + +def extract_entry_sequences(jsonl_data: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + """ + Extract protein sequences for each JSONL entry. + + Returns: + Dict mapping entry hash to entry data with sequences and IDs + """ + entry_data = {} + + for entry in jsonl_data: + if "proteins" in entry: + sequences = [] + seq_ids = [] + + for protein in entry["proteins"]: + if "sequence" in protein: + sequences.append(protein["sequence"]) + seq_ids.append(protein.get("id", f"protein_{len(seq_ids)}")) + + if sequences: # Only process entries with protein sequences + entry_hash = get_entry_hash(entry) + entry_data[entry_hash] = { + "sequences": sequences, + "seq_ids": seq_ids, + "original_entry": entry + } + + return entry_data + + +def process_msa_generation( + entry_data: Dict[str, Dict[str, Any]], + temp_dir: Path, + msa_dir: Path, + colabsearch_path: str, + db_dir: str, + mmseqs_path: Optional[str] = None, + db1: str = "uniref30_2302_db", + db2: Optional[str] = None, + db3: str = "colabfold_envdb_202108_db", +) -> Dict[str, str]: + """ + Generate MSAs for all entries. + + Returns: + Dict mapping entry hash to MSA CSV file path(s) + """ + msa_paths = {} + + for entry_hash, data in entry_data.items(): + sequences = data["sequences"] + seq_ids = data["seq_ids"] + + print(f"Processing entry {entry_hash} with {len(sequences)} sequences...") + + # Create FASTA file + fasta_path = temp_dir / f"{entry_hash}.fasta" + create_fasta_from_sequences(sequences, seq_ids, fasta_path) + + # Create temporary results directory for this entry + temp_results_dir = temp_dir / f"results_{entry_hash}" + temp_results_dir.mkdir(exist_ok=True) + + # Configure ColabFold search + config = LocalColabFoldConfig( + colabsearch=colabsearch_path, + query_fpath=str(fasta_path), + db_dir=db_dir, + results_dir=str(temp_results_dir), + mmseqs_path=mmseqs_path, + db1=db1, + db2=db2, + db3=db3, + ) + + try: + # Run ColabFold search + a3m_file = run_colabfold_search(config) + + # Process A3M file to generate CSV + processor = A3MProcessor(a3m_file, str(temp_results_dir)) + processor.split_sequences() + + # Move the CSV files to the final MSA directory + csv_files = list(temp_results_dir.glob("msa_*.csv")) + if csv_files: + if len(sequences) == 1: + # Single sequence - use one CSV file + source_csv = csv_files[0] + target_csv = msa_dir / f"{entry_hash}.csv" + shutil.move(str(source_csv), str(target_csv)) + msa_paths[entry_hash] = str(target_csv) + else: + # Multiple sequences - store paths to multiple CSV files + csv_paths = [] + for i, source_csv in enumerate(csv_files): + target_csv = msa_dir / f"{entry_hash}_{i}.csv" + shutil.move(str(source_csv), str(target_csv)) + csv_paths.append(str(target_csv)) + msa_paths[entry_hash] = csv_paths + + print(f"MSA generated for {entry_hash}: {msa_paths[entry_hash]}") + else: + print(f"Warning: No CSV files generated for entry {entry_hash}") + + except Exception as e: + print(f"Error processing entry {entry_hash}: {e}") + continue + + return msa_paths + + +def update_jsonl_with_msa_paths( + entry_data: Dict[str, Dict[str, Any]], + msa_paths: Dict[str, str], + output_jsonl: Path, +) -> None: + """Update JSONL data with new MSA paths (filenames only) and write to output file.""" + + updated_entries = [] + + for entry_hash, data in entry_data.items(): + entry = data["original_entry"].copy() + + if entry_hash in msa_paths: + msa_path_data = msa_paths[entry_hash] + + if "proteins" in entry: + if isinstance(msa_path_data, list): + # Multiple MSA files for multiple proteins + for i, protein in enumerate(entry["proteins"]): + if i < len(msa_path_data): + # Store only the filename, not the full path + protein["msa"] = Path(msa_path_data[i]).name + else: + print(f"Warning: No MSA file for protein {i} in entry {entry.get('datapoint_id', 'unknown')}") + else: + # Single MSA file - assign to all proteins (for concatenated sequences) + for protein in entry["proteins"]: + # Store only the filename, not the full path + protein["msa"] = Path(msa_path_data).name + else: + print(f"Warning: No MSA found for entry {entry.get('datapoint_id', 'unknown')}") + + updated_entries.append(entry) + + # Write updated JSONL + with output_jsonl.open("w") as f: + for entry in updated_entries: + f.write(json.dumps(entry) + "\n") + + +def load_jsonl(file_path: Path) -> List[Dict[str, Any]]: + """Load JSONL file into a list of dictionaries.""" + data = [] + with file_path.open("r") as f: + for line in f: + line = line.strip() + if line: + data.append(json.loads(line)) + return data + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Compute MSAs for protein sequences in a JSONL file", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Required arguments + parser.add_argument( + "--input-jsonl", + type=Path, + required=True, + help="Input JSONL file containing protein sequences" + ) + parser.add_argument( + "--output-jsonl", + type=Path, + required=True, + help="Output JSONL file with updated MSA paths" + ) + parser.add_argument( + "--msa-dir", + type=Path, + required=True, + help="Directory to store final MSA CSV files" + ) + parser.add_argument( + "--db-dir", + type=str, + required=True, + help="Directory containing ColabFold databases" + ) + + # Optional arguments + parser.add_argument( + "--temp-dir", + type=Path, + default=None, + help="Temporary directory for intermediate files (default: system temp dir)" + ) + parser.add_argument( + "--colabsearch", + type=str, + default="colabfold_search", + help="Path to colabfold_search executable" + ) + parser.add_argument( + "--mmseqs-path", + type=str, + default="mmseqs", + help="Path to MMseqs2 binary" + ) + parser.add_argument( + "--db1", + type=str, + default="uniref30_2302_db", + help="First database name" + ) + parser.add_argument( + "--db2", + type=str, + default=None, + help="Templates database" + ) + parser.add_argument( + "--db3", + type=str, + default="colabfold_envdb_202108_db", + help="Environmental database" + ) + + return parser.parse_args() + + +def main(): + """Main function.""" + args = parse_args() + + # Validate input file + if not args.input_jsonl.exists(): + raise FileNotFoundError(f"Input JSONL file not found: {args.input_jsonl}") + + # Create output directories + args.msa_dir.mkdir(parents=True, exist_ok=True) + args.output_jsonl.parent.mkdir(parents=True, exist_ok=True) + + try: + if not args.temp_dir: + with tempfile.TemporaryDirectory() as temp_dir_str: + temp_dir = Path(temp_dir_str) + _process_msa_workflow(args, temp_dir) + else: + args.temp_dir.mkdir(parents=True, exist_ok=True) + _process_msa_workflow(args, args.temp_dir) + + except Exception as e: + print(f"Error: {e}") + raise + + +def _process_msa_workflow(args: argparse.Namespace, temp_dir: Path): + """Process the MSA generation workflow.""" + print(f"Loading JSONL data from {args.input_jsonl}") + jsonl_data = load_jsonl(args.input_jsonl) + + print("Extracting entry sequences...") + entry_data = extract_entry_sequences(jsonl_data) + print(f"Found {len(entry_data)} entries with protein sequences") + + if not entry_data: + print("No protein sequences found in input file") + return + + print("Generating MSAs...") + msa_paths = process_msa_generation( + entry_data=entry_data, + temp_dir=temp_dir, + msa_dir=args.msa_dir, + colabsearch_path=args.colabsearch, + db_dir=args.db_dir, + mmseqs_path=args.mmseqs_path, + db1=args.db1, + db2=args.db2, + db3=args.db3, + ) + + print(f"Successfully generated MSA files for {len(msa_paths)} entries") + + print(f"Updating JSONL with MSA paths and writing to {args.output_jsonl}") + update_jsonl_with_msa_paths(entry_data, msa_paths, args.output_jsonl) + + print("MSA generation complete!") + + +if __name__ == "__main__": + main()