Skip to content

Commit

Permalink
Register signal ops for op set 17.
Browse files Browse the repository at this point in the history
Also:
* de-duplicate get_scalar_value_from_tensor
* fix some bugs that caused compilation errors with the experimental
  ops. Tested with `build.sh --ms_experimental`
* add function bodies for ms experimental ops
* fix some spelling errors

ghstack-source-id: 729a2230adbed583c044653b25dbce8c54a29e30
Pull Request resolved: microsoft#11733
  • Loading branch information
garymm committed Jun 8, 2022
1 parent f62a4d6 commit 51e9b33
Show file tree
Hide file tree
Showing 10 changed files with 876 additions and 776 deletions.
586 changes: 13 additions & 573 deletions onnxruntime/contrib_ops/cpu/signal/dft.cc

Large diffs are not rendered by default.

174 changes: 6 additions & 168 deletions onnxruntime/contrib_ops/cpu/signal/window_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "core/providers/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/signal/utils.h"
#include "core/util/math_cpuonly.h"
#include "Eigen/src/Core/Map.h"
#include "window_functions.h"
Expand All @@ -23,7 +24,7 @@ ONNX_OPERATOR_KERNEL_EX(
kMSExperimentalDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().MayInplace(0, 0)
KernelDefBuilder().MayInplace(0, 0) //
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t>())
.TypeConstraint("T2", BuildKernelDefConstraints<float, double, uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t>()),
HannWindow);
Expand All @@ -33,7 +34,7 @@ ONNX_OPERATOR_KERNEL_EX(
kMSExperimentalDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().MayInplace(0, 0)
KernelDefBuilder().MayInplace(0, 0) //
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t>())
.TypeConstraint("T2", BuildKernelDefConstraints<float, double, uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t>()),
HammingWindow);
Expand All @@ -43,24 +44,22 @@ ONNX_OPERATOR_KERNEL_EX(
kMSExperimentalDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().MayInplace(0, 0)
KernelDefBuilder().MayInplace(0, 0) //
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t>())
.TypeConstraint("T2", BuildKernelDefConstraints<float, double, uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t>()),
BlackmanWindow);


ONNX_OPERATOR_KERNEL_EX(
MelWeightMatrix,
kMSExperimentalDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().MayInplace(0, 0)
KernelDefBuilder().MayInplace(0, 0) //
.TypeConstraint("T1", BuildKernelDefConstraints<int32_t, int64_t>())
.TypeConstraint("T2", BuildKernelDefConstraints<float>())
.TypeConstraint("T3", BuildKernelDefConstraints<float, double, uint8_t, uint16_t, uint32_t, uint64_t, int8_t, int16_t, int32_t, int64_t>()),
MelWeightMatrix);


template <typename T>
static Status cosine_sum_window(Tensor* Y, size_t size, float a0, float a1, float a2) {
auto* Y_data = reinterpret_cast<T*>(Y->MutableDataRaw());
Expand All @@ -80,31 +79,12 @@ static Status cosine_sum_window(Tensor* Y, size_t size, float a0, float a1, floa
return Status::OK();
}

template <typename T>
static T get_scalar_value_from_tensor(const Tensor* tensor) {
ORT_ENFORCE(tensor->Shape().Size() == 1, "Tensor input should have a single value.");
auto data_type = tensor->DataType()->AsPrimitiveDataType()->GetDataType();
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
return static_cast<T>(*reinterpret_cast<const float*>(tensor->DataRaw()));
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
return static_cast<T>(*reinterpret_cast<const double*>(tensor->DataRaw()));
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
return static_cast<T>(*reinterpret_cast<const int32_t*>(tensor->DataRaw()));
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
return static_cast<T>(*reinterpret_cast<const int64_t*>(tensor->DataRaw()));
default:
ORT_THROW("Unsupported input data type of ", data_type);
}
}

static Status create_cosine_sum_window(
OpKernelContext* ctx,
onnx::TensorProto_DataType output_datatype,
float a0, float a1, float a2) {

// Get the size of the window
auto size = get_scalar_value_from_tensor<int64_t>(ctx->Input<Tensor>(0));
auto size = ::onnxruntime::signal::get_scalar_value_from_tensor<int64_t>(ctx->Input<Tensor>(0));

// Get the output tensor
auto Y_shape = onnxruntime::TensorShape({size});
Expand Down Expand Up @@ -186,148 +166,6 @@ Status BlackmanWindow::Compute(OpKernelContext* ctx) const {
return create_cosine_sum_window(ctx, data_type_, a0, a1, a2);
}

static inline double hz_to_mel_scale(double hz) {
return 2595 * std::log10(1 + hz / 700);
}

static inline double mel_scale_to_hz(double mels) {
return 700 * (pow(10, (mels / 2595)) - 1);
}

template <typename T>
Status create_mel_weight_matrix(OpKernelContext* ctx, int64_t num_mel_bins, int64_t dft_length, int64_t sample_rate, float lower_edge_hertz, float upper_edge_hertz) {
// Determine the width of the spectrogram.
// This is determined as half the size of the fft size. The first element of the spectrum is always retained,
// and the remaining are halved. The second half can be discarded due to the conjugate symmetry of the output with real valued ffts.
// Taken together the formula for the size of the output will be std::floor(dft_length / 2) + 1.
int64_t num_spectrogram_bins = static_cast<int64_t>(std::floor(dft_length / 2 + 1));

// Checks
auto lowest_index = std::floor(((dft_length + 1) * lower_edge_hertz) / sample_rate);
auto highest_index = std::floor(((dft_length + 1) * upper_edge_hertz) / sample_rate);
ORT_ENFORCE(lowest_index >= 0 && lowest_index < num_spectrogram_bins, "lower_edge_hertz produces a mel triangle filter bank that is out of range given the dft_length and the sample_rate.");
ORT_ENFORCE(highest_index >= 0 && highest_index < num_spectrogram_bins, "upper_edge_hertz produces a mel triangle filter bank that is out of range given the dft_length and the sample_rate.");

// Create the output shape
onnxruntime::TensorShape output_shape(
{
static_cast<int64_t>(num_spectrogram_bins),
num_mel_bins
});
auto* Y = ctx->Output(0, output_shape);

// Get the raw output data
auto* Y_data = reinterpret_cast<T*>(Y->MutableDataRaw());

// Set the weight matrix to 0
memset(Y_data, 0, num_spectrogram_bins * num_mel_bins * sizeof(T));

// The mel filterbank is a triangular shaped peak with a height of 1 and a base equal to the size of the MEL range divided by
// the number of bins needed times 2. This triagle is then slid across the mel domain linearly, with a constant step size that
// is equal to half of the base of the triange. To accomodate N bins, N+2 data points will be needed to determine the
// start, center and end points of each mel triange filter.
//
// low_frequency where the mel triangle filter banks begin, and they end on the high_frequency_mel
// The range is divided evenly to create the needed points corresponding to the begin, center, end points of each triangle filterbank
std::vector<size_t> frequency_bins(num_mel_bins + 2);
auto low_frequency_mel = hz_to_mel_scale(lower_edge_hertz);
auto high_frequency_mel = hz_to_mel_scale(upper_edge_hertz);
auto mel_step = (high_frequency_mel - low_frequency_mel) / static_cast<float>(frequency_bins.size());

// Convert each point from mel scale back to hertz, and then compute the corresponding index in the fft
for (size_t i = 0; i < frequency_bins.size(); i++) {
auto hz = mel_scale_to_hz(low_frequency_mel + mel_step * i);
frequency_bins[i] = static_cast<size_t>(std::floor(((dft_length + 1) * hz) / sample_rate));
}

for (size_t i = 0; i < static_cast<size_t>(num_mel_bins); i++) {
auto lower_frequency_value = frequency_bins[i]; //left
auto center_frequency_point = frequency_bins[i+1]; //center
auto higher_frequency_point = frequency_bins[i+2]; //right

auto low_to_center = center_frequency_point - lower_frequency_value;
if (low_to_center == 0) {
auto& current_element = *(Y_data + (center_frequency_point * num_mel_bins) + i);
current_element = static_cast<T>(1);
} else {
for (size_t j = lower_frequency_value; j <= center_frequency_point; j++) {
auto& current_element = *(Y_data + (j * num_mel_bins) + i);
current_element = static_cast<T>((j - lower_frequency_value) / static_cast<T>(low_to_center));
}
}

auto center_to_high = higher_frequency_point - center_frequency_point;
if (center_to_high > 0) {
for (size_t j = center_frequency_point; j < higher_frequency_point; j++) {
auto& current_element = *(Y_data + (j * num_mel_bins) + i);
current_element = static_cast<T>((higher_frequency_point - j) / static_cast<T>(center_to_high));
}
}
}

return Status::OK();
}

static Status create_mel_weight_matrix(OpKernelContext* ctx, onnx::TensorProto_DataType output_datatype,
int64_t num_mel_bins, int64_t dft_length, int64_t sample_rate, float lower_edge_hertz, float upper_edge_hertz) {
switch (output_datatype) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<float>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<double>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<int8_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<int16_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<int32_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_INT64: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<int64_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT8: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<uint8_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<uint16_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT32: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<uint32_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
case ONNX_NAMESPACE::TensorProto_DataType_UINT64: {
ORT_RETURN_IF_ERROR((create_mel_weight_matrix<uint64_t>(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)));
break;
}
default:
ORT_THROW("Unsupported input data type of ", output_datatype);
}
return Status::OK();
}

Status MelWeightMatrix::Compute(OpKernelContext* ctx) const {
const auto num_mel_bins = get_scalar_value_from_tensor<int64_t>(ctx->Input<Tensor>(0));
const auto dft_length = get_scalar_value_from_tensor<int64_t>(ctx->Input<Tensor>(1));
const auto sample_rate = get_scalar_value_from_tensor<int64_t>(ctx->Input<Tensor>(2));
const auto lower_edge_hertz = get_scalar_value_from_tensor<float>(ctx->Input<Tensor>(3));
const auto upper_edge_hertz = get_scalar_value_from_tensor<float>(ctx->Input<Tensor>(4));

ORT_RETURN_IF_ERROR(create_mel_weight_matrix(ctx, data_type_, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz));
return Status::OK();
}

} // namespace contrib
} // namespace onnxruntime

Expand Down
21 changes: 3 additions & 18 deletions onnxruntime/contrib_ops/cpu/signal/window_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,11 @@

#ifdef BUILD_MS_EXPERIMENTAL_OPS

#include "core/providers/cpu/signal/window_functions.h"

namespace onnxruntime {
namespace contrib {

class VariableOutputDataTypeBase : public OpKernel {
protected:
onnx::TensorProto_DataType data_type_;

public:
VariableOutputDataTypeBase(const OpKernelInfo& info) : OpKernel(info) {
data_type_ = static_cast<onnx::TensorProto_DataType>(info.GetAttrOrDefault<int64_t>("output_datatype", onnx::TensorProto_DataType::TensorProto_DataType_FLOAT));
}
};

class HannWindow final : public VariableOutputDataTypeBase {
public:
explicit HannWindow(const OpKernelInfo& info) : VariableOutputDataTypeBase(info) {
Expand All @@ -37,14 +29,7 @@ class BlackmanWindow final : public VariableOutputDataTypeBase {
Status Compute(OpKernelContext* ctx) const override;
};

class MelWeightMatrix final : public VariableOutputDataTypeBase {
public:
explicit MelWeightMatrix(const OpKernelInfo& info) : VariableOutputDataTypeBase(info) {
}
Status Compute(OpKernelContext* ctx) const override;
};

} // namespace contrib
} // namespace onnxruntime

#endif
#endif
37 changes: 28 additions & 9 deletions onnxruntime/core/graph/signal_ops/signal_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include <cmath>

// NOTE: These were added to the standard op set. We register them under the MS domain
// for backwards compatibility, but new users should use the standard ops instead.
namespace onnxruntime {
namespace signal {

Expand Down Expand Up @@ -284,6 +286,25 @@ void RegisterSignalSchemas() {
updateOutputShape(ctx, 0, result_shape_proto);
});

ONNX_NAMESPACE::NodeProto idft_function_body_node;
idft_function_body_node.set_op_type("DFT");

auto* idft_function_body_inverse = idft_function_body_node.add_attribute();
idft_function_body_inverse->set_name("inverse");
idft_function_body_inverse->set_i(1);

auto* idft_function_body_axis = idft_function_body_node.add_attribute();
idft_function_body_axis->set_name("axis");
idft_function_body_axis->set_ref_attr_name("axis");

idft_function_body_node.add_input("input");
idft_function_body_node.add_input("dft_length");
idft_function_body_node.add_output("output");

ONNX_NAMESPACE::OperatorSetIdProto idft_function_body_op_set;
idft_function_body_op_set.set_domain(kOnnxDomain);
idft_function_body_op_set.set_version(17);

MS_SIGNAL_OPERATOR_SCHEMA(IDFT)
.SetDomain(kMSExperimentalDomain)
.SinceVersion(1)
Expand Down Expand Up @@ -335,8 +356,6 @@ void RegisterSignalSchemas() {
"Constrain scalar length types to int64_t.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
const int64_t batch_ndim = 1;

auto& input_shape = getInputShape(ctx, 0);
ONNX_NAMESPACE::TensorShapeProto result_shape = input_shape;
auto dim_size = static_cast<int64_t>(input_shape.dim_size());
Expand All @@ -349,7 +368,8 @@ void RegisterSignalSchemas() {
}

updateOutputShape(ctx, 0, result_shape);
});
})
.FunctionBody({idft_function_body_node}, {idft_function_body_op_set});

MS_SIGNAL_OPERATOR_SCHEMA(STFT)
.SetDomain(kMSExperimentalDomain)
Expand Down Expand Up @@ -471,7 +491,7 @@ void RegisterSignalSchemas() {

const ONNX_NAMESPACE::TensorShapeProto* window_shape = nullptr;
if (ctx.getNumInputs() >= 3) {
window_shape = getOptionalInputShape(ctx, 2);
window_shape = ONNX_NAMESPACE::getOptionalInputShape(ctx, 2);
} else {
window_shape = nullptr;
}
Expand Down Expand Up @@ -644,11 +664,10 @@ void RegisterSignalSchemas() {
}
)ONNX");

static const char* MelWeightMatrix_ver17_doc = R"DOC(
static const char* MelWeightMatrix_doc = R"DOC(
Generate a MelWeightMatrix that can be used to re-weight a Tensor containing a linearly sampled frequency spectra
(from DFT or STFT) into num_mel_bins frequency information based on the [lower_edge_hertz, upper_edge_hertz] range
on the mel scale.
This function defines the mel scale in terms of a frequency in hertz according to the following formula:
(from DFT or STFT) into num_mel_bins frequency information based on the [lower_edge_hertz, upper_edge_hertz] range on
the mel scale. This function defines the mel scale in terms of a frequency in hertz according to the following formula:
mel(f) = 2595 * log10(1 + f/700)
Expand All @@ -661,7 +680,7 @@ linear scale spectrum values (e.g. STFT magnitudes) to generate a "mel spectrogr
MS_SIGNAL_OPERATOR_SCHEMA(MelWeightMatrix)
.SetDomain(kMSExperimentalDomain)
.SinceVersion(1)
.SetDoc(MelWeightMatrix_ver17_doc)
.SetDoc(MelWeightMatrix_doc)
.Attr("output_datatype",
"The data type of the output tensor. "
"Strictly must be one of the types from DataType enum in TensorProto.",
Expand Down
Loading

0 comments on commit 51e9b33

Please sign in to comment.