Skip to content

Commit

Permalink
Forbid control flow related ops to constant folding (#62206)
Browse files Browse the repository at this point in the history
* forbid control flow ops to constant folding

* refine
  • Loading branch information
yuanlehome authored Feb 29, 2024
1 parent 08d2b79 commit 7d84d55
Showing 1 changed file with 39 additions and 3 deletions.
42 changes: 39 additions & 3 deletions paddle/fluid/framework/ir/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/ir/constant_folding_pass.h"

#include <string>
#include <vector>
#include "glog/logging.h"

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

#include "paddle/fluid/framework/convert_utils.h"

namespace paddle {
namespace framework {
namespace ir {
Expand All @@ -51,6 +53,37 @@ struct ConstantFolding : public PatternBase {
};
} // namespace patterns

namespace {
std::unordered_set<std::string> GetControlFlowVarNames(ir::Graph *graph) {
std::unordered_set<std::string> control_flow_ops{"while",
"conditional_block"};
std::unordered_set<std::string> control_flow_var_names;
for (auto *node : graph->Nodes()) {
if (!node->IsOp() || control_flow_ops.count(node->Op()->Type()) == 0)
continue;
for (auto const &in_names : node->Op()->Inputs()) {
auto var_names = in_names.second;
control_flow_var_names.insert(var_names.begin(), var_names.end());
}
for (auto const &out_names : node->Op()->Outputs()) {
auto var_names = out_names.second;
control_flow_var_names.insert(var_names.begin(), var_names.end());
}
}
return control_flow_var_names;
}

bool OutputUsedByControlFlow(ir::Node *node,
const std::unordered_set<std::string> &cf_vars) {
for (auto out_node : node->outputs) {
if (cf_vars.count(out_node->Name())) {
return true;
}
}
return false;
}
} // namespace

ConstantFoldingPass::ConstantFoldingPass() = default;

void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const {
Expand All @@ -69,6 +102,7 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const {
"save",
"quantize_linear",
"dequantize_linear"};
const auto cf_vars = GetControlFlowVarNames(graph);
int folded_op_num = 0;

auto op_node_sorted = framework::ir::TopologyVariantSort(
Expand All @@ -78,7 +112,9 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const {
if (std::find(blacklist.begin(), blacklist.end(), op_node->Name()) !=
blacklist.end())
continue;

if (OutputUsedByControlFlow(op_node, cf_vars)) {
continue;
}
bool input_persis = true;
// map is used to record how many time a name string occurs in the whole
// graph's nodes
Expand Down

0 comments on commit 7d84d55

Please sign in to comment.