diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 48f29cbc02d6..5576e8ccb1a5 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -752,14 +752,14 @@ IMPLEMENT_GRADIENT_BUILDER(GetGatherGradient) { SrcNodeAttributes())}; } -IMPLEMENT_GRADIENT_BUILDER(GetGatherGradGradient) { - // TODO: Strictly speaking, GatherGrad's gradient is not alway Gather when the indices have repeated values. - // Since GatherGrad in foward path is only used by embed sparsity feature in which case the indices are unique, - // we can safely use Gather here. But we will adress this issue as soon as possible. +IMPLEMENT_GRADIENT_BUILDER(GetPadAndUnflattenGradient) { return std::vector{ + NodeDef(OpDef("Reshape"), + {GO(0), O(1)}, + {IA("GO_reshaped")}), NodeDef(OpDef{"Gather", kOnnxDomain, 1}, - {GO(0), I(1)}, - {GI(2)}, + {IA("GO_reshaped"), I(1)}, + {GI(0)}, SrcNodeAttributes())}; } diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 31a51d71198f..30da6e53ba2d 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -39,7 +39,7 @@ DECLARE_GRADIENT_BUILDER(GetPoolGradient) DECLARE_GRADIENT_BUILDER(GetAveragePoolGradient) DECLARE_GRADIENT_BUILDER(GetMaxPoolGradient) DECLARE_GRADIENT_BUILDER(GetGatherGradient) -DECLARE_GRADIENT_BUILDER(GetGatherGradGradient) +DECLARE_GRADIENT_BUILDER(GetPadAndUnflattenGradient) DECLARE_GRADIENT_BUILDER(GetShrunkenGatherGradient) DECLARE_GRADIENT_BUILDER(GetConvGradient) DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 9970c69e6ef7..ca0a397cedb3 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -70,7 +70,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient); REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient); REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient); - REGISTER_GRADIENT_BUILDER("GatherGrad", GetGatherGradGradient); + REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient); REGISTER_GRADIENT_BUILDER("ShrunkenGather", GetShrunkenGatherGradient); REGISTER_GRADIENT_BUILDER("Conv", GetConvGradient); REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient); diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 466de301fce9..a15964079e1f 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -4571,6 +4571,45 @@ Return true if all elements are true and false otherwise. updateOutputShape(ctx, 6, {num_directions, three * hidden_size}); } }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(PadAndUnflatten) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc( + "PadAndUnflatten operator pads zero on the first axis, and unflatten the axis into two axes according" + "to given unflatten_dims. This is used by padding elimination graph transformers." + "For each index in indices, the corresponding value in output comes from input." + "For other indices, the corresponding value in output will be padded to zero." + + "The indices don't allow duplicated index values, otherwise, though there is no runtime check" + "(in case of performance concern), the behaviour of output is undefined." + + "An example:" + " input: [[1, 2, 3, 4], [5, 6, 7, 8]], shape is [2, 4]" + " indices: [0, 5], shape is [2]" + " unflatten_dims: [2, 3], shape is [2]" + + " output: [[[1, 2, 3, 4], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [5, 6, 7, 8]]]," + " shape is [2, 3, 4]" + " flatten_output_shape: [6, 4], shape is [2]") + .Input(0, "input", "input data of rank N, shape is [d1, d2, ..., dN]", "T") + .Input(1, "indices", "1D Tensor of int32/int64 indices, shape is [d1], each element's value ranges in [0, M1*M2).", + "T_INDEX") + .Input(2, "unflatten_dims", "1D tensor with two values, [M1, M2].", "T_INT") + .Output(0, "output", "output data of rank N+1, [M1, M2, d2, ..., dN]", "T") + .Output(1, "flatten_output_shape", "1D tensor with output shape, [M1*M2, d2, ..., dN]", "T_INT") + .TypeConstraint( + "T_INT", + {"tensor(int32)", "tensor(int64)"}, + "Constrain shape to integer tensors.") + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint( + "T_INDEX", + {"tensor(int32)", "tensor(int64)"}, + "Constrain indices to integer types"); } } // namespace training diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc index fe2f86edd015..b6b6d62ec2f9 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.cc @@ -61,34 +61,6 @@ NodeArg* GetDimsValue(Graph& graph, NodeArg* input, NodeArg* indices_arg, Node& return gather_out_args[0]; } -// Insert Shape + ScatterElements to get an updated shape of input with index of indices_arg updated to -// the value of update_value. -// Such as, if the indices_arg is a initializer of [0] and the original shape of input is [valid_token_count, a, b, c], -// this function will return a shape of [update_value, a, b, c] -NodeArg* UpdateShape(Graph& graph, NodeArg* input, NodeArg* update_value, NodeArg* indices_arg, Node& node) { - InlinedVector shape_output_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("shape_result"), - nullptr)}; - Node& shape_node = graph.AddNode(graph.GenerateNodeName("shape"), "Shape", "", {input}, - shape_output_args, nullptr, kOnnxDomain); - ORT_ENFORCE(graph.SetOpSchemaFromRegistryForNode(shape_node), "Failed to get shape for " + shape_node.Name()); - shape_node.SetExecutionProviderType(node.GetExecutionProviderType()); - - InlinedVector scatter_input_args; - scatter_input_args.push_back(shape_output_args[0]); - scatter_input_args.push_back(indices_arg); - scatter_input_args.push_back(update_value); - - InlinedVector scatter_out_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("scatter_result"), - nullptr)}; - - Node& scatter_node = graph.AddNode(graph.GenerateNodeName("update_dim"), "ScatterElements", "", scatter_input_args, - scatter_out_args, nullptr, kOnnxDomain); - ORT_ENFORCE(graph.SetOpSchemaFromRegistryForNode(scatter_node), "Failed to update shape for " + scatter_node.Name()); - scatter_node.SetExecutionProviderType(node.GetExecutionProviderType()); - - return scatter_out_args[0]; -} - // Insert Reshape + ShrunkenGather to flatten the in_index-th input of node. // The gather_index_arg is the indices of the elements that are not padding. NodeArg* InsertNodesForInput(Graph& graph, @@ -125,7 +97,8 @@ NodeArg* InsertNodesForInput(Graph& graph, InlinedVector reshape_output_args; reshape_output_args.push_back( - &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"), node.MutableInputDefs()[in_index]->TypeAsProto())); + &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("inputs_reshape_result"), + node.MutableInputDefs()[in_index]->TypeAsProto())); Node* new_reshape_node = InsertIntermediateNodeOnDestInput( graph, node, @@ -175,7 +148,7 @@ NodeArg* InsertNodesForInput(Graph& graph, return gather_out_arg; } -// Insert GatherGrad + Reshape to unflatten the shape of the in_index-th input of node. +// Insert PadAndUnflatten to unflatten the shape of the in_index-th input of node. // The gathergrad_index_arg is the indices of the elements that are not padding. // The new_shape_arg is the shape of [batch_size * seqlen, ...] // gathergrad_index_arg and new_shape_arg are the arguments needed by GatherGrad. @@ -183,100 +156,42 @@ NodeArg* InsertNodesForOutput(Graph& graph, Node& node, uint32_t in_index, NodeArg* gathergrad_index_arg, - NodeArg* new_shape_arg, NodeArg* first_two_dims_arg, const logging::Logger& logger) { - std::vector other_indices; - auto input_shape = node.InputDefs()[in_index]->Shape(); - for (int k = 2; k < input_shape->dim_size(); k++) { - // When executing, Shape of node here has been flattened, so the indices should be k-1. - other_indices.push_back(int64_t(k) - 1); - } - - // Construct the unflattened_shape_arg of [batch_size, seqlen, ...] - NodeArg* unflattened_shape_arg = nullptr; - if (other_indices.empty()) { - unflattened_shape_arg = first_two_dims_arg; - } else { - // If the shape size of the in_index-th input of node is larger than 2 dims, we need to concat the first two dims - // of [batch_size, seqlen] and the other dims together. - ONNX_NAMESPACE::TensorProto other_indices_const_tensor; - other_indices_const_tensor.set_name(graph.GenerateNodeArgName("other_shape")); - other_indices_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - other_indices_const_tensor.add_dims(other_indices.size()); - other_indices_const_tensor.set_raw_data(other_indices.data(), other_indices.size() * sizeof(int64_t)); - NodeArg* other_indices_arg = &graph_utils::AddInitializer(graph, other_indices_const_tensor); - NodeArg* other_dims_arg = GetDimsValue(graph, node.MutableInputDefs()[in_index], other_indices_arg, node); - - InlinedVector concat_input_args; - concat_input_args.push_back(first_two_dims_arg); - concat_input_args.push_back(other_dims_arg); - - InlinedVector concat_output_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("concat_shape_result"), - nullptr)}; - - onnxruntime::NodeAttributes attributes; - attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", int64_t(0)); - - Node& concat_node = graph.AddNode(graph.GenerateNodeName("concat_shape"), "Concat", "", concat_input_args, - concat_output_args, &attributes, kOnnxDomain); - ORT_ENFORCE(graph.SetOpSchemaFromRegistryForNode(concat_node), "Failed to concat shape for " + concat_node.Name()); - concat_node.SetExecutionProviderType(node.GetExecutionProviderType()); - unflattened_shape_arg = concat_output_args[0]; - } - - InlinedVector gathergrad_input_args; - gathergrad_input_args.reserve(3); - gathergrad_input_args.push_back(new_shape_arg); - gathergrad_input_args.push_back(gathergrad_index_arg); - gathergrad_input_args.push_back(node.MutableInputDefs()[in_index]); - - InlinedVector gathergrad_output_args; - gathergrad_output_args.push_back( - &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padding_recover_result"), + InlinedVector pad_node_input_args; + pad_node_input_args.reserve(3); + pad_node_input_args.push_back(node.MutableInputDefs()[in_index]); + pad_node_input_args.push_back(gathergrad_index_arg); + pad_node_input_args.push_back(first_two_dims_arg); + + InlinedVector pad_node_output_args; + pad_node_output_args.push_back( + &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_result"), + nullptr)); + pad_node_output_args.push_back( + &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("padded_d1xd2_shape"), nullptr)); Node* new_gathergrad_node = InsertIntermediateNodeOnDestInput( graph, node, in_index, - 2, - 0, + 0 /* new_node_input_index*/, + 0 /* new_node_output_index*/, graph.GenerateNodeName("PaddingRecover"), - "GatherGrad", - "GatherGrad node to recover invalid tokens.", - gathergrad_input_args, - gathergrad_output_args, + "PadAndUnflatten", + "PadAndUnflatten node to recover invalid tokens.", + pad_node_input_args, + pad_node_output_args, {}, kMSDomain, logger); new_gathergrad_node->SetExecutionProviderType(node.GetExecutionProviderType()); - auto gathergrad_out_arg = new_gathergrad_node->MutableOutputDefs()[0]; - - InlinedVector reshape_input_args; - reshape_input_args.push_back(gathergrad_out_arg); - reshape_input_args.push_back(unflattened_shape_arg); - InlinedVector reshape_output_args{&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("reshape_result"), - nullptr)}; - Node* new_reshape_node = InsertIntermediateNodeOnDestInput( - graph, node, - in_index, - 0, - 0, - graph.GenerateNodeName("RecoverShape"), - "Reshape", - "Reshape node to recover invalid tokens.", - reshape_input_args, - reshape_output_args, - {}, - kOnnxDomain, - logger); - new_reshape_node->SetExecutionProviderType(node.GetExecutionProviderType()); - return new_reshape_node->MutableOutputDefs()[0]; + return new_gathergrad_node->MutableOutputDefs()[0]; } // Iterate the subgraph beginning from the start_node, and put all node args into 'subgraph' -// Also put all candidate input nodes and cantidate output nodes of the subgraph into candidate_inputs and +// Also put all candidate input nodes and candidate output nodes of the subgraph into candidate_inputs and // candidate_outputs respectively. void IterateSubgraphFromNode(Graph& graph, Node* start_node, @@ -368,7 +283,7 @@ void IterateSubgraphFromNode(Graph& graph, } else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "MatMul", {1, 9, 13})) { if (subgraph.find(cur->MutableInputDefs()[0]) != subgraph.end()) { // If shape of [batch_size, seqlen, ...] is propagated from the first argument of MatMul. - // The dim size of the first argument must larger than 2 to propagete the first two dims to the output. + // The dim size of the first argument must be larger than 2 to propagate the first two dims to the output. // Or else the first two dims of the output will not be [batch_size, seqlen] and this MatMul will be added // to candidate_outputs as the output of the subgraph. if (cur->InputDefs()[0]->Shape()->dim_size() > 2) { @@ -376,17 +291,17 @@ void IterateSubgraphFromNode(Graph& graph, PushAllOutputNode(graph, to_visit, cur, visited); } else { LOG_DEBUG_INFO(logger, - "PaddingElimination::dim size of left input of matmul smaller than 3 and \ - this matmul would be output of subgraph."); + "PaddingElimination::dim size of left input of MatMul smaller than 3 and \ + this MatMul would be the output of the subgraph."); candidate_outputs.insert(cur); continue; } } else if (subgraph.find(cur->MutableInputDefs()[1]) != subgraph.end()) { - LOG_DEBUG_INFO(logger, "PaddingElimination::right edge of matmul would not included."); + LOG_DEBUG_INFO(logger, "PaddingElimination::right edge of MatMul would not included."); candidate_outputs.insert(cur); continue; } else { - ORT_THROW("PaddingElimination::found matmul node without input in subgraph."); + ORT_THROW("PaddingElimination::found MatMul node without input in subgraph."); } } else if (graph_utils::IsSupportedOptypeVersionAndDomain(*cur, "PythonOp", {1}, kMSDomain)) { if (subgraph.find(cur->MutableInputDefs()[0]) == subgraph.end()) { @@ -451,7 +366,8 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev ") due to embedding input is not in the sparse embedding input list."); continue; } - const ONNX_NAMESPACE::TensorProto* padding_initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name()); + const ONNX_NAMESPACE::TensorProto* padding_initializer = + graph_utils::GetConstantInitializer(graph, node.InputDefs()[2]->Name()); if (padding_initializer != nullptr && padding_initializer->dims_size() == 0 && ((padding_initializer->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32) || @@ -526,18 +442,6 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev } } - std::vector first_indices; - first_indices.push_back(0); - ONNX_NAMESPACE::TensorProto first_indice_const_tensor; - first_indice_const_tensor.set_name(graph.GenerateNodeArgName("indices")); - first_indice_const_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - first_indice_const_tensor.add_dims(first_indices.size()); - first_indice_const_tensor.set_raw_data(first_indices.data(), first_indices.size() * sizeof(int64_t)); - NodeArg* first_index_arg = &graph_utils::AddInitializer(graph, first_indice_const_tensor); - - // Get the first dim value of flattened input_ids which is batch_size * seq_len - NodeArg* first_dim = GetDimsValue(graph, reshape_output_args[0], first_index_arg, *embedding_node); - std::vector first_two_indices{0, 1}; ONNX_NAMESPACE::TensorProto first_two_indices_const_tensor; first_two_indices_const_tensor.set_name(graph.GenerateNodeArgName("first_two_indices")); @@ -553,11 +457,7 @@ Status PaddingElimination::ApplyImpl(Graph& graph, bool& modified, int graph_lev for (const auto& node : candidate_outputs) { for (uint32_t i = 0; i < node->InputDefs().size(); ++i) { if (subgraph.find(node->MutableInputDefs()[i]) != subgraph.end()) { - // Get a shape of the i-th input of the node with first index updated to value of first_dim - // which is batch_size * seq_len. This shape arg will be used as the shape input of GatherGrad - NodeArg* shape_arg_for_gather_grad = UpdateShape( - graph, node->MutableInputDefs()[i], first_dim, first_index_arg, *node); - InsertNodesForOutput(graph, *node, i, squeeze_out_arg, shape_arg_for_gather_grad, first_two_dims_arg, logger); + InsertNodesForOutput(graph, *node, i, squeeze_out_arg, first_two_dims_arg, logger); handled_output_count++; } } diff --git a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h index 3136210b8ac6..c4f283c30fdd 100644 --- a/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h +++ b/orttraining/orttraining/core/optimizer/compute_optimizer/padding_elimination.h @@ -16,15 +16,18 @@ namespace onnxruntime { * @Class PaddingElimination * * @brief Graph transformer that eliminates unnecessary padding computation caused by embedding sparsity. + * + * In transformer trainings, input_ids are usually padded to the same length, which is the max sequence length, + * so its shape is [batch_size, sequence_length] or [sequence_length, batch_size]. This graph transformer + * tries to MERGE the leading two dimensions and REMOVE the padding on the merged + * dimension, i.e, [batch_size, sequence_length, ...] -> [batch_size * sequence_length, ...] -> + + * * This transformer is implemented in the following steps: * 1. Iterate the graph and find the Embedding node that matches these requirements: - * (1) Its 2nd input is a graph input and its rank > 2 with the first two dimensions are dim_params which are - * actually batch_size and sequence_length. - * Note: Now only support the case of the first two dimensions to merged and remove the padding on the merged - * dimension, i.e, [batch_size, sequence_length, ...] -> [batch_size * sequence_length, ...] -> - * [valid_token, ... ]. In the future, we may support the case of any two consecutive dimensions to merged, - * such as [..., batch_size, sequence_length, ...]. - * (2) Its 3nd input is a scalar constant initializer which is the padding idx that should >= 0. + * 1.1 The 2nd input is a graph input and its rank > 2, with the first two dimensions, are: + * [batch_size, sequence_length]. Both dimensions can be symbolic or concrete dim values. + * 1.2 The 3rd input(padding idx) is a scalar constant initializer, and should >= 0. * 2. Append embedding node in node_to_scan_list. * Iterate the node_to_scan_list, for each node, * 2.1 Check if it is supported for pad elimination (from a pre-defined op list). If no, record this node as output @@ -42,6 +45,8 @@ namespace onnxruntime { * This is needed to ensure not to affect subsequent computations * * For example, given the following graph: + * 1. `input_0` is a tensor that is an in-direct output of ATen embedding node. + * 2. `input_1` is a tensor that is NOT a direct or in-direct output of ATen embedding node. * * embed.weight input_ids [batch_size, seq_length] padding_idx [1] scale_grad_by_freq sparse * \ \ / / / @@ -49,11 +54,14 @@ namespace onnxruntime { * \ \ / / / * \_________________\_________________________/________________/______________________/ * | - * Aten:embedding + * ATen:embedding * | - * | - * input | - * \ + * - - - - - - - - - - - -| + * | | + * input_0 | input_1 + * \ | / + * \__________ | ___________/ + * \ | / * Subgraph * * | @@ -83,40 +91,36 @@ namespace onnxruntime { * \______________________\________________________________/__________________/________________/ * | * Aten:embedding - * _ _ _ _ _ __ _ _ _ __ _ _| - * / | - * input_node | - * \ [batch_size, seq_length] | - * \ | - * \ [-1] | - * \ / | - * Reshape (valid_token_index) | - * \ / | - * ShrunkenGather | shape:[valid_token, ...] - * \ | - * shape:[valid_token] \ | - * \ | - * candidate_input_node | - * \ | - * \ | + * - - - - - - - - - - - - - - - - - - - - | + * | | + * input_0 | input_1 + * \ [batch_size, seq_length, ...] | | + * \ | [batch_size, seq_length, ...] + * \ [-1] | | + * \ / | | + * Reshape (valid_token_index) | Reshape (valid_token_index) + * \ / | \ / + * ShrunkenGather shape:[valid_token, ...] ShrunkenGather + * \ | / + * shape:[valid_token, ...] \ | / + * \ | / + * candidate_input_node | candidate_input_node + * \ | / + * \ | / * * Subgraph * * | - * | shape:[valid_token] + * | shape:[valid_token, ...] * | + * | (valid_token_index) + * | / ________________ (unflatten_dims), shape:[2], + * | / / value:[batch_size, seq_length] + * | / / + * PadAndUnflatten * | - * [batch_size*seq_length] (valid_token_index) | - * \ | / - * \ | / - * \ | / - * - * GatherGrad - * | - * Reshape - * | - * | [batch_size, valid_token] - * candidate_output_node + * | [batch_size, seq_length, ...] + * candidate_output_node * * * diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index b5655e12c1c6..64e44f796f5f 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -177,8 +177,12 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps, config.sparse_label_input_names)); +#if defined(USE_CUDA) || defined(USE_ROCM) + // Put this under CUDA/ROCM guard as it depends on PadAndUnflatten CUDA/ROCM kernel. + // Once we have a CPU kernel for PadAndUnflatten, we can remove the guard. transformers.emplace_back(std::make_unique(compatible_eps, config.sparse_embedding_input_names)); +#endif } } break; diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 69f168390f53..b6f2639512cf 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -1901,23 +1901,6 @@ TEST(GradientCheckerTest, GatherGrad) { } } -TEST(GradientCheckerTest, GatherGradGrad) { - float max_error; - GradientChecker gradient_checker; - OpDef op_def{"GatherGrad", kMSDomain, 1}; - TensorInfo shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType()); - TensorInfo indices_info({2, 2}, false, nullptr, DataTypeImpl::GetTensorType()); - TensorInfo x_info({2, 2, 3}); - std::vector> x_datas = {{6, 3}, {3, 5, 0, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}}; - - TensorShape y_shape{6, 3}; - int64_t axis = 0; - - ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {shape_info, indices_info, x_info}, {y_shape}, &max_error, - x_datas, {MakeAttribute("axis", axis)})); - EXPECT_IS_TINY(max_error); -} - void TestDropoutOp(float ratio, TensorShape& x_shape, bool default_ratio = true) { OpTester test("Dropout", 12, kOnnxDomain, false); if (default_ratio) ratio = 0.5f; @@ -3016,6 +2999,34 @@ TEST(GradientCheckerTest, TriluGrad) { } } +// TODO (enable once found why it fails on ROCM) +#if defined(USE_CUDA) +TEST(GradientCheckerTest, PadAndUnflattenGrad) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"PadAndUnflatten", kMSDomain, 1}; + TensorInfo shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType()); + TensorInfo indices_info({4}, false, nullptr, DataTypeImpl::GetTensorType()); + TensorInfo x_info({4, 3}); + std::vector> x_datas = {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {3, 5, 0, 1}, {5, 2}}; + + TensorInfo padded_out_info({5, 2, 3}, true); + TensorInfo out_shape_info({2}, false, nullptr, DataTypeImpl::GetTensorType()); + + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, indices_info, shape_info}, + {padded_out_info, out_shape_info}, &max_error, + x_datas, {}, true, false, &execution_providers)); + EXPECT_IS_TINY(max_error); +} +#endif + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index f7d1c17beb0e..3c1887f83438 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -5886,12 +5886,12 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): assert len([node.op_type for node in training_model.graph.node if node.op_type == "Sub"]) == 1 assert len([node.op_type for node in training_model.graph.node if node.op_type == "NonZero"]) == 1 assert len([node.op_type for node in training_model.graph.node if node.op_type == "Squeeze"]) == 1 - assert len([node.op_type for node in training_model.graph.node if node.op_type == "GatherGrad"]) == 1 + assert len([node.op_type for node in training_model.graph.node if node.op_type == "PadAndUnflatten"]) == 1 if case == 2: assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 2 else: assert len([node.op_type for node in training_model.graph.node if node.op_type == "ShrunkenGather"]) == 1 - gathergrad_node = [node for node in training_model.graph.node if node.op_type == "GatherGrad"][0] + gathergrad_node = [node for node in training_model.graph.node if node.op_type == "PadAndUnflatten"][0] def find_input_node_type(model, arg): result = [] @@ -6057,5 +6057,5 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): training_model = ort_model._torch_module._execution_manager(True)._onnx_models.optimized_model assert "ShrunkenGather" in [node.op_type for node in training_model.graph.node] - assert "GatherGrad" in [node.op_type for node in training_model.graph.node] + assert "PadAndUnflatten" in [node.op_type for node in training_model.graph.node] del os.environ["ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER"] diff --git a/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc b/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc new file mode 100644 index 000000000000..a800f17e59ae --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cuda/pad_and_unflatten_test.cc @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/common/tensor_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +#if defined(USE_CUDA) || defined(USE_ROCM) + +TEST(PadAndUnflattenTest, FloatType1D) { + std::vector input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f}; + std::vector indices = {1, 3, 5, 7, 9, 11}; + std::vector unflatten_dims = {5, 3}; + + std::vector output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, + 0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f}; + + std::vector full_flatten_dims = {15}; + + OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain); + test.AddInput("input", {6}, input); + test.AddInput("indices", {6}, indices); + test.AddInput("unflatten_dims", {2}, unflatten_dims); + test.AddOutput("output", {5, 3}, output); + test.AddOutput("full_flatten_dims", {1}, full_flatten_dims); + test.Run(); +} + +TEST(PadAndUnflattenTest, FloatType2D) { + std::vector input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f}; + std::vector indices = {1, 3, 4}; + std::vector unflatten_dims = {2, 3}; + + std::vector output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f, + 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f}; + + std::vector full_flatten_dims = {6, 3}; + + OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain); + test.AddInput("input", {3, 3}, input); + test.AddInput("indices", {3}, indices); + test.AddInput("unflatten_dims", {2}, unflatten_dims); + test.AddOutput("output", {2, 3, 3}, output); + test.AddOutput("full_flatten_dims", {2}, full_flatten_dims); + test.Run(); +} + +TEST(PadAndUnflattenTest, MLFloat16Type1D) { + std::vector input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f}; + std::vector indices = {1, 3, 5, 7, 9, 11}; + std::vector unflatten_dims = {5, 3}; + + std::vector output = {0.0f, 1.0f, 0.0f, 2.0f, 0.0f, 3.0f, 0.0f, 4.0f, + 0.0f, 5.0f, 0.0f, 6.0f, 0.0f, 0.0f, 0.0f}; + + std::vector full_flatten_dims = {15}; + + std::vector input_half; + input_half.resize(input.size()); + ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size())); + std::vector output_half; + output_half.resize(output.size()); + ConvertFloatToMLFloat16(output.data(), output_half.data(), int(output.size())); + + OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain); + test.AddInput("input", {6}, input_half); + test.AddInput("indices", {6}, indices); + test.AddInput("unflatten_dims", {2}, unflatten_dims); + test.AddOutput("output", {5, 3}, output_half); + test.AddOutput("full_flatten_dims", {1}, full_flatten_dims); + test.Run(); +} + +TEST(PadAndUnflattenTest, MLFloat16Type2D) { + std::vector input = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.f, 7.f, 8.f, 9.f}; + std::vector indices = {1, 3, 4}; + std::vector unflatten_dims = {2, 3}; + + std::vector output = {0.0f, 0.0f, 0.0f, 1.0f, 2.0f, 3.0f, 0.0f, 0.0f, 0.0f, + 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 0.0f, 0.0f, 0.0f}; + + std::vector full_flatten_dims = {6, 3}; + + std::vector input_half; + input_half.resize(input.size()); + ConvertFloatToMLFloat16(input.data(), input_half.data(), int(input.size())); + std::vector output_half; + output_half.resize(output.size()); + ConvertFloatToMLFloat16(output.data(), output_half.data(), int(output.size())); + + OpTester test("PadAndUnflatten", 1, onnxruntime::kMSDomain); + test.AddInput("input", {3, 3}, input_half); + test.AddInput("indices", {3}, indices); + test.AddInput("unflatten_dims", {2}, unflatten_dims); + test.AddOutput("output", {2, 3, 3}, output_half); + test.AddOutput("full_flatten_dims", {2}, full_flatten_dims); + test.Run(); +} + +#endif + +} // namespace test +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 78724f32b916..f9161cd77d71 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -197,6 +197,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inpl class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuantGrad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten); // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training @@ -437,7 +438,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant)>, BuildKernelCreateInfo, - + BuildKernelCreateInfo, // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training #ifdef ENABLE_TRAINING diff --git a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc new file mode 100644 index 000000000000..caf89ef840e0 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.cc @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/tensor/pad_and_unflatten.h" +#include "orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + PadAndUnflatten, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraints()) + .TypeConstraint("T_INT", DataTypeImpl::GetTensorType()) + .TypeConstraint("T_INDEX", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 2) + .OutputMemoryType(OrtMemTypeCPUOutput, 1), + PadAndUnflatten); + +// Put implementation in the anonymous namespace to avoid name collision in the global namespace. +namespace { + +template +struct PadAndUnflattenFunctor { + void operator()(cudaStream_t stream, + const int64_t input_element_count, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const Tensor& input_tensor, + const Tensor& indices_tensor, + Tensor& output_tensor) const { + typedef typename ToCudaType::MappedType CudaT; + const CudaT* input_data = reinterpret_cast(input_tensor.Data()); + + CUDA_CALL_THROW(cudaMemset(output_tensor.MutableDataRaw(), 0, output_tensor.Shape().Size() * sizeof(CudaT))); + PadAndUnflattenImpl(stream, input_element_count, output_element_stride_fdm, index_value_upper_bound, + input_data, indices_tensor.Data(), + reinterpret_cast(output_tensor.MutableData())); + } +}; + +} // namespace + +Status PadAndUnflatten::ComputeInternal(OpKernelContext* context) const { + const Tensor* input_tensor = context->Input(0); + const Tensor* indices_tensor = context->Input(1); + const Tensor* unflatten_dims_tensor = context->Input(2); // Parse the 1-D shape tensor. + ORT_ENFORCE(unflatten_dims_tensor->Shape().NumDimensions() == 1, + "unflatten_dims_tensor tensor must be 1-D.", unflatten_dims_tensor->Shape().NumDimensions()); + ORT_ENFORCE(unflatten_dims_tensor->Shape().Size() == 2, + "unflatten_dims_tensor tensor must contain 2 values.", unflatten_dims_tensor->Shape().Size()); + + const int64_t* dims_ptr = unflatten_dims_tensor->Data(); + const auto& input_shape = input_tensor->Shape(); + ORT_ENFORCE(input_shape[0] == indices_tensor->Shape()[0], + "The first dimension of input and indices must be the same."); + + std::vector output_shape_vec; + output_shape_vec.push_back(dims_ptr[0]); + output_shape_vec.push_back(dims_ptr[1]); + + std::vector full_size_flatten_shape_vec; + const int64_t flatten_dim_factor = dims_ptr[0] * dims_ptr[1]; + full_size_flatten_shape_vec.push_back(flatten_dim_factor); + + int64_t element_stride = 1; + for (size_t i = 1; i < input_shape.NumDimensions(); ++i) { + output_shape_vec.push_back(input_shape[i]); + full_size_flatten_shape_vec.push_back(input_shape[i]); + element_stride *= input_shape[i]; + } + + fast_divmod output_element_stride_fdm(static_cast(element_stride)); + auto output_shape = TensorShape(output_shape_vec); + Tensor* output_tensor = context->Output(0, output_shape); + + utils::MLTypeCallDispatcher t_disp(input_tensor->GetElementType()); + t_disp.Invoke(Stream(context), + input_shape.Size(), + output_element_stride_fdm, + flatten_dim_factor, + *input_tensor, + *indices_tensor, + *output_tensor); + + // Set input shape output tensor. + size_t rank = full_size_flatten_shape_vec.size(); + Tensor* input_shape_tensor = context->Output(1, {static_cast(rank)}); + TensorShape(full_size_flatten_shape_vec).CopyDims(input_shape_tensor->MutableData(), rank); + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.h b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.h new file mode 100644 index 000000000000..e86bf5e57490 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace cuda { + +class PadAndUnflatten final : public CudaKernel { + public: + PadAndUnflatten(const OpKernelInfo& info) : CudaKernel(info) { + } + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.cu new file mode 100644 index 000000000000..22a4f518dfa4 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.cu @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" + +namespace onnxruntime { +namespace cuda { + +constexpr int kBlockSize = 256; +constexpr int kNumUnroll = 4; + +template +__global__ void FillOutputWithIndexKernel(const CUDA_LONG N, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const T* input_data, + const int64_t* indices_data, + T* output_data) { + CUDA_LONG idx = blockDim.x * blockIdx.x + threadIdx.x; + CUDA_LONG id = idx * kNumUnroll; + + T input[kNumUnroll]; + if (id < N) { +#pragma unroll + for (int i = 0; i < kNumUnroll; ++i) { + CUDA_LONG li = id + i; + if (li < N) { + input[i] = input_data[li]; + } + } + } + +#pragma unroll + for (int i = 0; i < kNumUnroll; ++i) { + CUDA_LONG li = id + i; + if (li < N) { + int row_index, col_index; + output_element_stride_fdm.divmod(li, row_index, col_index); + assert(indices_data[row_index] < index_value_upper_bound); + output_data[indices_data[row_index] * output_element_stride_fdm.d_ + col_index] = input[i]; + } + } +} + +template +void PadAndUnflattenImpl(cudaStream_t stream, + const int64_t total_element_count, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const T* input_data, + const int64_t* indices_data, + T* output_data) { + const int blocksPerGrid = static_cast(CeilDiv(total_element_count, kBlockSize * kNumUnroll)); + FillOutputWithIndexKernel<<>>( + static_cast(total_element_count), + output_element_stride_fdm, + index_value_upper_bound, + input_data, + indices_data, + output_data); +} + +#define SPECIALIZED_RESTORE_FROM_MASK_IMPL(T) \ + template void PadAndUnflattenImpl(cudaStream_t stream, \ + const int64_t total_element_count, \ + const fast_divmod output_element_stride_fdm, \ + const int64_t index_value_upper_bound, \ + const T* input_data, \ + const int64_t* indices_data, \ + T* output_data); + +SPECIALIZED_RESTORE_FROM_MASK_IMPL(float) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(double) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(half) +SPECIALIZED_RESTORE_FROM_MASK_IMPL(BFloat16) + +#undef SPECIALIZED_RESTORE_FROM_MASK_IMPL + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.h b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.h new file mode 100644 index 000000000000..8b015179cebd --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/pad_and_unflatten_impl.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef USE_ROCM +#include "core/providers/rocm/shared_inc/rocm_utils.h" +#else +#include "core/providers/cuda/shared_inc/cuda_utils.h" +#endif + +namespace onnxruntime { +namespace cuda { + +template +void PadAndUnflattenImpl(cudaStream_t stream, + const int64_t total_element_count, + const fast_divmod output_element_stride_fdm, + const int64_t index_value_upper_bound, + const T* input_data, + const int64_t* indices_data, + T* output_data); + +} // namespace cuda +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index daa43fdde99d..82631fc04ff0 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -182,6 +182,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_float, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten); #if defined(ORT_USE_NCCL) || defined(USE_MPI) // P2P communication operators. @@ -378,6 +379,7 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // P2P communication operators. #if defined(ORT_USE_NCCL) || defined(USE_MPI)