-
Notifications
You must be signed in to change notification settings - Fork 1
/
lime_helper.py
70 lines (56 loc) · 2.42 KB
/
lime_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import sis
from lime.lime_text import LimeTextExplainer
## Based on some code at:
## http://data4thought.com/deep-lime.html
class TextPipeline(object):
def __init__(self, model, tokenizer, padder):
self.model = model
self.tokenizer = tokenizer
self.padder = padder
# `texts` is a list of d strings
def predict_proba(self, texts):
tokenized_texts = self.tokenizer.texts_to_sequences(texts)
padded_seqs = self.padder(tokenized_texts)
predictions = self.model.predict(padded_seqs, batch_size=128)
return predictions
class DNASequencePipeline(object):
def __init__(self, model, encoder):
self.model = model
self.encoder = encoder
# `seqs` is a list of d strings, each base separated by space
# (may also include 'UNKWORDZ' token)
# Returns (d x 1) list of predictions
def predict_proba(self, seqs):
seqs_parsed = [s.replace('UNKWORDZ', 'N').replace(' ', '') \
for s in seqs]
encoded_seqs = [self.encoder(s) for s in seqs_parsed]
predictions = sis.predict_for_embed_sequence(encoded_seqs,
self.model, batch_size=5000)
return predictions.reshape(-1, 1)
# For beer reviews data
def make_pipeline(model, tokenizer, max_words):
pipeline = TextPipeline(model, tokenizer,
lambda s: sis.pad_sequences(s, max_words=max_words))
return pipeline
# For DNA sequence data
def make_pipeline_dna_seq(model, encoder):
pipeline = DNASequencePipeline(model, encoder)
return pipeline
def make_explainer(verbose=False):
explainer = LimeTextExplainer(class_names=['prediction'],
split_expression=' ',
bow=False,
verbose=verbose)
return explainer
def explain(text, explainer, pipeline, num_features=500, num_samples=5000):
explanation = explainer.explain_instance(text,
pipeline.predict_proba,
labels=(0,),
num_features=num_features,
num_samples=num_samples)
return explanation
def extract_word_order(explanation):
word_order, weight = zip(*explanation.as_map()[0][::-1])
# should be ordered by absolute value of weights, highest weight last
return word_order