From b3af323248df9bc9506dff27b302547465950d42 Mon Sep 17 00:00:00 2001 From: I0558018 Date: Thu, 6 Jul 2023 02:57:42 +0000 Subject: [PATCH] fix the align_mask bug when calculating the prmsd_loss --- igfold/model/IgFold.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/igfold/model/IgFold.py b/igfold/model/IgFold.py index c6c7c3a..8cde62b 100644 --- a/igfold/model/IgFold.py +++ b/igfold/model/IgFold.py @@ -314,8 +314,8 @@ def forward( cum_seq_lens = np.cumsum([0] + seq_lens) for sl_i, sl in enumerate(seq_lens): align_mask_ = align_mask.clone() - align_mask_[:, :cum_seq_lens[sl_i]] = False - align_mask_[:, cum_seq_lens[sl_i + 1]:] = False + align_mask_[:, :4*cum_seq_lens[sl_i]] = False + align_mask_[:, 4*cum_seq_lens[sl_i + 1]:] = False res_batch_mask_ = res_batch_mask.clone() res_batch_mask_[:, :cum_seq_lens[sl_i]] = False res_batch_mask_[:, cum_seq_lens[sl_i + 1]:] = False @@ -485,4 +485,4 @@ def gradient_refine( output.coords = coords output.prmsd = prmsd - return output \ No newline at end of file + return output