Skip to content

Commit

Permalink
[GPU] Enable SDPA by default (#24757)
Browse files Browse the repository at this point in the history
### Details:
 - Enabled SDPA by default
- Added indirect inputs support (copy of
#24665)
 - Updated SDPA decomposition rule to cover only well-checked cases
 - Updated functional tests accordingly
  • Loading branch information
sshlyapn committed May 29, 2024
1 parent 1d42420 commit 6a8079b
Show file tree
Hide file tree
Showing 24 changed files with 847 additions and 142 deletions.
4 changes: 4 additions & 0 deletions src/core/include/openvino/op/scaled_dot_product_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class OPENVINO_API ScaledDotProductAttention : public Op {
return m_causal;
}

void set_causal(bool causal) {
m_causal = causal;
}

private:
bool m_causal = false;
};
Expand Down
78 changes: 78 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/indirect_sdpa.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "intel_gpu/op/sdpa.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/partial_shape.hpp"
#include "openvino/op/op.hpp"

namespace ov {
namespace intel_gpu {
namespace op {

class IndirectSDPA : public ov::intel_gpu::op::SDPA {
public:
OPENVINO_OP("IndirectSDPA", "gpu_opset");

IndirectSDPA() = default;

IndirectSDPA(const ov::Output<Node>& Q,
const ov::Output<Node>& K,
const ov::Output<Node>& V,
const ov::Output<Node>& beam_table,
const bool is_causal,
const int64_t indirect_axis,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
const std::vector<int64_t>& order_out,
const ov::element::Type output_type = ov::element::undefined);

IndirectSDPA(const ov::Output<Node>& Q,
const ov::Output<Node>& K,
const ov::Output<Node>& V,
const ov::Output<Node>& attn_mask,
const ov::Output<Node>& beam_table,
const bool is_causal,
const int64_t indirect_axis,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
const std::vector<int64_t>& order_out,
const ov::element::Type output_type = ov::element::undefined);

IndirectSDPA(const ov::Output<Node>& Q,
const ov::Output<Node>& K,
const ov::Output<Node>& V,
const ov::Output<Node>& attn_mask,
const ov::Output<Node>& scale,
const ov::Output<Node>& beam_table,
const bool is_causal,
const int64_t indirect_axis,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
const std::vector<int64_t>& order_out,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor &visitor) override;
void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

ov::element::Type get_output_type() const { return m_output_type; }

int64_t get_indirect_axis() const { return m_indirect_axis; }

using ov::intel_gpu::op::SDPA::default_order;

protected:
int64_t m_indirect_axis = -1;
};

} // namespace op
} // namespace intel_gpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,4 @@ REGISTER_FACTORY(internal, IndirectGemm);
REGISTER_FACTORY(internal, Convolution);
REGISTER_FACTORY(internal, Placeholder);
REGISTER_FACTORY(internal, SDPA);
REGISTER_FACTORY(internal, IndirectSDPA);
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,31 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
scaled_dot_product_attention(const primitive_id& id,
const std::vector<cldnn::input_info> inputs,
bool is_causal,
int64_t indirect_axis = -1,
const std::vector<int64_t>& input_q_transpose_order = {},
const std::vector<int64_t>& input_k_transpose_order = {},
const std::vector<int64_t>& input_v_transpose_order = {},
const std::vector<int64_t>& output_transpose_order = {},
const padding& output_padding = padding())
: primitive_base(id, inputs, {output_padding})
, is_causal(is_causal)
, has_attn_mask_input(inputs.size() > 3)
, has_scale_input(inputs.size() > 4)
, indirect_axis(indirect_axis)
, input_q_transpose_order(input_q_transpose_order)
, input_k_transpose_order(input_k_transpose_order)
, input_v_transpose_order(input_v_transpose_order)
, output_transpose_order(output_transpose_order) {}
, output_transpose_order(output_transpose_order) {
auto data_inputs_num = inputs.size();
if (indirect_axis != -1)
data_inputs_num--;

has_attn_mask_input = data_inputs_num > 3;
has_scale_input = data_inputs_num > 4;
}

bool is_causal = false;
bool has_attn_mask_input = false;
bool has_scale_input = false;
int64_t indirect_axis = -1;

std::vector<int64_t> input_q_transpose_order;
std::vector<int64_t> input_k_transpose_order;
Expand All @@ -48,6 +55,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
seed = hash_combine(seed, is_causal);
seed = hash_combine(seed, has_attn_mask_input);
seed = hash_combine(seed, has_scale_input);
seed = hash_combine(seed, indirect_axis);
seed = hash_range(seed, input_q_transpose_order.begin(), input_q_transpose_order.end());
seed = hash_range(seed, input_k_transpose_order.begin(), input_k_transpose_order.end());
seed = hash_range(seed, input_v_transpose_order.begin(), input_v_transpose_order.end());
Expand All @@ -64,6 +72,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
return is_causal == rhs_casted.is_causal &&
has_attn_mask_input == rhs_casted.has_attn_mask_input &&
has_scale_input == rhs_casted.has_scale_input &&
indirect_axis == rhs_casted.indirect_axis &&
input_q_transpose_order == rhs_casted.input_q_transpose_order &&
input_k_transpose_order == rhs_casted.input_k_transpose_order &&
input_v_transpose_order == rhs_casted.input_v_transpose_order &&
Expand All @@ -75,6 +84,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
ob << is_causal;
ob << has_attn_mask_input;
ob << has_scale_input;
ob << indirect_axis;
ob << input_q_transpose_order;
ob << input_k_transpose_order;
ob << input_v_transpose_order;
Expand All @@ -86,6 +96,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
ib >> is_causal;
ib >> has_attn_mask_input;
ib >> has_scale_input;
ib >> indirect_axis;
ib >> input_q_transpose_order;
ib >> input_k_transpose_order;
ib >> input_v_transpose_order;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class debug_configuration {
std::vector<std::string> forced_impl_types; // Force implementation type either ocl or onednn
int max_kernels_per_batch; // Maximum number of kernels in a batch during compiling kernels
int impls_cache_capacity; // The maximum number of entries in the kernel impl cache
int enable_sdpa; // Allows to control SDPA decomposition
int disable_async_compilation; // Disable async compilation
int disable_winograd_conv; // Disable Winograd conv
int disable_dynamic_impl; // Disable dynamic implementation
Expand Down
Loading

0 comments on commit 6a8079b

Please sign in to comment.