Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve scripts/amg.py: paths, progress bar, multi-GPU, resuming #398

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 69 additions & 30 deletions scripts/amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import cv2 # type: ignore

from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

import argparse
import json
import os
from pathlib import Path
from typing import Any, Dict, List

import cv2 # type: ignore
import tqdm
import torch.utils.data

from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

parser = argparse.ArgumentParser(
description=(
"Runs automatic mask generation on an input image or directory of images, "
Expand Down Expand Up @@ -53,6 +55,9 @@
)

parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.")
parser.add_argument("--rank", type=int, default=0, help="Rank of the current process.")
parser.add_argument("--world", type=int, default=1, help="Number of processes.")
parser.add_argument("--num-workers", type=int, default=4, help="Dataloader workers.")

parser.add_argument(
"--convert-to-rle",
Expand Down Expand Up @@ -149,13 +154,13 @@
)


def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
def write_masks_to_folder(masks: List[Dict[str, Any]], path: Path) -> None:
header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa
metadata = [header]
for i, mask_data in enumerate(masks):
mask = mask_data["segmentation"]
filename = f"{i}.png"
cv2.imwrite(os.path.join(path, filename), mask * 255)
cv2.imwrite((path / filename).as_posix(), mask * 255)
mask_metadata = [
str(i),
str(mask_data["area"]),
Expand All @@ -167,8 +172,7 @@ def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
]
row = ",".join(mask_metadata)
metadata.append(row)
metadata_path = os.path.join(path, "metadata.csv")
with open(metadata_path, "w") as f:
with open(path / "metadata.csv", "w") as f:
f.write("\n".join(metadata))

return
Expand All @@ -192,6 +196,21 @@ def get_amg_kwargs(args):
return amg_kwargs


class ImageDataset(torch.utils.data.Dataset):
def __init__(self, paths: List[Path], base: Path):
self.paths = paths
self.base = base

def __len__(self):
return len(self.paths)

def __getitem__(self, idx):
path = self.paths[idx]
image = cv2.imread((self.base / path).as_posix())
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return path, image


def main(args: argparse.Namespace) -> None:
print("Loading model...")
sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
Expand All @@ -200,34 +219,54 @@ def main(args: argparse.Namespace) -> None:
amg_kwargs = get_amg_kwargs(args)
generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)

if not os.path.isdir(args.input):
targets = [args.input]
args.input = Path(args.input).expanduser().resolve()
args.output = Path(args.output).expanduser().resolve()

if not args.input.is_dir():
targets = ImageDataset([args.input.name], args.input.parent)
else:
targets = [
f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
f.relative_to(args.input) for f in args.input.rglob("*")
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
]
targets = [os.path.join(args.input, f) for f in targets]

os.makedirs(args.output, exist_ok=True)

for t in targets:
print(f"Processing '{t}'...")
image = cv2.imread(t)
if image is None:
print(f"Could not load '{t}' as an image, skipping...")
continue
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

print(f"Found {len(targets)} images in {args.input}.")

# Per-process split
if args.world > 1:
targets = targets[args.rank::args.world]
print(f"Rank {args.rank}/{args.world} will process {len(targets)} images.")

# Skip existing
if output_mode == "binary_mask":
targets = [
f for f in targets
if not Path.is_dir(args.output / f.with_suffix(""))
]
else:
targets = [
f for f in targets
if not Path.is_file(args.output / f.with_suffix(".json"))
]
print(f"Skip already processed images, {len(targets)} remain to do.")

targets = torch.utils.data.DataLoader(
ImageDataset(targets, args.input),
batch_size=None,
shuffle=False,
num_workers=args.num_workers,
collate_fn=lambda x: x,
)

for path, image in tqdm.tqdm(targets, ncols=0):
masks = generator.generate(image)

base = os.path.basename(t)
base = os.path.splitext(base)[0]
save_base = os.path.join(args.output, base)
if output_mode == "binary_mask":
os.makedirs(save_base, exist_ok=False)
write_masks_to_folder(masks, save_base)
save_dir = args.output / path.with_suffix("")
save_dir.mkdir(parents=True, exist_ok=False)
write_masks_to_folder(masks, save_dir)
else:
save_file = save_base + ".json"
save_file = args.output / path.with_suffix(".json")
save_file.parent.mkdir(parents=True, exist_ok=True)
with open(save_file, "w") as f:
json.dump(masks, f)
print("Done!")
Expand Down