diff --git a/docs/build_a_model/data_cleaning_script.py b/docs/build_a_model/data_cleaning_script.py new file mode 100644 index 00000000..e1092c44 --- /dev/null +++ b/docs/build_a_model/data_cleaning_script.py @@ -0,0 +1,189 @@ +""" +Script to read in parquet files and clean the data needed for the "How to Build a Model Using AtomWorks" tutorial. This script is meant to be run from the command line and takes in the path to the parquet files as an argument. It saves the cleaned data as a parquet file for future use in the tutorial. + +Example usage: + python data_cleaning_script.py /path/to/parquet/files + +Last edited: April 14, 2026 +""" + +import pandas as pd +import os +import sys + +def read_in_parquet_file(path_to_parquets: str | os.PathLike, parquet_file: str) -> pd.DataFrame: + """ + Creates a pandas DataFrame from a parquet file. + :param path_to_parquets: The path to the directory containing the parquet files. + :param parquet_file: The name of the parquet file (without the .parquet extension). + :return: A pandas DataFrame containing the data from the parquet file. + """ + file_path = os.path.join(path_to_parquets, (parquet_file + ".parquet")) + + if not os.path.exists(file_path): + raise FileNotFoundError(f"{file_path} not found") + + mydf = pd.read_parquet(file_path) + + return mydf + +path_to_parquets = None +PDB_MIRROR_PATH = None + +if len(sys.argv) > 2: + path_to_parquets = sys.argv[1] + PDB_MIRROR_PATH = sys.argv[2] +else: + print("To use this script, please provide the path to the parquet files and the path to your PDB mirror as command line arguments.") + sys.exit(1) + +def assign_split(cluster): + """ + Helper function to assign each cluster to a split (train, val, or test) based on the sets of clusters we defined earlier. + :param cluster: The protein cluster to assign a split to. + :return: The split that the cluster belongs to (train, val, test, or unassigned). + """ + # TODO: shouldn't we already have removed any unassigned clusters in the previous step? + if cluster in train_clusters: return "train" + if cluster in val_clusters: return "val" + if cluster in test_clusters: return "test" + return "unassigned" # rows where cluster was null. For this tutorial, these should already be filtered out, but it's a good debugging tool! + + +#################################################### +# Reading in the parquet files and cleaning the data. +#################################################### + +interfaces = read_in_parquet_file(path_to_parquets, "interfaces") +pn_units = read_in_parquet_file(path_to_parquets, "pn_units") + +# We need to combine the information in the interfaces and pn_units datasets to get the data for each pn_unit involved in each interface. +# However, the label "pn_unit_iid" in the pn_units dataset is not unique - it is only unique for a given pdb_id and assembly_id. +# So when we merge these dataframes we need to merge on all three keys and we need to do it twice - once for the first pn_unit in each interface and once for the second pn_unit in each interface. + +# Saving the particular columns we want to merge for the first set of pn_units +# We also keep the "is_polymer" column to filter our data later. +u1_cols = pn_units[["pdb_id", "assembly_id", "pn_unit_iid", "is_polymer", "num_resolved_residues"]].copy() + +u2_cols = pn_units[["pdb_id", "assembly_id", "pn_unit_iid", "is_polymer", "num_resolved_residues"]].copy() + +# Renaming the columns in these new dataframes to merge on the pn_unit_iid labels (they are called "pn_unit_1_iid" and "pn_unit_2_iid" in the interfaces dataset) and to distinguish the other columns for the first and second pn_units in each interface. +u1_cols = u1_cols.rename(columns={"pn_unit_iid": "pn_unit_1_iid", + "is_polymer": "u1_is_polymer", + "num_resolved_residues": "u1_num_resolved_residues"}) + +u2_cols = u2_cols.rename(columns={"pn_unit_iid": "pn_unit_2_iid", + "is_polymer": "u2_is_polymer", + "num_resolved_residues": "u2_num_resolved_residues"}) + +# Merge these three dataframes together to get the data for each pn_unit involved in each interface. + +df = interfaces.merge(u1_cols, on=["pdb_id", "assembly_id", "pn_unit_1_iid"], how="inner") +df = df.merge(u2_cols, on=["pdb_id", "assembly_id", "pn_unit_2_iid"], how="inner") + +# Now we will filter the data to only keep drug-like protein-ligand pockets: +# involves_loi==True so that at least one of the pn_units is a "ligand of interest" +df = df[df["involves_loi"]] +# is_inter_molecule==True to ensure the two pn_units belong to different molecules +df = df[df["is_inter_molecule"]] +# involves_metal==False to avoid interfaces that are metal-mediated +df = df[~df["involves_metal"]] +# involves_covalent_modification==False to avoid covalently modified residues +df = df[~df["involves_covalent_modification"]] +# only keep interfaces where one pn_unit is a polymer and the other is not (should be a small molecule ligand). This wil avoid protein-protein, ligand-ligand, and RNA/DNA interfaces. +df = df[df["u1_is_polymer"] != df["u2_is_polymer"]] +# total of num_resolved_residues < 200 to keep the size tractable for training +df = df[(df["u1_num_resolved_residues"] + df["u2_num_resolved_residues"]) < 200] + +# Remove duplicate rows +df = df.drop_duplicates(subset=["pdb_id", "assembly_id", "pn_unit_1_iid", "pn_unit_2_iid"]) + +# Add a unique identifier for each interface. This will be useful later in the tutorial. +df["example_id"] = ( + df["pdb_id"] + "_" + + df["assembly_id"].astype(str) + "_" + + df["pn_unit_1_iid"] + "_" + + df["pn_unit_2_iid"] +) +assert df["example_id"].nunique() == len(df), "example_id is not unique!" + +# Add path information to match each entry to the location of the files in your PDB mirror. +PDB_MIRROR_PATH = os.environ.get("PDB_MIRROR_PATH", "/PATH/TO/pdb_mirror") +df["path"] = df["pdb_id"].str.lower().map( + lambda x: f"{PDB_MIRROR_PATH}/{x[1:3]}/{x}.cif.gz" +) + +# Save the cleaned data as a parquet file for future use in the tutorial. +df.to_parquet("cleaned_data.parquet") + +#################################################### +# Splitting the data into training, testing, and validation sets. +#################################################### +# There are many ways to do this, but for the tutorial we have decided to use the information stored in the column "protein_cluster_30" to split the data. This column groups proteins by 30% sequence identity. We will use this information to ensure that all interfaces involving proteins from the same cluster are in the same set (training, testing, or validation). This will help us to better evaluate the generalizability of our model to new proteins. + +protein_clusters = pn_units[["pdb_id", "assembly_id", "pn_unit_iid", "protein_cluster_30"]].copy() + +# merge this with u1 and u2 since either could be the protein side of the interface. We us a left merge for this to only keep the keys from the "left" (df) dataframe. We will keep all rows in df and just add the cluster information where it exists. +# Note: ligands will have a null cluster. + +# Merge cluster for u1 side +df = df.merge( + protein_clusters.rename(columns={ + "pn_unit_iid": "pn_unit_1_iid","protein_cluster_30": "u1_cluster"}), + on=["pdb_id", "assembly_id", "pn_unit_1_iid"], + how="left" +) +# Merge cluster for u2 side +df = df.merge( + protein_clusters.rename(columns={ + "pn_unit_iid": "pn_unit_2_iid", + "protein_cluster_30": "u2_cluster"}), + on=["pdb_id", "assembly_id", "pn_unit_2_iid"], + how="left" +) + +# Get the protein_cluster information from whichever side is the polymer: +df["protein_cluster"] = np.where( + df["u1_is_polymer"], + df["u1_cluster"], + df["u2_cluster"] +) +# TODO ask Hope about this step. It seems like the condition is only for if u1 is a polymer - what if u2 is the polymer? + +# There may be cases where no clusters were assigned, for example low-quality entries. Let's remove these from our dataset: +df = df[df["protein_cluster"].notna()].reset_index(drop=True) + +# Now we can finally split the data based on the protein_cluster information. We will use 80% of the clusters for training, 10% for testing, and 10% for validation. +# Shuffle the the clusters randomly. A seed is specified for reproducibility. +unique_clusters = df["protein_cluster"].dropna().unique() +# TODO: didn't we already drop these in the previous step? +rng = np.random.default_rng(seed=42) +rng.shuffle(unique_clusters) + +# split +n = len(unique_clusters) +n_train = int(0.8 * n) +n_val = int(0.1 * n) +# test gets the remainder to avoid off-by-one gaps + +train_clusters = set(unique_clusters[:n_train]) +val_clusters = set(unique_clusters[n_train : n_train + n_val]) +test_clusters = set(unique_clusters[n_train + n_val :]) + +# create datasets for each +df["split"] = df["protein_cluster"].map(assign_split) + +df_train = df[df["split"] == "train"].reset_index(drop=True) +df_val = df[df["split"] == "val"].reset_index(drop=True) +df_test = df[df["split"] == "test"].reset_index(drop=True) + +# Save datasets as parquet files in separate directories +os.makedirs("splits", exist_ok=True) +df_train.to_parquet("splits/train.parquet", index=False) +df_val.to_parquet("splits/val.parquet", index=False) +df_test.to_parquet("splits/test.parquet", index=False) + + + + + diff --git a/docs/build_a_model/dataset_and_loader.py b/docs/build_a_model/dataset_and_loader.py new file mode 100644 index 00000000..c6b6846b --- /dev/null +++ b/docs/build_a_model/dataset_and_loader.py @@ -0,0 +1,71 @@ +""" +This script is part of the "How to Build a Model Using AtomWorks" tutorial. +It uses the AtomWorks API to apply a loader and transform to each 'example' (data point). + +Transforms applied: +- RemoveHydrogens +- RemoveUnresolvedAtoms +- CropToPocket +- FeaturizeFor +- ConvertToTorch + +Last updated: April 14, 2026 +""" + +import pandas as pd +from atomworks.ml.datasets import PandasDataset +from atomworks.ml.datasets.loaders import create_loader_with_query_pn_units +from atomworks.ml.transforms.filters import RemoveHydrogens, RemoveUnresolvedAtoms +from atomworks.ml.transforms.base import Compose + +from transforms import CropToPocket, FeaturizeForDocking + +# Load in the training data as a Pandas data frame +df_train = pd.read_parquet("splits/train.parquet") + +# Define our transforms pipeline with two transforms that are already defined in AtomWorks +transforms_pipeline = Compose([ + RemoveHydrogens(), + RemoveUnresolvedAtoms(), + CropToPocket(radius=10.0), + FeaturizeForDocking(), +]) + + +# Use the data to create an AtomWorks PandasDataset object +dataset = PandasDataset(data=df_train, + name="docking_train", + id_column="example_id", + loader=create_loader_with_query_pn_units( + pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"]), + transform=transforms_pipeline, +) + +################################ +# Print statements for testing # +################################ +print(f"Dataset size: {len(dataset)}") + +example = dataset[0] + +print("\nLoaded one sample successfully.") +print("Sample keys:", list(example.keys())) +print("atomic_numbers:", example["atomic_numbers"].shape, example["atomic_numbers"].dtype) +print("input_coords:", example["input_coords"].shape, example["input_coords"].dtype) +print("target_coords:", example["target_coords"].shape, example["target_coords"].dtype) +print("edge_index:", example["edge_index"].shape, example["edge_index"].dtype) +print("is_ligand:", example["is_ligand"].shape, example["is_ligand"].dtype) + +assert example["atomic_numbers"].ndim == 1 +assert example["target_coords"].ndim == 2 +assert example["target_coords"].shape[1] == 3 +assert example["input_coords"].shape == example["target_coords"].shape +assert example["edge_index"].ndim == 2 +assert example["edge_index"].shape[0] == 2 +assert example["edge_index"].max() < example["atomic_numbers"].shape[0] +assert (example["input_coords"][example["is_ligand"]] == 0).all(), \ + "Ligand coordinates should be zeroed" +assert (example["input_coords"][~example["is_ligand"]] != 0).any(), \ + "Pocket coordinates should not all be zero" + +print("\nFeaturizeForDocking smoke test passed.") diff --git a/docs/build_a_model/model.py b/docs/build_a_model/model.py new file mode 100644 index 00000000..d5060e66 --- /dev/null +++ b/docs/build_a_model/model.py @@ -0,0 +1,119 @@ +import torch +import torch.nn as nn +import pytorch_lightning as pl + +class PocketDockGNN(pl.LightningModule): + def __init__( + self, + num_atom_types: int = 119, + hidden_dim: int = 128, + num_layers: int = 3, + learning_rate: float = 1e-3, + ): + super().__init__() + self.save_hyperparameters() + self.learning_rate = learning_rate + self.atom_embedding = nn.Embedding(num_atom_types, hidden_dim) + + self.input_proj = nn.Sequential( + nn.Linear(hidden_dim + 3, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + ) + + self.conv_layers = nn.ModuleList([ + nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + ) + for _ in range(num_layers) + ]) + + self.update_layers = nn.ModuleList([ + nn.Sequential( + nn.Linear(hidden_dim * 2, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + ) + for _ in range(num_layers) + ]) + + self.layer_norms = nn.ModuleList([ + nn.LayerNorm(hidden_dim) + for _ in range(num_layers) + ]) + + self.output_proj = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, 3), + ) + + self.loss_fn = nn.MSELoss() + + def forward( + self, + atomic_numbers: torch.Tensor, + input_coords: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + x = self.atom_embedding(atomic_numbers) + + x = self.input_proj(torch.cat([x, input_coords], dim=-1)) + + src, dst = edge_index[0], edge_index[1] + + for conv, update, norm in zip(self.conv_layers, self.update_layers, self.layer_norms): + messages = conv(x) + + agg = torch.zeros_like(x) + agg.scatter_add_(0, dst.unsqueeze(-1).expand(-1, x.size(-1)), messages[src]) + + x = x + norm(update(torch.cat([x, agg], dim=-1))) + + pred_coords = self.output_proj(x) + return pred_coords + + def _shared_step(self, batch: dict, stage: str) -> torch.Tensor: + if batch is None: + return None + # Remove the batch dimension added by DataLoader (batch_size=1) + atomic_numbers = batch["atomic_numbers"].squeeze(0) # (N,) + input_coords = batch["input_coords"].squeeze(0) # (N, 3) + target_coords = batch["target_coords"].squeeze(0) # (N, 3) + edge_index = batch["edge_index"].squeeze(0) # (2, E) + is_ligand = batch["is_ligand"].squeeze(0) # (N,) + + pred_coords = self(atomic_numbers, input_coords, edge_index) + loss = self.loss_fn(pred_coords, target_coords) + + with torch.no_grad(): + ligand_rmsd = torch.sqrt( + ((pred_coords[is_ligand] - target_coords[is_ligand]) ** 2) + .sum(dim=-1).mean() + ) + + self.log(f"{stage}/loss", loss, prog_bar=True) + self.log(f"{stage}/ligand_rmsd", ligand_rmsd, prog_bar=True) + + return loss + + def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: + return self._shared_step(batch, "train") + + def validation_step(self, batch: dict, batch_idx: int) -> None: + self._shared_step(batch, "val") + + def test_step(self, batch: dict, batch_idx: int) -> None: + self._shared_step(batch, "test") + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.learning_rate) + + + + + + + diff --git a/docs/build_a_model/train.py b/docs/build_a_model/train.py new file mode 100644 index 00000000..e315183f --- /dev/null +++ b/docs/build_a_model/train.py @@ -0,0 +1,180 @@ +import torch +import pandas as pd +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from torch.utils.data import DataLoader, Dataset + +from atomworks.ml.datasets import PandasDataset +from atomworks.ml.datasets.loaders import create_loader_with_query_pn_units +from atomworks.ml.transforms.filters import RemoveHydrogens, RemoveUnresolvedAtoms +from atomworks.ml.transforms.base import ConvertToTorch, Compose + +from transforms import CropToPocket, FeaturizeForDocking +from model import PocketDockGNN + +torch.set_float32_matmul_precision("medium") +pl.seed_everything(42) + +CONFIG = { + "hidden_dim": 128, + "num_layers": 3, + "learning_rate": 1e-3, + "batch_size": 1, + "max_epochs": 5, + "pocket_radius": 10.0, + "num_workers": 0, + "max_train": 100, + "max_val": 20, + "max_test": 20, +} + +class RobustDataset(Dataset): + def __init__(self, dataset: PandasDataset): + self.dataset = dataset + self.failed = [] + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + try: + return self.dataset[idx] + except Exception: + self.failed.append(idx) + return None + + +def collate_fn(batch): + batch = [b for b in batch if b is not None] + if len(batch) == 0: + return None + + return { + k: torch.stack([example[k] for example in batch]) + for k in TENSOR_KEYS + if k in batch[0] + } + + +def build_pipeline(radius: float) -> Compose: + return Compose([ + RemoveHydrogens(), + RemoveUnresolvedAtoms(), + CropToPocket(radius=radius), + FeaturizeForDocking(), + ConvertToTorch(keys=TENSOR_KEYS), + ]) + +def build_dataset(parquet_path: str, name: str, radius: float, max_examples: int = None): + df = pd.read_parquet(parquet_path) + if max_examples is not None: + df = df.head(max_examples).reset_index(drop=True) + + loader = create_loader_with_query_pn_units( + pn_unit_iid_colnames=["pn_unit_1_iid", "pn_unit_2_iid"] + ) + + dataset = PandasDataset( + data=df, + name=name, + id_column="example_id", + loader=loader, + transform=build_pipeline(radius), + save_failed_examples_to_dir="failed_examples/", + ) + return RobustDataset(dataset) + +def build_dataloader(dataset, shuffle: bool) -> DataLoader: + return DataLoader( + dataset, + batch_size=CONFIG["batch_size"], + shuffle=shuffle, + num_workers=CONFIG["num_workers"], + collate_fn=collate_fn, + persistent_workers=False, + ) + +print("Building datasets...") +train_dataset = build_dataset( + "splits/train.parquet", "docking_train", CONFIG["pocket_radius"], CONFIG["max_train"] +) +val_dataset = build_dataset( + "splits/val.parquet", "docking_val", CONFIG["pocket_radius"], CONFIG["max_val"] +) +test_dataset = build_dataset( + "splits/test.parquet", "docking_test", CONFIG["pocket_radius"], CONFIG["max_test"] +) + +print(f" Train: {len(train_dataset):,} examples") +print(f" Val: {len(val_dataset):,} examples") +print(f" Test: {len(test_dataset):,} examples") + +train_loader = build_dataloader(train_dataset, shuffle=True) +val_loader = build_dataloader(val_dataset, shuffle=False) +test_loader = build_dataloader(test_dataset, shuffle=False) + + +model = PocketDockGNN( + hidden_dim=CONFIG["hidden_dim"], + num_layers=CONFIG["num_layers"], + learning_rate=CONFIG["learning_rate"], +) + +periodic_checkpoint = ModelCheckpoint( + dirpath="checkpoints/", + filename="pocketdockgnn-{epoch:02d}-{step}", + every_n_train_steps=50, + save_last=True, + verbose=True, +) + +best_checkpoint = ModelCheckpoint( + dirpath="checkpoints/", + filename="pocketdockgnn-best-{epoch:02d}-{val/loss:.4f}", + monitor="val/loss", + mode="min", + save_top_k=3, + verbose=True, +) + +callbacks = [ + best_checkpoint, + periodic_checkpoint, + EarlyStopping( + monitor="val/loss", + patience=10, + mode="min", + verbose=True, + ), +] + +trainer = pl.Trainer( + max_epochs=CONFIG["max_epochs"], + accelerator="gpu", + devices=1, + callbacks=callbacks, + log_every_n_steps=10, + val_check_interval=50, + enable_progress_bar=True, +) +print("\nStarting training...") +trainer.fit( + model, + train_dataloaders=train_loader, + val_dataloaders=val_loader, +) + +print(f"\nFailed examples during training: {len(train_dataset.failed)}") +print(f"Failed examples during val: {len(val_dataset.failed)}") +print(f"Best checkpoint: {best_checkpoint.best_model_path}") + +print("\nEvaluating on test set...") +trainer.test( + model, + dataloaders=test_loader, + ckpt_path=best_checkpoint.best_model_path, +) + +print("\nDone.") + + diff --git a/docs/build_a_model/transforms.py b/docs/build_a_model/transforms.py new file mode 100644 index 00000000..e37b441c --- /dev/null +++ b/docs/build_a_model/transforms.py @@ -0,0 +1,117 @@ +from atomworks.ml.transforms.base import Transform +import numpy as np +from biotite.structure import AtomArray +from scipy.spatial import cKDTree +from atomworks.ml.transforms._checks import check_atom_array_annotation +from atomworks.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER + +class CropToPocket(Transform): + requires_previous_transforms = ["RemoveHydrogens", "RemoveUnresolvedAtoms"] + def __init__(self, radius: float = 10.0) -> None: + super().__init__() + self.radius = radius + + def forward(self, data: dict) -> dict: + data["atom_array"] = crop_to_pocket( + data["atom_array"], + query_pn_unit_iids=data["query_pn_unit_iids"], + chain_info=data["chain_info"], + radius=self.radius, + ) + return data + + def check_input(self, data: dict) -> None: + assert "atom_array" in data, "Missing atom_array" + assert "query_pn_unit_iids" in data, "Missing query_pn_unit_iids" + assert "chain_info" in data, "Missing chain_info" + +class FeaturizeForDocking(Transform): + requires_previous_transforms = ["CropToPocket"] + + def check_input(self, data: dict) -> None: + check_atom_array_annotation(data, ["is_ligand"]) + + def forward(self, data: dict) -> dict: + features = featurize_for_docking(data["atom_array"]) + data.update(features) + return data + +def crop_to_pocket( + atom_array: AtomArray, + query_pn_unit_iids: list, + chain_info: dict, + radius: float = 10.0, +) -> AtomArray: + atom_array = atom_array.copy() + + iid_a, iid_b = query_pn_unit_iids + chain_a = iid_a.split("_")[0] + chain_b = iid_b.split("_")[0] + + a_is_polymer = chain_info.get(chain_a, {}).get("is_polymer", True) + b_is_polymer = chain_info.get(chain_b, {}).get("is_polymer", True) + + if not a_is_polymer: + ligand_iid, protein_iid = iid_a, iid_b + elif not b_is_polymer: + ligand_iid, protein_iid = iid_b, iid_a + else: + mask_a = atom_array.pn_unit_iid == iid_a + mask_b = atom_array.pn_unit_iid == iid_b + if mask_a.sum() <= mask_b.sum(): + ligand_iid, protein_iid = iid_a, iid_b + else: + ligand_iid, protein_iid = iid_b, iid_a + + ligand_mask = atom_array.pn_unit_iid == ligand_iid + protein_mask = atom_array.pn_unit_iid == protein_iid + + ligand_coords = atom_array.coord[ligand_mask] + protein_coords = atom_array.coord[protein_mask] + + if len(ligand_coords) == 0: + raise ValueError(f"Ligand {ligand_iid} has no atoms") + if len(protein_coords) == 0: + raise ValueError(f"Protein {protein_iid} has no atoms") + + tree = cKDTree(protein_coords) + neighbor_indices = tree.query_ball_point(ligand_coords, r=radius) + + total_neighbors = sum(len(n) for n in neighbor_indices) + if total_neighbors == 0: + raise ValueError(f"No protein atoms found within {radius}A of ligand {ligand_iid}") + + pocket_local_indices = np.unique(np.concatenate(neighbor_indices).astype(int)) + protein_global_indices = np.where(protein_mask)[0] + pocket_global_indices = protein_global_indices[pocket_local_indices] + ligand_global_indices = np.where(ligand_mask)[0] + + keep = np.sort(np.concatenate([pocket_global_indices, ligand_global_indices])) + is_ligand = np.isin(keep, ligand_global_indices) + cropped = atom_array[keep] + + return cropped + +def featurize_for_docking(atom_array: AtomArray) -> dict: + is_ligand = atom_array.is_ligand.astype(bool) + target_coords = atom_array.coord.astype(np.float32) + + input_coords = target_coords.copy() + input_coords[is_ligand] = 0.0 + + atomic_numbers = np.array( + [ELEMENT_NAME_TO_ATOMIC_NUMBER.get(e.upper(), 0) for e in atom_array.element], + dtype=np.int64, + ) + + bonds = atom_array.bonds.as_array() + edge_index = bonds[:, :2].T.astype(np.int64) + + return { + "atomic_numbers": atomic_numbers, + "is_ligand": is_ligand, + "target_coords": target_coords, + "edge_index": edge_index, + "input_coords": input_coords, + } + diff --git a/docs/conf.py b/docs/conf.py index 358c8db6..9755db2e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -44,8 +44,11 @@ "sphinx.ext.viewcode", # Add source code links "sphinx.ext.napoleon", # Google/NumPy style docstrings "sphinx_gallery.gen_gallery", # Generates auto_examples/ from examples/ + "myst_parser", # Support for Markdown files + "sphinx_design", # For better layout and design components ] + templates_path = ["_templates"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "examples/GALLERY_HEADER.rst", "ml/preprocessing.rst"] diff --git a/docs/docs_requirements.txt b/docs/docs_requirements.txt index e3661864..a4d7811b 100644 --- a/docs/docs_requirements.txt +++ b/docs/docs_requirements.txt @@ -5,4 +5,5 @@ sphinx-autodoc-typehints>=1.20.0,<2 nbsphinx>=0.8.9,<1 sphinx-gallery>=0.8.1,<1 ghp-import>=2.0.0,<3 -pandoc>=2.0.0,<3 \ No newline at end of file +pandoc>=2.0.0,<3 +myst-parser>=5.0.0 diff --git a/docs/how_to_build_a_model.md b/docs/how_to_build_a_model.md new file mode 100644 index 00000000..eabb29cf --- /dev/null +++ b/docs/how_to_build_a_model.md @@ -0,0 +1,106 @@ +# How to Build a Model Using AtomWorks + +## Table of Contents + +(aw_build_model_intro)= +## Introduction +In this tutorial you will combine several of the functionalities of AtomWorks to build a fully functional machine learning model to predict how ligands will attach to protein pockets. + +```{important} +This tutorial will walk you through creating a script using the AtomWorks API. + +For those that want to use the tutorial text as structure and hints to write your own script, we have hidden the code in collapsible cells. + +If you would like to see the full script, it is provided in the tutorial files. +``` + +## Prerequisites +- The parquet files take up about 674 MB of space +- Pandas (though I think this comes with an AtomWorks installation) + - link to a good tutorial: https://www.w3schools.com/python/pandas/default.asp +- Having the PDB mirror set up requires ~100GB of space + +(aw_build_model_setup)= +## Setup +Before we start building a model, we need to organize the data we want to train on. For this tutorial we will be using subsets of existing parquets – sets of parsed PDBs – to train our model. + +Download and decompress the parquet files via: +```bash +wget https://files.ipd.uw.edu/pub/atomworks/dfs/pdb/2026_01_06.tar.gz +tar -xvf 2026_01_06.tar.gz +``` +After decompressing the folder you should see three parquet files: +- **assemblies.parquet:** There are multiple bio assemblies stored in a single PDB, each assembly will have multiple chains, interfaces, etc. +- **interfaces.parquet:** Contains metadata for all binary interfaces in the PDB +- **pn_units.parquet:** Contains metadata for each PN unit in the [PDB](https://www.rcsb.org/) + + + +For more information about these parquet files, see the [Data Mirroring Documentation](mirrors.rst) + +(aw_build_model_organize_data)= +## Organizing the Data + +```{note} +You will need to have a PDB mirror set up to follow this portion of the tutorial. +If you do not have this you can: +1. follow the instructions [here](mirrors.rst) +2. Use the pre-cleaned parquet file and skip to section +``` + +We will use the **interfaces** and **pn_units** parquet files for this tutorial. + +```{important} +``` + +### Loading the parquet file using Pandas +Let's first take a look at the information contained in the parquet file. Parquet files are not human parsable, but we can use [Pandas](https://pandas.pydata.org/) to inspect it. + +```{dropdown} Click to see the code +Load in the datasets: +```python +import pandas as pd + +interfaces = pd.read_parquet("2026_01_06/interfaces.parquet") +pn_units = pd.read_parquet("2026_01_06/pn_units.parquet") +``` +View the dataset columns: +```python +interfaces.columns +pn_units.columns +``` + +Let's take a closer look at a few of the columns in the interfaces parquet: +- `pdb_id`: The identifier for the specific structure in the PDB +- `assembly_id`: +- `pn_unit_1_iid`: Label for the first PN unit in the interface +- `pn_unit_2_iid`: Label for the second PN unit in the interface +- `involves_loi`: Boolean for if the LOI (Ligand of Interest) is part of the interface +- `is_inter_molecule`: Boolean for if the interfaces is between two molecules (True) or within the same molecule (False) +- `involves_metal`: Boolean for if the interface involves a metal atom +- `involves_covalent_modification`: Boolean for if the interface involves a covalent modification (e.g. glycosylation) +- `num_contacts`: Number of contacts between the two PN units that create the interface +- `min_distance`: Minimum distance between the contacts that create the interface + +And at a few of the columns in the PN units parquet file: +- pn_unit_iid +- is_polymer +- num_resolved_residues +- pn_unit_type + +For the purposes of this tutorial, we want to remove rows where `involves_covalent_modification` is True and where the interface involves non-protein or non-ligand chains + +```{dropdown} Click to see the code +```python +interfaces_no_covalent = interfaces[interfaces["involves_covalent_modification"] != True] +``` + + + +## Glossary + +parquet +PN units - this is actually in the docs [glossary](https://rosettacommons.github.io/atomworks/latest/glossary.html#chains-pn-units-and-molecules) + + + diff --git a/docs/index.rst b/docs/index.rst index 5bdf7632..14390144 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -4,7 +4,7 @@ contain the root `toctree` directive. atomworks documentation -====================== +======================= Welcome to **atomworks** — a toolkit for converting, parsing, and manipulating biological structure and sequence data, inspired by the Biotite library. Quickly convert between formats, extract features, and prepare data for machine learning or structural analysis. @@ -13,6 +13,12 @@ Welcome to **atomworks** — a toolkit for converting, parsing, and manipulating :alt: atomworks datapipelines +.. toctree:: + :maxdepth: 1 + :caption: Building a Model with AtomWorks + + how_to_build_a_model.md + .. toctree:: :maxdepth: 2 :caption: Navigation