From b442abcbf17ee571eb862e7d37074448ad0cf207 Mon Sep 17 00:00:00 2001 From: abhinavnmagic <121893843+abhinavnmagic@users.noreply.github.com> Date: Fri, 14 Apr 2023 00:07:51 -0700 Subject: [PATCH] Fixing yolov8-seg validation (#1522) * Fixing yolov8-seg validation * Moving detach to sparseml/pytorch/utils --- src/sparseml/pytorch/utils/helpers.py | 12 ++++++++++++ src/sparseml/yolov8/trainers.py | 3 ++- src/sparseml/yolov8/validators.py | 3 ++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/utils/helpers.py b/src/sparseml/pytorch/utils/helpers.py index 31aee325f8b..7604af5e0ac 100644 --- a/src/sparseml/pytorch/utils/helpers.py +++ b/src/sparseml/pytorch/utils/helpers.py @@ -100,6 +100,7 @@ "MEMORY_BOUNDED", "memory_aware_threshold", "download_framework_model_by_recipe_type", + "detach", ] @@ -1162,3 +1163,14 @@ def download_framework_model_by_recipe_type( framework_model = zoo_model.training.default.get_file(model_name) return framework_model.path + + +def detach(x: Union[torch.Tensor, List, Tuple]): + if isinstance(x, torch.Tensor): + return x.detach() + elif isinstance(x, List): + return [detach(e) for e in x] + elif isinstance(x, Tuple): + return tuple([detach(e) for e in x]) + else: + raise ValueError("Unexpected type to detach") diff --git a/src/sparseml/yolov8/trainers.py b/src/sparseml/yolov8/trainers.py index 6a7008a2f86..0f2e51ce511 100644 --- a/src/sparseml/yolov8/trainers.py +++ b/src/sparseml/yolov8/trainers.py @@ -593,7 +593,8 @@ def _load(self, weights: str): self.model = self.ModelClass(dict(self.ckpt["model_yaml"])) else: self.model = self.ModelClass(dict(self.ckpt["model"].yaml)) - if "recipe" in self.ckpt: + + if "recipe" in self.ckpt and self.ckpt["recipe"]: manager = ScheduledModifierManager.from_yaml(self.ckpt["recipe"]) epoch = self.ckpt.get("epoch", -1) if epoch < 0: diff --git a/src/sparseml/yolov8/validators.py b/src/sparseml/yolov8/validators.py index 27067c891c3..12f2f29a91b 100644 --- a/src/sparseml/yolov8/validators.py +++ b/src/sparseml/yolov8/validators.py @@ -17,6 +17,7 @@ import torch from tqdm import tqdm +from sparseml.pytorch.utils import detach from ultralytics.nn.autobackend import AutoBackend from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset from ultralytics.yolo.engine.validator import BaseValidator @@ -136,7 +137,7 @@ def __call__(self, trainer=None, model=None): # During QAT the resulting preds are grad required, breaking # the update metrics function. - detached_preds = [p.detach() for p in preds] + detached_preds = detach(preds) self.update_metrics(detached_preds, batch) if self.args.plots and batch_i < 3: