File tree 1 file changed +8
-12
lines changed
1 file changed +8
-12
lines changed Original file line number Diff line number Diff line change @@ -107,23 +107,19 @@ def forward(
107
107
groups = None
108
108
if self .recompute_ext :
109
109
assert splt is not None
110
- trgt_np = np .squeeze (trgt . cpu (). numpy () )
111
- splt_np = np .squeeze (splt . cpu (). numpy () )
112
- mask_np = np .squeeze (mask . cpu (). numpy () )
113
- groups = create_mapping (trgt_np , splt_np , mask_np )
110
+ trgt = torch .squeeze (trgt )
111
+ splt = torch .squeeze (splt )
112
+ mask = torch .squeeze (mask )
113
+ groups = create_mapping (trgt . cpu (). numpy (), splt . cpu (). numpy (), mask . cpu (). numpy () )
114
114
trgt = splt
115
115
116
116
trgt = trgt .to (torch .int )
117
117
118
- # Extract unique IDs
119
- ids = np .unique (trgt [mask > 0 ].cpu ().numpy ())
120
-
121
- # Remove 0s from the IDs if `mask_background` is True
118
+ # Filter out background and get unique IDs
119
+ masked_trgt = trgt [mask > 0 ]
122
120
if self .mask_background :
123
- ids = ids [ids != 0 ]
124
-
125
- # Convert numpy array to a Python list
126
- ids = ids .tolist ()
121
+ masked_trgt = masked_trgt [masked_trgt != 0 ]
122
+ ids = torch .unique (masked_trgt ).tolist ()
127
123
128
124
# Recompute external matrix
129
125
mext = self .compute_ext_matrix (ids , groups , self .recompute_ext , device )
You can’t perform that action at this time.
0 commit comments