From de1302f0d2f8f8c3e73cc48497e5cbbb693765ce Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Fri, 27 Sep 2024 18:42:12 +0200 Subject: [PATCH] Optimize function builders --- .../src/subgraph_convert.cpp | 57 ++----------------- .../src/subgraph_simple.cpp | 6 +- 2 files changed, 7 insertions(+), 56 deletions(-) diff --git a/src/tests/ov_helpers/ov_snippets_models/src/subgraph_convert.cpp b/src/tests/ov_helpers/ov_snippets_models/src/subgraph_convert.cpp index dd6209d26500f7..1324d380294502 100644 --- a/src/tests/ov_helpers/ov_snippets_models/src/subgraph_convert.cpp +++ b/src/tests/ov_helpers/ov_snippets_models/src/subgraph_convert.cpp @@ -24,10 +24,7 @@ std::shared_ptr ConvertFunction::initOriginal() const { } std::shared_ptr ConvertFunction::initReference() const { auto data0 = std::make_shared(inType, input_shapes[0]); - auto indata0 = std::make_shared(inType, data0->get_shape()); - auto subgraph = std::make_shared(NodeVector{data0}, - std::make_shared(NodeVector{std::make_shared(indata0, outType)}, - ParameterVector{indata0})); + auto subgraph = std::make_shared(NodeVector{data0}, getOriginal()); return std::make_shared(NodeVector{subgraph}, ParameterVector{data0}); } @@ -41,13 +38,7 @@ std::shared_ptr ConvertInputFunction::initOriginal() const { std::shared_ptr ConvertInputFunction::initReference() const { auto data0 = std::make_shared(inType, input_shapes[0]); auto data1 = std::make_shared(outType, input_shapes[1]); - auto indata0 = std::make_shared(inType, data0->get_shape()); - auto indata1 = std::make_shared(outType, data1->get_shape()); - auto convert = std::make_shared(indata0, outType); - auto subgraph = std::make_shared(NodeVector{data0, data1}, - std::make_shared( - NodeVector{std::make_shared(convert, indata1)}, - ParameterVector{indata0, indata1})); + auto subgraph = std::make_shared(NodeVector{data0, data1}, getOriginal()); return std::make_shared(NodeVector{subgraph}, ParameterVector{data0, data1}); } @@ -61,14 +52,7 @@ std::shared_ptr ConvertOutputFunction::initOriginal() const { std::shared_ptr ConvertOutputFunction::initReference() const { auto data0 = std::make_shared(inType, input_shapes[0]); auto data1 = std::make_shared(inType, input_shapes[1]); - auto indata0 = std::make_shared(inType, data0->get_shape()); - auto indata1 = std::make_shared(inType, data1->get_shape()); - auto add = std::make_shared(indata0, indata1); - auto convert = std::make_shared(add, outType); - auto subgraph = std::make_shared(NodeVector{data0, data1}, - std::make_shared( - NodeVector{convert}, - ParameterVector{indata0, indata1})); + auto subgraph = std::make_shared(NodeVector{data0, data1}, getOriginal()); return std::make_shared(NodeVector{subgraph}, ParameterVector{data0, data1}); } @@ -132,15 +116,7 @@ std::shared_ptr ConvertManyOnInputsFunction::initOriginal() const { } std::shared_ptr ConvertManyOnInputsFunction::initReference() const { auto data0 = std::make_shared(types[0], input_shapes[0]); - auto indata0 = std::make_shared(types[0], data0->get_shape()); - std::shared_ptr out = indata0; - for (auto i = 1; i < types.size(); i++) { - auto convert = std::make_shared(out, types[i]); - out = convert; - } - auto relu = std::make_shared(out); - auto subgraph = std::make_shared(NodeVector{data0}, - std::make_shared(NodeVector{relu}, ParameterVector{indata0})); + auto subgraph = std::make_shared(NodeVector{data0}, getOriginal()); return std::make_shared(NodeVector{subgraph}, ParameterVector{data0}); } @@ -156,15 +132,7 @@ std::shared_ptr ConvertManyOnOutputsFunction::initOriginal() const { } std::shared_ptr ConvertManyOnOutputsFunction::initReference() const { auto data0 = std::make_shared(types[0], input_shapes[0]); - auto indata0 = std::make_shared(types[0], data0->get_shape()); - auto relu = std::make_shared(indata0); - std::shared_ptr out = relu; - for (auto i = 1; i < types.size(); i++) { - auto convert = std::make_shared(out, types[i]); - out = convert; - } - auto subgraph = std::make_shared(NodeVector{data0}, - std::make_shared(NodeVector{out}, ParameterVector{indata0})); + auto subgraph = std::make_shared(NodeVector{data0}, getOriginal()); return std::make_shared(NodeVector{subgraph}, ParameterVector{data0}); } @@ -185,20 +153,7 @@ std::shared_ptr ConvertManyOnInputOutputFunction::initOriginal() cons } std::shared_ptr ConvertManyOnInputOutputFunction::initReference() const { auto data0 = std::make_shared(inTypes[0], input_shapes[0]); - auto indata0 = std::make_shared(inTypes[0], data0->get_shape()); - std::shared_ptr out = indata0; - for (auto i = 1; i < inTypes.size(); i++) { - auto convert = std::make_shared(out, inTypes[i]); - out = convert; - } - auto relu = std::make_shared(data0); - out = relu; - for (auto i = 0; i < outTypes.size(); i++) { - auto convert = std::make_shared(out, outTypes[i]); - out = convert; - } - auto subgraph = std::make_shared(NodeVector{data0}, - std::make_shared(NodeVector{out}, ParameterVector{indata0})); + auto subgraph = std::make_shared(NodeVector{data0}, getOriginal()); return std::make_shared(NodeVector{subgraph}, ParameterVector{data0}); } } // namespace snippets diff --git a/src/tests/ov_helpers/ov_snippets_models/src/subgraph_simple.cpp b/src/tests/ov_helpers/ov_snippets_models/src/subgraph_simple.cpp index 1326864f89541a..12758a90c07652 100644 --- a/src/tests/ov_helpers/ov_snippets_models/src/subgraph_simple.cpp +++ b/src/tests/ov_helpers/ov_snippets_models/src/subgraph_simple.cpp @@ -20,11 +20,7 @@ std::shared_ptr AddFunction::initOriginal() const { std::shared_ptr AddFunction::initReference() const { auto data0 = std::make_shared(precision, input_shapes[0]); auto data1 = std::make_shared(precision, input_shapes[1]); - auto indata0 = std::make_shared(precision, data0->get_shape()); - auto indata1 = std::make_shared(precision, data1->get_shape()); - auto add = std::make_shared(NodeVector{data0, data1}, - std::make_shared(NodeVector{std::make_shared(indata0, indata1)}, - ParameterVector{indata0, indata1})); + auto add = std::make_shared(NodeVector{data0, data1}, getOriginal()); return std::make_shared(NodeVector{add}, ParameterVector{data0, data1}); } std::shared_ptr ExpFunction::initOriginal() const {