Skip to content

Commit

Permalink
Hotfix/sg 000 Fix support of arbitrary number of heads (#1431)
Browse files Browse the repository at this point in the history
* Added missing import

* Fix YoloNASDFLHead to respect the number of heads

* Revert back dtype
  • Loading branch information
BloodAxe committed Aug 30, 2023
1 parent 76d9ee7 commit 8278880
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
11 changes: 9 additions & 2 deletions src/super_gradients/module_interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from .module_interfaces import HasPredict, HasPreprocessingParams, SupportsReplaceNumClasses
from .exportable_detector import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule
from .exportable_detector import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule, ModelHasNoPreprocessingParamsException

__all__ = ["HasPredict", "HasPreprocessingParams", "SupportsReplaceNumClasses", "ExportableObjectDetectionModel", "AbstractObjectDetectionDecodingModule"]
__all__ = [
"HasPredict",
"HasPreprocessingParams",
"SupportsReplaceNumClasses",
"ExportableObjectDetectionModel",
"AbstractObjectDetectionDecodingModule",
"ModelHasNoPreprocessingParamsException",
]
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def _init_weights(self):

@torch.jit.ignore
def forward_train(self, feats: Tuple[Tensor, ...]):
feats = feats[: self.num_heads]
anchors, anchor_points, num_anchors_list, stride_tensor = generate_anchors_for_grid_cell(
feats, self.fpn_strides, self.grid_cell_scale, self.grid_cell_offset
)
Expand All @@ -215,7 +216,7 @@ def forward_train(self, feats: Tuple[Tensor, ...]):
return cls_score_list, reg_distri_list, anchors, anchor_points, num_anchors_list, stride_tensor

def forward_eval(self, feats: Tuple[Tensor, ...]) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, ...]]:

feats = feats[: self.num_heads]
cls_score_list, reg_distri_list, reg_dist_reduced_list = [], [], []

for i, feat in enumerate(feats):
Expand Down Expand Up @@ -290,7 +291,7 @@ def _generate_anchors(self, feats=None, dtype=None, device=None):
else:
shift_y, shift_x = torch.meshgrid(shift_y, shift_x)

anchor_point = torch.stack([shift_x, shift_y], dim=-1)
anchor_point = torch.stack([shift_x, shift_y], dim=-1).to(dtype=dtype)
anchor_points.append(anchor_point.reshape([-1, 2]))
stride_tensor.append(torch.full([h * w, 1], stride, dtype=dtype))
anchor_points = torch.cat(anchor_points)
Expand Down

0 comments on commit 8278880

Please sign in to comment.