Skip to content

Commit

Permalink
[Transformation] SpaceToDepthFusion
Browse files Browse the repository at this point in the history
Transform StridedSlice_chain+concat in yolov5 into SpaceToDepth

Signed-off-by: Li, Tingqian <tingqian.li@intel.com>
  • Loading branch information
usstq committed Sep 16, 2021
1 parent 7b370a7 commit 105373e
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 220 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ class TRANSFORMATIONS_API SpaceToDepthFusion;
* +---> StridedSlice -> StridedSlice ----+
* +---> StridedSlice -> StridedSlice ----+
*
* to SpaceToDepth
* with SpaceToDepth when applicable.
*
* Restrictions:
* - input rank must be 4
*/

class ngraph::pass::SpaceToDepthFusion: public ngraph::pass::MatcherPass {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include <limits>
#include <memory>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <numeric>
Expand All @@ -18,25 +18,79 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::SpaceToDepthFusion, "SpaceToDepthFusion", 0

using namespace ngraph;

const auto end_max = std::numeric_limits<int64_t>::max();
static const auto end_max = std::numeric_limits<int64_t>::max();

struct SliceSyntax {
struct SliceSemantics {
std::vector<int64_t> begin;
std::vector<int64_t> end;
std::vector<int64_t> stride;
bool b_valid = false;

SliceSyntax() = default;
SliceSemantics() = default;

SliceSemantics(std::shared_ptr<ngraph::opset8::StridedSlice> ss) : b_valid(false) {
Shape in_shape_max;

const auto& new_axis_mask = ss->get_new_axis_mask();
const auto& shrink_axis_mask = ss->get_shrink_axis_mask();
const auto& ellipsis_mask = ss->get_ellipsis_mask();

// no new, deleted or ellipsis axis is allowed
if (std::find(new_axis_mask.begin(), new_axis_mask.end(), 1) != new_axis_mask.end() ||
std::find(shrink_axis_mask.begin(), shrink_axis_mask.end(), 1) != shrink_axis_mask.end() ||
std::find(ellipsis_mask.begin(), ellipsis_mask.end(), 1) != ellipsis_mask.end())
return;

auto get_masked_input = [&](int input_id, std::vector<int64_t> mask, int64_t masked_value) {
std::vector<int64_t> ret;
auto input =
std::dynamic_pointer_cast<ngraph::opset8::Constant>(ss->input_value(input_id).get_node_shared_ptr());
if (!input)
return ret;

ret = input->cast_vector<int64_t>();

for (size_t k = 0; k < mask.size(); k++) {
if (mask[k] == 1)
ret[k] = masked_value;
}
return ret;
};

begin = get_masked_input(1, ss->get_begin_mask(), 0);
end = get_masked_input(2, ss->get_end_mask(), end_max);

const auto& pshape = ss->input_value(0).get_partial_shape();
if (pshape.is_static()) {
// use end_max to indicate the selection of whole range
const auto static_shape = pshape.get_shape();
for (size_t k = 0; k < static_shape.size() && k < end.size(); k++) {
if (end[k] >= static_cast<int64_t>(static_shape[k]))
end[k] = end_max;
}
}

stride.resize(begin.size(), 1);
if (ss->get_input_size() >= 4) {
auto input = std::dynamic_pointer_cast<ngraph::opset8::Constant>(ss->input_value(3).get_node_shared_ptr());
if (input)
stride = input->cast_vector<int64_t>();
}
b_valid = true;
}

operator bool() const {
return begin.size() > 0 && end.size() > 0 && stride.size() > 0;
return b_valid;
}

/*
A -> StridedSlice1 -> B -> StridedSlice2 -> C
<=>
A -> StridedSlice3 -> C
Fusion of two concecutive StridedSlices can be done on some condition:
A -> StridedSlice1 -> B -> StridedSlice2 -> C
<=>
A -> StridedSlice3 -> C
for 1 particular dimension
for 1 particular dimension:
StridedSlice1 (b1,e1,s1): B[i]=A[i*s1+b1] for i*s1+b1<e1
StridedSlice2 (b2,e2,s2): C[i]=B[i*s2+b2] for i*s2+b2<e2
Expand All @@ -49,7 +103,7 @@ struct SliceSyntax {
b3 = b1 + b2*s1
e3 = MIN(e1, e2*s1+b1)
*/
void fuse_with(const SliceSyntax& s2) {
void fuse_with(const SliceSemantics& s2) {
auto rank = s2.begin.size();

// expand rank to match s2
Expand All @@ -71,96 +125,23 @@ struct SliceSyntax {
this->begin[i] = new_begin;
this->end[i] = new_end;
}
}
};

static SliceSyntax get_syntax(std::shared_ptr<ngraph::opset7::StridedSlice> ss) {
SliceSyntax s;
int rank;
Shape in_shape_max;

rank = ss->input_value(0).get_partial_shape().rank().get_length();

if (ss->input_value(0).get_partial_shape().is_static()) {
in_shape_max = ss->input_value(0).get_shape();
} else {
in_shape_max = Shape(rank, end_max);
}

const auto& new_axis_mask = ss->get_new_axis_mask();
const auto& shrink_axis_mask = ss->get_shrink_axis_mask();
const auto& ellipsis_mask = ss->get_ellipsis_mask();

// no new, deleted or ellipsis axis is allowed
for (auto& v : new_axis_mask) {
if (v == 1)
return s;
b_valid = true;
}
for (auto& v : shrink_axis_mask) {
if (v == 1)
return s;
}
for (auto& v : ellipsis_mask) {
if (v == 1)
return s;
}

auto get_masked_input = [&](int input_id, std::vector<int64_t> mask, int64_t masked_value) {
std::vector<int64_t> ret;
auto input =
std::dynamic_pointer_cast<ngraph::opset7::Constant>(ss->input_value(input_id).get_node_shared_ptr());
if (!input)
return ret;

ret = input->cast_vector<int64_t>();

for (size_t k = 0; k < mask.size(); k++) {
if (mask[k] == 1)
ret[k] = masked_value;
}
return ret;
};

s.begin = get_masked_input(1, ss->get_begin_mask(), 0);
s.end = get_masked_input(2, ss->get_end_mask(), end_max);
for (size_t k = 0; k < in_shape_max.size(); k++) {
if (s.end[k] >= static_cast<int64_t>(in_shape_max[k]))
s.end[k] = end_max;
}

s.stride.resize(s.begin.size(), 1);
if (ss->get_input_size() >= 4) {
auto input = std::dynamic_pointer_cast<ngraph::opset7::Constant>(ss->input_value(3).get_node_shared_ptr());
if (input)
s.stride = input->cast_vector<int64_t>();
}

return s;
}
};

ngraph::pass::SpaceToDepthFusion::SpaceToDepthFusion() {
MATCHER_SCOPE(SpaceToDepthFusion);

const char* env_p = ::getenv("CROSS_CHECK_TOOL");
const int cross_check_tool = env_p ? std::stol(env_p) : -1;

if (cross_check_tool == 0) {
printf("[%s]: cross_check_tool=%d, skipping.\n", __func__, cross_check_tool);
return;
} else {
printf("[%s]: cross_check_tool=%d, enabled.\n", __func__, cross_check_tool);
}

auto concat_pattern = pattern::wrap_type<opset7::Concat>({}, [](const Output<Node>& value) {
auto concat = std::dynamic_pointer_cast<opset7::Concat>(value.get_node_shared_ptr());
auto concat_pattern = pattern::wrap_type<opset8::Concat>({}, [](const Output<Node>& value) {
auto concat = std::dynamic_pointer_cast<opset8::Concat>(value.get_node_shared_ptr());
if (!concat)
return false;
return concat->get_axis() == 1;
});

ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto concat = std::dynamic_pointer_cast<opset7::Concat>(pattern_map.at(concat_pattern).get_node_shared_ptr());
auto concat = std::dynamic_pointer_cast<opset8::Concat>(m.get_match_root());
if (!concat)
return false;

Expand All @@ -175,22 +156,25 @@ ngraph::pass::SpaceToDepthFusion::SpaceToDepthFusion() {
Output<Node> common_input;

for (int i = 0; i < slice_cnt; i++) {
SliceSyntax slice_syntax;
SliceSemantics slice_semantics;
auto input = concat->get_input_source_output(i);
auto ss = std::dynamic_pointer_cast<opset7::StridedSlice>(input.get_node_shared_ptr());
auto ss = std::dynamic_pointer_cast<opset8::StridedSlice>(input.get_node_shared_ptr());
while (ss) {
nodes_to_delete.push_back(ss);

auto syntax = get_syntax(ss);
if (!syntax)
SliceSemantics semantics(ss);
if (!semantics)
return false;

slice_syntax.fuse_with(syntax);
slice_semantics.fuse_with(semantics);
input = ss->input_value(0);

ss = std::dynamic_pointer_cast<opset7::StridedSlice>(input.get_node_shared_ptr());
ss = std::dynamic_pointer_cast<opset8::StridedSlice>(input.get_node_shared_ptr());
}

if (!slice_semantics)
return false;

// all path concated must originates from same input
if (!common_input.get_node_shared_ptr())
common_input = input;
Expand All @@ -199,24 +183,28 @@ ngraph::pass::SpaceToDepthFusion::SpaceToDepthFusion() {
return false;

if (rank == 0)
rank = slice_syntax.stride.size();
rank = slice_semantics.stride.size();

if (rank == 0)
return false;

if (static_cast<int>(slice_syntax.stride.size()) != rank)
if (static_cast<int>(slice_semantics.stride.size()) != rank)
return false;

// [N, C, D1, D2, ...]
for (size_t k = 0; k < 2; k++) {
if (slice_syntax.stride[k] != 1 || slice_syntax.begin[k] != 0 || slice_syntax.end[k] < end_max)
if (slice_semantics.stride[k] != 1 || slice_semantics.begin[k] != 0 || slice_semantics.end[k] < end_max)
return false;
}

// check block size consistency
// do:
// - block size consistency check
// - slice count consistency check
// - begin/stride/end validation
// - slice order calculation
for (int k = 2; k < rank; k++) {
if (block_size == 0) {
block_size = slice_syntax.stride[k];
block_size = slice_semantics.stride[k];
if (block_size < 2)
return false;

Expand All @@ -227,79 +215,69 @@ ngraph::pass::SpaceToDepthFusion::SpaceToDepthFusion() {
if (slice_expected != slice_cnt)
return false;
}
if (slice_syntax.stride[k] != block_size)
if (slice_semantics.begin[k] >= block_size)
return false;
if (slice_syntax.end[k] < end_max)
if (slice_semantics.stride[k] != block_size)
return false;
if (slice_semantics.end[k] < end_max)
return false;

slice_order[i] = slice_order[i] * block_size + slice_syntax.begin[k];
slice_order[i] = slice_order[i] * block_size + slice_semantics.begin[k];
}

if (slice_order[i] != i)
is_ordered = false;

if (slice_order[i] >= slice_cnt) {
printf("ERROR slice_order[i]=%d\n", slice_order[i]);
return false;
}
slice_from_order[slice_order[i]] = i;
}

if (is_ordered) {
std::shared_ptr<Node> new_root =
register_new_node<opset7::SpaceToDepth>(common_input,
opset7::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST,
block_size);
std::make_shared<opset8::SpaceToDepth>(common_input,
opset8::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST,
block_size);

new_root->set_friendly_name(concat->get_friendly_name());
copy_runtime_info(nodes_to_delete, new_root);
replace_node(m.get_match_root(), new_root);
} else {
// if output is connected to a Convolution node, channel re-order can be further fused
// into weights
bool b_further_opt = true;
// if output is connected to Convolution nodes only, channel
// re-order can be further fused into weights
for (auto input_to : concat->get_default_output().get_target_inputs()) {
auto conv = std::dynamic_pointer_cast<opset7::Convolution>(input_to.get_node()->shared_from_this());
if (!conv) {
b_further_opt = false;
break;
}
auto filters = std::dynamic_pointer_cast<opset7::Constant>(conv->get_input_node_shared_ptr(1));
if (!filters) {
b_further_opt = false;
break;
}
}
auto conv = std::dynamic_pointer_cast<opset8::Convolution>(input_to.get_node()->shared_from_this());
if (!conv)
return false;

if (!b_further_opt)
return false;
auto filters = std::dynamic_pointer_cast<opset8::Constant>(conv->get_input_node_shared_ptr(1));
if (!filters)
return false;
}

std::shared_ptr<Node> new_root =
register_new_node<opset7::SpaceToDepth>(common_input,
opset7::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST,
block_size);
std::make_shared<opset8::SpaceToDepth>(common_input,
opset8::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST,
block_size);

new_root->set_friendly_name(concat->get_friendly_name());
copy_runtime_info(nodes_to_delete, new_root);

// add slplit & concat to Convolution's weights, const-folding will eliminate them later
// add slplit & concat to reorder the channels of Convolution's weights,
// later constant-folding pass will eliminate them.
for (auto input_to : concat->get_default_output().get_target_inputs()) {
auto conv = std::dynamic_pointer_cast<opset7::Convolution>(input_to.get_node()->shared_from_this());
auto filters = std::dynamic_pointer_cast<opset7::Constant>(conv->get_input_node_shared_ptr(1));
auto conv = std::dynamic_pointer_cast<opset8::Convolution>(input_to.get_node()->shared_from_this());
auto filters = std::dynamic_pointer_cast<opset8::Constant>(conv->get_input_node_shared_ptr(1));

// filters are ordered by slice-order, now re-order them
auto axis = register_new_node<opset7::Constant>(element::i32, Shape{}, std::vector<int32_t>{1});
auto split = register_new_node<opset7::Split>(filters, axis, slice_cnt);
auto axis = std::make_shared<opset8::Constant>(element::i32, Shape{}, std::vector<int32_t>{1});
auto split = std::make_shared<opset8::Split>(filters, axis, slice_cnt);
OutputVector reorder;
for (int i = 0; i < slice_cnt; i++)
reorder.push_back(split->output(slice_from_order[i]));
auto new_filter = register_new_node<opset7::Concat>(reorder, 1);
replace_node(filters, new_filter);
}
auto new_filter = std::make_shared<opset8::Concat>(reorder, 1);

conv->set_argument(1, new_filter->get_default_output());
}
replace_node(m.get_match_root(), new_root);
}

return true;
};

Expand Down
Loading

0 comments on commit 105373e

Please sign in to comment.