-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
514 add channel masking #554
Changes from all commits
dda0b97
ec17ba9
3bd7f76
3aa167e
f7b8616
289275f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
|
||
|
||
def generate_masks(input_data: np.array, number_of_masks: int, p_keep: float = 0.5): | ||
"""Generate masks for time series data given a probability of keeping any time step unmasked. | ||
"""Generate masks for time series data given a probability of keeping any time step or channel unmasked. | ||
|
||
Args: | ||
input_data: Timeseries data to be explained. | ||
|
@@ -14,9 +14,39 @@ def generate_masks(input_data: np.array, number_of_masks: int, p_keep: float = 0 | |
Returns: | ||
Single array containing all masks where the first dimension represents the batch. | ||
""" | ||
series_length = input_data.shape[0] | ||
number_of_steps_masked = _determine_number_of_steps_masked(p_keep, series_length) | ||
if input_data.shape[-1] == 1: # univariate data | ||
return generate_time_step_masks(input_data, number_of_masks, p_keep) | ||
|
||
number_of_channel_masks = number_of_masks // 3 | ||
number_of_time_step_masks = number_of_channel_masks | ||
number_of_combined_masks = number_of_masks - number_of_time_step_masks - number_of_channel_masks | ||
|
||
time_step_masks = generate_time_step_masks(input_data, number_of_time_step_masks, p_keep) | ||
channel_masks = generate_channel_masks(input_data, number_of_channel_masks, p_keep) | ||
number_of_combined_masks = generate_time_step_masks(input_data, number_of_combined_masks, | ||
p_keep) * generate_channel_masks(input_data, | ||
number_of_combined_masks, | ||
p_keep) | ||
|
||
return np.concatenate([time_step_masks, channel_masks, number_of_combined_masks], axis=0) | ||
|
||
|
||
def generate_channel_masks(input_data: np.ndarray, number_of_masks: int, p_keep: float): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the duplication part you mentioned in the top post I think, just a reminder for ourselves, that it will be solved in PR #562, I will leave a comment there as well. |
||
"""Generate masks that mask one or multiple channels at a time.""" | ||
number_of_channels = input_data.shape[1] | ||
number_of_channels_masked = _determine_number_masked(p_keep, number_of_channels) | ||
masked_data_shape = [number_of_masks] + list(input_data.shape) | ||
masks = np.ones(masked_data_shape, dtype=np.bool) | ||
for i in range(number_of_masks): | ||
channels_to_mask = np.random.choice(number_of_channels, number_of_channels_masked, False) | ||
masks[i, :, channels_to_mask] = False | ||
return masks | ||
|
||
|
||
def generate_time_step_masks(input_data: np.ndarray, number_of_masks: int, p_keep: float): | ||
"""Generate masks that mask one or multiple time steps at a time.""" | ||
series_length = input_data.shape[0] | ||
number_of_steps_masked = _determine_number_masked(p_keep, series_length) | ||
masked_data_shape = [number_of_masks] + list(input_data.shape) | ||
masks = np.ones(masked_data_shape, dtype=np.bool) | ||
for i in range(number_of_masks): | ||
|
@@ -53,7 +83,7 @@ def _get_mask_value(data: np.array, mask_type: str) -> int: | |
raise ValueError(f'Unknown mask_type selected: {mask_type}') | ||
|
||
|
||
def _determine_number_of_steps_masked(p_keep: float, series_length: int) -> int: | ||
def _determine_number_masked(p_keep: float, series_length: int) -> int: | ||
user_requested_steps = int(np.round(series_length * (1 - p_keep))) | ||
if user_requested_steps == series_length: | ||
warnings.warn('Warning: p_keep chosen too low. Continuing with leaving 1 time step unmasked per mask.') | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import numpy as np | ||
import pytest | ||
from dianna.utils.maskers import generate_channel_masks | ||
from dianna.utils.maskers import generate_masks | ||
from dianna.utils.maskers import mask_data | ||
|
||
|
@@ -54,13 +55,76 @@ def test_mask_contains_correct_parts_are_mean_masked(): | |
|
||
|
||
def _get_univariate_input_data() -> np.array: | ||
"""Get some univariate test data.""" | ||
return np.zeros((10, 1)) + np.arange(10).reshape(10, 1) | ||
|
||
|
||
def _get_multivariate_input_data() -> np.array: | ||
return np.zeros((10, 6)) + np.arange(10).reshape(10, 1) | ||
def _get_multivariate_input_data(number_of_channels: int = 6) -> np.array: | ||
"""Get some multivariate test data.""" | ||
return np.row_stack([np.zeros((10, number_of_channels)), np.ones((10, number_of_channels))]) | ||
|
||
|
||
def _call_masking_function(input_data, number_of_masks=5, p_keep=.3, mask_type='mean'): | ||
"""Helper function with some defaults to call the code under test.""" | ||
masks = generate_masks(input_data, number_of_masks, p_keep=p_keep) | ||
return mask_data(input_data, masks, mask_type=mask_type) | ||
|
||
|
||
def test_channel_mask_has_correct_shape_multivariate(): | ||
"""Tests the output has the correct shape.""" | ||
number_of_masks = 15 | ||
input_data = _get_multivariate_input_data() | ||
|
||
result = generate_channel_masks(input_data, number_of_masks, 0.5) | ||
|
||
assert result.shape == tuple([number_of_masks] + list(input_data.shape)) | ||
|
||
|
||
def test_channel_mask_has_does_not_contain_conflicting_values(): | ||
"""Tests that only complete channels are masked.""" | ||
number_of_masks = 15 | ||
input_data = _get_multivariate_input_data() | ||
|
||
result = generate_channel_masks(input_data, number_of_masks, 0.5) | ||
|
||
unexpected_results = [] | ||
for mask_i, mask in enumerate(result): | ||
for channel_i in range(mask.shape[-1]): | ||
channel = mask[:, channel_i] | ||
value = channel[0] | ||
if (not value) in channel: | ||
unexpected_results.append( | ||
f'Mask {mask_i} contains conflicting values in channel {channel_i}. Channel: {channel}') | ||
assert not unexpected_results | ||
|
||
|
||
def test_channel_mask_masks_correct_number_of_cells(): | ||
"""Tests whether the correct fraction of cells is masked.""" | ||
number_of_masks = 1 | ||
input_data = _get_multivariate_input_data(number_of_channels=10) | ||
p_keep = 0.3 | ||
|
||
result = generate_channel_masks(input_data, number_of_masks, p_keep) | ||
|
||
assert result.sum() / np.product(result.shape) == p_keep | ||
|
||
|
||
def test_masking_has_correct_shape_multivariate(): | ||
"""Test for the correct output shape for the general masking function.""" | ||
number_of_masks = 15 | ||
input_data = _get_multivariate_input_data() | ||
|
||
result = generate_masks(input_data, number_of_masks, 0.5) | ||
|
||
assert result.shape == tuple([number_of_masks] + list(input_data.shape)) | ||
|
||
|
||
def test_masking_univariate_leaves_anything_unmasked(): | ||
"""Tests that something remains unmasked and some parts are masked for the univariate case.""" | ||
number_of_masks = 1 | ||
input_data = _get_univariate_input_data() | ||
|
||
result = generate_masks(input_data, number_of_masks, 0.5) | ||
|
||
assert np.any(result) | ||
assert np.any(~result) | ||
Comment on lines
+129
to
+130
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Smart check! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite an interesting way to implement the masking for multi-channels. I can understand that by doing this we can have a very balanced masking array, which contains masking for entire channels, masking for certain time steps across channels and a mixture of them. But I have a feeling that this makes it a bit too complex.
What I have in mind are two simple ways:
I can imagine that the current implementation makes the segmentation very tricky to code. Let's have a chat about it. It could be possible that I misunderstand something.
But thanks a lot for the effort! This also provides more insight about what we want.