Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#205 Add stochastic heuristics from Kirsch et al. #206

Merged
merged 6 commits into from
May 3, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions baal/active/heuristics/stochastics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import types

import numpy as np
import structlog
from scipy.special import softmax
from scipy.stats import rankdata

from baal.active.heuristics import AbstractHeuristic, Sequence

log = structlog.get_logger(__name__)
EPSILON = 1e-8


class StochasticHeuristic(AbstractHeuristic):
def __init__(self, base_heuristic: AbstractHeuristic, query_size):
"""Heuristic that is stochastic to improve diversity.

Common acquisition functions are heavily impacted by duplicates.
When using a `top-k` approache where the most
uncertain examples are selected, the acquisition function can select many duplicates.
Techniques such as BADGE (Ash et al, 2019) or BatchBALD (Kirsh et al. 2019)
are common solutions to this problem, but they are quite expensive.

Stochastic acquisitions are cheap to compute and get similar performances.

References:
Stochastic Batch Acquisition for Deep Active Learning, Kirsch et al. (2022)
https://arxiv.org/abs/2106.12059

Args:
base_heuristic: Heuristic to get uncertainty from before sampling.
query_size: These heuristics will return `query_size` items.
"""
# TODO handle reverse
super().__init__(reverse=False)
self._bh = base_heuristic
self.query_size = query_size

def get_ranks(self, predictions):
# Get the raw uncertainty from the base heuristic.
scores = self.get_scores(predictions)
# Create the distribution to sample from.
distributions = self._make_distribution(scores)
# Force normalization for np.random.choice
distributions = np.clip(distributions, 0)
distributions /= distributions.sum()

# TODO Seed?
if (distributions > 0).sum() < self.query_size:
log.warnings("Not enough values, return random")
distributions = np.ones_like(distributions) / len(distributions)
return (
np.random.choice(len(distributions), self.query_size, replace=False, p=distributions),
distributions,
)

def get_scores(self, predictions):
if isinstance(predictions, types.GeneratorType):
scores = self._bh.get_uncertainties_generator(predictions)
else:
scores = self._bh.get_uncertainties(predictions)
if isinstance(scores, Sequence):
scores = np.concatenate(scores)
return scores

def _make_distribution(self, scores: np.ndarray) -> np.ndarray:
raise NotImplementedError


class PowerSampling(StochasticHeuristic):
def __init__(self, base_heuristic: AbstractHeuristic, query_size, temperature=1.0):
"""Samples from the uncertainty distribution without modification beside
temperature scaling and normalization.

Stochastic heuristic that assumes that the uncertainty distribution
is positive and that items with near-zero uncertainty are uninformative.
Empirically worked the best in the paper.

References:
Stochastic Batch Acquisition for Deep Active Learning, Kirsch et al. (2022)
https://arxiv.org/abs/2106.12059

Args:
base_heuristic: Heuristic to get uncertainty from before sampling.
query_size: These heuristics will return `query_size` items.
temperature: Value to temper the uncertainty distribution before sampling.
"""
super().__init__(base_heuristic=base_heuristic, query_size=query_size)
self.temperature = temperature

def _make_distribution(self, scores: np.ndarray) -> np.ndarray:
scores = scores ** (1 / self.temperature)
scores = scores / scores.sum()
return scores


class GibbsSampling(StochasticHeuristic):
def __init__(self, base_heuristic: AbstractHeuristic, query_size, temperature=1.0):
"""Samples from the uncertainty distribution after applying softmax.

References:
Stochastic Batch Acquisition for Deep Active Learning, Kirsch et al. (2022)
https://arxiv.org/abs/2106.12059

Args:
base_heuristic: Heuristic to get uncertainty from before sampling.
query_size: These heuristics will return `query_size` items.
temperature: Value to temper the uncertainty distribution before sampling.
"""
super().__init__(base_heuristic=base_heuristic, query_size=query_size)
self.temperature = temperature

def _make_distribution(self, scores: np.ndarray) -> np.ndarray:
scores /= self.temperature
# scores dimensions is [N]
scores = softmax(scores)
return scores


class RankBasedSampling(StochasticHeuristic):
def __init__(self, base_heuristic: AbstractHeuristic, query_size, temperature=1.0):
"""Samples from the ranks of the uncertainty distribution.

References:
Stochastic Batch Acquisition for Deep Active Learning, Kirsch et al. (2022)
https://arxiv.org/abs/2106.12059

Args:
base_heuristic: Heuristic to get uncertainty from before sampling.
query_size: These heuristics will return `query_size` items.
temperature: Value to temper the uncertainty distribution before sampling.
"""
super().__init__(base_heuristic=base_heuristic, query_size=query_size)
self.temperature = temperature

def _make_distribution(self, scores: np.ndarray) -> np.ndarray:
rank = rankdata(-scores)
weights = rank ** (-1 / self.temperature)
normalized_weights: np.ndarray = weights / weights.sum()
return normalized_weights
56 changes: 7 additions & 49 deletions tests/active/heuristic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,64 +17,22 @@
Precomputed,
CombineHeuristics,
)
from tests.test_utils import make_fake_dist, make_3d_fake_dist, make_5d_fake_dist

N_ITERATIONS = 50
IMG_SIZE = 3
N_CLASS = 10


def chunks(l, n):
"""Yield successive n-sized chunks from l."""
for i in range(0, len(l), n):
yield l[i: i + n]


def _make_3d_fake_dist(means, stds, dims=10):
d = np.stack(
[_make_fake_dist(means, stds, dims=dims) for _ in range(N_ITERATIONS)]
) # 50 iterations
d = np.rollaxis(d, 0, 3)
# [n_sample, n_class, n_iter]
return d


def _make_5d_fake_dist(means, stds, dims=10):
d = np.stack(
[_make_3d_fake_dist(means, stds, dims=dims) for _ in range(IMG_SIZE ** 2)], -1
) # 3x3 image
b, c, i, hw = d.shape
d = np.reshape(d, [b, c, i, IMG_SIZE, IMG_SIZE])
d = np.rollaxis(d, 2, 5)
# [n_sample, n_class, H, W, iter]
return d


def _make_fake_dist(means, stds, dims=10):
"""
Create some fake discrete distributions
Args:
means: List of means
stds: List of standard deviations
dims: Dimensions of the distributions

Returns:
List of distributions
"""
n_trials = 100
distributions = []
for m, std in zip(means, stds):
dist = np.zeros([dims])
for i in range(n_trials):
dist[
np.round(np.clip(np.random.normal(m, std, 1), 0, dims - 1)).astype(int).item()
] += 1
distributions.append(dist / n_trials)
return np.array(distributions)


distribution_2d = _make_fake_dist([5, 6, 9], [0.1, 4, 2], dims=N_CLASS)
distributions_3d = _make_3d_fake_dist([5, 6, 9], [0.1, 4, 2], dims=N_CLASS)
distributions_5d = _make_5d_fake_dist([5, 6, 9], [0.1, 4, 2], dims=N_CLASS)



distribution_2d = make_fake_dist([5, 6, 9], [0.1, 4, 2], dims=N_CLASS)
distributions_3d = make_3d_fake_dist([5, 6, 9], [0.1, 4, 2], dims=N_CLASS)
distributions_5d = make_5d_fake_dist([5, 6, 9], [0.1, 4, 2], dims=N_CLASS)


@pytest.mark.parametrize(
Expand Down
37 changes: 37 additions & 0 deletions tests/active/stochastic_heuristic_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
import pytest
from scipy.stats import entropy

from baal.active.heuristics import BALD, Entropy
from baal.active.heuristics.stochastics import GibbsSampling, RankBasedSampling, PowerSampling

NUM_CLASSES = 10
NUM_ITERATIONS = 20
BATCH_SIZE = 32


@pytest.fixture
def sampled_predictions():
predictions = np.stack(
[np.histogram(np.random.rand(5), bins=np.linspace(-.5, .5, NUM_CLASSES + 1))[0] for _ in
range(BATCH_SIZE * NUM_ITERATIONS)]).reshape(
[BATCH_SIZE, NUM_ITERATIONS, NUM_CLASSES])
return np.rollaxis(predictions, -1, 1)


@pytest.mark.parametrize("stochastic_heuristic", [GibbsSampling, RankBasedSampling, PowerSampling])
@pytest.mark.parametrize("base_heuristic", [BALD, Entropy])
def test_stochastic_heuristic(stochastic_heuristic, base_heuristic, sampled_predictions):
heur_temp_1 = stochastic_heuristic(base_heuristic(), query_size=100, temperature=1.0)
heur_temp_10 = stochastic_heuristic(base_heuristic(), query_size=100, temperature=10.0)
heur_temp_05 = stochastic_heuristic(base_heuristic(), query_size=100, temperature=0.01)

scores = heur_temp_1.get_scores(sampled_predictions)

dist_temp_1, dist_temp_10, dist_temp_05 = (heur_temp_1._make_distribution(scores),
heur_temp_10._make_distribution(scores),
heur_temp_05._make_distribution(scores))

assert entropy(dist_temp_1) < entropy(dist_temp_10)
# NOTE: it is possible that this fails, as temp_1 can already have minimal entropy. This is unlikely.
assert entropy(dist_temp_1) > entropy(dist_temp_05)
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ def fn(module: nn.Module, input_shape):
pred1 = module(inp).detach().cpu().numpy()
return all(np.allclose(pred1, module(inp).detach().cpu().numpy()) for _ in range(5))
return fn


@pytest.fixture
def sampled_predictions():
return np.random.randn(100, 10, 20)
47 changes: 47 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np

N_ITERATIONS = 50
IMG_SIZE = 3


def make_3d_fake_dist(means, stds, dims=10):
d = np.stack(
[make_fake_dist(means, stds, dims=dims) for _ in range(N_ITERATIONS)]
) # 50 iterations
d = np.rollaxis(d, 0, 3)
# [n_sample, n_class, n_iter]
return d


def make_5d_fake_dist(means, stds, dims=10):
d = np.stack(
[make_3d_fake_dist(means, stds, dims=dims) for _ in range(IMG_SIZE ** 2)], -1
) # 3x3 image
b, c, i, hw = d.shape
d = np.reshape(d, [b, c, i, IMG_SIZE, IMG_SIZE])
d = np.rollaxis(d, 2, 5)
# [n_sample, n_class, H, W, iter]
return d


def make_fake_dist(means, stds, dims=10):
"""
Create some fake discrete distributions
Args:
means: List of means
stds: List of standard deviations
dims: Dimensions of the distributions

Returns:
List of distributions
"""
n_trials = 100
distributions = []
for m, std in zip(means, stds):
dist = np.zeros([dims])
for i in range(n_trials):
dist[
np.round(np.clip(np.random.normal(m, std, 1), 0, dims - 1)).astype(int).item()
] += 1
distributions.append(dist / n_trials)
return np.array(distributions)