Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logging update #267

Merged
merged 10 commits into from
Jan 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions collections/nemo_asr/nemo_asr/audio_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,7 @@ def __init__(
frame_splicing=frame_splicing,
stft_conv=stft_conv,
pad_value=pad_value,
mag_power=mag_power,
logger=self._logger
mag_power=mag_power
)
self.featurizer.to(self._device)

Expand Down
9 changes: 4 additions & 5 deletions collections/nemo_asr/nemo_asr/data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from functools import partial
import torch

import nemo
from nemo.backends.pytorch import DataLayerNM
from nemo.core import DeviceType
from nemo.core.neural_types import *
Expand Down Expand Up @@ -146,14 +147,13 @@ def __init__(
'trim': trim_silence,
'bos_id': bos_id,
'eos_id': eos_id,
'logger': self._logger,
'load_audio': load_audio}

self._dataset = AudioDataset(**dataset_params)

# Set up data loader
if self._placement == DeviceType.AllGpu:
self._logger.info('Parallelizing DATALAYER')
nemo.logging.info('Parallelizing DATALAYER')
sampler = torch.utils.data.distributed.DistributedSampler(
self._dataset)
else:
Expand Down Expand Up @@ -272,13 +272,12 @@ def __init__(
'labels': labels,
'min_duration': min_duration,
'max_duration': max_duration,
'normalize': normalize_transcripts,
'logger': self._logger}
'normalize': normalize_transcripts}
self._dataset = KaldiFeatureDataset(**dataset_params)

# Set up data loader
if self._placement == DeviceType.AllGpu:
self._logger.info('Parallelizing DATALAYER')
nemo.logging.info('Parallelizing DATALAYER')
sampler = torch.utils.data.distributed.DistributedSampler(
self._dataset)
else:
Expand Down
44 changes: 13 additions & 31 deletions collections/nemo_asr/nemo_asr/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2019 NVIDIA Corporation

import torch
import nemo

from .metrics import word_error_rate

Expand Down Expand Up @@ -31,8 +32,7 @@ def __ctc_decoder_predictions_tensor(tensor, labels):
def monitor_asr_train_progress(tensors: list,
labels: list,
eval_metric='WER',
tb_logger=None,
logger=None):
tb_logger=None):
"""
Takes output of greedy ctc decoder and performs ctc decoding algorithm to
remove duplicates and special symbol. Prints sample to screen, computes
Expand All @@ -42,7 +42,6 @@ def monitor_asr_train_progress(tensors: list,
labels: A list of labels
eval_metric: An optional string from 'WER', 'CER'. Defaults to 'WER'.
tb_logger: Tensorboard logging object
logger:
Returns:
None
"""
Expand Down Expand Up @@ -72,16 +71,10 @@ def monitor_asr_train_progress(tensors: list,
wer = word_error_rate(hypotheses, references, use_cer=use_cer)
if tb_logger is not None:
tb_logger.add_scalar(tag, wer)
if logger:
logger.info(f'Loss: {tensors[0]}')
logger.info(f'{tag}: {wer*100 : 5.2f}%')
logger.info(f'Prediction: {hypotheses[0]}')
logger.info(f'Reference: {references[0]}')
else:
print(f'Loss: {tensors[0]}')
print(f'{tag}: {wer*100 : 5.2f}%')
print(f'Prediction: {hypotheses[0]}')
print(f'Reference: {references[0]}')
nemo.logging.info(f'Loss: {tensors[0]}')
nemo.logging.info(f'{tag}: {wer*100 : 5.2f}%')
nemo.logging.info(f'Prediction: {hypotheses[0]}')
nemo.logging.info(f'Reference: {references[0]}')


def __gather_losses(losses_list: list) -> list:
Expand Down Expand Up @@ -146,8 +139,7 @@ def process_evaluation_batch(tensors: dict, global_vars: dict, labels: list):

def process_evaluation_epoch(global_vars: dict,
eval_metric='WER',
tag=None,
logger=None):
tag=None):
"""
Calculates the aggregated loss and WER across the entire evaluation dataset
"""
Expand All @@ -165,24 +157,14 @@ def process_evaluation_epoch(global_vars: dict,
use_cer=use_cer)

if tag is None:
if logger:
logger.info(f"==========>>>>>>Evaluation Loss: {eloss}")
logger.info(f"==========>>>>>>Evaluation {eval_metric}: "
f"{wer*100 : 5.2f}%")
else:
print(f"==========>>>>>>Evaluation Loss: {eloss}")
print(f"==========>>>>>>Evaluation {eval_metric}: "
f"{wer*100 : 5.2f}%")
nemo.logging.info(f"==========>>>>>>Evaluation Loss: {eloss}")
nemo.logging.info(f"==========>>>>>>Evaluation {eval_metric}: "
f"{wer*100 : 5.2f}%")
return {"Evaluation_Loss": eloss, f"Evaluation_{eval_metric}": wer}
else:
if logger:
logger.info(f"==========>>>>>>Evaluation Loss {tag}: {eloss}")
logger.info(f"==========>>>>>>Evaluation {eval_metric} {tag}: "
f"{wer*100 : 5.2f}%")
else:
print(f"==========>>>>>>Evaluation Loss {tag}: {eloss}")
print(f"==========>>>>>>Evaluation {eval_metric} {tag}:"
f" {wer*100 : 5.2f}%")
nemo.logging.info(f"==========>>>>>>Evaluation Loss {tag}: {eloss}")
nemo.logging.info(f"==========>>>>>>Evaluation {eval_metric} {tag}: "
f"{wer*100 : 5.2f}%")
return {f"Evaluation_Loss_{tag}": eloss,
f"Evaluation_{eval_metric}_{tag}": wer}

Expand Down
13 changes: 6 additions & 7 deletions collections/nemo_asr/nemo_asr/las/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pprint import pformat

import torch
import nemo
from nemo.backends.pytorch.common.metrics import char_lm_metrics

from nemo_asr.metrics import word_error_rate
Expand Down Expand Up @@ -55,7 +56,7 @@ def process_evaluation_batch(tensors, global_vars, labels, specials,

def process_evaluation_epoch(global_vars,
metrics=('loss', 'bpc', 'ppl'), calc_wer=False,
logger=None, mode='eval', tag='none'):
mode='eval', tag='none'):
tag = '_'.join(tag.lower().strip().split())
return_dict = {}
for metric in metrics:
Expand All @@ -70,17 +71,15 @@ def process_evaluation_epoch(global_vars,
transcript_texts = list(chain(*global_vars['transcript_texts']))
prediction_texts = list(chain(*global_vars['prediction_texts']))

if logger:
logger.info(f'Ten examples (transcripts and predictions)')
logger.info(transcript_texts[:10])
logger.info(prediction_texts[:10])
nemo.logging.info(f'Ten examples (transcripts and predictions)')
nemo.logging.info(transcript_texts[:10])
nemo.logging.info(prediction_texts[:10])

wer = word_error_rate(hypotheses=prediction_texts,
references=transcript_texts)
return_dict[f'metric/{mode}_wer_{tag}'] = wer

if logger:
logger.info(pformat(return_dict))
nemo.logging.info(pformat(return_dict))

return return_dict

Expand Down
30 changes: 14 additions & 16 deletions collections/nemo_asr/nemo_asr/parts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch
from torch.utils.data import Dataset

import nemo

from .manifest import ManifestBase, ManifestEN


Expand Down Expand Up @@ -131,7 +133,6 @@ def __init__(
trim=False,
bos_id=None,
eos_id=None,
logger=False,
load_audio=True,
manifest_class=ManifestEN):
m_paths = manifest_filepath.split(',')
Expand All @@ -141,19 +142,17 @@ def __init__(
max_utts=max_utts,
blank_index=blank_index,
unk_index=unk_index,
normalize=normalize,
logger=logger)
normalize=normalize)
self.featurizer = featurizer
self.trim = trim
self.eos_id = eos_id
self.bos_id = bos_id
self.load_audio = load_audio
if logger:
logger.info(
"Dataset loaded with {0:.2f} hours. Filtered {1:.2f} "
"hours.".format(
self.manifest.duration / 3600,
self.manifest.filtered_duration / 3600))
nemo.logging.info(
"Dataset loaded with {0:.2f} hours. Filtered {1:.2f} "
"hours.".format(
self.manifest.duration / 3600,
self.manifest.filtered_duration / 3600))

def __getitem__(self, index):
sample = self.manifest[index]
Expand Down Expand Up @@ -214,8 +213,7 @@ def __init__(
unk_index=-1,
blank_index=-1,
normalize=True,
eos_id=None,
logger=None):
eos_id=None):
self.eos_id = eos_id
self.unk_index = unk_index
self.blank_index = blank_index
Expand Down Expand Up @@ -245,8 +243,8 @@ def __init__(
f"KaldiFeatureDataset max_duration or min_duration is set but"
f" utt2dur file not found in {kaldi_dir}."
)
elif logger:
logger.info(
else:
nemo.logging.info(
f"Did not find utt2dur when loading data from "
f"{kaldi_dir}. Skipping dataset duration calculations."
)
Expand All @@ -265,7 +263,7 @@ def __init__(
text = line[split_idx:].strip()
if normalize:
text = ManifestEN.normalize_text(
text, labels, logger=logger)
text, labels)
dur = id2dur[utt_id] if id2dur else None

# Filter by duration if specified & utt2dur exists
Expand Down Expand Up @@ -295,9 +293,9 @@ def __init__(
print(f"Stop parsing due to max_utts ({max_utts})")
break

if logger and id2dur:
if id2dur:
# utt2dur durations are in seconds
logger.info(
nemo.logging.info(
f"Dataset loaded with {duration/60 : .2f} hours. "
f"Filtered {filtered_duration/60 : .2f} hours.")

Expand Down
13 changes: 4 additions & 9 deletions collections/nemo_asr/nemo_asr/parts/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from .segment import AudioSegment
from torch_stft import STFT

import nemo

CONSTANT = 1e-5


Expand Down Expand Up @@ -127,7 +129,6 @@ def __init__(
stft_conv=False,
pad_value=0,
mag_power=2.,
logger=None
):
super(FilterbankFeatures, self).__init__()
if (n_window_size is None or n_window_stride is None
Expand All @@ -137,21 +138,15 @@ def __init__(
raise ValueError(
f"{self} got an invalid value for either n_window_size or "
f"n_window_stride. Both must be positive ints.")
if logger:
logger.info(f"PADDING: {pad_to}")
else:
print(f"PADDING: {pad_to}")
nemo.logging.info(f"PADDING: {pad_to}")

self.win_length = n_window_size
self.hop_length = n_window_stride
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
self.stft_conv = stft_conv

if stft_conv:
if logger:
logger.info("STFT using conv")
else:
print("STFT using conv")
nemo.logging.info("STFT using conv")

# Create helper class to patch forward func for use with AMP
class STFTPatch(STFT):
Expand Down
23 changes: 8 additions & 15 deletions collections/nemo_asr/nemo_asr/parts/manifest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Taken straight from Patter https://github.com/ryanleary/patter
# TODO: review, and copyright and fix/add comments
import json
import nemo
import string

from nemo.utils import get_logger
from .cleaners import clean_text


Expand All @@ -17,8 +17,7 @@ def __init__(self,
max_utts=0,
blank_index=-1,
unk_index=-1,
normalize=True,
logger=None):
normalize=True):
self.min_duration = min_duration
self.max_duration = max_duration
self.sort_by_duration = sort_by_duration
Expand All @@ -27,9 +26,6 @@ def __init__(self,
self.unk_index = unk_index
self.normalize = normalize
self.labels_map = {label: i for i, label in enumerate(labels)}
self.logger = logger
if logger is None:
self.logger = get_logger('')

data = []
duration = 0.0
Expand All @@ -53,9 +49,9 @@ def __init__(self,
filtered_duration += item['duration']
continue
if normalize:
text = self.normalize_text(text, labels, logger=self.logger)
text = self.normalize_text(text, labels)
if not isinstance(text, str):
self.logger.warning(
nemo.logging.warning(
"WARNING: Got transcript: {}. It is not a "
"string. Dropping data point".format(text)
)
Expand All @@ -69,7 +65,7 @@ def __init__(self,

# support files using audio_filename
if 'audio_filename' in item and 'audio_filepath' not in item:
self.logger.warning(
nemo.logging.warning(
"Malformed manifest: The key audio_filepath was not "
"found in the manifest. Using audio_filename instead."
)
Expand All @@ -79,7 +75,7 @@ def __init__(self,
duration += item['duration']

if max_utts > 0 and len(data) >= max_utts:
self.logger.info(
nemo.logging.info(
'Stop parsing due to max_utts ({})'.format(max_utts))
break

Expand Down Expand Up @@ -155,7 +151,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@staticmethod
def normalize_text(text, labels, logger=None):
def normalize_text(text, labels):
# Punctuation to remove
punctuation = string.punctuation
# Define punctuation that will be handled by text cleaner
Expand Down Expand Up @@ -183,10 +179,7 @@ def normalize_text(text, labels, logger=None):
try:
text = clean_text(text, table, punctuation_to_replace)
except BaseException:
if logger:
logger.warning("WARNING: Normalizing {} failed".format(text))
else:
print("WARNING: Normalizing {} failed".format(text))
nemo.logging.warning("WARNING: Normalizing {} failed".format(text))
return None

return text
Loading