Skip to content

Commit

Permalink
Kapre 0.3.4 (#101)
Browse files Browse the repository at this point in the history
* working on example

* fix window function in backend

* add release note; bump version

Co-authored-by: Keunwoo Choi <keunwoo.choi@bytedance.com>
  • Loading branch information
keunwoochoi and keunwoochoi committed Sep 29, 2020
1 parent 1988e84 commit 3ca2a9c
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 44 deletions.
98 changes: 57 additions & 41 deletions docs/quickstart.rst
Original file line number Diff line number Diff line change
@@ -1,46 +1,62 @@
One-shot example
^^^^^^^^^^^^^^^^
Examples
========

How To Import
-------------

.. code-block:: python
import kapre # to import the whole library
from kapre import ( # `time_frequency` layers can be directly imported from `kapre`
STFT,
InverseSTFT,
Magnitude,
Phase,
MagnitudeToDecibel,
ApplyFilterbank,
Delta,
ConcatenateFrequencyMap,
)
from kapre import ( # `signal` layers can be also directly imported from kapre
Frame,
Energy,
MuLawEncoding,
MuLawDecoding,
LogmelToMFCC,
)
# from kapre import backend # we can do this, but `backend` might be a too general name
import kapre.backend # for namespace sanity, you might prefer this
from kapre import backend as kapre_backend # or maybe this
from kapre.composed import ( # function names in `composed` are purposefully verbose.
get_stft_magnitude_layer,
get_melspectrogram_layer,
get_log_frequency_spectrogram_layer,
get_perfectly_reconstructing_stft_istft,
get_stft_mag_phase,
get_frequency_aware_conv2d,
)
Use STFT Magnitude
------------------

.. code-block:: python
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU, GlobalAveragePooling2D, Dense, Softmax
from kapre import STFT, Magnitude, MagnitudeToDecibel
from kapre.composed import get_melspectrogram_layer, get_log_frequency_spectrogram_layer, get_stft_magnitude_layer
# 6 channels (!), maybe 1-sec audio signal, for an example.
input_shape = (6, 44100)
sr = 44100
model = Sequential()
# A STFT layer
model.add(STFT(n_fft=2048, win_length=2018, hop_length=1024,
window_fn=None, pad_end=False,
input_data_format='channels_last', output_data_format='channels_last',
input_shape=input_shape))
model.add(Magnitude())
model.add(MagnitudeToDecibel()) # these three layers can be replaced with get_stft_magnitude_layer()
# Alternatively, you may want to use a melspectrogram layer
# melgram_layer = get_melspectrogram_layer()
# or log-frequency layer
# log_stft_layer = get_log_frequency_spectrogram_layer()
# add more layers as you want
model.add(Conv2D(32, (3, 3), strides=(2, 2)))
model.add(BatchNormalization())
model.add(ReLU())
model.add(GlobalAveragePooling2D())
model.add(Dense(10))
model.add(Softmax())
# Compile the model
model.compile('adam', 'categorical_crossentropy') # if single-label classification
# train it with raw audio sample inputs
# for example, you may have functions that load your data as below.
x = load_x() # e.g., x.shape = (10000, 6, 44100)
y = load_y() # e.g., y.shape = (10000, 10) if it's 10-class classification
# then..
model.fit(x, y)
# Done!
from tensorflow.keras.models import Sequential
from kapre import STFT, Magnitude, MagnitudeToDecibel
from kapre.composed import get_stft_magnitude_layer

sampling_rate = 16000 # sampling rate of your input audio
duration = 20.0 # duration of the audio
num_channel = 2 # number of channels of the audio
input_shape = (num_channel, int(sampling_rate * duration)) # let's follow `channels_last` convention even for audio

model = Sequential()
model.add(STFT(n_fft=2048, win_length=2018, hop_length=1024,
window_name='hann_window', pad_end=False,
input_data_format='channels_last', output_data_format='channels_last',
input_shape=input_shape))
model.add(Magnitude())
model.add(MagnitudeToDecibel()) # these three layers can be replaced with get_stft_magnitude_layer()


* See the Jupyter notebook at the `example folder <https://github.com/keunwoochoi/kapre/tree/master/examples>`_
5 changes: 5 additions & 0 deletions docs/release_note.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Release Note
^^^^^^^^^^^^

* 29 Sep 2020
- 0.3.4
- Fix a bug in `kapre.backend.get_window_fn()`. Previously, it only correctly worked with `None` input and
an erorr was raised when non-default value was set for `window_name` in any layer.

* 15 Sep 2020
- 0.3.3
- `kapre.augmentation` is added
Expand Down
2 changes: 1 addition & 1 deletion kapre/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.3.3'
__version__ = '0.3.4'
VERSION = __version__

from . import composed
Expand Down
2 changes: 1 addition & 1 deletion kapre/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_window_fn(window_name=None):
)
)

return window_name
return available_windows[window_name]


def validate_data_format_str(data_format):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='kapre',
version='0.3.3',
version='0.3.4',
description='Kapre: Keras Audio Preprocessors. Tensorflow.Keras layers for audio pre-processing in deep learning',
author='Keunwoo Choi',
url='http://github.com/keunwoochoi/kapre/',
Expand Down
60 changes: 60 additions & 0 deletions tests/test_time_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,66 @@ def _get_stft_model(following_layer=None):
allclose_phase(np.angle(S_complex), S)


@pytest.mark.parametrize('data_format', ['channels_first', 'channels_last'])
@pytest.mark.parametrize('window_name', [None, 'hann_window', 'hamming_window'])
def test_spectrogram_correctness_more(data_format, window_name):
def _get_stft_model(following_layer=None):
# compute with kapre
stft_model = tensorflow.keras.models.Sequential()
stft_model.add(
STFT(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window_name=window_name,
pad_end=False,
input_data_format=data_format,
output_data_format=data_format,
input_shape=input_shape,
name='stft',
)
)
if following_layer is not None:
stft_model.add(following_layer)
return stft_model

n_fft = 512
hop_length = 256
n_ch = 2

src_mono, batch_src, input_shape = get_audio(data_format=data_format, n_ch=n_ch)
win_length = n_fft # test with x2
# compute with librosa
S_ref = librosa.core.stft(
src_mono,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
center=False,
window=window_name.replace('_window', '') if window_name else 'hann',
).T # (time, freq)

S_ref = np.expand_dims(S_ref, axis=2) # time, freq, ch=1
S_ref = np.tile(S_ref, [1, 1, n_ch]) # time, freq, ch=n_ch
if data_format == 'channels_first':
S_ref = np.transpose(S_ref, (2, 0, 1)) # ch, time, freq

stft_model = _get_stft_model()

S_complex = stft_model.predict(batch_src)[0] # 3d representation
allclose_complex_numbers(S_ref, S_complex)

# test Magnitude()
stft_mag_model = _get_stft_model(Magnitude())
S = stft_mag_model.predict(batch_src)[0] # 3d representation
np.testing.assert_allclose(np.abs(S_ref), S, atol=2e-4)

# # test Phase()
stft_phase_model = _get_stft_model(Phase())
S = stft_phase_model.predict(batch_src)[0] # 3d representation
allclose_phase(np.angle(S_complex), S)


@pytest.mark.parametrize('n_fft', [512])
@pytest.mark.parametrize('sr', [22050])
@pytest.mark.parametrize('hop_length', [None, 256])
Expand Down

0 comments on commit 3ca2a9c

Please sign in to comment.