Skip to content

Commit

Permalink
Explicitly remove the stride keys from the checkpoint if they are pre…
Browse files Browse the repository at this point in the history
…sent which should fix the issue with DeciDet checkpoints (#1386)
  • Loading branch information
BloodAxe committed Aug 21, 2023
1 parent 67f7a4e commit 15802c5
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/super_gradients/training/models/detection_models/yolo_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import math
import warnings
from typing import Union, Type, List, Tuple, Optional
Expand Down Expand Up @@ -590,6 +591,15 @@ def forward(self, x):

def load_state_dict(self, state_dict, strict=True):
try:
keys_dropped_in_sg_320 = {
"stride",
"_head.anchors._stride",
"_head.anchors._anchors",
"_head.anchors._anchor_grid",
"_head._modules_list.14.stride",
}
state_dict = collections.OrderedDict([(k, v) for k, v in state_dict.items() if k not in keys_dropped_in_sg_320])

super().load_state_dict(state_dict, strict)
except RuntimeError as e:
raise RuntimeError(
Expand Down

0 comments on commit 15802c5

Please sign in to comment.