Skip to content

Commit

Permalink
fix custom operator backward=None (PaddlePaddle#48656)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahy0825 committed Dec 20, 2022
1 parent 753fdcc commit 9fafa06
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions paddle/fluid/eager/custom_operator/custom_operator_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,18 +217,20 @@ RunCustomOpNode::operator()(
VLOG(6) << "Prepare Grad outputs for size: " << grad_outputs_names.size();
for (size_t i = 0; i < OutputMeta().size(); i++) {
if (map[0][0].find(i) != map[0][0].end()) {
int grad_output_idx = map[0][0][i];
VLOG(7) << "Insert grad outputs: " << i
<< " with size: " << OutputMeta()[i].size()
<< " to tmp_outputs: " << map[0][0][i];
for (size_t j = 0; j < OutputMeta()[i].size(); j++) {
outs[i].emplace_back(/* init it incase of copy nullptr of shared_ptr */
std::make_shared<phi::DenseTensor>(
phi::DataType::UNDEFINED),
egr::Controller::Instance().GenerateUniqueName(
"custom_tmp_grad"));
egr::EagerUtils::autograd_meta(&(outs[i][j]));
<< " with size: " << OutputMeta()[grad_output_idx].size()
<< " to tmp_outputs: " << grad_output_idx;
for (size_t j = 0; j < OutputMeta()[grad_output_idx].size(); j++) {
outs[grad_output_idx]
.emplace_back(/* init it incase of copy nullptr of shared_ptr */
std::make_shared<phi::DenseTensor>(
phi::DataType::UNDEFINED),
egr::Controller::Instance().GenerateUniqueName(
"custom_tmp_grad"));
egr::EagerUtils::autograd_meta(&(outs[grad_output_idx][j]));
}
tmp_outs[map[0][0][i]] = outs[i];
tmp_outs[grad_output_idx] = outs[grad_output_idx];
}
}
for (size_t i = 0; i < tmp_outs.size(); i++) {
Expand Down

0 comments on commit 9fafa06

Please sign in to comment.