Skip to content

Commit

Permalink
Lint and reply to comments
Browse files Browse the repository at this point in the history
  • Loading branch information
garymm committed Jun 8, 2022
1 parent c66c02d commit 5067d11
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 142 deletions.
1 change: 1 addition & 0 deletions onnxruntime/core/graph/signal_ops/signal_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "onnx/defs/tensor_proto_util.h"

#include <cmath>
#include <string>

// NOTE: These were added to the standard op set. We register them under the MS domain
// for backwards compatibility, but new users should use the standard ops instead. Ideally these would be deleted.
Expand Down
23 changes: 12 additions & 11 deletions onnxruntime/core/providers/cpu/signal/dft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ static T compute_angular_velocity(size_t number_of_samples, bool inverse) {
template <typename T, typename U>
static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, size_t X_offset, size_t X_stride,
size_t Y_offset, size_t Y_stride, int64_t axis, size_t dft_length, const Tensor* window,
bool is_onesided, bool inverse, std::vector<std::complex<T>>& V,
std::vector<std::complex<T>>& temp_output) {
bool is_onesided, bool inverse, InlinedVector<std::complex<T>>& V,
InlinedVector<std::complex<T>>& temp_output) {
// Get shape and significant bits
const auto& X_shape = X->Shape();
size_t number_of_samples = static_cast<size_t>(X_shape[axis]);
Expand All @@ -183,7 +183,7 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s
std::complex<T>* Y_data;
if (is_onesided) {
if (temp_output.size() != dft_length) {
temp_output = std::vector<std::complex<T>>(dft_length);
temp_output = InlinedVector<std::complex<T>>(dft_length);
}
Y_data = temp_output.data();
} else {
Expand All @@ -195,7 +195,7 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s

// Create vandermonde matrix V ordered with the bit-reversed permutation
if (V.size() != dft_length) {
V = std::vector<std::complex<T>>(dft_length); // e^(i *2*pi / N * k)
V = InlinedVector<std::complex<T>>(dft_length); // e^(i *2*pi / N * k)
for (size_t i = 0; i < dft_length; i++) {
size_t bit_reversed_index = bit_reverse(i, significant_bits);
V[bit_reversed_index] = std::complex<T>(cos(i * angular_velocity), sin(i * angular_velocity));
Expand Down Expand Up @@ -293,7 +293,8 @@ static Status dft_naive(const Tensor* X, Tensor* Y, size_t X_offset, size_t X_st
template <typename T, typename U>
static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X, Tensor* Y, int64_t axis,
int64_t dft_length, const Tensor* window, bool is_onesided, bool inverse,
std::vector<std::complex<T>>& V, std::vector<std::complex<T>>& temp_output) {
InlinedVector<std::complex<T>>& V,
InlinedVector<std::complex<T>>& temp_output) {
// Get shape
const auto& X_shape = X->Shape();
const auto& Y_shape = Y->Shape();
Expand Down Expand Up @@ -394,8 +395,8 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo

auto element_size = data_type->Size();
if (element_size == sizeof(float)) {
std::vector<std::complex<float>> V;
std::vector<std::complex<float>> temp_output;
InlinedVector<std::complex<float>> V;
InlinedVector<std::complex<float>> temp_output;
if (is_real_valued) {
ORT_RETURN_IF_ERROR((discrete_fourier_transform<float, float>(ctx, X, Y, axis, number_of_samples, nullptr,
is_onesided, inverse, V, temp_output)));
Expand All @@ -410,8 +411,8 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo
data_type);
}
} else if (element_size == sizeof(double)) {
std::vector<std::complex<double>> V;
std::vector<std::complex<double>> temp_output;
InlinedVector<std::complex<double>> V;
InlinedVector<std::complex<double>> temp_output;
if (is_real_valued) {
ORT_RETURN_IF_ERROR((discrete_fourier_transform<double, double>(ctx, X, Y, axis, number_of_samples, nullptr,
is_onesided, inverse, V, temp_output)));
Expand Down Expand Up @@ -510,8 +511,8 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside
auto dft_input_shape = onnxruntime::TensorShape({1, window_size, signal_components});
auto dft_output_shape = onnxruntime::TensorShape({1, dft_output_size, output_components});

std::vector<std::complex<T>> V;
std::vector<std::complex<T>> temp_output;
InlinedVector<std::complex<T>> V;
InlinedVector<std::complex<T>> temp_output;

// Run each dft of each batch as if it was a real-valued batch size 1 dft operation
for (int64_t batch_idx = 0; batch_idx < batch_size; batch_idx++) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/signal/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ template <typename T>
static T get_scalar_value_from_tensor(const Tensor* tensor) {
ORT_ENFORCE(tensor->Shape().Size() == 1, "ratio input should have a single value.");

auto data_type = tensor->DataType()->AsPrimitiveDataType()->GetDataType();
auto data_type = tensor->GetElementType();
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
return static_cast<T>(*reinterpret_cast<const float*>(tensor->DataRaw()));
Expand Down
226 changes: 96 additions & 130 deletions onnxruntime/core/providers/cpu/signal/window_functions.cc
Original file line number Diff line number Diff line change
@@ -1,149 +1,115 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/common.h"
#include "utils.h"
#include "window_functions.h"
#include "core/providers/cpu/signal/window_functions.h"

#include <cmath>

#include "core/providers/common.h"
#include "core/providers/cpu/signal/utils.h"

namespace onnxruntime {
ONNX_CPU_OPERATOR_KERNEL(
MelWeightMatrix,
17,
KernelDefBuilder().MayInplace(0, 0) //
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t>()) //
.TypeConstraint("T2", BuildKernelDefConstraints<float>())
.TypeConstraint("T3", BuildKernelDefConstraints<float, double, uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t>()),
MelWeightMatrix);

static inline double hz_to_mel_scale(double hz) {
return 2595 * std::log10(1 + hz / 700);
}
ONNX_CPU_OPERATOR_KERNEL(MelWeightMatrix, 17,
KernelDefBuilder()
.MayInplace(0, 0) //
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t>()) //
.TypeConstraint("T2", BuildKernelDefConstraints<float>())
.TypeConstraint("T3",
BuildKernelDefConstraints<float, double, uint8_t, uint16_t, uint32_t,
uint64_t, int8_t, int16_t, int32_t, int64_t>()),
MelWeightMatrix);

static inline double mel_scale_to_hz(double mels) {
return 700 * (pow(10, (mels / 2595)) - 1);
}
static inline double hz_to_mel_scale(double hz) { return 2595 * std::log10(1 + hz / 700); }

static inline double mel_scale_to_hz(double mels) { return 700 * (pow(10, (mels / 2595)) - 1); }

template <typename T>
Status create_mel_weight_matrix(OpKernelContext* ctx, int64_t num_mel_bins, int64_t dft_length, int64_t sample_rate, float lower_edge_hertz, float upper_edge_hertz) {
// Determine the width of the spectrogram.
// This is determined as half the size of the fft size. The first element of the spectrum is always retained,
// and the remaining are halved. The second half can be discarded due to the conjugate symmetry of the output with real valued ffts.
// Taken together the formula for the size of the output will be std::floor(dft_length / 2) + 1.
int64_t num_spectrogram_bins = static_cast<int64_t>(std::floor(dft_length / 2 + 1));

// Checks
auto lowest_index = std::floor(((dft_length + 1) * lower_edge_hertz) / sample_rate);
auto highest_index = std::floor(((dft_length + 1) * upper_edge_hertz) / sample_rate);
ORT_ENFORCE(lowest_index >= 0 && lowest_index < num_spectrogram_bins, "lower_edge_hertz produces a mel triangle filter bank that is out of range given the dft_length and the sample_rate.");
ORT_ENFORCE(highest_index >= 0 && highest_index < num_spectrogram_bins, "upper_edge_hertz produces a mel triangle filter bank that is out of range given the dft_length and the sample_rate.");

// Create the output shape
onnxruntime::TensorShape output_shape(
{static_cast<int64_t>(num_spectrogram_bins),
num_mel_bins});
auto* Y = ctx->Output(0, output_shape);

// Get the raw output data
auto* Y_data = reinterpret_cast<T*>(Y->MutableDataRaw());

// Set the weight matrix to 0
memset(Y_data, 0, num_spectrogram_bins * num_mel_bins * sizeof(T));

// The mel filterbank is a triangular shaped peak with a height of 1 and a base equal to the size of the MEL range divided by
// the number of bins needed times 2. This triangle is then slid across the mel domain linearly, with a constant step size that
// is equal to half of the base of the triangle. To accommodate N bins, N+2 data points will be needed to determine the
// start, center and end points of each mel triangle filter.
//
// low_frequency where the mel triangle filter banks begin, and they end on the high_frequency_mel
// The range is divided evenly to create the needed points corresponding to the begin, center, end points of each triangle filterbank
std::vector<size_t> frequency_bins(num_mel_bins + 2);
auto low_frequency_mel = hz_to_mel_scale(lower_edge_hertz);
auto high_frequency_mel = hz_to_mel_scale(upper_edge_hertz);
auto mel_step = (high_frequency_mel - low_frequency_mel) / static_cast<float>(frequency_bins.size());

// Convert each point from mel scale back to hertz, and then compute the corresponding index in the fft
for (size_t i = 0; i < frequency_bins.size(); i++) {
auto hz = mel_scale_to_hz(low_frequency_mel + mel_step * i);
frequency_bins[i] = static_cast<size_t>(std::floor(((dft_length + 1) * hz) / sample_rate));
}
struct CreateMelWeightMatrix {
Status operator()(OpKernelContext* ctx, int64_t num_mel_bins, int64_t dft_length, int64_t sample_rate,
float lower_edge_hertz, float upper_edge_hertz) {
// Determine the width of the spectrogram.
// This is determined as half the size of the fft size. The first element of the spectrum is always retained,
// and the remaining are halved. The second half can be discarded due to the conjugate symmetry of the output with
// real valued ffts. Taken together the formula for the size of the output will be std::floor(dft_length / 2) + 1.
int64_t num_spectrogram_bins = static_cast<int64_t>(std::floor(dft_length / 2 + 1));

// Checks
auto lowest_index = std::floor(((dft_length + 1) * lower_edge_hertz) / sample_rate);
auto highest_index = std::floor(((dft_length + 1) * upper_edge_hertz) / sample_rate);
ORT_ENFORCE(
lowest_index >= 0 && lowest_index < num_spectrogram_bins,
"lower_edge_hertz produces a mel triangle filter bank that is out of range given the dft_length and the "
"sample_rate.");
ORT_ENFORCE(
highest_index >= 0 && highest_index < num_spectrogram_bins,
"upper_edge_hertz produces a mel triangle filter bank that is out of range given the dft_length and the "
"sample_rate.");

// Create the output shape
onnxruntime::TensorShape output_shape({static_cast<int64_t>(num_spectrogram_bins), num_mel_bins});
auto* Y = ctx->Output(0, output_shape);

// Get the raw output data
auto* Y_data = reinterpret_cast<T*>(Y->MutableDataRaw());

// Set the weight matrix to 0
memset(Y_data, 0, num_spectrogram_bins * num_mel_bins * sizeof(T));

// The mel filterbank is a triangular shaped peak with a height of 1 and a base equal to the size of the MEL range
// divided by the number of bins needed times 2. This triangle is then slid across the mel domain linearly, with a
// constant step size that is equal to half of the base of the triangle. To accommodate N bins, N+2 data points will
// be needed to determine the start, center and end points of each mel triangle filter.
//
// low_frequency where the mel triangle filter banks begin, and they end on the high_frequency_mel
// The range is divided evenly to create the needed points corresponding to the begin, center, end points of each
// triangle filterbank
InlinedVector<size_t> frequency_bins(num_mel_bins + 2);
auto low_frequency_mel = hz_to_mel_scale(lower_edge_hertz);
auto high_frequency_mel = hz_to_mel_scale(upper_edge_hertz);
auto mel_step = (high_frequency_mel - low_frequency_mel) / static_cast<float>(frequency_bins.size());

// Convert each point from mel scale back to hertz, and then compute the corresponding index in the fft
for (size_t i = 0; i < frequency_bins.size(); i++) {
auto hz = mel_scale_to_hz(low_frequency_mel + mel_step * i);
frequency_bins[i] = static_cast<size_t>(std::floor(((dft_length + 1) * hz) / sample_rate));
}

for (size_t i = 0; i < static_cast<size_t>(num_mel_bins); i++) {
auto lower_frequency_value = frequency_bins[i]; // left
auto center_frequency_point = frequency_bins[i + 1]; // center
auto higher_frequency_point = frequency_bins[i + 2]; // right

auto low_to_center = center_frequency_point - lower_frequency_value;
if (low_to_center == 0) {
auto& current_element = *(Y_data + (center_frequency_point * num_mel_bins) + i);
current_element = static_cast<T>(1);
} else {
for (size_t j = lower_frequency_value; j <= center_frequency_point; j++) {
auto& current_element = *(Y_data + (j * num_mel_bins) + i);
current_element = static_cast<T>((j - lower_frequency_value) / static_cast<T>(low_to_center));
for (size_t i = 0; i < static_cast<size_t>(num_mel_bins); i++) {
auto lower_frequency_value = frequency_bins[i]; // left
auto center_frequency_point = frequency_bins[i + 1]; // center
auto higher_frequency_point = frequency_bins[i + 2]; // right

auto low_to_center = center_frequency_point - lower_frequency_value;
if (low_to_center == 0) {
auto& current_element = *(Y_data + (center_frequency_point * num_mel_bins) + i);
current_element = static_cast<T>(1);
} else {
for (size_t j = lower_frequency_value; j <= center_frequency_point; j++) {
auto& current_element = *(Y_data + (j * num_mel_bins) + i);
current_element = static_cast<T>((j - lower_frequency_value) / static_cast<T>(low_to_center));
}
}
}

auto center_to_high = higher_frequency_point - center_frequency_point;
if (center_to_high > 0) {
for (size_t j = center_frequency_point; j < higher_frequency_point; j++) {
auto& current_element = *(Y_data + (j * num_mel_bins) + i);
current_element = static_cast<T>((higher_frequency_point - j) / static_cast<T>(center_to_high));
auto center_to_high = higher_frequency_point - center_frequency_point;
if (center_to_high > 0) {
for (size_t j = center_frequency_point; j < higher_frequency_point; j++) {
auto& current_element = *(Y_data + (j * num_mel_bins) + i);
current_element = static_cast<T>((higher_frequency_point - j) / static_cast<T>(center_to_high));
}
}
}
}

return Status::OK();
}
return Status::OK();
}
};

static Status create_mel_weight_matrix(OpKernelContext* ctx, onnx::TensorProto_DataType output_datatype,
int64_t num_mel_bins, int64_t dft_length, int64_t sample_rate, float lower_edge_hertz, float upper_edge_hertz) {
switch (output_datatype) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<float>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<double>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<int8_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<int16_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<int32_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT64: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<int64_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<uint8_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<uint16_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT32: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<uint32_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<uint64_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
default:
ORT_THROW("Unsupported input data type of ", output_datatype);
}
return Status::OK();
int64_t num_mel_bins, int64_t dft_length, int64_t sample_rate,
float lower_edge_hertz, float upper_edge_hertz) {
utils::MLTypeCallDispatcher<float, double, int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t>
dispatcher(output_datatype);
return dispatcher.InvokeRet<Status, CreateMelWeightMatrix>(ctx, num_mel_bins, dft_length, sample_rate,
lower_edge_hertz, upper_edge_hertz);
}

Status MelWeightMatrix::Compute(OpKernelContext* ctx) const {
Expand All @@ -153,7 +119,7 @@ Status MelWeightMatrix::Compute(OpKernelContext* ctx) const {
const auto lower_edge_hertz = signal::get_scalar_value_from_tensor<float>(ctx->Input<Tensor>(3));
const auto upper_edge_hertz = signal::get_scalar_value_from_tensor<float>(ctx->Input<Tensor>(4));

ORT_RETURN_IF_ERROR(create_mel_weight_matrix(ctx, data_type_, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz));
return Status::OK();
return create_mel_weight_matrix(ctx, data_type_, num_mel_bins, dft_length, sample_rate, lower_edge_hertz,
upper_edge_hertz);
}
} // namespace onnxruntime

0 comments on commit 5067d11

Please sign in to comment.