Skip to content

Commit

Permalink
Rename RE_MATCH function to REGEXP_FULL_MATCH
Browse files Browse the repository at this point in the history
The latter is more descriptive and complaint with
duckdb's naming convention.

Introduce regexp utils based on re2

Refactor regex_full_match implementation

Add string regex functions

Functions added:

1.regexp_matches(string, regex)
Returns true if a part of string matches the
regex.

2. regexp_replace(string, regex, replacement)
Replaces the first occurrence of regex with the
replacement,

3. regexp_extract(string, regex[, group = 0])
Split the string along the regex and extract
first occurrence of group.

4. regexp_extract_all(string, regex[, group = 0])
Split the string along the regex and extract
all occurrences of group.
  • Loading branch information
gaurav8297 committed May 6, 2023
1 parent bfc1b0d commit 13ace76
Show file tree
Hide file tree
Showing 15 changed files with 503 additions and 76 deletions.
3 changes: 2 additions & 1 deletion src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ add_library(kuzu_common
profiler.cpp
type_utils.cpp
utils.cpp
string_utils.cpp)
string_utils.cpp
re2_regex.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_common>
Expand Down
90 changes: 90 additions & 0 deletions src/common/re2_regex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#include "common/re2_regex.h"

#include <regex>

namespace kuzu {
namespace common {

Regex::Regex(const std::string& pattern, RegexOptions options) {
regex::RE2::Options o;
o.set_case_sensitive(options == RegexOptions::CASE_INSENSITIVE);
regex = std::make_shared<regex::RE2>(regex::StringPiece(pattern), o);
}

bool RegexSearchInternal(const char* input, Match& match, const Regex& r, regex::RE2::Anchor anchor,
size_t start, size_t end) {
auto& regex = r.getRegex();
std::vector<regex::StringPiece> targetSubMatches;
auto submatchCount = regex.NumberOfCapturingGroups() + 1;
targetSubMatches.resize(submatchCount);
match.clear();
if (!regex.Match(regex::StringPiece(input), start, end, anchor, targetSubMatches.data(),
submatchCount)) {
return false;
}
for (auto& subMatch : targetSubMatches) {
SubMatch group_match(subMatch.ToString(), subMatch.data() - input);
match.emplace_back(group_match);
}
return true;
}

bool RegexFullMatch(const std::string& input, const Regex& regex) {
Match nopMatch;
return RegexSearchInternal(
input.c_str(), nopMatch, regex, regex::RE2::ANCHOR_BOTH, 0, input.size());
}

bool RegexPartialMatch(const std::string& input, const Regex& regex) {
Match nopMatch;
return RegexSearchInternal(
input.c_str(), nopMatch, regex, regex::RE2::UNANCHORED, 0, input.size());
}

std::string RegexReplace(
const std::string& input, const Regex& regex, const std::string& replacement) {
std::string result = input;
RE2::Replace(&result, regex.getRegex(), replacement);
return result;
}

bool RegexExtract(const std::string& input, Match& match, const Regex& regex) {
return RegexSearchInternal(
input.c_str(), match, regex, regex::RE2::UNANCHORED, 0, input.size());
}

static inline bool IsCharacter(char c) {
// Check if this character is not the middle of utf-8 character i.e. it shouldn't begin with 10
// XX XX XX
return (c & 0xc0) != 0x80;
}

std::vector<Match> RegexExtractAll(const std::string& input, const Regex& regex) {
std::vector<Match> matches;
size_t position = 0;
Match match;
while (RegexSearchInternal(
input.c_str(), match, regex, regex::RE2::UNANCHORED, position, input.size())) {
size_t newPosition = match.getPosition(0) + match.length(0);
if (newPosition == position) {
// Empty match found, increment the position manually
newPosition++;
while (newPosition < input.length() && !IsCharacter(input[newPosition])) {
newPosition++;
}
}
position = newPosition;
matches.emplace_back(std::move(match));
}
return matches;
}

std::string ParseCypherRegex(const std::string& input) {
// Cypher parses escape characters with 2 backslash eg. for expressing '.' requires '\\.'
// Since Regular Expression requires only 1 backslash '\.' we need to replace double slash
// with single
return std::regex_replace(input, std::regex(R"(\\\\)"), "\\");
}

} // namespace common
} // namespace kuzu
8 changes: 7 additions & 1 deletion src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ void BuiltInVectorOperations::registerStringOperations() {
vectorOperations.insert({CONCAT_FUNC_NAME, ConcatVectorOperation::getDefinitions()});
vectorOperations.insert({CONTAINS_FUNC_NAME, ContainsVectorOperation::getDefinitions()});
vectorOperations.insert({ENDS_WITH_FUNC_NAME, EndsWithVectorOperation::getDefinitions()});
vectorOperations.insert({RE_MATCH_FUNC_NAME, REMatchVectorOperation::getDefinitions()});
vectorOperations.insert({LCASE_FUNC_NAME, LowerVectorOperation::getDefinitions()});
vectorOperations.insert({LEFT_FUNC_NAME, LeftVectorOperation::getDefinitions()});
vectorOperations.insert({LENGTH_FUNC_NAME, LengthVectorOperation::getDefinitions()});
Expand All @@ -378,6 +377,13 @@ void BuiltInVectorOperations::registerStringOperations() {
vectorOperations.insert({TRIM_FUNC_NAME, TrimVectorOperation::getDefinitions()});
vectorOperations.insert({UCASE_FUNC_NAME, UpperVectorOperation::getDefinitions()});
vectorOperations.insert({UPPER_FUNC_NAME, UpperVectorOperation::getDefinitions()});
vectorOperations.insert(
{REGEXP_FULL_MATCH_FUNC_NAME, RegexpFullMatchVectorOperation::getDefinitions()});
vectorOperations.insert({REGEXP_MATCHES_FUNC_NAME, RegexpMatchesOperation::getDefinitions()});
vectorOperations.insert({REGEXP_REPLACE_FUNC_NAME, RegexpReplaceOperation::getDefinitions()});
vectorOperations.insert({REGEXP_EXTRACT_FUNC_NAME, RegexpExtractOperation::getDefinitions()});
vectorOperations.insert(
{REGEXP_EXTRACT_ALL_FUNC_NAME, RegexpExtractAllOperation::getDefinitions()});
}

void BuiltInVectorOperations::registerCastOperations() {
Expand Down
82 changes: 71 additions & 11 deletions src/function/vector_string_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
#include "function/string/operations/left_operation.h"
#include "function/string/operations/length_operation.h"
#include "function/string/operations/lpad_operation.h"
#include "function/string/operations/reg_expr_operation.h"
#include "function/string/operations/regexp_extract_all_operation.h"
#include "function/string/operations/regexp_extract_operation.h"
#include "function/string/operations/regexp_full_match_operation.h"
#include "function/string/operations/regexp_matches_operation.h"
#include "function/string/operations/regexp_replace_operation.h"
#include "function/string/operations/repeat_operation.h"
#include "function/string/operations/right_operation.h"
#include "function/string/operations/rpad_operation.h"
Expand Down Expand Up @@ -58,16 +62,6 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> EndsWithVectorOperation:
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> REMatchVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(RE_MATCH_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, BOOL,
BinaryExecFunction<ku_string_t, ku_string_t, uint8_t, operation::REMatch>,
BinarySelectFunction<ku_string_t, ku_string_t, operation::REMatch>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> LeftVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(LEFT_FUNC_NAME,
Expand Down Expand Up @@ -141,5 +135,71 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> SubStrVectorOperation::g
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
RegexpFullMatchVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_FULL_MATCH_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, BOOL,
BinaryExecFunction<ku_string_t, ku_string_t, uint8_t, operation::RegexpFullMatch>,
BinarySelectFunction<ku_string_t, ku_string_t, operation::RegexpFullMatch>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> RegexpMatchesOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_MATCHES_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, BOOL,
BinaryExecFunction<ku_string_t, ku_string_t, uint8_t, operation::RegexpMatches>,
BinarySelectFunction<ku_string_t, ku_string_t, operation::RegexpMatches>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> RegexpReplaceOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_REPLACE_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING, STRING}, STRING,
TernaryStringExecFunction<ku_string_t, ku_string_t, ku_string_t, ku_string_t,
operation::RegexpReplace>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> RegexpExtractOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_EXTRACT_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, STRING,
BinaryStringExecFunction<ku_string_t, ku_string_t, ku_string_t, operation::RegexpExtract>,
false /* isVarLength */));
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_EXTRACT_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING, INT64}, STRING,
TernaryStringExecFunction<ku_string_t, ku_string_t, int64_t, ku_string_t,
operation::RegexpExtract>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
RegexpExtractAllOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_EXTRACT_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, VAR_LIST,
BinaryStringExecFunction<ku_string_t, ku_string_t, list_entry_t,
operation::RegexpExtractAll>,
nullptr, bindFunc, false /* isVarLength */));
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_EXTRACT_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING, INT64}, VAR_LIST,
TernaryStringExecFunction<ku_string_t, ku_string_t, int64_t, list_entry_t,
operation::RegexpExtractAll>,
nullptr, bindFunc, false /* isVarLength */));
return definitions;
}

std::unique_ptr<FunctionBindData> RegexpExtractAllOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
return std::make_unique<FunctionBindData>(DataType(std::make_unique<DataType>(STRING)));
}

} // namespace function
} // namespace kuzu
6 changes: 5 additions & 1 deletion src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ const std::string ARRAY_EXTRACT_FUNC_NAME = "ARRAY_EXTRACT";
const std::string CONCAT_FUNC_NAME = "CONCAT";
const std::string CONTAINS_FUNC_NAME = "CONTAINS";
const std::string ENDS_WITH_FUNC_NAME = "ENDS_WITH";
const std::string RE_MATCH_FUNC_NAME = "RE_MATCH";
const std::string LCASE_FUNC_NAME = "LCASE";
const std::string LEFT_FUNC_NAME = "LEFT";
const std::string LENGTH_FUNC_NAME = "LENGTH";
Expand All @@ -137,6 +136,11 @@ const std::string SUFFIX_FUNC_NAME = "SUFFIX";
const std::string TRIM_FUNC_NAME = "TRIM";
const std::string UCASE_FUNC_NAME = "UCASE";
const std::string UPPER_FUNC_NAME = "UPPER";
const std::string REGEXP_FULL_MATCH_FUNC_NAME = "REGEXP_FULL_MATCH";
const std::string REGEXP_MATCHES_FUNC_NAME = "REGEXP_MATCHES";
const std::string REGEXP_REPLACE_FUNC_NAME = "REGEXP_REPLACE";
const std::string REGEXP_EXTRACT_FUNC_NAME = "REGEXP_EXTRACT";
const std::string REGEXP_EXTRACT_ALL_FUNC_NAME = "REGEXP_EXTRACT_ALL";

// Date functions.
const std::string DATE_PART_FUNC_NAME = "DATE_PART";
Expand Down
69 changes: 69 additions & 0 deletions src/include/common/re2_regex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#pragma once

#include <cstdint>
#include <memory>
#include <stdexcept>
#include <string>

#include "re2.h"

namespace kuzu {
namespace common {
enum class RegexOptions : uint8_t { NONE, CASE_INSENSITIVE };

class Regex {
public:
Regex(const std::string& pattern, RegexOptions options = RegexOptions::NONE);
const kuzu::regex::RE2& getRegex() const { return *regex; }

private:
std::shared_ptr<kuzu::regex::RE2> regex;
};

/**
* This represents a single submatch coming from RE2::Match
*/
class SubMatch {
public:
SubMatch(const std::string& text, const uint32_t& position) : text{text}, position{position} {};
const std::string& getText() const { return text; }
const uint32_t getPosition() const { return position; }

private:
std::string text;
uint32_t position;
};

/**
* This represents a list of submatches in a single iteration of RE2::Match
* i.e. matching Regex("([a-z ]+)_?") on "hello_world" will have two
* submatches, "hello_" and "hello".
*/
class Match {
public:
SubMatch& getSubMatch(uint64_t index) {
if (index >= subMatches.size()) {
throw std::runtime_error("RE2: Match index is out of range");
}
return subMatches[index];
}
std::string getText(uint64_t index) { return getSubMatch(index).getText(); }
uint64_t getPosition(uint64_t index) { return getSubMatch(index).getPosition(); }
uint64_t length(uint64_t index) { return getSubMatch(index).getText().size(); }
void emplace_back(SubMatch& group) { subMatches.emplace_back(group); }
void clear() { subMatches.clear(); }

private:
std::vector<SubMatch> subMatches;
};

bool RegexFullMatch(const std::string& input, const Regex& regex);
bool RegexPartialMatch(const std::string& input, const Regex& regex);
std::string RegexReplace(
const std::string& input, const Regex& regex, const std::string& replacement);
bool RegexExtract(const std::string& input, Match& match, const Regex& regex);
std::vector<Match> RegexExtractAll(const std::string& input, const Regex& regex);
std::string ParseCypherRegex(const std::string& input);

} // namespace common
} // namespace kuzu
26 changes: 0 additions & 26 deletions src/include/function/string/operations/reg_expr_operation.h

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

#include "common/re2_regex.h"
#include "common/types/ku_list.h"
#include "common/types/ku_string.h"
#include "common/vector/value_vector.h"
#include "common/vector/value_vector_utils.h"

namespace kuzu::function::operation {
struct RegexpExtractAll {
static inline void operation(common::ku_string_t& value, common::ku_string_t& pattern,
std::int64_t& group, common::list_entry_t& result, common::ValueVector& resultVector) {
std::vector<common::Match> matches = common::RegexExtractAll(
value.getAsString(), common::ParseCypherRegex(pattern.getAsString()));

result = common::ListVector::addList(&resultVector, matches.size());
auto resultValues = common::ListVector::getListValues(&resultVector, result);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
auto numBytesPerValue = resultDataVector->getNumBytesPerValue();

for (auto match : matches) {
common::ku_string_t kuString;
copyToKuzuString(match.getText(group), kuString, *resultDataVector);
common::ValueVectorUtils::copyValue(resultValues, *resultDataVector,
reinterpret_cast<uint8_t*>(&kuString), *resultDataVector);
resultValues += numBytesPerValue;
}
}

static inline void operation(common::ku_string_t& value, common::ku_string_t& pattern,
common::list_entry_t& result, common::ValueVector& resultVector) {
int64_t defaultGroup = 0;
operation(value, pattern, defaultGroup, result, resultVector);
}

static void copyToKuzuString(
const std::string& value, common::ku_string_t& kuString, common::ValueVector& valueVector) {
if (!common::ku_string_t::isShortString(value.length())) {
kuString.overflowPtr = reinterpret_cast<uint64_t>(
common::StringVector::getInMemOverflowBuffer(&valueVector)
->allocateSpace(value.length()));
}
kuString.set(value);
}
};
} // namespace kuzu::function::operation
Loading

0 comments on commit 13ace76

Please sign in to comment.