Skip to content

Commit 6df0bef

Browse files
committed
refactor(mean): minor code modifications to improve efficiency
1 parent 6a47003 commit 6df0bef

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

deepem/loss/mean.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -107,23 +107,19 @@ def forward(
107107
groups = None
108108
if self.recompute_ext:
109109
assert splt is not None
110-
trgt_np = np.squeeze(trgt.cpu().numpy())
111-
splt_np = np.squeeze(splt.cpu().numpy())
112-
mask_np = np.squeeze(mask.cpu().numpy())
113-
groups = create_mapping(trgt_np, splt_np, mask_np)
110+
trgt = torch.squeeze(trgt)
111+
splt = torch.squeeze(splt)
112+
mask = torch.squeeze(mask)
113+
groups = create_mapping(trgt.cpu().numpy(), splt.cpu().numpy(), mask.cpu().numpy())
114114
trgt = splt
115115

116116
trgt = trgt.to(torch.int)
117117

118-
# Extract unique IDs
119-
ids = np.unique(trgt[mask > 0].cpu().numpy())
120-
121-
# Remove 0s from the IDs if `mask_background` is True
118+
# Filter out background and get unique IDs
119+
masked_trgt = trgt[mask > 0]
122120
if self.mask_background:
123-
ids = ids[ids != 0]
124-
125-
# Convert numpy array to a Python list
126-
ids = ids.tolist()
121+
masked_trgt = masked_trgt[masked_trgt != 0]
122+
ids = torch.unique(masked_trgt).tolist()
127123

128124
# Recompute external matrix
129125
mext = self.compute_ext_matrix(ids, groups, self.recompute_ext, device)

0 commit comments

Comments
 (0)