Skip to content

Commit

Permalink
[!96][SIMULTANEOUS] Add agent Wait-k with tags
Browse files Browse the repository at this point in the history
# Why is the change needed?
Currently, the wait-k agent does not support the module for the tag prediction of the parallel model introduced in the paper ["Joint Speech Translation and Named Entity Recognition"](https://arxiv.org/pdf/2210.11987.pdf).

# What changes does the patch introduce?
Implements the wait-k inference with tags produced by the parallel model.

# How was this patch tested?
UTs and manual runs
  • Loading branch information
sarapapi authored and mgaido91 committed Sep 27, 2023
1 parent 4b75f39 commit 9892515
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def _generate(
prefix_tokens: Optional[Tensor] = None,
constraints: Optional[Tensor] = None,
bos_token: Optional[int] = None,
pre_computed_encoder_outs: Optional[Tensor] = None,
):
incremental_states = torch.jit.annotate(
List[Dict[str, Dict[str, Optional[Tensor]]]],
Expand Down Expand Up @@ -147,7 +148,10 @@ def _generate(
self.min_len <= max_len
), "min_len cannot be larger than max_len, please adjust these!"
# compute the encoder output for each beam
encoder_outs = self.model.forward_encoder(net_input)
if pre_computed_encoder_outs is not None:
encoder_outs = pre_computed_encoder_outs
else:
encoder_outs = self.model.forward_encoder(net_input)

# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2023 FBK

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import torch
from examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk import WaitkAgent

try:
from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS
from simuleval.agents import SpeechAgent
from simuleval.states import ListEntry, SpeechStates
except ImportError:
print("Please install simuleval 'pip install simuleval'")


class WaitkAgentWithTags(WaitkAgent):
def load_model_vocab(self, args):
super().load_model_vocab(args)
self.tags = self.task.data_cfg.tags

def initialize_states(self, states):
super().initialize_states(states)
# Store previous output tokens without considering emitted tags
states.prev_toks = []
states.prev_tag = 0

def _get_prefix(self, states):
if states.prev_toks:
prefix_tokens = torch.tensor([states.prev_toks], dtype=torch.int64)
if self.prefix_token_idx is not None:
return torch.cat(
(torch.LongTensor([[self.prefix_token_idx]]), prefix_tokens), dim=1)
return prefix_tokens
else:
if self.prefix_token_idx is not None:
return torch.LongTensor([[self.prefix_token_idx]])
return None

def add_tags_to_target(self, states, hypo_tag):
hypo_tok = states.write
states.write = []
for token, tag in zip(hypo_tok, hypo_tag):
if tag != states.prev_tag:
if states.prev_tag == 0:
states.write.append(torch.tensor(
self.tgtdict.index(f"<{self.tags[tag - 1]}>"), dtype=token.dtype))
elif tag == 0:
states.write.append(torch.tensor(
self.tgtdict.index(f"</{self.tags[states.prev_tag - 1]}>"), dtype=token.dtype))
else:
states.write.append(torch.tensor(
self.tgtdict.index(f"</{self.tags[states.prev_tag - 1]}>"), dtype=token.dtype))
states.write.append(torch.tensor(
self.tgtdict.index(f"<{self.tags[tag - 1]}>"), dtype=token.dtype))
states.write.append(token)
states.prev_tag = tag

def new_hypo(self, states):
states.new_segment = False
prefix_tokens = self._get_prefix(states)
prefix_len = self._get_prefix_len(prefix_tokens)
hypo = self.generate_hypothesis(states, prefix_tokens)
hypo_tokens = hypo['tokens'].int().cpu()
new_hypo_tokens = hypo_tokens[prefix_len:]
hypo_tags = hypo['tags'].int().cpu()
new_hypo_tags = hypo_tags[prefix_len:]
return new_hypo_tokens, new_hypo_tags

def waitk_prediction(self, states):
new_hypo, new_tags = self.new_hypo(states)
selected_n_words = states.n_audio_words - (states.n_predicted_words + self.waitk)
states.n_predicted_words += selected_n_words
states.write = self._select_words(new_hypo, selected_n_words)
if states.write:
states.prev_toks += states.write
new_tags = new_tags[:len(states.write)]
if sum(new_tags != 0) > 0 or states.prev_tag != 0:
self.add_tags_to_target(states, new_tags)
return True
return False

def _emit_remaining_tokens(self, states):
final_hypo, final_tags = self.new_hypo(states)
states.write = final_hypo
if sum(final_tags != 0) > 0 or states.prev_tag != 0:
self.add_tags_to_target(states, final_tags)
return WRITE_ACTION
2 changes: 1 addition & 1 deletion fbk_uts/simultaneous/test_base_simulst_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def initialize_agent(agent, args):
agent.feature_extractor = OnlineFeatureExtractor(args)
agent.eos = "<s>"
agent.eos_idx = 0
agent.prefix_token_idx = 0
agent.prefix_token_idx = None
agent.tgtdict = Dictionary()
agent.tgtdict.add_symbol(BOW_PREFIX + "I")
agent.tgtdict.add_symbol(BOW_PREFIX + "am")
Expand Down
121 changes: 121 additions & 0 deletions fbk_uts/simultaneous/test_waitk_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2023 FBK

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import unittest
from unittest.mock import patch
import copy

import torch

from examples.speech_to_text.simultaneous_translation.agents.base_simulst_agent import BOW_PREFIX
from examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags import WaitkAgentWithTags

from fbk_uts.simultaneous.test_base_simulst_agent import BaseSTAgentTestCase


class WaitkSimulSTWithTagsTestCase(BaseSTAgentTestCase, unittest.TestCase):
def add_extra_args(self):
self.args.waitk = 0
self.args.parallel = False

def create_agent(self):
return WaitkAgentWithTags(self.args)

@patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.'
'WaitkAgentWithTags.load_model_vocab')
@patch('examples.speech_to_text.simultaneous_translation.agents.base_simulst_agent.'
'FairseqSimulSTAgent.__init__')
def setUp(self, mock_load_model_vocab, mock_simulst_agent_init):
mock_simulst_agent_init.return_value = None
mock_load_model_vocab.return_value = None
self.base_init()
self.hypo = BOW_PREFIX + "quokka " + BOW_PREFIX + "is " + BOW_PREFIX + "pretty ."
self.agent.tgtdict.add_symbol(BOW_PREFIX + "is")
self.agent.tgtdict.add_symbol(BOW_PREFIX + "pretty")
self.agent.tgtdict.add_symbol("<PERSON>")
self.agent.tgtdict.add_symbol("</PERSON>")
self.agent.tags = ["", "", "", "", "", "", "", "", "", "", "PERSON"]
self.encoded_hypo = self.agent.tgtdict.encode_line(self.hypo, add_if_not_exist=False)
self.predicted_tags = torch.tensor([self.agent.tgtdict.index("<PERSON>"), 0, 0, 0, 0])
self.states.n_audio_words = 3

@patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.'
'WaitkAgentWithTags.new_hypo')
def test_full_hypo(self, mock_new_hypo):
# Full hypothesis emitted
mock_new_hypo.return_value = self.encoded_hypo, self.predicted_tags
self.states.n_predicted_words = 0
WaitkAgentWithTags.waitk_prediction(self.agent, self.states)
self.assertEqual(self.states.write, [11, 7, 12, 9, 10, 8, 2])

@patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.'
'WaitkAgentWithTags.new_hypo')
def test_wait_1(self, mock_new_hypo):
# Partial hypothesis emitted
mock_new_hypo.return_value = self.encoded_hypo, self.predicted_tags
self.states.n_predicted_words = 0
new_agent = copy.deepcopy(self.agent)
new_agent.waitk = 1
WaitkAgentWithTags.waitk_prediction(new_agent, self.states)
self.assertEqual(self.states.write, [11, 7, 12, 9])

@patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.'
'WaitkAgentWithTags.new_hypo')
def test_wait_1_predicted_1(self, mock_new_hypo):
# Partial hypothesis emitted considering already predicted words
mock_new_hypo.return_value = self.encoded_hypo[1:], self.predicted_tags[1:]
new_agent = copy.deepcopy(self.agent)
new_agent.waitk = 1
self.states.n_predicted_words = 1
WaitkAgentWithTags.waitk_prediction(new_agent, self.states)
self.assertEqual(self.states.write, [9])

@patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.'
'WaitkAgentWithTags.new_hypo')
def test_wait_3(self, mock_new_hypo):
# No hypothesis emitted
mock_new_hypo.return_value = self.encoded_hypo, self.predicted_tags
new_agent = copy.deepcopy(self.agent)
new_agent.waitk = 3
self.states.n_predicted_words = 0
WaitkAgentWithTags.waitk_prediction(new_agent, self.states)
self.assertEqual(self.states.write, [])

@patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.'
'WaitkAgentWithTags.new_hypo')
def test_emit_remaining_tokens_with_tags(self, mock_new_hypo):
mock_new_hypo.return_value = self.encoded_hypo, self.predicted_tags
new_agent = copy.deepcopy(self.agent)
new_agent.waitk = 3
self.states.n_predicted_words = 0
WaitkAgentWithTags._emit_remaining_tokens(new_agent, self.states)
self.assertEqual(self.states.write, [11, 7, 12, 9, 10, 8, 2])

# Move tag towards the end (last word: "pretty")
mock_new_hypo.return_value = self.encoded_hypo, torch.tensor(
[0, 0, self.agent.tgtdict.index("<PERSON>"), 0, 0])
new_agent = copy.deepcopy(self.agent)
new_agent.waitk = 3
self.states.n_predicted_words = 0
WaitkAgentWithTags._emit_remaining_tokens(new_agent, self.states)
self.assertEqual(self.states.write, [7, 9, 11, 10, 12, 8, 2])

@patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.'
'WaitkAgentWithTags._emit_remaining_tokens')
def test_finish_read(self, mock_emit_remaining_tokens):
mock_emit_remaining_tokens.return_values = None
super().test_finish_read()


if __name__ == '__main__':
unittest.main()

0 comments on commit 9892515

Please sign in to comment.