Skip to content

Commit

Permalink
add decoding args
Browse files Browse the repository at this point in the history
  • Loading branch information
Hainan Xu committed Sep 3, 2024
1 parent cd2eab6 commit b516ce3
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 305 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/rnnt_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.cfg.decoding = self.set_decoding_type_according_to_loss(self.cfg.decoding)
# Setup decoding object
self.decoding = RNNTBPEDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, decoder2=self.decoder2, joint2=self.joint2,
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer,
)

# Setup wer object
Expand Down Expand Up @@ -465,7 +465,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig):
decoding_cfg = self.set_decoding_type_according_to_loss(decoding_cfg)

self.decoding = RNNTBPEDecoding(
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, decoder2=self.decoder2, joint2=self.joint2,
decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer,
)

self.wer = WER(
Expand Down
120 changes: 10 additions & 110 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,35 +68,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Update config values required by components dynamically
with open_dict(self.cfg.decoder):
self.cfg.decoder.vocab_size = len(self.cfg.labels)
if 'decoder2' in self.cfg:
with open_dict(self.cfg.decoder2):
self.cfg.decoder2.vocab_size = len(self.cfg.labels)

with open_dict(self.cfg.joint):
self.cfg.joint.num_classes = len(self.cfg.labels)
self.cfg.joint.vocabulary = self.cfg.labels
self.cfg.joint.jointnet.encoder_hidden = self.cfg.model_defaults.enc_hidden
self.cfg.joint.jointnet.pred_hidden = self.cfg.model_defaults.pred_hidden
if 'joint2' in self.cfg:
with open_dict(self.cfg.joint2):
self.cfg.joint2.num_classes = len(self.cfg.labels)
self.cfg.joint2.vocabulary = self.cfg.labels
self.cfg.joint2.jointnet.encoder_hidden = self.cfg.model_defaults.enc_hidden
self.cfg.joint2.jointnet.pred_hidden = self.cfg.model_defaults.pred_hidden

self.decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder)
self.joint = EncDecRNNTModel.from_config_dict(self.cfg.joint)

self.decoder2 = None
self.joint2 = None
self.run_backward = False
if 'decoder2' in self.cfg:
assert 'joint2' in self.cfg
assert not self.joint.fuse_loss_wer
self.decoder2 = EncDecRNNTModel.from_config_dict(self.cfg.decoder2)
self.joint2 = EncDecRNNTModel.from_config_dict(self.cfg.joint2)
self.run_backward = True

# Setup RNNT Loss
loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None))

Expand All @@ -119,22 +100,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

self.cfg.decoding = self.set_decoding_type_according_to_loss(self.cfg.decoding)
# Setup decoding objects
if not self.run_backward:
self.decoding = RNNTDecoding(
decoding_cfg=self.cfg.decoding,
decoder=self.decoder,
joint=self.joint,
vocabulary=self.joint.vocabulary,
)
else:
self.decoding = RNNTDecoding(
decoding_cfg=self.cfg.decoding,
decoder=self.decoder,
decoder2=self.decoder2,
joint=self.joint,
joint2=self.joint2,
vocabulary=self.joint.vocabulary,
)
self.decoding = RNNTDecoding(
decoding_cfg=self.cfg.decoding,
decoder=self.decoder,
joint=self.joint,
vocabulary=self.joint.vocabulary,
)
# Setup WER calculation
self.wer = WER(
decoding=self.decoding,
Expand Down Expand Up @@ -716,19 +687,6 @@ def training_step(self, batch, batch_nb):

# During training, loss must be computed, so decoder forward is necessary
decoder, target_length, states = self.decoder(targets=transcript, target_length=transcript_len)
if self.run_backward:
signal_len_list = signal_len.tolist()
transcript_len_list = transcript_len.tolist()

encoded2 = encoded * 1.0
for i in range(len(signal_len_list)):
encoded2[i,:,:encoded_len[i]] = torch.flip(encoded2[i,:,:encoded_len[i]], dims=(-1,))

transcript2 = transcript * 1
for i in range(len(transcript_len_list)):
transcript2[i,:transcript_len_list[i]] = torch.flip(transcript2[i,:transcript_len_list[i]], dims=(-1,))

decoder2, target_length2, states2 = self.decoder2(targets=transcript2, target_length=transcript_len)

if hasattr(self, '_trainer') and self._trainer is not None:
log_every_n_steps = self._trainer.log_every_n_steps
Expand All @@ -739,68 +697,10 @@ def training_step(self, batch, batch_nb):

# If experimental fused Joint-Loss-WER is not used
if not self.joint.fuse_loss_wer:
if not self.run_backward:
joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder)
loss_value = self.loss(
log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length
)
else:
if random.uniform(0.0, 1.0) < 0.5:
rand = random.uniform(0.0, 1.0)
joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder)

joint2 = self.joint2(encoder_outputs=encoded2, decoder_outputs=decoder2)

# joint2 = torch.flip(joint2, dims=(1,))
# joint[:,:-1,:,:self.joint.num_extra_outputs] = (joint[:,:-1,:,:self.joint.num_extra_outputs] + joint2[:,1:,:,:self.joint.num_extra_outputs]) / 2.0
# joint2[:,1:,:,:self.joint.num_extra_outputs] = joint[:,:-1,:,:self.joint.num_extra_outputs]
# joint2 = torch.flip(joint2, dims=(1,))

forward_loss_value = self.loss(
log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length
)

backward_loss_value = self.loss(
log_probs=joint2, targets=transcript2, input_lengths=encoded_len, target_lengths=target_length
)

loss_value = rand * forward_loss_value + backward_loss_value * (1- rand)
else:
rand = random.uniform(0.0, 1.0)
joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder)

joint2 = self.joint2(encoder_outputs=encoded2, decoder_outputs=decoder2)

joint2 = torch.flip(joint2, dims=(1,))
joint[:,:-1,:,:self.joint.num_extra_outputs] = (joint[:,:-1,:,:self.joint.num_extra_outputs] + joint2[:,1:,:,:self.joint.num_extra_outputs]) / 2.0
joint2[:,1:,:,:self.joint.num_extra_outputs] = joint[:,:-1,:,:self.joint.num_extra_outputs]
joint2 = torch.flip(joint2, dims=(1,))

forward_loss_value = self.loss(
log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length
)

backward_loss_value = self.loss(
log_probs=joint2, targets=transcript2, input_lengths=encoded_len, target_lengths=target_length
)

loss_value = rand * forward_loss_value + backward_loss_value * (1- rand)

# joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder)
# B, T, U, _ = joint.shape
# joint2 = self.joint2(encoder_outputs=encoded2, decoder_outputs=decoder2)
# joint2 = torch.flip(joint2, dims=(1,))
# rand2 = torch.rand([B, T - 1, U, 1]).to(joint.device)
# rand2 = torch.gt(rand2, 0.5)
# if random.uniform(0.0, 1.0) < 0.5:
# joint[:,:-1,:,:] = joint[:,:-1,:,:] + joint2[:,1:,:,:] * rand2
# else:
# joint[:,:-1,:,:] = joint[:,:-1,:,:] + joint2[:,1:,:,:]
# loss_value = self.loss(
# log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length
# )


joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder)
loss_value = self.loss(
log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length
)

# Add auxiliary losses, if registered
loss_value = self.add_auxiliary_losses(loss_value)
Expand Down
15 changes: 6 additions & 9 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class AbstractRNNTDecoding(ConfidenceMixin):
blank_id: The id of the RNNT blank token.
"""

def __init__(self, decoding_cfg, decoder, joint, blank_id: int, decoder2, joint2):
def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
super(AbstractRNNTDecoding, self).__init__()

# Convert dataclass to config object
Expand All @@ -205,6 +205,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, decoder2, joint2
self.num_extra_outputs = joint.num_extra_outputs
self.big_blank_durations = self.cfg.get("big_blank_durations", None)
self.durations = self.cfg.get("durations", None)
self.decoding_type = self.cfg.get("decoding_type", None)

self.compute_hypothesis_token_set = self.cfg.get("compute_hypothesis_token_set", False)
self.compute_langs = decoding_cfg.get('compute_langs', False)
self.preserve_alignments = self.cfg.get('preserve_alignments', None)
Expand Down Expand Up @@ -291,10 +293,9 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, decoder2, joint2
self.decoding = rnnt_greedy_decoding.GreedyTDTInfer(
decoder_model=decoder,
joint_model=joint,
decoder_model2=decoder2,
joint_model2=joint2,
blank_index=self.blank_id,
durations=self.durations,
decoding_type=self.decoding_type,
max_symbols_per_step=(
self.cfg.greedy.get('max_symbols', None)
or self.cfg.greedy.get('max_symbols_per_step', None)
Expand Down Expand Up @@ -1182,8 +1183,6 @@ def __init__(
decoder,
joint,
vocabulary,
decoder2=None,
joint2=None,
):
# we need to ensure blank is the last token in the vocab for the case of RNNT and Multi-blank RNNT.
blank_id = len(vocabulary) + joint.num_extra_outputs
Expand All @@ -1198,8 +1197,6 @@ def __init__(
decoder=decoder,
joint=joint,
blank_id=blank_id,
decoder2=decoder2,
joint2=joint2,
)

if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer):
Expand Down Expand Up @@ -1453,7 +1450,7 @@ class RNNTBPEDecoding(AbstractRNNTDecoding):
tokenizer: The tokenizer which will be used for decoding.
"""

def __init__(self, decoding_cfg, decoder, joint, decoder2, joint2, tokenizer: TokenizerSpec):
def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec):
blank_id = tokenizer.tokenizer.vocab_size # RNNT or TDT models.

# multi-blank RNNTs
Expand All @@ -1463,7 +1460,7 @@ def __init__(self, decoding_cfg, decoder, joint, decoder2, joint2, tokenizer: To
self.tokenizer = tokenizer

super(RNNTBPEDecoding, self).__init__(
decoding_cfg=decoding_cfg, decoder=decoder, joint=joint, decoder2=decoder2, joint2=joint2, blank_id=blank_id
decoding_cfg=decoding_cfg, decoder=decoder, joint=joint, blank_id=blank_id
)

if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer):
Expand Down
Loading

0 comments on commit b516ce3

Please sign in to comment.