From 403238f82d26879ba5fca53fbf75b3cdc70fb49b Mon Sep 17 00:00:00 2001 From: Adriana Flores Date: Thu, 13 Feb 2020 12:31:53 -0700 Subject: [PATCH] Issue 357 (#360) * Issue 357: Fix to export jasper to onnx - logger and factory creation Signed-off-by: adriana * Issue 357: Changes to jasper_eval.py - added amp_opt_level, cache=False and formatting Signed-off-by: adriana * Issue 357: Updated changes for PR Signed-off-by: adriana --- examples/asr/jasper_eval.py | 53 ++++++++++++++++---------------- scripts/export_jasper_to_onnx.py | 6 +++- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/examples/asr/jasper_eval.py b/examples/asr/jasper_eval.py index b4b16699d13f..9c5fac4eb36d 100644 --- a/examples/asr/jasper_eval.py +++ b/examples/asr/jasper_eval.py @@ -16,16 +16,20 @@ def main(): parser = argparse.ArgumentParser(description='Jasper') - parser.add_argument("--local_rank", default=None, type=int) - parser.add_argument("--batch_size", default=64, type=int) + # model params parser.add_argument("--model_config", type=str, required=True) parser.add_argument("--eval_datasets", type=str, required=True) parser.add_argument("--load_dir", type=str, required=True) + # run params + parser.add_argument("--local_rank", default=None, type=int) + parser.add_argument("--batch_size", default=64, type=int) + parser.add_argument("--amp_opt_level", default="O1", type=str) + # store results parser.add_argument("--save_logprob", default=None, type=str) + + # lm inference parameters parser.add_argument("--lm_path", default=None, type=str) - parser.add_argument( - '--alpha', default=2.0, type=float, help='value of LM weight', required=False, - ) + parser.add_argument('--alpha', default=2.0, type=float, help='value of LM weight', required=False) parser.add_argument( '--alpha_max', type=float, @@ -33,11 +37,9 @@ def main(): required=False, ) parser.add_argument( - '--alpha_step', type=float, help='step for LM weight\'s tuning in \'eval\' mode', required=False, default=0.1, - ) - parser.add_argument( - '--beta', default=1.5, type=float, help='value of word count weight', required=False, + '--alpha_step', type=float, help='step for LM weight\'s tuning in \'eval\' mode', required=False, default=0.1 ) + parser.add_argument('--beta', default=1.5, type=float, help='value of word count weight', required=False) parser.add_argument( '--beta_max', type=float, @@ -71,7 +73,7 @@ def main(): neural_factory = nemo.core.NeuralModuleFactory( backend=nemo.core.Backend.PyTorch, local_rank=args.local_rank, - optimization_level=nemo.core.Optimization.mxprO1, + optimization_level=args.amp_opt_level, placement=device, ) @@ -102,13 +104,13 @@ def main(): nemo.logging.info('Evaluating {0} examples'.format(N)) data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor( - sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"], + sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"] ) jasper_encoder = nemo_asr.JasperEncoder( - feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"], **jasper_params["JasperEncoder"], + feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"], **jasper_params["JasperEncoder"] ) jasper_decoder = nemo_asr.JasperDecoderForCTC( - feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"], num_classes=len(vocab), + feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"], num_classes=len(vocab) ) greedy_decoder = nemo_asr.GreedyCTCDecoder() @@ -120,27 +122,25 @@ def main(): ) nemo.logging.info('================================') - (audio_signal_e1, a_sig_length_e1, transcript_e1, transcript_len_e1,) = data_layer() + # Define inference DAG + audio_signal_e1, a_sig_length_e1, transcript_e1, transcript_len_e1 = data_layer() processed_signal_e1, p_length_e1 = data_preprocessor(input_signal=audio_signal_e1, length=a_sig_length_e1) encoded_e1, encoded_len_e1 = jasper_encoder(audio_signal=processed_signal_e1, length=p_length_e1) log_probs_e1 = jasper_decoder(encoder_output=encoded_e1) predictions_e1 = greedy_decoder(log_probs=log_probs_e1) - eval_tensors = [ - log_probs_e1, - predictions_e1, - transcript_e1, - transcript_len_e1, - encoded_len_e1, - ] + eval_tensors = [log_probs_e1, predictions_e1, transcript_e1, transcript_len_e1, encoded_len_e1] - evaluated_tensors = neural_factory.infer(tensors=eval_tensors, checkpoint_dir=load_dir, cache=True) + # inference + evaluated_tensors = neural_factory.infer(tensors=eval_tensors, checkpoint_dir=load_dir, cache=False) greedy_hypotheses = post_process_predictions(evaluated_tensors[1], vocab) references = post_process_transcripts(evaluated_tensors[2], evaluated_tensors[3], vocab) + wer = word_error_rate(hypotheses=greedy_hypotheses, references=references) nemo.logging.info("Greedy WER {:.2f}%".format(wer * 100)) + # language model if args.lm_path: if args.alpha_max is None: args.alpha_max = args.alpha @@ -168,7 +168,7 @@ def main(): ) beam_predictions_e1 = beam_search_with_lm(log_probs=log_probs_e1, log_probs_length=encoded_len_e1) - evaluated_tensors = neural_factory.infer(tensors=[beam_predictions_e1], use_cache=True, verbose=False,) + evaluated_tensors = neural_factory.infer(tensors=[beam_predictions_e1], use_cache=False, verbose=False) beam_hypotheses = [] # Over mini-batch @@ -176,10 +176,9 @@ def main(): # Over samples for j in i: beam_hypotheses.append(j[0][1]) - - wer = word_error_rate(hypotheses=beam_hypotheses, references=references) - nemo.logging.info("Beam WER {:.2f}%".format(wer * 100)) - beam_wers.append(((alpha, beta), wer * 100)) + lm_wer = word_error_rate(hypotheses=beam_hypotheses, references=references) + nemo.logging.info("Beam WER {:.2f}%".format(lm_wer * 100)) + beam_wers.append(((alpha, beta), lm_wer * 100)) nemo.logging.info('Beam WER for (alpha, beta)') nemo.logging.info('================================') diff --git a/scripts/export_jasper_to_onnx.py b/scripts/export_jasper_to_onnx.py index 84db7bddaf9a..dbb24023fa2f 100644 --- a/scripts/export_jasper_to_onnx.py +++ b/scripts/export_jasper_to_onnx.py @@ -7,6 +7,8 @@ import nemo import nemo.collections.asr as nemo_asr +logging = nemo.logging + def get_parser(): parser = argparse.ArgumentParser(description="Convert Jasper NeMo checkpoint to ONNX") @@ -58,10 +60,13 @@ def main( logging.info(" Num encoder input features: {}".format(num_encoder_input_features)) logging.info(" Num decoder input features: {}".format(num_decoder_input_features)) + nf = nemo.core.NeuralModuleFactory(create_tb_writer=False) + logging.info("Initializing models...") jasper_encoder = nemo_asr.JasperEncoder( feat_in=num_encoder_input_features, **jasper_model_definition['JasperEncoder'] ) + jasper_decoder = nemo_asr.JasperDecoderForCTC( feat_in=num_decoder_input_features, num_classes=len(jasper_model_definition['labels']), ) @@ -83,7 +88,6 @@ def main( jasper_encoder.restore_from(nn_encoder) jasper_decoder.restore_from(nn_decoder) - nf = nemo.core.NeuralModuleFactory(create_tb_writer=False) logging.info("Exporting encoder...") nf.deployment_export( jasper_encoder,