Skip to content
11 changes: 11 additions & 0 deletions segment_anything/automatic_mask_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
uncrop_points,
)

import cv2, time

class SamAutomaticMaskGenerator:
def __init__(
Expand Down Expand Up @@ -161,6 +162,8 @@ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:

# Generate masks
mask_data = self._generate_masks(image)
if mask_data is None:
return None

# Filter small disconnected regions and holes in masks
if self.min_mask_region_area > 0:
Expand Down Expand Up @@ -204,8 +207,16 @@ def _generate_masks(self, image: np.ndarray) -> MaskData:
data = MaskData()
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)

data.cat(crop_data)

if data["crop_boxes"] is None or data["crop_boxes"].numel() == 0:
# No masks were found in all the little crops
# saving problem image
cv2.imwrite(f"errorimg_{time.time()}.png",image)
return None


# Remove duplicate masks between crops
if len(crop_boxes) > 1:
# Prefer masks from smaller crops
Expand Down
45 changes: 44 additions & 1 deletion segment_anything/utils/amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,50 @@ def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:

# Compute change indices
diff = tensor[:, 1:] ^ tensor[:, :-1]
change_indices = diff.nonzero()

# the torch function nonzero() only works up to INT_MAX tensor elements
# We first test if we have more than that:
# Total elements in the tensor
b, w_h = diff.shape
total_elements = b * w_h

# Maximum allowable elements in one chunk - as torch is using 32 bit integers for this function
max_elements_per_chunk = 2**31 - 1

if total_elements < max_elements_per_chunk:
change_indices = (
diff.nonzero()
) # the tensor is < 32 bit so we find the change indices in a single torch call.
else:
# Calculate the number of chunks needed
num_chunks = total_elements // max_elements_per_chunk
if total_elements % max_elements_per_chunk != 0:
num_chunks += 1

# Calculate the actual chunk size
chunk_size = b // num_chunks
if b % num_chunks != 0:
chunk_size += 1

# List to store the results from each chunk
all_indices = []

# Loop through the diff tensor in chunks
for i in range(num_chunks):
start = i * chunk_size
end = min((i + 1) * chunk_size, b)
chunk = diff[start:end, :]

# Get non-zero indices for the current chunk
indices = chunk.nonzero()

# Adjust the row indices to the original tensor
indices[:, 0] += start

all_indices.append(indices)

# Concatenate all the results
change_indices = torch.cat(all_indices, dim=0)

# Encode run length
out = []
Expand Down