Skip to content

Commit

Permalink
[YOLOv8] Fix support for --dataset_dir argument (#1520)
Browse files Browse the repository at this point in the history
* working

* working

* addressing comments from review

* add more verbose error message

* Apply suggestions from code review

* Update helpers.py
  • Loading branch information
dbogunowicz authored and anmarques committed Apr 26, 2023
1 parent 6057002 commit b03cd91
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 33 deletions.
12 changes: 2 additions & 10 deletions src/sparseml/yolov8/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import click
from sparseml.yolov8.trainers import SparseYOLO
from ultralytics.yolo.utils import USER_CONFIG_DIR, get_settings, yaml_save


# Options generated from
Expand Down Expand Up @@ -82,18 +80,12 @@
help="cuda device, i.e. 0 or 0,1,2,3 or cpu. Device to run on",
)
@click.option(
"--datasets-dir",
"--dataset-path",
type=str,
default=None,
help="Path to override default datasets dir.",
help="Path to override default dataset path.",
)
def main(**kwargs):
if kwargs["datasets_dir"] is not None:
settings = get_settings()
settings["datasets_dir"] = os.path.abspath(
os.path.expanduser(kwargs["datasets_dir"])
)
yaml_save(USER_CONFIG_DIR / "settings.yaml", settings)

model = SparseYOLO(kwargs["model"])
model.export(**kwargs)
Expand Down
16 changes: 6 additions & 10 deletions src/sparseml/yolov8/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.

import logging
import os

import click
from sparseml.yolov8.trainers import SparseYOLO
from ultralytics.yolo.utils import USER_CONFIG_DIR, get_settings, yaml_save
from sparseml.yolov8.utils import data_from_dataset_path


logger = logging.getLogger()
Expand Down Expand Up @@ -212,18 +211,15 @@
"--copy-paste", type=float, default=0.0, help="segment copy-paste (probability)"
)
@click.option(
"--datasets-dir",
"--dataset-path",
type=str,
default=None,
help="Path to override default datasets dir.",
help="Path to override default dataset path.",
)
def main(**kwargs):
if kwargs["datasets_dir"] is not None:
settings = get_settings()
settings["datasets_dir"] = os.path.abspath(
os.path.expanduser(kwargs["datasets_dir"])
)
yaml_save(USER_CONFIG_DIR / "settings.yaml", settings)
if kwargs["dataset_path"] is not None:
kwargs["data"] = data_from_dataset_path(kwargs["data"], kwargs["dataset_path"])
del kwargs["dataset_path"]

model = SparseYOLO(kwargs["model"])
model.train(**kwargs)
Expand Down
16 changes: 14 additions & 2 deletions src/sparseml/yolov8/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
from sparseml.pytorch.utils.helpers import download_framework_model_by_recipe_type
from sparseml.pytorch.utils.logger import LoggerManager, PythonLogger, WANDBLogger
from sparseml.yolov8.modules import Bottleneck, Conv
from sparseml.yolov8.utils import check_coco128_segmentation, create_grad_sampler
from sparseml.yolov8.utils import (
check_coco128_segmentation,
create_grad_sampler,
data_from_dataset_path,
)
from sparseml.yolov8.utils.export_samples import export_sample_inputs_outputs
from sparseml.yolov8.validators import (
SparseClassificationValidator,
Expand Down Expand Up @@ -655,6 +659,11 @@ def export(self, **kwargs):
if kwargs["device"] is not None and "cpu" not in kwargs["device"]:
overrides["device"] = "cuda:" + kwargs["device"]
overrides["deterministic"] = kwargs["deterministic"]
if kwargs["dataset_path"] is not None:
overrides["data"] = data_from_dataset_path(
overrides["data"], kwargs["dataset_path"]
)

trainer = self.TrainerClass(overrides=overrides)
self.model = self.model.to(trainer.device)

Expand Down Expand Up @@ -710,9 +719,12 @@ def export(self, **kwargs):
if args["export_samples"]:
trainer_config = get_cfg(cfg=DEFAULT_SPARSEML_CONFIG_PATH)

if args["dataset_path"] is not None:
args["data"] = data_from_dataset_path(
args["data"], args["dataset_path"]
)
trainer_config.data = args["data"]
trainer_config.imgsz = args["imgsz"]

trainer = DetectionTrainer(trainer_config)
# inconsistency in name between
# validation and test sets
Expand Down
43 changes: 42 additions & 1 deletion src/sparseml/yolov8/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import os
import warnings
from argparse import Namespace
from typing import Any, Dict

import yaml

from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
from ultralytics.yolo.data.utils import ROOT
from ultralytics.yolo.engine.model import DetectionModel
from ultralytics.yolo.engine.trainer import BaseTrainer


__all__ = ["check_coco128_segmentation", "create_grad_sampler"]
__all__ = [
"check_coco128_segmentation",
"create_grad_sampler",
"data_from_dataset_path",
]


def check_coco128_segmentation(args: Namespace) -> Namespace:
Expand Down Expand Up @@ -69,3 +77,36 @@ def create_grad_sampler(
/ train_loader.batch_size,
)
return grad_sampler


def data_from_dataset_path(data: str, dataset_path: str) -> str:
"""
Given a dataset name, fetch the yaml config for the dataset
from the Ultralytics dataset repo, overwrite its 'path'
attribute (dataset root dir) to point to the `dataset_path`
and finally save it to the current working directory.
This allows to create load data yaml config files that point
to the arbitrary directories on the disk.
:param data: name of the dataset (e.g. "coco.yaml")
:param dataset_path: path to the dataset directory
:return: a path to the new yaml config file
(saved in the current working directory)
"""
ultralytics_dataset_path = glob.glob(os.path.join(ROOT, "**", data), recursive=True)
if len(ultralytics_dataset_path) != 1:
raise ValueError(
"Expected to find a single path to the "
f"dataset yaml file: {data}, but found {ultralytics_dataset_path}"
)
ultralytics_dataset_path = ultralytics_dataset_path[0]
with open(ultralytics_dataset_path, "r") as f:
yaml_config = yaml.safe_load(f)
yaml_config["path"] = dataset_path

yaml_save_path = os.path.join(os.getcwd(), data)

# save the new dataset yaml file
with open(yaml_save_path, "w") as outfile:
yaml.dump(yaml_config, outfile, default_flow_style=False)
return yaml_save_path
16 changes: 6 additions & 10 deletions src/sparseml/yolov8/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import click
from sparseml.yolov8.trainers import SparseYOLO
from ultralytics.yolo.utils import USER_CONFIG_DIR, get_settings, yaml_save
from sparseml.yolov8.utils import data_from_dataset_path


@click.command(
Expand Down Expand Up @@ -72,18 +71,15 @@
)
@click.option("--plots", default=False, is_flag=True, help="show plots during training")
@click.option(
"--datasets-dir",
"--dataset-path",
type=str,
default=None,
help="Path to override default datasets dir.",
help="Path to override default datasets path.",
)
def main(**kwargs):
if kwargs["datasets_dir"] is not None:
settings = get_settings()
settings["datasets_dir"] = os.path.abspath(
os.path.expanduser(kwargs["datasets_dir"])
)
yaml_save(USER_CONFIG_DIR / "settings.yaml", settings)
if kwargs["dataset_path"] is not None:
kwargs["data"] = data_from_dataset_path(kwargs["data"], kwargs["dataset_path"])
del kwargs["dataset_path"]

model = SparseYOLO(kwargs["model"])
if hasattr(model, "overrides"):
Expand Down

0 comments on commit b03cd91

Please sign in to comment.