Skip to content

Commit

Permalink
addressing comments from review
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Apr 14, 2023
1 parent 2149bb2 commit 752ed10
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 33 deletions.
4 changes: 2 additions & 2 deletions src/sparseml/yolov8/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@
help="cuda device, i.e. 0 or 0,1,2,3 or cpu. Device to run on",
)
@click.option(
"--datasets-dir",
"--dataset_path",
type=str,
default=None,
help="Path to override default datasets dir.",
help="Path to override default dataset path.",
)
def main(**kwargs):

Expand Down
11 changes: 6 additions & 5 deletions src/sparseml/yolov8/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import click
from sparseml.yolov8.trainers import SparseYOLO
from sparseml.yolov8.utils import data_from_datasets_dir
from sparseml.yolov8.utils import data_from_dataset_path


logger = logging.getLogger()
Expand Down Expand Up @@ -211,14 +211,15 @@
"--copy-paste", type=float, default=0.0, help="segment copy-paste (probability)"
)
@click.option(
"--datasets-dir",
"--dataset_path",
type=str,
default=None,
help="Path to override default datasets dir.",
help="Path to override default dataset path.",
)
def main(**kwargs):
if kwargs["datasets_dir"] is not None:
kwargs["data"] = data_from_datasets_dir(kwargs["data"], kwargs["datasets_dir"])
if kwargs["dataset_path"] is not None:
kwargs["data"] = data_from_dataset_path(kwargs["data"], kwargs["dataset_path"])
del kwargs["dataset_path"]

model = SparseYOLO(kwargs["model"])
model.train(**kwargs)
Expand Down
14 changes: 7 additions & 7 deletions src/sparseml/yolov8/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from sparseml.yolov8.utils import (
check_coco128_segmentation,
create_grad_sampler,
data_from_datasets_dir,
data_from_dataset_path,
)
from sparseml.yolov8.utils.export_samples import export_sample_inputs_outputs
from sparseml.yolov8.validators import (
Expand Down Expand Up @@ -659,9 +659,9 @@ def export(self, **kwargs):
if kwargs["device"] is not None and "cpu" not in kwargs["device"]:
overrides["device"] = "cuda:" + kwargs["device"]
overrides["deterministic"] = kwargs["deterministic"]
if kwargs["datasets_dir"] is not None:
overrides["data"] = data_from_datasets_dir(
overrides["data"], kwargs["datasets_dir"]
if kwargs["dataset_path"] is not None:
overrides["data"] = data_from_dataset_path(
overrides["data"], kwargs["dataset_path"]
)

trainer = self.TrainerClass(overrides=overrides)
Expand Down Expand Up @@ -719,9 +719,9 @@ def export(self, **kwargs):
if args["export_samples"]:
trainer_config = get_cfg(cfg=DEFAULT_SPARSEML_CONFIG_PATH)

if args["datasets_dir"] is not None:
args["data"] = data_from_datasets_dir(
args["data"], args["datasets_dir"]
if args["dataset_path"] is not None:
args["data"] = data_from_dataset_path(
args["data"], args["dataset_path"]
)
trainer_config.data = args["data"]
trainer_config.imgsz = args["imgsz"]
Expand Down
26 changes: 13 additions & 13 deletions src/sparseml/yolov8/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
__all__ = [
"check_coco128_segmentation",
"create_grad_sampler",
"data_from_datasets_dir",
"data_from_dataset_path",
]


Expand Down Expand Up @@ -79,21 +79,25 @@ def create_grad_sampler(
return grad_sampler


def data_from_datasets_dir(data: str, datasets_dir: str) -> str:
def data_from_dataset_path(data: str, dataset_path: str) -> str:
"""
Given a dataset name, fetch the yaml config for the dataset
from the Ultralytics dataset repo, overwrite its 'path'
attribute (dataset root dir) to point to the `datasets_dir`
and finally save it to the `datasets_dir` directory.
attribute (dataset root dir) to point to the `dataset_path`
and finally save it to the current working directory.
This allows to create load data yaml config files that point
to the arbitrary directories on the disk.
:param data: name of the dataset (e.g. "coco.yaml")
:param datasets_dir: path to the directory where the dataset is expected to be
(and where the new yaml config will be saved)
:param dataset_path: path to the dataset directory
:return: a path to the new yaml config file
(saved in the current working directory)
"""
if os.path.basename(dataset_path) != os.path.splitext(data)[0]:
raise ValueError(
f"Dataset name (`data` argument): {data} "
f"does not match the `dataset_path` argument: {dataset_path}"
)
ultralytics_dataset_path = glob.glob(os.path.join(ROOT, "**", data), recursive=True)
if len(ultralytics_dataset_path) != 1:
raise ValueError(
Expand All @@ -103,13 +107,9 @@ def data_from_datasets_dir(data: str, datasets_dir: str) -> str:
ultralytics_dataset_path = ultralytics_dataset_path[0]
with open(ultralytics_dataset_path, "r") as f:
yaml_config = yaml.safe_load(f)
yaml_config["path"] = dataset_path

local_dataset_path = os.path.join(
datasets_dir, os.path.basename(yaml_config["path"])
)
yaml_config["path"] = local_dataset_path

yaml_save_path = os.path.join(datasets_dir, data)
yaml_save_path = os.path.join(os.getcwd(), data)

# save the new dataset yaml file
with open(yaml_save_path, "w") as outfile:
Expand Down
13 changes: 7 additions & 6 deletions src/sparseml/yolov8/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import click
from sparseml.yolov8.trainers import SparseYOLO
from sparseml.yolov8.utils import data_from_datasets_dir
from sparseml.yolov8.utils import data_from_dataset_path


@click.command(
Expand Down Expand Up @@ -71,14 +71,15 @@
)
@click.option("--plots", default=False, is_flag=True, help="show plots during training")
@click.option(
"--datasets-dir",
"--dataset_path",
type=str,
default="/home/ubuntu/damian/sparseml/funny_dir",
help="Path to override default datasets dir.",
default=None,
help="Path to override default datasets path.",
)
def main(**kwargs):
if kwargs["datasets_dir"] is not None:
kwargs["data"] = data_from_datasets_dir(kwargs["data"], kwargs["datasets_dir"])
if kwargs["dataset_path"] is not None:
kwargs["data"] = data_from_dataset_path(kwargs["data"], kwargs["dataset_path"])
del kwargs["dataset_path"]

model = SparseYOLO(kwargs["model"])
model.val(**kwargs)
Expand Down

0 comments on commit 752ed10

Please sign in to comment.