From b4ba6f90ce67c2c5642784c42b93ace7e1ff438a Mon Sep 17 00:00:00 2001 From: Edward Wang Date: Wed, 13 May 2020 14:21:39 -0700 Subject: [PATCH] Add Occlusion to Insights (#369) Summary: ![image](https://user-images.githubusercontent.com/53842584/81026858-6efefb00-8e30-11ea-970d-5c6907fe3e7b.png) Pull Request resolved: https://github.com/pytorch/captum/pull/369 Reviewed By: vivekmig, J0Nreynolds Differential Revision: D21394665 Pulled By: edward-io fbshipit-source-id: 4f6848928fa271b99ee8a376b6232985fc739b2c --- captum/insights/api.py | 16 ++- captum/insights/config.py | 56 ++++++-- captum/insights/frontend/src/App.js | 120 ++++++++++-------- captum/insights/frontend/widget/src/Widget.js | 19 ++- 4 files changed, 137 insertions(+), 74 deletions(-) diff --git a/captum/insights/api.py b/captum/insights/api.py index 36c0da13db..7125334187 100644 --- a/captum/insights/api.py +++ b/captum/insights/api.py @@ -89,7 +89,7 @@ class FilterConfig(NamedTuple): arg: config.value # type: ignore for arg, config in ATTRIBUTION_METHOD_CONFIG[ IntegratedGradients.get_name() - ].items() + ].params.items() } prediction: str = "all" classes: List[str] = [] @@ -221,6 +221,12 @@ def _calculate_attribution( attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[self._config.attribution_method] attribution_method = attribution_cls(net) args = self._config.attribution_arguments + param_config = ATTRIBUTION_METHOD_CONFIG[self._config.attribution_method] + if param_config.post_process: + for k, v in args.items(): + if k in param_config.post_process: + args[k] = param_config.post_process[k](v) + # TODO support multiple baselines baseline = baselines[0] if baselines and len(baselines) > 0 else None label = ( @@ -329,7 +335,9 @@ def _serve_colab(self, blocking=False, debug=False, port=None): def _get_labels_from_scores( self, scores: Tensor, indices: Tensor ) -> List[OutputScore]: - pred_scores = [] + pred_scores: List[OutputScore] = [] + if indices.nelement() < 2: + return pred_scores for i in range(len(indices)): score = scores[i] pred_scores.append( @@ -542,6 +550,8 @@ def get_insights_config(self): return { "classes": self.classes, "methods": list(ATTRIBUTION_NAMES_TO_METHODS.keys()), - "method_arguments": namedtuple_to_dict(ATTRIBUTION_METHOD_CONFIG), + "method_arguments": namedtuple_to_dict( + {k: v.params for (k, v) in ATTRIBUTION_METHOD_CONFIG.items()} + ), "selected_method": self._config.attribution_method, } diff --git a/captum/insights/config.py b/captum/insights/config.py index 897835e8f6..d1abccb5a7 100644 --- a/captum/insights/config.py +++ b/captum/insights/config.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from typing import Dict, List, NamedTuple, Optional, Tuple +from typing import Dict, List, NamedTuple, Optional, Tuple, Callable, Any, Union from captum.attr import ( Deconvolution, @@ -9,6 +9,7 @@ InputXGradient, IntegratedGradients, Saliency, + Occlusion, ) from captum.attr._utils.approximation_methods import SUPPORTED_METHODS @@ -25,6 +26,13 @@ class StrEnumConfig(NamedTuple): type: str = "enum" +class StrConfig(NamedTuple): + value: str + type: str = "string" + + +Config = Union[NumberConfig, StrEnumConfig, StrConfig] + SUPPORTED_ATTRIBUTION_METHODS = [ Deconvolution, DeepLift, @@ -33,20 +41,50 @@ class StrEnumConfig(NamedTuple): IntegratedGradients, Saliency, FeatureAblation, + Occlusion, ] + +class ConfigParameters(NamedTuple): + params: Dict[str, Config] + help_info: Optional[str] = None # TODO fill out help for each method + post_process: Optional[Dict[str, Callable[[Any], Any]]] = None + + ATTRIBUTION_NAMES_TO_METHODS = { # mypy bug - treating it as a type instead of a class cls.get_name(): cls # type: ignore for cls in SUPPORTED_ATTRIBUTION_METHODS } -ATTRIBUTION_METHOD_CONFIG: Dict[str, Dict[str, tuple]] = { - IntegratedGradients.get_name(): { - "n_steps": NumberConfig(value=25, limit=(2, None)), - "method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"), - }, - FeatureAblation.get_name(): { - "perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)), - }, + +def _str_to_tuple(s): + if isinstance(s, tuple): + return s + return tuple([int(i) for i in s.split()]) + + +ATTRIBUTION_METHOD_CONFIG: Dict[str, ConfigParameters] = { + IntegratedGradients.get_name(): ConfigParameters( + params={ + "n_steps": NumberConfig(value=25, limit=(2, None)), + "method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"), + }, + post_process={"n_steps": int}, + ), + FeatureAblation.get_name(): ConfigParameters( + params={"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100))}, + ), + Occlusion.get_name(): ConfigParameters( + params={ + "sliding_window_shapes": StrConfig(value=""), + "strides": StrConfig(value=""), + "perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)), + }, + post_process={ + "sliding_window_shapes": _str_to_tuple, + "strides": _str_to_tuple, + "perturbations_per_eval": int, + }, + ), } diff --git a/captum/insights/frontend/src/App.js b/captum/insights/frontend/src/App.js index 99a271b0af..461965ddf8 100644 --- a/captum/insights/frontend/src/App.js +++ b/captum/insights/frontend/src/App.js @@ -8,6 +8,7 @@ import "./App.css"; const ConfigType = Object.freeze({ Number: "number", Enum: "enum", + String: "string", }); const Plot = createPlotlyComponent(Plotly); @@ -153,62 +154,71 @@ class FilterContainer extends React.Component { } } -class ClassFilter extends React.Component { - render() { - return ( - + ); +} + +function NumberArgument(props) { + var min = props.limit[0]; + var max = props.limit[1]; + return ( +
+ {props.name}: + - ); - } +
+ ); } -class NumberArgument extends React.Component { - render() { - var min = this.props.limit[0]; - var max = this.props.limit[1]; - return ( -
- {this.props.name + ": "} - -
- ); - } +function EnumArgument(props) { + const options = props.limit.map((item, key) => ( + + )); + return ( +
+ {props.name}: + +
+ ); } -class EnumArgument extends React.Component { - render() { - const options = this.props.limit.map((item, key) => ( - - )); - return ( -
- {this.props.name + ": "} - -
- ); - } +function StringArgument(props) { + return ( +
+ {props.name}: + +
+ ); } class Filter extends React.Component { @@ -232,6 +242,14 @@ class Filter extends React.Component { handleInputChange={this.props.handleArgumentChange} /> ); + case ConfigType.String: + return ( + + ); } }; diff --git a/captum/insights/frontend/widget/src/Widget.js b/captum/insights/frontend/widget/src/Widget.js index 5800f269a0..324b4b0d5a 100644 --- a/captum/insights/frontend/widget/src/Widget.js +++ b/captum/insights/frontend/widget/src/Widget.js @@ -12,10 +12,10 @@ class Widget extends React.Component { config: { classes: [], methods: [], - method_arguments: {} + method_arguments: {}, }, loading: false, - callback: null + callback: null, }; this.backbone = this.props.backbone; } @@ -47,14 +47,11 @@ class Widget extends React.Component { _fetchInit = () => { this.setState({ - config: this.backbone.model.get("insights_config") + config: this.backbone.model.get("insights_config"), }); }; - fetchData = filterConfig => { - filterConfig.approximation_steps = parseInt( - filterConfig.approximation_steps - ); + fetchData = (filterConfig) => { this.setState({ loading: true }, () => { this.backbone.model.save({ config: filterConfig, output: [] }); }); @@ -64,7 +61,7 @@ class Widget extends React.Component { this.setState({ callback: callback }, () => { this.backbone.model.save({ label_details: { labelIndex, instance }, - attribution: {} + attribution: {}, }); }); }; @@ -90,8 +87,8 @@ var CaptumInsightsModel = widgets.DOMWidgetModel.extend({ _model_module: "jupyter-captum-insights", _view_module: "jupyter-captum-insights", _model_module_version: "0.1.0", - _view_module_version: "0.1.0" - }) + _view_module_version: "0.1.0", + }), }); var CaptumInsightsView = widgets.DOMWidgetView.extend({ @@ -99,7 +96,7 @@ var CaptumInsightsView = widgets.DOMWidgetView.extend({ const $app = document.createElement("div"); ReactDOM.render(, $app); this.el.append($app); - } + }, }); export { Widget as default, CaptumInsightsModel, CaptumInsightsView };