Skip to content

Commit 9a605f0

Browse files
committed
RF3 has full array, deal with that during saving
1 parent 1005b21 commit 9a605f0

1 file changed

Lines changed: 29 additions & 2 deletions

File tree

src/sampleworks/utils/guidance_script_utils.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,33 @@ def save_trajectory(
6969
raise ValueError(f"Invalid scaler type: {scaler_type}")
7070

7171

72+
def _assign_coords_to_array(
73+
array_copy: AtomArrayStack,
74+
coords: np.ndarray,
75+
reward_param_mask: np.ndarray,
76+
) -> None:
77+
"""Assign trajectory coords into an AtomArrayStack, handling shape mismatches.
78+
79+
When the trajectory spans all atoms in the array (e.g. model
80+
trajectories saved during a has_mismatch run (see pure_guidance.py or fk_steering.py)), coords
81+
are assigned directly to ``.coord``. Otherwise the ``reward_param_mask``
82+
is used to index into the correct atom subset.
83+
"""
84+
n_atoms_array = array_copy.coord.shape[-2] # pyright: ignore[reportOptionalMemberAccess]
85+
n_atoms_coords = coords.shape[-2]
86+
87+
if n_atoms_coords == n_atoms_array:
88+
array_copy.coord = coords
89+
elif n_atoms_coords == int(reward_param_mask.sum()):
90+
array_copy.coord[:, reward_param_mask] = coords # pyright: ignore[reportOptionalSubscript]
91+
else:
92+
raise ValueError(
93+
f"Trajectory coords ({n_atoms_coords} atoms) match neither "
94+
f"the full atom array ({n_atoms_array}) nor the masked subset "
95+
f"({int(reward_param_mask.sum())})"
96+
)
97+
98+
7299
def _save_trajectory(
73100
trajectory, atom_array, output_dir, reward_param_mask, subdir_name, save_every
74101
):
@@ -89,7 +116,7 @@ def _save_trajectory(
89116
continue
90117
array_copy = atom_array.copy()
91118
array_copy = stack([array_copy] * ensemble_size)
92-
array_copy.coord[:, reward_param_mask] = coords.detach().numpy() # type: ignore[reportOptionalSubscript] coords will be subscriptable
119+
_assign_coords_to_array(array_copy, coords.detach().numpy(), reward_param_mask)
93120
save_structure(str(output_dir / f"trajectory_{i}.cif"), array_copy)
94121

95122

@@ -115,7 +142,7 @@ def _save_fk_steering_trajectory(
115142
array_copy = stack([array_copy] * ensemble_size)
116143
# we save only the first ensemble out of n_particles, since saving
117144
# each particle at every step would clog trajectory saving
118-
array_copy.coord[:, reward_param_mask] = coords[0].detach().numpy() # type: ignore[reportOptionalSubscript] coords will be subscriptable
145+
_assign_coords_to_array(array_copy, coords[0].detach().numpy(), reward_param_mask)
119146
save_structure(str(output_dir / f"trajectory_{i}.cif"), array_copy)
120147

121148

0 commit comments

Comments
 (0)