Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add torchvision detector support #486

Merged
merged 8 commits into from
Jun 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ Object detection and instance segmentation are by far the most important fields

- `HuggingFace` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_huggingface.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-huggingface"></a> (NEW)

- `TorchVision` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_torchvision.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-torchvision"></a> (NEW)

<a href="https://huggingface.co/spaces/fcakyon/sahi-yolox"><img width="600" src="https://user-images.githubusercontent.com/34196005/144092739-c1d9bade-a128-4346-947f-424ce00e5c4f.gif" alt="sahi-yolox"></a>


Expand Down Expand Up @@ -111,17 +113,17 @@ conda install pytorch=1.10.2 torchvision=0.11.3 cudatoolkit=11.3 -c pytorch
- Install your desired detection framework (yolov5):

```console
pip install yolov5
pip install yolov5==6.1.3
```

- Install your desired detection framework (mmdet):

```console
pip install mmcv-full==1.4.4 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html
pip install mmcv-full==1.5.3 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html
```

```console
pip install mmdet==2.21.0
pip install mmdet==2.25.0
```

- Install your desired detection framework (detectron2):
Expand Down
1 change: 1 addition & 0 deletions demo/inference_for_torchvision.ipynb

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion sahi/auto_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from typing import Dict, Optional

from sahi.model import Detectron2DetectionModel, HuggingfaceDetectionModel, MmdetDetectionModel, Yolov5DetectionModel
from sahi.model import (
Detectron2DetectionModel,
HuggingfaceDetectionModel,
MmdetDetectionModel,
TorchVisionDetectionModel,
Yolov5DetectionModel,
)

MODEL_TYPE_TO_MODEL_CLASS_NAME = {
"mmdet": MmdetDetectionModel,
"yolov5": Yolov5DetectionModel,
"detectron2": Detectron2DetectionModel,
"huggingface": HuggingfaceDetectionModel,
"torchvision": TorchVisionDetectionModel,
}


Expand Down
206 changes: 206 additions & 0 deletions sahi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,3 +879,209 @@ def _create_object_prediction_list_from_original_predictions(
object_prediction_list_per_image.append(object_prediction_list)

self._object_prediction_list_per_image = object_prediction_list_per_image


@check_requirements(["torch", "torchvision"])
class TorchVisionDetectionModel(DetectionModel):
def __init__(
self,
model_path: Optional[str] = None,
model: Optional[Any] = None,
config_path: Optional[str] = None,
device: Optional[str] = None,
mask_threshold: float = 0.5,
confidence_threshold: float = 0.3,
category_mapping: Optional[Dict] = None,
category_remapping: Optional[Dict] = None,
load_at_init: bool = True,
image_size: int = None,
):

super().__init__(
model_path=model_path,
model=model,
config_path=config_path,
device=device,
mask_threshold=mask_threshold,
confidence_threshold=confidence_threshold,
category_mapping=category_mapping,
category_remapping=category_remapping,
load_at_init=load_at_init,
image_size=image_size,
)

def load_model(self):
import torch

from sahi.utils.torchvision import MODEL_NAME_TO_CONSTRUCTOR

# read config params
model_name = None
num_classes = None
if self.config_path is not None:
import yaml

with open(self.config_path, "r") as stream:
try:
config = yaml.safe_load(stream)
except yaml.YAMLError as exc:
raise RuntimeError(exc)

model_name = config.get("model_name", None)
num_classes = config.get("num_classes", None)

# complete params if not provided in config
if not model_name:
model_name = "fasterrcnn_resnet50_fpn"
logger.warning(f"model_name not provided in config, using default model_type: {model_name}'")
if num_classes is None:
logger.warning("num_classes not provided in config, using default num_classes: 91")
num_classes = 91
if self.model_path is None:
logger.warning("model_path not provided in config, using pretrained weights and default num_classes: 91.")
pretrained = True
num_classes = 91
else:
pretrained = False

# load model
model = MODEL_NAME_TO_CONSTRUCTOR[model_name](num_classes=num_classes, pretrained=pretrained)
try:
model.load_state_dict(torch.load(self.model_path))
except Exception as e:
TypeError("model_path is not a valid torchvision model path: ", e)

self.set_model(model)

def set_model(self, model: Any):
"""
Sets the underlying TorchVision model.
Args:
model: Any
A TorchVision model
"""

model.eval()
self.model = model.to(self.device)

# set category_mapping
from sahi.utils.torchvision import COCO_CLASSES

if self.category_mapping is None:
category_names = {str(i): COCO_CLASSES[i] for i in range(len(COCO_CLASSES))}
self.category_mapping = category_names

def perform_inference(self, image: np.ndarray, image_size: int = None):
"""
Prediction is performed using self.model and the prediction result is set to self._original_predictions.
Args:
image: np.ndarray
A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
image_size: int
Inference input size.
"""
from sahi.utils.torch import to_float_tensor

# arrange model input size
if self.image_size is not None:
# get min and max of image height and width
min_shape, max_shape = min(image.shape[:2]), max(image.shape[:2])
# torchvision resize transform scales the shorter dimension to the target size
# we want to scale the longer dimension to the target size
image_size = self.image_size * min_shape / max_shape
self.model.transform.min_size = (image_size,) # default is (800,)
self.model.transform.max_size = image_size # default is 1333

image = to_float_tensor(image)
image = image.to(self.device)
prediction_result = self.model([image])

self._original_predictions = prediction_result

@property
def num_categories(self):
"""
Returns number of categories
"""
return len(self.category_mapping)

@property
def has_mask(self):
"""
Returns if model output contains segmentation mask
"""
return self.model.with_mask

@property
def category_names(self):
return list(self.category_mapping.values())

def _create_object_prediction_list_from_original_predictions(
self,
shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
full_shape_list: Optional[List[List[int]]] = None,
):
"""
self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
self._object_prediction_list_per_image.
Args:
shift_amount_list: list of list
To shift the box and mask predictions from sliced image to full sized image, should
be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
full_shape_list: list of list
Size of the full image after shifting, should be in the form of
List[[height, width],[height, width],...]
"""
original_predictions = self._original_predictions

# compatilibty for sahi v0.8.20
if isinstance(shift_amount_list[0], int):
shift_amount_list = [shift_amount_list]
if full_shape_list is not None and isinstance(full_shape_list[0], int):
full_shape_list = [full_shape_list]

for image_predictions in original_predictions:
object_prediction_list_per_image = []

# get indices of boxes with score > confidence_threshold
scores = image_predictions["scores"].cpu().detach().numpy()
selected_indices = np.where(scores > self.confidence_threshold)[0]

# parse boxes, masks, scores, category_ids from predictions
category_ids = list(image_predictions["labels"][selected_indices].cpu().detach().numpy())
boxes = list(image_predictions["boxes"][selected_indices].cpu().detach().numpy())
scores = scores[selected_indices]

# check if predictions contain mask
masks = image_predictions.get("masks", None)
if masks is not None:
masks = list(image_predictions["masks"][selected_indices].cpu().detach().numpy())
else:
masks = None

# create object_prediction_list
object_prediction_list = []

shift_amount = shift_amount_list[0]
full_shape = None if full_shape_list is None else full_shape_list[0]

for ind in range(len(boxes)):

if masks is not None:
mask = np.array(masks[ind])
else:
mask = None

object_prediction = ObjectPrediction(
bbox=boxes[ind],
bool_mask=mask,
category_id=int(category_ids[ind]),
category_name=self.category_mapping[str(int(category_ids[ind]))],
shift_amount=shift_amount,
score=scores[ind],
full_shape=full_shape,
)
object_prediction_list.append(object_prediction)
object_prediction_list_per_image.append(object_prediction_list)

self._object_prediction_list_per_image = object_prediction_list_per_image
124 changes: 124 additions & 0 deletions sahi/utils/torchvision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# OBSS SAHI Tool
# Code written by Kadir Nar, 2022.


from packaging import version

from sahi.utils.import_utils import _torchvision_available, _torchvision_version, is_available


class TorchVisionTestConstants:
FASTERRCNN_CONFIG_PATH = "tests/data/models/torchvision/fasterrcnn_resnet50_fpn.yaml"
SSD300_CONFIG_PATH = "tests/data/models/torchvision/ssd300_vgg16.yaml"


if _torchvision_available:
import torchvision

MODEL_NAME_TO_CONSTRUCTOR = {
"fasterrcnn_resnet50_fpn": torchvision.models.detection.fasterrcnn_resnet50_fpn,
"fasterrcnn_mobilenet_v3_large_fpn": torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn,
"fasterrcnn_mobilenet_v3_large_320_fpn": torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn,
"retinanet_resnet50_fpn": torchvision.models.detection.retinanet_resnet50_fpn,
"ssd300_vgg16": torchvision.models.detection.ssd300_vgg16,
"ssdlite320_mobilenet_v3_large": torchvision.models.detection.ssdlite320_mobilenet_v3_large,
}

# fcos requires torchvision >= 0.12.0
if version.parse(_torchvision_version) >= version.parse("0.12.0"):
MODEL_NAME_TO_CONSTRUCTOR["fcos_resnet50_fpn"] = (torchvision.models.detection.fcos_resnet50_fpn,)


COCO_CLASSES = [
"__background__",
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"N/A",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"N/A",
"backpack",
"umbrella",
"N/A",
"N/A",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"N/A",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"N/A",
"dining table",
"N/A",
"N/A",
"toilet",
"N/A",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"N/A",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
]
2 changes: 2 additions & 0 deletions tests/data/models/torchvision/fasterrcnn_resnet50_fpn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: fasterrcnn_resnet50_fpn
num_classes: 91
2 changes: 2 additions & 0 deletions tests/data/models/torchvision/ssd300_vgg16.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_name: ssd300_vgg16
num_classes: 91
Loading