Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

514 add channel masking #554

Merged
merged 6 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions dianna/utils/maskers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def generate_masks(input_data: np.array, number_of_masks: int, p_keep: float = 0.5):
"""Generate masks for time series data given a probability of keeping any time step unmasked.
"""Generate masks for time series data given a probability of keeping any time step or channel unmasked.

Args:
input_data: Timeseries data to be explained.
Expand All @@ -14,9 +14,39 @@ def generate_masks(input_data: np.array, number_of_masks: int, p_keep: float = 0
Returns:
Single array containing all masks where the first dimension represents the batch.
"""
series_length = input_data.shape[0]
number_of_steps_masked = _determine_number_of_steps_masked(p_keep, series_length)
if input_data.shape[-1] == 1: # univariate data
return generate_time_step_masks(input_data, number_of_masks, p_keep)

number_of_channel_masks = number_of_masks // 3
number_of_time_step_masks = number_of_channel_masks
number_of_combined_masks = number_of_masks - number_of_time_step_masks - number_of_channel_masks
Comment on lines +20 to +22
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite an interesting way to implement the masking for multi-channels. I can understand that by doing this we can have a very balanced masking array, which contains masking for entire channels, masking for certain time steps across channels and a mixture of them. But I have a feeling that this makes it a bit too complex.

What I have in mind are two simple ways:

  1. Simple flatten the input data and mask them brute-forcely
  2. Loop through all channels and treat them individually (mask each channel separately and concatenate them)

I can imagine that the current implementation makes the segmentation very tricky to code. Let's have a chat about it. It could be possible that I misunderstand something.

But thanks a lot for the effort! This also provides more insight about what we want.


time_step_masks = generate_time_step_masks(input_data, number_of_time_step_masks, p_keep)
channel_masks = generate_channel_masks(input_data, number_of_channel_masks, p_keep)
number_of_combined_masks = generate_time_step_masks(input_data, number_of_combined_masks,
p_keep) * generate_channel_masks(input_data,
number_of_combined_masks,
p_keep)

return np.concatenate([time_step_masks, channel_masks, number_of_combined_masks], axis=0)


def generate_channel_masks(input_data: np.ndarray, number_of_masks: int, p_keep: float):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the duplication part you mentioned in the top post I think, just a reminder for ourselves, that it will be solved in PR #562, I will leave a comment there as well.

"""Generate masks that mask one or multiple channels at a time."""
number_of_channels = input_data.shape[1]
number_of_channels_masked = _determine_number_masked(p_keep, number_of_channels)
masked_data_shape = [number_of_masks] + list(input_data.shape)
masks = np.ones(masked_data_shape, dtype=np.bool)
for i in range(number_of_masks):
channels_to_mask = np.random.choice(number_of_channels, number_of_channels_masked, False)
masks[i, :, channels_to_mask] = False
return masks


def generate_time_step_masks(input_data: np.ndarray, number_of_masks: int, p_keep: float):
"""Generate masks that mask one or multiple time steps at a time."""
series_length = input_data.shape[0]
number_of_steps_masked = _determine_number_masked(p_keep, series_length)
masked_data_shape = [number_of_masks] + list(input_data.shape)
masks = np.ones(masked_data_shape, dtype=np.bool)
for i in range(number_of_masks):
Expand Down Expand Up @@ -53,7 +83,7 @@ def _get_mask_value(data: np.array, mask_type: str) -> int:
raise ValueError(f'Unknown mask_type selected: {mask_type}')


def _determine_number_of_steps_masked(p_keep: float, series_length: int) -> int:
def _determine_number_masked(p_keep: float, series_length: int) -> int:
user_requested_steps = int(np.round(series_length * (1 - p_keep)))
if user_requested_steps == series_length:
warnings.warn('Warning: p_keep chosen too low. Continuing with leaving 1 time step unmasked per mask.')
Expand Down
68 changes: 66 additions & 2 deletions tests/methods/test_maskers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
from dianna.utils.maskers import generate_channel_masks
from dianna.utils.maskers import generate_masks
from dianna.utils.maskers import mask_data

Expand Down Expand Up @@ -54,13 +55,76 @@ def test_mask_contains_correct_parts_are_mean_masked():


def _get_univariate_input_data() -> np.array:
"""Get some univariate test data."""
return np.zeros((10, 1)) + np.arange(10).reshape(10, 1)


def _get_multivariate_input_data() -> np.array:
return np.zeros((10, 6)) + np.arange(10).reshape(10, 1)
def _get_multivariate_input_data(number_of_channels: int = 6) -> np.array:
"""Get some multivariate test data."""
return np.row_stack([np.zeros((10, number_of_channels)), np.ones((10, number_of_channels))])


def _call_masking_function(input_data, number_of_masks=5, p_keep=.3, mask_type='mean'):
"""Helper function with some defaults to call the code under test."""
masks = generate_masks(input_data, number_of_masks, p_keep=p_keep)
return mask_data(input_data, masks, mask_type=mask_type)


def test_channel_mask_has_correct_shape_multivariate():
"""Tests the output has the correct shape."""
number_of_masks = 15
input_data = _get_multivariate_input_data()

result = generate_channel_masks(input_data, number_of_masks, 0.5)

assert result.shape == tuple([number_of_masks] + list(input_data.shape))


def test_channel_mask_has_does_not_contain_conflicting_values():
"""Tests that only complete channels are masked."""
number_of_masks = 15
input_data = _get_multivariate_input_data()

result = generate_channel_masks(input_data, number_of_masks, 0.5)

unexpected_results = []
for mask_i, mask in enumerate(result):
for channel_i in range(mask.shape[-1]):
channel = mask[:, channel_i]
value = channel[0]
if (not value) in channel:
unexpected_results.append(
f'Mask {mask_i} contains conflicting values in channel {channel_i}. Channel: {channel}')
assert not unexpected_results


def test_channel_mask_masks_correct_number_of_cells():
"""Tests whether the correct fraction of cells is masked."""
number_of_masks = 1
input_data = _get_multivariate_input_data(number_of_channels=10)
p_keep = 0.3

result = generate_channel_masks(input_data, number_of_masks, p_keep)

assert result.sum() / np.product(result.shape) == p_keep


def test_masking_has_correct_shape_multivariate():
"""Test for the correct output shape for the general masking function."""
number_of_masks = 15
input_data = _get_multivariate_input_data()

result = generate_masks(input_data, number_of_masks, 0.5)

assert result.shape == tuple([number_of_masks] + list(input_data.shape))


def test_masking_univariate_leaves_anything_unmasked():
"""Tests that something remains unmasked and some parts are masked for the univariate case."""
number_of_masks = 1
input_data = _get_univariate_input_data()

result = generate_masks(input_data, number_of_masks, 0.5)

assert np.any(result)
assert np.any(~result)
Comment on lines +129 to +130
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart check!