Skip to content

Commit

Permalink
Add Occlusion to Insights (pytorch#369)
Browse files Browse the repository at this point in the history
Summary:
![image](https://user-images.githubusercontent.com/53842584/81026858-6efefb00-8e30-11ea-970d-5c6907fe3e7b.png)
Pull Request resolved: pytorch#369

Reviewed By: vivekmig, J0Nreynolds

Differential Revision: D21394665

Pulled By: edward-io

fbshipit-source-id: 4f6848928fa271b99ee8a376b6232985fc739b2c
  • Loading branch information
edward-io authored and facebook-github-bot committed May 13, 2020
1 parent 8e2d11a commit b4ba6f9
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 74 deletions.
16 changes: 13 additions & 3 deletions captum/insights/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class FilterConfig(NamedTuple):
arg: config.value # type: ignore
for arg, config in ATTRIBUTION_METHOD_CONFIG[
IntegratedGradients.get_name()
].items()
].params.items()
}
prediction: str = "all"
classes: List[str] = []
Expand Down Expand Up @@ -221,6 +221,12 @@ def _calculate_attribution(
attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[self._config.attribution_method]
attribution_method = attribution_cls(net)
args = self._config.attribution_arguments
param_config = ATTRIBUTION_METHOD_CONFIG[self._config.attribution_method]
if param_config.post_process:
for k, v in args.items():
if k in param_config.post_process:
args[k] = param_config.post_process[k](v)

# TODO support multiple baselines
baseline = baselines[0] if baselines and len(baselines) > 0 else None
label = (
Expand Down Expand Up @@ -329,7 +335,9 @@ def _serve_colab(self, blocking=False, debug=False, port=None):
def _get_labels_from_scores(
self, scores: Tensor, indices: Tensor
) -> List[OutputScore]:
pred_scores = []
pred_scores: List[OutputScore] = []
if indices.nelement() < 2:
return pred_scores
for i in range(len(indices)):
score = scores[i]
pred_scores.append(
Expand Down Expand Up @@ -542,6 +550,8 @@ def get_insights_config(self):
return {
"classes": self.classes,
"methods": list(ATTRIBUTION_NAMES_TO_METHODS.keys()),
"method_arguments": namedtuple_to_dict(ATTRIBUTION_METHOD_CONFIG),
"method_arguments": namedtuple_to_dict(
{k: v.params for (k, v) in ATTRIBUTION_METHOD_CONFIG.items()}
),
"selected_method": self._config.attribution_method,
}
56 changes: 47 additions & 9 deletions captum/insights/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
from typing import Dict, List, NamedTuple, Optional, Tuple
from typing import Dict, List, NamedTuple, Optional, Tuple, Callable, Any, Union

from captum.attr import (
Deconvolution,
Expand All @@ -9,6 +9,7 @@
InputXGradient,
IntegratedGradients,
Saliency,
Occlusion,
)
from captum.attr._utils.approximation_methods import SUPPORTED_METHODS

Expand All @@ -25,6 +26,13 @@ class StrEnumConfig(NamedTuple):
type: str = "enum"


class StrConfig(NamedTuple):
value: str
type: str = "string"


Config = Union[NumberConfig, StrEnumConfig, StrConfig]

SUPPORTED_ATTRIBUTION_METHODS = [
Deconvolution,
DeepLift,
Expand All @@ -33,20 +41,50 @@ class StrEnumConfig(NamedTuple):
IntegratedGradients,
Saliency,
FeatureAblation,
Occlusion,
]


class ConfigParameters(NamedTuple):
params: Dict[str, Config]
help_info: Optional[str] = None # TODO fill out help for each method
post_process: Optional[Dict[str, Callable[[Any], Any]]] = None


ATTRIBUTION_NAMES_TO_METHODS = {
# mypy bug - treating it as a type instead of a class
cls.get_name(): cls # type: ignore
for cls in SUPPORTED_ATTRIBUTION_METHODS
}

ATTRIBUTION_METHOD_CONFIG: Dict[str, Dict[str, tuple]] = {
IntegratedGradients.get_name(): {
"n_steps": NumberConfig(value=25, limit=(2, None)),
"method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"),
},
FeatureAblation.get_name(): {
"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)),
},

def _str_to_tuple(s):
if isinstance(s, tuple):
return s
return tuple([int(i) for i in s.split()])


ATTRIBUTION_METHOD_CONFIG: Dict[str, ConfigParameters] = {
IntegratedGradients.get_name(): ConfigParameters(
params={
"n_steps": NumberConfig(value=25, limit=(2, None)),
"method": StrEnumConfig(limit=SUPPORTED_METHODS, value="gausslegendre"),
},
post_process={"n_steps": int},
),
FeatureAblation.get_name(): ConfigParameters(
params={"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100))},
),
Occlusion.get_name(): ConfigParameters(
params={
"sliding_window_shapes": StrConfig(value=""),
"strides": StrConfig(value=""),
"perturbations_per_eval": NumberConfig(value=1, limit=(1, 100)),
},
post_process={
"sliding_window_shapes": _str_to_tuple,
"strides": _str_to_tuple,
"perturbations_per_eval": int,
},
),
}
120 changes: 69 additions & 51 deletions captum/insights/frontend/src/App.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import "./App.css";
const ConfigType = Object.freeze({
Number: "number",
Enum: "enum",
String: "string",
});

const Plot = createPlotlyComponent(Plotly);
Expand Down Expand Up @@ -153,62 +154,71 @@ class FilterContainer extends React.Component {
}
}

class ClassFilter extends React.Component {
render() {
return (
<ReactTags
tags={this.props.classes}
autofocus={false}
suggestions={this.props.suggestedClasses}
handleDelete={this.props.handleClassDelete}
handleAddition={this.props.handleClassAdd}
minQueryLength={0}
placeholder="add new class..."
function ClassFilter(props) {
return (
<ReactTags
tags={props.classes}
autofocus={false}
suggestions={props.suggestedClasses}
handleDelete={props.handleClassDelete}
handleAddition={props.handleClassAdd}
minQueryLength={0}
placeholder="add new class..."
/>
);
}

function NumberArgument(props) {
var min = props.limit[0];
var max = props.limit[1];
return (
<div>
{props.name}:
<input
className={cx([styles.input, styles["input--narrow"]])}
name={props.name}
type="number"
value={props.value}
min={min}
max={max}
onChange={props.handleInputChange}
/>
);
}
</div>
);
}

class NumberArgument extends React.Component {
render() {
var min = this.props.limit[0];
var max = this.props.limit[1];
return (
<div>
{this.props.name + ": "}
<input
className={cx([styles.input, styles["input--narrow"]])}
name={this.props.name}
type="number"
value={this.props.value}
min={min}
max={max}
onChange={this.props.handleInputChange}
/>
</div>
);
}
function EnumArgument(props) {
const options = props.limit.map((item, key) => (
<option value={item}>{item}</option>
));
return (
<div>
{props.name}:
<select
className={styles.select}
name={props.name}
value={props.value}
onChange={props.handleInputChange}
>
{options}
</select>
</div>
);
}

class EnumArgument extends React.Component {
render() {
const options = this.props.limit.map((item, key) => (
<option value={item}>{item}</option>
));
return (
<div>
{this.props.name + ": "}
<select
className={styles.select}
name={this.props.name}
value={this.props.value}
onChange={this.props.handleInputChange}
>
{options}
</select>
</div>
);
}
function StringArgument(props) {
return (
<div>
{props.name}:
<input
className={cx([styles.input, styles["input--narrow"]])}
name={props.name}
type="text"
value={props.value}
onChange={props.handleInputChange}
/>
</div>
);
}

class Filter extends React.Component {
Expand All @@ -232,6 +242,14 @@ class Filter extends React.Component {
handleInputChange={this.props.handleArgumentChange}
/>
);
case ConfigType.String:
return (
<StringArgument
name={name}
value={config.value}
handleInputChange={this.props.handleArgumentChange}
/>
);
}
};

Expand Down
19 changes: 8 additions & 11 deletions captum/insights/frontend/widget/src/Widget.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ class Widget extends React.Component {
config: {
classes: [],
methods: [],
method_arguments: {}
method_arguments: {},
},
loading: false,
callback: null
callback: null,
};
this.backbone = this.props.backbone;
}
Expand Down Expand Up @@ -47,14 +47,11 @@ class Widget extends React.Component {

_fetchInit = () => {
this.setState({
config: this.backbone.model.get("insights_config")
config: this.backbone.model.get("insights_config"),
});
};

fetchData = filterConfig => {
filterConfig.approximation_steps = parseInt(
filterConfig.approximation_steps
);
fetchData = (filterConfig) => {
this.setState({ loading: true }, () => {
this.backbone.model.save({ config: filterConfig, output: [] });
});
Expand All @@ -64,7 +61,7 @@ class Widget extends React.Component {
this.setState({ callback: callback }, () => {
this.backbone.model.save({
label_details: { labelIndex, instance },
attribution: {}
attribution: {},
});
});
};
Expand All @@ -90,16 +87,16 @@ var CaptumInsightsModel = widgets.DOMWidgetModel.extend({
_model_module: "jupyter-captum-insights",
_view_module: "jupyter-captum-insights",
_model_module_version: "0.1.0",
_view_module_version: "0.1.0"
})
_view_module_version: "0.1.0",
}),
});

var CaptumInsightsView = widgets.DOMWidgetView.extend({
initialize() {
const $app = document.createElement("div");
ReactDOM.render(<Widget backbone={this} />, $app);
this.el.append($app);
}
},
});

export { Widget as default, CaptumInsightsModel, CaptumInsightsView };

0 comments on commit b4ba6f9

Please sign in to comment.