Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Tensorflow Hub detector support #501

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
7179f68
SAHI için tensorflow_hub kütüphanesi eklendi.
kadirnar Jun 18, 2022
f79e846
reformatter
kadirnar Jun 18, 2022
3a1f3d2
hatalar düzeltildi
kadirnar Jun 19, 2022
5197fa5
To do ve check işlemleri düzenlendi.
kadirnar Jun 19, 2022
7e2d6fc
reformatter
kadirnar Jun 19, 2022
ab27f33
kurulum ve image preprocces kodları düzenlendi
kadirnar Jun 19, 2022
26f388e
CI dosyalarına tensorflow paketleri eklendi
kadirnar Jun 19, 2022
119d7a6
tensorflow sürümü düzeltildi.
kadirnar Jun 19, 2022
4e55416
sürüm düzeltildi.
kadirnar Jun 19, 2022
9b4788a
Merge branch 'main' into tfhub
kadirnar Jun 19, 2022
e7f7567
CI dosyası düzenlendi
kadirnar Jun 19, 2022
db9b319
model_path hatası düzeltildi
kadirnar Jun 19, 2022
373e032
gereksiz kütüphaneler ve Model_type düzenlendi.
kadirnar Jun 20, 2022
aac9425
model ismi düzeltildi
kadirnar Jun 20, 2022
2696758
class isimleri düzeltildi.
kadirnar Jun 20, 2022
82d359a
tensorflow için notebook oluşturuldu.
kadirnar Jun 20, 2022
3f9c324
Tensorflow kütüphanesi için GPU kodu yazılmıştır.
kadirnar Jun 20, 2022
50c19b7
Adding new class to variable Coco_classes
kadirnar Jul 18, 2022
58faf09
update automodel loading method to from_pretrained
kadirnar Jul 18, 2022
9c7acd9
Merge branch 'main' into tfhub
fcakyon Jul 18, 2022
c07222d
category_name count error fixed
kadirnar Jul 18, 2022
c03ba63
Merge branch 'tfhub' of https://github.com/kadirnar/sahi into tfhub
kadirnar Jul 18, 2022
0fb6d94
category_mapping variable moved to load model.
kadirnar Aug 5, 2022
533a80a
category_mapping variable edited
kadirnar Aug 5, 2022
d6e46fa
Merge branch 'main' into tfhub
fcakyon Aug 7, 2022
6c91202
The check_requirements function has been updated.
kadirnar Aug 8, 2022
9262bea
Merge branch 'main' into tfhub
fcakyon Aug 29, 2022
38e26a8
Merge branch 'main' into tfhub
fcakyon Aug 29, 2022
4c94739
added set_device functions
kadirnar Aug 30, 2022
f6ed750
update set_device
kadirnar Aug 30, 2022
a31a89e
Update sahi/model.py
kadirnar Sep 1, 2022
50014b1
update set_device
kadirnar Sep 3, 2022
30bc488
Merge branch 'main' into tfhub
kadirnar Sep 3, 2022
9aafc97
Update sahi/model.py
kadirnar Sep 3, 2022
22f4f80
update automatically set_device
kadirnar Sep 3, 2022
a7575f0
added self.set_device
kadirnar Sep 3, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ jobs:
run: >
pip install pycocotools==2.0.4

- name: Install tensorflow(2.9.1) and tensorflow_hub(0.12.0)
kadirnar marked this conversation as resolved.
Show resolved Hide resolved
run: >
pip install tensorflow==2.9.1
pip install tensorflow_hub==0.12.0

- name: Install SAHI package from local setup.py
run: >
pip install -e .
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/package_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ jobs:
run: >
pip install pycocotools==2.0.4

- name: Install tensorflow(2.9.1) and tensorflow_hub(0.12.0)
run: >
pip install tensorflow==2.9.1
pip install tensorflow_hub==0.12.0

- name: Install latest SAHI package
run: >
pip install --upgrade --force-reinstall sahi
Expand Down
97 changes: 96 additions & 1 deletion sahi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Code written by Fatih C Akyon, 2020.

import logging
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -1063,3 +1062,99 @@ def _create_object_prediction_list_from_original_predictions(
object_prediction_list_per_image.append(object_prediction_list)

self._object_prediction_list_per_image = object_prediction_list_per_image


@check_requirements(["tensorflow", "tensorflow_hub"])
kadirnar marked this conversation as resolved.
Show resolved Hide resolved
class TensorflowHubDetectionModel(DetectionModel):
def load_model(self):
kadirnar marked this conversation as resolved.
Show resolved Hide resolved
import tensorflow_hub as hub

self.model = hub.load(self.model_path)

def perform_inference(self, image: np.ndarray):
kadirnar marked this conversation as resolved.
Show resolved Hide resolved
from sahi.utils.tfhub import get_image, resize

if self.image_size is not None:
img = get_image(image)
img = resize(img, self.image_size)
prediction_result = self.model(img)

else:
img = get_image(image)
prediction_result = self.model(img)

self._original_predictions = prediction_result
# TODO: add support for multiple image prediction
self.image_height, self.image_width = image.shape[0], image.shape[1]

if self.category_mapping is None:
kadirnar marked this conversation as resolved.
Show resolved Hide resolved
from sahi.utils.tfhub import COCO_CLASSES

category_mapping = {str(i): COCO_CLASSES[i] for i in range(len(COCO_CLASSES))}
self.category_mapping = category_mapping

@property
def num_categories(self):
num_categories = len(self.category_mapping)
return num_categories

@property
def has_mask(self):
# TODO: check if model output contains segmentation mask
return False

@property
def category_names(self):
return list(self.category_mapping.values())

def _create_object_prediction_list_from_original_predictions(
self,
shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
full_shape_list: Optional[List[List[int]]] = None,
):
original_predictions = self._original_predictions

# compatilibty for sahi v0.8.20
if isinstance(shift_amount_list[0], int):
shift_amount_list = [shift_amount_list]
if full_shape_list is not None and isinstance(full_shape_list[0], int):
full_shape_list = [full_shape_list]

shift_amount = shift_amount_list[0]
full_shape = None if full_shape_list is None else full_shape_list[0]

boxes = original_predictions["detection_boxes"][0].numpy()
scores = original_predictions["detection_scores"][0].numpy()
category_ids = original_predictions["detection_classes"][0].numpy()

# create object_prediction_list
object_prediction_list = []
object_prediction_list_per_image = []

for i in range(min(boxes.shape[0], 100)):
if scores[i] >= self.confidence_threshold:
score = float(scores[i])
category_id = int(category_ids[i])
# Tfhub categories start from 1
category_names = self.category_mapping[str(category_id - 1)]
kadirnar marked this conversation as resolved.
Show resolved Hide resolved
box = [float(box) for box in boxes[i]]
x1, y1, x2, y2 = (
int(box[1] * self.image_width),
int(box[0] * self.image_height),
int(box[3] * self.image_width),
int(box[2] * self.image_height),
)
bbox = [x1, y1, x2, y2]

object_prediction = ObjectPrediction(
bbox=bbox,
bool_mask=None,
category_id=category_id,
category_name=category_names,
shift_amount=shift_amount,
score=score,
full_shape=full_shape,
)
object_prediction_list.append(object_prediction)
object_prediction_list_per_image.append(object_prediction_list)
self._object_prediction_list_per_image = object_prediction_list_per_image
102 changes: 102 additions & 0 deletions sahi/utils/tfhub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
def get_image(array):
kadirnar marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import tensorflow as tf

array = np.asarray(array, np.float32)
kadirnar marked this conversation as resolved.
Show resolved Hide resolved
normalized_array = array
if array.max() <= 1:
kadirnar marked this conversation as resolved.
Show resolved Hide resolved
normalized_array = array * 255.0

normalized_array = np.asarray(normalized_array, np.uint8)
kadirnar marked this conversation as resolved.
Show resolved Hide resolved
normalized_array = tf.convert_to_tensor([normalized_array], tf.uint8)
return normalized_array


def resize(array, size):
import tensorflow as tf

return tf.image.resize(array, [size, size]).numpy()


COCO_CLASSES = (
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
)
96 changes: 96 additions & 0 deletions tests/test_tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# OBSS SAHI Tool
kadirnar marked this conversation as resolved.
Show resolved Hide resolved
# Code written by Kadir Nar, 2022.

import unittest
kadirnar marked this conversation as resolved.
Show resolved Hide resolved

from sahi.model import TensorflowHubDetectionModel
from sahi.utils.cv import read_image

MODEL_DEVICE = "cpu"
CONFIDENCE_THRESHOLD = 0.3
IMAGE_SIZE = 320
EFFICIENTDET_URL = "https://tfhub.dev/tensorflow/efficientdet/d0/1"


kadirnar marked this conversation as resolved.
Show resolved Hide resolved
class TestTensorflowHubDetectionModel(unittest.TestCase):
def test_load_model(self):
from sahi.model import TensorflowHubDetectionModel

tensorflow_hub_model = TensorflowHubDetectionModel(
model_path=EFFICIENTDET_URL,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
category_remapping=None,
load_at_init=True,
)
self.assertNotEqual(tensorflow_hub_model.model, None)

def test_perform_inference(self):
from sahi.model import TensorflowHubDetectionModel

tensorflow_hub_model = TensorflowHubDetectionModel(
model_path=EFFICIENTDET_URL,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
category_remapping=None,
load_at_init=True,
image_size=IMAGE_SIZE,
)

# prepare image
image_path = "tests/data/small-vehicles1.jpeg"
image = read_image(image_path)
image_height, image_width = image.shape[0], image.shape[1]
# perform inference

tensorflow_hub_model.perform_inference(image)
original_prediction = tensorflow_hub_model.original_predictions

boxes = original_prediction["detection_boxes"][0]
box = [float(box) for box in boxes[0].numpy()]
x1, y1, x2, y2 = (
int(box[1] * image_width),
int(box[0] * image_height),
int(box[3] * image_width),
int(box[2] * image_height),
)
bbox = [x1, y1, x2, y2]
# compare
desidred_bbox = [317, 324, 381, 364]
predicted_bbox = [x1, y1, x2, y2]
self.assertEqual(desidred_bbox, predicted_bbox)

def test_convert_original_predictions(self):

tensorflow_hub_model = TensorflowHubDetectionModel(
model_path=EFFICIENTDET_URL,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
category_remapping=None,
load_at_init=True,
image_size=IMAGE_SIZE,
)

# prepare image
image_path = "tests/data/small-vehicles1.jpeg"
image = read_image(image_path)
image_height, image_width = image.shape[0], image.shape[1]

# perform inference
tensorflow_hub_model.perform_inference(image)

# convert predictions to ObjectPrediction list
tensorflow_hub_model.convert_original_predictions()
object_prediction_list = tensorflow_hub_model.object_prediction_list

# compare
self.assertEqual(len(object_prediction_list), 5)
self.assertEqual(object_prediction_list[0].category.id, 3)
self.assertEqual(object_prediction_list[0].category.name, "car")
desidred_bbox = [317, 324, 64, 40]
predicted_bbox = object_prediction_list[0].bbox.to_coco_bbox()
self.assertEqual(desidred_bbox, predicted_bbox)


if __name__ == "__main__":
unittest.main()