Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TFLite accuracy estimation tool #1854

Merged
merged 1 commit into from
Feb 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 3 additions & 37 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import tables
import tensorflow as tf

from attrdict import AttrDict
from collections import namedtuple
from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from multiprocessing import Pool, cpu_count
Expand All @@ -21,9 +20,9 @@
from util.config import Config, initialize_globals
from util.flags import create_flags, FLAGS
from util.logging import log_error
from util.preprocess import pmap, preprocess
from util.text import Alphabet, wer_cer_batch, levenshtein

from util.preprocess import preprocess
from util.text import Alphabet, levenshtein
from util.evaluate_tools import process_decode_result, calculate_report

def split_data(dataset, batch_size):
remainder = len(dataset) % batch_size
Expand All @@ -45,39 +44,6 @@ def pad_to_dense(jagged):
return padded


def process_decode_result(item):
label, decoding, distance, loss = item
word_distance = levenshtein(label.split(), decoding.split())
word_length = float(len(label.split()))
return AttrDict({
'src': label,
'res': decoding,
'loss': loss,
'distance': distance,
'wer': word_distance / word_length,
})


def calculate_report(labels, decodings, distances, losses):
r'''
This routine will calculate a WER report.
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
'''
samples = pmap(process_decode_result, zip(labels, decodings, distances, losses))

# Getting the WER and CER from the accumulated edit distances and lengths
samples_wer, samples_cer = wer_cer_batch(labels, decodings)

# Order the remaining items by their loss (lowest loss on top)
samples.sort(key=lambda s: s.loss)

# Then order by WER (highest WER on top)
samples.sort(key=lambda s: s.wer, reverse=True)

return samples_wer, samples_cer, samples


def evaluate(test_data, inference_graph):
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
Expand Down
105 changes: 105 additions & 0 deletions evaluate_tflite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function

import argparse
import numpy as np
import wave
import csv
import sys

from six.moves import zip, range
from multiprocessing import JoinableQueue, Pool, Process, Queue, cpu_count
from deepspeech import Model

from util.text import levenshtein
from util.evaluate_tools import process_decode_result, calculate_report

r'''
This module should be self-contained:
- build libdeepspeech.so with TFLite:
- add a dep in native_client/BUILD against TFlite: '//tensorflow:linux_x86_64': [ "//tensorflow/contrib/lite/kernels:builtin_ops" ]
- bazel build [...] --copt=-DUSE_TFLITE [...] //native_client:libdeepspeech.so
lissyx marked this conversation as resolved.
Show resolved Hide resolved
- make -C native_client/python/ TFDIR=... bindings
- setup a virtualenv
- pip install native_client/python/dist/deepspeech*.whl
- pip install -r requirements_eval_tflite.txt

Then run with a TF Lite model, alphabet, LM/trie and a CSV test file
'''

BEAM_WIDTH = 500
LM_ALPHA = 0.75
LM_BETA = 1.85
N_FEATURES = 26
N_CONTEXT = 9

def tflite_worker(model, alphabet, lm, trie, queue_in, queue_out):
ds = Model(model, N_FEATURES, N_CONTEXT, alphabet, BEAM_WIDTH)
ds.enableDecoderWithLM(alphabet, lm, trie, LM_ALPHA, LM_BETA)

while True:
msg = queue_in.get()

fin = wave.open(msg['filename'], 'rb')
fs = fin.getframerate()
audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16)
audio_length = fin.getnframes() * (1/16000)
fin.close()

decoded = ds.stt(audio, fs)

queue_out.put({'prediction': decoded, 'ground_truth': msg['transcript']})
queue_in.task_done()

def main():
parser = argparse.ArgumentParser(description='Computing TFLite accuracy')
parser.add_argument('--model', required=True,
help='Path to the model (protocol buffer binary file)')
parser.add_argument('--alphabet', required=True,
help='Path to the configuration file specifying the alphabet used by the network')
parser.add_argument('--lm', required=True,
help='Path to the language model binary file')
parser.add_argument('--trie', required=True,
help='Path to the language model trie file created with native_client/generate_trie')
parser.add_argument('--csv', required=True,
help='Path to the CSV source file')
args = parser.parse_args()

work_todo = JoinableQueue() # this is where we are going to store input data
work_done = Queue() # this where we are gonna push them out

processes = []
for i in range(cpu_count()):
worker_process = Process(target=tflite_worker, args=(args.model, args.alphabet, args.lm, args.trie, work_todo, work_done), daemon=True, name='tflite_process_{}'.format(i))
worker_process.start() # Launch reader() as a separate python process
processes.append(worker_process)

print([x.name for x in processes])

ground_truths = []
predictions = []
losses = []

with open(args.csv, 'r') as csvfile:
csvreader = csv.DictReader(csvfile)
for row in csvreader:
work_todo.put({'filename': row['wav_filename'], 'transcript': row['transcript']})
work_todo.join()

while (not work_done.empty()):
msg = work_done.get()
losses.append(0.0)
ground_truths.append(msg['ground_truth'])
predictions.append(msg['prediction'])

distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]

wer, cer, samples = calculate_report(ground_truths, predictions, distances, losses)
mean_loss = np.mean(losses)

print('Test - WER: %f, CER: %f, loss: %f' %
(wer, cer, mean_loss))

if __name__ == '__main__':
main()
7 changes: 7 additions & 0 deletions requirements_eval_tflite.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
attrdict==2.0.0
deepspeech
numpy==1.16.0
pkg-resources==0.0.0
progressbar2==3.39.2
python-utils==2.3.0
six==1.12.0
45 changes: 45 additions & 0 deletions util/evaluate_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function

from attrdict import AttrDict
from multiprocessing.dummy import Pool
from util.text import wer_cer_batch, levenshtein

def pmap(fun, iterable):
pool = Pool()
results = pool.map(fun, iterable)
pool.close()
return results

def process_decode_result(item):
label, decoding, distance, loss = item
word_distance = levenshtein(label.split(), decoding.split())
word_length = float(len(label.split()))
return AttrDict({
'src': label,
'res': decoding,
'loss': loss,
'distance': distance,
'wer': word_distance / word_length,
})


def calculate_report(labels, decodings, distances, losses):
r'''
This routine will calculate a WER report.
It'll compute the `mean` WER and create ``Sample`` objects of the ``report_count`` top lowest
loss items from the provided WER results tuple (only items with WER!=0 and ordered by their WER).
'''
samples = pmap(process_decode_result, zip(labels, decodings, distances, losses))

# Getting the WER and CER from the accumulated edit distances and lengths
samples_wer, samples_cer = wer_cer_batch(labels, decodings)

# Order the remaining items by their loss (lowest loss on top)
samples.sort(key=lambda s: s.loss)

# Then order by WER (highest WER on top)
samples.sort(key=lambda s: s.wer, reverse=True)

return samples_wer, samples_cer, samples