Skip to content

Commit

Permalink
Issue 357 (#360)
Browse files Browse the repository at this point in the history
* Issue 357: Fix to export jasper to onnx - logger and factory creation

Signed-off-by: adriana <adrifloresm@gmail.com>

* Issue 357: Changes to jasper_eval.py - added amp_opt_level, cache=False and formatting

Signed-off-by: adriana <adrifloresm@gmail.com>

* Issue 357: Updated changes for PR

Signed-off-by: adriana <adrifloresm@gmail.com>
  • Loading branch information
adrifloresm authored Feb 13, 2020
1 parent 3cdba88 commit 403238f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 28 deletions.
53 changes: 26 additions & 27 deletions examples/asr/jasper_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,30 @@

def main():
parser = argparse.ArgumentParser(description='Jasper')
parser.add_argument("--local_rank", default=None, type=int)
parser.add_argument("--batch_size", default=64, type=int)
# model params
parser.add_argument("--model_config", type=str, required=True)
parser.add_argument("--eval_datasets", type=str, required=True)
parser.add_argument("--load_dir", type=str, required=True)
# run params
parser.add_argument("--local_rank", default=None, type=int)
parser.add_argument("--batch_size", default=64, type=int)
parser.add_argument("--amp_opt_level", default="O1", type=str)
# store results
parser.add_argument("--save_logprob", default=None, type=str)

# lm inference parameters
parser.add_argument("--lm_path", default=None, type=str)
parser.add_argument(
'--alpha', default=2.0, type=float, help='value of LM weight', required=False,
)
parser.add_argument('--alpha', default=2.0, type=float, help='value of LM weight', required=False)
parser.add_argument(
'--alpha_max',
type=float,
help='maximum value of LM weight (for a grid search in \'eval\' mode)',
required=False,
)
parser.add_argument(
'--alpha_step', type=float, help='step for LM weight\'s tuning in \'eval\' mode', required=False, default=0.1,
)
parser.add_argument(
'--beta', default=1.5, type=float, help='value of word count weight', required=False,
'--alpha_step', type=float, help='step for LM weight\'s tuning in \'eval\' mode', required=False, default=0.1
)
parser.add_argument('--beta', default=1.5, type=float, help='value of word count weight', required=False)
parser.add_argument(
'--beta_max',
type=float,
Expand Down Expand Up @@ -71,7 +73,7 @@ def main():
neural_factory = nemo.core.NeuralModuleFactory(
backend=nemo.core.Backend.PyTorch,
local_rank=args.local_rank,
optimization_level=nemo.core.Optimization.mxprO1,
optimization_level=args.amp_opt_level,
placement=device,
)

Expand Down Expand Up @@ -102,13 +104,13 @@ def main():
nemo.logging.info('Evaluating {0} examples'.format(N))

data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"],
sample_rate=sample_rate, **jasper_params["AudioToMelSpectrogramPreprocessor"]
)
jasper_encoder = nemo_asr.JasperEncoder(
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"], **jasper_params["JasperEncoder"],
feat_in=jasper_params["AudioToMelSpectrogramPreprocessor"]["features"], **jasper_params["JasperEncoder"]
)
jasper_decoder = nemo_asr.JasperDecoderForCTC(
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"], num_classes=len(vocab),
feat_in=jasper_params["JasperEncoder"]["jasper"][-1]["filters"], num_classes=len(vocab)
)
greedy_decoder = nemo_asr.GreedyCTCDecoder()

Expand All @@ -120,27 +122,25 @@ def main():
)
nemo.logging.info('================================')

(audio_signal_e1, a_sig_length_e1, transcript_e1, transcript_len_e1,) = data_layer()
# Define inference DAG
audio_signal_e1, a_sig_length_e1, transcript_e1, transcript_len_e1 = data_layer()
processed_signal_e1, p_length_e1 = data_preprocessor(input_signal=audio_signal_e1, length=a_sig_length_e1)
encoded_e1, encoded_len_e1 = jasper_encoder(audio_signal=processed_signal_e1, length=p_length_e1)
log_probs_e1 = jasper_decoder(encoder_output=encoded_e1)
predictions_e1 = greedy_decoder(log_probs=log_probs_e1)

eval_tensors = [
log_probs_e1,
predictions_e1,
transcript_e1,
transcript_len_e1,
encoded_len_e1,
]
eval_tensors = [log_probs_e1, predictions_e1, transcript_e1, transcript_len_e1, encoded_len_e1]

evaluated_tensors = neural_factory.infer(tensors=eval_tensors, checkpoint_dir=load_dir, cache=True)
# inference
evaluated_tensors = neural_factory.infer(tensors=eval_tensors, checkpoint_dir=load_dir, cache=False)

greedy_hypotheses = post_process_predictions(evaluated_tensors[1], vocab)
references = post_process_transcripts(evaluated_tensors[2], evaluated_tensors[3], vocab)

wer = word_error_rate(hypotheses=greedy_hypotheses, references=references)
nemo.logging.info("Greedy WER {:.2f}%".format(wer * 100))

# language model
if args.lm_path:
if args.alpha_max is None:
args.alpha_max = args.alpha
Expand Down Expand Up @@ -168,18 +168,17 @@ def main():
)
beam_predictions_e1 = beam_search_with_lm(log_probs=log_probs_e1, log_probs_length=encoded_len_e1)

evaluated_tensors = neural_factory.infer(tensors=[beam_predictions_e1], use_cache=True, verbose=False,)
evaluated_tensors = neural_factory.infer(tensors=[beam_predictions_e1], use_cache=False, verbose=False)

beam_hypotheses = []
# Over mini-batch
for i in evaluated_tensors[-1]:
# Over samples
for j in i:
beam_hypotheses.append(j[0][1])

wer = word_error_rate(hypotheses=beam_hypotheses, references=references)
nemo.logging.info("Beam WER {:.2f}%".format(wer * 100))
beam_wers.append(((alpha, beta), wer * 100))
lm_wer = word_error_rate(hypotheses=beam_hypotheses, references=references)
nemo.logging.info("Beam WER {:.2f}%".format(lm_wer * 100))
beam_wers.append(((alpha, beta), lm_wer * 100))

nemo.logging.info('Beam WER for (alpha, beta)')
nemo.logging.info('================================')
Expand Down
6 changes: 5 additions & 1 deletion scripts/export_jasper_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import nemo
import nemo.collections.asr as nemo_asr

logging = nemo.logging


def get_parser():
parser = argparse.ArgumentParser(description="Convert Jasper NeMo checkpoint to ONNX")
Expand Down Expand Up @@ -58,10 +60,13 @@ def main(
logging.info(" Num encoder input features: {}".format(num_encoder_input_features))
logging.info(" Num decoder input features: {}".format(num_decoder_input_features))

nf = nemo.core.NeuralModuleFactory(create_tb_writer=False)

logging.info("Initializing models...")
jasper_encoder = nemo_asr.JasperEncoder(
feat_in=num_encoder_input_features, **jasper_model_definition['JasperEncoder']
)

jasper_decoder = nemo_asr.JasperDecoderForCTC(
feat_in=num_decoder_input_features, num_classes=len(jasper_model_definition['labels']),
)
Expand All @@ -83,7 +88,6 @@ def main(
jasper_encoder.restore_from(nn_encoder)
jasper_decoder.restore_from(nn_decoder)

nf = nemo.core.NeuralModuleFactory(create_tb_writer=False)
logging.info("Exporting encoder...")
nf.deployment_export(
jasper_encoder,
Expand Down

0 comments on commit 403238f

Please sign in to comment.