Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

378 update button for execution #396

Merged
merged 13 commits into from
Dec 7, 2022
268 changes: 136 additions & 132 deletions dashboard/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,147 +210,151 @@ def compute_value_i(method_sel, fn_m, fn_i):
dash.dependencies.Output('graph_img', 'figure'),
dash.dependencies.State("upload-model-img", "filename"),
dash.dependencies.State("upload-image", "filename"),
dash.dependencies.Input("signal_image", "data"),
dash.dependencies.Input("upload-model-img", "filename"),
dash.dependencies.Input("upload-image", "filename"),
dash.dependencies.Input("show_top", "value"),
dash.dependencies.Input("n_masks", "value"),
dash.dependencies.Input("feature_res", "value"),
dash.dependencies.Input("p_keep", "value"),
dash.dependencies.Input("n_samples", "value"),
dash.dependencies.Input("background", "value"),
dash.dependencies.Input("n_segments", "value"),
dash.dependencies.Input("sigma", "value"),
dash.dependencies.Input("random_state", "value")
dash.dependencies.State("signal_image", "data"),
dash.dependencies.State("upload-model-img", "filename"),
dash.dependencies.State("upload-image", "filename"),
dash.dependencies.State("show_top", "value"),
dash.dependencies.State("n_masks", "value"),
dash.dependencies.State("feature_res", "value"),
dash.dependencies.State("p_keep", "value"),
dash.dependencies.State("n_samples", "value"),
dash.dependencies.State("background", "value"),
dash.dependencies.State("n_segments", "value"),
dash.dependencies.State("sigma", "value"),
dash.dependencies.State("random_state", "value"),
dash.dependencies.Input("update_button", "n_clicks"),
dash.dependencies.Input("stop_button", "n_clicks")
)
# pylint: disable=too-many-locals
# pylint: disable=unused-argument
# pylint: disable=too-many-arguments
def update_multi_options_i(fn_m, fn_i, sel_methods, new_model, new_image,
show_top=2, n_masks=1000, feature_res=6, p_keep=0.1, n_samples=1000,
background=0, n_segments=200, sigma=0, random_state=2):
def update_multi_options_i(fn_m, fn_i, sel_methods, new_model, new_image, show_top=2, n_masks=1000, feature_res=6, p_keep=0.1, n_samples=1000,
background=0, n_segments=200, sigma=0, random_state=2, update_button=0, stop_button=0):
"""Takes in the last model and image uploaded filenames, the selected XAI method, and returns the selected XAI method."""
ctx = dash.callback_context

if ((ctx.triggered[0]["prop_id"] == "upload-model-img.filename") or
(ctx.triggered[0]["prop_id"] == "upload-image.filename") or
(not ctx.triggered)):
cache.clear()
return html.Div(['']), utilities.blank_fig()
if (not sel_methods):
return html.Div(['']), utilities.blank_fig()

# if ((ctx.triggered[0]["prop_id"] == "upload-model-img.filename") or
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this part is no longer needed, it can maybe be deleted?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this part relates to the global cache in #418. So I do not want to delete it yet.

# (ctx.triggered[0]["prop_id"] == "upload-image.filename") or
# (not ctx.triggered)):
# cache.clear()
# return html.Div(['']), utilities.blank_fig()
# if (not sel_methods):
# return html.Div(['']), utilities.blank_fig()
if (ctx.triggered[0]["prop_id"] == "stop_button.n_clicks"):
return (html.Div(['Explanation stopped.'], style={'margin-top' : '60px'}),
utilities.blank_fig())
# update graph
if (fn_m and fn_i) is not None:
data_path = os.path.join(folder_on_server, fn_i[0])
X_test, _ = utilities.open_image(data_path)

onnx_model_path = os.path.join(folder_on_server, fn_m[0])
onnx_model = onnx.load(onnx_model_path)
# get the output node
output_node = prepare(onnx_model, gen_tensor_dict=True).outputs[0]

try:
predictions = (prepare(onnx_model).run(X_test[None, ...])
[f'{output_node}'])
if len(predictions[0]) == 2:
class_name = class_name_mnist
else:
class_name = class_names_imagenet
# get the predicted class
preds = np.array(predictions[0])
pred_class = class_name[np.argmax(preds)]
# get the top most likely results
if show_top > len(class_name):
show_top = len(class_name)
# make sure the top results are ordered most to least likely
ind = np.array(np.argpartition(preds, -show_top)[-show_top:])
ind = ind[np.argsort(preds[ind])]
ind = np.flip(ind)
top = [class_name[i] for i in ind]
n_rows = len(top)
fig = make_subplots(rows=n_rows, cols=3,
subplot_titles=("RISE", "KernelShap", "LIME"), row_titles=top,
shared_xaxes=True, vertical_spacing=0.02,
horizontal_spacing = 0.02)
# check which axis is color channel
if X_test.shape[2] <=3:
z_rise = X_test[:, :, 0]
axis_labels = {2: 'channels'}
colorscale='Bluered'
else:
z_rise = X_test[1, :, :]
axis_labels = {0: 'channels'}
colorscale='jet'
for m in sel_methods:
for i in range(n_rows):
if m == "RISE":
# RISE plot
relevances_rise = global_store_i('RISE',
onnx_model_path, X_test, labels=[ind[i]],
axis_labels=axis_labels, n_masks=n_masks,
feature_res=feature_res, p_keep=p_keep)
fig.add_trace(
go.Heatmap(z=z_rise, colorscale='gray',
showscale=False), i+1, 1)
fig.add_trace(
go.Heatmap(z=relevances_rise[0],
colorscale=colorscale, showscale=False,
opacity=0.7), i+1, 1)
elif m == "KernelSHAP":
shap_values, segments_slic = global_store_i(
m, onnx_model_path, X_test, labels=[ind[i]],
axis_labels=axis_labels, n_samples=n_samples,
background=background, n_segments=n_segments,
sigma=sigma)

# KernelSHAP plot
fig.add_trace(
go.Heatmap(z=z_rise, colorscale='gray',
showscale=False), i+1, 2)
fig.add_trace(
go.Heatmap(
z=utilities.fill_segmentation(shap_values[i][0],
segments_slic), colorscale='Bluered',
showscale=False, opacity=0.7), i+1, 2)
else:
relevances_lime = global_store_i(
m, onnx_model_path, X_test, labels=[ind[i]],
axis_labels=axis_labels, random_state=random_state)
# LIME plot
fig.add_trace(
go.Heatmap(z=z_rise, colorscale='gray',
showscale=False), i+1, 3)
fig.add_trace(
go.Heatmap(z=relevances_lime[0],
colorscale='bluered', showscale=False,
opacity=0.7), i+1, 3)

fig.update_layout(
width=650,
height=(200*n_rows+50),
paper_bgcolor=layouts.colors['blue4'])

fig.update_xaxes(showgrid=False, showticklabels=False,
zeroline=False)
fig.update_yaxes(showgrid=False, showticklabels=False,
zeroline=False, autorange="reversed")

return html.Div(['The predicted class is: ' + pred_class], style={
'fontSize': 18,
'font-weight': 'bold',
'text-decoration': 'underline',
'margin-top': '60px',
'textAlign' : 'center'
}), fig
elif (ctx.triggered[0]["prop_id"] == "update_button.n_clicks"):
if (fn_m and fn_i and sel_methods) is not None:

data_path = os.path.join(folder_on_server, fn_i[0])
X_test, _ = utilities.open_image(data_path)

onnx_model_path = os.path.join(folder_on_server, fn_m[0])
onnx_model = onnx.load(onnx_model_path)
# get the output node
output_node = prepare(onnx_model, gen_tensor_dict=True).outputs[0]

try:
predictions = (prepare(onnx_model).run(X_test[None, ...])
[f'{output_node}'])
if len(predictions[0]) == 2:
class_name = class_name_mnist
else:
class_name = class_names_imagenet
# get the predicted class
preds = np.array(predictions[0])
pred_class = class_name[np.argmax(preds)]
# get the top most likely results
if show_top > len(class_name):
show_top = len(class_name)
# make sure the top results are ordered most to least likely
ind = np.array(np.argpartition(preds, -show_top)[-show_top:])
ind = ind[np.argsort(preds[ind])]
ind = np.flip(ind)
top = [class_name[i] for i in ind]
n_rows = len(top)
fig = make_subplots(rows=n_rows, cols=3,
subplot_titles=("RISE", "KernelShap", "LIME"), row_titles=top,
shared_xaxes=True, vertical_spacing=0.02,
horizontal_spacing = 0.02)
# check which axis is color channel
if X_test.shape[2] <=3:
z_rise = X_test[:, :, 0]
axis_labels = {2: 'channels'}
colorscale='Bluered'
else:
z_rise = X_test[1, :, :]
axis_labels = {0: 'channels'}
colorscale='jet'
for m in sel_methods:
for i in range(n_rows):
if m == "RISE":
# RISE plot
relevances_rise = global_store_i('RISE',
onnx_model_path, X_test, labels=[ind[i]],
axis_labels=axis_labels, n_masks=n_masks,
feature_res=feature_res, p_keep=p_keep)
fig.add_trace(
go.Heatmap(z=z_rise, colorscale='gray',
showscale=False), i+1, 1)
fig.add_trace(
go.Heatmap(z=relevances_rise[0],
colorscale=colorscale, showscale=False,
opacity=0.7), i+1, 1)
elif m == "KernelSHAP":
shap_values, segments_slic = global_store_i(
m, onnx_model_path, X_test, labels=[ind[i]],
axis_labels=axis_labels, n_samples=n_samples,
background=background, n_segments=n_segments,
sigma=sigma)

# KernelSHAP plot
fig.add_trace(
go.Heatmap(z=z_rise, colorscale='gray',
showscale=False), i+1, 2)
fig.add_trace(
go.Heatmap(
z=utilities.fill_segmentation(shap_values[i][0],
segments_slic), colorscale='Bluered',
showscale=False, opacity=0.7), i+1, 2)
else:
relevances_lime = global_store_i(
m, onnx_model_path, X_test, labels=[ind[i]],
axis_labels=axis_labels, random_state=random_state)
# LIME plot
fig.add_trace(
go.Heatmap(z=z_rise, colorscale='gray',
showscale=False), i+1, 3)
fig.add_trace(
go.Heatmap(z=relevances_lime[0],
colorscale='bluered', showscale=False,
opacity=0.7), i+1, 3)
fig.update_layout(
width=650,
height=(200*n_rows+50),
paper_bgcolor=layouts.colors['blue4'])

except Exception as e:
print(e)
return (html.Div(['There was an error running the model. Check' +
'either the test image or the model.']), utilities.blank_fig())
else:
return (html.Div(['Missing either model or image.']),
utilities.blank_fig())
fig.update_xaxes(showgrid=False, showticklabels=False,
zeroline=False)
fig.update_yaxes(showgrid=False, showticklabels=False,
zeroline=False, autorange="reversed")

return html.Div(['The predicted class is: ' + pred_class], style={
'fontSize': 18,
'font-weight': 'bold',
'text-decoration': 'underline',
'margin-top': '60px',
'textAlign' : 'center'
}), fig

except Exception as e:
print(e)
return (html.Div(['There was an error running the model. Check' +
'either the test image or the model.']), utilities.blank_fig())
else:
return (html.Div(['Missing model, image or XAI method.'], style={'margin-top' : '60px'}),
utilities.blank_fig())

###################################################################

Expand Down
41 changes: 39 additions & 2 deletions dashboard/layouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
'blue2' : '#0e749b',
'blue3' : '#15b3f0',
'blue4' : '#E4F3F9', #light blue
'red1' : '#FF0000',
'yellow1' : '#f0d515'
}

Expand Down Expand Up @@ -290,13 +291,49 @@ def get_uploads_images():
], className = 'three columns'
),
],
className = 'row', style = {'padding-bottom' : '3%'}
className = 'row', style = {'padding-bottom' : '1%'}
),
html.Div([
# update button
html.Div([
html.Button('Update explanation',
id='update_button',
n_clicks=0,
style={
'margin-left': '0px',
'margin-top': '0px',
'width': '20%',
'float': 'left',
'backgroundColor': colors['blue2'],
'color' : colors['white']
}
),
],
),
html.Div([
html.Button('Stop Explanation',
id='stop_button',
n_clicks=0,
style={
'margin-left': '40px',
'margin-top': '0px',
'width': '20%',
'float': 'left',
'backgroundColor': colors['red1'],
'color' : colors['white']
}
),
],
),
],
className = 'row', style = {'padding-bottom' : '1%'}
),
# Settings bar
html.Div([
html.Div([
html.H6(children='XAI method specific settings',
style={'font-weight': 'bold'}),
style={'font-weight': 'bold',
'margin-top': '30px',}),
], className='nine columns'
),
],
Expand Down