From 8c54a66dde9cc879576523b5cd7fcd9cf890b824 Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Sun, 10 Mar 2024 18:10:23 -0700 Subject: [PATCH] working --- nemo/collections/asr/data/audio_to_text.py | 2 +- .../collections/asr/models/rnnt_bpe_models.py | 30 +++---- nemo/collections/asr/models/rnnt_models.py | 87 +++++++++++-------- 3 files changed, 64 insertions(+), 55 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text.py b/nemo/collections/asr/data/audio_to_text.py index 5e3823a160f5..f881aa047307 100644 --- a/nemo/collections/asr/data/audio_to_text.py +++ b/nemo/collections/asr/data/audio_to_text.py @@ -1460,7 +1460,7 @@ def __init__( super().__init__() def _collate_fn(self, batch): - return _speech_collate_fn(batch[0], self.wrapped_dataset.pad_id) + return _speech_collate_fn(batch[0], self.wrapped_dataset.asr_pad_id, self.wrapped_dataset.st_pad_id) def __iter__(self): return BucketingIterator( diff --git a/nemo/collections/asr/models/rnnt_bpe_models.py b/nemo/collections/asr/models/rnnt_bpe_models.py index 274647b41931..5dea45c431d1 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models.py +++ b/nemo/collections/asr/models/rnnt_bpe_models.py @@ -285,7 +285,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): cfg = model_utils.maybe_update_config_version(cfg) # Tokenizer is necessary for this model - if 'asr_tokenizer' not in cfg 'tokenizer' not in cfg: + if 'asr_tokenizer' not in cfg and 'tokenizer' not in cfg: raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !") if not isinstance(cfg, DictConfig): @@ -299,9 +299,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): else: assert(0) - # Initialize a dummy vocabulary -# vocabulary = self.tokenizer.tokenizer.get_vocab() -# print("self.tokenizer.tokenizer is", self.tokenizer.tokenizer) if hasattr(self.asr_tokenizer.tokenizer, 'vocab') and callable(self.asr_tokenizer.tokenizer.vocab): asr_vocabulary = self.asr_tokenizer.tokenizer.vocab() else: @@ -311,7 +308,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): st_vocabulary = self.st_tokenizer.tokenizer.vocab() else: st_vocabulary = self.st_tokenizer.tokenizer.get_vocab() - # Set the new vocabulary with open_dict(cfg): @@ -334,19 +330,19 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.asr_decoding = RNNTBPEDecoding( decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.asr_tokenizer, ) + if self.st_encoder != None: - self.st_decoding = RNNTBPEDecoding( - decoding_cfg=self.cfg.decoding, decoder=self.st_decoder, joint=self.st_joint, tokenizer=self.st_tokenizer, - ) + self.st_decoding = RNNTBPEDecoding( + decoding_cfg=self.cfg.decoding, decoder=self.st_decoder, joint=self.st_joint, tokenizer=self.st_tokenizer, + ) - # Setup wer object - self.bleu = BLEU( - decoding=self.st_decoding, - batch_dim_index=0, -# use_cer=self._cfg.get('use_cer', False), - log_prediction=self._cfg.get('log_prediction', True), - dist_sync_on_step=True, - ) + # Setup wer object + self.bleu = BLEU( + decoding=self.st_decoding, + batch_dim_index=0, + log_prediction=self._cfg.get('log_prediction', True), + dist_sync_on_step=True, + ) self.wer = WER( decoding=self.asr_decoding, @@ -358,7 +354,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Setup fused Joint step if flag is set if self.joint.fuse_loss_wer: - assert(0) + pass def change_vocabulary( self, diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index c4c7809db55a..db4249b23152 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -62,7 +62,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Initialize components self.preprocessor = EncDecRNNTModel.from_config_dict(self.cfg.preprocessor) self.encoder = EncDecRNNTModel.from_config_dict(self.cfg.encoder) - self.st_encoder = EncDecRNNTModel.from_config_dict(self.cfg.st_encoder) + if 'st_encoder' in self.cfg: + self.st_encoder = EncDecRNNTModel.from_config_dict(self.cfg.st_encoder) + else: + self.st_encoder = None # Update config values required by components dynamically with open_dict(self.cfg.decoder): @@ -77,24 +80,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder) self.joint = EncDecRNNTModel.from_config_dict(self.cfg.joint) - with open_dict(self.cfg.decoder): - self.cfg.decoder.vocab_size = len(self.cfg.st_labels) - with open_dict(self.cfg.joint): - self.cfg.joint.num_classes = len(self.cfg.st_labels) - self.cfg.joint.vocabulary = self.cfg.st_labels - - self.st_decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder) - self.st_joint = EncDecRNNTModel.from_config_dict(self.cfg.joint) - # Setup RNNT Loss loss_name, loss_kwargs = self.extract_rnnt_loss_cfg(self.cfg.get("loss", None)) asr_num_classes = self.joint.num_classes_with_blank - 1 # for standard RNNT and multi-blank - st_num_classes = self.st_joint.num_classes_with_blank - 1 # for standard RNNT and multi-blank if loss_name == 'tdt': asr_num_classes = asr_num_classes - self.joint.num_extra_outputs - st_num_classes = st_num_classes - self.joint.num_extra_outputs self.asr_loss = RNNTLoss( num_classes=asr_num_classes, @@ -103,13 +95,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): reduction=self.cfg.get("rnnt_reduction", "mean_batch"), ) - self.st_loss = RNNTLoss( - num_classes=st_num_classes, - loss_name=loss_name, - loss_kwargs=loss_kwargs, - reduction=self.cfg.get("rnnt_reduction", "mean_batch"), - ) - if hasattr(self.cfg, 'spec_augment') and self._cfg.spec_augment is not None: self.spec_augmentation = EncDecRNNTModel.from_config_dict(self.cfg.spec_augment) else: @@ -120,9 +105,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.decoding = RNNTDecoding( decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, vocabulary=self.joint.vocabulary, ) - self.st_decoding = RNNTDecoding( - decoding_cfg=self.cfg.decoding, decoder=self.st_decoder, joint=self.st_joint, vocabulary=self.st_joint.vocabulary, - ) # Setup WER calculation self.wer = WER( decoding=self.decoding, @@ -131,12 +113,39 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): log_prediction=self._cfg.get('log_prediction', True), dist_sync_on_step=True, ) - self.bleu = BLEU( - decoding=self.st_decoding, - batch_dim_index=0, - log_prediction=self._cfg.get('log_prediction', True), - dist_sync_on_step=True, - ) + + if self.st_encoder is not None: + with open_dict(self.cfg.decoder): + self.cfg.decoder.vocab_size = len(self.cfg.st_labels) + with open_dict(self.cfg.joint): + self.cfg.joint.num_classes = len(self.cfg.st_labels) + self.cfg.joint.vocabulary = self.cfg.st_labels + + self.st_decoder = EncDecRNNTModel.from_config_dict(self.cfg.decoder) + self.st_joint = EncDecRNNTModel.from_config_dict(self.cfg.joint) + st_num_classes = self.st_joint.num_classes_with_blank - 1 # for standard RNNT and multi-blank + + if loss_name == 'tdt': + st_num_classes = st_num_classes - self.joint.num_extra_outputs + + self.st_loss = RNNTLoss( + num_classes=st_num_classes, + loss_name=loss_name, + loss_kwargs=loss_kwargs, + reduction=self.cfg.get("rnnt_reduction", "mean_batch"), + ) + + self.st_decoding = RNNTDecoding( + decoding_cfg=self.cfg.decoding, decoder=self.st_decoder, joint=self.st_joint, vocabulary=self.st_joint.vocabulary, + ) + + self.bleu = BLEU( + decoding=self.st_decoding, + batch_dim_index=0, + log_prediction=self._cfg.get('log_prediction', True), + dist_sync_on_step=True, + ) + # Whether to compute loss during evaluation if 'compute_eval_loss' in self.cfg: @@ -148,7 +157,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): if self.joint.fuse_loss_wer or ( self.decoding.joint_fused_batch_size is not None and self.decoding.joint_fused_batch_size > 0 ): - assert(0) + pass +# assert(0) # Setup optimization normalization (if provided in config) self.setup_optim_normalization() @@ -635,7 +645,10 @@ def forward( processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) - st_encoded, st_encoded_len = self.st_encoder(audio_signal=encoded, length=encoded_len) + if self.st_encoder is not None: + st_encoded, st_encoded_len = self.st_encoder(audio_signal=encoded, length=encoded_len) + else: + st_encoded, st_encoded_len = None, None return encoded, encoded_len, st_encoded, st_encoded_len @@ -697,18 +710,18 @@ def training_step(self, batch, batch_nb): if (sample_id + 1) % log_every_n_steps == 0: self.bleu.update( - predictions=encoded, - predictions_lengths=encoded_len, - targets=asr_transcript, - targets_lengths=asr_transcript_len, + predictions=st_encoded, + predictions_lengths=st_encoded_len, + targets=st_transcript, + targets_lengths=st_transcript_len, ) bleu = self.bleu.compute(return_all_metrics=False)['bleu'] self.bleu.reset() self.wer.update( predictions=encoded, predictions_lengths=encoded_len, - targets=transcript, - targets_lengths=transcript_len, + targets=asr_transcript, + targets_lengths=asr_transcript_len, ) _, scores, words = self.wer.compute() self.wer.reset() @@ -763,7 +776,7 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0): joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) asr_loss_value = self.asr_loss( - log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + log_probs=joint, targets=asr_transcript, input_lengths=encoded_len, target_lengths=target_length ) st_decoder, st_target_length, st_states = self.st_decoder(targets=st_transcript, target_length=st_transcript_len)