Skip to content

Commit

Permalink
[CPU]Fix qwen rope fusion with special_zero (#24727)
Browse files Browse the repository at this point in the history
### Details:
 - *Extend Qwen Rope Fusion*
- PR to master branch
#24750

### Tickets:
 - *CVS-142523*
  • Loading branch information
zhangYiIntel committed May 29, 2024
1 parent a986ce2 commit 3af803b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "ov_ops/type_relaxed.hpp"
#include "transformations/cpu_opset/common/op/rope.hpp"
#include "utils/gen_pattern.hpp"
#include "utils/general_utils.h"

using namespace ov::gen_pattern;

Expand Down Expand Up @@ -689,13 +690,12 @@ ov::intel_cpu::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
{{"special_zero", false}}); // tensor_array<f32[?,?,32,2,64]>
};

auto reshape_opt2 = [&](std::shared_ptr<Node> input_BLHS) {
return makePattern<opset1::Reshape>({input_BLHS, {0, 0, 0, 2, head_size / 2}},
{{"special_zero", true}}); // tensor_array<f32[?,?,32,2,64]>
};
// If with sepcial_zero, const_shape should be checked later
auto const_shape = makePattern<opset1::Constant>({}, {});
auto reshape_special = makePattern<opset1::Reshape>({slice_Slice_543, const_shape}, {{"special_zero", true}});

auto ListUnpack_586_Split =
makePattern<opset1::Split>({reshape_opt1(slice_Slice_543) | reshape_opt2(slice_Slice_543), -2},
makePattern<opset1::Split>({reshape_opt1(slice_Slice_543) | reshape_special, -2},
{{"num_splits", 2}}); // tensor_array<f32[?,?,32,1,64] f32[?,?,32,1,64]>
ListUnpack_586_Split->set_output_size(2);
auto Multiply_567527 =
Expand Down Expand Up @@ -746,6 +746,29 @@ ov::intel_cpu::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
config.slice_stop = config.slice_start + config.head_cnt * config.head_size;
}

if (pattern_map.count(reshape_special)) {
// check reshape_special shape correctness
auto reshape_special_node = pattern_map.at(reshape_special).get_node_shared_ptr();
auto data_shape = reshape_special_node->get_input_partial_shape(0);
auto reshape_shape = pattern_map.at(const_shape);
auto node = ov::as_type_ptr<opset1::Constant>(reshape_shape.get_node_shared_ptr());
const auto& target = node->cast_vector<int32_t>();
// ensure target_shape have correct rank
if (target.size() < 3) {
return false;
}
int32_t head_size = config.head_size;
int32_t head_cnt = config.head_cnt;
// reshape splits the head_size of input to [2, head_size / 2]
// head_cnt of target_shape could be 0 or head_cnt
size_t target_rank = target.size();
bool is_ok = (target[target_rank - 1] == head_size / 2) && (target[target_rank - 2] == 2) &&
(ov::intel_cpu::one_of(target[target_rank - 3], 0, head_cnt));
if (!is_ok) {
return false;
}
}

new_args.push_back(pattern_map.at(qkv_proj));
new_args.push_back(pattern_map.at(rotary_emb_cos));
new_args.push_back(pattern_map.at(rotary_emb_sin));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,14 @@ TEST_F(RoPECPUTestChatGLM, smoke_CompareWithRefs) {
CheckNumberOfNodesWithType(compiledModel, "RoPE", 1);
}

class RoPECPUTestQwen7b : public SubgraphBaseTest {
class RoPECPUTestQwen7b : public SubgraphBaseTest, public testing::WithParamInterface<bool> {
public:
static std::string getTestCaseName(const testing::TestParamInfo<bool>& obj) {
const bool specialReshape = obj.param;
std::ostringstream result;
result << "specialReshape=" << specialReshape << std::endl;
return result.str();
}
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
const auto& funcInputs = function->inputs();

Expand All @@ -346,7 +352,7 @@ class RoPECPUTestQwen7b : public SubgraphBaseTest {
}

protected:
std::shared_ptr<ov::Model> buildROPE_QWen7b() {
std::shared_ptr<ov::Model> buildROPE_QWen7b(bool specialReshape) {
auto input =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, PartialShape{-1, -1, 4096 + 4096 + 4096});
auto cos_cache = std::make_shared<ov::opset1::Parameter>(ov::element::f32, PartialShape{1, -1, 1, 128});
Expand Down Expand Up @@ -401,8 +407,13 @@ class RoPECPUTestQwen7b : public SubgraphBaseTest {
makeOP<opset1::Reshape>({floor_divide_Floor, {-1}}, {{"special_zero", false}});
auto ListConstruct_493_Concat =
makeOP<opset1::Concat>({Gather_239390, {2}, ListConstruct_493_Reshape_3}, {{"axis", 0}});
auto reshape_Reshape =
makeOP<opset1::Reshape>({slice_Slice_470, ListConstruct_493_Concat}, {{"special_zero", false}});
std::shared_ptr<ov::Node> reshape_Reshape = nullptr;
if (specialReshape) {
reshape_Reshape = makeOP<opset1::Reshape>({slice_Slice_470, {0, 0, 32, 2, 64}}, {{"special_zero", true}});
} else {
reshape_Reshape =
makeOP<opset1::Reshape>({slice_Slice_470, ListConstruct_493_Concat}, {{"special_zero", false}});
}
auto ListUnpack_496_Split = makeOP<opset1::Split>({reshape_Reshape, -2}, {{"num_splits", 2}});
auto ListUnpack_496_Squeeze_0 = makeOP<opset1::Squeeze>({ListUnpack_496_Split->output(1), -2});
auto Constant_296840_compressed = makeConst(element::f16,
Expand Down Expand Up @@ -444,19 +455,25 @@ class RoPECPUTestQwen7b : public SubgraphBaseTest {
}
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;
const bool specialReshape = this->GetParam();
const int batch = 2;
const int seq_length = 7;
InputShape inpShape = {{batch, -1, 4096 + 4096 + 4096}, {{batch, seq_length, 4096 + 4096 + 4096}}};
init_input_shapes({inpShape});
function = buildROPE_QWen7b();
function = buildROPE_QWen7b(specialReshape);
}
};

TEST_F(RoPECPUTestQwen7b, smoke_CompareWithRefs) {
TEST_P(RoPECPUTestQwen7b, smoke_CompareWithRefs) {
run();
CheckNumberOfNodesWithType(compiledModel, "RoPE", 1);
}

INSTANTIATE_TEST_SUITE_P(smoke_RoPECPUTestQwen7b,
RoPECPUTestQwen7b,
::testing::Values(true, false),
RoPECPUTestQwen7b::getTestCaseName);

class RoPECPUTestGPTJ : public SubgraphBaseTest, public testing::WithParamInterface<bool> {
public:
static std::string getTestCaseName(const testing::TestParamInfo<bool>& obj) {
Expand Down

0 comments on commit 3af803b

Please sign in to comment.