Skip to content

Commit

Permalink
【inplace api】batch add inplace api paddle.log_, paddle.i0_, paddle.nn…
Browse files Browse the repository at this point in the history
….functional.leaky_relu_... (#55576)

* batch add inplace api

* add inplace test

* add activation inplace

* fix test

* remove atan2 ge, gt, le, lt, nq

* remove atan2 ge, gt, le, lt, nq

* fix windows ci error

* rerun ci

* fix typro

* fix bugs

---------

Co-authored-by: zhangrui34 <v_zhangrui34@baidu.com>
  • Loading branch information
GGBond8488 and zhangrui34 authored Jul 27, 2023
1 parent da25896 commit 58a03d4
Show file tree
Hide file tree
Showing 10 changed files with 434 additions and 14 deletions.
2 changes: 2 additions & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,7 @@
func : TrilInferMeta
kernel :
func : tril
inplace: (x -> out)
backward : tril_grad

- op : tril_indices
Expand All @@ -928,6 +929,7 @@
func : TriuInferMeta
kernel :
func : triu
inplace: (x -> out)
backward : triu_grad

- op : triu_indices
Expand Down
33 changes: 23 additions & 10 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -665,11 +665,12 @@

- op : digamma
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : digamma
inplace: (x -> out)
backward : digamma_grad

- op : dirichlet
Expand Down Expand Up @@ -1107,12 +1108,13 @@

- op : hardtanh
args : (Tensor x, float t_min=0, float t_max=24)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : hardtanh
inplace: (x -> out)
backward : hardtanh_grad

- op : heaviside
Expand Down Expand Up @@ -1149,6 +1151,7 @@
func : UnchangedInferMeta
kernel :
func : i0
inplace: (x -> out)
backward : i0_grad

- op : i0e
Expand Down Expand Up @@ -1361,12 +1364,13 @@

- op : leaky_relu
args : (Tensor x, float negative_slope = 0.02f)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : leaky_relu
inplace: (x -> out)
backward : leaky_relu_grad

- op : lerp
Expand All @@ -1386,6 +1390,7 @@
func : UnchangedInferMeta
kernel :
func : lgamma
inplace: (x -> out)
backward : lgamma_grad

- op : linear_interp
Expand Down Expand Up @@ -1413,38 +1418,42 @@

- op : log
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : log
inplace: (x -> out)
backward: log_grad

- op : log10
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : log10
inplace: (x -> out)
backward: log10_grad

- op : log1p
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : log1p
inplace: (x -> out)
backward: log1p_grad

- op : log2
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : log2
inplace: (x -> out)
backward: log2_grad

- op : log_loss
Expand Down Expand Up @@ -1517,12 +1526,13 @@

- op : logit
args : (Tensor x, float eps = 1e-6f)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : logit
inplace: (x -> out)
backward : logit_grad

- op : logsigmoid
Expand Down Expand Up @@ -1895,6 +1905,7 @@
param: [x]
kernel :
func : polygamma
inplace: (x -> out)
backward : polygamma_grad

- op : pow
Expand Down Expand Up @@ -2494,12 +2505,13 @@

- op : thresholded_relu
args : (Tensor x, float threshold = 1.0)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : thresholded_relu
inplace: (x -> out)
backward : thresholded_relu_grad

- op : topk
Expand Down Expand Up @@ -2546,11 +2558,12 @@

- op : trunc
args : (Tensor input)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : trunc
inplace: (input -> out)
backward : trunc_grad

- op : unbind
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2032,7 +2032,7 @@ struct LogFunctor : public BaseActivationFunctor<T> {

template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.template cast<U>().unaryExpr(Log<U>());
out.device(d) = x.template cast<U>().unaryExpr(Log<U>()).eval();
}
};

Expand Down Expand Up @@ -2076,7 +2076,7 @@ struct Log2Functor : public BaseActivationFunctor<T> {

template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.template cast<U>().unaryExpr(Log2<U>());
out.device(d) = x.template cast<U>().unaryExpr(Log2<U>()).eval();
}
};

Expand Down Expand Up @@ -2121,7 +2121,7 @@ struct Log10Functor : public BaseActivationFunctor<T> {

template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.template cast<U>().unaryExpr(Log10<U>());
out.device(d) = x.template cast<U>().unaryExpr(Log10<U>()).eval();
}
};

Expand Down Expand Up @@ -2166,7 +2166,7 @@ struct Log1pFunctor : public BaseActivationFunctor<T> {

template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.template cast<U>().unaryExpr(Log1p<U>());
out.device(d) = x.template cast<U>().unaryExpr(Log1p<U>()).eval();
}
};

Expand Down
27 changes: 27 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@
from .tensor.creation import full # noqa: F401
from .tensor.creation import full_like # noqa: F401
from .tensor.creation import triu # noqa: F401
from .tensor.creation import triu_ # noqa: F401
from .tensor.creation import tril # noqa: F401
from .tensor.creation import tril_ # noqa: F401
from .tensor.creation import meshgrid # noqa: F401
from .tensor.creation import empty # noqa: F401
from .tensor.creation import empty_like # noqa: F401
Expand Down Expand Up @@ -224,14 +226,18 @@
from .tensor.math import cumprod # noqa: F401
from .tensor.math import logcumsumexp # noqa: F401
from .tensor.math import logit # noqa: F401
from .tensor.math import logit_ # noqa: F401
from .tensor.math import exp # noqa: F401
from .tensor.math import expm1 # noqa: F401
from .tensor.math import expm1_ # noqa: F401
from .tensor.math import floor # noqa: F401
from .tensor.math import increment # noqa: F401
from .tensor.math import log # noqa: F401
from .tensor.math import log_ # noqa: F401
from .tensor.math import log2_ # noqa: F401
from .tensor.math import log2 # noqa: F401
from .tensor.math import log10 # noqa: F401
from .tensor.math import log10_ # noqa: F401
from .tensor.math import multiplex # noqa: F401
from .tensor.math import pow # noqa: F401
from .tensor.math import pow_ # noqa: F401
Expand Down Expand Up @@ -279,6 +285,7 @@
from .tensor.math import logaddexp # noqa: F401
from .tensor.math import inverse # noqa: F401
from .tensor.math import log1p # noqa: F401
from .tensor.math import log1p_ # noqa: F401
from .tensor.math import erf # noqa: F401
from .tensor.math import erf_ # noqa: F401
from .tensor.math import addmm # noqa: F401
Expand All @@ -294,9 +301,13 @@
from .tensor.math import broadcast_shape # noqa: F401
from .tensor.math import conj # noqa: F401
from .tensor.math import trunc # noqa: F401
from .tensor.math import trunc_ # noqa: F401
from .tensor.math import digamma # noqa: F401
from .tensor.math import digamma_ # noqa: F401
from .tensor.math import neg # noqa: F401
from .tensor.math import neg_ # noqa: F401
from .tensor.math import lgamma # noqa: F401
from .tensor.math import lgamma_ # noqa: F401
from .tensor.math import acosh # noqa: F401
from .tensor.math import acosh_ # noqa: F401
from .tensor.math import asinh # noqa: F401
Expand All @@ -317,6 +328,7 @@
from .tensor.math import outer # noqa: F401
from .tensor.math import heaviside # noqa: F401
from .tensor.math import frac # noqa: F401
from .tensor.math import frac_ # noqa: F401
from .tensor.math import sgn # noqa: F401
from .tensor.math import take # noqa: F401
from .tensor.math import frexp # noqa: F401
Expand All @@ -326,10 +338,12 @@
from .tensor.math import vander # noqa: F401
from .tensor.math import nextafter # noqa: F401
from .tensor.math import i0 # noqa: F401
from .tensor.math import i0_ # noqa: F401
from .tensor.math import i0e # noqa: F401
from .tensor.math import i1 # noqa: F401
from .tensor.math import i1e # noqa: F401
from .tensor.math import polygamma # noqa: F401
from .tensor.math import polygamma_ # noqa: F401

from .tensor.random import bernoulli # noqa: F401
from .tensor.random import poisson # noqa: F401
Expand Down Expand Up @@ -473,6 +487,7 @@
'logaddexp',
'logcumsumexp',
'logit',
'logit_',
'LazyGuard',
'sign',
'is_empty',
Expand Down Expand Up @@ -561,6 +576,7 @@
'rand',
'less_equal',
'triu',
'triu_',
'sin',
'sin_',
'dist',
Expand All @@ -582,6 +598,7 @@
'abs',
'abs_',
'tril',
'tril_',
'pow',
'pow_',
'zeros_like',
Expand All @@ -608,7 +625,9 @@
'broadcast_shape',
'conj',
'neg',
'neg_',
'lgamma',
'lgamma_',
'lerp',
'erfinv',
'inner',
Expand Down Expand Up @@ -693,13 +712,19 @@
'floor',
'cosh',
'log',
'log_',
'log2',
'log2_',
'log10',
'log10_',
'concat',
'check_shape',
'trunc',
'trunc_',
'frac',
'frac_',
'digamma',
'digamma_',
'standard_normal',
'diagonal',
'broadcast_tensors',
Expand Down Expand Up @@ -741,8 +766,10 @@
'unflatten',
'nextafter',
'i0',
'i0_',
'i0e',
'i1',
'i1e',
'polygamma',
'polygamma_',
]
6 changes: 6 additions & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
from .activation import gelu # noqa: F401
from .activation import hardshrink # noqa: F401
from .activation import hardtanh # noqa: F401
from .activation import hardtanh_ # noqa: F401
from .activation import hardsigmoid # noqa: F401
from .activation import hardswish # noqa: F401
from .activation import leaky_relu # noqa: F401
from .activation import leaky_relu_ # noqa: F401
from .activation import log_sigmoid # noqa: F401
from .activation import maxout # noqa: F401
from .activation import prelu # noqa: F401
Expand All @@ -44,6 +46,7 @@
from .activation import tanh_ # noqa: F401
from .activation import tanhshrink # noqa: F401
from .activation import thresholded_relu # noqa: F401
from .activation import thresholded_relu_ # noqa: F401
from .activation import log_softmax # noqa: F401
from .activation import glu # noqa: F401
from .activation import gumbel_softmax # noqa: F401
Expand Down Expand Up @@ -153,9 +156,11 @@
'gelu',
'hardshrink',
'hardtanh',
'hardtanh_',
'hardsigmoid',
'hardswish',
'leaky_relu',
'leaky_relu_',
'log_sigmoid',
'maxout',
'prelu',
Expand All @@ -176,6 +181,7 @@
'tanh_',
'tanhshrink',
'thresholded_relu',
'thresholded_relu_',
'log_softmax',
'glu',
'gumbel_softmax',
Expand Down
Loading

0 comments on commit 58a03d4

Please sign in to comment.