Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add type promotion for complex and real number. #63842

Merged
merged 46 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
e71d2a3
add type promotion for complex and real number.
zxcd Jan 25, 2024
b5b54d6
fix
zxcd Jan 25, 2024
828b92b
reduce api support
zxcd Jan 26, 2024
5b91d51
add more api support
zxcd Jan 30, 2024
5dd3ce1
fix
zxcd Jan 30, 2024
bca1b67
fix
zxcd Feb 1, 2024
e6f09f4
remove matmul
zxcd Feb 5, 2024
5f9c3f1
add T+S logic.
zxcd Feb 28, 2024
2ba8764
fix bug
zxcd Feb 29, 2024
7df3fa1
fix unittest
zxcd Mar 1, 2024
05d93c1
fix
zxcd Mar 1, 2024
f3af919
fix
zxcd Mar 1, 2024
d0eef9b
fix unittest
zxcd Mar 1, 2024
c2f10a9
Merge branch 'develop' into type_promotion_stage2_T_T
zxcd Mar 1, 2024
1da794b
fix gumbel
zxcd Mar 1, 2024
c9bf9a9
Merge branch 'type_promotion_stage2_T_T' of https://github.com/zxcd/P…
zxcd Mar 1, 2024
290fe25
rm print
zxcd Mar 1, 2024
9cd34a9
fix more unittests.
zxcd Mar 5, 2024
ce3a6ec
Merge branch 'PaddlePaddle:develop' into type_promotion_stage2_T_T
zxcd Mar 12, 2024
7912dc4
fix test_llama_group_log_softmax.py
zxcd Mar 12, 2024
38adb7b
fix bug, and add 0-d + 0-d logic.
zxcd Mar 18, 2024
437ca5b
rm print
zxcd Mar 19, 2024
80f2132
fix behavior of bool and int
zxcd Mar 20, 2024
8b08687
add unittest for all type promotion.
zxcd Mar 21, 2024
8c98c16
Merge branch 'develop' into type_promotion_stage2_T_T
zxcd Mar 22, 2024
f44e926
rm unintest which is unsupport dtype
zxcd Mar 22, 2024
afb8788
Merge branch 'type_promotion_stage2_T_T' of https://github.com/zxcd/P…
zxcd Mar 22, 2024
5598010
fix
zxcd Mar 22, 2024
aa0bf9c
fix
zxcd Mar 22, 2024
59d02a2
add error unittest
zxcd Mar 25, 2024
a12bb8d
Merge branch 'PaddlePaddle:develop' into type_promotion_stage2_T_T
zxcd Mar 26, 2024
ed5ed3d
fix increase unittest
zxcd Mar 26, 2024
9b1caf1
bug fix
zxcd Mar 29, 2024
c352742
Merge branch 'PaddlePaddle:develop' into type_promotion_stage2_T_T
zxcd Apr 7, 2024
e2a7686
fixed by comment
zxcd Apr 12, 2024
604db69
Merge branch 'type_promotion_stage2_T_T' of https://github.com/zxcd/P…
zxcd Apr 12, 2024
4a54c32
remove useless code.
zxcd Apr 12, 2024
d6d4598
Merge branch 'develop' into type_promotion_stage2_T_T
zxcd Apr 12, 2024
0e62c85
fix
zxcd Apr 16, 2024
d79cac2
Merge branch 'type_promotion_stage2_T_T' of https://github.com/zxcd/P…
zxcd Apr 16, 2024
c34aa6b
fix
zxcd Apr 17, 2024
dc624f8
fix TypePromotionForZeroDimTensor
zxcd Apr 18, 2024
de8ac06
add inplace API support, add special case can skip type promotion (ad…
zxcd Apr 24, 2024
ecd8c1a
Merge branch 'type_promotion_stage2_T_T' of https://github.com/zxcd/P…
zxcd Apr 24, 2024
46238aa
add broatcast support for MultiPrecisionAddKernelImpl.
zxcd Apr 30, 2024
c144bc7
Merge branch 'PaddlePaddle:develop' into type_promotion_stage2_T_T
zxcd May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x,
}

// Type promotion Logic
if (phi::NeedTypePromotion(x.dtype(), y.dtype())) {
if (phi::NeedTypePromotion("multiply", x.dtype(), y.dtype())) {
VLOG(5) << "got different data type, run type promotion automatically.";
LOG_FIRST_N(WARNING, 1)
<< "got different data type, run type promotion "
Expand Down Expand Up @@ -247,6 +247,22 @@ paddle::Tensor& multiply__ad_func(paddle::Tensor& x, // NOLINT

VLOG(5)
<< " No AMP for multiply__ad_func because it is a inplace or cast api. ";

// Type promotion Logic
if (phi::NeedTypePromotion("multiply_", x.dtype(), y.dtype())) {
VLOG(5) << "got different data type, run type promotion automatically.";
LOG_FIRST_N(WARNING, 1)
<< "got different data type, run type promotion "
"automatically, this may cause data type been changed.";
auto op_name = phi::TransToFluidOpName("multiply_");
auto promotion_type = phi::GetPromoteDtype(op_name, x.dtype(), y.dtype());

x = egr::PromoteCastInplace("x", x, promotion_type);
auto new_y = egr::PromoteCast("y", y, promotion_type);

return multiply__ad_func(x, new_y);
}

// Layout autotune

if (egr::Controller::Instance().UseLayoutAutoTune()) {
Expand Down Expand Up @@ -424,7 +440,7 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x,
}

// Type promotion Logic
if (phi::NeedTypePromotion(x.dtype(), y.dtype())) {
if (phi::NeedTypePromotion("multiply", x.dtype(), y.dtype())) {
VLOG(5) << "got different data type, run type promotion automatically.";
LOG_FIRST_N(WARNING, 1)
<< "got different data type, run type promotion "
Expand Down
83 changes: 81 additions & 2 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,50 @@
type_promote_white_list = {
"add": ["x", "y"],
"subtract": ["x", "y"],
"divide": ["x", "y"],
"floor_divide": ["x", "y"],
"elementwise_pow": ["x", "y"],
"where": ["x", "y"],
"equal": ["x", "y"],
"not_equal": ["x", "y"],
"less_than": ["x", "y"],
"less_equal": ["x", "y"],
"greater_than": ["x", "y"],
"greater_equal": ["x", "y"],
"logical_and": ["x", "y"],
"logical_or": ["x", "y"],
"logical_xor": ["x", "y"],
"fmax": ["x", "y"],
"fmin": ["x", "y"],
"maximum": ["x", "y"],
"minimum": ["x", "y"],
"remainder": ["x", "y"],
"huber_loss": ["input", "label"],
"nextafter": ["x", "y"],
"atan2": ["x", "y"],
}

type_promote_inplace_white_list = {
"add_": ["x", "y"],
"subtract_": ["x", "y"],
"divide_": ["x", "y"],
"floor_divide_": ["x", "y"],
"where_": ["x", "y"],
"equal_": ["x", "y"],
"not_equal_": ["x", "y"],
"less_than_": ["x", "y"],
"less_equal_": ["x", "y"],
"greater_than_": ["x", "y"],
"greater_equal_": ["x", "y"],
"logical_and_": ["x", "y"],
"logical_or_": ["x", "y"],
"logical_xor_": ["x", "y"],
"remainder_": ["x", "y"],
}

# dict of special api that forward api's output will affect backward api's output
# backward api's output usually affected by backward api's input

special_prune_dict = {
"matmul_grad": {"x": "grad_y", "y": "grad_x"},
}
Expand Down Expand Up @@ -537,13 +576,13 @@ class {} : public egr::GradNodeBase {{
}}
"""

TYPE_PROMOTION_LOGIC_TEMPLATE = """ if (phi::NeedTypePromotion({x}.dtype(), {y}.dtype())) {{
TYPE_PROMOTION_LOGIC_TEMPLATE = """ if (phi::NeedTypePromotion({op_func_name}, {x}.dtype(), {y}.dtype())) {{
VLOG(5) << "got different data type, run type promotion automatically.";
LOG_FIRST_N(WARNING, 1) << "got different data type, run type promotion automatically, this may cause data type been changed.";
{op_name}
auto promotion_type = phi::GetPromoteDtype(op_name, {x}.dtype(), {y}.dtype());

auto new_{x} = egr::PromoteCast("{x}", {x}, promotion_type);
{x_cast}
auto new_{y} = egr::PromoteCast("{y}", {y}, promotion_type);

{return_value}
Expand Down Expand Up @@ -1511,6 +1550,18 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
type_promote_inputs_call_list[pos] = f"new_{name}"
else:
type_promote_inputs_call_list[pos] = f"{name}"
elif forward_api_name in type_promote_inplace_white_list:
if name in type_promote_inplace_white_list[forward_api_name]:
if (
is_inplaced
and forward_inplace_map
and name in forward_inplace_map
):
type_promote_inputs_call_list[pos] = f"{name}"
else:
type_promote_inputs_call_list[pos] = f"new_{name}"
else:
type_promote_inputs_call_list[pos] = f"{name}"
if IsPlainTensorType(ttype):
if is_optional:
if (
Expand Down Expand Up @@ -1601,6 +1652,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
for name, atype, default_val, pos in forward_attrs_list:
inputs_call_list[pos] = name
amp_inputs_call_list[pos] = name
type_promote_inputs_call_list[pos] = name
if default_val is not None:
inputs_args_declaration_list[
pos
Expand Down Expand Up @@ -1846,16 +1898,43 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
# Forward type promotion logic
if forward_api_name in type_promote_white_list:
# only support two inputs
op_func_name = f"\"{forward_api_name}\""
x = type_promote_white_list[forward_api_name][0]
y = type_promote_white_list[forward_api_name][1]
type_promote_inputs_call_args_str = ", ".join(
type_promote_inputs_call_list
)
type_promote_call_list = f"return {forward_ad_function_name}({type_promote_inputs_call_args_str});"

x_cast = f"auto new_{x} = egr::PromoteCast(\"{x}\", {x}, promotion_type);"

type_promotion_logic_str = TYPE_PROMOTION_LOGIC_TEMPLATE.format(
op_func_name=op_func_name,
x=x,
y=y,
x_cast=x_cast,
op_name=kernel_trans2_op_name_str,
return_value=type_promote_call_list,
)
elif forward_api_name in type_promote_inplace_white_list:
# only support two inputs
op_func_name = f"\"{forward_api_name}\""
x = type_promote_inplace_white_list[forward_api_name][0]
y = type_promote_inplace_white_list[forward_api_name][1]
type_promote_inputs_call_args_str = ", ".join(
type_promote_inputs_call_list
)
type_promote_call_list = f"return {forward_ad_function_name}({type_promote_inputs_call_args_str});"

x_cast = (
f"{x} = egr::PromoteCastInplace(\"{x}\", {x}, promotion_type);"
)

type_promotion_logic_str = TYPE_PROMOTION_LOGIC_TEMPLATE.format(
op_func_name=op_func_name,
x=x,
y=y,
x_cast=x_cast,
op_name=kernel_trans2_op_name_str,
return_value=type_promote_call_list,
)
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/eager/type_promotion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,15 @@ inline paddle::Tensor PromoteCast(const std::string& input_name,
}
}

inline paddle::Tensor PromoteCastInplace(const std::string& input_name,
paddle::Tensor& input, // NOLINT
const phi::DataType& dst_dtype,
bool trace_backward = true) {
if (input.dtype() != dst_dtype) {
return paddle::experimental::cast_(input, dst_dtype);
} else {
return input;
}
}

} // namespace egr
Loading