Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【inplace api】Batch add inplace api gt_, ge_, lt_, le_, eq_, not_equal_, logical_and_, logical_or_, logical_xor_, logical_not_, divide_, floor_divide_, bitwise_and_ , bitwise_or_, bitwise_xor_, bitwise_not_ #55509

Merged
merged 84 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
ff20306
tmp commit
GGBond8488 Jul 6, 2023
fc4e297
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
GGBond8488 Jul 18, 2023
3b71a68
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into batch_a…
GGBond8488 Jul 25, 2023
9398b5e
add atan2
GGBond8488 Jul 31, 2023
7fef09c
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into batch_a…
GGBond8488 Jul 31, 2023
793bd3e
add inplace api
GGBond8488 Jul 31, 2023
7c0f852
fix error
GGBond8488 Jul 31, 2023
e62e055
add inpalce divide
GGBond8488 Aug 3, 2023
08efca5
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into batch_a…
GGBond8488 Aug 3, 2023
6bd940b
add inplace api
GGBond8488 Aug 3, 2023
2aa0846
add more inplace
GGBond8488 Aug 6, 2023
aef245b
add more inpalce
GGBond8488 Aug 7, 2023
29faa09
fix logical_not error
GGBond8488 Aug 7, 2023
a0f7316
support sinh and cosh in cpu
ScottWong98 Aug 7, 2023
500b04b
support asin, acos, atan, asinh, acosh, atanh in cpu
ScottWong98 Aug 7, 2023
29ae413
fix typro
GGBond8488 Aug 8, 2023
c6f5245
fix typro
GGBond8488 Aug 8, 2023
986cf39
mv out atan2 ldexp
GGBond8488 Aug 8, 2023
71fc663
mv out atan2 ldexp
GGBond8488 Aug 8, 2023
c8588ff
support sinh and cosh in gpu
ScottWong98 Aug 8, 2023
c7ba0ce
support asin, acos, atan, asinh, acosh, atanh in gpu
ScottWong98 Aug 8, 2023
066b96f
fix ge error
GGBond8488 Aug 9, 2023
4e6c4c5
fix dygraph commpare error
GGBond8488 Aug 9, 2023
7dcc9f5
fix dygraph commpare error
GGBond8488 Aug 9, 2023
c822064
check complex in python
ScottWong98 Aug 11, 2023
df933dd
fix cast inpalce error
GGBond8488 Aug 11, 2023
5d9537a
open inplace test
GGBond8488 Aug 11, 2023
d1c1ddb
fix ops.yaml error
GGBond8488 Aug 12, 2023
99164da
mv cast inpalce to python
GGBond8488 Aug 13, 2023
c8eac31
fix coverage ci
GGBond8488 Aug 14, 2023
d43672e
add last inplace
GGBond8488 Aug 14, 2023
3860b4a
fix inplace error
GGBond8488 Aug 15, 2023
7795e21
fix cast error
GGBond8488 Aug 15, 2023
52daba0
fix error
GGBond8488 Aug 15, 2023
47acec7
add nan_to_num_
GGBond8488 Aug 15, 2023
d86bdcf
fix typro
GGBond8488 Aug 15, 2023
969a775
fix sparse cast error
GGBond8488 Aug 15, 2023
8f3748a
Merge branch 'develop' into add_complex_support_for_math
ScottWong98 Aug 15, 2023
a7744b2
remove gpu 4
GGBond8488 Aug 15, 2023
a012645
fix static cast error
GGBond8488 Aug 16, 2023
2c37061
tmp commit
GGBond8488 Jul 6, 2023
87008b1
add atan2
GGBond8488 Jul 31, 2023
3a02bf9
add inplace api
GGBond8488 Jul 31, 2023
4c6f8f7
fix error
GGBond8488 Jul 31, 2023
5afd0e4
add inpalce divide
GGBond8488 Aug 3, 2023
7f10e3a
add inplace api
GGBond8488 Aug 3, 2023
2215a22
add more inplace
GGBond8488 Aug 6, 2023
ade74bf
add more inpalce
GGBond8488 Aug 7, 2023
7272e87
fix logical_not error
GGBond8488 Aug 7, 2023
ffd26b0
fix typro
GGBond8488 Aug 8, 2023
cfab627
fix typro
GGBond8488 Aug 8, 2023
1cb0529
mv out atan2 ldexp
GGBond8488 Aug 8, 2023
9104d60
mv out atan2 ldexp
GGBond8488 Aug 8, 2023
db301dd
fix ge error
GGBond8488 Aug 9, 2023
2c9a299
fix dygraph commpare error
GGBond8488 Aug 9, 2023
43e5484
fix dygraph commpare error
GGBond8488 Aug 9, 2023
27b8309
fix cast inpalce error
GGBond8488 Aug 11, 2023
3a0e180
open inplace test
GGBond8488 Aug 11, 2023
29d085f
fix ops.yaml error
GGBond8488 Aug 12, 2023
bd7756b
mv cast inpalce to python
GGBond8488 Aug 13, 2023
f44cb81
fix coverage ci
GGBond8488 Aug 14, 2023
77f64d0
add last inplace
GGBond8488 Aug 14, 2023
f3a8fc0
fix inplace error
GGBond8488 Aug 15, 2023
56b44fd
fix cast error
GGBond8488 Aug 15, 2023
c264e27
fix error
GGBond8488 Aug 15, 2023
5b0c9a3
add nan_to_num_
GGBond8488 Aug 15, 2023
502469d
fix typro
GGBond8488 Aug 15, 2023
2ca1c00
fix sparse cast error
GGBond8488 Aug 15, 2023
4d55a81
remove gpu 4
GGBond8488 Aug 15, 2023
dc1be49
fix static cast error
GGBond8488 Aug 16, 2023
aaa8743
fix cast error
GGBond8488 Aug 18, 2023
b0b268a
merge
GGBond8488 Aug 21, 2023
4a78509
fix
GGBond8488 Aug 21, 2023
d10ec7d
Revert "check complex in python"
GGBond8488 Aug 21, 2023
bfaeb7a
add renorm , fix error
GGBond8488 Aug 22, 2023
7950dc2
add coverage
GGBond8488 Aug 22, 2023
8957d0d
fix cumsum inpalce version error
GGBond8488 Aug 22, 2023
62bea67
add cast inpalce impl
GGBond8488 Aug 23, 2023
163ae79
rm test.log
GGBond8488 Aug 23, 2023
d9fa589
fix multiply_dyfunction and add multiply_backward test
GGBond8488 Aug 23, 2023
468528c
add and use is_same_tensor
GGBond8488 Aug 23, 2023
a95a282
fix typro
GGBond8488 Aug 24, 2023
a206eba
fix sone error
GGBond8488 Aug 24, 2023
1af98b2
fix typro
GGBond8488 Aug 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,32 @@ paddle::Tensor& multiply__ad_func(paddle::Tensor& x, // NOLINT
VLOG(3) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str);
}

bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad(
trace_backward, x_autograd_meta, y_autograd_meta);

// Node Declaration
std::shared_ptr<MultiplyGradNode> grad_node;
// Set grad_node before API Call
if (require_any_grad) {
paddle::platform::RecordEvent node_creation_record_event(
"multiply node_creation",
paddle::platform::TracerEventType::OperatorInner,
1);

grad_node = std::shared_ptr<MultiplyGradNode>(new MultiplyGradNode(1, 2));
// Set for forward trace
if (FLAGS_check_nan_inf) {
grad_node->SetForwardTrace(egr::Controller::Instance().GetPythonStack());
}
// SetAttributes if needed
grad_node->SetAttributeaxis(-1);
// Set TensorWrappers for Forward Inputs if needed
auto x_clone = paddle::experimental::assign(x);
grad_node->SetTensorWrapperx(x_clone);
grad_node->SetTensorWrappery(y);
}

// Forward API Call
auto& api_result = paddle::experimental::multiply_(x, y);
// Check NaN and Inf if needed
Expand All @@ -275,10 +301,6 @@ paddle::Tensor& multiply__ad_func(paddle::Tensor& x, // NOLINT

// Get Output AutoGradMeta
egr::AutogradMeta* out_autograd_meta = egr::EagerUtils::autograd_meta(&out);
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad(
trace_backward, x_autograd_meta, y_autograd_meta);

// Check Inplace if needed

egr::EagerUtils::CheckInplace(x, x_autograd_meta, require_any_grad);
Expand All @@ -289,25 +311,7 @@ paddle::Tensor& multiply__ad_func(paddle::Tensor& x, // NOLINT

// Node Creation
if (require_any_grad) {
paddle::platform::RecordEvent node_creation_record_event(
"multiply node_creation",
paddle::platform::TracerEventType::OperatorInner,
1);

egr::EagerUtils::PassStopGradient(false, out_autograd_meta);

// Node Construction
auto grad_node =
std::shared_ptr<MultiplyGradNode>(new MultiplyGradNode(1, 2));
// Set for forward trace
if (FLAGS_check_nan_inf) {
grad_node->SetForwardTrace(egr::Controller::Instance().GetPythonStack());
}
// SetAttributes if needed
grad_node->SetAttributeaxis(-1);
// Set TensorWrappers for Forward Inputs if needed
grad_node->SetTensorWrapperx(x);
grad_node->SetTensorWrappery(y);
// SetGradOutMeta & SetEdges
grad_node->SetGradOutMeta(x, 0);
grad_node->SetGradOutMeta(y, 1);
Expand Down Expand Up @@ -429,7 +433,6 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x,
input_str += input_y_str;
VLOG(3) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str);
}

// Forward API Call
auto api_result = paddle::experimental::sparse::multiply(x, y);
// Check NaN and Inf if needed
Expand Down
13 changes: 11 additions & 2 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,14 @@

- op : cast
args : (Tensor x, DataType dtype)
output : Tensor
output : Tensor(out)
infer_meta :
func : CastInferMeta
kernel :
func : cast
param : [x, dtype]
data_type : x
inplace: (x -> out)
backward : cast_grad

- op : channel_shuffle
Expand Down Expand Up @@ -202,11 +203,12 @@

- op : divide
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : divide
inplace: (x -> out)
backward : divide_grad

- op : dropout
Expand Down Expand Up @@ -293,6 +295,7 @@
func : CompareInferMeta
kernel :
func : equal
inplace: (x -> out)

- op : exponential_
args : (Tensor x, float lam)
Expand Down Expand Up @@ -324,6 +327,7 @@
func : ElementwiseInferMeta
kernel :
func : floor_divide
inplace: (x -> out)

- op : frobenius_norm
args : (Tensor x, int64_t[] axis, bool keep_dim, bool reduce_all)
Expand Down Expand Up @@ -424,6 +428,7 @@
func : CompareInferMeta
kernel :
func : greater_equal
inplace: (x -> out)

- op : greater_than
args : (Tensor x, Tensor y)
Expand All @@ -432,6 +437,7 @@
func : CompareInferMeta
kernel :
func : greater_than
inplace: (x -> out)

- op : hardswish
args : (Tensor x)
Expand Down Expand Up @@ -470,6 +476,7 @@
func : CompareInferMeta
kernel :
func : less_equal
inplace: (x -> out)

- op : less_than
args : (Tensor x, Tensor y)
Expand All @@ -478,6 +485,7 @@
func : CompareInferMeta
kernel :
func : less_than
inplace: (x -> out)

- op : linspace
args : (Tensor start, Tensor stop, Tensor number, DataType dtype, Place place)
Expand Down Expand Up @@ -646,6 +654,7 @@
func : CompareInferMeta
kernel :
func : not_equal
inplace: (x -> out)

- op : one_hot
args : (Tensor x, Scalar(int) num_classes)
Expand Down
16 changes: 14 additions & 2 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@
kernel :
func : bitwise_and
backend : x
inplace: (x -> out)

- op : bitwise_not
args : (Tensor x)
Expand All @@ -339,6 +340,7 @@
kernel :
func : bitwise_not
backend : x
inplace: (x -> out)

- op : bitwise_or
args : (Tensor x, Tensor y)
Expand All @@ -348,6 +350,7 @@
kernel :
func : bitwise_or
backend : x
inplace: (x -> out)

- op : bitwise_xor
args : (Tensor x, Tensor y)
Expand All @@ -357,6 +360,7 @@
kernel :
func : bitwise_xor
backend : x
inplace: (x -> out)

- op : bmm
args : (Tensor x, Tensor y)
Expand Down Expand Up @@ -618,6 +622,7 @@
func : UnchangedInferMetaCheckAxis
kernel :
func : cumprod
inplace: (x -> out)
backward : cumprod_grad

- op : cumsum
Expand All @@ -628,6 +633,7 @@
kernel :
func : cumsum
data_type : x
inplace: (x -> out)
backward : cumsum_grad

- op : data
Expand Down Expand Up @@ -1524,6 +1530,7 @@
func : logical_and
data_type : x
backend : x
inplace: (x -> out)

- op : logical_not
args : (Tensor x)
Expand All @@ -1534,6 +1541,7 @@
func : logical_not
data_type : x
backend : x
inplace: (x -> out)

- op : logical_or
args : (Tensor x, Tensor y)
Expand All @@ -1544,6 +1552,7 @@
func : logical_or
data_type : x
backend : x
inplace: (x -> out)

- op : logical_xor
args : (Tensor x, Tensor y)
Expand All @@ -1554,6 +1563,7 @@
func : logical_xor
data_type : x
backend : x
inplace: (x -> out)

- op : logit
args : (Tensor x, float eps = 1e-6f)
Expand Down Expand Up @@ -2073,12 +2083,13 @@

- op : renorm
args : (Tensor x, float p, int axis, float max_norm)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : renorm
inplace: (x -> out)
backward : renorm_grad

- op : reverse
Expand Down Expand Up @@ -2788,11 +2799,12 @@

- op : where
args : (Tensor condition, Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : WhereInferMeta
kernel :
func : where
inplace: (x -> out)
backward : where_grad

- op : yolo_box
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/core/meta_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ bool MetaTensor::is_selected_rows() const {
}
bool MetaTensor::is_tensor_array() const { return false; }

bool MetaTensor::is_same_tensor(const MetaTensor& meta_tensor) const {
return tensor_ != nullptr && tensor_ == meta_tensor.tensor();
}

void MetaTensor::share_dims(const MetaTensor& meta_tensor) {
ValidCheck(*this);
bool is_dense_tensor = phi::DenseTensor::classof(tensor_);
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/core/meta_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class MetaTensor {
// and it will be deleted in the future.
virtual bool is_tensor_array() const;

virtual bool is_same_tensor(const MetaTensor& meta_tensor) const;

virtual operator unspecified_bool_type() const {
return tensor_ == nullptr ? 0 : unspecified_bool_true;
}
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,9 @@ void CompareRawInferMeta(const MetaTensor& x,
out->set_dims(make_ddim(out_dims_array));
out->share_lod(x);
}

out->set_dtype(DataType::BOOL);
if (!out->is_same_tensor(x)) {
out->set_dtype(DataType::BOOL);
}
}

void CompareInferMeta(const MetaTensor& x,
Expand Down
7 changes: 6 additions & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,14 @@ void BatchSizeLikeInferMeta(const MetaTensor& x,

void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(out_dtype);
out->set_layout(x.layout());
out->share_lod(x);
// In inpalce case, setting the dtype of out will reset the dtype of x at the
// same time, which will cause bugs, so move the dtype setting of out to the
// kernel
if (!(out->is_same_tensor(x))) {
out->set_dtype(out_dtype);
}
}

void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) {
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/cpu/cast_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ void CastGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
PD_VISIT_ALL_TYPES(x.dtype(), "CastKernelImpl", ([&] {
CastKernelImpl<T, data_t>(dev_ctx, out_grad, x_grad);
CastKernelImpl<T, data_t>(
dev_ctx, out_grad, x_grad->dtype(), x_grad);
}));
}

Expand Down
23 changes: 23 additions & 0 deletions paddle/phi/kernels/cpu/cast_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,35 @@ struct CastOpTransformFunctor {
template <typename InT, typename OutT>
void CastKernelImpl(const CPUContext& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DenseTensor* out) {
auto* in_begin = x.data<InT>();
auto numel = x.numel();
auto* in_end = in_begin + numel;

auto* out_begin = dev_ctx.Alloc<OutT>(out);
out->set_type(out_dtype);

phi::Transform<CPUContext> trans;
trans(dev_ctx,
in_begin,
in_end,
out_begin,
CastOpTransformFunctor<InT, OutT>());
}

template <typename InT, typename OutT>
void CastInplaceKernelImpl(const CPUContext& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DenseTensor* out) {
auto x_origin = x;
auto* in_begin = x_origin.data<InT>();
auto numel = x_origin.numel();
auto* in_end = in_begin + numel;

auto* out_begin = dev_ctx.Alloc<OutT>(out);
out->set_type(out_dtype);

phi::Transform<CPUContext> trans;
trans(dev_ctx,
Expand Down
13 changes: 10 additions & 3 deletions paddle/phi/kernels/cpu/cast_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@ void CastKernel(const Context& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DenseTensor* out) {
PD_VISIT_ALL_TYPES(out_dtype, "CastKernelImpl", ([&] {
CastKernelImpl<T, data_t>(dev_ctx, x, out);
}));
if (out->IsSharedWith(x)) {
PD_VISIT_ALL_TYPES(out_dtype, "CastInplaceKernelImpl", ([&] {
CastInplaceKernelImpl<T, data_t>(
dev_ctx, x, out_dtype, out);
}));
} else {
PD_VISIT_ALL_TYPES(out_dtype, "CastKernelImpl", ([&] {
CastKernelImpl<T, data_t>(dev_ctx, x, out_dtype, out);
}));
}
}

} // namespace phi
Expand Down
Loading