Skip to content

Commit

Permalink
Fixing yolov8-seg validation (#1522)
Browse files Browse the repository at this point in the history
* Fixing yolov8-seg validation

* Moving detach to sparseml/pytorch/utils
  • Loading branch information
abhinavnmagic committed Apr 14, 2023
1 parent 674d50e commit b442abc
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
12 changes: 12 additions & 0 deletions src/sparseml/pytorch/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
"MEMORY_BOUNDED",
"memory_aware_threshold",
"download_framework_model_by_recipe_type",
"detach",
]


Expand Down Expand Up @@ -1162,3 +1163,14 @@ def download_framework_model_by_recipe_type(
framework_model = zoo_model.training.default.get_file(model_name)

return framework_model.path


def detach(x: Union[torch.Tensor, List, Tuple]):
if isinstance(x, torch.Tensor):
return x.detach()
elif isinstance(x, List):
return [detach(e) for e in x]
elif isinstance(x, Tuple):
return tuple([detach(e) for e in x])
else:
raise ValueError("Unexpected type to detach")
3 changes: 2 additions & 1 deletion src/sparseml/yolov8/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,8 @@ def _load(self, weights: str):
self.model = self.ModelClass(dict(self.ckpt["model_yaml"]))
else:
self.model = self.ModelClass(dict(self.ckpt["model"].yaml))
if "recipe" in self.ckpt:

if "recipe" in self.ckpt and self.ckpt["recipe"]:
manager = ScheduledModifierManager.from_yaml(self.ckpt["recipe"])
epoch = self.ckpt.get("epoch", -1)
if epoch < 0:
Expand Down
3 changes: 2 additions & 1 deletion src/sparseml/yolov8/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from tqdm import tqdm

from sparseml.pytorch.utils import detach
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.yolo.engine.validator import BaseValidator
Expand Down Expand Up @@ -136,7 +137,7 @@ def __call__(self, trainer=None, model=None):

# During QAT the resulting preds are grad required, breaking
# the update metrics function.
detached_preds = [p.detach() for p in preds]
detached_preds = detach(preds)
self.update_metrics(detached_preds, batch)

if self.args.plots and batch_i < 3:
Expand Down

0 comments on commit b442abc

Please sign in to comment.