Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix/sg 000 fix yolox replace head #1411

Merged
merged 1 commit into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,12 @@ def __init__(
self.n_anchors = 1
self.grid = [torch.zeros(1)] * self.detection_layers_num # init grid

self.register_buffer("stride", torch.tensor(stride), persistent=False)
if torch.is_tensor(stride):
stride = stride.clone().detach()
else:
stride = torch.tensor(stride)

self.register_buffer("stride", stride, persistent=False)
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved

self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
Expand Down Expand Up @@ -691,7 +696,7 @@ def replace_head(self, new_num_classes=None, new_head=None):

new_last_layer = DetectX(
num_classes=new_num_classes,
stride=self._head.anchors.stride,
stride=self.strides,
Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved
activation_func_type=activation_type,
channels=[width_mult(v) for v in (256, 512, 1024)],
depthwise=isinstance(old_detectx.cls_convs[0][0], GroupedConvBlock),
Expand Down
17 changes: 11 additions & 6 deletions tests/unit_tests/replace_head_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os
import shutil
import unittest

import torch
Expand All @@ -14,6 +12,17 @@ def setUp(self) -> None:
self.device = "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu"
super_gradients.init_trainer()

def test_yolox_replace_head(self):
input = torch.randn(1, 3, 640, 640).to(self.device)
for model in [Models.YOLOX_S, Models.YOLOX_M, Models.YOLOX_L, Models.YOLOX_T]:
model = models.get(model, pretrained_weights="coco").to(self.device).eval()
num_classes = 100
model.replace_head(new_num_classes=num_classes)
outputs = model.forward(input)
self.assertEqual(outputs[0].size(4), num_classes + 5)
self.assertEqual(outputs[1].size(4), num_classes + 5)
self.assertEqual(outputs[2].size(4), num_classes + 5)

def test_ppyolo_replace_head(self):
input = torch.randn(1, 3, 640, 640).to(self.device)
for model in [Models.PP_YOLOE_S, Models.PP_YOLOE_M, Models.PP_YOLOE_L, Models.PP_YOLOE_X]:
Expand All @@ -37,10 +46,6 @@ def test_dekr_replace_head(self):
self.assertEqual(heatmap.size(1), 20 + 1)
self.assertEqual(offsets.size(1), 20 * 2)

def tearDown(self) -> None:
if os.path.exists("~/.cache/torch/hub/"):
shutil.rmtree("~/.cache/torch/hub/")

Louis-Dupont marked this conversation as resolved.
Show resolved Hide resolved

if __name__ == "__main__":
unittest.main()