Skip to content

Commit

Permalink
Use return_dict in REST server for more flexibility (#1745)
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez authored Mar 9, 2020
1 parent e442f3f commit d14613d
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 11 deletions.
2 changes: 1 addition & 1 deletion onmt/bin/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def translate():
for i in range(len(trans)):
response = {"src": inputs[i // n_best]['src'], "tgt": trans[i],
"n_best": n_best, "pred_score": scores[i]}
if aligns[i] is not None:
if aligns[i][0] is not None:
response["align"] = aligns[i]
out[i % n_best].append(response)
except ServerModelError as e:
Expand Down
7 changes: 4 additions & 3 deletions onmt/tests/test_translation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ def test_run(self):
for elem in scores:
self.assertIsInstance(elem, float)
self.assertIsInstance(aligns, list)
for align_string in aligns:
if align_string is not None:
self.assertIsInstance(align_string, string_types)
for align_list in aligns:
for align_string in align_list:
if align_string is not None:
self.assertIsInstance(align_string, string_types)
self.assertEqual(len(results), len(scores))
self.assertEqual(len(scores), len(inp) * n_best)
self.assertEqual(len(time), 1)
Expand Down
16 changes: 16 additions & 0 deletions onmt/translate/process_zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,47 @@
import pkuseg


def wrap_str_func(func):
"""
Wrapper to apply str function to the proper key of return_dict.
"""
def wrapper(some_dict):
some_dict["seg"] = [func(item) for item in some_dict["seg"]]
return some_dict
return wrapper


# Chinese segmentation
@wrap_str_func
def zh_segmentator(line):
return " ".join(pkuseg.pkuseg().cut(line))


# Chinese simplify -> Chinese traditional standard
@wrap_str_func
def zh_traditional_standard(line):
return HanLP.convertToTraditionalChinese(line)


# Chinese simplify -> Chinese traditional (HongKong)
@wrap_str_func
def zh_traditional_hk(line):
return HanLP.s2hk(line)


# Chinese simplify -> Chinese traditional (Taiwan)
@wrap_str_func
def zh_traditional_tw(line):
return HanLP.s2tw(line)


# Chinese traditional -> Chinese simplify (v1)
@wrap_str_func
def zh_simplify(line):
return HanLP.convertToSimplifiedChinese(line)


# Chinese traditional -> Chinese simplify (v2)
@wrap_str_func
def zh_simplify_v2(line):
return SnowNLP(line).han
42 changes: 35 additions & 7 deletions onmt/translate/translation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def run(self, inputs):
head_spaces = []
tail_spaces = []
sslength = []
all_preprocessed = []
for i, inp in enumerate(inputs):
src = inp['src']
if src.strip() == "":
Expand All @@ -406,9 +407,12 @@ def run(self, inputs):
if match_after is not None:
whitespaces_after = match_after.group(0)
head_spaces.append(whitespaces_before)
preprocessed_src = self.maybe_preprocess(src.strip())
tok = self.maybe_tokenize(preprocessed_src)
texts.append(tok)
# every segment becomes a dict for flexibility purposes
seg_dict = self.maybe_preprocess(src.strip())
all_preprocessed.append(seg_dict)
for seg in seg_dict["seg"]:
tok = self.maybe_tokenize(seg)
texts.append(tok)
sslength.append(len(tok.split()))
tail_spaces.append(whitespaces_after)

Expand Down Expand Up @@ -453,7 +457,9 @@ def flatten_list(_list): return sum(_list, [])
for result, src in zip(results, tiled_texts)]

aligns = [align for _, align in results]
results = [self.maybe_postprocess(seq) for seq, _ in results]
rebuilt_segs, scores, aligns = self.rebuild_seg_packages(
all_preprocessed, results, scores, aligns)
results = [self.maybe_postprocess(seg) for seg in rebuilt_segs]

# build back results with empty texts
for i in empty_indices:
Expand All @@ -470,6 +476,24 @@ def flatten_list(_list): return sum(_list, [])
self.logger.info("Translation Results: %d", len(results))
return results, scores, self.opt.n_best, timer.times, aligns

def rebuild_seg_packages(self, all_preprocessed, results, scores, aligns):
"""
Rebuild proper segment packages based on initial n_seg.
"""
offset = 0
rebuilt_segs = []
avg_scores = []
merged_aligns = []
for seg_dict in all_preprocessed:
seg_dict["seg"] = list(
list(zip(*results))[0][offset:offset+seg_dict["n_seg"]])
rebuilt_segs.append(seg_dict)
avg_scores.append(sum(
scores[offset:offset+seg_dict["n_seg"]])/seg_dict["n_seg"])
merged_aligns.append(aligns[offset:offset+seg_dict["n_seg"]])
offset += seg_dict["n_seg"]
return rebuilt_segs, avg_scores, merged_aligns

def do_timeout(self):
"""Timeout function that frees GPU memory.
Expand Down Expand Up @@ -535,7 +559,11 @@ def maybe_preprocess(self, sequence):
"""Preprocess the sequence (or not)
"""

if type(sequence) is str:
sequence = {
"seg": [sequence],
"n_seg": 1
}
if self.preprocess_opt is not None:
return self.preprocess(sequence)
return sequence
Expand Down Expand Up @@ -666,10 +694,10 @@ def maybe_postprocess(self, sequence):
"""Postprocess the sequence (or not)
"""

if self.postprocess_opt is not None:
return self.postprocess(sequence)
return sequence
else:
return sequence["seg"][0]

def postprocess(self, sequence):
"""Preprocess a single sequence.
Expand Down

0 comments on commit d14613d

Please sign in to comment.