Skip to content

Commit

Permalink
option to specify smaller model capacity (fixes #20) (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook committed May 31, 2018
1 parent 600990c commit dad5540
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 43 deletions.
8 changes: 5 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*.pyo
__pycache__

.cache
.ipynb_checkpoints
.pytest_cache

Expand All @@ -15,8 +16,9 @@ dist
build
*.egg-info

*.salience.png
*.salience.npy
*.activation.png
*.activation.npy
*.f0.csv

crepe/model.h5
crepe/model-*.h5
crepe/model-*.h5.bz2
15 changes: 13 additions & 2 deletions crepe/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from .core import process_file


def run(filename, output=None, viterbi=False, save_activation=False,
save_plot=False, plot_voicing=False, no_centering=False, step_size=10):
def run(filename, output=None, model_capacity='full', viterbi=False,
save_activation=False, save_plot=False, plot_voicing=False,
no_centering=False, step_size=10):
"""
Collect the WAV files to process and run the model
Expand All @@ -21,6 +22,9 @@ def run(filename, output=None, viterbi=False, save_activation=False,
output : str or None
Path to directory for saving output files. If None, output files will
be saved to the directory containing the input file.
model_capacity : 'tiny', 'small', 'medium', 'large', or 'full'
String specifying the model capacity; see the docstring of
:func:`~crepe.core.build_and_load_model`
viterbi : bool
Apply viterbi smoothing to the estimated pitch curve. False by default.
save_activation : bool
Expand Down Expand Up @@ -68,6 +72,7 @@ def run(filename, output=None, viterbi=False, save_activation=False,
print('CREPE: Processing {} ... ({}/{})'.format(file, i+1, len(files)),
file=sys.stderr)
process_file(file, output=output,
model_capacity=model_capacity,
viterbi=viterbi,
center=(not no_centering),
save_activation=save_activation,
Expand Down Expand Up @@ -117,6 +122,11 @@ def main():
'already exist; if not given, the output will be '
'saved to the same directory as the input WAV '
'file(s)')
parser.add_argument('--model-capacity', '-c', default='full',
choices=['tiny', 'small', 'medium', 'large', 'full'],
help='String specifying the model capacity; smaller '
'models are faster to compute, but may yield '
'less accurate pitch estimation')
parser.add_argument('--viterbi', '-V', action='store_true',
help='perform Viterbi decoding to smooth the pitch '
'curve')
Expand Down Expand Up @@ -145,6 +155,7 @@ def main():

run(args.filename,
output=args.output,
model_capacity=args.model_capacity,
viterbi=args.viterbi,
save_activation=args.save_activation,
save_plot=args.save_plot,
Expand Down
89 changes: 61 additions & 28 deletions crepe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,61 @@
import numpy as np
from numpy.lib.stride_tricks import as_strided

# store as a global variable, since only one model is supported at the moment
model = None
# store as a global variable, since we only support a few models for now
models = {
'tiny': None,
'small': None,
'medium': None,
'large': None,
'full': None
}

# the model is trained on 16kHz audio
model_srate = 16000


def build_and_load_model():
def build_and_load_model(model_capacity):
"""
Build the CNN model and load the weights; this needs to exactly match
what's saved in the Keras weights file
Build the CNN model and load the weights
Parameters
----------
model_capacity : 'tiny', 'small', 'medium', 'large', or 'full'
String specifying the model capacity, which determines the model's
capacity multiplier to 4 (tiny), 8 (small), 16 (medium), 24 (large),
or 32 (full). 'full' uses the model size specified in the paper,
and the others use a reduced number of filters in each convolutional
layer, resulting in a smaller model that is faster to evaluate at the
cost of slightly reduced pitch estimation accuracy.
Returns
-------
The keras model loaded in memory
"""
from keras.layers import Input, Reshape, Conv2D, BatchNormalization
from keras.layers import MaxPool2D, Dropout, Permute, Flatten, Dense
from keras.models import Model
global model

if model is None:
model_capacity = 32
if models[model_capacity] is None:
capacity_multiplier = {
'tiny': 4, 'small': 8, 'medium': 16, 'large': 24, 'full': 32
}[model_capacity]

layers = [1, 2, 3, 4, 5, 6]
filters = [n * model_capacity for n in [32, 4, 4, 4, 8, 16]]
filters = [n * capacity_multiplier for n in [32, 4, 4, 4, 8, 16]]
widths = [512, 64, 64, 64, 64, 64]
strides = [(4, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]

x = Input(shape=(1024,), name='input', dtype='float32')
y = Reshape(target_shape=(1024, 1, 1), name='input-reshape')(x)

for layer, filters, width, strides in zip(layers, filters, widths, strides):
y = Conv2D(filters, (width, 1), strides=strides, padding='same',
activation='relu', name="conv%d" % layer)(y)
y = BatchNormalization(name="conv%d-BN" % layer)(y)
for l, f, w, s in zip(layers, filters, widths, strides):
y = Conv2D(f, (w, 1), strides=s, padding='same',
activation='relu', name="conv%d" % l)(y)
y = BatchNormalization(name="conv%d-BN" % l)(y)
y = MaxPool2D(pool_size=(2, 1), strides=None, padding='valid',
name="conv%d-maxpool" % layer)(y)
y = Dropout(0.25, name="conv%d-dropout" % layer)(y)
name="conv%d-maxpool" % l)(y)
y = Dropout(0.25, name="conv%d-dropout" % l)(y)

y = Permute((2, 1, 3), name="transpose")(y)
y = Flatten(name="flatten")(y)
Expand All @@ -50,9 +71,14 @@ def build_and_load_model():
model = Model(inputs=x, outputs=y)

package_dir = os.path.dirname(os.path.realpath(__file__))
model.load_weights(os.path.join(package_dir, "model.h5"))
filename = "model-{}.h5".format(model_capacity)
model.load_weights(os.path.join(package_dir, filename))
model.compile('adam', 'binary_crossentropy')

models[model_capacity] = model

return models[model_capacity]


def output_path(file, suffix, output_dir):
"""
Expand Down Expand Up @@ -125,7 +151,7 @@ def to_viterbi_cents(salience):
range(len(observations))])


def get_activation(audio, sr, center=True, step_size=10):
def get_activation(audio, sr, model_capacity='full', center=True, step_size=10):
"""
Parameters
Expand All @@ -135,6 +161,9 @@ def get_activation(audio, sr, center=True, step_size=10):
sr : int
Sample rate of the audio samples. The audio will be resampled if
the sample rate is not 16 kHz, which is expected by the model.
model_capacity : 'tiny', 'small', 'medium', 'large', or 'full'
String specifying the model capacity; see the docstring of
:func:`~crepe.core.build_and_load_model`
center : boolean
- If `True` (default), the signal `audio` is padded so that frame
`D[:, t]` is centered at `audio[t * hop_length]`.
Expand All @@ -147,9 +176,7 @@ def get_activation(audio, sr, center=True, step_size=10):
activation : np.ndarray [shape=(T, 360)]
The raw activation matrix
"""
global model
if model is None:
build_and_load_model()
model = build_and_load_model(model_capacity)

if len(audio.shape) == 2:
audio = audio.mean(1) # make mono
Expand Down Expand Up @@ -179,7 +206,8 @@ def get_activation(audio, sr, center=True, step_size=10):
return model.predict(frames, verbose=1)


def predict(audio, sr, viterbi=False, center=True, step_size=10):
def predict(audio, sr, model_capacity='full',
viterbi=False, center=True, step_size=10):
"""
Perform pitch estimation on given audio
Expand All @@ -190,6 +218,9 @@ def predict(audio, sr, viterbi=False, center=True, step_size=10):
sr : int
Sample rate of the audio samples. The audio will be resampled if
the sample rate is not 16 kHz, which is expected by the model.
model_capacity : 'tiny', 'small', 'medium', 'large', or 'full'
String specifying the model capacity; see the docstring of
:func:`~crepe.core.build_and_load_model`
viterbi : bool
Apply viterbi smoothing to the estimated pitch curve. False by default.
center : boolean
Expand All @@ -212,7 +243,8 @@ def predict(audio, sr, viterbi=False, center=True, step_size=10):
activation: np.ndarray [shape=(T, 360)]
The raw activation matrix
"""
activation = get_activation(audio, sr, center=center, step_size=step_size)
activation = get_activation(audio, sr, model_capacity=model_capacity,
center=center, step_size=step_size)
confidence = activation.max(axis=1)

if viterbi:
Expand All @@ -228,9 +260,9 @@ def predict(audio, sr, viterbi=False, center=True, step_size=10):
return time, frequency, confidence, activation


def process_file(file, output=None, viterbi=False, center=True,
save_activation=False, save_plot=False, plot_voicing=False,
step_size=10):
def process_file(file, output=None, model_capacity='full', viterbi=False,
center=True, save_activation=False, save_plot=False,
plot_voicing=False, step_size=10):
"""
Use the input model to perform pitch estimation on the input file.
Expand All @@ -241,6 +273,9 @@ def process_file(file, output=None, viterbi=False, center=True,
output : str or None
Path to directory for saving output files. If None, output files will
be saved to the directory containing the input file.
model_capacity : 'tiny', 'small', 'medium', 'large', or 'full'
String specifying the model capacity; see the docstring of
:func:`~crepe.core.build_and_load_model`
viterbi : bool
Apply viterbi smoothing to the estimated pitch curve. False by default.
center : boolean
Expand All @@ -263,16 +298,14 @@ def process_file(file, output=None, viterbi=False, center=True,
-------
"""
# ensure that the model is loaded
build_and_load_model()

try:
sr, audio = wavfile.read(file)
except ValueError:
print("CREPE: Could not read %s" % file, file=sys.stderr)
raise

time, frequency, confidence, activation = predict(audio, sr,
model_capacity=model_capacity,
viterbi=viterbi,
center=center,
step_size=step_size)
Expand Down
Binary file removed crepe/model.h5.bz2
Binary file not shown.
34 changes: 24 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,33 @@
import imp
from setuptools import setup, find_packages

weight_file = 'model.h5'
try:
from urllib.request import urlretrieve
except ImportError:
from urllib import urlretrieve

model_capacities = ['tiny', 'small', 'medium', 'large', 'full']
weight_files = ['model-{}.h5'.format(cap) for cap in model_capacities]
base_url = 'https://github.com/marl/crepe/raw/models/'

if len(sys.argv) > 1 and sys.argv[1] == 'sdist':
# include the compressed weights file in sdist
weight_file = 'model.h5.bz2'
# exclude the weight files in sdist
weight_files = []
else:
# in all other cases, decompress the weights file if necessary
if not os.path.isfile(os.path.join('crepe', 'model.h5')):
print('Decompressing the model weights ...')
with bz2.BZ2File(os.path.join('crepe', 'model.h5.bz2'), 'rb') as source:
with open(os.path.join('crepe', 'model.h5'), 'wb') as target:
target.write(source.read())
print('Decompression complete')
for weight_file in weight_files:
weight_path = os.path.join('crepe', weight_file)
if not os.path.isfile(weight_path):
compressed_file = weight_file + '.bz2'
compressed_path = os.path.join('crepe', compressed_file)
if not os.path.isfile(compressed_file):
print('Downloading weight file {} ...'.format(compressed_file))
urlretrieve(base_url + compressed_file, compressed_path)
print('Decompressing ...')
with bz2.BZ2File(compressed_path, 'rb') as source:
with open(weight_path, 'wb') as target:
target.write(source.read())
print('Decompression complete')

version = imp.load_source('crepe.version', os.path.join('crepe', 'version.py'))

Expand Down Expand Up @@ -64,6 +78,6 @@
'scikit-learn>=0.16'
],
package_data={
'crepe': [weight_file]
'crepe': weight_files
},
)

0 comments on commit dad5540

Please sign in to comment.