Skip to content

Commit

Permalink
Expose model_selection_n_permutations parameter. (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
mlondschien authored Jan 13, 2022
1 parent 0a0eca8 commit 86daae2
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

- Upgrade `biosphere` to `0.2.1` fixing a bug in `RandomForest`.

**Other changes:**

- New parameter `model_selection_n_permutations`.

## 0.4.0 - (2021-01-11)

**New features:**
Expand Down
2 changes: 2 additions & 0 deletions changeforest-py/changeforest/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(
minimal_relative_segment_length=None,
minimal_gain_to_split=None,
model_selection_alpha=None,
model_selection_n_permutations=None,
number_of_wild_segments=None,
seeded_segments_alpha=None,
seed=None,
Expand All @@ -23,6 +24,7 @@ def __init__(
)
self.minimal_gain_to_split = _to_float(minimal_gain_to_split)
self.model_selection_alpha = _to_float(model_selection_alpha)
self.model_selection_n_permutations = _to_int(model_selection_n_permutations)
self.number_of_wild_segments = _to_int(number_of_wild_segments)
self.seeded_segments_alpha = _to_float(seeded_segments_alpha)
self.seed = _to_int(seed)
Expand Down
6 changes: 6 additions & 0 deletions changeforest-py/src/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ pub fn control_from_pyobj(py: Python, obj: Option<PyObject>) -> PyResult<Control
}
};

if let Ok(pyvalue) = obj.getattr(py, "model_selection_n_permutations") {
if let Ok(value) = pyvalue.extract::<usize>(py) {
control = control.with_model_selection_n_permutations(value);
}
};

if let Ok(pyvalue) = obj.getattr(py, "number_of_wild_segments") {
if let Ok(value) = pyvalue.extract::<usize>(py) {
control = control.with_number_of_wild_segments(value);
Expand Down
1 change: 1 addition & 0 deletions changeforest-py/tests/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
("X_test", "bs", "random_forest", {"random_forest_n_trees": 100}, [5]),
("X_correlated", "bs", "random_forest", {"random_forest_max_depth": 1}, []),
("X_correlated", "bs", "random_forest", {"random_forest_max_depth": 2}, [49]),
("iris", "bs", "random_forest", {"model_selection_n_permutations": 10}, []),
],
)
def test_control_model_selection_parameters(
Expand Down
3 changes: 3 additions & 0 deletions changeforest-r/R/control.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Control = R6::R6Class(
minimal_relative_segment_length = NULL,
minimal_gain_to_split = NULL,
model_selection_alpha = NULL,
model_selection_n_permutations = NULL,
number_of_wild_segments = NULL,
seeded_segments_alpha = NULL,
seed = NULL,
Expand All @@ -20,6 +21,7 @@ Control = R6::R6Class(
minimal_relative_segment_length = NULL,
minimal_gain_to_split = NULL,
model_selection_alpha = NULL,
model_selection_n_permutations = NULL,
number_of_wild_segments = NULL,
seeded_segments_alpha = NULL,
seed = NULL,
Expand All @@ -31,6 +33,7 @@ Control = R6::R6Class(
self$minimal_relative_segment_length = minimal_relative_segment_length
self$minimal_gain_to_split = minimal_gain_to_split
self$model_selection_alpha = model_selection_alpha
self$model_selection_n_permutations = model_selection_n_permutations
self$number_of_wild_segments = number_of_wild_segments
self$seeded_segments_alpha = seeded_segments_alpha
self$seed = seed
Expand Down
8 changes: 8 additions & 0 deletions changeforest-r/src/rust/src/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ impl<'a> FromRobj<'a> for MyControl {
control = control.with_model_selection_alpha(value);
}

if let Some(value) = robj
.dollar("model_selection_n_permutations")
.unwrap()
.as_real()
{
control = control.with_model_selection_n_permutations(value as usize);
}

// as_integer does not seem to work.
if let Some(value) = robj.dollar("number_of_wild_segments").unwrap().as_real() {
control = control.with_number_of_wild_segments(value as usize);
Expand Down
3 changes: 3 additions & 0 deletions changeforest-r/tests/testthat/test-control.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,8 @@ test_that("control", {
expect_lists_equal(changeforest(X, "random_forest", "bs", Control$new(random_forest_n_trees=1))$split_points(), c())
expect_lists_equal(changeforest(X, "random_forest", "bs", Control$new(random_forest_n_trees=1))$split_points(), c())
expect_lists_equal(changeforest(X, "random_forest", "bs", Control$new(random_forest_n_trees=10))$split_points(), c(3, 5))

# model_selection_n_permutations
expect_lists_equal(changeforest(X_iris, "random_forest", "bs", Control$new(model_selection_n_permutations=10))$split_points(), c())
})

12 changes: 12 additions & 0 deletions src/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ pub struct Control {
/// Type two error in model selection to be approximated. Relevant for classifier
/// based changepoint detection.
pub model_selection_alpha: f64,
/// Number of permutations for model selection in classifier-based change point
/// detection.
pub model_selection_n_permutations: usize,
/// Number of randomly drawn segments. Corresponds to parameter `M` in
/// https://arxiv.org/pdf/1411.0858.pdf.
pub number_of_wild_segments: usize,
Expand All @@ -35,6 +38,7 @@ impl Control {
minimal_relative_segment_length: 0.1,
minimal_gain_to_split: 0.1,
model_selection_alpha: 0.05,
model_selection_n_permutations: 99,
number_of_wild_segments: 100,
seeded_segments_alpha: std::f64::consts::FRAC_1_SQRT_2, // 1 / sqrt(2)
seed: 0,
Expand Down Expand Up @@ -75,6 +79,14 @@ impl Control {
self
}

pub fn with_model_selection_n_permutations(
mut self,
model_selection_n_permutations: usize,
) -> Self {
self.model_selection_n_permutations = model_selection_n_permutations;
self
}

pub fn with_number_of_wild_segments(mut self, number_of_wild_segments: usize) -> Self {
self.number_of_wild_segments = number_of_wild_segments;
self
Expand Down
5 changes: 2 additions & 3 deletions src/gain/classifier_gain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ where
/// maximal gain to compute a p-value.
fn model_selection(&self, optimizer_result: &OptimizerResult) -> ModelSelectionResult {
let mut rng = StdRng::seed_from_u64(self.control().seed);
let n_permutations = 99;

let mut max_gain = -f64::INFINITY;
let mut deltas: Vec<Array1<f64>> = Vec::with_capacity(3);
Expand All @@ -69,7 +68,7 @@ where
let mut p_value: u32 = 1;
let segment_length = optimizer_result.stop - optimizer_result.start;

for _ in 0..n_permutations {
for _ in 0..self.control().model_selection_n_permutations {
let mut values = likelihood_0.clone();

// Test if for any jdx=1,2,3 the gain (likelihood_0[jdx] + cumsum(deltas[jdx]))
Expand All @@ -90,7 +89,7 @@ where

// Up to here p_value is # of permutations for which the max_gain is higher than
// the non-permuted max_gain. From this create a true p_value.
let p_value = p_value as f64 / (n_permutations + 1) as f64;
let p_value = p_value as f64 / (self.control().model_selection_n_permutations + 1) as f64;
let is_significant = p_value < self.control().model_selection_alpha;

ModelSelectionResult {
Expand Down

0 comments on commit 86daae2

Please sign in to comment.