Skip to content

Commit

Permalink
training code for hybrid-autoregressive inference model
Browse files Browse the repository at this point in the history
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
  • Loading branch information
Hainan Xu committed Oct 10, 2024
1 parent 861f805 commit 6254186
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,10 @@ class RNNTJoint(rnnt_abstract.AbstractRNNTJoint, Exportable, AdapterModuleMixin)
fused_batch_size: Optional int, required if `fuse_loss_wer` flag is set. Determines the size of the
sub-batches. Should be any value below the actual batch size per GPU.
masking_prob: Optional float, indicating the probability of masking out decoder output in HAINAN
(Hybrid Autoregressive Inference Transducer) model, described in https://arxiv.org/pdf/2410.02597
Default to -1.0, which runs standard Joint network computation; if > 0, then masking out decoder output
with the specified probability.
"""

@property
Expand Down Expand Up @@ -1313,6 +1317,7 @@ def __init__(
fuse_loss_wer: bool = False,
fused_batch_size: Optional[int] = None,
experimental_fuse_loss_wer: Any = None,
masking_prob: float = -1.0,
):
super().__init__()

Expand All @@ -1322,6 +1327,10 @@ def __init__(
self._num_extra_outputs = num_extra_outputs
self._num_classes = num_classes + 1 + num_extra_outputs # 1 is for blank

self.masking_prob = masking_prob
if self.masking_prob > 0.0:
assert self.masking_prob < 1.0, "masking_prob must be between 0 and 1"

if experimental_fuse_loss_wer is not None:
# Override fuse_loss_wer from deprecated argument
fuse_loss_wer = experimental_fuse_loss_wer
Expand Down Expand Up @@ -1578,6 +1587,13 @@ def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tens
"""
f = f.unsqueeze(dim=2) # (B, T, 1, H)
g = g.unsqueeze(dim=1) # (B, 1, U, H)

if self.training and self.masking_prob > 0:
[B, _, U, _] = g.shape
rand = torch.rand([B, 1, U, 1]).to(g.device)
rand = torch.gt(rand, self.masking_prob)
g = g * rand

inp = f + g # [B, T, U, H]

del f, g
Expand Down

0 comments on commit 6254186

Please sign in to comment.