Skip to content

Commit 30861aa

Browse files
committed
refactor(mean): minor code modifications to improve efficiency
1 parent e0a2028 commit 30861aa

File tree

1 file changed

+8
-12
lines changed

1 file changed

+8
-12
lines changed

deepem/loss/mean.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -115,23 +115,19 @@ def forward(
115115
groups = None
116116
if self.recompute_ext:
117117
assert splt is not None
118-
trgt_np = np.squeeze(trgt.cpu().numpy())
119-
splt_np = np.squeeze(splt.cpu().numpy())
120-
mask_np = np.squeeze(mask.cpu().numpy())
121-
groups = create_mapping(trgt_np, splt_np, mask_np)
118+
trgt = torch.squeeze(trgt)
119+
splt = torch.squeeze(splt)
120+
mask = torch.squeeze(mask)
121+
groups = create_mapping(trgt.cpu().numpy(), splt.cpu().numpy(), mask.cpu().numpy())
122122
trgt = splt
123123

124124
trgt = trgt.to(torch.int)
125125

126-
# Extract unique IDs
127-
ids = np.unique(trgt[mask > 0].cpu().numpy())
128-
129-
# Remove 0s from the IDs if `mask_background` is True
126+
# Filter out background and get unique IDs
127+
masked_trgt = trgt[mask > 0]
130128
if self.mask_background:
131-
ids = ids[ids != 0]
132-
133-
# Convert numpy array to a Python list
134-
ids = ids.tolist()
129+
masked_trgt = masked_trgt[masked_trgt != 0]
130+
ids = torch.unique(masked_trgt).tolist()
135131

136132
# Recompute external matrix
137133
mext = self.compute_ext_matrix(ids, groups, self.recompute_ext, device)

0 commit comments

Comments
 (0)