Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated amg.py to allow batching of large mask sets and avoid over-sized torch.nonzero() calls #569

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion segment_anything/utils/amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,50 @@ def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:

# Compute change indices
diff = tensor[:, 1:] ^ tensor[:, :-1]
change_indices = diff.nonzero()

# the torch function nonzero() only works up to INT_MAX tensor elements
# We first test if we have more than that:
# Total elements in the tensor
b, w_h = diff.shape
total_elements = b * w_h

# Maximum allowable elements in one chunk - as torch is using 32 bit integers for this function
max_elements_per_chunk = 2**31 - 1

if total_elements < max_elements_per_chunk:
change_indices = (
diff.nonzero()
) # the tensor is < 32 bit so we find the change indices in a single torch call.
else:
# Calculate the number of chunks needed
num_chunks = total_elements // max_elements_per_chunk
if total_elements % max_elements_per_chunk != 0:
num_chunks += 1

# Calculate the actual chunk size
chunk_size = b // num_chunks
if b % num_chunks != 0:
chunk_size += 1

# List to store the results from each chunk
all_indices = []

# Loop through the diff tensor in chunks
for i in range(num_chunks):
start = i * chunk_size
end = min((i + 1) * chunk_size, b)
chunk = diff[start:end, :]

# Get non-zero indices for the current chunk
indices = chunk.nonzero()

# Adjust the row indices to the original tensor
indices[:, 0] += start

all_indices.append(indices)

# Concatenate all the results
change_indices = torch.cat(all_indices, dim=0)

# Encode run length
out = []
Expand Down