-
Notifications
You must be signed in to change notification settings - Fork 6
/
model.py
155 lines (116 loc) · 5.94 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import logging
import tensorflow as tf
from neighbor_function import get_extractive_next_state_func
from neighbor_function import State
from lm.stitch import get_lm_metric_func
from cos.weighted_average import get_cos_metric_func
def get_model(metrics, mode='extractive'):
if mode not in ['extractive', 'exhaustive']:
raise ValueError('Model mode not supported: {}'.format(mode))
inputs = dict()
outputs = dict()
sentence = tf.placeholder(tf.int32, shape=[None], name="sentence")
inputs['sentence'] = sentence
sentence_length = tf.placeholder(tf.int32, shape=[], name="sentence_length")
inputs['sentence_length'] = sentence_length
num_steps = tf.placeholder(tf.int32, shape=[], name="num_steps")
inputs['num_steps'] = num_steps
initial_state = tf.placeholder(tf.int32, shape=[None, None], name="initial_state")
inputs['initial_state'] = initial_state
if mode == 'extractive':
initial_internal_state = tf.placeholder(tf.bool, shape=[None, None], name="initial_internal_state")
inputs['initial_internal_state'] = initial_internal_state
initial_state_tuple = State(state=initial_state, internal_state=initial_internal_state)
summary_length = tf.placeholder(tf.int32, shape=[], name="summary_length")
inputs['summary_length'] = summary_length
batch_size = tf.shape(initial_state)[0]
metric_funcs = dict()
if 'cos' in metrics:
logging.info('create metric function: cos')
get_cos_metric = get_cos_metric_func(tf.expand_dims(sentence, axis=0),
tf.expand_dims(sentence_length, axis=0),
metrics['cos'])
metric_funcs['cos'] = (get_cos_metric, metrics['cos']['weight'])
if 'lm' in metrics:
logging.info('create metric function: lm')
get_lm_metric = get_lm_metric_func(metrics['lm'])
metric_funcs['lm'] = (get_lm_metric, metrics['lm']['weight'])
get_score = get_score_func(metric_funcs)
def run_get_score():
_states = tf.expand_dims(initial_state_tuple.state, axis=0)
_scores = tf.expand_dims(get_score(initial_state_tuple.state), axis=0)
return _states, _scores
def run_get_score_exhaustive():
_state = initial_state
_score = get_score(initial_state)
return _state, _score
def run_sampler():
next_state = get_extractive_next_state_func(batch_size, sentence_length, summary_length, sentence)
hill_climber = get_hill_climber(get_score, num_steps)
_states, _scores = hill_climber(initial_state_tuple, next_state)
return _states, _scores
if mode == 'exhaustive':
state, score = run_get_score_exhaustive()
outputs['state'] = state
outputs['score'] = score
else:
cond = tf.logical_or(sentence_length <= summary_length, num_steps < 1)
states, scores = tf.cond(cond, true_fn=run_get_score, false_fn=run_sampler)
outputs['states'] = states
outputs['scores'] = scores
return inputs, outputs
def get_hill_climber(get_score, num_steps):
def hill_climber(initial_state, next_state):
def body(step, state_old, score_old, states_ta, scores_ta):
state_new = next_state(state_old)
score_new = get_score(state_new.state)
states_ta = states_ta.write(step, state_new.state)
scores_ta = scores_ta.write(step, score_new)
cond = tf.greater_equal(score_new, score_old)
state_next = tf.where(cond, x=state_new.state, y=state_old.state)
internal_state = tf.where(cond, x=state_new.internal_state, y=state_old.internal_state)
score_next = tf.where(cond, x=score_new, y=score_old)
state_next = State(state=state_next, internal_state=internal_state)
step += 1
return step, state_next, score_next, states_ta, scores_ta
def condition(step, unused_state, unused_score, unused_states_ta, unused_scores_ta):
return step < num_steps
initial_step = 0
initial_score = get_score(initial_state.state)
initial_states_ta = tf.TensorArray(dtype=initial_state.state.dtype,
size=num_steps,
element_shape=initial_state.state.shape,
dynamic_size=False)
initial_scores_ta = tf.TensorArray(dtype=initial_score.dtype,
size=num_steps,
element_shape=initial_score.shape,
dynamic_size=False)
initial_states_ta = initial_states_ta.write(initial_step, initial_state.state)
initial_scores_ta = initial_scores_ta.write(initial_step, initial_score)
initial_step += 1
res = tf.while_loop(cond=condition,
body=body,
loop_vars=[initial_step,
initial_state,
initial_score,
initial_states_ta,
initial_scores_ta],
back_prop=False)
states = res[3].stack()
scores = res[4].stack()
return states, scores
return hill_climber
def get_score_func(metric_funcs):
if len(metric_funcs) == 0:
return 'Need metrics to compute score'
def get_score(sample, sample_length=None):
if sample_length is None:
sample_length = tf.reduce_sum(tf.ones(shape=(tf.shape(sample)), dtype=tf.int32), axis=-1)
score = 1
for name, (metric_func, weight) in metric_funcs.items():
logging.info(f'score name: {name}')
logging.info(f'score weight: {weight}')
score_metric = metric_func(sample, sample_length)
score = score * tf.pow(score_metric, weight)
return score
return get_score