Skip to content

Commit

Permalink
Merge pull request #395 from asogaard/dataset-paths
Browse files Browse the repository at this point in the history
Dataset paths
  • Loading branch information
asogaard committed Jan 26, 2023
2 parents dbaa885 + a26d5d7 commit ebff360
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 3 deletions.
2 changes: 1 addition & 1 deletion configs/datasets/test_data_sqlite.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
path: /groups/icecube/asogaard/work/development/graphnet/data/tests/sqlite/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.db
path: $GRAPHNET/data/tests/sqlite/oscNext_genie_level7_v02/oscNext_genie_level7_v02_first_5_frames.db
pulsemaps:
- SRTInIcePulses
features:
Expand Down
2 changes: 1 addition & 1 deletion configs/datasets/training_example_data_parquet.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
path: /groups/icecube/asogaard/work/development/graphnet/data/examples/parquet/prometheus/prometheus-events.parquet
path: $GRAPHNET/data/examples/parquet/prometheus/prometheus-events.parquet
pulsemaps:
- total
features:
Expand Down
2 changes: 1 addition & 1 deletion configs/datasets/training_example_data_sqlite.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
path: /groups/icecube/asogaard/work/development/graphnet/data/examples/sqlite/prometheus/prometheus-events.db
path: $GRAPHNET/data/examples/sqlite/prometheus/prometheus-events.db
pulsemaps:
- total
features:
Expand Down
19 changes: 19 additions & 0 deletions src/graphnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.utils.data import ConcatDataset
from torch_geometric.data import Data

from graphnet.constants import GRAPHNET_ROOT_DIR
from graphnet.utilities.config import (
Configurable,
DatasetConfig,
Expand Down Expand Up @@ -105,6 +106,21 @@ def _construct_dataset_from_list_of_strings(

return cls.concatenate(datasets)

@classmethod
def _resolve_graphnet_paths(
cls, path: Union[str, List[str]]
) -> Union[str, List[str]]:
if isinstance(path, list):
return [cast(str, cls._resolve_graphnet_paths(p)) for p in path]

assert isinstance(path, str)
return (
path.replace("$graphnet", GRAPHNET_ROOT_DIR)
.replace("$GRAPHNET", GRAPHNET_ROOT_DIR)
.replace("${graphnet}", GRAPHNET_ROOT_DIR)
.replace("${GRAPHNET}", GRAPHNET_ROOT_DIR)
)

@save_dataset_config
def __init__(
self,
Expand Down Expand Up @@ -177,6 +193,9 @@ def __init__(
assert isinstance(features, (list, tuple))
assert isinstance(truth, (list, tuple))

# Resolve reference to `$GRAPHNET` in path(s)
path = self._resolve_graphnet_paths(path)

# Member variable(s)
self._path = path
self._selection = None
Expand Down

0 comments on commit ebff360

Please sign in to comment.