Skip to content

Commit

Permalink
Deprecate tight_box_rotation parameters in COCODetectionDataset (#1786)
Browse files Browse the repository at this point in the history
* Deprecate cache and cache_dir parameters support

* Deprecate tight box rotation support for COCO dataset

* deprecated_parameter

* Fix bug in deprecated_parameter

* Improve test but adding subTest to indicate a tested architecture and use np.testing.assert_array_almost_equal to get more detailed output if test fails
  • Loading branch information
BloodAxe committed Feb 6, 2024
1 parent 5feb788 commit d5a85fd
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 54 deletions.
98 changes: 98 additions & 0 deletions src/super_gradients/common/deprecate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import inspect
import warnings
from functools import wraps
from typing import Optional, Callable
from pkg_resources import parse_version

__all__ = ["deprecated", "deprecated_parameter", "deprecated_training_param", "deprecate_param"]


def deprecated(deprecated_since: str, removed_from: str, target: Optional[callable] = None, reason: str = ""):
"""
Expand Down Expand Up @@ -78,6 +81,101 @@ def wrapper(*args, **kwargs):
return decorator


def deprecated_parameter(parameter_name: str, deprecated_since: str, removed_from: str, target: Optional[callable] = None, reason: str = ""):
"""
Decorator to mark a parameter of a callable as deprecated.
It provides a clear and actionable warning message informing
the user about the version in which parameter was deprecated,
the version in which it will be removed, and guidance on how to replace it.
:param parameter_name: Name of the parameter
:param deprecated_since: Version number when the function was deprecated.
:param removed_from: Version number when the function will be removed.
:param target: (Optional) The new function that should be used as a replacement. If provided, it will guide the user to the updated function.
:param reason: (Optional) Additional information or reason for the deprecation.
Example usage:
If a parameter removed with no replacement alternative:
>>> @deprecated_parameter("c",deprecated_since='3.2.0', removed_from='4.0.0', reason="This argument is not used")
>>> def do_some_work(a,b,c = None):
>>> return a+b
If a parameter has new name:
>>> @deprecated_parameter("c", target="new_parameter", deprecated_since='3.2.0', removed_from='4.0.0', reason="This argument is not used")
>>> def do_some_work(a,b,target,c = None):
>>> return a+b+target
"""

def decorator(func: callable) -> callable:
argspec = inspect.getfullargspec(func)
argument_index = argspec.args.index(parameter_name)

default_value = None
sig = inspect.signature(func)
for name, param in sig.parameters.items():
if name == parameter_name:
default_value = param.default
break

@wraps(func)
def wrapper(*args, **kwargs):

# Initialize the value to the default value
value = default_value

# Try to get the actual value from the arguments
# Have to check both positional and keyword arguments
try:
value = args[argument_index]
except IndexError:
if parameter_name in kwargs:
value = kwargs[parameter_name]

if value != default_value and not wrapper._warned:
import super_gradients

is_still_supported = parse_version(super_gradients.__version__) < parse_version(removed_from)
status_msg = "is deprecated" if is_still_supported else "was deprecated and has been removed"
message = (
f"Parameter `{parameter_name}` of `{func.__module__}.{func.__name__}` {status_msg} since version `{deprecated_since}` "
f"and will be removed in version `{removed_from}`.\n"
)
if reason:
message += f"Reason: {reason}.\n"

if target is not None:
message += (
f"Please update your code:\n"
f" [-] from `{func.__name__}(..., {parameter_name}={value})`\n"
f" [+] to `{func.__name__}(..., {target}={value})`\n"
)
else:
# fmt: off
message += (
f"Please update your code:\n"
f" [-] from `{func.__name__}(..., {parameter_name}={value})`\n"
f" [+] to `{func.__name__}(...)`\n"
)
# fmt: on

if is_still_supported:
warnings.simplefilter("once", DeprecationWarning) # Required, otherwise the warning may never be displayed.
warnings.warn(message, DeprecationWarning, stacklevel=2)
wrapper._warned = True
else:
raise ImportError(message)

return func(*args, **kwargs)

# Each decorated object will have its own _warned state
# This state ensures that the warning will appear only once, to avoid polluting the console in case the function is called too often.
wrapper._warned = False
return wrapper

return decorator


def deprecated_training_param(deprecated_tparam_name: str, deprecated_since: str, removed_from: str, new_arg_assigner: Callable, message: str = ""):
"""
Decorator for deprecating training hyperparameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ train_dataset_params:
- DetectionTargetsFormatTransform:
input_dim: ${dataset_params.train_dataset_params.input_dim}
output_format: LABEL_CXCYWH
tight_box_rotation: False
class_inclusion_list:
max_num_samples:
with_crowd: False
Expand Down Expand Up @@ -65,7 +64,6 @@ val_dataset_params:
- DetectionTargetsFormatTransform:
input_dim: ${dataset_params.val_dataset_params.input_dim}
output_format: LABEL_CXCYWH
tight_box_rotation: False
class_inclusion_list:
max_num_samples:
with_crowd: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ train_dataset_params:
- DetectionTargetsFormatTransform:
output_format: LABEL_CXCYWH

tight_box_rotation: False
class_inclusion_list:
max_num_samples:
with_crowd: False
Expand Down Expand Up @@ -78,7 +77,6 @@ val_dataset_params:
std: [ 58.395, 57.12, 57.375 ]
- DetectionTargetsFormatTransform:
output_format: LABEL_CXCYWH
tight_box_rotation: False
class_inclusion_list:
max_num_samples:
with_crowd: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ train_dataset_params:
input_dim: ${dataset_params.train_dataset_params.input_dim}
output_format: LABEL_NORMALIZED_CXCYWH

tight_box_rotation: False
class_inclusion_list:
max_num_samples:
with_crowd: False
Expand Down Expand Up @@ -61,7 +60,6 @@ val_dataset_params:
- DetectionTargetsFormatTransform:
input_dim: ${dataset_params.val_dataset_params.input_dim}
output_format: LABEL_NORMALIZED_CXCYWH
tight_box_rotation: False
class_inclusion_list:
max_num_samples:
with_crowd: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ train_dataset_params:
- DetectionTargetsFormatTransform:
output_format: LABEL_CXCYWH

tight_box_rotation: False
class_inclusion_list:
max_num_samples:
with_crowd: False
Expand Down Expand Up @@ -153,7 +152,6 @@ val_dataset_params:
- DetectionTargetsFormatTransform:
input_dim: [640, 640]
output_format: LABEL_CXCYWH
tight_box_rotation: False
class_inclusion_list:
max_num_samples:
with_crowd: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ train_dataset_params:
- DetectionTargetsFormatTransform:
input_dim: ${dataset_params.train_dataset_params.input_dim}
output_format: LABEL_CXCYWH
tight_box_rotation: False
class_inclusion_list:
max_num_samples:
with_crowd: False
Expand Down Expand Up @@ -77,7 +76,6 @@ val_dataset_params:
- DetectionTargetsFormatTransform:
input_dim: ${dataset_params.val_dataset_params.input_dim}
output_format: LABEL_CXCYWH
tight_box_rotation: False
class_inclusion_list:
max_num_samples:
with_crowd: True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def __init__(
"""
:param json_file: Name of the coco json file, that resides in data_dir/annotations/json_file.
:param subdir: Sub directory of data_dir containing the data.
:param tight_box_rotation: bool, whether to use of segmentation maps convex hull as target_seg
(check get_sample docs).
:param with_crowd: Add the crowd groundtruths to __getitem__
kwargs:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import copy
import os

import cv2
import numpy as np
from pycocotools.coco import COCO
from typing import List, Optional

from contextlib import redirect_stdout
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.deprecate import deprecated_parameter
from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset
from super_gradients.common.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException
from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL
Expand All @@ -23,29 +23,33 @@ class COCOFormatDetectionDataset(DetectionDataset):
Output format: (x, y, x, y, class_id)
"""

@deprecated_parameter(
"tight_box_rotation",
deprecated_since="3.7.0",
removed_from="3.8.0",
reason="Support of `tight_box_rotation` has been removed. This parameter has no effect anymore.",
)
def __init__(
self,
data_dir: str,
json_annotation_file: str,
images_dir: str,
tight_box_rotation: bool = False,
with_crowd: bool = True,
class_ids_to_ignore: Optional[List[int]] = None,
tight_box_rotation=None,
*args,
**kwargs,
):
"""
:param data_dir: Where the data is stored.
:param json_annotation_file: Name of the coco json file. Path relative to data_dir.
:param images_dir: Name of the directory that includes all the images. Path relative to data_dir.
:param tight_box_rotation: bool, whether to use of segmentation maps convex hull as target_seg
(check get_sample docs).
:param with_crowd: Add the crowd groundtruths to __getitem__
:param class_ids_to_ignore: List of class ids to ignore in the dataset. By default, doesnt ignore any class.
:param tight_box_rotation: This parameter is deprecated and will be removed in a SuperGradients 3.8.
"""
self.images_dir = images_dir
self.json_annotation_file = json_annotation_file
self.tight_box_rotation = tight_box_rotation
self.with_crowd = with_crowd
self.class_ids_to_ignore = class_ids_to_ignore or []

Expand Down Expand Up @@ -95,7 +99,7 @@ def _init_coco(self) -> COCO:
else:
coco = COCO(annotation_file_path)

remove_useless_info(coco, self.tight_box_rotation)
remove_useless_info(coco, False)
return coco

def _load_annotation(self, sample_id: int) -> dict:
Expand Down Expand Up @@ -133,21 +137,11 @@ def _load_annotation(self, sample_id: int) -> dict:
non_crowd_annotations = [annotation for annotation in cleaned_annotations if annotation["iscrowd"] == 0]

target = np.zeros((len(non_crowd_annotations), 5))
num_seg_values = 98 if self.tight_box_rotation else 0
target_segmentation = np.ones((len(non_crowd_annotations), num_seg_values))
target_segmentation.fill(np.nan)

for ix, annotation in enumerate(non_crowd_annotations):
cls = self.class_ids.index(annotation["category_id"])
target[ix, 0:4] = annotation["clean_bbox"]
target[ix, 4] = cls
if self.tight_box_rotation:
seg_points = [j for i in annotation.get("segmentation", []) for j in i]
if seg_points:
seg_points_c = np.array(seg_points).reshape((-1, 2)).astype(np.int32)
seg_points_convex = cv2.convexHull(seg_points_c).ravel()
else:
seg_points_convex = []
target_segmentation[ix, : len(seg_points_convex)] = seg_points_convex

crowd_annotations = [annotation for annotation in cleaned_annotations if annotation["iscrowd"] == 1]

Expand All @@ -163,7 +157,6 @@ def _load_annotation(self, sample_id: int) -> dict:
r = min(self.input_dim[0] / height, self.input_dim[1] / width)
target[:, :4] *= r
crowd_target[:, :4] *= r
target_segmentation *= r
resized_img_shape = (int(height * r), int(width * r))
else:
resized_img_shape = initial_img_shape
Expand All @@ -175,7 +168,6 @@ def _load_annotation(self, sample_id: int) -> dict:
annotation = {
"target": target,
"crowd_target": crowd_target,
"target_segmentation": target_segmentation,
"initial_img_shape": initial_img_shape,
"resized_img_shape": resized_img_shape,
"img_path": img_path,
Expand Down
49 changes: 26 additions & 23 deletions tests/unit_tests/repvgg_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from super_gradients.training.utils.utils import HpmStruct
import torch
import copy
import numpy as np


class BackboneBasedModel(torch.nn.Module):
Expand Down Expand Up @@ -50,29 +51,31 @@ def test_deployment_architecture(self):
# skip custom constructors to keep all_arch_params as general as a possible
if "repvgg" not in arch_name or "custom" in arch_name:
continue
model = ARCHITECTURES[arch_name](arch_params=self.all_arch_params)
self.assertTrue(hasattr(model.stem, "branch_3x3")) # check single layer for training mode
self.assertTrue(model.build_residual_branches)

training_mode_sd = model.state_dict()
for module in training_mode_sd:
self.assertFalse("reparam" in module) # deployment block included in training mode
test_input = torch.ones((1, in_channels, image_size, image_size))
model.eval()
training_mode_output = model(test_input)

model.prep_model_for_conversion()
self.assertTrue(hasattr(model.stem, "rbr_reparam")) # check single layer for training mode
self.assertFalse(model.build_residual_branches)

deployment_mode_sd = model.state_dict()
for module in deployment_mode_sd:
self.assertFalse("running_mean" in module) # BN were not fused
self.assertFalse("branch" in module) # branches were not joined

deployment_mode_output = model(test_input)
# difference is of very low magnitude
self.assertFalse(False in torch.isclose(training_mode_output, deployment_mode_output, atol=1e-4))

with self.subTest(arch_name=arch_name):
model = ARCHITECTURES[arch_name](arch_params=self.all_arch_params)
self.assertTrue(hasattr(model.stem, "branch_3x3")) # check single layer for training mode
self.assertTrue(model.build_residual_branches)

training_mode_sd = model.state_dict()
for module in training_mode_sd:
self.assertFalse("reparam" in module) # deployment block included in training mode
test_input = torch.ones((1, in_channels, image_size, image_size))
model.eval()
training_mode_output = model(test_input)

model.prep_model_for_conversion()
self.assertTrue(hasattr(model.stem, "rbr_reparam")) # check single layer for training mode
self.assertFalse(model.build_residual_branches)

deployment_mode_sd = model.state_dict()
for module in deployment_mode_sd:
self.assertFalse("running_mean" in module) # BN were not fused
self.assertFalse("branch" in module) # branches were not joined

deployment_mode_output = model(test_input)
# difference is of very low magnitude
np.testing.assert_array_almost_equal(training_mode_output.detach().numpy(), deployment_mode_output.detach().numpy(), decimal=4)

def test_backbone_mode(self):
"""
Expand Down

0 comments on commit d5a85fd

Please sign in to comment.