From 79db92f8fea2f14b354ea4bad8d5b17f19983efb Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Wed, 8 Jun 2022 15:45:40 -0700 Subject: [PATCH] clang-format signal_defs.cc (#11767) --- .../core/graph/signal_ops/signal_defs.cc | 522 +++++++++--------- 1 file changed, 259 insertions(+), 263 deletions(-) diff --git a/onnxruntime/core/graph/signal_ops/signal_defs.cc b/onnxruntime/core/graph/signal_ops/signal_defs.cc index e7cd7329bc243..27e077c9fefe4 100644 --- a/onnxruntime/core/graph/signal_ops/signal_defs.cc +++ b/onnxruntime/core/graph/signal_ops/signal_defs.cc @@ -193,91 +193,88 @@ void RegisterSignalSchemas() { {"tensor(int32)", "tensor(int64)"}, "Constrain scalar length types to int64_t.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - bool is_onesided = static_cast(getAttribute(ctx, "onesided", 0)); - bool inverse = static_cast(getAttribute(ctx, "inverse", 0)); + bool is_onesided = static_cast(getAttribute(ctx, "onesided", 0)); + bool inverse = static_cast(getAttribute(ctx, "inverse", 0)); - if (inverse && is_onesided) { - fail_shape_inference("is_onesided and inverse attributes cannot be enabled at the same time"); - } + if (inverse && is_onesided) { + fail_shape_inference("is_onesided and inverse attributes cannot be enabled at the same time"); + } - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (!hasInputShape(ctx, 0)) - { - // If no shape is available for the input, skip shape inference... - return; - } + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (!hasInputShape(ctx, 0)) { + // If no shape is available for the input, skip shape inference... + return; + } - // In general the output shape will match the input shape exactly - // So initialize the output shape with the input shape - auto& input_shape = getInputShape(ctx, 0); - ONNX_NAMESPACE::TensorShapeProto result_shape_proto = input_shape; - - // Get the axis where the DFT will be performed. - auto axis = static_cast(getAttribute(ctx, "axis", 1)); - auto rank = input_shape.dim_size(); - - if (!(-rank <= axis && axis < rank)) { - fail_shape_inference( - "axis attribute value ", - axis, - " is invalid for a tensor of rank ", - rank); - } + // In general the output shape will match the input shape exactly + // So initialize the output shape with the input shape + auto& input_shape = getInputShape(ctx, 0); + ONNX_NAMESPACE::TensorShapeProto result_shape_proto = input_shape; + + // Get the axis where the DFT will be performed. + auto axis = static_cast(getAttribute(ctx, "axis", 1)); + auto rank = input_shape.dim_size(); + + if (!(-rank <= axis && axis < rank)) { + fail_shape_inference( + "axis attribute value ", + axis, + " is invalid for a tensor of rank ", + rank); + } - auto axis_idx = (axis >= 0 ? axis : axis + rank); - - // If dft_length is specified, then we should honor the shape. - // Set the output dimension to match the dft_length on the axis. - // If onesided this will be adjusted later on... - const ONNX_NAMESPACE::TensorProto* dft_length = nullptr; - if (ctx.getNumInputs() >= 2 && ctx.getInputType(1) != nullptr) { - dft_length = ctx.getInputData(1); - if (dft_length == nullptr) { - // If we cannot read the dft_length, we cannot infer shape - // return... - return; - } + auto axis_idx = (axis >= 0 ? axis : axis + rank); + + // If dft_length is specified, then we should honor the shape. + // Set the output dimension to match the dft_length on the axis. + // If onesided this will be adjusted later on... + const ONNX_NAMESPACE::TensorProto* dft_length = nullptr; + if (ctx.getNumInputs() >= 2 && ctx.getInputType(1) != nullptr) { + dft_length = ctx.getInputData(1); + if (dft_length == nullptr) { + // If we cannot read the dft_length, we cannot infer shape + // return... + return; } + } - if (nullptr != dft_length) { - if (dft_length->dims_size() != 0) { - fail_shape_inference("dft_length input must be a scalar."); - } - auto dft_length_value = get_scalar_value_from_tensor(dft_length); - result_shape_proto.mutable_dim(axis_idx)->set_dim_value(dft_length_value); + if (nullptr != dft_length) { + if (dft_length->dims_size() != 0) { + fail_shape_inference("dft_length input must be a scalar."); } - // When DFT is onesided, the output shape is half the size of the input shape - // along the specified axis. - if (is_onesided) { - auto axis_dimension = result_shape_proto.dim(axis_idx); - // We need to update the output shape dimension along the specified axis, - // but sometimes the dimension will be a free dimension or be otherwise unset. - // Only perform inference when a input dimension value exists. - if (axis_dimension.has_dim_value()) - { - auto original_signal_size = axis_dimension.dim_value(); - auto half_signal_size = (original_signal_size >> 1) + 1; - result_shape_proto.mutable_dim(axis_idx)->set_dim_value(half_signal_size); - } else - { - // Clear the value and param (which would otherwie be inherited from the input). - result_shape_proto.mutable_dim(axis_idx)->clear_dim_value(); - result_shape_proto.mutable_dim(axis_idx)->clear_dim_param(); - } + auto dft_length_value = get_scalar_value_from_tensor(dft_length); + result_shape_proto.mutable_dim(axis_idx)->set_dim_value(dft_length_value); + } + // When DFT is onesided, the output shape is half the size of the input shape + // along the specified axis. + if (is_onesided) { + auto axis_dimension = result_shape_proto.dim(axis_idx); + // We need to update the output shape dimension along the specified axis, + // but sometimes the dimension will be a free dimension or be otherwise unset. + // Only perform inference when a input dimension value exists. + if (axis_dimension.has_dim_value()) { + auto original_signal_size = axis_dimension.dim_value(); + auto half_signal_size = (original_signal_size >> 1) + 1; + result_shape_proto.mutable_dim(axis_idx)->set_dim_value(half_signal_size); + } else { + // Clear the value and param (which would otherwie be inherited from the input). + result_shape_proto.mutable_dim(axis_idx)->clear_dim_value(); + result_shape_proto.mutable_dim(axis_idx)->clear_dim_param(); } + } - // Coerce the last dimension to 2. - auto dim_size = static_cast(result_shape_proto.dim_size()); - auto has_component_dimension = dim_size > 2; + // Coerce the last dimension to 2. + auto dim_size = static_cast(result_shape_proto.dim_size()); + auto has_component_dimension = dim_size > 2; - // This if check is retained in the contrib op and not the official spec for back compat - if (has_component_dimension) { - result_shape_proto.mutable_dim(static_cast(dim_size - 1))->set_dim_value(2); - } else { - result_shape_proto.add_dim()->set_dim_value(2); - } + // This if check is retained in the contrib op and not the official spec for back compat + if (has_component_dimension) { + result_shape_proto.mutable_dim(static_cast(dim_size - 1))->set_dim_value(2); + } else { + result_shape_proto.add_dim()->set_dim_value(2); + } - updateOutputShape(ctx, 0, result_shape_proto); + updateOutputShape(ctx, 0, result_shape_proto); }); MS_SIGNAL_OPERATOR_SCHEMA(IDFT) @@ -319,29 +316,29 @@ void RegisterSignalSchemas() { 1, OpSchema::NonDifferentiable) .TypeConstraint( - "T1", - {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, - "Constrain input and output types to float tensors.") + "T1", + {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") .TypeConstraint( - "T2", - {"tensor(int64)"}, - "Constrain scalar length types to int64_t.") + "T2", + {"tensor(int64)"}, + "Constrain scalar length types to int64_t.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - const int64_t batch_ndim = 1; - - auto& input_shape = getInputShape(ctx, 0); - ONNX_NAMESPACE::TensorShapeProto result_shape = input_shape; - auto dim_size = static_cast(input_shape.dim_size()); - auto has_component_dimension = dim_size > 2; - - if (has_component_dimension) { - result_shape.mutable_dim(static_cast(dim_size - 1))->set_dim_value(2); - } else { - result_shape.add_dim()->set_dim_value(2); - } + propagateElemTypeFromInputToOutput(ctx, 0, 0); + const int64_t batch_ndim = 1; + + auto& input_shape = getInputShape(ctx, 0); + ONNX_NAMESPACE::TensorShapeProto result_shape = input_shape; + auto dim_size = static_cast(input_shape.dim_size()); + auto has_component_dimension = dim_size > 2; + + if (has_component_dimension) { + result_shape.mutable_dim(static_cast(dim_size - 1))->set_dim_value(2); + } else { + result_shape.add_dim()->set_dim_value(2); + } - updateOutputShape(ctx, 0, result_shape); + updateOutputShape(ctx, 0, result_shape); }); MS_SIGNAL_OPERATOR_SCHEMA(STFT) @@ -349,21 +346,22 @@ void RegisterSignalSchemas() { .SinceVersion(1) .SetDoc(R"DOC(STFT)DOC") .Attr( - "onesided", - "If onesided is 1, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] are returned because " - "the real-to-complex Fourier transform satisfies the conjugate symmetry, i.e., X[m, w] = X[m,w]=X[m,n_fft-w]*. " - "Note if the input or window tensors are complex, then onesided output is not possible. " - "Enabling onesided with real inputs performs a Real-valued fast Fourier transform (RFFT)." - "When invoked with real or complex valued input, the default value is 1. " - "Values can be 0 or 1.", - AttributeProto::INT, - static_cast(1)) + "onesided", + "If onesided is 1, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] are returned because " + "the real-to-complex Fourier transform satisfies the conjugate symmetry, i.e., X[m, w] = X[m,w] = " + "X[m,n_fft-w]*. Note if the input or window tensors are complex, then onesided output is not possible. " + "Enabling onesided with real inputs performs a Real-valued fast Fourier transform (RFFT)." + "When invoked with real or complex valued input, the default value is 1. " + "Values can be 0 or 1.", + AttributeProto::INT, + static_cast(1)) .Input(0, "signal", "Input tensor representing a real or complex valued signal. " "For real input, the following shape is expected: [batch_size][signal_length][1]. " "For complex input, the following shape is expected: [batch_size][signal_length][2], where " - "[batch_size][signal_length][0] represents the real component and [batch_size][signal_length][1] represents the imaginary component of the signal.", + "[batch_size][signal_length][0] represents the real component and [batch_size][signal_length][1] " + "represents the imaginary component of the signal.", "T1", OpSchema::Single, true, @@ -399,8 +397,10 @@ void RegisterSignalSchemas() { .Output(0, "output", "The Short-time Fourier Transform of the signals." - "If onesided is 1, the output has the shape: [batch_size][frames][dft_unique_bins][2], where dft_unique_bins is frame_length // 2 + 1 (the unique components of the DFT) " - "If onesided is 0, the output has the shape: [batch_size][frames][frame_length][2], where frame_length is the length of the DFT.", + "If onesided is 1, the output has the shape: [batch_size][frames][dft_unique_bins][2], where " + "dft_unique_bins is frame_length // 2 + 1 (the unique components of the DFT) " + "If onesided is 0, the output has the shape: [batch_size][frames][frame_length][2], where frame_length " + "is the length of the DFT.", "T1", OpSchema::Single, true, @@ -409,141 +409,136 @@ void RegisterSignalSchemas() { .TypeConstraint( "T1", {"tensor(float)", - "tensor(float16)", - "tensor(double)", - "tensor(bfloat16)"}, + "tensor(float16)", + "tensor(double)", + "tensor(bfloat16)"}, "Constrain signal and output to float tensors.") .TypeConstraint( "T2", {"tensor(int32)", "tensor(int64)"}, "Constrain scalar length types to int64_t.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - - // Get signal size - // The signal size is needed to perform inference because the size of the signal - // is needed to compute the number of DFTs in the output. - // - // 1) Check if shape exists, return if not - // 2) Get the shape - // 3) Check if signal dim value exists, return if not - if (!hasInputShape(ctx, 0)) { - return; - } - - auto& input_shape = getInputShape(ctx, 0); - auto signal_dim = input_shape.dim(1); - if (!signal_dim.has_dim_value()) - { - return; - } - auto signal_size = signal_dim.dim_value(); - - // The frame step is a required input. - // Its value is needed to compute the number output nDFTs, so return early is missing. - const auto* frame_step = ctx.getInputData(1); - if (nullptr == frame_step) { - return; - } - auto frame_step_value = get_scalar_value_from_tensor(frame_step); - - // Determine the size of the DFT based on the 2 optional inputs window and frame_length. - // One must be set. - int64_t dft_size = -1; - const ONNX_NAMESPACE::TensorProto* frame_length = nullptr; - if (ctx.getNumInputs() >= 4 && ctx.getInputType(3) != nullptr) { - frame_length = ctx.getInputData(3); - if (frame_length == nullptr) { - // If we cannot read the frame_length, we cannot infer shape - // return... - return; - } - } - - const ONNX_NAMESPACE::TensorShapeProto* window_shape = nullptr; - if (ctx.getNumInputs() >= 3) { - window_shape = getOptionalInputShape(ctx, 2); - } else { - window_shape = nullptr; - } - - if (window_shape == nullptr && frame_length == nullptr) - { - // STFT expects to have at least one of these inputs set: [window, frame_length], - // but they may not be available at shape inference time - return; - } else if (window_shape != nullptr && frame_length != nullptr) - { - if (frame_length->dims_size() != 0) { - fail_shape_inference("frame_length input must be scalar."); - } - auto frame_length_value = get_scalar_value_from_tensor(frame_length); - - // Ensure that the window length and the dft_length match. - if (window_shape->dim_size() != 1) { - fail_shape_inference("window input must have rank = 1."); - } - if (window_shape->dim(0).has_dim_value()) - { - auto window_length = window_shape->dim(0).dim_value(); - if (window_length != frame_length_value) - { - fail_type_inference("If STFT has both a window input and frame_length specified, the dimension of the window must match the frame_length specified!"); - } - } - - dft_size = frame_length_value; - } else if (window_shape != nullptr) - { - // Ensure that the window length and the dft_length match. - if (window_shape->dim_size() != 1) { - fail_shape_inference("window input must have rank = 1."); - } - if (window_shape->dim(0).has_dim_value()) { - dft_size = window_shape->dim(0).dim_value(); - } else { - // Cannot determine the window size, and there is no frame_length, - // So shape inference cannot proceed. - return; - } - } else if (frame_length != nullptr) - { - if (frame_length->dims_size() != 0) { - fail_shape_inference("frame_length input must be scalar."); - } - dft_size = get_scalar_value_from_tensor(frame_length); - } - - bool is_onesided = static_cast(getAttribute(ctx, "onesided", 0)); - if (is_onesided) { - dft_size = is_onesided ? ((dft_size >> 1) + 1) : dft_size; - } - - auto n_dfts = static_cast((signal_size - dft_size) / static_cast(frame_step_value)) + 1; - - // The output has the following shape: [batch_size][frames][dft_unique_bins][2] - ONNX_NAMESPACE::TensorShapeProto result_shape_proto; - result_shape_proto.add_dim()->set_dim_value(input_shape.dim(0).dim_value()); // batch size - result_shape_proto.add_dim()->set_dim_value(n_dfts); - result_shape_proto.add_dim()->set_dim_value(dft_size); - result_shape_proto.add_dim()->set_dim_value(2); - updateOutputShape(ctx, 0, result_shape_proto); - }); + propagateElemTypeFromInputToOutput(ctx, 0, 0); + + // Get signal size + // The signal size is needed to perform inference because the size of the signal + // is needed to compute the number of DFTs in the output. + // + // 1) Check if shape exists, return if not + // 2) Get the shape + // 3) Check if signal dim value exists, return if not + if (!hasInputShape(ctx, 0)) { + return; + } + + auto& input_shape = getInputShape(ctx, 0); + auto signal_dim = input_shape.dim(1); + if (!signal_dim.has_dim_value()) { + return; + } + auto signal_size = signal_dim.dim_value(); + + // The frame step is a required input. + // Its value is needed to compute the number output nDFTs, so return early is missing. + const auto* frame_step = ctx.getInputData(1); + if (nullptr == frame_step) { + return; + } + auto frame_step_value = get_scalar_value_from_tensor(frame_step); + + // Determine the size of the DFT based on the 2 optional inputs window and frame_length. + // One must be set. + int64_t dft_size = -1; + const ONNX_NAMESPACE::TensorProto* frame_length = nullptr; + if (ctx.getNumInputs() >= 4 && ctx.getInputType(3) != nullptr) { + frame_length = ctx.getInputData(3); + if (frame_length == nullptr) { + // If we cannot read the frame_length, we cannot infer shape + // return... + return; + } + } + + const ONNX_NAMESPACE::TensorShapeProto* window_shape = nullptr; + if (ctx.getNumInputs() >= 3) { + window_shape = getOptionalInputShape(ctx, 2); + } else { + window_shape = nullptr; + } + + if (window_shape == nullptr && frame_length == nullptr) { + // STFT expects to have at least one of these inputs set: [window, frame_length], + // but they may not be available at shape inference time + return; + } else if (window_shape != nullptr && frame_length != nullptr) { + if (frame_length->dims_size() != 0) { + fail_shape_inference("frame_length input must be scalar."); + } + auto frame_length_value = get_scalar_value_from_tensor(frame_length); + + // Ensure that the window length and the dft_length match. + if (window_shape->dim_size() != 1) { + fail_shape_inference("window input must have rank = 1."); + } + if (window_shape->dim(0).has_dim_value()) { + auto window_length = window_shape->dim(0).dim_value(); + if (window_length != frame_length_value) { + fail_type_inference( + "If STFT has both a window input and frame_length specified, the dimension of the " + "window must match the frame_length specified!"); + } + } + + dft_size = frame_length_value; + } else if (window_shape != nullptr) { + // Ensure that the window length and the dft_length match. + if (window_shape->dim_size() != 1) { + fail_shape_inference("window input must have rank = 1."); + } + if (window_shape->dim(0).has_dim_value()) { + dft_size = window_shape->dim(0).dim_value(); + } else { + // Cannot determine the window size, and there is no frame_length, + // So shape inference cannot proceed. + return; + } + } else if (frame_length != nullptr) { + if (frame_length->dims_size() != 0) { + fail_shape_inference("frame_length input must be scalar."); + } + dft_size = get_scalar_value_from_tensor(frame_length); + } + + bool is_onesided = static_cast(getAttribute(ctx, "onesided", 0)); + if (is_onesided) { + dft_size = is_onesided ? ((dft_size >> 1) + 1) : dft_size; + } + + auto n_dfts = static_cast((signal_size - dft_size) / static_cast(frame_step_value)) + 1; + + // The output has the following shape: [batch_size][frames][dft_unique_bins][2] + ONNX_NAMESPACE::TensorShapeProto result_shape_proto; + result_shape_proto.add_dim()->set_dim_value(input_shape.dim(0).dim_value()); // batch size + result_shape_proto.add_dim()->set_dim_value(n_dfts); + result_shape_proto.add_dim()->set_dim_value(dft_size); + result_shape_proto.add_dim()->set_dim_value(2); + updateOutputShape(ctx, 0, result_shape_proto); + }); // Window Functions MS_SIGNAL_OPERATOR_SCHEMA(HannWindow) .SetDomain(kMSExperimentalDomain) .SinceVersion(1) .FillUsing(CosineSumWindowOpDocGenerator("Hann")) - .TypeConstraint( + .TypeConstraint( "T1", {"tensor(int32)", "tensor(int64)"}, - "Constrain the input size to int64_t.") - .TypeConstraint( + "Constrain the input size to int64_t.") + .TypeConstraint( "T2", ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), "Constrain output types to numeric tensors.") - .FunctionBody(R"ONNX( + .FunctionBody(R"ONNX( { A0 = Constant () A1 = Constant () @@ -565,22 +560,21 @@ void RegisterSignalSchemas() { Temp1 = Sub (A0, Temp0) output = Cast (Temp1) } - )ONNX" - ); + )ONNX"); MS_SIGNAL_OPERATOR_SCHEMA(HammingWindow) .SetDomain(kMSExperimentalDomain) .SinceVersion(1) .FillUsing(CosineSumWindowOpDocGenerator("Hamming")) - .TypeConstraint( + .TypeConstraint( "T1", {"tensor(int32)", "tensor(int64)"}, - "Constrain the input size to int64_t.") - .TypeConstraint( + "Constrain the input size to int64_t.") + .TypeConstraint( "T2", ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), "Constrain output types to numeric tensors.") - .FunctionBody(R"ONNX( + .FunctionBody(R"ONNX( { A0 = Constant () A1 = Constant () @@ -602,22 +596,21 @@ void RegisterSignalSchemas() { Temp1 = Sub (A0, Temp0) output = Cast (Temp1) } - )ONNX" - ); + )ONNX"); MS_SIGNAL_OPERATOR_SCHEMA(BlackmanWindow) .SetDomain(kMSExperimentalDomain) .SinceVersion(1) .FillUsing(CosineSumWindowOpDocGenerator("Blackman")) - .TypeConstraint( + .TypeConstraint( "T1", {"tensor(int32)", "tensor(int64)"}, - "Constrain the input size to int64_t.") - .TypeConstraint( + "Constrain the input size to int64_t.") + .TypeConstraint( "T2", ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), "Constrain output types to numeric tensors.") - .FunctionBody(R"ONNX( + .FunctionBody(R"ONNX( { A0 = Constant () A1 = Constant () @@ -639,18 +632,20 @@ void RegisterSignalSchemas() { Temp1 = Sub (A0, Temp0) output = Cast (Temp1) } - )ONNX" - ); + )ONNX"); -static const char* MelWeightMatrix_ver17_doc = R"DOC( -Generate a MelWeightMatrix that can be used to re-weight a Tensor containing a linearly sampled frequency spectra (from DFT or STFT) into num_mel_bins frequency information based on the [lower_edge_hertz, upper_edge_hertz] range on the mel scale. + static const char* MelWeightMatrix_ver17_doc = R"DOC( +Generate a MelWeightMatrix that can be used to re-weight a Tensor containing a linearly sampled frequency spectra +(from DFT or STFT) into num_mel_bins frequency information based on the [lower_edge_hertz, upper_edge_hertz] range +on the mel scale. This function defines the mel scale in terms of a frequency in hertz according to the following formula: mel(f) = 2595 * log10(1 + f/700) In the returned matrix, all the triangles (filterbanks) have a peak value of 1.0. -The returned MelWeightMatrix can be used to right-multiply a spectrogram S of shape [frames, num_spectrogram_bins] of linear scale spectrum values (e.g. STFT magnitudes) to generate a "mel spectrogram" M of shape [frames, num_mel_bins]. +The returned MelWeightMatrix can be used to right-multiply a spectrogram S of shape [frames, num_spectrogram_bins] of +linear scale spectrum values (e.g. STFT magnitudes) to generate a "mel spectrogram" M of shape [frames, num_mel_bins]. )DOC"; MS_SIGNAL_OPERATOR_SCHEMA(MelWeightMatrix) @@ -687,56 +682,57 @@ The returned MelWeightMatrix can be used to right-multiply a spectrogram S of sh "The MEL Matrix", "T3") .TypeConstraint( - "T1", - {"tensor(int32)", "tensor(int64)"}, - "Constrain to integer tensors.") + "T1", + {"tensor(int32)", "tensor(int64)"}, + "Constrain to integer tensors.") .TypeConstraint( - "T2", - {"tensor(float)", - "tensor(float16)", - "tensor(double)", - "tensor(bfloat16)"}, - "Constrain to float tensors") + "T2", + {"tensor(float)", + "tensor(float16)", + "tensor(double)", + "tensor(bfloat16)"}, + "Constrain to float tensors") .TypeConstraint( - "T3", - ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), - "Constrain to any numerical types.") + "T3", + ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), + "Constrain to any numerical types.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - auto output_datatype = getAttribute(ctx, "output_datatype", static_cast(onnx::TensorProto_DataType::TensorProto_DataType_FLOAT)); + auto output_datatype = getAttribute( + ctx, "output_datatype", static_cast(onnx::TensorProto::DataType::TensorProto_DataType_FLOAT)); updateOutputElemType(ctx, 0, static_cast(output_datatype)); if (!hasInputShape(ctx, 0) || !hasInputShape(ctx, 1)) { - return; + return; } const auto* num_mel_bins = ctx.getInputData(0); const auto* dft_length = ctx.getInputData(1); if (nullptr == num_mel_bins || nullptr == dft_length) { - return; + return; } int64_t num_mel_bins_value = -1; int64_t dft_length_value = -1; if (num_mel_bins->dims_size() != 0) { - fail_shape_inference("num_mel_bins input must be scalar."); + fail_shape_inference("num_mel_bins input must be scalar."); } num_mel_bins_value = get_scalar_value_from_tensor(num_mel_bins); if (dft_length->dims_size() != 0) { - fail_shape_inference("dft_length input must be scalar."); + fail_shape_inference("dft_length input must be scalar."); } dft_length_value = get_scalar_value_from_tensor(dft_length); if (num_mel_bins_value > 0 && dft_length_value > 0) { - ONNX_NAMESPACE::TensorShapeProto result_shape; - result_shape.add_dim()->set_dim_value(static_cast((dft_length_value >> 1) + 1)); - result_shape.add_dim()->set_dim_value(num_mel_bins_value); - updateOutputShape(ctx, 0, result_shape); + ONNX_NAMESPACE::TensorShapeProto result_shape; + result_shape.add_dim()->set_dim_value(static_cast((dft_length_value >> 1) + 1)); + result_shape.add_dim()->set_dim_value(num_mel_bins_value); + updateOutputShape(ctx, 0, result_shape); } }); } -} // namespace audio +} // namespace signal } // namespace onnxruntime #endif