diff --git a/models/rfd3/src/rfd3/utils/inference.py b/models/rfd3/src/rfd3/utils/inference.py index 2d04e149..406f2196 100644 --- a/models/rfd3/src/rfd3/utils/inference.py +++ b/models/rfd3/src/rfd3/utils/inference.py @@ -405,21 +405,27 @@ def infer_ori_from_hotspots(atom_array: struc.AtomArray): # We can only perform distance computations on atoms with non-NaN coordinates nan_coords_mask = np.any(np.isnan(atom_array.coord), axis=1) - non_nan_atom_array = atom_array[~nan_coords_mask] + motif_mask = atom_array.is_motif_atom_with_fixed_coord.astype(bool) + non_nan_motif_atom_array = atom_array[~nan_coords_mask & motif_mask] + if non_nan_motif_atom_array.array_length() == 0: + raise ValueError( + "infer_ori_from_hotspots requires at least one fixed motif atom " + "(is_motif_atom_with_fixed_coord=True) to compute nearby atoms COM." + ) # Perform the distance computation # RFD2 used 10 Angstroms instead of 12, but was for residue-level hotspots DISTANCE_CUTOFF = 12.0 - cell_list = struc.CellList(non_nan_atom_array, cell_size=DISTANCE_CUTOFF) + cell_list = struc.CellList(non_nan_motif_atom_array, cell_size=DISTANCE_CUTOFF) nearby_atoms_mask = get_atom_mask_from_cell_list( hotspot_atom_array.coord, cell_list, - len(non_nan_atom_array), + len(non_nan_motif_atom_array), cutoff=DISTANCE_CUTOFF, ) # (n_query, n_cell_list) nearby_atoms_mask = np.any(nearby_atoms_mask, axis=0) # (n_cell_list,) - nearby_atoms_com = non_nan_atom_array.coord[nearby_atoms_mask].mean(axis=0) + nearby_atoms_com = non_nan_motif_atom_array.coord[nearby_atoms_mask].mean(axis=0) vector_from_core_to_hotspot = hotspot_com - nearby_atoms_com vector_from_core_to_hotspot = vector_from_core_to_hotspot / np.linalg.norm( diff --git a/models/rfd3na/src/rfd3na/utils/inference.py b/models/rfd3na/src/rfd3na/utils/inference.py index 61f82927..1838aedb 100644 --- a/models/rfd3na/src/rfd3na/utils/inference.py +++ b/models/rfd3na/src/rfd3na/utils/inference.py @@ -405,21 +405,26 @@ def infer_ori_from_hotspots(atom_array: struc.AtomArray): # We can only perform distance computations on atoms with non-NaN coordinates nan_coords_mask = np.any(np.isnan(atom_array.coord), axis=1) - non_nan_atom_array = atom_array[~nan_coords_mask] - + motif_mask = atom_array.is_motif_atom_with_fixed_coord.astype(bool) + non_nan_motif_atom_array = atom_array[~nan_coords_mask & motif_mask] + if non_nan_motif_atom_array.array_length() == 0: + raise ValueError( + "infer_ori_from_hotspots requires at least one fixed motif atom " + "(is_motif_atom_with_fixed_coord=True) to compute nearby atoms COM." + ) # Perform the distance computation # RFD2 used 10 Angstroms instead of 12, but was for residue-level hotspots DISTANCE_CUTOFF = 12.0 - cell_list = struc.CellList(non_nan_atom_array, cell_size=DISTANCE_CUTOFF) + cell_list = struc.CellList(non_nan_motif_atom_array, cell_size=DISTANCE_CUTOFF) nearby_atoms_mask = get_atom_mask_from_cell_list( hotspot_atom_array.coord, cell_list, - len(non_nan_atom_array), + len(non_nan_motif_atom_array), cutoff=DISTANCE_CUTOFF, ) # (n_query, n_cell_list) nearby_atoms_mask = np.any(nearby_atoms_mask, axis=0) # (n_cell_list,) - nearby_atoms_com = non_nan_atom_array.coord[nearby_atoms_mask].mean(axis=0) + nearby_atoms_com = non_nan_motif_atom_array.coord[nearby_atoms_mask].mean(axis=0) vector_from_core_to_hotspot = hotspot_com - nearby_atoms_com vector_from_core_to_hotspot = vector_from_core_to_hotspot / np.linalg.norm(