Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

478 rise for time series #506

Merged
merged 24 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bf44f46
add WIP rise timeseries (refs #478)
cwmeijer Mar 7, 2023
a4c8e9e
Merge branch '477-time-series-maskers' into 478-rise-for-time-series
cwmeijer Mar 9, 2023
1912245
add RISE for timeseries minimal viable (refs #478)
cwmeijer Mar 9, 2023
201f57d
add rise input-output shape test (#refs 478)
cwmeijer Mar 9, 2023
43d80b2
Merge branch 'main' into 478-rise-for-time-series
cwmeijer Mar 9, 2023
52c41e2
add common usage test for rise timeseries (refs #478)
cwmeijer Mar 9, 2023
2148b74
fix bug: move default value to active place (#478)
cwmeijer Mar 9, 2023
87ea848
fix linter issues (#478)
cwmeijer Mar 9, 2023
d4b63b4
fix pylint disable comment
cwmeijer Mar 9, 2023
f87f47a
run isort
cwmeijer Mar 21, 2023
040dc03
Merge branch '478-rise-for-time-series' of github.com:dianna-ai/diann…
cwmeijer Mar 22, 2023
666092c
reverse mask value meaning (!); add option for custom mask strategy
cwmeijer Mar 22, 2023
c0b7918
Merge branch '477-time-series-maskers' into 478-rise-for-time-series
cwmeijer Mar 22, 2023
5902565
add masking strategy option to rise for timeseries
cwmeijer Mar 22, 2023
b5c140a
add rise timeseries notebook
cwmeijer Mar 22, 2023
b56c365
add binary weather timeseries model for notebook
cwmeijer Mar 28, 2023
bdda754
add rise timeseries test with expert model
cwmeijer Mar 28, 2023
44109cd
Merge branch 'main' into 478-rise-for-time-series
cwmeijer Mar 28, 2023
fc75f2a
add docstring, fix comments, sort imports
cwmeijer Mar 28, 2023
ebe8352
add timeseries test case
cwmeijer Mar 28, 2023
b4d5022
change notebook kernel
cwmeijer Mar 28, 2023
5244afc
change download link to zenodo in rise timeseries notebook
cwmeijer Mar 29, 2023
6fafa1d
Apply suggestions from code review
cwmeijer Mar 30, 2023
ee2c4a5
fix isort
Mar 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions dianna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@
__version__ = "0.7.0"


def explain_timeseries(model_or_function, timeseries_data, method, labels, **kwargs):
"""Explain timeseries data given a model and a chosen method.

Args:
model_or_function (callable or str): The function that runs the model to be explained _or_
the path to a ONNX model on disk.
timeseries_data (np.ndarray): Image data to be explained
cwmeijer marked this conversation as resolved.
Show resolved Hide resolved
method (string): One of the supported methods: RISE, LIME or KernelSHAP
labels (Iterable(int)): Labels to be explained
**kwargs: key word arguments
cwmeijer marked this conversation as resolved.
Show resolved Hide resolved

Returns:
One heatmap per class.

"""
explainer = _get_explainer(method, kwargs, modality="Timeseries")
explain_image_kwargs = utils.get_kwargs_applicable_to_function(explainer.explain, kwargs)
return explainer.explain(model_or_function,
timeseries_data,
labels,
**explain_image_kwargs)


def explain_image(model_or_function, input_data, method, labels, **kwargs):
"""Explain an image (input_data) given a model and a chosen method.

Expand Down
1 change: 1 addition & 0 deletions dianna/methods/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from skimage.transform import resize
from tqdm import tqdm
from dianna import utils
from dianna.methods.rise_timeseries import RISETimeseries # noqa: F401 ignore unused import
cwmeijer marked this conversation as resolved.
Show resolved Hide resolved


def normalize(saliency, n_masks, p_keep):
Expand Down
70 changes: 70 additions & 0 deletions dianna/methods/rise_timeseries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np
from tqdm import tqdm
from dianna import utils
from dianna.utils.maskers import generate_masks
from dianna.utils.maskers import mask_data


def _make_predictions(input_data, runner, batch_size):
"""Process the input_data with the model runner in batches and return the predictions."""
number_of_masks = input_data.shape[0]
batch_predictions = []
for i in tqdm(range(0, number_of_masks, batch_size), desc='Explaining'):
batch_predictions.append(runner(input_data[i:i + batch_size]))
return np.concatenate(batch_predictions)


# Duplicate code from rise.py:
def normalize(saliency, n_masks, p_keep):
"""Normalizes salience by number of masks and keep probability."""
return saliency / n_masks / p_keep


class RISETimeseries:
"""RISE implementation for timeseries adapted from the image version of RISE."""

def __init__(self, n_masks=1000, feature_res=8, p_keep=0.5,
preprocess_function=None):
"""RISE initializer.

Args:
n_masks (int): Number of masks to generate.
feature_res (int): Resolution of features in masks.
p_keep (float): Fraction of input data to keep in each mask (Default: auto-tune this value).
preprocess_function (callable, optional): Function to preprocess input data with
"""
self.n_masks = n_masks
self.feature_res = feature_res
self.p_keep = p_keep
self.preprocess_function = preprocess_function
self.masks = None
self.predictions = None

def explain(self, model_or_function, input_timeseries, labels, batch_size=100, mask_type='mean'):
"""Runs the RISE explainer on images.

The model will be called with masked timeseries,
with a shape defined by `batch_size` and the shape of `input_data`.

Args:
model_or_function (callable or str): The function that runs the model to be explained _or_
the path to a ONNX model on disk.
input_timeseries (np.ndarray): Input timeseries data to be explained
batch_size (int): Batch size to use for running the model.
labels (Iterable(int)): Labels to be explained
mask_type: Masking strategy for masked values. Choose from 'mean' or a callable(input_timeseries)

Returns:
Explanation heatmap for each class (np.ndarray).
"""
runner = utils.get_function(model_or_function, preprocess_function=self.preprocess_function)
self.masks = generate_masks(input_timeseries, number_of_masks=self.n_masks, p_keep=self.p_keep)
masked = mask_data(input_timeseries, self.masks, mask_type=mask_type)

self.predictions = _make_predictions(masked, runner, batch_size)
n_labels = self.predictions.shape[1]

saliency = self.predictions.T.dot(self.masks.reshape(self.n_masks, -1)).reshape(n_labels,
*input_timeseries.shape)
selected_saliency = saliency[labels]
return normalize(selected_saliency, self.n_masks, self.p_keep)
16 changes: 9 additions & 7 deletions dianna/utils/maskers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from typing import Union
import numpy as np


Expand All @@ -17,35 +18,36 @@ def generate_masks(input_data: np.array, number_of_masks: int, p_keep: float = 0
number_of_steps_masked = _determine_number_of_steps_masked(p_keep, series_length)

masked_data_shape = [number_of_masks] + list(input_data.shape)
masks = np.zeros(masked_data_shape, dtype=np.bool)
masks = np.ones(masked_data_shape, dtype=np.bool)
for i in range(number_of_masks):
steps_to_mask = np.random.choice(series_length, number_of_steps_masked, False)
masked_value = 1
masks[i, steps_to_mask] = masked_value
masks[i, steps_to_mask] = False
return masks


def mask_data(data, masks, mask_type='mean'):
def mask_data(data: np.array, masks: np.array, mask_type: Union[object, str]):
"""Mask data given using a set of masks.

Args:
data: ?
cwmeijer marked this conversation as resolved.
Show resolved Hide resolved
masks: an array with shape [number_of_masks] + data.shape
mask_type: ?
mask_type: Masking strategy.

Returns:
Single array containing all masked input where the first dimension represents the batch.
"""
number_of_masks = masks.shape[0]
input_data_batch = np.repeat(np.expand_dims(data, 0), number_of_masks, axis=0)
result = np.empty(input_data_batch.shape)
result[~masks] = input_data_batch[~masks]
result[masks] = _get_mask_value(data, mask_type)
result[masks] = input_data_batch[masks]
result[~masks] = _get_mask_value(data, mask_type)
return result


def _get_mask_value(data: np.array, mask_type: str) -> int:
"""Calculates a masking value of the given type for the data."""
if callable(mask_type):
return mask_type(data)
if mask_type == 'mean':
return np.mean(data)
raise ValueError(f'Unknown mask_type selected: {mask_type}')
Expand Down
37 changes: 37 additions & 0 deletions tests/methods/test_rise_timeseries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
import dianna
from tests.methods.time_series_test_case import average_temperature_timeseries_with_1_cold_and_1_hot_day
from tests.methods.time_series_test_case import input_train_mean
from tests.methods.time_series_test_case import run_expert_model
from tests.utils import run_model


def test_rise_timeseries_correct_output_shape():
"""Test if rise runs and outputs the correct shape given some data and a model function."""
input_data = np.random.random((10, 1))
axis_labels = ['t', 'channels']
labels = [1]

heatmaps = dianna.explain_timeseries(run_model, input_data, "RISE", labels, axis_labels=axis_labels,
n_masks=200, p_keep=.5)

assert heatmaps.shape == (len(labels), *input_data.shape)


def test_rise_timeseries_with_expert_model_for_correct_max_and_min():
"""Test if RISE highlights the correct areas for this artificial example."""
hot_day_index = 6
cold_day_index = 12
temperature_timeseries = average_temperature_timeseries_with_1_cold_and_1_hot_day(cold_day_index, hot_day_index)

summer_explanation, winter_explanation = dianna.explain_timeseries(run_expert_model,
timeseries_data=temperature_timeseries,
method='rise',
labels=[0, 1],
p_keep=0.1, n_masks=10000,
mask_type=input_train_mean)

assert np.argmax(summer_explanation) == hot_day_index
assert np.argmin(summer_explanation) == cold_day_index
assert np.argmax(winter_explanation) == cold_day_index
assert np.argmin(winter_explanation) == hot_day_index
35 changes: 35 additions & 0 deletions tests/methods/time_series_test_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np


"""In this test case, every test instance is a 28 days by 1 channel array indicating the max temp on a day."""
cwmeijer marked this conversation as resolved.
Show resolved Hide resolved


def input_train_mean(_data):
"""Return overall mean temperature of 14."""
return 14


def average_temperature_timeseries_with_1_cold_and_1_hot_day(cold_day_index, hot_day_index):
"""Creates a temperature time series of all 14s and a single cold (-2) and hot (30) day."""
temperature_timeseries = np.expand_dims(np.zeros(28), axis=1) + 14
temperature_timeseries[hot_day_index] = 30
temperature_timeseries[cold_day_index] = -2
return temperature_timeseries


def run_expert_model(data):
"""A simple model that classifies a batch of timeseries.

All instances with an average above 14 are classified as summer (0) and the rest as winter (1).
"""
# Make actual decision
is_summer = np.mean(np.mean(data, axis=1), axis=1) > 14

# Create the correct output format
number_of_classes = 2
number_of_instances = data.shape[0]
result = np.zeros((number_of_instances, number_of_classes))
result[is_summer] = [1.0, 0.0]
result[~is_summer] = [0.0, 1.0]

return result
30 changes: 23 additions & 7 deletions tests/test_common_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,29 @@
from tests.utils import run_model


input_data = np.random.random((224, 224, 3))
axis_labels = {-1: 'channels'}
labels = [0, 1]
def test_common_RISE_image_pipeline(): # noqa: N802 ignore case
"""No errors thrown while creating a relevance map and visualizing it."""
input_image = np.random.random((224, 224, 3))
axis_labels = {-1: 'channels'}
labels = [0, 1]

heatmap = dianna.explain_image(run_model, input_image, "RISE", labels, axis_labels=axis_labels)[0]
dianna.visualization.plot_image(heatmap, show_plot=False)
dianna.visualization.plot_image(heatmap, original_data=input_image[0], show_plot=False)


def test_common_RISE_pipeline(): # noqa: N802 ignore case
def test_common_RISE_timeseries_pipeline(): # noqa: N802 ignore case
"""No errors thrown while creating a relevance map and visualizing it."""
heatmap = dianna.explain_image(run_model, input_data, "RISE", labels, axis_labels=axis_labels)[0]
dianna.visualization.plot_image(heatmap, show_plot=False)
dianna.visualization.plot_image(heatmap, original_data=input_data[0], show_plot=False)
input_image = np.random.random((31, 1))
cwmeijer marked this conversation as resolved.
Show resolved Hide resolved
labels = [0]

heatmap = dianna.explain_timeseries(run_model, input_image, "RISE", labels)[0]
cwmeijer marked this conversation as resolved.
Show resolved Hide resolved
heatmap_channel = heatmap[:, 0]
segments = []
for i in range(len(heatmap_channel) - 1):
segments.append({
'index': i,
'start': i,
'stop': i + 1,
'weight': heatmap_channel[i]})
dianna.visualization.plot_timeseries(range(len(heatmap_channel)), input_image[:, 0], segments, show_plot=False)
cwmeijer marked this conversation as resolved.
Show resolved Hide resolved
Binary file not shown.
Loading