Skip to content

Commit

Permalink
Add string regex functions
Browse files Browse the repository at this point in the history
Rename RE_MATCH function to REGEXP_FULL_MATCH

The latter is more descriptive and complaint with
duckdb's naming convention.

Introduce regexp utils based on re2

Refactor regex_full_match implementation

Functions added:

1.regexp_matches(string, regex)
Returns true if a part of string matches the
regex.

2. regexp_replace(string, regex, replacement)
Replaces the first occurrence of regex with the
replacement,

3. regexp_extract(string, regex[, group = 0])
Split the string along the regex and extract
first occurrence of group.

4. regexp_extract_all(string, regex[, group = 0])
Split the string along the regex and extract
all occurrences of group.
  • Loading branch information
gaurav8297 committed May 8, 2023
1 parent 3108a21 commit a682d24
Show file tree
Hide file tree
Showing 13 changed files with 405 additions and 75 deletions.
8 changes: 7 additions & 1 deletion src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ void BuiltInVectorOperations::registerStringOperations() {
vectorOperations.insert({CONCAT_FUNC_NAME, ConcatVectorOperation::getDefinitions()});
vectorOperations.insert({CONTAINS_FUNC_NAME, ContainsVectorOperation::getDefinitions()});
vectorOperations.insert({ENDS_WITH_FUNC_NAME, EndsWithVectorOperation::getDefinitions()});
vectorOperations.insert({RE_MATCH_FUNC_NAME, REMatchVectorOperation::getDefinitions()});
vectorOperations.insert({LCASE_FUNC_NAME, LowerVectorOperation::getDefinitions()});
vectorOperations.insert({LEFT_FUNC_NAME, LeftVectorOperation::getDefinitions()});
vectorOperations.insert({LENGTH_FUNC_NAME, LengthVectorOperation::getDefinitions()});
Expand All @@ -378,6 +377,13 @@ void BuiltInVectorOperations::registerStringOperations() {
vectorOperations.insert({TRIM_FUNC_NAME, TrimVectorOperation::getDefinitions()});
vectorOperations.insert({UCASE_FUNC_NAME, UpperVectorOperation::getDefinitions()});
vectorOperations.insert({UPPER_FUNC_NAME, UpperVectorOperation::getDefinitions()});
vectorOperations.insert(
{REGEXP_FULL_MATCH_FUNC_NAME, RegexpFullMatchVectorOperation::getDefinitions()});
vectorOperations.insert({REGEXP_MATCHES_FUNC_NAME, RegexpMatchesOperation::getDefinitions()});
vectorOperations.insert({REGEXP_REPLACE_FUNC_NAME, RegexpReplaceOperation::getDefinitions()});
vectorOperations.insert({REGEXP_EXTRACT_FUNC_NAME, RegexpExtractOperation::getDefinitions()});
vectorOperations.insert(
{REGEXP_EXTRACT_ALL_FUNC_NAME, RegexpExtractAllOperation::getDefinitions()});
}

void BuiltInVectorOperations::registerCastOperations() {
Expand Down
84 changes: 73 additions & 11 deletions src/function/vector_string_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
#include "function/string/operations/left_operation.h"
#include "function/string/operations/length_operation.h"
#include "function/string/operations/lpad_operation.h"
#include "function/string/operations/reg_expr_operation.h"
#include "function/string/operations/regexp_extract_all_operation.h"
#include "function/string/operations/regexp_extract_operation.h"
#include "function/string/operations/regexp_full_match_operation.h"
#include "function/string/operations/regexp_matches_operation.h"
#include "function/string/operations/regexp_replace_operation.h"
#include "function/string/operations/repeat_operation.h"
#include "function/string/operations/right_operation.h"
#include "function/string/operations/rpad_operation.h"
Expand Down Expand Up @@ -58,16 +62,6 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> EndsWithVectorOperation:
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> REMatchVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(RE_MATCH_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, BOOL,
BinaryExecFunction<ku_string_t, ku_string_t, uint8_t, operation::REMatch>,
BinarySelectFunction<ku_string_t, ku_string_t, operation::REMatch>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> LeftVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(LEFT_FUNC_NAME,
Expand Down Expand Up @@ -141,5 +135,73 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> SubStrVectorOperation::g
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
RegexpFullMatchVectorOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_FULL_MATCH_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, BOOL,
BinaryExecFunction<ku_string_t, ku_string_t, uint8_t, operation::RegexpFullMatch>,
BinarySelectFunction<ku_string_t, ku_string_t, operation::RegexpFullMatch>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> RegexpMatchesOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_MATCHES_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, BOOL,
BinaryExecFunction<ku_string_t, ku_string_t, uint8_t, operation::RegexpMatches>,
BinarySelectFunction<ku_string_t, ku_string_t, operation::RegexpMatches>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> RegexpReplaceOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
// Todo: Implement a function with modifiers
// regexp_replace(string, regex, replacement, modifiers)
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_REPLACE_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING, STRING}, STRING,
TernaryStringExecFunction<ku_string_t, ku_string_t, ku_string_t, ku_string_t,
operation::RegexpReplace>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>> RegexpExtractOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_EXTRACT_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, STRING,
BinaryStringExecFunction<ku_string_t, ku_string_t, ku_string_t, operation::RegexpExtract>,
false /* isVarLength */));
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_EXTRACT_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING, INT64}, STRING,
TernaryStringExecFunction<ku_string_t, ku_string_t, int64_t, ku_string_t,
operation::RegexpExtract>,
false /* isVarLength */));
return definitions;
}

std::vector<std::unique_ptr<VectorOperationDefinition>>
RegexpExtractAllOperation::getDefinitions() {
std::vector<std::unique_ptr<VectorOperationDefinition>> definitions;
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_EXTRACT_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING}, VAR_LIST,
BinaryStringExecFunction<ku_string_t, ku_string_t, list_entry_t,
operation::RegexpExtractAll>,
nullptr, bindFunc, false /* isVarLength */));
definitions.emplace_back(make_unique<VectorOperationDefinition>(REGEXP_EXTRACT_FUNC_NAME,
std::vector<DataTypeID>{STRING, STRING, INT64}, VAR_LIST,
TernaryStringExecFunction<ku_string_t, ku_string_t, int64_t, list_entry_t,
operation::RegexpExtractAll>,
nullptr, bindFunc, false /* isVarLength */));
return definitions;
}

std::unique_ptr<FunctionBindData> RegexpExtractAllOperation::bindFunc(
const binder::expression_vector& arguments, FunctionDefinition* definition) {
return std::make_unique<FunctionBindData>(DataType(std::make_unique<DataType>(STRING)));
}

} // namespace function
} // namespace kuzu
6 changes: 5 additions & 1 deletion src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ const std::string ARRAY_EXTRACT_FUNC_NAME = "ARRAY_EXTRACT";
const std::string CONCAT_FUNC_NAME = "CONCAT";
const std::string CONTAINS_FUNC_NAME = "CONTAINS";
const std::string ENDS_WITH_FUNC_NAME = "ENDS_WITH";
const std::string RE_MATCH_FUNC_NAME = "RE_MATCH";
const std::string LCASE_FUNC_NAME = "LCASE";
const std::string LEFT_FUNC_NAME = "LEFT";
const std::string LENGTH_FUNC_NAME = "LENGTH";
Expand All @@ -137,6 +136,11 @@ const std::string SUFFIX_FUNC_NAME = "SUFFIX";
const std::string TRIM_FUNC_NAME = "TRIM";
const std::string UCASE_FUNC_NAME = "UCASE";
const std::string UPPER_FUNC_NAME = "UPPER";
const std::string REGEXP_FULL_MATCH_FUNC_NAME = "REGEXP_FULL_MATCH";
const std::string REGEXP_MATCHES_FUNC_NAME = "REGEXP_MATCHES";
const std::string REGEXP_REPLACE_FUNC_NAME = "REGEXP_REPLACE";
const std::string REGEXP_EXTRACT_FUNC_NAME = "REGEXP_EXTRACT";
const std::string REGEXP_EXTRACT_ALL_FUNC_NAME = "REGEXP_EXTRACT_ALL";

// Date functions.
const std::string DATE_PART_FUNC_NAME = "DATE_PART";
Expand Down
32 changes: 32 additions & 0 deletions src/include/function/string/operations/base_regexp_operation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once

#include <regex>

#include "common/vector/value_vector.h"

namespace kuzu {
namespace function {
namespace operation {

struct BaseRegexpOperation {
static inline std::string parseCypherPatten(const std::string& pattern) {
// Cypher parses escape characters with 2 backslash eg. for expressing '.' requires '\\.'
// Since Regular Expression requires only 1 backslash '\.' we need to replace double slash
// with single
return std::regex_replace(pattern, std::regex(R"(\\\\)"), "\\");
}

static inline void copyToKuzuString(
const std::string& value, common::ku_string_t& kuString, common::ValueVector& valueVector) {
if (!common::ku_string_t::isShortString(value.length())) {
kuString.overflowPtr = reinterpret_cast<uint64_t>(
common::StringVector::getInMemOverflowBuffer(&valueVector)
->allocateSpace(value.length()));
}
kuString.set(value);
}
};

} // namespace operation
} // namespace function
} // namespace kuzu
26 changes: 0 additions & 26 deletions src/include/function/string/operations/reg_expr_operation.h

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#pragma once

#include "base_regexp_operation.h"
#include "common/vector/value_vector_utils.h"
#include "re2.h"

namespace kuzu {
namespace function {
namespace operation {

struct RegexpExtractAll : BaseRegexpOperation {
static inline void operation(common::ku_string_t& value, common::ku_string_t& pattern,
std::int64_t& group, common::list_entry_t& result, common::ValueVector& resultVector) {
std::vector<std::string> matches =
regexExtractAll(value.getAsString(), pattern.getAsString(), group);
result = common::ListVector::addList(&resultVector, matches.size());
auto resultValues = common::ListVector::getListValues(&resultVector, result);
auto resultDataVector = common::ListVector::getDataVector(&resultVector);
auto numBytesPerValue = resultDataVector->getNumBytesPerValue();
for (const auto& match : matches) {
common::ku_string_t kuString;
copyToKuzuString(match, kuString, *resultDataVector);
common::ValueVectorUtils::copyValue(resultValues, *resultDataVector,
reinterpret_cast<uint8_t*>(&kuString), *resultDataVector);
resultValues += numBytesPerValue;
}
}

static inline void operation(common::ku_string_t& value, common::ku_string_t& pattern,
common::list_entry_t& result, common::ValueVector& resultVector) {
int64_t defaultGroup = 0;
operation(value, pattern, defaultGroup, result, resultVector);
}

static std::vector<std::string> regexExtractAll(
const std::string& value, const std::string& pattern, std::int64_t& group) {
RE2 regex(parseCypherPatten(pattern));
auto submatchCount = regex.NumberOfCapturingGroups() + 1;
if (group > submatchCount) {
throw common::RuntimeException("Regex match group index is out of range");
}

regex::StringPiece input(value);
std::vector<regex::StringPiece> targetSubMatches;
targetSubMatches.resize(submatchCount);
uint64_t startPos = 0;

std::vector<std::string> matches;
while (regex.Match(input, startPos, input.length(), RE2::UNANCHORED,
targetSubMatches.data(), submatchCount)) {
uint64_t consumed =
static_cast<size_t>(targetSubMatches[0].end() - (input.begin() + startPos));
if (!consumed) {
// Empty match found, increment the position manually
consumed++;
while (startPos + consumed < input.length() &&
!IsCharacter(input[startPos + consumed])) {
consumed++;
}
}
startPos += consumed;
matches.emplace_back(targetSubMatches[group]);
}

return matches;
}

static inline bool IsCharacter(char c) {
// Check if this character is not the middle of utf-8 character i.e. it shouldn't begin with
// 10 XX XX XX
return (c & 0xc0) != 0x80;
}
};

} // namespace operation
} // namespace function
} // namespace kuzu
46 changes: 46 additions & 0 deletions src/include/function/string/operations/regexp_extract_operation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#pragma once

#include "common/types/ku_string.h"
#include "common/vector/value_vector.h"
#include "re2.h"

namespace kuzu {
namespace function {
namespace operation {

struct RegexpExtract : BaseRegexpOperation {
static inline void operation(common::ku_string_t& value, common::ku_string_t& pattern,
std::int64_t& group, common::ku_string_t& result, common::ValueVector& resultValueVector) {
regexExtract(value.getAsString(), pattern.getAsString(), group, result, resultValueVector);
}

static inline void operation(common::ku_string_t& value, common::ku_string_t& pattern,
common::ku_string_t& result, common::ValueVector& resultValueVector) {
int64_t defaultGroup = 0;
regexExtract(
value.getAsString(), pattern.getAsString(), defaultGroup, result, resultValueVector);
}

static void regexExtract(const std::string& input, const std::string& pattern,
std::int64_t& group, common::ku_string_t& result, common::ValueVector& resultValueVector) {
RE2 regex(parseCypherPatten(pattern));
auto submatchCount = regex.NumberOfCapturingGroups() + 1;
if (group > submatchCount) {
throw common::RuntimeException("Regex match group index is out of range");
}

std::vector<regex::StringPiece> targetSubMatches;
targetSubMatches.resize(submatchCount);

if (!regex.Match(regex::StringPiece(input), 0, input.length(), RE2::UNANCHORED,
targetSubMatches.data(), submatchCount)) {
return;
}

copyToKuzuString(targetSubMatches[group].ToString(), result, resultValueVector);
}
};

} // namespace operation
} // namespace function
} // namespace kuzu
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include "common/types/ku_string.h"
#include "re2.h"

namespace kuzu {
namespace function {
namespace operation {

struct RegexpFullMatch : BaseRegexpOperation {
static inline void operation(
common::ku_string_t& left, common::ku_string_t& right, uint8_t& result) {
result = RE2::FullMatch(left.getAsString(), parseCypherPatten(right.getAsString()));
}
};

} // namespace operation
} // namespace function
} // namespace kuzu
19 changes: 19 additions & 0 deletions src/include/function/string/operations/regexp_matches_operation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include "common/types/ku_string.h"
#include "re2.h"

namespace kuzu {
namespace function {
namespace operation {

struct RegexpMatches : BaseRegexpOperation {
static inline void operation(
common::ku_string_t& left, common::ku_string_t& right, uint8_t& result) {
result = RE2::PartialMatch(left.getAsString(), parseCypherPatten(right.getAsString()));
}
};

} // namespace operation
} // namespace function
} // namespace kuzu
23 changes: 23 additions & 0 deletions src/include/function/string/operations/regexp_replace_operation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#pragma once

#include "common/types/ku_string.h"
#include "re2.h"

namespace kuzu {
namespace function {
namespace operation {

struct RegexpReplace : BaseRegexpOperation {
static inline void operation(common::ku_string_t& value, common::ku_string_t& pattern,
common::ku_string_t& replacement, common::ku_string_t& result,
common::ValueVector& resultValueVector) {
std::string resultStr = value.getAsString();
RE2::Replace(
&resultStr, parseCypherPatten(pattern.getAsString()), replacement.getAsString());
copyToKuzuString(resultStr, result, resultValueVector);
}
};

} // namespace operation
} // namespace function
} // namespace kuzu
Loading

0 comments on commit a682d24

Please sign in to comment.