Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add detectron2 support #322

Merged
merged 120 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
120 commits
Select commit Hold shift + click to select a range
10faea7
model upload
Dec 30, 2021
befe7df
model upload
Dec 30, 2021
3a28861
detectron upload
Dec 30, 2021
ca9fb3a
gitmodules file upload
Dec 30, 2021
7823264
Update .gitmodules
kadirnar Dec 30, 2021
2296b85
detectron2 prediction kodları düzenlendi.
Dec 31, 2021
35d4738
Delete .gitmodules
kadirnar Dec 31, 2021
d132d5f
Update ci.yml
kadirnar Dec 31, 2021
6f6258a
Kodlara black ve isort işlemleri uygulandı.
Dec 31, 2021
0b2196b
kütüphaneler fonksiyonlar altında tanımlandı
Dec 31, 2021
1ef9c6e
requirements yapılan değişiklikler iptal edildi.
Dec 31, 2021
98f076e
detectron modülü kullanarak model indirilecek
kadirnar Dec 31, 2021
4dc10a5
Delete detectron2.py
kadirnar Dec 31, 2021
9707126
Update ci.yml
kadirnar Dec 31, 2021
68c6ae7
Update test_detectron2.py
kadirnar Dec 31, 2021
11e12e3
Update model.py
kadirnar Dec 31, 2021
be78662
update push
kadirnar Dec 31, 2021
9667d73
Update model.py
kadirnar Dec 31, 2021
fee145e
Update test_detectron2.py
kadirnar Dec 31, 2021
2ee7788
Update model.py
kadirnar Dec 31, 2021
ea3e7f4
Update test_detectron2.py
kadirnar Dec 31, 2021
0caaa20
Update test_detectron2.py
kadirnar Dec 31, 2021
caaa832
Update model.py
kadirnar Dec 31, 2021
20ebf70
boxes ve mask değişkenleri düzenlendi
kadirnar Dec 31, 2021
daa1b34
boxes ve mask değişkenleri düzeltildi
kadirnar Dec 31, 2021
20d4701
Update model.py
kadirnar Dec 31, 2021
e78adf3
Update test_detectron2.py
kadirnar Dec 31, 2021
deee4cf
refactor detectron2 class, fix 2 tests
fcakyon Jan 1, 2022
7f5b40d
remove redundant code
fcakyon Jan 1, 2022
953de9c
fix detectron class
fcakyon Jan 1, 2022
0c1c9cd
update workflows
fcakyon Jan 1, 2022
e14790c
add margin to bbox tests
fcakyon Jan 1, 2022
3e427aa
update test margin
fcakyon Jan 1, 2022
9da568c
handle empty mask
fcakyon Jan 1, 2022
8d1a836
print assertion error in tests
fcakyon Jan 1, 2022
c3194e9
update tests
fcakyon Jan 1, 2022
051ce6e
handle empty prediction masks
fcakyon Jan 1, 2022
a339b0a
add detectron2 model support for predict
fcakyon Jan 1, 2022
ca210fb
throw error for invalid mask prediction
fcakyon Jan 1, 2022
6912080
update workflow docstring
fcakyon Jan 1, 2022
a31d694
perform detectron2 tests only on linux builds
fcakyon Jan 1, 2022
657030f
fix workflow
fcakyon Jan 1, 2022
795eef0
fix detectron tests
fcakyon Jan 1, 2022
74261d9
attempt to fix detectron tests on win osx
fcakyon Jan 1, 2022
d908220
attempt to fix detectron tests
fcakyon Jan 1, 2022
e2e9b0d
update detectron2 model
fcakyon Jan 1, 2022
2ec5871
update detectron test
fcakyon Jan 1, 2022
f775de9
fix detectron class
fcakyon Jan 1, 2022
a72bc25
update tests
fcakyon Jan 1, 2022
b788961
update workflows
fcakyon Jan 1, 2022
a70390f
fix yolov5 test
fcakyon Jan 1, 2022
20a690f
model upload
Dec 30, 2021
b0e8f64
model upload
Dec 30, 2021
cd16f05
detectron upload
Dec 30, 2021
7a2d6a8
gitmodules file upload
Dec 30, 2021
9fb0631
Update .gitmodules
kadirnar Dec 30, 2021
ab55467
detectron2 prediction kodları düzenlendi.
Dec 31, 2021
1f3f6f1
Delete .gitmodules
kadirnar Dec 31, 2021
e4bcadb
Update ci.yml
kadirnar Dec 31, 2021
ed82f7b
Kodlara black ve isort işlemleri uygulandı.
Dec 31, 2021
b8306ae
kütüphaneler fonksiyonlar altında tanımlandı
Dec 31, 2021
7dad95e
requirements yapılan değişiklikler iptal edildi.
Dec 31, 2021
00ea910
detectron modülü kullanarak model indirilecek
kadirnar Dec 31, 2021
2db345a
Delete detectron2.py
kadirnar Dec 31, 2021
bccc66a
Update ci.yml
kadirnar Dec 31, 2021
c574ff7
Update test_detectron2.py
kadirnar Dec 31, 2021
86b288f
Update model.py
kadirnar Dec 31, 2021
dc7a0be
update push
kadirnar Dec 31, 2021
b452f5a
Update model.py
kadirnar Dec 31, 2021
449909d
Update test_detectron2.py
kadirnar Dec 31, 2021
5df0bb0
Update model.py
kadirnar Dec 31, 2021
56b141d
Update test_detectron2.py
kadirnar Dec 31, 2021
cd4026a
Update test_detectron2.py
kadirnar Dec 31, 2021
754ae8d
Update model.py
kadirnar Dec 31, 2021
51e8889
boxes ve mask değişkenleri düzenlendi
kadirnar Dec 31, 2021
e532cb3
boxes ve mask değişkenleri düzeltildi
kadirnar Dec 31, 2021
09eff03
Update model.py
kadirnar Dec 31, 2021
67cb5f9
Update test_detectron2.py
kadirnar Dec 31, 2021
4621b6a
refactor detectron2 class, fix 2 tests
fcakyon Jan 1, 2022
f68abcb
remove redundant code
fcakyon Jan 1, 2022
d8dc435
fix detectron class
fcakyon Jan 1, 2022
ab0e9b9
update workflows
fcakyon Jan 1, 2022
8bc344e
add margin to bbox tests
fcakyon Jan 1, 2022
73f27f3
update test margin
fcakyon Jan 1, 2022
b281df9
handle empty mask
fcakyon Jan 1, 2022
410a1bc
print assertion error in tests
fcakyon Jan 1, 2022
c93a338
update tests
fcakyon Jan 1, 2022
f4c9e2b
handle empty prediction masks
fcakyon Jan 1, 2022
daa59a3
add detectron2 model support for predict
fcakyon Jan 1, 2022
d06f0d6
throw error for invalid mask prediction
fcakyon Jan 1, 2022
67b92ce
update workflow docstring
fcakyon Jan 1, 2022
50024e9
perform detectron2 tests only on linux builds
fcakyon Jan 1, 2022
8e32bf4
fix workflow
fcakyon Jan 1, 2022
74ece34
fix detectron tests
fcakyon Jan 1, 2022
304d9a1
attempt to fix detectron tests on win osx
fcakyon Jan 1, 2022
49a3d20
attempt to fix detectron tests
fcakyon Jan 1, 2022
2474cd8
update detectron2 model
fcakyon Jan 1, 2022
cae6649
update detectron test
fcakyon Jan 1, 2022
5c926f0
fix detectron class
fcakyon Jan 1, 2022
1852322
update tests
fcakyon Jan 1, 2022
b3e46fd
update workflows
fcakyon Jan 1, 2022
422a95a
fix yolov5 test
fcakyon Jan 1, 2022
ee5b2f4
Merge branch 'detectron' of https://github.com/kadirnar/sahi into det…
fcakyon Jan 2, 2022
359b73e
update tests
fcakyon Jan 2, 2022
6201f12
fix tests for new postprocess
fcakyon Jan 2, 2022
7b908cc
fix tests
fcakyon Jan 2, 2022
cc4e96d
update workflows docstring
fcakyon Jan 2, 2022
fa3215a
install detectron on all platforms
fcakyon Jan 2, 2022
9961d75
update detectron2 notebooks
kadirnar Jan 2, 2022
e91443a
Delete Detectron2.ipynb
kadirnar Jan 2, 2022
bfdea11
update detectron2 notebook
kadirnar Jan 2, 2022
bcc9468
update detectron2 notebook
fcakyon Jan 2, 2022
4f46742
revert back sahi installation todetectron branch
fcakyon Jan 2, 2022
7c83c1d
handle when config_path is None
fcakyon Jan 3, 2022
729137a
remove default config value
fcakyon Jan 3, 2022
7a5113d
add export_cfg_as_yaml detectron2 util
fcakyon Jan 3, 2022
7bbd70a
detectron model can load custom finetuned models
fcakyon Jan 3, 2022
bfe9981
Merge branch 'main' of https://github.com/kadirnar/sahi into detectron
fcakyon Jan 6, 2022
86f3206
reformat with isort
fcakyon Jan 6, 2022
5e2a7d5
Merge branch 'main' of https://github.com/kadirnar/sahi into detectron
fcakyon Jan 11, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,21 @@ jobs:
if: matrix.operating-system == 'macos-latest'
run: pip install torch==1.10.0 torchvision==0.11.1

- name: Install MMDetection(2.19.1), YOLOv5(6.0.6) and Norfair(0.3.1)
- name: Install MMDetection(2.19.1) with MMCV(1.4.1)
run: >
pip install mmcv-full==1.4.1 -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.10.0/index.html
pip install mmdet==2.19.1 yolov5==6.0.6 norfair==0.3.1
pip install mmdet==2.19.1

- name: Test with unittest
- name: Install YOLOv5(6.0.6) and Norfair(0.3.1)
run: >
pip install yolov5==6.0.6 norfair==0.3.1

- name: Install Detectron2(0.6)
run: >
python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/index.html

- name: Unittest for SAHI+YOLOV5/MMDET/Detectron2 on all platforms
run: |
pip install pytest
python -m unittest

- name: Install pycocotools(2.0.3)
Expand Down Expand Up @@ -113,4 +120,4 @@ jobs:
# coco evaluate
sahi coco evaluate --dataset_json_path tests/data/coco_evaluate/dataset.json --result_json_path tests/data/coco_evaluate/result.json
# coco analyse
sahi coco analyse --dataset_json_path tests/data/coco_evaluate/dataset.json --result_json_path tests/data/coco_evaluate/result.json --out_dir tests/data/coco_evaluate/
sahi coco analyse --dataset_json_path tests/data/coco_evaluate/dataset.json --result_json_path tests/data/coco_evaluate/result.json --out_dir tests/data/coco_evaluate/
15 changes: 11 additions & 4 deletions .github/workflows/package_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,18 @@ jobs:
if: matrix.operating-system == 'macos-latest'
run: pip install torch==1.10.0 torchvision==0.11.1

- name: Install MMDetection(2.19.1), YOLOv5(6.0.6) and Norfair(0.3.1)
- name: Install MMDetection(2.19.1) with MMCV(1.4.1)
run: >
pip install mmcv-full==1.4.1 -f https://download.openmmlab.com/mmcv/dist/cpu/torch1.10.0/index.html
pip install mmdet==2.19.1 yolov5==6.0.6 norfair==0.3.1
pip install mmdet==2.19.1

- name: Install YOLOv5(6.0.6) and Norfair(0.3.1)
run: >
pip install yolov5==6.0.6 norfair==0.3.1

- name: Install Detectron2(0.6)
run: >
python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/index.html

- name: Install pycocotools(2.0.3)
run: >
Expand All @@ -75,9 +83,8 @@ jobs:
run: >
pip install --upgrade --force-reinstall sahi

- name: Test with unittest
- name: Unittest for SAHI+YOLOV5/MMDET/Detectron2 on all platforms
run: |
pip install pytest
python -m unittest

- name: Test SAHI CLI
Expand Down
886 changes: 886 additions & 0 deletions demo/inference_for_detectron2.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ pillow>=8.2.0
pyyaml
fire
terminaltables
requests
requests
2 changes: 2 additions & 0 deletions sahi/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,8 @@ def __init__(
# https://github.com/obss/sahi/issues/235
if bbox_from_bool_mask is not None:
bbox = bbox_from_bool_mask
else:
raise ValueError("Invalid boolean mask.")
else:
self.mask = None

Expand Down
184 changes: 176 additions & 8 deletions sahi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
# Code written by Fatih C Akyon, 2020.

import logging
import os
import warnings
from typing import Dict, List, Optional, Union

import numpy as np

from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from sahi.utils.cv import get_bbox_from_bool_mask
from sahi.utils.torch import cuda_is_available, empty_cuda_cache

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -89,7 +89,6 @@ def perform_inference(self, image: np.ndarray, image_size: int = None):
"""
This function should be implemented in a way that prediction should be
performed using self.model and the prediction result should be set to self._original_predictions.

Args:
image: np.ndarray
A numpy array that contains the image to be predicted.
Expand All @@ -107,7 +106,6 @@ def _create_object_prediction_list_from_original_predictions(
This function should be implemented in a way that self._original_predictions should
be converted to a list of prediction.ObjectPrediction and set to
self._object_prediction_list. self.mask_threshold can also be utilized.

Args:
shift_amount: list
To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]
Expand Down Expand Up @@ -137,7 +135,6 @@ def convert_original_predictions(
"""
Converts original predictions of the detection model to a list of
prediction.ObjectPrediction object. Should be called after perform_inference().

Args:
shift_amount: list
To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y]
Expand Down Expand Up @@ -200,7 +197,6 @@ def load_model(self):
def perform_inference(self, image: np.ndarray, image_size: int = None):
"""
Prediction is performed using self.model and the prediction result is set to self._original_predictions.

Args:
image: np.ndarray
A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
Expand Down Expand Up @@ -271,7 +267,6 @@ def _create_object_prediction_list_from_original_predictions(
"""
self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
self._object_prediction_list_per_image.

Args:
shift_amount_list: list of list
To shift the box and mask predictions from sliced image to full sized image, should
Expand Down Expand Up @@ -384,7 +379,6 @@ def load_model(self):
def perform_inference(self, image: np.ndarray, image_size: int = None):
"""
Prediction is performed using self.model and the prediction result is set to self._original_predictions.

Args:
image: np.ndarray
A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
Expand Down Expand Up @@ -436,7 +430,6 @@ def _create_object_prediction_list_from_original_predictions(
"""
self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
self._object_prediction_list_per_image.

Args:
shift_amount_list: list of list
To shift the box and mask predictions from sliced image to full sized image, should
Expand Down Expand Up @@ -495,3 +488,178 @@ def _create_object_prediction_list_from_original_predictions(
object_prediction_list_per_image.append(object_prediction_list)

self._object_prediction_list_per_image = object_prediction_list_per_image


class Detectron2DetectionModel(DetectionModel):
def load_model(self):
try:
import detectron2
except ImportError:
raise ImportError(
"Please install detectron2. Check "
"`https://detectron2.readthedocs.io/en/latest/tutorials/install.html` "
"for instalattion details."
)

from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultPredictor
from detectron2.model_zoo import model_zoo

cfg = get_cfg()
cfg.MODEL.DEVICE = self.device

try: # try to load from model zoo
config_file = model_zoo.get_config_file(self.config_path)
cfg.merge_from_file(config_file)
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(self.config_path)
except Exception as e: # try to load from local
print(e)
if self.config_path is not None:
cfg.merge_from_file(self.config_path)
cfg.MODEL.WEIGHTS = self.model_path
# set input image size
if self.image_size is not None:
cfg.INPUT.MIN_SIZE_TEST = self.image_size
cfg.INPUT.MAX_SIZE_TEST = self.image_size
# init predictor
model = DefaultPredictor(cfg)

self.model = model

# detectron2 category mapping
if self.category_mapping is None:
try: # try to parse category names from metadata
metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
category_names = metadata.thing_classes
self.category_names = category_names
self.category_mapping = {
str(ind): category_name for ind, category_name in enumerate(self.category_names)
}
except Exception as e:
logger.warning(e)
# https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html#update-the-config-for-new-datasets
if cfg.MODEL.META_ARCHITECTURE == "RetinaNet":
num_categories = cfg.MODEL.RETINANET.NUM_CLASSES
else: # fasterrcnn/maskrcnn etc
num_categories = cfg.MODEL.ROI_HEADS.NUM_CLASSES
self.category_names = [str(category_id) for category_id in range(num_categories)]
self.category_mapping = {
str(ind): category_name for ind, category_name in enumerate(self.category_names)
}
else:
self.category_names = list(self.category_mapping.values())

def perform_inference(self, image: np.ndarray, image_size: int = None):
"""
Prediction is performed using self.model and the prediction result is set to self._original_predictions.
Args:
image: np.ndarray
A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
"""
try:
import detectron2
except ImportError:
raise ImportError("Please install detectron2 via `pip install detectron2`")

# confirm image_size is not provided
if image_size is not None:
warnings.warn("Set 'image_size' at DetectionModel init.")

# Confirm model is loaded
if self.model is None:
raise RuntimeError("Model is not loaded, load it by calling .load_model()")

if isinstance(image, np.ndarray) and self.model.input_format == "BGR":
# convert RGB image to BGR format
image = image[:, :, ::-1]

prediction_result = self.model(image)

self._original_predictions = prediction_result

@property
def num_categories(self):
"""
Returns number of categories
"""
num_categories = len(self.category_mapping)
return num_categories

def _create_object_prediction_list_from_original_predictions(
self,
shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
full_shape_list: Optional[List[List[int]]] = None,
):
"""
self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
self._object_prediction_list_per_image.
Args:
shift_amount_list: list of list
To shift the box and mask predictions from sliced image to full sized image, should
be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
full_shape_list: list of list
Size of the full image after shifting, should be in the form of
List[[height, width],[height, width],...]
"""
original_predictions = self._original_predictions
category_mapping = self.category_mapping

# compatilibty for sahi v0.8.15
if isinstance(shift_amount_list[0], int):
shift_amount_list = [shift_amount_list]
if full_shape_list is not None and isinstance(full_shape_list[0], int):
full_shape_list = [full_shape_list]

# parse boxes, masks, scores, category_ids from predictions
boxes = original_predictions["instances"].pred_boxes.tensor.tolist()
scores = original_predictions["instances"].scores.tolist()
category_ids = original_predictions["instances"].pred_classes.tolist()
try:
masks = original_predictions["instances"].pred_masks.tolist()
except AttributeError:
masks = None

# create object_prediction_list
num_categories = self.num_categories
object_prediction_list_per_image = []
object_prediction_list = []

# detectron2 DefaultPredictor supports single image
shift_amount = shift_amount_list[0]
full_shape = None if full_shape_list is None else full_shape_list[0]

for ind in range(len(boxes)):
score = scores[ind]
if score < self.confidence_threshold:
continue

category_id = category_ids[ind]

if masks is None:
bbox = boxes[ind]
mask = None
else:
mask = np.array(masks[ind])

# check if mask is valid
if get_bbox_from_bool_mask(mask) is not None:
bbox = None
else:
continue

object_prediction = ObjectPrediction(
bbox=bbox,
bool_mask=mask,
category_id=category_id,
category_name=self.category_mapping[str(category_id)],
shift_amount=shift_amount,
score=score,
full_shape=full_shape,
)
object_prediction_list.append(object_prediction)

# detectron2 DefaultPredictor supports single image
object_prediction_list_per_image = [object_prediction_list]

self._object_prediction_list_per_image = object_prediction_list_per_image
1 change: 1 addition & 0 deletions sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
MODEL_TYPE_TO_MODEL_CLASS_NAME = {
"mmdet": "MmdetDetectionModel",
"yolov5": "Yolov5DetectionModel",
"detectron2": "Detectron2DetectionModel",
}

LOW_MODEL_CONFIDENCE = 0.1
Expand Down
7 changes: 5 additions & 2 deletions sahi/utils/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,11 @@ def get_bbox_from_bool_mask(bool_mask):

ymin, ymax = np.where(rows)[0][[0, -1]]
xmin, xmax = np.where(cols)[0][[0, -1]]
# width = xmax - xmin
# height = ymax - ymin
width = xmax - xmin
height = ymax - ymin

if width == 0 or height == 0:
return None

return [xmin, ymin, xmax, ymax]

Expand Down
21 changes: 21 additions & 0 deletions sahi/utils/detectron2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pathlib import Path


class Detectron2TestConstants:
FASTERCNN_MODEL_ZOO_NAME = "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml"
RETINANET_MODEL_ZOO_NAME = "COCO-Detection/retinanet_R_50_FPN_3x.yaml"
MASKRCNN_MODEL_ZOO_NAME = "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"


def export_cfg_as_yaml(cfg, export_path: str = "config.yaml"):
"""
Exports Detectron2 config object in yaml format so that it can be used later.
Args:
cfg (detectron2.config.CfgNode): Detectron2 config object.
export_path (str): Path to export the Detectron2 config.
Related Detectron2 doc: https://detectron2.readthedocs.io/en/stable/modules/config.html#detectron2.config.CfgNode.dump
"""
Path(export_path).parent.mkdir(exist_ok=True, parents=True)

with open(export_path, "w") as f:
f.write(cfg.dump())
Loading