Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize KahnsTopologicalSort and PriorityNodeCompare #19475

Merged
merged 10 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1818,16 +1818,36 @@
}
}

template <typename T>
struct VisitorPriorityQueue {
using ComparatorType = std::function<bool(T, T)>;
std::list<T> list_;
const ComparatorType comparator_ = nullptr;
VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {}

Check warning on line 1826 in onnxruntime/core/graph/graph.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [5] Raw Output: onnxruntime/core/graph/graph.cc:1826: Single-parameter constructors should be marked explicit. [runtime/explicit] [5]

void push(T node) {
list_.insert(
std::upper_bound(list_.begin(), list_.end(), node, comparator_),
smk2007 marked this conversation as resolved.
Show resolved Hide resolved
node);
}
bool empty() { return list_.empty(); }
T top() { return list_.back(); }
void pop() { list_.pop_back(); }
};

#if !defined(ORT_MINIMAL_BUILD)
void Graph::KahnsTopologicalSort(const std::function<void(const Node*)>& enter,
const std::function<bool(const Node*, const Node*)>& comp) const {
std::unordered_map<NodeIndex, size_t> in_degree;
std::priority_queue<const Node*, std::vector<const Node*>, decltype(comp)> to_visit(comp);
std::vector<NodeIndex> topo_order;
InlinedVector<size_t> in_degree(MaxNodeIndex(), 0);
InlinedVector<NodeIndex> topo_order;
VisitorPriorityQueue<const Node*> to_visit(comp);

auto number_of_nodes = NumberOfNodes();
topo_order.reserve(number_of_nodes);

for (auto& node : Nodes()) {
size_t input_edge_count = node.GetInputEdgesCount();
in_degree.insert({node.Index(), input_edge_count});
in_degree[node.Index()] = input_edge_count;
if (input_edge_count == 0) {
to_visit.push(&node);
}
Expand All @@ -1844,16 +1864,17 @@
}

for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) {
in_degree[node_it->Index()]--;
auto& node_in_degree = in_degree[node_it->Index()];
node_in_degree--;

if (in_degree[node_it->Index()] == 0) {
if (node_in_degree == 0) {
to_visit.push(&*node_it);
}
}
topo_order.push_back(current->Index());
}

if (NumberOfNodes() != static_cast<int>(topo_order.size())) {
if (number_of_nodes != static_cast<int>(topo_order.size())) {
ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle.");
}
}
Expand Down Expand Up @@ -2843,7 +2864,7 @@

const gsl::not_null<TensorProto*> tensor_added{graph_proto_->add_initializer()};
*(tensor_added) = tensor;
name_to_initial_tensor_[tensor.name()] = tensor_added;
name_to_initial_tensor_.emplace(tensor.name(), tensor_added);
SetGraphResolveNeeded();
if (!is_loaded_from_model_file_ && GetNodeArg(tensor.name()) == nullptr) {
// make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs.
Expand Down
18 changes: 12 additions & 6 deletions onnxruntime/core/graph/graph_viewer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const {
struct PriorityNodeCompare {
inline bool IsHighPri(const Node* n) const {
// local statics so we can compare std::strings in the checks
static const std::string shape_op("Shape");
static const std::string size_op("Size");
static constexpr std::string_view shape_op("Shape");
static constexpr std::string_view size_op("Size");

const auto& op_type = n->OpType();
return op_type == shape_op || op_type == size_op;
Expand All @@ -26,15 +26,20 @@ struct PriorityNodeCompare {
// If return true, n2 will be output first
bool operator()(const Node* n1, const Node* n2) const {
// nodes in global high priority list will be output first
if (IsHighPri(n1) != IsHighPri(n2)) {
return IsHighPri(n2);
const bool isN1HighPri = IsHighPri(n1);
const bool isN2HighPri = IsHighPri(n2);
if (isN1HighPri != isN2HighPri) {
return isN2HighPri;
}

// nodes with lower priority value will be output first
if (n1->Priority() != n2->Priority()) {
return n1->Priority() > n2->Priority();
const auto n1_priority = n1->Priority();
const auto n2_priority = n2->Priority();
if (n1_priority != n2_priority) {
return n1_priority > n2_priority;
}

#ifdef ENABLE_TRAINING
// nodes of forward pass will be output first
auto n1_attrs = n1->GetAttributes();
auto n2_attrs = n2->GetAttributes();
Expand All @@ -45,6 +50,7 @@ struct PriorityNodeCompare {
if (n1_is_forward != n2_is_forward) {
return n2_is_forward > n1_is_forward;
}
#endif

// otherwise, nodes with lower index will be output first
return n1->Index() > n2->Index();
Expand Down
73 changes: 43 additions & 30 deletions onnxruntime/core/optimizer/noop_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,49 +42,62 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con

// if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule,
// but it won't happen if the case is accepted, thus reject it
auto initializer_rank = initializer->dims().size();
const auto& dims = initializer->dims();
auto initializer_rank = dims.size();
const auto* other_input_shape = node.InputDefs()[input0_is_initializer ? 1 : 0]->Shape();
if (other_input_shape == nullptr || initializer_rank > other_input_shape->dim_size()) {
return false;
}

int32_t data_type = initializer->data_type();
Initializer add_init(*initializer, graph.ModelPath());
if (add_init.size() > 1) {
int64_t tensor_size = 1;
for (auto i : dims) {
tensor_size *= i;
}

if (tensor_size > 1) {
return false;
}

// handle edge case where the total size of the initializer is 0
if (add_init.size() == 0) {
if (tensor_size == 0) {
return true;
}

float value = 0.0f;
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
value = *add_init.data<float>();
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
value = math::halfToFloat(add_init.data<MLFloat16>()->val);
break;
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
value = static_cast<float>(*add_init.data<double>());
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
value = static_cast<float>(*add_init.data<int32_t>());
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
value = static_cast<float>(*add_init.data<int64_t>());
break;
default:
if (op_type == "Add" ||
op_type == "Sub" ||
op_type == "Mul" ||
op_type == "Div") {
int32_t data_type = initializer->data_type();
Initializer add_init(*initializer, graph.ModelPath());

float value = 0.0f;
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
value = *add_init.data<float>();
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
value = math::halfToFloat(add_init.data<MLFloat16>()->val);
break;
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
value = static_cast<float>(*add_init.data<double>());
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
value = static_cast<float>(*add_init.data<int32_t>());
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
value = static_cast<float>(*add_init.data<int64_t>());
break;
default:
return false;
}

if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) {
smk2007 marked this conversation as resolved.
Show resolved Hide resolved
return false;
}
}

if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) {
return false;
}

if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) {
return false;
if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) {
return false;
}
}

// reject node output is graph output for now
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class ApiGraph final : public api::GraphRef {
const auto& graph_outputs = graph_.GetOutputs();
graph_outputs_.reserve(graph_outputs.size());
for (const auto* output : graph_outputs) {
graph_outputs_.insert(output->Name());
graph_outputs_.emplace(output->Name());
}
}

Expand Down
Loading