diff --git a/src/sparseml/pytorch/utils/helpers.py b/src/sparseml/pytorch/utils/helpers.py index 31aee325f8b..7604af5e0ac 100644 --- a/src/sparseml/pytorch/utils/helpers.py +++ b/src/sparseml/pytorch/utils/helpers.py @@ -100,6 +100,7 @@ "MEMORY_BOUNDED", "memory_aware_threshold", "download_framework_model_by_recipe_type", + "detach", ] @@ -1162,3 +1163,14 @@ def download_framework_model_by_recipe_type( framework_model = zoo_model.training.default.get_file(model_name) return framework_model.path + + +def detach(x: Union[torch.Tensor, List, Tuple]): + if isinstance(x, torch.Tensor): + return x.detach() + elif isinstance(x, List): + return [detach(e) for e in x] + elif isinstance(x, Tuple): + return tuple([detach(e) for e in x]) + else: + raise ValueError("Unexpected type to detach") diff --git a/src/sparseml/yolov8/trainers.py b/src/sparseml/yolov8/trainers.py index 6a7008a2f86..0f2e51ce511 100644 --- a/src/sparseml/yolov8/trainers.py +++ b/src/sparseml/yolov8/trainers.py @@ -593,7 +593,8 @@ def _load(self, weights: str): self.model = self.ModelClass(dict(self.ckpt["model_yaml"])) else: self.model = self.ModelClass(dict(self.ckpt["model"].yaml)) - if "recipe" in self.ckpt: + + if "recipe" in self.ckpt and self.ckpt["recipe"]: manager = ScheduledModifierManager.from_yaml(self.ckpt["recipe"]) epoch = self.ckpt.get("epoch", -1) if epoch < 0: diff --git a/src/sparseml/yolov8/validators.py b/src/sparseml/yolov8/validators.py index 27067c891c3..12f2f29a91b 100644 --- a/src/sparseml/yolov8/validators.py +++ b/src/sparseml/yolov8/validators.py @@ -17,6 +17,7 @@ import torch from tqdm import tqdm +from sparseml.pytorch.utils import detach from ultralytics.nn.autobackend import AutoBackend from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset from ultralytics.yolo.engine.validator import BaseValidator @@ -136,7 +137,7 @@ def __call__(self, trainer=None, model=None): # During QAT the resulting preds are grad required, breaking # the update metrics function. - detached_preds = [p.detach() for p in preds] + detached_preds = detach(preds) self.update_metrics(detached_preds, batch) if self.args.plots and batch_i < 3: