Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove pycocotools #1791

Merged
merged 21 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f9fa509
Deprecate cache and cache_dir parameters support
BloodAxe Jan 24, 2024
5a904e0
Deprecate tight box rotation support for COCO dataset
BloodAxe Jan 24, 2024
24a74b3
Merge branch 'master' into feature/SG-1333-remove-tight-box-rotation
BloodAxe Jan 24, 2024
85f66ab
Remove pycocotools from detection
BloodAxe Jan 25, 2024
4328979
Remove pycocotools from pose estimation
BloodAxe Jan 25, 2024
0fba00a
Update COCOKeypointsDataset to use our own parser
BloodAxe Jan 25, 2024
ef0c50a
Merge branch 'master' into feature/SG-1333-remove-pycocotools
BloodAxe Jan 25, 2024
430ceb3
Add missing installation of dev-dependencies
BloodAxe Jan 25, 2024
1e43db6
Merge remote-tracking branch 'origin/feature/SG-1333-remove-pycocotoo…
BloodAxe Jan 25, 2024
dad98a2
Remove methods that are not used
BloodAxe Jan 26, 2024
897d08f
add install
BloodAxe Jan 26, 2024
f5b9694
Fix imports
BloodAxe Jan 26, 2024
b596a51
Merge branch 'master' into feature/SG-1333-remove-pycocotools
BloodAxe Jan 26, 2024
dd482fd
Fix bug in parse_coco_into_keypoints_annotations
BloodAxe Jan 26, 2024
f63bf37
Bugfix
BloodAxe Jan 26, 2024
e70b5b0
Fix dtype
BloodAxe Jan 27, 2024
2a9cce6
Merge remote-tracking branch 'origin/feature/SG-1333-remove-pycocotoo…
BloodAxe Jan 27, 2024
b55c5ef
Merge master
BloodAxe Feb 6, 2024
fa50ec9
Updated change_bbox_bounds_for_image_size to not be a BC
BloodAxe Feb 8, 2024
ab20fc1
Merge branch 'master' into feature/SG-1333-remove-pycocotools
BloodAxe Feb 9, 2024
1e4176c
Merge branch 'master' into feature/SG-1333-remove-pycocotools
BloodAxe Feb 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ jobs:
. venv/bin/activate
python3 -m pip install pip==23.1.2
python3 -m pip install -r requirements.txt
python3 -m pip install -r requirements.dev.txt
- run:
name: edit package version
command: |
Expand Down Expand Up @@ -539,6 +540,7 @@ jobs:
source << parameters.sg_new_env_name >>/bin/activate
python3.8 -m pip install --upgrade setuptools pip wheel
python3.8 -m pip install -r requirements.txt
python3.8 -m pip install -r requirements.dev.txt
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
python3.8 -m pip install .
python3.8 -m pip install torch torchvision torchaudio
make sweeper_test
Expand Down Expand Up @@ -576,6 +578,7 @@ jobs:
source << parameters.sg_new_env_name >>/bin/activate
python3.8 -m pip install --upgrade setuptools pip wheel
python3.8 -m pip install -r requirements.txt
python3.8 -m pip install -r requirements.dev.txt
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
python3.8 -m pip install .
python3.8 -m pip install torch torchvision torchaudio
python3.8 -m pip install pytorch-quantization==2.1.2 --extra-index-url https://pypi.ngc.nvidia.com
Expand Down
1 change: 1 addition & 0 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pre-commit==2.20.0
gitpython>=3.1.0
ipykernel==6.25
nbconvert==7.8.0
pycocotools==2.0.6
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ pillow>=5.3.0,!=8.3
pip-tools>=6.12.1
pyparsing==2.4.5
einops==0.3.2
pycocotools==2.0.6
protobuf==3.20.3
treelib==1.6.1
termcolor==1.1.0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import copy
import dataclasses
import json
import os

import numpy as np
from pycocotools.coco import COCO
from typing import List, Optional
from typing import List, Optional, Tuple

from contextlib import redirect_stdout
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException
from super_gradients.common.deprecate import deprecated_parameter
from super_gradients.training.datasets.data_formats.bbox_formats.xywh import xywh_to_xyxy_inplace
from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
from super_gradients.common.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException
from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
from super_gradients.training.utils.detection_utils import change_bbox_bounds_for_image_size

logger = get_logger(__name__)

Expand Down Expand Up @@ -48,6 +50,10 @@ def __init__(
:param class_ids_to_ignore: List of class ids to ignore in the dataset. By default, doesnt ignore any class.
:param tight_box_rotation: This parameter is deprecated and will be removed in a SuperGradients 3.8.
"""
if tight_box_rotation is not None:
logger.warning(
"Parameter `tight_box_rotation` is deprecated and will be removed in a SuperGradients 3.8." "Please remove this parameter from your code."
)
self.images_dir = images_dir
self.json_annotation_file = json_annotation_file
self.with_crowd = with_crowd
Expand All @@ -72,36 +78,34 @@ def __init__(
)

def _setup_data_source(self) -> int:
"""Initialize img_and_target_path_list and warn if label file is missing

:return: List of tuples made of (img_path,target_path)
"""
Parse COCO annotation file
:return: Number of images in annotation JSON
"""
annotation_file_path = os.path.join(self.data_dir, self.json_annotation_file)
if not os.path.exists(annotation_file_path):
raise ValueError("Could not find annotation file under " + str(annotation_file_path))

self.coco = self._init_coco()
self.class_ids = sorted(cls_id for cls_id in self.coco.getCatIds() if cls_id not in self.class_ids_to_ignore)
self.original_classes = list([category["name"] for category in self.coco.loadCats(self.class_ids)])
all_class_names, annotations = parse_coco_into_detection_annotations(
annotation_file_path,
exclude_classes=None,
include_classes=None,
# This parameter exists solely for the purpose of keeping the backward compatibility with the old code.
# Once we refactor base dataset, we can remove this parameter and use only exclude_classes/include_classes
# at parsing time instead.
class_ids_to_ignore=self.class_ids_to_ignore,
image_path_prefix=os.path.join(self.data_dir, self.images_dir),
)

self.original_classes = list(all_class_names)
self.classes = copy.deepcopy(self.original_classes)
self.sample_id_to_coco_id = self.coco.getImgIds()
return len(self.sample_id_to_coco_id)
self._annotations = annotations
return len(annotations)

@property
def _all_classes(self) -> List[str]:
return self.original_classes

def _init_coco(self) -> COCO:
annotation_file_path = os.path.join(self.data_dir, self.json_annotation_file)
if not os.path.exists(annotation_file_path):
raise ValueError("Could not find annotation file under " + str(annotation_file_path))

if not self.verbose:
with redirect_stdout(open(os.devnull, "w")):
coco = COCO(annotation_file_path)
else:
coco = COCO(annotation_file_path)

remove_useless_info(coco, False)
return coco

def _load_annotation(self, sample_id: int) -> dict:
"""
Load relevant information of a specific image.
Expand All @@ -115,81 +119,167 @@ def _load_annotation(self, sample_id: int) -> dict:
:return img_path: Path to the associated image
"""

img_id = self.sample_id_to_coco_id[sample_id]

img_metadata = self.coco.loadImgs(img_id)[0]
width = img_metadata["width"]
height = img_metadata["height"]

img_annotation_ids = self.coco.getAnnIds(imgIds=[int(img_id)])
img_annotations = self.coco.loadAnns(img_annotation_ids)

cleaned_annotations = []
for annotation in img_annotations:
x1 = np.max((0, annotation["bbox"][0]))
y1 = np.max((0, annotation["bbox"][1]))
x2 = np.min((width, x1 + np.max((0, annotation["bbox"][2]))))
y2 = np.min((height, y1 + np.max((0, annotation["bbox"][3]))))
if annotation["area"] > 0 and x2 >= x1 and y2 >= y1:
annotation["clean_bbox"] = [x1, y1, x2, y2]
cleaned_annotations.append(annotation)
annotation = self._annotations[sample_id]

non_crowd_annotations = [annotation for annotation in cleaned_annotations if annotation["iscrowd"] == 0]
width = annotation.image_width
height = annotation.image_height

target = np.zeros((len(non_crowd_annotations), 5))
# Make a copy of the annotations, so that we can modify them
boxes_xyxy = change_bbox_bounds_for_image_size(annotation.ann_boxes_xyxy, img_shape=(height, width))
iscrowd = annotation.ann_is_crowd.copy()
labels = annotation.ann_labels.copy()

for ix, annotation in enumerate(non_crowd_annotations):
cls = self.class_ids.index(annotation["category_id"])
target[ix, 0:4] = annotation["clean_bbox"]
target[ix, 4] = cls

crowd_annotations = [annotation for annotation in cleaned_annotations if annotation["iscrowd"] == 1]

crowd_target = np.zeros((len(crowd_annotations), 5))
for ix, annotation in enumerate(crowd_annotations):
cls = self.class_ids.index(annotation["category_id"])
crowd_target[ix, 0:4] = annotation["clean_bbox"]
crowd_target[ix, 4] = cls
# Exclude boxes with invalid dimensions (x1 > x2 or y1 > y2)
mask = np.logical_and(boxes_xyxy[:, 2] >= boxes_xyxy[:, 0], boxes_xyxy[:, 3] >= boxes_xyxy[:, 1])
boxes_xyxy = boxes_xyxy[mask]
iscrowd = iscrowd[mask]
labels = labels[mask]

# Currently, the base class includes a feature to resize the image, so we need to resize the target as well when self.input_dim is set.
initial_img_shape = (height, width)
if self.input_dim is not None:
r = min(self.input_dim[0] / height, self.input_dim[1] / width)
target[:, :4] *= r
crowd_target[:, :4] *= r
resized_img_shape = (int(height * r), int(width * r))
scale_factor = min(self.input_dim[0] / height, self.input_dim[1] / width)
resized_img_shape = (int(height * scale_factor), int(width * scale_factor))
else:
resized_img_shape = initial_img_shape

file_name = img_metadata["file_name"] if "file_name" in img_metadata else "{:012}".format(img_id) + ".jpg"
img_path = os.path.join(self.data_dir, self.images_dir, file_name)
img_id = self.sample_id_to_coco_id[sample_id]
targets = np.concatenate([boxes_xyxy[~iscrowd], labels[~iscrowd, None]], axis=1).astype(np.float32)
crowd_targets = np.concatenate([boxes_xyxy[iscrowd], labels[iscrowd, None]], axis=1).astype(np.float32)

annotation = {
"target": target,
"crowd_target": crowd_target,
"target": targets,
"crowd_target": crowd_targets,
"initial_img_shape": initial_img_shape,
"resized_img_shape": resized_img_shape,
"img_path": img_path,
"id": np.array([img_id]),
"img_path": annotation.image_path,
}
return annotation


def remove_useless_info(coco: COCO, use_seg_info: bool = False) -> None:
@dataclasses.dataclass
class DetectionAnnotation:
image_path: str
image_width: int
image_height: int

# Bounding boxes in XYXY format
ann_boxes_xyxy: np.ndarray
ann_is_crowd: np.ndarray
ann_labels: np.ndarray


def parse_coco_into_detection_annotations(
ann: str,
exclude_classes: Optional[List[str]] = None,
include_classes: Optional[List[str]] = None,
class_ids_to_ignore: Optional[List[int]] = None,
image_path_prefix=None,
) -> Tuple[List[str], List[DetectionAnnotation]]:
"""
Remove useless info in coco dataset. COCO object is modified inplace.
This function is mainly used for saving memory (save about 30% mem).
Load COCO detection dataset from annotation file.
:param ann: A path to the JSON annotation file in COCO format.
:param exclude_classes: List of classes to exclude from the dataset. All other classes will be included.
This parameter is mutually exclusive with include_classes and class_ids_to_ignore.

:param include_classes: List of classes to include in the dataset. All other classes will be excluded.
This parameter is mutually exclusive with exclude_classes and class_ids_to_ignore.
:param class_ids_to_ignore: List of category ids to ignore in the dataset. All other classes will be included.
This parameter added for the purpose of backward compatibility with the class_ids_to_ignore
argument of COCOFormatDetectionDataset but will be
removed in future in favor of include_classes/exclude_classes.
This parameter is mutually exclusive with exclude_classes and include_classes.
:param image_path_prefix: A prefix to add to the image paths in the annotation file.
:return: Tuple (class_names, annotations) where class_names is a list of class names
(respecting include_classes/exclude_classes/class_ids_to_ignore) and
annotations is a list of DetectionAnnotation objects.
"""
if isinstance(coco, COCO):
dataset = coco.dataset
dataset.pop("info", None)
dataset.pop("licenses", None)
for img in dataset["images"]:
img.pop("license", None)
img.pop("coco_url", None)
img.pop("date_captured", None)
img.pop("flickr_url", None)
if "annotations" in coco.dataset and not use_seg_info:
for anno in coco.dataset["annotations"]:
anno.pop("segmentation", None)
with open(ann, "r") as f:
coco = json.load(f)

# Extract class names and class ids
category_ids = np.array([category["id"] for category in coco["categories"]], dtype=int)
category_names = np.array([category["name"] for category in coco["categories"]], dtype=str)

# Extract box annotations
ann_box_xyxy = xywh_to_xyxy_inplace(np.array([annotation["bbox"] for annotation in coco["annotations"]], dtype=np.float32), image_shape=None)

ann_category_id = np.array([annotation["category_id"] for annotation in coco["annotations"]], dtype=int)
ann_iscrowd = np.array([annotation["iscrowd"] for annotation in coco["annotations"]], dtype=bool)
ann_image_ids = np.array([annotation["image_id"] for annotation in coco["annotations"]], dtype=int)

# Extract image stuff
img_ids = np.array([img["id"] for img in coco["images"]], dtype=int)
img_paths = np.array([img["file_name"] if "file_name" in img else "{:012}".format(img["id"]) + ".jpg" for img in coco["images"]], dtype=str)
img_width_height = np.array([(img["width"], img["height"]) for img in coco["images"]], dtype=int)

# Now, we can drop the annotations that belongs to the excluded classes
if int(class_ids_to_ignore is not None) + int(exclude_classes is not None) + int(include_classes is not None) > 1:
raise ValueError("Only one of exclude_classes, class_ids_to_ignore or include_classes can be specified")
elif exclude_classes is not None:
if len(exclude_classes) != len(set(exclude_classes)):
raise ValueError("The excluded classes must be unique")
classes_not_in_dataset = set(exclude_classes).difference(set(category_names))
if len(classes_not_in_dataset) > 0:
raise ValueError(f"One or more of the excluded classes does not exist in the dataset: {classes_not_in_dataset}")
keep_classes_mask = np.isin(category_names, exclude_classes, invert=True)
elif class_ids_to_ignore is not None:
if len(class_ids_to_ignore) != len(set(class_ids_to_ignore)):
raise ValueError("The ignored classes must be unique")
classes_not_in_dataset = set(class_ids_to_ignore).difference(set(category_ids))
if len(classes_not_in_dataset) > 0:
raise ValueError(f"One or more of the ignored classes does not exist in the dataset: {classes_not_in_dataset}")
keep_classes_mask = np.isin(category_ids, class_ids_to_ignore, invert=True)
elif include_classes is not None:
if len(include_classes) != len(set(include_classes)):
raise ValueError("The included classes must be unique")
classes_not_in_dataset = set(include_classes).difference(set(category_names))
if len(classes_not_in_dataset) > 0:
raise ValueError(f"One or more of the included classes does not exist in the dataset: {classes_not_in_dataset}")
keep_classes_mask = np.isin(category_names, include_classes)
else:
keep_classes_mask = None

if keep_classes_mask is not None:
category_ids = category_ids[keep_classes_mask]
category_names = category_names[keep_classes_mask]

keep_anns_mask = np.isin(ann_category_id, category_ids)
ann_category_id = ann_category_id[keep_anns_mask]

# category_ids can be non-sequential and not ordered
num_categories = len(category_ids)

# Make sequential
order = np.argsort(category_ids, kind="stable")
category_ids = category_ids[order] #
category_names = category_names[order]

# Remap category ids to be in range [0, num_categories)
class_label_table = np.zeros(np.max(category_ids) + 1, dtype=int) - 1
new_class_ids = np.arange(num_categories, dtype=int)
class_label_table[category_ids] = new_class_ids

# Remap category ids in annotations
ann_category_id = class_label_table[ann_category_id]
if (ann_category_id < 0).any():
raise ValueError("Some annotations have class ids that are not in the list of classes. This probably indicates a bug in the annotation file")

annotations = []

for img_id, image_path, (image_width, image_height) in zip(img_ids, img_paths, img_width_height):
mask = ann_image_ids == img_id

if image_path_prefix is not None:
image_path = os.path.join(image_path_prefix, image_path)

ann = DetectionAnnotation(
image_path=image_path,
image_width=image_width,
image_height=image_height,
ann_boxes_xyxy=ann_box_xyxy[mask],
ann_is_crowd=ann_iscrowd[mask],
ann_labels=ann_category_id[mask],
)
annotations.append(ann)

return category_names, annotations
Loading
Loading