Skip to content

Commit

Permalink
Create dataloaders in seperate function to allow for customization
Browse files Browse the repository at this point in the history
  • Loading branch information
Niclas Doll committed Nov 24, 2022
1 parent 6a6be58 commit b4f72a7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
14 changes: 13 additions & 1 deletion active/core/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ def _sample(self, query_size:int, pool:Dataset) -> None:
# set attribute
self._indices = indices

def dataloader(self, data:Dataset, **kwargs) -> DataLoader:
""" Create the dataloader for a given dataset with some specific configuration.
Args:
data (Dataset): dataset to use
**kwargs (Any): keyword arguments passed to the dataloader
Returns:
loader (DataLoader): dataloader from given dataset and configuration
"""
return DataLoader(data, **kwargs)

def query(
self,
pool:Dataset,
Expand All @@ -122,7 +134,7 @@ def query(
# check query size
assert query_size <= len(pool), "Query size (%i) larger than pool (%i)" % (query_size, len(pool))
# create dataloader
loader = DataLoader(
loader = self.dataloader(
pool,
batch_size=batch_size,
shuffle=False,
Expand Down
16 changes: 14 additions & 2 deletions active/helpers/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ def val_dataset(self) -> ConcatDataset:
else:
return self.train_dataset

def dataloader(self, data:Dataset, **kwargs) -> DataLoader:
""" Create the dataloader for a given dataset with some specific configuration.
Args:
data (Dataset): dataset to use
**kwargs (Any): keyword arguments passed to the dataloader
Returns:
loader (DataLoader): dataloader from given dataset and configuration
"""
return DataLoader(data, **kwargs)

def _reset(self):
""" Event handler to reset the engine, i.e. re-initialize the model,
reset optimizer and scheduler states and clear the sampled datasets.
Expand Down Expand Up @@ -144,8 +156,8 @@ def step(self, samples:Dataset):
self.fire_event(ActiveLearningEvents.DATA_SAMPLING_COMPLETED)

# create dataloaders and update validation loader in trainer
train_loader = DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=True)
self.trainer.val_loader = DataLoader(self.val_dataset, batch_size=self.eval_batch_size, shuffle=False)
train_loader = self.dataloader(self.train_dataset, batch_size=self.train_batch_size, shuffle=True)
self.trainer.val_loader = self.dataloader(self.val_dataset, batch_size=self.eval_batch_size, shuffle=False)
# run training
self.trainer.run(train_loader, **self.trainer_run_kwargs)

Expand Down

0 comments on commit b4f72a7

Please sign in to comment.