Skip to content

Commit 28cd6ac

Browse files
committed
Save progress. Average distance working.
1 parent 62c4adf commit 28cd6ac

File tree

6 files changed

+61
-99
lines changed

6 files changed

+61
-99
lines changed

README.md

+6-2
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ We provide scripts for analysing JAMUN and original MD trajectories in [https://
162162

163163
## Data Generation
164164

165-
We also provide scripts for generating the MD simulation data with [OpenMM](https://openmm.org/), including energy minimization and calibration steps with NVT and NPT ensembles.
165+
### Running Molecular Dynamics with OpenMM
166+
167+
We provide scripts for generating MD simulation data with [OpenMM](https://openmm.org/), including energy minimization and calibration steps with NVT and NPT ensembles.
166168

167169
```bash
168170
python scripts/generate_data/run_simulation.py [INIT_PDB]
@@ -171,7 +173,9 @@ python scripts/generate_data/run_simulation.py [INIT_PDB]
171173
The defaults correspond to our setup for the capped diamines.
172174
Please run this script with the `-h` flag to see all simulation parameters.
173175

174-
## Preprocessing
176+
### Preprocessing
177+
178+
Some of the datasets require some preprocessing for easier consumption, for eg. the MDGen data:
175179

176180
```bash
177181
source .env

src/jamun/cmdline/train.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import jamun
1818
from jamun.hydra import instantiate_dict_cfg
1919
from jamun.hydra.utils import format_resolver
20-
from jamun.utils import compute_average_squared_distance_from_data, dist_log, find_checkpoint
20+
from jamun.utils import compute_average_squared_distance_from_datasets, dist_log, find_checkpoint
2121

2222
dotenv.load_dotenv(".env", verbose=True)
2323
OmegaConf.register_new_resolver("format", format_resolver)
@@ -27,10 +27,9 @@ def compute_average_squared_distance_from_config(cfg: OmegaConf) -> float:
2727
"""Computes the average squared distance for normalization from the data."""
2828
datamodule = hydra.utils.instantiate(cfg.data.datamodule)
2929
datamodule.setup("compute_normalization")
30-
train_dataloader = datamodule.train_dataloader()
30+
train_datasets = datamodule.datasets["train"]
3131
cutoff = cfg.model.max_radius
32-
average_squared_distance = compute_average_squared_distance_from_data(train_dataloader, cutoff, cfg.trainer)
33-
average_squared_distance = float(average_squared_distance)
32+
average_squared_distance = compute_average_squared_distance_from_datasets(train_datasets, cutoff)
3433
return average_squared_distance
3534

3635

src/jamun/data/_utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import collections
22
import os
33
import re
4-
import random
54
from typing import List, Optional, Sequence
65

6+
import pandas as pd
77
import hydra
88
import requests
99
import torch
@@ -124,6 +124,7 @@ def parse_datasets_from_directory_new(
124124
max_datasets: Optional[int] = None,
125125
max_datasets_offset: Optional[int] = None,
126126
filter_codes: Optional[Sequence[str]] = None,
127+
split_csv: Optional[str] = None,
127128
as_iterable: bool = False,
128129
**dataset_kwargs,
129130
) -> List[MDtrajDataset]:
@@ -170,6 +171,9 @@ def parse_datasets_from_directory_new(
170171
pdb_files[code] = pdb_file
171172

172173
# Filter out codes
174+
if split_csv is not None:
175+
filter_codes = pd.read_csv("train.csv")["entry"].tolist()
176+
173177
if filter_codes is not None:
174178
codes = [code for code in codes if code in set(filter_codes)]
175179

src/jamun/metrics/_save_trajectory.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, save_true_trajectory: bool = False, *args, **kwargs):
2525
for ext in self.true_samples_extensions:
2626
os.makedirs(os.path.join(self.true_samples_dir, ext), exist_ok=True)
2727

28-
self.pred_samples_extensions = ["npy", "pdb", "dcd"]
28+
self.pred_samples_extensions = ["pdb", "dcd"]
2929
for ext in self.pred_samples_extensions:
3030
os.makedirs(os.path.join(self.pred_samples_dir, ext), exist_ok=True)
3131

@@ -69,29 +69,29 @@ def on_sample_end(self):
6969
label = self.dataset.label()
7070
label = label.replace("/", "_").replace("=", "-")
7171

72-
for ext in ["npy", "pdb", "dcd"]:
73-
filename = self.filename_pred("joined", ext)
74-
artifact = wandb.Artifact(f"{label}_pred_samples_joined", type="pred_samples_joined")
75-
artifact.add_file(filename, f"pred_samples_joined.{ext}")
76-
wandb.log_artifact(artifact)
72+
# for ext in self.pred_samples_extensions:
73+
# filename = self.filename_pred("joined", ext)
74+
# artifact = wandb.Artifact(f"{label}_pred_samples_joined", type="pred_samples_joined")
75+
# artifact.add_file(filename, f"pred_samples_joined.{ext}")
76+
# wandb.log_artifact(artifact)
7777

7878
def compute(self) -> Dict[str, float]:
7979
# Save the predicted samples as numpy files.
80-
samples_np = self.sample_tensors(new=True).cpu().detach().numpy()
81-
for trajectory_index, sample in enumerate(samples_np):
82-
np.save(self.filename_pred(trajectory_index, "npy"), sample)
80+
# samples_np = self.sample_tensors(new=True).cpu().detach().numpy()
81+
# for trajectory_index, sample in enumerate(samples_np):
82+
# np.save(self.filename_pred(trajectory_index, "npy"), sample)
8383

84-
samples_joined_np = self.joined_sample_tensor().cpu().detach().numpy()
85-
np.save(self.filename_pred("joined", "npy"), samples_joined_np)
84+
# samples_joined_np = self.joined_sample_tensor().cpu().detach().numpy()
85+
# np.save(self.filename_pred("joined", "npy"), samples_joined_np)
8686

8787
# Save the predict sample trajectory as a PDB and DCD file.
8888
pred_trajectories = self.sample_trajectories(new=True)
8989
for trajectory_index, pred_trajectory in enumerate(pred_trajectories, start=self.num_chains_seen):
90-
utils.save_pdb(pred_trajectory, self.filename_pred(trajectory_index, "pdb"))
90+
# utils.save_pdb(pred_trajectory, self.filename_pred(trajectory_index, "pdb"))
9191
pred_trajectory.save_dcd(self.filename_pred(trajectory_index, "dcd"))
9292

9393
pred_trajectory_joined = self.joined_sample_trajectory()
94-
utils.save_pdb(pred_trajectory_joined, self.filename_pred("joined", "pdb"))
94+
# utils.save_pdb(pred_trajectory_joined, self.filename_pred("joined", "pdb"))
9595
pred_trajectory_joined.save_dcd(self.filename_pred("joined", "dcd"))
9696

9797
return {}

src/jamun/utils/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .align import align_A_to_B, align_A_to_B_batched
2-
from .average_squared_distance import compute_average_squared_distance, compute_average_squared_distance_from_data
2+
from .average_squared_distance import compute_average_squared_distance, compute_average_squared_distance_from_datasets
33
from .checkpoint import find_checkpoint, find_checkpoint_directory, get_wandb_run_config
44
from .data_with_residue_info import DataWithResidueInformation
55
from .dist_log import dist_log, wandb_dist_log

src/jamun/utils/average_squared_distance.py

+33-78
Original file line numberDiff line numberDiff line change
@@ -112,45 +112,6 @@ def compute_final_statistics(self):
112112
return average_squared_distance
113113

114114

115-
def compute_average_squared_distance_from_data(
116-
datamodule: pl.LightningDataModule,
117-
cutoff: float,
118-
trainer_cfg: Dict[str, Any],
119-
num_estimation_graphs: int = 5000,
120-
verbose: bool = False
121-
):
122-
"""Compute normalization using a Lightning trainer.
123-
124-
Args:
125-
datamodule: The Lightning datamodule
126-
cutoff (float): The radius cutoff for distance calculations
127-
compute_average_squared_distance_fn (callable): Function to compute average
128-
squared distance for a graph
129-
trainer_cfg: Configuration for the Lightning trainer
130-
num_estimation_graphs (int): Maximum number of graphs to process
131-
verbose (bool): Whether to print detailed statistics
132-
133-
Returns:
134-
float: The computed average squared distance
135-
"""
136-
137-
# Create the normalization module
138-
norm_module = ComputeNormalizationModule(
139-
cutoff=cutoff,
140-
num_estimation_graphs=num_estimation_graphs,
141-
verbose=verbose
142-
)
143-
144-
# Create the trainer
145-
trainer = hydra.utils.instantiate(trainer_cfg)
146-
147-
# Fit without any callbacks or loggers
148-
trainer.fit(norm_module, datamodule=datamodule)
149-
150-
# Compute and return the final statistics
151-
return norm_module.compute_final_statistics()
152-
153-
154115
def compute_distance_matrix(x: np.ndarray, cutoff: Optional[float] = None) -> np.ndarray:
155116
"""Computes the distance matrix between points in x, ignoring self-distances."""
156117
if x.shape[-1] != 3:
@@ -177,42 +138,36 @@ def compute_average_squared_distance(x: np.ndarray, cutoff: Optional[float] = No
177138
return np.mean(dist_x**2)
178139

179140

180-
# def compute_average_squared_distance_from_data(
181-
# dataloader: torch.utils.data.DataLoader,
182-
# cutoff: float,
183-
# num_estimation_graphs: int = 5000,
184-
# verbose: bool = False,
185-
# ) -> float:
186-
# """Computes the average squared distance for normalization."""
187-
# avg_sq_dists = collections.defaultdict(list)
188-
# num_graphs = 0
189-
# for batch in dataloader:
190-
# for graph in batch.to_data_list():
191-
# pos = np.asarray(graph.pos)
192-
# avg_sq_dist = compute_average_squared_distance(pos, cutoff=cutoff)
193-
# avg_sq_dists[graph.dataset_label].append(avg_sq_dist)
194-
# num_graphs += 1
195-
196-
# if num_graphs >= num_estimation_graphs:
197-
# break
198-
199-
# mean_avg_sq_dist = sum(np.sum(avg_sq_dists[label]) for label in avg_sq_dists) / num_graphs
200-
# utils.dist_log(f"Mean average squared distance = {mean_avg_sq_dist:0.3f} nm^2")
201-
202-
# if verbose:
203-
# utils.dist_log(f"For cutoff {cutoff} nm:")
204-
# for label in sorted(avg_sq_dists):
205-
# utils.dist_log(
206-
# f"- Dataset {label}: Average squared distance = {np.mean(avg_sq_dists[label]):0.3f} +- {np.std(avg_sq_dists[label]):0.3f} nm^2"
207-
# )
208-
209-
# # Average across all processes, if distributed.
210-
# print("torch.distributed.is_initialized():", torch.distributed.is_initialized())
211-
# mean_avg_sq_dist = torch.tensor(mean_avg_sq_dist, device="cuda")
212-
213-
# print("mean_avg_sq_dist bef:", mean_avg_sq_dist)
214-
# torch.distributed.all_reduce(mean_avg_sq_dist, op=torch.distributed.ReduceOp.AVG)
215-
# mean_avg_sq_dist = mean_avg_sq_dist.item()
216-
# print("mean_avg_sq_dist aft:", mean_avg_sq_dist)
217-
218-
# return mean_avg_sq_dist
141+
def compute_average_squared_distance_from_datasets(
142+
datasets: Sequence[torch.utils.data.Dataset],
143+
cutoff: float,
144+
num_estimation_datasets: int = 50,
145+
num_estimation_graphs_per_dataset: int = 100,
146+
verbose: bool = False,
147+
) -> float:
148+
"""Computes the average squared distance for normalization."""
149+
avg_sq_dists = collections.defaultdict(list)
150+
151+
for dataset in datasets[:num_estimation_datasets]:
152+
num_graphs = 0
153+
154+
for graph in dataset:
155+
pos = np.asarray(graph.pos)
156+
avg_sq_dist = compute_average_squared_distance(pos, cutoff=cutoff)
157+
avg_sq_dists[graph.dataset_label].append(avg_sq_dist)
158+
num_graphs += 1
159+
160+
if num_graphs >= num_estimation_graphs_per_dataset:
161+
break
162+
163+
mean_avg_sq_dist = sum(np.sum(avg_sq_dists[label]) for label in avg_sq_dists) / num_graphs
164+
utils.dist_log(f"Mean average squared distance = {mean_avg_sq_dist:0.3f} nm^2")
165+
166+
if verbose:
167+
utils.dist_log(f"For cutoff {cutoff} nm:")
168+
for label in sorted(avg_sq_dists):
169+
utils.dist_log(
170+
f"- Dataset {label}: Average squared distance = {np.mean(avg_sq_dists[label]):0.3f} +- {np.std(avg_sq_dists[label]):0.3f} nm^2"
171+
)
172+
173+
return float(mean_avg_sq_dist)

0 commit comments

Comments
 (0)