From 62541863119482f4f6c2e597305d99723082d954 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Thu, 10 Oct 2024 16:28:30 -0400 Subject: [PATCH] training code for hybrid-autoregressive inference model Signed-off-by: Hainan Xu --- nemo/collections/asr/modules/rnnt.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 2355cfb7005b..cc6ec0452ece 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -1249,6 +1249,10 @@ class RNNTJoint(rnnt_abstract.AbstractRNNTJoint, Exportable, AdapterModuleMixin) fused_batch_size: Optional int, required if `fuse_loss_wer` flag is set. Determines the size of the sub-batches. Should be any value below the actual batch size per GPU. + masking_prob: Optional float, indicating the probability of masking out decoder output in HAINAN + (Hybrid Autoregressive Inference Transducer) model, described in https://arxiv.org/pdf/2410.02597 + Default to -1.0, which runs standard Joint network computation; if > 0, then masking out decoder output + with the specified probability. """ @property @@ -1313,6 +1317,7 @@ def __init__( fuse_loss_wer: bool = False, fused_batch_size: Optional[int] = None, experimental_fuse_loss_wer: Any = None, + masking_prob: float = -1.0, ): super().__init__() @@ -1322,6 +1327,10 @@ def __init__( self._num_extra_outputs = num_extra_outputs self._num_classes = num_classes + 1 + num_extra_outputs # 1 is for blank + self.masking_prob = masking_prob + if self.masking_prob > 0.0: + assert self.masking_prob < 1.0, "masking_prob must be between 0 and 1" + if experimental_fuse_loss_wer is not None: # Override fuse_loss_wer from deprecated argument fuse_loss_wer = experimental_fuse_loss_wer @@ -1578,6 +1587,13 @@ def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tens """ f = f.unsqueeze(dim=2) # (B, T, 1, H) g = g.unsqueeze(dim=1) # (B, 1, U, H) + + if self.training and self.masking_prob > 0: + [B, _, U, _] = g.shape + rand = torch.rand([B, 1, U, 1]).to(g.device) + rand = torch.gt(rand, self.masking_prob) + g = g * rand + inp = f + g # [B, T, U, H] del f, g