Skip to content

Commit

Permalink
Merge pull request #481 from asogaard/concatdataset-to-ensembledataset
Browse files Browse the repository at this point in the history
Replace ConcatDataset with EnsembleDataset
  • Loading branch information
asogaard committed Apr 14, 2023
2 parents a8fec86 + 02a717a commit 484e72c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
12 changes: 5 additions & 7 deletions src/graphnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
Iterable,
)

from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import ConcatDataset
from torch_geometric.data import Data

from graphnet.constants import GRAPHNET_ROOT_DIR
Expand Down Expand Up @@ -46,9 +44,9 @@ def from_config( # type: ignore[override]
source: Union[DatasetConfig, str],
) -> Union[
"Dataset",
ConcatDataset,
"EnsembleDataset",
Dict[str, "Dataset"],
Dict[str, ConcatDataset],
Dict[str, "EnsembleDataset"],
]:
"""Construct `Dataset` instance from `source` configuration."""
if isinstance(source, str):
Expand All @@ -75,9 +73,9 @@ def from_config( # type: ignore[override]
def concatenate(
cls,
datasets: List["Dataset"],
) -> ConcatDataset:
) -> "EnsembleDataset":
"""Concatenate multiple `Dataset`s into one instance."""
return ConcatDataset(datasets)
return EnsembleDataset(datasets)

@classmethod
def _construct_datasets_from_dict(
Expand All @@ -90,7 +88,7 @@ def _construct_datasets_from_dict(
for key, selection in selections.items():
config.selection = selection
dataset = Dataset.from_config(config)
assert isinstance(dataset, (Dataset, ConcatDataset))
assert isinstance(dataset, (Dataset, EnsembleDataset))
datasets[key] = dataset

# Reset `selections`.
Expand Down
4 changes: 2 additions & 2 deletions src/graphnet/utilities/config/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ def __init__(self, **data: Any) -> None:
(...)
}
>>> dataset.config.dump("dataset.yml")
>>> datasets: Dict[str, ConcatDataset] = Dataset.from_config(
>>> datasets: Dict[str, EnsembleDataset] = Dataset.from_config(
"dataset.yml"
)
>>> datasets
{
"train": ConcatDataset(...),
"train": EnsembleDataset(...),
(...)
}
Expand Down

0 comments on commit 484e72c

Please sign in to comment.