Skip to content

Commit 35b74ad

Browse files
committed
Minor updates to scripts.
1 parent 4acae17 commit 35b74ad

19 files changed

+438
-93
lines changed

README.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ This is because certain dependencies are tricky to install directly.
3333
conda create --name jamun python=3.11 -y
3434
conda activate jamun
3535
conda install -c conda-forge ambertools=23 openmm pdbfixer pyemma -y
36+
conda install pulchra -c bioconda -y
3637
```
3738

3839
The remaining dependencies can be installed via `pip` or [`uv`](https://docs.astral.sh/uv/getting-started/installation/) (recommended).
@@ -191,7 +192,10 @@ Please run this script with the `-h` flag to see all simulation parameters.
191192
## Preprocessing
192193

193194
```bash
194-
python scripts/process_mdgen.py --input-dir /data/bucket/kleinhej/mdgen --output-dir /data/bucket/kleinhej/mdgen/data/4AA_sims_partitioned_chunked
195+
source .env
196+
python scripts/process_mdgen.py \
197+
--input-dir ${JAMUN_DATA_PATH}/mdgen \
198+
--output-dir ${JAMUN_DATA_PATH}/mdgen/data/4AA_sims_partitioned_chunked
195199
```
196200

197201
## Citation

analysis/analysis_sweep.py

+1-59
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
sys.path.append("./")
1313

14+
from jamun.utils.slurm import wait_for_jobs
1415
import load_trajectory
1516

1617

@@ -43,65 +44,6 @@ def run_analysis(args) -> Tuple[str, Optional[str], Optional[str]]:
4344
return (peptide, None, None)
4445

4546

46-
def wait_for_jobs(job_ids: List[str], poll_interval: int = 60) -> int:
47-
"""Wait for all jobs to finish and print progress."""
48-
49-
previous_states = collections.defaultdict(str)
50-
completion_count = 0
51-
total_jobs = len(job_ids)
52-
53-
while True:
54-
cmd = [
55-
"sacct",
56-
"-j", ",".join(job_ids),
57-
"--format=JobID,State",
58-
"--noheader",
59-
"--parsable2"
60-
]
61-
62-
result = subprocess.run(cmd, capture_output=True, text=True)
63-
current_states: Dict[str, str] = {}
64-
65-
# Parse current states.
66-
for line in result.stdout.strip().split('\n'):
67-
if not line: continue
68-
jobid, state = line.split('|')
69-
if '.' not in jobid: # Only main jobs
70-
current_states[jobid] = state
71-
72-
# If job just completed (wasn't completed before).
73-
if state == 'COMPLETED' and previous_states[jobid] != 'COMPLETED':
74-
completion_count += 1
75-
print(f"Job {jobid} completed successfully. Progress: {completion_count}/{total_jobs}")
76-
77-
# Update states for next iteration.
78-
previous_states.update(current_states)
79-
80-
# Group jobs by state for summary.
81-
states_summary = collections.defaultdict(int)
82-
for state in current_states.values():
83-
states_summary[state] += 1
84-
85-
print(f"\nStatus summary:")
86-
print(f"Completed: {completion_count}/{total_jobs} ({completion_count/total_jobs*100:.1f}%)")
87-
print(f"Current states: {dict(states_summary)}")
88-
89-
# Check if all jobs reached terminal state.
90-
all_done = all(state in ['COMPLETED', 'FAILED', 'TIMEOUT', 'OUT_OF_MEMORY', 'CANCELLED']
91-
for state in current_states.values())
92-
93-
if all_done:
94-
print("\nAll jobs finished!")
95-
failures = [jid for jid, state in current_states.items() if state != 'COMPLETED']
96-
if failures:
97-
print(f"Failed jobs: {failures}")
98-
break
99-
100-
time.sleep(poll_interval)
101-
102-
return completion_count
103-
104-
10547
def main():
10648
parser = argparse.ArgumentParser(description="Run analysis for multiple peptides")
10749
parser.add_argument("--csv", type=str, required=True, help="CSV file containing wandb runs")

configs/experiment/sample_mdgen.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ repeat_init_samples: 1
1212
continue_chain: true
1313

1414
# MDGen
15-
wandb_train_run_path: prescient-design/jamun/lmnf3vyu
16-
15+
wandb_train_run_path: prescient-design/jamun/brd51ln4
1716

1817
checkpoint_type: best_so_far
1918
sigma: 0.04

configs/experiment/train_idrome.yaml

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# @package _global_
2+
3+
compute_average_squared_distance_from_data: false
4+
5+
model:
6+
average_squared_distance: 0.332
7+
sigma_distribution:
8+
_target_: jamun.distributions.ConstantSigma
9+
sigma: 0.04
10+
max_radius: 1.0
11+
optim:
12+
lr: 0.002
13+
use_torch_compile: true
14+
torch_compile_kwargs:
15+
fullgraph: true
16+
dynamic: true
17+
mode: default
18+
19+
callbacks:
20+
viz:
21+
sigma_list: ["${model.sigma_distribution.sigma}"]
22+
23+
data:
24+
datamodule:
25+
num_workers: 4
26+
batch_size: 32
27+
datasets:
28+
train:
29+
_target_: jamun.data.parse_datasets_from_directory
30+
root: "${paths.data_path}/IDRome_v4_preprocessed/all_atom_relaxed_combined/"
31+
traj_pattern: "^(.*)/traj.xtc"
32+
pdb_pattern: "^(.*)/top.pdb"
33+
34+
val:
35+
_target_: jamun.data.parse_datasets_from_directory
36+
root: "${paths.data_path}/IDRome_v4_preprocessed/all_atom_relaxed_combined/"
37+
traj_pattern: "^(.*)/traj.xtc"
38+
pdb_pattern: "^(.*)/top.pdb"
39+
subsample: 100
40+
41+
test:
42+
_target_: jamun.data.parse_datasets_from_directory
43+
root: "${paths.data_path}/IDRome_v4_preprocessed/all_atom_relaxed_combined/"
44+
traj_pattern: "^(.*)/traj.xtc"
45+
pdb_pattern: "^(.*)/top.pdb"
46+
subsample: 100
47+
48+
trainer:
49+
val_check_interval: 30000
50+
limit_val_batches: 1000
51+
max_epochs: 10
52+
53+
logger:
54+
wandb:
55+
group: train_idrome
56+
+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# @package _global_
2+
3+
compute_average_squared_distance_from_data: false
4+
5+
model:
6+
average_squared_distance: 0.332
7+
sigma_distribution:
8+
_target_: jamun.distributions.ConstantSigma
9+
sigma: 0.08
10+
max_radius: 2.0
11+
optim:
12+
lr: 0.002
13+
use_torch_compile: true
14+
torch_compile_kwargs:
15+
fullgraph: true
16+
dynamic: true
17+
mode: default
18+
19+
callbacks:
20+
viz:
21+
sigma_list: ["${model.sigma_distribution.sigma}"]
22+
23+
data:
24+
datamodule:
25+
num_workers: 4
26+
batch_size: 32
27+
datasets:
28+
train:
29+
_target_: jamun.data.parse_datasets_from_directory
30+
root: "${paths.data_path}/IDRome_v4_preprocessed/flat/"
31+
traj_pattern: "^(.*)/traj.xtc"
32+
pdb_pattern: "^(.*)/top.pdb"
33+
34+
val:
35+
_target_: jamun.data.parse_datasets_from_directory
36+
root: "${paths.data_path}/IDRome_v4_preprocessed/flat/"
37+
traj_pattern: "^(.*)/traj.xtc"
38+
pdb_pattern: "^(.*)/top.pdb"
39+
subsample: 100
40+
41+
test:
42+
_target_: jamun.data.parse_datasets_from_directory
43+
root: "${paths.data_path}/IDRome_v4_preprocessed/flat/"
44+
traj_pattern: "^(.*)/traj.xtc"
45+
pdb_pattern: "^(.*)/top.pdb"
46+
subsample: 100
47+
48+
trainer:
49+
val_check_interval: 30000
50+
limit_val_batches: 1000
51+
max_epochs: 10
52+
53+
logger:
54+
wandb:
55+
group: train_idrome
56+

configs/experiment/train_mdgen.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ data:
3232
pdb_pattern: "^(....).pdb"
3333
as_iterable: true
3434
subsample: 5
35+
start_at_random_frame: true
3536

3637
val:
3738
_target_: jamun.data.parse_datasets_from_directory
@@ -40,6 +41,7 @@ data:
4041
pdb_pattern: "^(....).pdb"
4142
as_iterable: true
4243
subsample: 100
44+
start_at_random_frame: true
4345

4446
test:
4547
_target_: jamun.data.parse_datasets_from_directory
@@ -48,8 +50,11 @@ data:
4850
pdb_pattern: "^(....).pdb"
4951
as_iterable: true
5052
subsample: 100
53+
start_at_random_frame: true
5154

5255
trainer:
56+
val_check_interval: 30000
57+
limit_val_batches: 1000
5358
max_epochs: 10
5459

5560
logger:

scripts/IDRome/README.md

+16-3
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,26 @@ conda install pulchra -c bioconda --yes
88
```
99

1010
```bash
11-
python scripts/generate_data/run_simulation.py /homefs/home/daigavaa/jamun/145_181/all_atom/top_AA.pdb --energy-minimization-only --energy-minimization-steps=5000
11+
source .env
12+
sbatch scripts/IDRome/to_all_atom_batched.sh \
13+
${JAMUN_DATA_PATH}/IDRome_v4_preprocessed/flat \
14+
${JAMUN_DATA_PATH}/IDRome_v4_preprocessed/flat_by_frame/ \
15+
${JAMUN_DATA_PATH}/IDRome_v4_preprocessed/all_atom/ \
16+
1000
1217
```
1318

1419
```bash
15-
sbatch to_all_atom_batched.sh /data/bucket/kleinhej/IDRome_v4_preprocessed/flat /data/bucket/kleinhej/IDRome_v4_preprocessed/flat_by_frame/ /data/bucket/kleinhej/IDRome_v4_preprocessed/all_atom/ 1000
20+
source .env
21+
sbatch scripts/IDRome/relax_structures_batched.sh \
22+
${JAMUN_DATA_PATH}/IDRome_v4_preprocessed/all_atom \
23+
${JAMUN_DATA_PATH}/IDRome_v4_preprocessed/all_atom_relaxed \
24+
1000
1625
```
1726

1827
```bash
19-
sbatch relax_structures_batched.sh /data/bucket/kleinhej/IDRome_v4_preprocessed/all_atom /data/bucket/kleinhej/IDRome_v4_preprocessed/all_atom_relaxed 1000
28+
source .env
29+
sbatch scripts/IDRome/combine_frames.sh \
30+
${JAMUN_DATA_PATH}/IDRome_v4_preprocessed/all_atom_relaxed/ \
31+
${JAMUN_DATA_PATH}/IDRome_v4_preprocessed/flat/ \
32+
${JAMUN_DATA_PATH}/IDRome_v4_preprocessed/all_atom_relaxed_combined/
2033
```

scripts/IDRome/combine_frames.py

+13-22
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212

1313

1414

15-
def combine_frames(args, use_srun: bool = True) -> None:
15+
def combine_frames(name: str, input_dir: str, original_traj_dir: str, output_dir: str) -> None:
1616
"""Combine relaxed IDRome v4 all-atom frames."""
17-
name, input_dir, original_traj_dir, output_dir = args
1817

1918
traj_AA = None
2019
frames = sorted(os.listdir(os.path.join(input_dir, name)),
@@ -39,33 +38,25 @@ def combine_frames(args, use_srun: bool = True) -> None:
3938
top_AA.add_atom(atom.name, element=atom.element, residue=res)
4039
top_AA.create_standard_bonds()
4140

42-
traj = md.load_xtc(os.path.join(original_traj_dir, f'{name}.xtc'), top=os.path.join(original_traj_dir, f'{name}.pdb'))
43-
traj_AA = md.Trajectory(traj_AA.xyz, top_AA, traj.time, traj.unitcell_lengths, traj.unitcell_angles)
44-
traj_AA[0].save_pdb(os.path.join(output_dir, f"{name}.pdb"))
45-
traj_AA.save_xtc(os.path.join(output_dir, f"{name}.xtc"))
41+
original_traj_path = os.path.join(original_traj_dir, name, 'traj.xtc')
42+
original_top_path = os.path.join(original_traj_dir, name, 'top.pdb')
43+
original_traj = md.load_xtc(original_traj_path, top=original_top_path)
44+
original_traj = original_traj[0:traj_AA.n_frames]
4645

46+
os.makedirs(os.path.join(output_dir, name), exist_ok=True)
47+
traj_AA = md.Trajectory(traj_AA.xyz, top_AA, original_traj.time, original_traj.unitcell_lengths, original_traj.unitcell_angles)
48+
traj_AA[0].save_pdb(os.path.join(output_dir, name, 'top.pdb'))
49+
traj_AA.save_xtc(os.path.join(output_dir, name, 'traj.xtc'))
50+
51+
py_logger.info(f"Successfully processed {name}")
4752

4853

4954
if __name__ == "__main__":
5055
parser = argparse.ArgumentParser(description='Convert IDRome v4 data to all-atom.')
56+
parser.add_argument('--name', help='Name of the trajectory.', type=str, required=True)
5157
parser.add_argument('--input-dir', help='Directory of relaxed all-atom trajectories (stored in each folder).', type=str, required=True)
5258
parser.add_argument('--original-traj-dir', help='Directory of original coarse-grained trajectories (stored in each folder).', type=str, required=True)
5359
parser.add_argument('--output-dir', '-o', help='Output directory to save combined relaxed all-atom trajectories (stored in each folder).', type=str, required=True)
54-
parser.add_argument('--num-workers', type=int, default=multiprocessing.cpu_count(),
55-
help='Number of parallel workers')
5660
args = parser.parse_args()
5761

58-
# Run in parallel.
59-
names = sorted(os.listdir(args.input_dir))
60-
preprocess_args = list(
61-
zip(
62-
names,
63-
[args.input_dir] * len(names),
64-
[args.original_traj_dir] * len(names),
65-
[args.output_dir] * len(names),
66-
)
67-
)
68-
with ProcessPoolExecutor(max_workers=args.num_workers) as executor:
69-
results = list(executor.map(combine_frames, preprocess_args))
70-
71-
62+
combine_frames(args.name, args.input_dir, args.original_traj_dir, args.output_dir)

scripts/IDRome/combine_frames.sh

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/bin/bash
2+
#SBATCH --partition=cpu
3+
#SBATCH --mem=1G
4+
#SBATCH --cpus-per-task=2
5+
#SBATCH --job-name=combine_frames
6+
#SBATCH --output=logs/%j_combine_frames.log
7+
#SBATCH --error=logs/%j_combine_frames.err
8+
#SBATCH --array=0-1
9+
10+
# Directory containing all input directories
11+
BASE_INPUT_DIR="$1"
12+
# Directory containing all original coarse-grained directories
13+
BASE_ORIGINAL_DIR="$2"
14+
# Directory to store output
15+
BASE_OUTPUT_DIR="$3"
16+
17+
eval "$(conda shell.bash hook)"
18+
conda activate jamun
19+
20+
# Get list of all directories and store in an array
21+
# You can use a file with directory names or generate the list dynamically
22+
DIRECTORIES=($(ls -d ${BASE_INPUT_DIR}/*/ | sort | xargs -n 1 basename))
23+
24+
# Each job processes 50 directories
25+
START_IDX=$((SLURM_ARRAY_TASK_ID * 50))
26+
END_IDX=$(( (SLURM_ARRAY_TASK_ID + 1) * 50 - 1 ))
27+
28+
for DIR_INDEX in $(seq ${START_IDX} ${END_IDX}); do
29+
NAME="${DIRECTORIES[${DIR_INDEX}]}"
30+
31+
echo "Processing directory: ${NAME} (index: ${DIR_INDEX})"
32+
33+
# Create output directory
34+
mkdir -p "${BASE_OUTPUT_DIR}/${NAME}"
35+
36+
# Check if the input frame exists
37+
if [ ! -f "${BASE_INPUT_DIR}/${NAME}/0_minimized_protein_0.pdb" ]; then
38+
echo "Input frame 0 does not exist in ${NAME}. Skipping."
39+
continue
40+
fi
41+
42+
python scripts/IDRome/combine_frames.py \
43+
--name "${NAME}" \
44+
--input-dir "${BASE_INPUT_DIR}" \
45+
--original-traj-dir "${BASE_ORIGINAL_DIR}" \
46+
--output-dir "${BASE_OUTPUT_DIR}"
47+
48+
echo "Completed processing directory ${NAME}"
49+
done
50+
51+
exit 0

0 commit comments

Comments
 (0)