@@ -69,6 +69,33 @@ def save_trajectory(
6969 raise ValueError (f"Invalid scaler type: { scaler_type } " )
7070
7171
72+ def _assign_coords_to_array (
73+ array_copy : AtomArrayStack ,
74+ coords : np .ndarray ,
75+ reward_param_mask : np .ndarray ,
76+ ) -> None :
77+ """Assign trajectory coords into an AtomArrayStack, handling shape mismatches.
78+
79+ When the trajectory spans all atoms in the array (e.g. model
80+ trajectories saved during a has_mismatch run (see pure_guidance.py or fk_steering.py)), coords
81+ are assigned directly to ``.coord``. Otherwise the ``reward_param_mask``
82+ is used to index into the correct atom subset.
83+ """
84+ n_atoms_array = array_copy .coord .shape [- 2 ] # pyright: ignore[reportOptionalMemberAccess]
85+ n_atoms_coords = coords .shape [- 2 ]
86+
87+ if n_atoms_coords == n_atoms_array :
88+ array_copy .coord = coords
89+ elif n_atoms_coords == int (reward_param_mask .sum ()):
90+ array_copy .coord [:, reward_param_mask ] = coords # pyright: ignore[reportOptionalSubscript]
91+ else :
92+ raise ValueError (
93+ f"Trajectory coords ({ n_atoms_coords } atoms) match neither "
94+ f"the full atom array ({ n_atoms_array } ) nor the masked subset "
95+ f"({ int (reward_param_mask .sum ())} )"
96+ )
97+
98+
7299def _save_trajectory (
73100 trajectory , atom_array , output_dir , reward_param_mask , subdir_name , save_every
74101):
@@ -89,7 +116,7 @@ def _save_trajectory(
89116 continue
90117 array_copy = atom_array .copy ()
91118 array_copy = stack ([array_copy ] * ensemble_size )
92- array_copy . coord [:, reward_param_mask ] = coords .detach ().numpy () # type: ignore[reportOptionalSubscript] coords will be subscriptable
119+ _assign_coords_to_array ( array_copy , coords .detach ().numpy (), reward_param_mask )
93120 save_structure (str (output_dir / f"trajectory_{ i } .cif" ), array_copy )
94121
95122
@@ -115,7 +142,7 @@ def _save_fk_steering_trajectory(
115142 array_copy = stack ([array_copy ] * ensemble_size )
116143 # we save only the first ensemble out of n_particles, since saving
117144 # each particle at every step would clog trajectory saving
118- array_copy . coord [:, reward_param_mask ] = coords [0 ].detach ().numpy () # type: ignore[reportOptionalSubscript] coords will be subscriptable
145+ _assign_coords_to_array ( array_copy , coords [0 ].detach ().numpy (), reward_param_mask )
119146 save_structure (str (output_dir / f"trajectory_{ i } .cif" ), array_copy )
120147
121148
0 commit comments