Skip to content

Commit

Permalink
Obtain string slices in first pass and remove CountSubstrings functio…
Browse files Browse the repository at this point in the history
…n entirely
  • Loading branch information
adityagoel4512 committed Jan 9, 2024
1 parent c333ad5 commit d303e5b
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 57 deletions.
95 changes: 38 additions & 57 deletions onnxruntime/core/providers/cpu/nn/string_split.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "string_split.h"
#include "core/providers/cpu/nn/string_split.h"
#include <algorithm>
#include <limits>
#include <string>
#include "core/common/common.h"
namespace onnxruntime {

Expand All @@ -13,75 +15,47 @@ ONNX_CPU_OPERATOR_KERNEL(StringSplit, 20,
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int64_t>()),
StringSplit);

/// Count the number of instances of substring ``substr`` in ``str``. If ``substr`` is an empty string it counts the
/// number of whitespace delimited words.
int64_t CountSubstrings(std::string_view str, std::string_view substr) {
if (substr.empty()) {
// Count consecutive whitespace as one delimiter
int64_t count = 0;
size_t pos = str.find_first_not_of(" ");
while (pos != std::string::npos) {
++count;
pos = str.find_first_not_of(" ", str.find_first_of(" ", pos));
}
return count;
} else {
int64_t count = 0;
size_t pos = 0;
while (pos != std::string::npos) {
++count;
pos = str.find(substr, pos);
if (pos != std::string::npos) {
pos += substr.length();
}
}
return count;
}
}

/// Fill substrings of ``str`` based on split delimiter ``delimiter`` into ``output`` span. Restrict maximum number of
/// generated substrings to ``max_tokens``. The function returns the number of substrings generated (this is less or
/// equal to ``max_tokens``).
int64_t FillSubstrings(std::string_view str, std::string_view delimiter,
gsl::details::span_iterator<std::string> output, size_t max_tokens) {
InlinedVector<std::string_view> FillSubstrings(std::string_view str, std::string_view delimiter, int64_t max_splits) {
InlinedVector<std::string_view> output;
if (str.empty()) {
return 0;
return output;
}
if (delimiter.empty()) {
// Count consecutive whitespace as one delimiter. Preceding and trailing whitespace is meant to be ignored.
size_t pos = str.find_first_not_of(" ");
int64_t token_count = 0;
while (pos != std::string::npos) {
if (++token_count == max_tokens) {
if (token_count++ == max_splits) {
// trim down last substring as required in specification
size_t next_pos = str.length() - 1;
while (str[next_pos] == ' ') {
next_pos--;
}
*output = str.substr(pos, next_pos - pos + 1);
output.push_back(str.substr(pos, next_pos - pos + 1));
break;
} else {
auto next_pos = str.find_first_of(" ", pos);
*output = str.substr(pos, next_pos - pos);
output.push_back(str.substr(pos, next_pos - pos));
pos = str.find_first_not_of(" ", next_pos);
}

output++;
}
return token_count;
return output;
} else {
size_t pos = 0;
int64_t token_count = 0;
while (pos != std::string::npos) {
auto next_pos = token_count == max_tokens - 1 ? std::string::npos : str.find(delimiter, pos);
*output++ = str.substr(pos, next_pos - pos);
token_count++;
if (next_pos == std::string::npos) {
auto next_pos = str.find(delimiter, pos);
if (token_count++ == max_splits || next_pos == std::string::npos) {
output.push_back(str.substr(pos));
break;
}
output.push_back(str.substr(pos, next_pos - pos));
pos = next_pos + delimiter.size();
}
return token_count;
return output;
}
}

Expand All @@ -94,29 +68,36 @@ Status StringSplit::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
auto input_data = input->template DataAsSpan<std::string>();

int64_t last_dim = 0;
for (const auto& str : input_data) {
last_dim = std::max(last_dim, CountSubstrings(str, delimiter_));
// Set up number of tokens output
auto num_tokens_data = context->Output(1, input->Shape())->template MutableDataAsSpan<int64_t>();
auto num_tokens_iter = num_tokens_data.begin();

int64_t last_dim = 1;

InlinedVector<InlinedVector<std::string_view>> input_slices;
input_slices.reserve(input_data.size());
auto input_slice_iterator = input_slices.begin();
for (auto input_iter = input_data.begin(); input_iter != input_data.end(); input_iter++, input_slice_iterator++, num_tokens_iter++) {
auto substrs = FillSubstrings(*input_iter, delimiter_, maxsplit_);
auto substr_count = static_cast<int64_t>(substrs.size());
input_slices.push_back(substrs);
last_dim = std::max(last_dim, substr_count);
*num_tokens_iter = substr_count;
}

last_dim = std::min(last_dim, maxsplit_ + 1);

// Set up splits output
auto splits_shape = input->Shape().AsShapeVector();
if (last_dim > 0) {
splits_shape.push_back(last_dim);
}
auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan<std::string>();
auto output_splits_iter = splits_data.begin();

// Set up number of tokens output
auto* num_substrings = context->Output(1, input->Shape());
auto num_substrings_data = num_substrings->template MutableDataAsSpan<int64_t>();
auto output_num_tokens_iter = num_substrings_data.begin();
splits_shape.push_back(last_dim);

for (auto input_iter = input_data.begin(); input_iter != input_data.end();
input_iter++, output_splits_iter += last_dim, output_num_tokens_iter++) {
*output_num_tokens_iter = FillSubstrings(*input_iter, delimiter_, output_splits_iter, last_dim);
auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan<std::string>();
auto slices_iter = input_slices.begin();
for (auto output_splits_iter = splits_data.begin(); output_splits_iter != splits_data.end(); output_splits_iter += last_dim, slices_iter++) {
const auto output_slices = *slices_iter;
std::copy(output_slices.begin(), output_slices.end(), output_splits_iter);
}

return Status::OK();
}

Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/test/providers/cpu/nn/string_split_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,22 @@ TEST(StringSplit, EmptyInputTest) {
test.Run();
}

TEST(StringSplit, OnlyEmptyInputTest) {
OpTester test("StringSplit", 20);
test.AddAttribute<std::string>("delimiter", "*");
test.AddInput<std::string>("X", {1, 2, 1}, {"", ""});
test.AddOutput<std::string>("Y", {1, 2, 1, 1}, {"", ""});
test.AddOutput<int64_t>("Z", {1, 2, 1}, {0, 0});
test.Run();
}

TEST(StringSplit, OnlyEmptyNoDelimiterInputTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 2, 1}, {"", ""});
test.AddOutput<std::string>("Y", {1, 2, 1, 1}, {"", ""});
test.AddOutput<int64_t>("Z", {1, 2, 1}, {0, 0});
test.Run();
}

} // namespace test
} // namespace onnxruntime

0 comments on commit d303e5b

Please sign in to comment.