File tree 1 file changed +8
-12
lines changed
1 file changed +8
-12
lines changed Original file line number Diff line number Diff line change @@ -115,23 +115,19 @@ def forward(
115
115
groups = None
116
116
if self .recompute_ext :
117
117
assert splt is not None
118
- trgt_np = np .squeeze (trgt . cpu (). numpy () )
119
- splt_np = np .squeeze (splt . cpu (). numpy () )
120
- mask_np = np .squeeze (mask . cpu (). numpy () )
121
- groups = create_mapping (trgt_np , splt_np , mask_np )
118
+ trgt = torch .squeeze (trgt )
119
+ splt = torch .squeeze (splt )
120
+ mask = torch .squeeze (mask )
121
+ groups = create_mapping (trgt . cpu (). numpy (), splt . cpu (). numpy (), mask . cpu (). numpy () )
122
122
trgt = splt
123
123
124
124
trgt = trgt .to (torch .int )
125
125
126
- # Extract unique IDs
127
- ids = np .unique (trgt [mask > 0 ].cpu ().numpy ())
128
-
129
- # Remove 0s from the IDs if `mask_background` is True
126
+ # Filter out background and get unique IDs
127
+ masked_trgt = trgt [mask > 0 ]
130
128
if self .mask_background :
131
- ids = ids [ids != 0 ]
132
-
133
- # Convert numpy array to a Python list
134
- ids = ids .tolist ()
129
+ masked_trgt = masked_trgt [masked_trgt != 0 ]
130
+ ids = torch .unique (masked_trgt ).tolist ()
135
131
136
132
# Recompute external matrix
137
133
mext = self .compute_ext_matrix (ids , groups , self .recompute_ext , device )
You can’t perform that action at this time.
0 commit comments