Skip to content

Commit

Permalink
Add string regex functions
Browse files Browse the repository at this point in the history
Rename RE_MATCH function to REGEXP_FULL_MATCH

The latter is more descriptive and complaint with
duckdb's naming convention.

Introduce regexp utils based on re2

Refactor regex_full_match implementation

Functions added:

1.regexp_matches(string, regex)
Returns true if a part of string matches the
regex.

2. regexp_replace(string, regex, replacement)
Replaces the first occurrence of regex with the
replacement,

3. regexp_extract(string, regex[, group = 0])
Split the string along the regex and extract
first occurrence of group.

4. regexp_extract_all(string, regex[, group = 0])
Split the string along the regex and extract
all occurrences of group.
  • Loading branch information
gaurav8297 committed May 6, 2023
1 parent 3108a21 commit 9e3b399
Show file tree
Hide file tree
Showing 15 changed files with 503 additions and 76 deletions.
3 changes: 2 additions & 1 deletion src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ add_library(kuzu_common
profiler.cpp
type_utils.cpp
utils.cpp
string_utils.cpp)
string_utils.cpp
re2_regex.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_common>
Expand Down
90 changes: 90 additions & 0 deletions src/common/re2_regex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#include "common/re2_regex.h"

#include <regex>

namespace kuzu {
namespace common {

Regex::Regex(const std::string& pattern, RegexOptions options) {
regex::RE2::Options o;
o.set_case_sensitive(options == RegexOptions::CASE_INSENSITIVE);
regex = std::make_shared<regex::RE2>(regex::StringPiece(pattern), o);
}

bool RegexSearchInternal(const char* input, Match& match, const Regex& r, regex::RE2::Anchor anchor,
size_t start, size_t end) {
auto& regex = r.getRegex();
std::vector<regex::StringPiece> targetSubMatches;
auto submatchCount = regex.NumberOfCapturingGroups() + 1;
targetSubMatches.resize(submatchCount);
match.clear();
if (!regex.Match(regex::StringPiece(input), start, end, anchor, targetSubMatches.data(),
submatchCount)) {
return false;
}
for (auto& subMatch : targetSubMatches) {
SubMatch group_match(subMatch.ToString(), subMatch.data() - input);
match.emplace_back(group_match);
}
return true;
}

bool RegexFullMatch(const std::string& input, const Regex& regex) {
Match nopMatch;
return RegexSearchInternal(
input.c_str(), nopMatch, regex, regex::RE2::ANCHOR_BOTH, 0, input.size());
}

bool RegexPartialMatch(const std::string& input, const Regex& regex) {
Match nopMatch;
return RegexSearchInternal(
input.c_str(), nopMatch, regex, regex::RE2::UNANCHORED, 0, input.size());
}

std::string RegexReplace(
const std::string& input, const Regex& regex, const std::string& replacement) {
std::string result = input;
RE2::Replace(&result, regex.getRegex(), replacement);
return result;
}

bool RegexExtract(const std::string& input, Match& match, const Regex& regex) {
return RegexSearchInternal(
input.c_str(), match, regex, regex::RE2::UNANCHORED, 0, input.size());
}

static inline bool IsCharacter(char c) {
// Check if this character is not the middle of utf-8 character i.e. it shouldn't begin with 10
// XX XX XX
return (c & 0xc0) != 0x80;
}

std::vector<Match> RegexExtractAll(const std::string& input, const Regex& regex) {
std::vector<Match> matches;
size_t position = 0;
Match match;
while (RegexSearchInternal(
input.c_str(), match, regex, regex::RE2::UNANCHORED, position, input.size())) {
size_t newPosition = match.getPosition(0) + match.length(0);
if (newPosition == position) {
// Empty match found, increment the position manually
newPosition++;
while (newPosition < input.length() && !IsCharacter(input[newPosition])) {
newPosition++;
}
}
position = newPosition;
matches.emplace_back(std::move(match));
}
return matches;
}

std::string ParseCypherRegex(const std::string& input) {
// Cypher parses escape characters with 2 backslash eg. for expressing '.' requires '\\.'
// Since Regular Expression requires only 1 backslash '\.' we need to replace double slash
// with single
return std::regex_replace(input, std::regex(R"(\\\\)"), "\\");
}

} // namespace common
} // namespace kuzu
8 changes: 7 additions & 1 deletion src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ void BuiltInVectorOperations::registerStringOperations() {
vectorOperations.insert({CONCAT_FUNC_NAME, ConcatVectorOperation::getDefinitions()});
vectorOperations.insert({CONTAINS_FUNC_NAME, ContainsVectorOperation::getDefinitions()});
vectorOperations.insert({ENDS_WITH_FUNC_NAME, EndsWithVectorOperation::getDefinitions()});
vectorOperations.insert({RE_MATCH_FUNC_NAME, REMatchVectorOperation::getDefinitions()});
vectorOperations.insert({LCASE_FUNC_NAME, LowerVectorOperation::getDefinitions()});
vectorOperations.insert({LEFT_FUNC_NAME, LeftVectorOperation::getDefinitions()});
vectorOperations.insert({LENGTH_FUNC_NAME, LengthVectorOperation::getDefinitions()});
Expand All @@ -378,6 +377,13 @@ void BuiltInVectorOperations::registerStringOperations() {
vectorOperations.insert({TRIM_FUNC_NAME, TrimVectorOperation::getDefinitions()});
vectorOperations.insert({UCASE_FUNC_NAME, UpperVectorOperation::getDefinitions()});
vectorOperations.insert({UPPER_FUNC_NAME, UpperVectorOperation::getDefinitions()});
vectorOperations.insert(
{REGEXP_FULL_MATCH_FUNC_NAME, RegexpFullMatchVectorOperation::getDefinitions()});
vectorOperations.insert({REGEXP_MATCHES_FUNC_NAME, RegexpMatchesOperation::getDefinitions()});
vectorOperations.insert({REGEXP_REPLACE_FUNC_NAME, RegexpReplaceOperation::getDefinitions()});
vectorOperations.insert({REGEXP_EXTRACT_FUNC_NAME, RegexpExtractOperation::getDefinitions()});
vectorOperations.insert(
{REGEXP_EXTRACT_ALL_FUNC_NAME, RegexpExtractAllOperation::getDefinitions()});
}

void BuiltInVectorOperations::registerCastOperations() {
Expand Down
82 changes: 71 additions & 11 deletions src/function/vector_string_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
#include "function/string/operations/left_operation.h"
#include "function/string/operations/length_operation.h"
#include "function/string/operations/lpad_operation.h"
#include "function/string/operations/reg_expr_operation.h"
#include "function/string/operations/regexp_extract_all_operation.h"
#include "function/string/operations/regexp_extract_operation.h"
#include "function/string/operations/regexp_full_match_operation.h"
#include "function/string/operations/regexp_matches_operation.h"
#include "function/string/operations/regexp_replace_operation.h"
#include "function/string/operations/repeat_operation.h"
#include "function/string/operations/right_operation.h"
#include "function/string/operations/rpad_operation.h"
Expand Down Expand Up @@ -58,16 +62,6 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> EndsWithVectorOperation:
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> REMatchVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(RE_MATCH_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, BOOL,
BinaryExecFunction<ku_string_t, ku_string_t, uint8_t, operation::REMatch>,
BinarySelectFunction<ku_string_t, ku_string_t, operation::REMatch>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> LeftVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(LEFT_FUNC_NAME,
Expand Down Expand Up @@ -141,5 +135,71 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> SubStrVectorOperation::g
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
RegexpFullMatchVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_FULL_MATCH_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, BOOL,
BinaryExecFunction<ku_string_t, ku_string_t, uint8_t, operation::RegexpFullMatch>,
BinarySelectFunction<ku_string_t, ku_string_t, operation::RegexpFullMatch>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> RegexpMatchesOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_MATCHES_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, BOOL,
BinaryExecFunction<ku_string_t, ku_string_t, uint8_t, operation::RegexpMatches>,
BinarySelectFunction<ku_string_t, ku_string_t, operation::RegexpMatches>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> RegexpReplaceOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_REPLACE_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING, STRING}, STRING,
TernaryStringExecFunction<ku_string_t, ku_string_t, ku_string_t, ku_string_t,
operation::RegexpReplace>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> RegexpExtractOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_EXTRACT_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, STRING,
BinaryStringExecFunction<ku_string_t, ku_string_t, ku_string_t, operation::RegexpExtract>,
false /* isVarLength */));
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_EXTRACT_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING, INT64}, STRING,
TernaryStringExecFunction<ku_string_t, ku_string_t, int64_t, ku_string_t,
operation::RegexpExtract>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
RegexpExtractAllOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_EXTRACT_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, VAR_LIST,
BinaryStringExecFunction<ku_string_t, ku_string_t, list_entry_t,
operation::RegexpExtractAll>,
nullptr, bindFunc, false /* isVarLength */));
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_EXTRACT_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING, INT64}, VAR_LIST,
TernaryStringExecFunction<ku_string_t, ku_string_t, int64_t, list_entry_t,
operation::RegexpExtractAll>,
nullptr, bindFunc, false /* isVarLength */));
return definitions;
}

std::unique_ptr<FunctionBindData> RegexpExtractAllOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
return std::make_unique<FunctionBindData>(DataType(std::make_unique<DataType>(STRING)));
}

} // namespace function
} // namespace kuzu
6 changes: 5 additions & 1 deletion src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ const std::string ARRAY_EXTRACT_FUNC_NAME = "ARRAY_EXTRACT";
const std::string CONCAT_FUNC_NAME = "CONCAT";
const std::string CONTAINS_FUNC_NAME = "CONTAINS";
const std::string ENDS_WITH_FUNC_NAME = "ENDS_WITH";
const std::string RE_MATCH_FUNC_NAME = "RE_MATCH";
const std::string LCASE_FUNC_NAME = "LCASE";
const std::string LEFT_FUNC_NAME = "LEFT";
const std::string LENGTH_FUNC_NAME = "LENGTH";
Expand All @@ -137,6 +136,11 @@ const std::string SUFFIX_FUNC_NAME = "SUFFIX";
const std::string TRIM_FUNC_NAME = "TRIM";
const std::string UCASE_FUNC_NAME = "UCASE";
const std::string UPPER_FUNC_NAME = "UPPER";
const std::string REGEXP_FULL_MATCH_FUNC_NAME = "REGEXP_FULL_MATCH";
const std::string REGEXP_MATCHES_FUNC_NAME = "REGEXP_MATCHES";
const std::string REGEXP_REPLACE_FUNC_NAME = "REGEXP_REPLACE";
const std::string REGEXP_EXTRACT_FUNC_NAME = "REGEXP_EXTRACT";
const std::string REGEXP_EXTRACT_ALL_FUNC_NAME = "REGEXP_EXTRACT_ALL";

// Date functions.
const std::string DATE_PART_FUNC_NAME = "DATE_PART";
Expand Down
69 changes: 69 additions & 0 deletions src/include/common/re2_regex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#pragma once

#include <cstdint>
#include <memory>
#include <stdexcept>
#include <string>

#include "re2.h"

namespace kuzu {
namespace common {
enum class RegexOptions : uint8_t { NONE, CASE_INSENSITIVE };

class Regex {
public:
Regex(const std::string& pattern, RegexOptions options = RegexOptions::NONE);
const kuzu::regex::RE2& getRegex() const { return *regex; }

private:
std::shared_ptr<kuzu::regex::RE2> regex;
};

/**
* This represents a single submatch coming from RE2::Match
*/
class SubMatch {
public:
SubMatch(const std::string& text, const uint32_t& position) : text{text}, position{position} {};
const std::string& getText() const { return text; }
const uint32_t getPosition() const { return position; }

private:
std::string text;
uint32_t position;
};

/**
* This represents a list of submatches in a single iteration of RE2::Match
* i.e. matching Regex("([a-z ]+)_?") on "hello_world" will have two
* submatches, "hello_" and "hello".
*/
class Match {
public:
SubMatch& getSubMatch(uint64_t index) {
if (index >= subMatches.size()) {
throw std::runtime_error("RE2: Match index is out of range");
}
return subMatches[index];
}
std::string getText(uint64_t index) { return getSubMatch(index).getText(); }
uint64_t getPosition(uint64_t index) { return getSubMatch(index).getPosition(); }
uint64_t length(uint64_t index) { return getSubMatch(index).getText().size(); }
void emplace_back(SubMatch& group) { subMatches.emplace_back(group); }
void clear() { subMatches.clear(); }

private:
std::vector<SubMatch> subMatches;
};

bool RegexFullMatch(const std::string& input, const Regex& regex);
bool RegexPartialMatch(const std::string& input, const Regex& regex);
std::string RegexReplace(
const std::string& input, const Regex& regex, const std::string& replacement);
bool RegexExtract(const std::string& input, Match& match, const Regex& regex);
std::vector<Match> RegexExtractAll(const std::string& input, const Regex& regex);
std::string ParseCypherRegex(const std::string& input);

} // namespace common
} // namespace kuzu
26 changes: 0 additions & 26 deletions src/include/function/string/operations/reg_expr_operation.h

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

#include "common/re2_regex.h"
#include "common/types/ku_list.h"
#include "common/types/ku_string.h"
#include "common/vector/value_vector.h"
#include "common/vector/value_vector_utils.h"

namespace kuzu::function::operation {
struct RegexpExtractAll {
static inline void operation(common::ku_string_t& value, common::ku_string_t& pattern,
std::int64_t& group, common::list_entry_t& result, common::ValueVector& resultVector) {
std::vector<common::Match> matches = common::RegexExtractAll(
value.getAsString(), common::ParseCypherRegex(pattern.getAsString()));

result = common::ListVector::addList(&resultVector, matches.size());
auto resultValues = common::ListVector::getListValues(&resultVector, result);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
auto numBytesPerValue = resultDataVector->getNumBytesPerValue();

for (auto match : matches) {
common::ku_string_t kuString;
copyToKuzuString(match.getText(group), kuString, *resultDataVector);
common::ValueVectorUtils::copyValue(resultValues, *resultDataVector,
reinterpret_cast<uint8_t*>(&kuString), *resultDataVector);
resultValues += numBytesPerValue;
}
}

static inline void operation(common::ku_string_t& value, common::ku_string_t& pattern,
common::list_entry_t& result, common::ValueVector& resultVector) {
int64_t defaultGroup = 0;
operation(value, pattern, defaultGroup, result, resultVector);
}

static void copyToKuzuString(
const std::string& value, common::ku_string_t& kuString, common::ValueVector& valueVector) {
if (!common::ku_string_t::isShortString(value.length())) {
kuString.overflowPtr = reinterpret_cast<uint64_t>(
common::StringVector::getInMemOverflowBuffer(&valueVector)
->allocateSpace(value.length()));
}
kuString.set(value);
}
};
} // namespace kuzu::function::operation
Loading

0 comments on commit 9e3b399

Please sign in to comment.