diff --git a/pxdesign/runner/dumper.py b/pxdesign/runner/dumper.py index 6c43ac7..df06215 100644 --- a/pxdesign/runner/dumper.py +++ b/pxdesign/runner/dumper.py @@ -136,26 +136,24 @@ def _save_structure( entity_poly_type=None, ): N_sample = pred_coordinates.shape[0] + # Set annotations once before the loop + atom_array.set_annotation( + "b_factor", np.round(np.zeros(len(atom_array)).astype(float), 2) + ) + if "occupancy" not in atom_array._annot: + atom_array.set_annotation( + "occupancy", np.round(np.ones(len(atom_array)), 2) + ) for sample_idx in range(N_sample): output_fpath = os.path.join( prediction_save_dir, f"{sample_name}_sample_{sample_idx}.cif" ) - # fake b_factor - atom_array.set_annotation( - "b_factor", np.round(np.zeros(len(atom_array)).astype(float), 2) - ) - if "occupancy" not in atom_array._annot: - # fake occupancy - atom_array.set_annotation( - "occupancy", np.round(np.ones(len(atom_array)), 2) - ) save_structure_cif( atom_array, pred_coordinates[sample_idx], output_fpath, entity_poly_type, sample_name, - # save_wounresol=False, ) def _save_confidence( diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_parallel_cif_writing.py b/tests/test_parallel_cif_writing.py new file mode 100644 index 0000000..9438731 --- /dev/null +++ b/tests/test_parallel_cif_writing.py @@ -0,0 +1,17 @@ +"""Test that annotations are set once before the CIF writing loop.""" + +from pathlib import Path + + +def test_annotations_set_before_loop(): + """Verify b_factor and occupancy are set before the per-sample loop.""" + source = ( + Path(__file__).parent.parent / "pxdesign" / "runner" / "dumper.py" + ).read_text() + method_start = source.find("def _save_structure") + method_source = source[method_start:] + b_factor_pos = method_source.find("b_factor") + loop_pos = method_source.find("for sample_idx") + assert b_factor_pos < loop_pos, ( + "Annotations (b_factor, occupancy) should be set once before the per-sample loop" + )