diff --git a/boxmot/appearance/backends/base_backend.py b/boxmot/appearance/backends/base_backend.py index dba72317e3..49eb877355 100644 --- a/boxmot/appearance/backends/base_backend.py +++ b/boxmot/appearance/backends/base_backend.py @@ -34,48 +34,41 @@ def __init__(self, weights, device, half): self.checker = RequirementsChecker() self.load_model(self.weights) + def get_crops(self, xyxys, img): - crops = [] h, w = img.shape[:2] resize_dims = (128, 256) interpolation_method = cv2.INTER_LINEAR - mean_array = np.array([0.485, 0.456, 0.406]) - std_array = np.array([0.229, 0.224, 0.225]) - # dets are of different sizes so batch preprocessing is not possible - for box in xyxys: + mean_array = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1, 3, 1, 1) + std_array = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1, 3, 1, 1) + + # Preallocate tensor for crops + num_crops = len(xyxys) + crops = torch.empty((num_crops, 3, resize_dims[1], resize_dims[0]), + dtype=torch.half if self.half else torch.float, device=self.device) + + for i, box in enumerate(xyxys): x1, y1, x2, y2 = box.astype('int') - x1 = max(0, x1) - y1 = max(0, y1) - x2 = min(w - 1, x2) - y2 = min(h - 1, y2) + x1, y1, x2, y2 = max(0, x1), max(0, y1), min(w - 1, x2), min(h - 1, y2) crop = img[y1:y2, x1:x2] - # resize - crop = cv2.resize( - crop, - resize_dims, # from (x, y) to (128, 256) | (w, h) - interpolation=interpolation_method, - ) - - # (cv2) BGR 2 (PIL) RGB. The ReID models have been trained with this channel order + + # Resize and convert color in one step + crop = cv2.resize(crop, resize_dims, interpolation=interpolation_method) crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB) - - crop = torch.from_numpy(crop).float() - crops.append(crop) - - # List of torch tensor crops to unified torch tensor - crops = torch.stack(crops, dim=0) - - # Normalize the batch + + # Convert to tensor and normalize (convert to [0, 1] by dividing by 255 in batch later) + crop = torch.from_numpy(crop).to(self.device, dtype=torch.half if self.half else torch.float) + crops[i] = torch.permute(crop, (2, 0, 1)) # Change to (C, H, W) + + # Normalize the entire batch in one go crops = crops / 255.0 # Standardize the batch crops = (crops - mean_array) / std_array - - crops = torch.permute(crops, (0, 3, 1, 2)) - crops = crops.to(dtype=torch.half if self.half else torch.float, device=self.device) - + return crops + @torch.no_grad() def get_features(self, xyxys, img): if xyxys.size != 0: