Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 1033 fix yolox anchors #1369

Merged
merged 37 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
d5d33ae
Update readme
BloodAxe Aug 9, 2023
a00a498
Fix small bug in __repr__ implementation of KeypointsImageToTensor
BloodAxe Aug 9, 2023
beae2a4
Test
BloodAxe Aug 9, 2023
5e1305a
Test
BloodAxe Aug 9, 2023
b1a9066
Test
BloodAxe Aug 9, 2023
0c11a52
Test
BloodAxe Aug 9, 2023
f45f448
Test
BloodAxe Aug 9, 2023
f2694c8
Test
BloodAxe Aug 9, 2023
1c4bda0
Make graphsurgeon an optional
BloodAxe Aug 9, 2023
b42c0a0
Make graphsurgeon an optional
BloodAxe Aug 9, 2023
7c260fd
Properly handle imports of optional packages
BloodAxe Aug 9, 2023
c8c62ce
Added empty __init__.py files
BloodAxe Aug 9, 2023
d8450c3
Do imports of gs inside the export call
BloodAxe Aug 9, 2023
cd278aa
Do imports of gs inside the export call
BloodAxe Aug 9, 2023
67c249d
Fix DEKR's missing HasPredict interface
BloodAxe Aug 9, 2023
30360c6
Update notebook & example doc to reflect changes in imports & functio…
BloodAxe Aug 9, 2023
d0dca14
Update readme
BloodAxe Aug 9, 2023
dfe2da5
Put back images
BloodAxe Aug 9, 2023
52091e8
Remove onnx_graphsurgeon from requirements and install it on demand
BloodAxe Aug 9, 2023
41dd1cb
Merge branch 'master' into feature/SG-000-fix-import-of-onnx-graphsur…
BloodAxe Aug 10, 2023
768c6e9
Merge branch 'master' into feature/SG-000-fix-import-of-onnx-graphsur…
BloodAxe Aug 10, 2023
f82c863
Install onnx_graphsurgeon in CI
BloodAxe Aug 10, 2023
20702a6
Install onnx_graphsurgeon in CI
BloodAxe Aug 10, 2023
1e5a29b
Merge branch 'master' into feature/SG-000-fix-import-of-onnx-graphsur…
BloodAxe Aug 11, 2023
2a97253
Working prototype of YoloX fix of Anchors that can load model weights…
BloodAxe Aug 11, 2023
af51c59
Added more tests for detection predict() and yolox checkpoint loading
BloodAxe Aug 14, 2023
6c7f2d5
Merge branch 'master' into feature/SG-1033-fix-yolox-anchors
BloodAxe Aug 14, 2023
1a9fa79
Fix version of ONNX-GS installed in CI and installed on-demand
BloodAxe Aug 14, 2023
0cb8d33
Merge branch 'master' into feature/SG-1033-fix-yolox-anchors
BloodAxe Aug 14, 2023
cc81568
Added docs
BloodAxe Aug 14, 2023
a117c77
Added docs
BloodAxe Aug 14, 2023
78f42e7
Added docs
BloodAxe Aug 14, 2023
af8327e
Remove leftover
BloodAxe Aug 14, 2023
6e840ed
Set ignore_errors=True to trainer test and declare why
BloodAxe Aug 11, 2023
77637c8
Merge remote-tracking branch 'origin/feature/SG-1033-fix-yolox-anchor…
BloodAxe Aug 15, 2023
e588bb8
Fix bug in maybe_remove_module_prefix
BloodAxe Aug 15, 2023
3e9f628
Merge branch 'master' into feature/SG-1033-fix-yolox-anchors
BloodAxe Aug 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,14 @@ def prep_model_for_conversion(self, input_size: Union[tuple, list] = None, **kwa
with convertible substitutes and remove all auxiliary or training related parts.
:param input_size: [H,W]
"""
self.head.cache_anchors(input_size)

# There is some discrepancy of what input_size is.
# When exporting to ONNX it is passed as 4-element tuple (B,C,H,W)
# When called from predict() it is just (H,W)
# So we take two last elements of the tuple which handles both cases but ultimately we should fix this
h, w = input_size[-2:]

self.head.cache_anchors((h, w))

for module in self.modules():
if isinstance(module, RepVGGBlock):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Union, Type, List, Tuple, Optional
from functools import lru_cache

import numpy as np
import torch
import torch.nn as nn

Expand Down Expand Up @@ -177,7 +178,7 @@ class DetectX(nn.Module):
def __init__(
self,
num_classes: int,
stride: torch.Tensor,
stride: np.ndarray,
activation_func_type: type,
channels: list,
depthwise=False,
Expand All @@ -203,7 +204,7 @@ def __init__(
self.n_anchors = 1
self.grid = [torch.zeros(1)] * self.detection_layers_num # init grid

self.register_buffer("stride", stride)
self.register_buffer("stride", torch.tensor(stride), persistent=False)

self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
Expand Down Expand Up @@ -409,7 +410,7 @@ def __init__(self, arch_params):
) # 24

self._shortcuts = nn.ModuleList([CrossModelSkipConnection() for _ in range(len(self._skip_connections_dict.keys()) - 1)])
self.anchors = anchors

self.width_mult = width_mult

def forward(self, intermediate_output):
Expand Down Expand Up @@ -481,6 +482,7 @@ def __init__(self, backbone: Type[nn.Module], arch_params: HpmStruct, initialize
self._image_processor: Optional[Processing] = None
self._default_nms_iou: Optional[float] = None
self._default_nms_conf: Optional[float] = None
self.register_buffer("strides", torch.tensor(self.arch_params.anchors.stride), persistent=False)

@staticmethod
def get_post_prediction_callback(conf: float, iou: float) -> DetectionPostPredictionCallback:
Expand Down Expand Up @@ -617,8 +619,6 @@ def _check_strides(self):
if not torch.equal(m.stride, stride):
raise RuntimeError("Provided anchor strides do not match the model strides")

self.register_buffer("stride", m.stride) # USED ONLY FOR CONVERSION

def _initialize_biases(self):
"""initialize biases into DetectX()"""
detect_module = self._head._modules_list[-1] # DetectX() module
Expand Down Expand Up @@ -650,7 +650,7 @@ def prep_model_for_conversion(self, input_size: Union[tuple, list] = None, **kwa
assert not self.training, "model has to be in eval mode to be converted"

# Verify dummy_input from converter is of multiple of the grid size
max_stride = int(max(self.stride))
max_stride = int(max(self.strides))

# Validate the image size
image_dims = input_size[-2:] # assume torch uses channels first layout
Expand Down
1,310 changes: 1,278 additions & 32 deletions src/super_gradients/training/utils/checkpoint_utils.py

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions src/super_gradients/training/utils/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def visualize_batch(
return out_images


class Anchors(nn.Module):
class Anchors:
"""
A wrapper function to hold the anchors used by detection models such as Yolo
"""
Expand All @@ -568,15 +568,15 @@ def __init__(self, anchors_list: List[List], strides: List[int]):
super().__init__()

self.__anchors_list = anchors_list
self.__strides = strides
self.__strides = tuple(strides)

self._check_all_lists(anchors_list)
self._check_all_len_equal_and_even(anchors_list)

self._stride = nn.Parameter(torch.Tensor(strides).float(), requires_grad=False)
anchors = torch.Tensor(anchors_list).float().view(len(anchors_list), -1, 2)
self._anchors = nn.Parameter(anchors / self._stride.view(-1, 1, 1), requires_grad=False)
self._anchor_grid = nn.Parameter(anchors.clone().view(len(anchors_list), 1, -1, 1, 1, 2), requires_grad=False)
self._stride = np.array(strides, dtype=np.float32)
anchors = np.array(anchors_list, dtype=np.float32).reshape((len(anchors_list), -1, 2))
self._anchors = anchors / self._stride.reshape((-1, 1, 1))
self._anchor_grid = anchors.copy().reshape(len(anchors_list), 1, -1, 1, 1, 2)

@staticmethod
def _check_all_lists(anchors: list) -> bool:
Expand All @@ -592,15 +592,15 @@ def _check_all_len_equal_and_even(anchors: list) -> bool:
raise RuntimeError("All objects of anchors_list must be of the same even length")

@property
def stride(self) -> nn.Parameter:
def stride(self) -> np.ndarray:
return self._stride

@property
def anchors(self) -> nn.Parameter:
def anchors(self) -> np.ndarray:
return self._anchors

@property
def anchor_grid(self) -> nn.Parameter:
def anchor_grid(self) -> np.ndarray:
return self._anchor_grid

@property
Expand Down
5 changes: 4 additions & 1 deletion tests/end_to_end_tests/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def tearDownClass(cls) -> None:
for experiment_name in cls.experiment_names:
experiment_dir = get_checkpoints_dir_path(experiment_name=experiment_name)
if os.path.isdir(experiment_dir):
shutil.rmtree(experiment_dir)
# TODO: Occasionally this method fails because log files are still open (See setup_logging() call).
# TODO: Need to find a way to close them at the end of training, this is however tricky to achieve
# TODO: because setup_logging() called outside of Trainer class.
shutil.rmtree(experiment_dir, ignore_errors=True)

@staticmethod
def get_classification_trainer(name=""):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/export_detection_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def manual_test_export_export_all_variants(self):
# pass

for model_type in [
# Models.YOLOX_S don't have full support for YOLOX so it's commented out,
Models.YOLOX_S,
Models.PP_YOLOE_S,
Models.YOLO_NAS_S,
]:
Expand Down
30 changes: 30 additions & 0 deletions tests/unit_tests/yolox_unit_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
import tempfile
import unittest

import torch

from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training.losses import YoloXDetectionLoss, YoloXFastDetectionLoss
from super_gradients.training.models.detection_models.yolox import YoloX_N, YoloX_T, YoloX_S, YoloX_M, YoloX_L, YoloX_X
from super_gradients.training.utils.detection_utils import DetectionCollateFN
Expand Down Expand Up @@ -69,6 +73,32 @@ def test_yolox_loss(self):
result = loss(predictions, targets.to(device))
print(result)

def test_yolo_x_checkpoint_solver(self):
"""
This test checks whether we can:
1. load an old pretrained weights for YoloX that has non-matching keys (Using custom solver under the hood).
2. load a regular checkpoint (As if one would train a model from scratch).
3. that both models produce the same output.

:return:
"""
model_variant = [Models.YOLOX_S, Models.YOLOX_M, Models.YOLOX_L, Models.YOLOX_T, Models.YOLOX_N]
for model_name in model_variant:
model = models.get(model_name, pretrained_weights="coco").eval()
input = torch.randn((1, 3, 320, 320))

output1 = model(input)

sd = model.state_dict()

with tempfile.TemporaryDirectory() as tmp_dirname:
path = os.path.join(tmp_dirname, f"{model_name}_coco.pth")
torch.save({"net": sd}, path)
model = models.get(model_name, num_classes=80, checkpoint_path=path).eval()
output2 = model(input)

assert torch.allclose(output1[0], output2[0], atol=1e-4)


if __name__ == "__main__":
unittest.main()