Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Make FLAGS_einsum_opt as default ] Einsum memory optimization #43397

Merged
merged 18 commits into from
Jun 14, 2022
Merged
1 change: 1 addition & 0 deletions paddle/fluid/eager/nan_inf_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTensorAndVector& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<2>(tensors));
}

} // namespace egr
3 changes: 2 additions & 1 deletion paddle/fluid/eager/nan_inf_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ using TupleOfFourTensors = std::tuple<Tensor, Tensor, Tensor, Tensor>;
using TupleOfFiveTensors = std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfSixTensors =
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfTensorAndVector = std::tuple<Tensor, std::vector<Tensor>>;
using TupleOfTensorAndVector =
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>>;

void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor);

Expand Down
19 changes: 15 additions & 4 deletions paddle/fluid/operators/einsum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ class EinsumOpMaker : public framework::OpProtoAndCheckerMaker {
.AsExtra()
.AsIntermediate();

AddOutput("XShape", "(Tensor), The cache of the x_shape of: A and B.")
.AsDuplicable()
.AsExtra()
.AsIntermediate();
AddAttr<std::string>("equation",
"(string) A einsum equation. such as `ij,jk->ik`"
"There must have `->` and the number of operands in "
Expand All @@ -59,8 +63,8 @@ class EinsumGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
auto x_name = "Operands";
auto x_grad_name = framework::GradVarName(x_name);
ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim(x_name));
ctx->ShareAllLoD(x_name, x_grad_name);
ctx->SetOutputsDim(x_grad_name, ctx->GetInputsDim("Operands"));
ctx->ShareAllLoD("Operands", x_grad_name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么不直接用 x_name 了?

}

protected:
Expand All @@ -79,8 +83,15 @@ class EinsumGradMaker : public framework::SingleGradOpMaker<T> {

void Apply(GradOpPtr<T> retv) const override {
retv->SetType("einsum_grad");
retv->SetInput("Operands", this->Input("Operands"));
retv->SetInput("InnerCache", this->Output("InnerCache"));
if (this->HasOutput("InnerCache")) {
retv->SetInput("InnerCache", this->Output("InnerCache"));
}
if (this->HasOutput("XShape")) {
// add if for compatibility.
retv->SetInput("Operands", this->Output("XShape")); // for memory save.
} else {
retv->SetInput("Operands", this->Input("Operands"));
}
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("Operands"),
Expand Down
9 changes: 8 additions & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache) {
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape) {
// collect the following informations to prepare einsum.
LabelMap labelshape(0);
LabelMap labeltype(LabelType::Reduction);
Expand Down Expand Up @@ -439,6 +440,12 @@ void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape);
out->set_dims(make_ddim(output_dims));
out->set_dtype(inputs[0]->dtype());
for (size_t i = 0; i < xshape.size(); ++i) {
if (xshape[i] != nullptr) {
xshape[i]->set_dims(inputs[i]->dims());
xshape[i]->set_dtype(inputs[i]->dtype());
}
}
}

void ExpandInferMeta(const MetaTensor& x,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ void EighInferMeta(const MetaTensor& x,
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
std::vector<MetaTensor*> inner_cache);
std::vector<MetaTensor*> inner_cache,
std::vector<MetaTensor*> xshape);

void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/einsum_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache);
std::vector<DenseTensor*> inner_cache,
std::vector<DenseTensor*> xshape);

} // namespace phi
1 change: 0 additions & 1 deletion paddle/phi/kernels/impl/einsum_grad_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ void EinsumGradKernel(const Context& dev_ctx,
cache[0].ShareBufferWith(*(inner_cache[0]));
cache[1].ShareBufferWith(*(inner_cache[1]));
}

EinsumKernelImpl<T, Context>(dev_ctx,
all_labels,
operands_for_A,
Expand Down
12 changes: 9 additions & 3 deletions paddle/phi/kernels/impl/einsum_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ DenseTensor PerformContraction(
}
// reduction
DenseTensor trans_t;
if (FLAGS_einsum_opt && use_cache && cache[operand_idx] != nullptr &&
if (use_cache && cache[operand_idx] != nullptr &&
cache[operand_idx]->IsInitialized()) {
trans_t.ShareBufferWith(*(cache[operand_idx]));
VLOG(5) << "Cache Used!";
Expand All @@ -468,7 +468,7 @@ DenseTensor PerformContraction(
dev_ctx, t, perm, all_labels, ellipsis, label2type);
trans_t = PerformTranspose<T, Context>(
dev_ctx, reduct_t, perm, reordered_all_labels, ellipsis, label2type);
if (FLAGS_einsum_opt && cache[operand_idx] != nullptr)
if (cache[operand_idx] != nullptr)
cache[operand_idx]->ShareBufferWith(trans_t);
}
auto mul_dims = GetShapeByType<int>(all_labels,
Expand Down Expand Up @@ -599,6 +599,11 @@ void EinsumKernelImpl(const Context& dev_ctx,
out);
// Reshape Procedure
} else if (inputs.size() == 1) {
if (cache[0] != nullptr) { // For compatibility, may be cache is nullptr if
// loading the program from v2.3.0
(*cache[0]) = *(inputs[0]); // ShareBuffer for backward, because backward
// we can only see cached tensor.
}
auto reduce_A = PerformReduction<T, Context>(dev_ctx,
*inputs[0],
label2perms[0],
Expand Down Expand Up @@ -627,7 +632,8 @@ void EinsumKernelRaw(const Context& dev_ctx,
const std::vector<const DenseTensor*>& inputs,
const std::string& equation,
DenseTensor* out,
std::vector<DenseTensor*> cache) {
std::vector<DenseTensor*> cache,
std::vector<DenseTensor*> xshape) {
std::vector<char> tmp;
// for the sake of compatibility, we may load and run v2.3 EinsumOp. Output
// may have nullptr and the cache.size() is not equal to inputs.size(). refer
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/compat/einsum_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace phi {

KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache"});
"einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache", "XShape"});
}

KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
Expand Down
7 changes: 4 additions & 3 deletions python/paddle/fluid/tests/unittests/test_einsum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def setUp(self):
'Out':
out,
"InnerCache": [('cache_' + str(i), np.array([1.0]))
for i in range(len(self.operands))]
for i in range(len(self.operands))],
"XShape": [('xshape_' + str(i), np.array([1.0]))
for i in range(len(self.operands))],
}

def init_input(self):
Expand All @@ -48,14 +50,13 @@ def init_input(self):
self.inputs.append(np.random.random(s).astype(t))

def set_mandatory(self):
self.disable = False
self.shapes = [(10, 10, 20), (20, 6)]
self.types = [np.float64, np.float64]
self.equation = "mij,jk->ki"

def test_check_output(self):
if not self.disable:
self.check_output(no_check_set=["InnerCache"])
self.check_output(no_check_set=["InnerCache", "XShape"])

def test_grad(self):
if not self.disable:
Expand Down
11 changes: 8 additions & 3 deletions python/paddle/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,9 +807,9 @@ def gen_einsum_op(equation, *operands):

if _in_legacy_dygraph():
# dygraph
return _C_ops.einsum(operands, len(operands), 'equation', equation)[0]
return _C_ops.einsum(operands, len(operands), len(operands), 'equation',
equation)[0]

# static graph
for inp in operands:
check_variable_and_dtype(inp, 'dtype', ['float32', 'float64'], 'einsum')
check_type(equation, 'equation', str, 'einsum')
Expand All @@ -821,11 +821,16 @@ def gen_einsum_op(equation, *operands):
helper.create_variable_for_type_inference(dtype=operands[0].dtype)
for i in range(len(operands))
]
xshape = [
helper.create_variable_for_type_inference(dtype=operands[0].dtype)
for i in range(len(operands))
]
helper.append_op(type='einsum',
inputs={'Operands': operands},
outputs={
'Out': out,
"InnerCache": caches
"InnerCache": caches,
"XShape": xshape
},
attrs=attrs)
return out
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@

- api : einsum
args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()}
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
infer_meta :
func : EinsumInferMeta
param : [x, equation]
Expand Down
17 changes: 14 additions & 3 deletions python/paddle/utils/code_gen/backward.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
#- backward_api : einsum_grad

#forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache)
#args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation)
#output : Tensor[](x_grad){x.size()}
#infer_meta :
#func : UnchangedMultiInferMeta
#param : [x]
#kernel :
#func : einsum_grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分注释代码删掉吧


- backward_api : abs_double_grad
forward : abs_grad (Tensor x, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_x_grad)
Expand Down Expand Up @@ -611,12 +622,12 @@
skip_transform : out_w, out_w_grad

- backward_api : einsum_grad
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache)
args : (Tensor[] x, Tensor[] inner_cache, Tensor out_grad, str equation)
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape)
args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation)
output : Tensor[](x_grad){x.size()}
infer_meta :
func : UnchangedMultiInferMeta
param : [x]
param : [x_shape]
kernel :
func : einsum_grad

Expand Down