Skip to content

Commit

Permalink
Yi3/shape inference 3rd batch (#8611)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangYiIntel committed Dec 15, 2021
1 parent e8d5cf4 commit 4fba88d
Show file tree
Hide file tree
Showing 18 changed files with 554 additions and 254 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "shape_inference.hpp"

#include <ngraph/runtime/host_tensor.hpp>
Expand All @@ -17,12 +16,16 @@
#include "convolution_shape_inference.hpp"
#include "experimental_detectron_detection_output_shape_inference.hpp"
#include "experimental_detectron_prior_grid_generator_shape_inference.hpp"
#include "experimental_detectron_topkrois_shape_inference.hpp"
#include "fake_quantize.hpp"
#include "gather_elements_shape_inference.hpp"
#include "gather_shape_inference.hpp"
#include "gather_tree_shape_inference.hpp"
#include "interpolate_shape_inference.hpp"
#include "lstm_cell_shape_inference.hpp"
#include "one_hot_shape_inference.hpp"
#include "read_value_shape_inference.hpp"
#include "reduce_shape_inference.hpp"
#include "experimental_detectron_topkrois_shape_inference.hpp"
#include "interpolate_shape_inference.hpp"
#include "scatter_elements_update_shape_inference.hpp"
#include "scatter_nd_base_shape_inference.hpp"
#include "shape_inference.hpp"
Expand Down Expand Up @@ -129,6 +132,14 @@ void shape_inference(ov::Node* op,
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::opset4::ScatterNDUpdate>(op)) {
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset6::GatherElements>(op)) {
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::op::util::GatherBase>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else if (auto node = ov::as_type<ov::opset1::GatherTree>(op)) {
shape_infer(node, input_shapes, output_shapes);
} else if (auto node = ov::as_type<ov::opset1::OneHot>(op)) {
shape_infer(node, input_shapes, output_shapes, constant_data);
} else {
ngraph::OutputVector new_inputs;
for (size_t i = 0; i < op->get_input_size(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>

#include <gather_elements_shape_inference.hpp>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>

using namespace ov;

TEST(StaticShapeInferenceTest, GatherElementsTest) {
int64_t axis = -1;
auto D = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto I = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1, -1, -1});
auto GE = std::make_shared<op::v6::GatherElements>(D, I, axis);
// Test StaticShape
std::vector<StaticShape> static_input_shapes = {StaticShape{300, 3, 10, 1}, StaticShape{300, 3, 10, 33333}},
static_output_shapes = {StaticShape{}};
shape_inference(GE.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], (StaticShape{300, 3, 10, 33333}));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include <gather_shape_inference.hpp>
#include <openvino/op/gather.hpp>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>

using namespace ov;

TEST(StaticShapeInferenceTest, GatherV1Test) {
auto P = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
auto I = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1});
auto A = op::v0::Constant::create(element::i64, Shape{}, {0});
auto G = std::make_shared<op::v1::Gather>(P, I, A);
// Test StaticShape
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 2}, StaticShape{2, 2}, StaticShape{1}},
static_output_shapes = {StaticShape{}};
shape_inference(G.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], (StaticShape{2, 2, 2}));
}

TEST(StaticShapeInferenceTest, GatherV1TestNonConstantA) {
auto P = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
auto I = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1});
auto A = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
auto G = std::make_shared<op::v1::Gather>(P, I, A);
auto hostTensor = std::make_shared<HostTensor>(element::i32, Shape{});
int32_t val_a = 1;
hostTensor->write(&val_a, sizeof(int32_t));
// Test StaticShape
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 2}, StaticShape{2, 2}, StaticShape{}},
static_output_shapes = {StaticShape{}};
shape_inference(G.get(), static_input_shapes, static_output_shapes, {{2, hostTensor}});
ASSERT_EQ(static_output_shapes[0], (StaticShape{3, 2, 2}));
}

TEST(StaticShapeInferenceTest, GatherV7Test) {
auto P = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
auto I = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1});
auto A = op::v0::Constant::create(element::i64, Shape{}, {0});
auto G = std::make_shared<op::v7::Gather>(P, I, A);
// Test StaticShape
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 2}, StaticShape{2, 2}, StaticShape{1}},
static_output_shapes = {StaticShape{}};
shape_inference(G.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], (StaticShape{2, 2, 2}));
}

TEST(StaticShapeInferenceTest, GatherV7TestNonConstantA) {
auto P = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
auto I = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1});
auto A = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
auto G = std::make_shared<op::v7::Gather>(P, I, A);
auto hostTensor = std::make_shared<HostTensor>(element::i32, Shape{});
int32_t val_a = 0;
hostTensor->write(&val_a, sizeof(int32_t));
// Test StaticShape
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 2}, StaticShape{2, 2}, StaticShape{}},
static_output_shapes = {StaticShape{}};
shape_inference(G.get(), static_input_shapes, static_output_shapes, {{2, hostTensor}});
ASSERT_EQ(static_output_shapes[0], (StaticShape{2, 2, 2}));
}

TEST(StaticShapeInferenceTest, GatherV8Test) {
auto P = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
auto I = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1});
auto A = op::v0::Constant::create(element::i64, Shape{}, {0});
auto G = std::make_shared<op::v8::Gather>(P, I, A);
// Test StaticShape
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 2}, StaticShape{2, 2}, StaticShape{1}},
static_output_shapes = {StaticShape{}};
shape_inference(G.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], (StaticShape{2, 2, 2}));
}

TEST(StaticShapeInferenceTest, GatherV8TestNonConstantA) {
auto P = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
auto I = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{-1, -1});
auto A = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
auto G = std::make_shared<op::v8::Gather>(P, I, A);
auto hostTensor = std::make_shared<HostTensor>(element::i32, Shape{});
int32_t val_a = 0;
hostTensor->write(&val_a, sizeof(int32_t));
// Test StaticShape
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 2}, StaticShape{2, 2}, StaticShape{}},
static_output_shapes = {StaticShape{}};
shape_inference(G.get(), static_input_shapes, static_output_shapes, {{2, hostTensor}});
ASSERT_EQ(static_output_shapes[0], (StaticShape{2, 2, 2}));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <gtest/gtest.h>

#include <gather_tree_shape_inference.hpp>
#include <openvino/op/gather_tree.hpp>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>

using namespace ov;

TEST(StaticShapeInferenceTest, GatherTreeTest) {
auto step_ids = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto parent_idx = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto max_seq_len = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1});
auto end_token = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{Shape{}});
auto gather_tree = std::make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
// Test StaticShape
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 2, 3},
StaticShape{1, 2, 3},
StaticShape{2},
StaticShape{}},
static_output_shapes = {StaticShape{}};
shape_inference(gather_tree.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], (StaticShape{1, 2, 3}));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>

#include <one_hot_shape_inference.hpp>
#include <openvino/core/coordinate_diff.hpp>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>

using namespace ov;

TEST(StaticShapeInferenceTest, OneHotTest) {
auto indices = std::make_shared<op::v0::Parameter>(element::i64, PartialShape{-1});
auto depth = op::v0::Constant::create(element::i64, Shape{}, {2});
auto on_value = op::v0::Constant::create(element::u32, Shape{}, {5});
auto off_value = op::v0::Constant::create(element::u32, Shape{}, {10});
int64_t axis = -1;
auto ont_hot = std::make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis);
// Test StaticShape
std::vector<StaticShape> static_input_shapes = {StaticShape{3}, StaticShape{}, StaticShape{}, StaticShape{}},
static_output_shapes = {StaticShape{}};
shape_inference(ont_hot.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], (StaticShape{3, 2}));
}
4 changes: 4 additions & 0 deletions src/core/include/openvino/op/gather_elements.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class OPENVINO_API GatherElements : public Op {

private:
int64_t m_axis;
template <class T>
void friend shape_infer(const GatherElements* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes);
};
} // namespace v6
} // namespace op
Expand Down
9 changes: 5 additions & 4 deletions src/core/include/openvino/op/one_hot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,16 @@ class OPENVINO_API OneHot : public Op {
bool has_evaluate() const override;

/// \return The index of the one-hot axis.
int64_t get_axis() const {
const int64_t& get_axis() const {
return m_axis;
}
void set_axis(int64_t axis) {
m_axis = axis;
}
void set_axis(int64_t axis);

protected:
int64_t m_axis;

private:
friend void inline resolve_axis(OneHot* op);
};
} // namespace v1
} // namespace op
Expand Down
1 change: 1 addition & 0 deletions src/core/include/openvino/op/util/gather_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class OPENVINO_API GatherBase : public Op {
OPENVINO_SUPPRESS_DEPRECATED_END

bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override;
const int64_t& get_batch_dims() const;

protected:
int64_t m_batch_dims = 0;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/op/gather_elements.hpp>

#include "utils.hpp"

namespace ov {
namespace op {
namespace v6 {
template <class T>
void shape_infer(const GatherElements* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 && output_shapes.size() == 1);
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;

const auto& data_pshape = input_shapes[0];
const auto& indices_pshape = input_shapes[1];
auto data_rank = data_pshape.rank();
auto indices_rank = indices_pshape.rank();
auto& output_shape = output_shapes[0];

int64_t axis = op->m_axis;
if (data_rank.is_static())
axis = ov::normalize_axis(op, axis, data_rank);

output_shape = indices_pshape;

NODE_VALIDATION_CHECK(op, data_rank.is_dynamic() || data_rank.get_length() >= 1, "data rank must be >= 1.");

NODE_VALIDATION_CHECK(op,
indices_rank.is_dynamic() || indices_rank.get_length() >= 1,
"indices rank must be >= 1.");

if (data_rank.is_static() && indices_rank.is_dynamic()) {
// output has the same rank of data
output_shape = data_pshape;
output_shape[axis] = DimType();
return;
}

if (data_rank.is_dynamic()) {
// can't decide rank, set it to all dynamic
if (indices_rank.is_dynamic())
output_shape = PartialShape::dynamic();
return;
}

// left only case when data_rank.is_static() && indices_rank.is_static()
NODE_VALIDATION_CHECK(op,
data_rank.get_length() == indices_rank.get_length(),
"data and indices rank must be equal. But instead got: ",
data_rank.get_length(),
" and ",
indices_rank.get_length());

for (int i = 0; i < indices_rank.get_length(); i++) {
if (i != axis) {
// if size of the current dimension of indices is unknown it will be retrieved from data
// e.g., if data_shape = {4, 4, ?}, indices_shape = {1, ?, 5} and axis = 0
// (and if intervals intersect) then output_pshape will be {1, 4, 5}

NODE_VALIDATION_CHECK(op,
data_pshape[i].compatible(indices_pshape[i]),
"Shapes ",
data_pshape,
" and ",
indices_pshape,
" are not consistent. data and indices must have equal or "
"intersecting sizes, except for axis ",
axis);

output_shape[i] = data_pshape[i] & indices_pshape[i];
}
}
}
} // namespace v6
} // namespace op
} // namespace ov
Loading

0 comments on commit 4fba88d

Please sign in to comment.