Skip to content

Commit

Permalink
working
Browse files Browse the repository at this point in the history
  • Loading branch information
Hainan Xu committed Mar 11, 2024
1 parent a009c91 commit 8c54a66
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 55 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/asr/data/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,7 @@ def __init__(
super().__init__()

def _collate_fn(self, batch):
return _speech_collate_fn(batch[0], self.wrapped_dataset.pad_id)
return _speech_collate_fn(batch[0], self.wrapped_dataset.asr_pad_id, self.wrapped_dataset.st_pad_id)

def __iter__(self):
return BucketingIterator(
Expand Down
30 changes: 13 additions & 17 deletions nemo/collections/asr/models/rnnt_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
cfg = model_utils.maybe_update_config_version(cfg)

# Tokenizer is necessary for this model
if 'asr_tokenizer' not in cfg 'tokenizer' not in cfg:
if 'asr_tokenizer' not in cfg and 'tokenizer' not in cfg:
raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !")

if not isinstance(cfg, DictConfig):
Expand All @@ -299,9 +299,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
else:
assert(0)

# Initialize a dummy vocabulary
# vocabulary = self.tokenizer.tokenizer.get_vocab()
# print("self.tokenizer.tokenizer is", self.tokenizer.tokenizer)
if hasattr(self.asr_tokenizer.tokenizer, 'vocab') and callable(self.asr_tokenizer.tokenizer.vocab):
asr_vocabulary = self.asr_tokenizer.tokenizer.vocab()
else:
Expand All @@ -311,7 +308,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
st_vocabulary = self.st_tokenizer.tokenizer.vocab()
else:
st_vocabulary = self.st_tokenizer.tokenizer.get_vocab()


# Set the new vocabulary
with open_dict(cfg):
Expand All @@ -334,19 +330,19 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.asr_decoding = RNNTBPEDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.asr_tokenizer,
)
if self.st_encoder != None:

self.st_decoding = RNNTBPEDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.st_decoder, joint=self.st_joint, tokenizer=self.st_tokenizer,
)
self.st_decoding = RNNTBPEDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.st_decoder, joint=self.st_joint, tokenizer=self.st_tokenizer,
)

# Setup wer object
self.bleu = BLEU(
decoding=self.st_decoding,
batch_dim_index=0,
# use_cer=self._cfg.get('use_cer', False),
log_prediction=self._cfg.get('log_prediction', True),
dist_sync_on_step=True,
)
# Setup wer object
self.bleu = BLEU(
decoding=self.st_decoding,
batch_dim_index=0,
log_prediction=self._cfg.get('log_prediction', True),
dist_sync_on_step=True,
)

self.wer = WER(
decoding=self.asr_decoding,
Expand All @@ -358,7 +354,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

# Setup fused Joint step if flag is set
if self.joint.fuse_loss_wer:
assert(0)
pass

def change_vocabulary(
self,
Expand Down
87 changes: 50 additions & 37 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Initialize components
self.preprocessor = EncDecRNNTModel.from_config_dict(self.cfg.preprocessor)
self.encoder = EncDecRNNTModel.from_config_dict(self.cfg.encoder)
self.st_encoder = EncDecRNNTModel.from_config_dict(self.cfg.st_encoder)
if 'st_encoder' in self.cfg:
self.st_encoder = EncDecRNNTModel.from_config_dict(self.cfg.st_encoder)
else:
self.st_encoder = None

# Update config values required by components dynamically
with open_dict(self.cfg.decoder):
Expand All @@ -77,24 +80,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder)
self.joint = EncDecRNNTModel.from_config_dict(self.cfg.joint)

with open_dict(self.cfg.decoder):
self.cfg.decoder.vocab_size = len(self.cfg.st_labels)
with open_dict(self.cfg.joint):
self.cfg.joint.num_classes = len(self.cfg.st_labels)
self.cfg.joint.vocabulary = self.cfg.st_labels

self.st_decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder)
self.st_joint = EncDecRNNTModel.from_config_dict(self.cfg.joint)

# Setup RNNT Loss
loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None))

asr_num_classes = self.joint.num_classes_with_blank - 1 # for standard RNNT and multi-blank
st_num_classes = self.st_joint.num_classes_with_blank - 1 # for standard RNNT and multi-blank

if loss_name == 'tdt':
asr_num_classes = asr_num_classes - self.joint.num_extra_outputs
st_num_classes = st_num_classes - self.joint.num_extra_outputs

self.asr_loss = RNNTLoss(
num_classes=asr_num_classes,
Expand All @@ -103,13 +95,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
reduction=self.cfg.get("rnnt_reduction", "mean_batch"),
)

self.st_loss = RNNTLoss(
num_classes=st_num_classes,
loss_name=loss_name,
loss_kwargs=loss_kwargs,
reduction=self.cfg.get("rnnt_reduction", "mean_batch"),
)

if hasattr(self.cfg, 'spec_augment') and self._cfg.spec_augment is not None:
self.spec_augmentation = EncDecRNNTModel.from_config_dict(self.cfg.spec_augment)
else:
Expand All @@ -120,9 +105,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.decoding = RNNTDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary,
)
self.st_decoding = RNNTDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.st_decoder, joint=self.st_joint, vocabulary=self.st_joint.vocabulary,
)
# Setup WER calculation
self.wer = WER(
decoding=self.decoding,
Expand All @@ -131,12 +113,39 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
log_prediction=self._cfg.get('log_prediction', True),
dist_sync_on_step=True,
)
self.bleu = BLEU(
decoding=self.st_decoding,
batch_dim_index=0,
log_prediction=self._cfg.get('log_prediction', True),
dist_sync_on_step=True,
)

if self.st_encoder is not None:
with open_dict(self.cfg.decoder):
self.cfg.decoder.vocab_size = len(self.cfg.st_labels)
with open_dict(self.cfg.joint):
self.cfg.joint.num_classes = len(self.cfg.st_labels)
self.cfg.joint.vocabulary = self.cfg.st_labels

self.st_decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder)
self.st_joint = EncDecRNNTModel.from_config_dict(self.cfg.joint)
st_num_classes = self.st_joint.num_classes_with_blank - 1 # for standard RNNT and multi-blank

if loss_name == 'tdt':
st_num_classes = st_num_classes - self.joint.num_extra_outputs

self.st_loss = RNNTLoss(
num_classes=st_num_classes,
loss_name=loss_name,
loss_kwargs=loss_kwargs,
reduction=self.cfg.get("rnnt_reduction", "mean_batch"),
)

self.st_decoding = RNNTDecoding(
decoding_cfg=self.cfg.decoding, decoder=self.st_decoder, joint=self.st_joint, vocabulary=self.st_joint.vocabulary,
)

self.bleu = BLEU(
decoding=self.st_decoding,
batch_dim_index=0,
log_prediction=self._cfg.get('log_prediction', True),
dist_sync_on_step=True,
)


# Whether to compute loss during evaluation
if 'compute_eval_loss' in self.cfg:
Expand All @@ -148,7 +157,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
if self.joint.fuse_loss_wer or (
self.decoding.joint_fused_batch_size is not None and self.decoding.joint_fused_batch_size > 0
):
assert(0)
pass
# assert(0)

# Setup optimization normalization (if provided in config)
self.setup_optim_normalization()
Expand Down Expand Up @@ -635,7 +645,10 @@ def forward(
processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)

encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
st_encoded, st_encoded_len = self.st_encoder(audio_signal=encoded, length=encoded_len)
if self.st_encoder is not None:
st_encoded, st_encoded_len = self.st_encoder(audio_signal=encoded, length=encoded_len)
else:
st_encoded, st_encoded_len = None, None

return encoded, encoded_len, st_encoded, st_encoded_len

Expand Down Expand Up @@ -697,18 +710,18 @@ def training_step(self, batch, batch_nb):

if (sample_id + 1) % log_every_n_steps == 0:
self.bleu.update(
predictions=encoded,
predictions_lengths=encoded_len,
targets=asr_transcript,
targets_lengths=asr_transcript_len,
predictions=st_encoded,
predictions_lengths=st_encoded_len,
targets=st_transcript,
targets_lengths=st_transcript_len,
)
bleu = self.bleu.compute(return_all_metrics=False)['bleu']
self.bleu.reset()
self.wer.update(
predictions=encoded,
predictions_lengths=encoded_len,
targets=transcript,
targets_lengths=transcript_len,
targets=asr_transcript,
targets_lengths=asr_transcript_len,
)
_, scores, words = self.wer.compute()
self.wer.reset()
Expand Down Expand Up @@ -763,7 +776,7 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0):
joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder)

asr_loss_value = self.asr_loss(
log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length
log_probs=joint, targets=asr_transcript, input_lengths=encoded_len, target_lengths=target_length
)

st_decoder, st_target_length, st_states = self.st_decoder(targets=st_transcript, target_length=st_transcript_len)
Expand Down

0 comments on commit 8c54a66

Please sign in to comment.