Skip to content

Commit

Permalink
Experiment API v2 (#296)
Browse files Browse the repository at this point in the history
* WIP implementation

* Finish Adapter work, start test and editing experiments

* Add tests to HF

* Lot of documentation

* Mypy

* Lint

* Add test for Experiment.get_pool

* Update to autotqdm, improve example

* Improve base example

* Fix documentation

* Bump to 2.0.0
  • Loading branch information
Dref360 authored Jun 10, 2024
1 parent 5970ad1 commit 85c3866
Show file tree
Hide file tree
Showing 76 changed files with 211,052 additions and 273,513 deletions.
28 changes: 14 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ The framework consists of four main parts, as demonstrated in the flowchart belo
- ActiveLearningLoop

<p align="center">
<img src="docs/research/literature/images/Baalscheme.svg">
<img src="docs/learn/literature/images/Baalscheme.svg">
</p>

To get started, wrap your dataset in our _[**ActiveLearningDataset**](baal/active/dataset.py)_ class. This will ensure
Expand All @@ -114,19 +114,19 @@ In conclusion, your script should be similar to this:
dataset = ActiveLearningDataset(your_dataset)
dataset.label_randomly(INITIAL_POOL) # label some data
model = MCDropoutModule(your_model)
model = ModelWrapper(model, args=TrainingArgs(...))
active_loop = ActiveLearningLoop(dataset,
get_probabilities=model.predict_on_dataset,
heuristic=heuristics.BALD(),
iterations=20, # Number of MC sampling.
query_size=QUERY_SIZE) # Number of item to label.
for al_step in range(N_ALSTEP):
model.train_on_dataset(dataset)
metrics = model.test_on_dataset(test_dataset)
# Label the next most uncertain items.
if not active_loop.step():
# We're done!
break
wrapper = ModelWrapper(model, args=TrainingArgs(...))
experiment = ActiveLearningExperiment(
trainer=wrapper, # Huggingface or ModelWrapper to train
al_dataset=dataset, # Active learning dataset
eval_dataset=test_dataset, # Evaluation Dataset
heuristic=BALD(), # Uncertainty heuristic to use
query_size=100, # How many items to label per round.
iterations=20, # How many MC sampling to perform per item.
pool_size=None, # Optionally limit the size of the unlabelled pool.
criterion=None # Stopping criterion for the experiment.
)
# The experiment will run until all items are labelled.
metrics = experiment.start()
```

For a complete experiment, see _[experiments/vgg_mcdropout_cifar10.py](experiments/vgg_mcdropout_cifar10.py)_ .
Expand Down
2 changes: 1 addition & 1 deletion baal/active/active_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from . import heuristics
from .dataset import ActiveLearningDataset

log = structlog.get_logger(__name__)
log = structlog.get_logger("baal")
pjoin = os.path.join


Expand Down
4 changes: 2 additions & 2 deletions baal/active/dataset/pytorch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,12 @@ def label(self, index: Union[list, int], value: Optional[Any] = None) -> None:
elif self.can_label and val is None:
raise ValueError(
"""The dataset is able to label data, but no label was provided.
If this is a research setting, please set the
If this is a learn setting, please set the
`ActiveLearningDataset.can_label` to `False`.
"""
)
else:
# Regular research usecase.
# Regular learn usecase.
self.labelled_map[idx] = active_step
if val is not None:
warnings.warn(
Expand Down
6 changes: 5 additions & 1 deletion baal/active/heuristics/heuristics_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ def predict_on_dataset(
half=False,
verbose=True,
):
return super().predict_on_dataset(dataset, iterations, half, verbose).reshape([-1])
return (
super()
.predict_on_dataset(dataset, iterations, half, verbose)
.reshape([-1]) # type: ignore
)

def predict_on_batch(self, data, iterations=1):
"""Rank the predictions according to their uncertainties."""
Expand Down
2 changes: 1 addition & 1 deletion baal/active/heuristics/stochastics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from baal.active.heuristics import AbstractHeuristic, Sequence

log = structlog.get_logger(__name__)
log = structlog.get_logger("baal")
EPSILON = 1e-8


Expand Down
4 changes: 2 additions & 2 deletions baal/bayesian/consistent_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class ConsistentDropout(_DropoutNd):
"""
ConsistentDropout is useful when doing research.
ConsistentDropout is useful when doing learn.
It guarantees that while the masks are the same between batches
during inference. The masks are different inside the batch.
Expand Down Expand Up @@ -59,7 +59,7 @@ def eval(self):

class ConsistentDropout2d(_DropoutNd):
"""
ConsistentDropout is useful when doing research.
ConsistentDropout is useful when doing learn.
It guarantees that while the mask are the same between batches,
they are different inside the batch.
Expand Down
2 changes: 1 addition & 1 deletion baal/calibration/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from baal.modelwrapper import TrainingArgs
from baal.utils.metrics import ECE, ECE_PerCLs

log = structlog.get_logger("Calibrating...")
log = structlog.get_logger("baal")


class DirichletCalibrator(object):
Expand Down
20 changes: 20 additions & 0 deletions baal/experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import abc
from typing import Dict, List, Union

from numpy._typing import NDArray

from baal.active.dataset.base import Dataset


class FrameworkAdapter(abc.ABC):
def reset_weights(self):
raise NotImplementedError

def train(self, al_dataset: Dataset) -> Dict[str, float]:
raise NotImplementedError

def predict(self, dataset: Dataset, iterations: int) -> Union[NDArray, List[NDArray]]:
raise NotImplementedError

def evaluate(self, dataset: Dataset, average_predictions: int) -> Dict[str, float]:
raise NotImplementedError
121 changes: 121 additions & 0 deletions baal/experiments/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import itertools
from typing import Union, Optional, Any

import numpy as np
import structlog
from torch.utils.data import Subset
from tqdm.autonotebook import tqdm

from baal import ModelWrapper, ActiveLearningDataset
from baal.active.dataset.base import Dataset
from baal.active.heuristics import AbstractHeuristic
from baal.active.stopping_criteria import StoppingCriterion, LabellingBudgetStoppingCriterion
from baal.experiments import FrameworkAdapter
from baal.experiments.modelwrapper import ModelWrapperAdapter

try:
import transformers
from baal.transformers_trainer_wrapper import BaalTransformersTrainer
from baal.experiments.transformers import TransformersAdapter

TRANSFORMERS_AVAILABLE = True
except ImportError:
BaalTransformersTrainer = None # type: ignore
TransformersAdapter = None # type: ignore
TRANSFORMERS_AVAILABLE = False

log = structlog.get_logger("baal")


class ActiveLearningExperiment:
"""Experiment manager for Baal.
Takes care of:
1. Train the model on the label set.
2. Evaluate the model on the evaluation set.
3. Predict on the unlabelled examples.
4. Label the most uncertain examples.
5. Stop the experiment if finished.
Args:
trainer: Huggingface or ModelWrapper to train
al_dataset: Active learning dataset
eval_dataset: Evaluation Dataset
heuristic: Uncertainty heuristic to use
query_size: How many items to label per round.
iterations: How many MC sampling to perform per item.
pool_size: Optionally limit the size of the unlabelled pool.
criterion: Stopping criterion for the experiment.
"""

def __init__(
self,
trainer: Union[ModelWrapper, "BaalTransformersTrainer"],
al_dataset: ActiveLearningDataset,
eval_dataset: Dataset,
heuristic: AbstractHeuristic,
query_size: int = 100,
iterations: int = 20,
pool_size: Optional[int] = None,
criterion: Optional[StoppingCriterion] = None,
):
self.al_dataset = al_dataset
self.eval_dataset = eval_dataset
self.heuristic = heuristic
self.query_size = query_size
self.iterations = iterations
self.criterion = criterion or LabellingBudgetStoppingCriterion(
al_dataset, labelling_budget=al_dataset.n_unlabelled
)
self.pool_size = pool_size
self.adapter = self._get_adapter(trainer)

def start(self):
records = []
_start = len(self.al_dataset)
if _start == 0:
raise ValueError(
"No item labelled in the training set."
" Did you run `ActiveLearningDataset.label_randomly`?"
)
for _ in tqdm(
itertools.count(start=0), # Infinite counter to rely on Criterion
desc="Active Experiment",
# Upper bound estimation.
total=np.round(self.al_dataset.n_unlabelled // self.query_size),
):
self.adapter.reset_weights()
train_metrics = self.adapter.train(self.al_dataset)
eval_metrics = self.adapter.evaluate(
self.eval_dataset, average_predictions=self.iterations
)
pool = self._get_pool()
ranks, uncertainty = self.heuristic.get_ranks(
self.adapter.predict(pool, iterations=self.iterations)
)
self.al_dataset.label(ranks[: self.query_size])
records.append({**train_metrics, **eval_metrics})
if self.criterion.should_stop(eval_metrics, uncertainty):
log.info("Experiment complete", num_labelled=len(self.al_dataset) - _start)
return records

def _get_adapter(
self, trainer: Union[ModelWrapper, "BaalTransformersTrainer"]
) -> FrameworkAdapter:
if isinstance(trainer, ModelWrapper):
return ModelWrapperAdapter(trainer)
elif TRANSFORMERS_AVAILABLE and isinstance(trainer, BaalTransformersTrainer):
return TransformersAdapter(trainer)
raise ValueError(
f"{type(trainer)} is not a supported trainer."
" Baal supports ModelWrapper and BaalTransformersTrainer"
)

def _get_pool(self):
if self.pool_size is None:
return self.al_dataset.pool
pool = self.al_dataset.pool
if len(pool) < self.pool_size:
return pool
indices = np.random.choice(len(pool), min(len(pool), self.pool_size), replace=False)
return Subset(pool, indices)
28 changes: 28 additions & 0 deletions baal/experiments/modelwrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from copy import deepcopy
from typing import Dict, Union, List

from numpy._typing import NDArray

from baal import ModelWrapper
from baal.active.dataset.base import Dataset
from baal.experiments import FrameworkAdapter


class ModelWrapperAdapter(FrameworkAdapter):
def __init__(self, wrapper: ModelWrapper):
self.wrapper = wrapper
self._init_weight = deepcopy(self.wrapper.state_dict())

def reset_weights(self):
self.wrapper.load_state_dict(self._init_weight)

def train(self, al_dataset: Dataset) -> Dict[str, float]:
self.wrapper.train_on_dataset(al_dataset)
return self.wrapper.get_metrics("train")

def predict(self, dataset: Dataset, iterations: int) -> Union[NDArray, List[NDArray]]:
return self.wrapper.predict_on_dataset(dataset, iterations=iterations)

def evaluate(self, dataset: Dataset, average_predictions: int) -> Dict[str, float]:
self.wrapper.test_on_dataset(dataset, average_predictions=average_predictions)
return self.wrapper.get_metrics("test")
29 changes: 29 additions & 0 deletions baal/experiments/transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from copy import deepcopy
from typing import Dict, cast, List, Union

from numpy._typing import NDArray

from baal.active.dataset.base import Dataset
from baal.experiments import FrameworkAdapter
from baal.transformers_trainer_wrapper import BaalTransformersTrainer


class TransformersAdapter(FrameworkAdapter):
def __init__(self, wrapper: BaalTransformersTrainer):
self.wrapper = wrapper
self._init_weight = deepcopy(self.wrapper.model.state_dict())

def reset_weights(self):
self.wrapper.model.load_state_dict(self._init_weight)
self.wrapper.lr_scheduler = None
self.wrapper.optimizer = None

def train(self, al_dataset: Dataset) -> Dict[str, float]:
self.wrapper.train_dataset = al_dataset
return self.wrapper.train().metrics # type: ignore

def predict(self, dataset: Dataset, iterations: int) -> Union[NDArray, List[NDArray]]:
return self.wrapper.predict_on_dataset(dataset, iterations=iterations)

def evaluate(self, dataset: Dataset, average_predictions: int) -> Dict[str, float]:
return cast(Dict[str, float], self.wrapper.evaluate(dataset))
13 changes: 6 additions & 7 deletions baal/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Callable, Optional
from typing import Callable, Optional, Union, List

import numpy as np
import structlog
import torch
from numpy._typing import NDArray
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from tqdm import tqdm
from tqdm.autonotebook import tqdm

from baal.active.dataset.base import Dataset
from baal.metrics.mixin import MetricMixin
Expand All @@ -22,7 +23,7 @@
from baal.utils.metrics import Loss
from baal.utils.warnings import raise_warnings_cache_replicated

log = structlog.get_logger("ModelWrapper")
log = structlog.get_logger("baal")


def _stack_preds(out):
Expand Down Expand Up @@ -52,8 +53,7 @@ class ModelWrapper(MetricMixin):
Args:
model (nn.Module): The model to optimize.
criterion (Callable): A loss function.
replicate_in_memory (bool): Replicate in memory optional.
args (TrainingArgs): Model arguments for training/predicting.
"""

def __init__(self, model, args: TrainingArgs):
Expand Down Expand Up @@ -148,7 +148,6 @@ def train_and_test_on_datasets(
Args:
train_dataset (Dataset): Dataset to train on.
test_dataset (Dataset): Dataset to evaluate on.
optimizer (Optimizer): Optimizer to use during training.
return_best_weights (bool): If True, will keep the best weights and return them.
patience (Optional[int]): If provided, will use early stopping to stop after
`patience` epoch without improvement.
Expand Down Expand Up @@ -236,7 +235,7 @@ def predict_on_dataset(
iterations: int,
half=False,
verbose=True,
):
) -> Union[NDArray, List[NDArray]]:
"""
Use the model to predict on a dataset `iterations` time.
Expand Down
Loading

0 comments on commit 85c3866

Please sign in to comment.