Skip to content

Commit

Permalink
Merge pull request #518 from ChrisCummins/optional-dataset-site-data
Browse files Browse the repository at this point in the history
[datasets] Make the `site_data_base` path optional.
  • Loading branch information
ChrisCummins authored Dec 15, 2021
2 parents 283941e + e8676f3 commit 362a5f5
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 13 deletions.
38 changes: 32 additions & 6 deletions compiler_gym/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
name: str,
description: str,
license: str, # pylint: disable=redefined-builtin
site_data_base: Path,
site_data_base: Optional[Path] = None,
benchmark_class=Benchmark,
references: Optional[Dict[str, str]] = None,
deprecated: Optional[str] = None,
Expand All @@ -61,8 +61,15 @@ def __init__(
:param license: The name of the dataset's license.
:param site_data_base: The base path of a directory that will be used to
store installed files.
:param site_data_base: An optional directory that can be used by the
dataset to house the "site data", i.e. persistent files on disk. The
site data directory is a subdirectory of this :code:`site_data_base`
path, which can be shared by multiple datasets. If not provided, the
:attr:`dataset.site_data_path
<compiler_gym.datasets.Dataset.site_data_path>` attribute will raise
an error. Use :attr:`dataset.has_site_data
<compiler_gym.datasets.Dataset.has_site_data>` to check if a site
data path was set.
:param benchmark_class: The class to use when instantiating benchmarks.
It must have the same constructor signature as :class:`Benchmark
Expand Down Expand Up @@ -110,8 +117,11 @@ def __init__(
self.benchmark_class = benchmark_class

# Set up the site data name.
basename = components.group("dataset_name")
self._site_data_path = Path(site_data_base).resolve() / self.protocol / basename
if site_data_base:
basename = components.group("dataset_name")
self._site_data_path = (
Path(site_data_base).resolve() / self.protocol / basename
)

def __repr__(self):
return self.name
Expand Down Expand Up @@ -212,14 +222,27 @@ def validatable(self) -> str:
"""
return self._validatable

@property
def has_site_data(self) -> bool:
"""Return whether the dataset has a site data directory.
:type: bool
"""
return hasattr(self, "_site_data_path")

@property
def site_data_path(self) -> Path:
"""The filesystem path used to store persistent dataset files.
This directory may not exist.
:type: Path
:raises ValueError: If no site data path was specified at constructor
time.
"""
if not self.has_site_data:
raise ValueError(f"Dataset has no site data path: {self.name}")
return self._site_data_path

@property
Expand All @@ -228,6 +251,9 @@ def site_data_size_in_bytes(self) -> int:
:type: int
"""
if not self.has_site_data:
return 0

if not self.site_data_path.is_dir():
return 0

Expand Down Expand Up @@ -314,7 +340,7 @@ def uninstall(self) -> None:
<compiler_gym.datasets.Dataset.install>`. The dataset can still be used
after calling this method.
"""
if self.site_data_path.is_dir():
if self.has_site_data() and self.site_data_path.is_dir():
shutil.rmtree(self.site_data_path)

def benchmarks(self) -> Iterable[Benchmark]:
Expand Down
4 changes: 1 addition & 3 deletions compiler_gym/envs/loop_tool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from compiler_gym.datasets import Benchmark, Dataset, benchmark
from compiler_gym.spaces import Reward
from compiler_gym.util.registration import register
from compiler_gym.util.runfiles_path import runfiles_path, site_data_path
from compiler_gym.util.runfiles_path import runfiles_path

LOOP_TOOL_SERVICE_BINARY: Path = runfiles_path(
"compiler_gym/envs/loop_tool/service/compiler_gym-loop_tool-service"
Expand Down Expand Up @@ -56,7 +56,6 @@ def __init__(self, *args, **kwargs):
name="benchmark://loop_tool-cuda-v0",
license="MIT",
description="loop_tool dataset",
site_data_base=site_data_path("loop_tool_dataset"),
)

def benchmark_uris(self) -> Iterable[str]:
Expand All @@ -72,7 +71,6 @@ def __init__(self, *args, **kwargs):
name="benchmark://loop_tool-cpu-v0",
license="MIT",
description="loop_tool dataset",
site_data_base=site_data_path("loop_tool_dataset"),
)

def benchmark_uris(self) -> Iterable[str]:
Expand Down
3 changes: 1 addition & 2 deletions examples/example_compiler_gym_service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from compiler_gym.datasets import Benchmark, Dataset
from compiler_gym.spaces import Reward
from compiler_gym.util.registration import register
from compiler_gym.util.runfiles_path import runfiles_path, site_data_path
from compiler_gym.util.runfiles_path import runfiles_path

EXAMPLE_CC_SERVICE_BINARY: Path = runfiles_path(
"examples/example_compiler_gym_service/service_cc/compiler_gym-example-service-cc"
Expand Down Expand Up @@ -58,7 +58,6 @@ def __init__(self, *args, **kwargs):
name="benchmark://example-v0",
license="MIT",
description="An example dataset",
site_data_base=site_data_path("example_dataset"),
)
self._benchmarks = {
"benchmark://example-v0/foo": Benchmark.from_file_contents(
Expand Down
2 changes: 0 additions & 2 deletions examples/example_compiler_gym_service/demo_without_bazel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from compiler_gym.spaces import Reward
from compiler_gym.util.logging import init_logging
from compiler_gym.util.registration import register
from compiler_gym.util.runfiles_path import site_data_path

EXAMPLE_PY_SERVICE_BINARY: Path = Path(
"example_compiler_gym_service/service_py/example_service.py"
Expand Down Expand Up @@ -65,7 +64,6 @@ def __init__(self, *args, **kwargs):
name="benchmark://example-v0",
license="MIT",
description="An example dataset",
site_data_base=site_data_path("example_dataset"),
)
self._benchmarks = {
"benchmark://example-v0/foo": Benchmark.from_file_contents(
Expand Down
25 changes: 25 additions & 0 deletions tests/datasets/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,30 @@ def test_logger_is_deprecated():
dataset.logger


def test_with_site_data():
"""Test the dataset property values."""
dataset = Dataset(
name="benchmark://test-v0",
description="A test dataset",
license="MIT",
site_data_base="test",
)
assert dataset.has_site_data


def test_without_site_data():
"""Test the dataset property values."""
dataset = Dataset(
name="benchmark://test-v0",
description="A test dataset",
license="MIT",
)
assert not dataset.has_site_data
with pytest.raises(
ValueError, match=r"^Dataset has no site data path: benchmark://test-v0$"
):
dataset.site_data_path # noqa


if __name__ == "__main__":
main()

0 comments on commit 362a5f5

Please sign in to comment.