Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dashboard #548

Merged
merged 22 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions dashboard/.streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[theme]
base="light"
primaryColor="0e749b"
# backgroundColor=
secondaryBackgroundColor="#e4f3f9"
# textColor=
# font=

[browser]
gatherUsageStats = false
50 changes: 50 additions & 0 deletions dashboard/_image_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np
from keras import utils as keras_utils
from PIL import Image
from PIL import ImageStat
from dianna.utils import move_axis
from dianna.utils import to_xarray


def preprocess_img_resnet(path):
"""Resnet specific function for preprocessing.

Reshape figure to 224,224 and get colour channel at position 0.
Also: for resnet preprocessing: normalize the data. This works specifically for ImageNet.
See: https://github.com/onnx/models/tree/main/vision/classification/resnet
"""
img = keras_utils.load_img(path, target_size=(224, 224))
img_data = keras_utils.img_to_array(img)
if img_data.shape[0] != 3:
# Colour channel is not in position 0; reshape the data
xarray = to_xarray(img_data, {0: 'height', 1: 'width', 2: 'channels'})
reshaped_data = move_axis(xarray, 'channels', 0)
img_data = np.array(reshaped_data)

# definitions for normalisation (for ImageNet)
mean_vec = np.array([0.485, 0.456, 0.406])
stddev_vec = np.array([0.229, 0.224, 0.225])

norm_img_data = np.zeros(img_data.shape).astype('float32')

for i in range(img_data.shape[0]):
# for each pixel in each channel, divide the values by 255 ([0,1]), and normalize
# using mean and standard deviation from values above
norm_img_data[i, :, :] = (img_data[i, :, :] / 255 -
mean_vec[i]) / stddev_vec[i]

return norm_img_data, img


def open_image(file):
"""Open an image from a file and returns it as a numpy array."""
im = Image.open(file).convert('RGB')
stat = ImageStat.Stat(im)
im = np.asarray(im).astype(np.float32)

if sum(stat.sum
) / 3 == stat.sum[0]: # check the avg with any element value
return np.expand_dims(im[:, :, 0], axis=2) / 255, im # if grayscale
else:
# else it's colour, reshape to 224x224x3 for resnet
return preprocess_img_resnet(file)
27 changes: 27 additions & 0 deletions dashboard/_model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np
import onnx


def preprocess_function(image):
"""For LIME: we divided the input data by 256 for the model (binary mnist) and LIME needs RGB values."""
return (image / 256).astype(np.float32)


def fill_segmentation(values, segmentation):
"""For KernelSHAP: fill each pixel with SHAP values."""
out = np.zeros(segmentation.shape)
for i, _ in enumerate(values):
out[segmentation == i] = values[i]
return out


def load_model(file):
onnx_model = onnx.load(file)
return onnx_model


def load_labels(file):
labels = [line.decode().rstrip() for line in file.readlines()]
if labels is None or labels == ['']:
raise ValueError(labels)
return labels
60 changes: 60 additions & 0 deletions dashboard/_models_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import tempfile
import numpy as np
import streamlit as st
from _model_utils import fill_segmentation
from _model_utils import preprocess_function
from onnx_tf.backend import prepare
from dianna import explain_image


def get_top_indices(predictions, n_top):
indices = np.array(np.argpartition(predictions, -n_top)[-n_top:])
indices = indices[np.argsort(predictions[indices])]
indices = np.flip(indices)
return indices


@st.cache_data
def predict(*, model, image):
output_node = prepare(model, gen_tensor_dict=True).outputs[0]
predictions = (prepare(model).run(image[None, ...])[str(output_node)])
return predictions[0]


@st.cache_data
def _run_rise_image(model, image, i, **kwargs):
relevances = explain_image(
model,
image,
**kwargs,
)
return relevances[0]


@st.cache_data
def _run_lime_image(model, image, i, **kwargs):
relevances = explain_image(
model,
image * 256,
preprocess_function=preprocess_function,
**kwargs,
)
return relevances[0]


@st.cache_data
def _run_kernelshap_image(model, image, i, **kwargs):
# Kernelshap interface is different. Write model to temporary file.
with tempfile.NamedTemporaryFile() as f:
f.write(model)
f.flush()
shap_values, segments_slic = explain_image(f.name, image, **kwargs)

return fill_segmentation(shap_values[i][0], segments_slic)


explain_image_dispatcher = {
'RISE': _run_rise_image,
'LIME': _run_lime_image,
'KernelSHAP': _run_kernelshap_image,
}
37 changes: 37 additions & 0 deletions dashboard/_models_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import streamlit as st
from _movie_model import MovieReviewsModelRunner
from dianna import explain_text
from dianna.utils.tokenizers import SpacyTokenizer


tokenizer = SpacyTokenizer()


@st.cache_data
def predict(*, model, text_input):
model_runner = MovieReviewsModelRunner(model)
predictions = model_runner(text_input)
return predictions


@st.cache_data
def _run_rise_text(_model, text, **kwargs):
relevances = explain_text(
_model,
text,
tokenizer,
**kwargs,
)
return relevances


@st.cache_data
def _run_lime_text(_model, text, **kwargs):
relevances = explain_text(_model, text, tokenizer, **kwargs)
return relevances


explain_text_dispatcher = {
'RISE': _run_rise_text,
'LIME': _run_lime_text,
}
53 changes: 53 additions & 0 deletions dashboard/_movie_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
from pathlib import Path
import numpy as np
from scipy.special import expit as sigmoid
from torchtext.vocab import Vectors
import dianna
from dianna import utils
from dianna.utils.tokenizers import SpacyTokenizer


class MovieReviewsModelRunner:
"""Creates runner for movie review model."""

def __init__(self, model, word_vectors=None, max_filter_size=5):
"""Initializes the class."""
if word_vectors is None:
dianna_root_dir = Path(dianna.__file__).parents[1]
word_vectors = dianna_root_dir / 'tutorials' / 'data' / 'movie_reviews_word_vectors.txt'

self.run_model = utils.get_function(model)
self.vocab = Vectors(word_vectors, cache=os.path.dirname(word_vectors))
self.max_filter_size = max_filter_size
self.tokenizer = SpacyTokenizer()

def __call__(self, sentences):
"""Call Runner."""
# ensure the input has a batch axis
if isinstance(sentences, str):
sentences = [sentences]

tokenized_sentences = []
for sentence in sentences:
# tokenize and pad to minimum length
tokens = self.tokenizer.tokenize(sentence)
if len(tokens) < self.max_filter_size:
tokens += ['<pad>'] * (self.max_filter_size - len(tokens))

# numericalize the tokens
tokens_numerical = [
self.vocab.stoi[token]
if token in self.vocab.stoi else self.vocab.stoi['<unk>']
for token in tokens
]
tokenized_sentences.append(tokens_numerical)

# run the model, applying a sigmoid because the model outputs logits
logits = self.run_model(tokenized_sentences)
pred = np.apply_along_axis(sigmoid, 1, logits)

# output pos/neg
positivity = pred[:, 0]
negativity = 1 - positivity
return np.transpose([negativity, positivity])
7 changes: 7 additions & 0 deletions dashboard/_text_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from _models_text import tokenizer
from dianna.visualization.text import _create_html


def format_word_importances(text, relevances) -> str:
tokens = tokenizer.tokenize(text)
return _create_html(tokens, relevances)
Loading