Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick 2.3] Cherry parallel fused transformer api #43505

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions paddle/fluid/operators/fused/fused_attention_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
// the same as QKOut's shape.
ctx->SetOutputDim("AttnDropoutOut",
{x_dim[0], y_dim[1], x_dim[1], out_seq_len});
if (ctx->Attrs().Get<bool>("attn_dropout_is_test") == false) {
if (ctx->Attrs().Get<bool>("is_test") == false) {
ctx->SetOutputDim("AttnDropoutMaskOut",
{x_dim[0], y_dim[1], x_dim[1], out_seq_len});
}
Expand All @@ -202,7 +202,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]});
ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X"));

if (ctx->Attrs().Get<bool>("dropout_is_test") == false) {
if (ctx->Attrs().Get<bool>("is_test") == false) {
ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X"));
}

Expand Down Expand Up @@ -297,7 +297,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
platform::errors::InvalidArgument(
"'attn_dropout_rate' must be between 0.0 and 1.0."));
});
AddAttr<bool>("attn_dropout_is_test",
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
Expand Down Expand Up @@ -341,11 +341,6 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
platform::errors::InvalidArgument(
"'dropout_rate' must be between 0.0 and 1.0."));
});

AddAttr<bool>("dropout_is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<bool>("dropout_fix_seed",
"A flag indicating whether to use a fixed seed to generate "
"random mask. NOTE: DO NOT set this flag to true in "
Expand Down Expand Up @@ -414,10 +409,9 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->Attrs().Get<bool>("attn_dropout_is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when attn_dropout_is_test is false"));
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when is_test is false"));

if (ctx->Attrs().Get<bool>("pre_layer_norm") == false) {
OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fused/fused_attention_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
const float ln_epsilon = ctx.Attr<float>("ln_epsilon");

float attn_dropout_rate = ctx.Attr<float>("attn_dropout_rate");
bool is_test_1 = ctx.Attr<bool>("attn_dropout_is_test");
bool is_test_1 = ctx.Attr<bool>("is_test");
auto &dropout_implementation_1 =
ctx.Attr<std::string>("attn_dropout_implementation");
bool is_upscale_in_train_1 =
Expand Down Expand Up @@ -279,7 +279,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
const float ln2epsilon = ctx.Attr<float>("ln_epsilon");

float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
bool is_test_1 = ctx.Attr<bool>("attn_dropout_is_test");
bool is_test_1 = ctx.Attr<bool>("is_test");
auto &dropout_implementation_1 =
ctx.Attr<std::string>("attn_dropout_implementation");
bool is_upscale_in_train_1 =
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/fused/fused_dropout_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct DropoutParam {
auto& dropout_implementation =
context.Attr<std::string>(pre_fix + "implementation");
is_upscale_in_train = (dropout_implementation == "upscale_in_train");
is_test = context.Attr<bool>(pre_fix + "is_test");
is_test = context.Attr<bool>("is_test");
fix_seed = context.Attr<bool>(pre_fix + "fix_seed");

std::string str_seed = "Dropout";
Expand Down
13 changes: 4 additions & 9 deletions paddle/fluid/operators/fused/fused_feedforward_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ class FusedFeedForwardOp : public framework::OperatorWithKernel {
tmp_dim_x[dim_x.size() - 1] =
dim_Linear1Weight[dim_Linear1Weight.size() - 1];
context->SetOutputDim("Out", dim_x);
if (context->Attrs().Get<bool>("dropout1_is_test") == false) {
if (context->Attrs().Get<bool>("is_test") == false) {
context->SetOutputDim("Dropout1Mask", tmp_dim_x);
}
context->SetOutputDim("Dropout1Out", tmp_dim_x);
context->SetOutputDim("Linear1Out", tmp_dim_x);
context->SetOutputDim("Dropout2Out", dim_x);

if (context->Attrs().Get<bool>("dropout2_is_test") == false) {
if (context->Attrs().Get<bool>("is_test") == false) {
context->SetOutputDim("Dropout2Mask", dim_x);
}
framework::DDim mean_dim =
Expand Down Expand Up @@ -185,9 +185,7 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker {
"dropout2_implementation can only be downgrade_in_infer or "
"upscale_in_train"));
});
AddAttr<bool>("dropout1_is_test", "the is_test of first dropout")
.SetDefault(false);
AddAttr<bool>("dropout2_is_test", "the is_test of second dropout")
AddAttr<bool>("is_test", "the is_test attribute of dropout")
.SetDefault(false);
AddAttr<bool>("dropout1_fix_seed", "the is_test of first dropout")
.SetDefault(false);
Expand Down Expand Up @@ -218,10 +216,7 @@ class FusedFeedForwardOpGrad : public framework::OperatorWithKernel {

protected:
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("dropout1_is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when is_test is false"));
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("dropout2_is_test"), false,
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when is_test is false"));
bool pre_layer_norm = ctx->Attrs().Get<bool>("pre_layer_norm");
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/fused/fused_multi_transformer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class FusedMultiTransformerOpOpMaker
"'dropout_rate' must be between 0.0 and 1.0."));
});

AddAttr<bool>("dropout_is_test",
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,154 +20,11 @@
import paddle.fluid as fluid
from test_dist_base import TestDistRunnerBase, runtime_main
import paddle.distributed.fleet as fleet
import paddle.incubate.nn.functional as incubate_f

from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid import core
from paddle.nn.initializer import Constant
from paddle.incubate.nn import FusedMultiHeadAttention

paddle.enable_static()


def _set_var_distributed(var):
if var is None:
return

var.is_distributed = True

# NOTE: use current_block and find_var_recursive to support while_loop
startup_block = paddle.static.default_startup_program().current_block()
main_block = paddle.static.default_main_program().current_block()
startup_block._find_var_recursive(var.name).is_distributed = True
main_block._find_var_recursive(var.name).is_distributed = True


class ParallelFusedMultiHeadAttention(Layer):
def __init__(self,
embed_dim,
num_heads,
dropout_rate=0.5,
attn_dropout_rate=0.5,
kdim=None,
vdim=None,
normalize_before=False,
need_weights=False,
qkv_weight_attr=None,
qkv_bias_attr=None,
linear_weight_attr=None,
linear_bias_attr=None,
pre_ln_scale_attr=None,
pre_ln_bias_attr=None,
ln_scale_attr=None,
ln_bias_attr=None,
epsilon=1e-5,
nranks=1,
ring_id=-1,
name=None):
super(ParallelFusedMultiHeadAttention, self).__init__()

assert embed_dim > 0, ("Expected embed_dim to be greater than 0, "
"but recieved {}".format(embed_dim))
assert num_heads > 0, ("Expected nhead to be greater than 0, "
"but recieved {}".format(num_heads))

self.normalize_before = normalize_before
self._dtype = self._helper.get_default_dtype()
self._epsilon = epsilon
self._ring_id = ring_id

self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.kdim = kdim
self.vdim = vdim
self.need_weights = need_weights
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
assert need_weights == False, "Only support need_weight is False now."

# tensor model parallel
assert num_heads % nranks == 0
num_heads = num_heads // nranks

self.qkv_weight = self.create_parameter(
shape=[3, num_heads, self.head_dim, embed_dim],
attr=qkv_weight_attr,
dtype=self._dtype,
is_bias=False)
self.qkv_bias = self.create_parameter(
shape=[3, num_heads, self.head_dim],
attr=qkv_bias_attr,
dtype=self._dtype,
is_bias=True)
self.linear_weight = self.create_parameter(
shape=[num_heads * self.head_dim, embed_dim],
attr=linear_weight_attr,
dtype=self._dtype,
is_bias=False)
self.linear_bias = self.create_parameter(
shape=[embed_dim],
attr=linear_bias_attr,
dtype=self._dtype,
is_bias=True)

# tensor model parallel
if nranks > 1:
assert ring_id != -1
# column parallel
_set_var_distributed(self.qkv_weight)
_set_var_distributed(self.qkv_bias)
# row parallel
_set_var_distributed(self.linear_weight)

if normalize_before:
self.pre_ln_scale = self.create_parameter(
attr=pre_ln_scale_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.pre_ln_bias = self.create_parameter(
attr=pre_ln_bias_attr, shape=[embed_dim], is_bias=True)
self.ln_scale = None
self.ln_bias = None
else:
self.pre_ln_scale = None
self.pre_ln_bias = None
self.ln_scale = self.create_parameter(
attr=ln_scale_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.ln_bias = self.create_parameter(
attr=ln_bias_attr, shape=[embed_dim], is_bias=True)

self.dropout_rate = dropout_rate
self.attn_dropout_rate = attn_dropout_rate

self.name = name

def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
out = incubate_f.fused_multi_head_attention(
x=query,
qkv_weight=self.qkv_weight,
linear_weight=self.linear_weight,
pre_layer_norm=self.normalize_before,
pre_ln_scale=self.pre_ln_scale,
pre_ln_bias=self.pre_ln_bias,
ln_scale=self.ln_scale,
ln_bias=self.ln_bias,
pre_ln_epsilon=self._epsilon,
qkv_bias=self.qkv_bias,
linear_bias=self.linear_bias,
attn_mask=attn_mask,
dropout_rate=self.dropout_rate,
attn_dropout_rate=self.attn_dropout_rate,
ln_epsilon=self._epsilon,
training=self.training,
ring_id=self._ring_id,
name=self.name)
return out


def get_param_attr(weight, bias):
weight_attr = paddle.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(weight))
Expand Down Expand Up @@ -206,7 +63,7 @@ def create_model(data, rank):
qkv_w_attr, qkv_b_attr = get_param_attr(col_qkv_w, col_qkv_b)
linear_w_attr, linear_b_attr = get_param_attr(row_linear_w, linear_b)

attn = ParallelFusedMultiHeadAttention(
attn = FusedMultiHeadAttention(
hidden,
n_head,
dropout_rate=0.0,
Expand All @@ -228,7 +85,7 @@ def create_model(data, rank):
qkv_w_attr, qkv_b_attr = get_param_attr(qkv_w, qkv_b)
linear_w_attr, linear_b_attr = get_param_attr(linear_w, linear_b)

attn = ParallelFusedMultiHeadAttention(
attn = FusedMultiHeadAttention(
hidden,
n_head,
dropout_rate=0.0,
Expand Down
Loading