diff --git a/src/sparseml/yolov8/trainers.py b/src/sparseml/yolov8/trainers.py index cada1792fac..46398bcfb94 100644 --- a/src/sparseml/yolov8/trainers.py +++ b/src/sparseml/yolov8/trainers.py @@ -754,12 +754,26 @@ def export(self, **kwargs): deployment_folder = exporter.create_deployment_folder(onnx_model_name=name) if args["export_samples"]: trainer_config = get_cfg(cfg=DEFAULT_SPARSEML_CONFIG_PATH) + # First check if the yaml exists locally + if os.path.exists(args["data"]): + trainer_config.data = args["data"] + else: + # If it does not exist, may be a uralytics provided yaml. Try + # downloading and updating path to dataset_path + # Does this case actually happen? I.e. is args["data"] ever not a + # checkpointed local yaml path? + try: + if args["dataset_path"] is not None: + args["data"] = data_from_dataset_path( + args["data"], args["dataset_path"] + ) + trainer_config.data = args["data"] + except ValueError: + LOGGER.info( + f"yaml in {args['data']} could not be found. " + "Using default config" + ) - if args["dataset_path"] is not None: - args["data"] = data_from_dataset_path( - args["data"], args["dataset_path"] - ) - trainer_config.data = args["data"] trainer_config.imgsz = args["imgsz"] trainer = DetectionTrainer(trainer_config) # inconsistency in name between