Skip to content

Commit

Permalink
Use PadAndUnflatten to replace GatherGrad for restore (microsoft#16429)
Browse files Browse the repository at this point in the history
### Use PadAndUnflatten to replace GatherGrad for restore




### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa committed Jun 27, 2023
1 parent ae6da03 commit 403bebf
Show file tree
Hide file tree
Showing 16 changed files with 491 additions and 199 deletions.
12 changes: 6 additions & 6 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -752,14 +752,14 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) {
SrcNodeAttributes())};
}

IMPLEMENT_GRADIENT_BUILDER(GetGatherGradGradient) {
// TODO: Strictly speaking, GatherGrad's gradient is not alway Gather when the indices have repeated values.
// Since GatherGrad in foward path is only used by embed sparsity feature in which case the indices are unique,
// we can safely use Gather here. But we will adress this issue as soon as possible.
IMPLEMENT_GRADIENT_BUILDER(GetPadAndUnflattenGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef("Reshape"),
{GO(0), O(1)},
{IA("GO_reshaped")}),
NodeDef(OpDef{"Gather", kOnnxDomain, 1},
{GO(0), I(1)},
{GI(2)},
{IA("GO_reshaped"), I(1)},
{GI(0)},
SrcNodeAttributes())};
}

Expand Down
2 changes: 1 addition & 1 deletion orttraining/orttraining/core/graph/gradient_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ DECLARE_GRADIENT_BUILDER(GetPoolGradient)
DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient)
DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient)
DECLARE_GRADIENT_BUILDER(GetGatherGradient)
DECLARE_GRADIENT_BUILDER(GetGatherGradGradient)
DECLARE_GRADIENT_BUILDER(GetPadAndUnflattenGradient)
DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient)
DECLARE_GRADIENT_BUILDER(GetConvGradient)
DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient);
REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient);
REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient);
REGISTER_GRADIENT_BUILDER("GatherGrad", GetGatherGradGradient);
REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient);
REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient);
REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient);
REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient);
Expand Down
39 changes: 39 additions & 0 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4571,6 +4571,45 @@ Return true if all elements are true and false otherwise.
updateOutputShape(ctx, 6, {num_directions, three * hidden_size});
}
});

ONNX_CONTRIB_OPERATOR_SCHEMA(PadAndUnflatten)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(
"PadAndUnflatten operator pads zero on the first axis, and unflatten the axis into two axes according"
"to given unflatten_dims. This is used by padding elimination graph transformers."
"For each index in indices, the corresponding value in output comes from input."
"For other indices, the corresponding value in output will be padded to zero."

"The indices don't allow duplicated index values, otherwise, though there is no runtime check"
"(in case of performance concern), the behaviour of output is undefined."

"An example:"
" input: [[1, 2, 3, 4], [5, 6, 7, 8]], shape is [2, 4]"
" indices: [0, 5], shape is [2]"
" unflatten_dims: [2, 3], shape is [2]"

" output: [[[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [5, 6, 7, 8]]],"
" shape is [2, 3, 4]"
" flatten_output_shape: [6, 4], shape is [2]")
.Input(0, "input", "input data of rank N, shape is [d1, d2, ..., dN]", "T")
.Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).",
"T_INDEX")
.Input(2, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT")
.Output(0, "output", "output data of rank N+1, [M1, M2, d2, ..., dN]", "T")
.Output(1, "flatten_output_shape", "1D tensor with output shape, [M1*M2, d2, ..., dN]", "T_INT")
.TypeConstraint(
"T_INT",
{"tensor(int32)", "tensor(int64)"},
"Constrain shape to integer tensors.")
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeConstraint(
"T_INDEX",
{"tensor(int32)", "tensor(int64)"},
"Constrain indices to integer types");
}

} // namespace training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,34 +61,6 @@ NodeArg* GetDimsValue(Graph& graph, NodeArg* input, NodeArg* indices_arg, Node&
return gather_out_args[0];
}

// Insert Shape + ScatterElements to get an updated shape of input with index of indices_arg updated to
// the value of update_value.
// Such as, if the indices_arg is a initializer of [0] and the original shape of input is [valid_token_count, a, b, c],
// this function will return a shape of [update_value, a, b, c]
NodeArg* UpdateShape(Graph& graph, NodeArg* input, NodeArg* update_value, NodeArg* indices_arg, Node& node) {
InlinedVector<NodeArg*> shape_output_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("shape_result"),
nullptr)};
Node& shape_node = graph.AddNode(graph.GenerateNodeName("shape"), "Shape", "", {input},
shape_output_args, nullptr, kOnnxDomain);
ORT_ENFORCE(graph.SetOpSchemaFromRegistryForNode(shape_node), "Failed to get shape for " + shape_node.Name());
shape_node.SetExecutionProviderType(node.GetExecutionProviderType());

InlinedVector<NodeArg*> scatter_input_args;
scatter_input_args.push_back(shape_output_args[0]);
scatter_input_args.push_back(indices_arg);
scatter_input_args.push_back(update_value);

InlinedVector<NodeArg*> scatter_out_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("scatter_result"),
nullptr)};

Node& scatter_node = graph.AddNode(graph.GenerateNodeName("update_dim"), "ScatterElements", "", scatter_input_args,
scatter_out_args, nullptr, kOnnxDomain);
ORT_ENFORCE(graph.SetOpSchemaFromRegistryForNode(scatter_node), "Failed to update shape for " + scatter_node.Name());
scatter_node.SetExecutionProviderType(node.GetExecutionProviderType());

return scatter_out_args[0];
}

// Insert Reshape + ShrunkenGather to flatten the in_index-th input of node.
// The gather_index_arg is the indices of the elements that are not padding.
NodeArg* InsertNodesForInput(Graph& graph,
Expand Down Expand Up @@ -125,7 +97,8 @@ NodeArg* InsertNodesForInput(Graph& graph,

InlinedVector<NodeArg*> reshape_output_args;
reshape_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"), node.MutableInputDefs()[in_index]->TypeAsProto()));
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"),
node.MutableInputDefs()[in_index]->TypeAsProto()));

Node* new_reshape_node = InsertIntermediateNodeOnDestInput(
graph, node,
Expand Down Expand Up @@ -175,108 +148,50 @@ NodeArg* InsertNodesForInput(Graph& graph,
return gather_out_arg;
}

// Insert GatherGrad + Reshape to unflatten the shape of the in_index-th input of node.
// Insert PadAndUnflatten to unflatten the shape of the in_index-th input of node.
// The gathergrad_index_arg is the indices of the elements that are not padding.
// The new_shape_arg is the shape of [batch_size * seqlen, ...]
// gathergrad_index_arg and new_shape_arg are the arguments needed by GatherGrad.
NodeArg* InsertNodesForOutput(Graph& graph,
Node& node,
uint32_t in_index,
NodeArg* gathergrad_index_arg,
NodeArg* new_shape_arg,
NodeArg* first_two_dims_arg,
const logging::Logger& logger) {
std::vector<int64_t> other_indices;
auto input_shape = node.InputDefs()[in_index]->Shape();
for (int k = 2; k < input_shape->dim_size(); k++) {
// When executing, Shape of node here has been flattened, so the indices should be k-1.
other_indices.push_back(int64_t(k) - 1);
}

// Construct the unflattened_shape_arg of [batch_size, seqlen, ...]
NodeArg* unflattened_shape_arg = nullptr;
if (other_indices.empty()) {
unflattened_shape_arg = first_two_dims_arg;
} else {
// If the shape size of the in_index-th input of node is larger than 2 dims, we need to concat the first two dims
// of [batch_size, seqlen] and the other dims together.
ONNX_NAMESPACE::TensorProto other_indices_const_tensor;
other_indices_const_tensor.set_name(graph.GenerateNodeArgName("other_shape"));
other_indices_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
other_indices_const_tensor.add_dims(other_indices.size());
other_indices_const_tensor.set_raw_data(other_indices.data(), other_indices.size() * sizeof(int64_t));
NodeArg* other_indices_arg = &graph_utils::AddInitializer(graph, other_indices_const_tensor);
NodeArg* other_dims_arg = GetDimsValue(graph, node.MutableInputDefs()[in_index], other_indices_arg, node);

InlinedVector<NodeArg*> concat_input_args;
concat_input_args.push_back(first_two_dims_arg);
concat_input_args.push_back(other_dims_arg);

InlinedVector<NodeArg*> concat_output_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("concat_shape_result"),
nullptr)};

onnxruntime::NodeAttributes attributes;
attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", int64_t(0));

Node& concat_node = graph.AddNode(graph.GenerateNodeName("concat_shape"), "Concat", "", concat_input_args,
concat_output_args, &attributes, kOnnxDomain);
ORT_ENFORCE(graph.SetOpSchemaFromRegistryForNode(concat_node), "Failed to concat shape for " + concat_node.Name());
concat_node.SetExecutionProviderType(node.GetExecutionProviderType());
unflattened_shape_arg = concat_output_args[0];
}

InlinedVector<NodeArg*> gathergrad_input_args;
gathergrad_input_args.reserve(3);
gathergrad_input_args.push_back(new_shape_arg);
gathergrad_input_args.push_back(gathergrad_index_arg);
gathergrad_input_args.push_back(node.MutableInputDefs()[in_index]);

InlinedVector<NodeArg*> gathergrad_output_args;
gathergrad_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padding_recover_result"),
InlinedVector<NodeArg*> pad_node_input_args;
pad_node_input_args.reserve(3);
pad_node_input_args.push_back(node.MutableInputDefs()[in_index]);
pad_node_input_args.push_back(gathergrad_index_arg);
pad_node_input_args.push_back(first_two_dims_arg);

InlinedVector<NodeArg*> pad_node_output_args;
pad_node_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_result"),
nullptr));
pad_node_output_args.push_back(
&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_d1xd2_shape"),
nullptr));

Node* new_gathergrad_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
2,
0,
0 /* new_node_input_index*/,
0 /* new_node_output_index*/,
graph.GenerateNodeName("PaddingRecover"),
"GatherGrad",
"GatherGrad node to recover invalid tokens.",
gathergrad_input_args,
gathergrad_output_args,
"PadAndUnflatten",
"PadAndUnflatten node to recover invalid tokens.",
pad_node_input_args,
pad_node_output_args,
{},
kMSDomain,
logger);

new_gathergrad_node->SetExecutionProviderType(node.GetExecutionProviderType());
auto gathergrad_out_arg = new_gathergrad_node->MutableOutputDefs()[0];

InlinedVector<NodeArg*> reshape_input_args;
reshape_input_args.push_back(gathergrad_out_arg);
reshape_input_args.push_back(unflattened_shape_arg);
InlinedVector<NodeArg*> reshape_output_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("reshape_result"),
nullptr)};
Node* new_reshape_node = InsertIntermediateNodeOnDestInput(
graph, node,
in_index,
0,
0,
graph.GenerateNodeName("RecoverShape"),
"Reshape",
"Reshape node to recover invalid tokens.",
reshape_input_args,
reshape_output_args,
{},
kOnnxDomain,
logger);
new_reshape_node->SetExecutionProviderType(node.GetExecutionProviderType());
return new_reshape_node->MutableOutputDefs()[0];
return new_gathergrad_node->MutableOutputDefs()[0];
}

// Iterate the subgraph beginning from the start_node, and put all node args into 'subgraph'
// Also put all candidate input nodes and cantidate output nodes of the subgraph into candidate_inputs and
// Also put all candidate input nodes and candidate output nodes of the subgraph into candidate_inputs and
// candidate_outputs respectively.
void IterateSubgraphFromNode(Graph& graph,
Node* start_node,
Expand Down Expand Up @@ -368,25 +283,25 @@ void IterateSubgraphFromNode(Graph& graph,
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMul", {1, 9, 13})) {
if (subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end()) {
// If shape of [batch_size, seqlen, ...] is propagated from the first argument of MatMul.
// The dim size of the first argument must larger than 2 to propagete the first two dims to the output.
// The dim size of the first argument must be larger than 2 to propagate the first two dims to the output.
// Or else the first two dims of the output will not be [batch_size, seqlen] and this MatMul will be added
// to candidate_outputs as the output of the subgraph.
if (cur->InputDefs()[0]->Shape()->dim_size() > 2) {
subgraph.insert(cur->MutableOutputDefs()[0]);
PushAllOutputNode(graph, to_visit, cur, visited);
} else {
LOG_DEBUG_INFO(logger,
"PaddingElimination::dim size of left input of matmul smaller than 3 and \
this matmul would be output of subgraph.");
"PaddingElimination::dim size of left input of MatMul smaller than 3 and \
this MatMul would be the output of the subgraph.");
candidate_outputs.insert(cur);
continue;
}
} else if (subgraph.find(cur->MutableInputDefs()[1]) != subgraph.end()) {
LOG_DEBUG_INFO(logger, "PaddingElimination::right edge of matmul would not included.");
LOG_DEBUG_INFO(logger, "PaddingElimination::right edge of MatMul would not included.");
candidate_outputs.insert(cur);
continue;
} else {
ORT_THROW("PaddingElimination::found matmul node without input in subgraph.");
ORT_THROW("PaddingElimination::found MatMul node without input in subgraph.");
}
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "PythonOp", {1}, kMSDomain)) {
if (subgraph.find(cur->MutableInputDefs()[0]) == subgraph.end()) {
Expand Down Expand Up @@ -451,7 +366,8 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
") due to embedding input is not in the sparse embedding input list.");
continue;
}
const ONNX_NAMESPACE::TensorProto* padding_initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name());
const ONNX_NAMESPACE::TensorProto* padding_initializer =
graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name());
if (padding_initializer != nullptr &&
padding_initializer->dims_size() == 0 &&
((padding_initializer->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32) ||
Expand Down Expand Up @@ -526,18 +442,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
}
}

std::vector<int64_t> first_indices;
first_indices.push_back(0);
ONNX_NAMESPACE::TensorProto first_indice_const_tensor;
first_indice_const_tensor.set_name(graph.GenerateNodeArgName("indices"));
first_indice_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
first_indice_const_tensor.add_dims(first_indices.size());
first_indice_const_tensor.set_raw_data(first_indices.data(), first_indices.size() * sizeof(int64_t));
NodeArg* first_index_arg = &graph_utils::AddInitializer(graph, first_indice_const_tensor);

// Get the first dim value of flattened input_ids which is batch_size * seq_len
NodeArg* first_dim = GetDimsValue(graph, reshape_output_args[0], first_index_arg, *embedding_node);

std::vector<int64_t> first_two_indices{0, 1};
ONNX_NAMESPACE::TensorProto first_two_indices_const_tensor;
first_two_indices_const_tensor.set_name(graph.GenerateNodeArgName("first_two_indices"));
Expand All @@ -553,11 +457,7 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev
for (const auto& node : candidate_outputs) {
for (uint32_t i = 0; i < node->InputDefs().size(); ++i) {
if (subgraph.find(node->MutableInputDefs()[i]) != subgraph.end()) {
// Get a shape of the i-th input of the node with first index updated to value of first_dim
// which is batch_size * seq_len. This shape arg will be used as the shape input of GatherGrad
NodeArg* shape_arg_for_gather_grad = UpdateShape(
graph, node->MutableInputDefs()[i], first_dim, first_index_arg, *node);
InsertNodesForOutput(graph, *node, i, squeeze_out_arg, shape_arg_for_gather_grad, first_two_dims_arg, logger);
InsertNodesForOutput(graph, *node, i, squeeze_out_arg, first_two_dims_arg, logger);
handled_output_count++;
}
}
Expand Down
Loading

0 comments on commit 403bebf

Please sign in to comment.