Skip to content

Commit

Permalink
Implement union functions
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Jun 13, 2023
1 parent 708ac89 commit 5536970
Show file tree
Hide file tree
Showing 26 changed files with 2,919 additions and 2,413 deletions.
8 changes: 7 additions & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,17 @@ oC_ParenthesizedExpression

oC_FunctionInvocation
: oC_FunctionName SP? '(' SP? '*' SP? ')'
| oC_FunctionName SP? '(' SP? ( DISTINCT SP? )? ( oC_Expression SP? ( ',' SP? oC_Expression SP? )* )? ')' ;
| oC_FunctionName SP? '(' SP? ( DISTINCT SP? )? ( oC_FunctionParameter SP? ( ',' SP? oC_FunctionParameter SP? )* )? ')' ;

oC_FunctionName
: oC_SymbolicName ;

oC_FunctionParameter
: ( oC_FunctionParameterName SP? ':' '=' SP? )? oC_Expression ;

oC_FunctionParameterName
: oC_SymbolicName ;

oC_ExistentialSubquery
: EXISTS SP? '{' SP? MATCH SP? oC_Pattern ( SP? oC_Where )? SP? '}' ;

Expand Down
15 changes: 15 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ void LogicalType::setPhysicalType() {
} break;
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
case LogicalTypeID::UNION:
case LogicalTypeID::STRUCT: {
physicalType = PhysicalTypeID::STRUCT;
} break;
Expand Down Expand Up @@ -338,6 +339,18 @@ std::string LogicalTypeUtils::dataTypeToString(const LogicalType& dataType) {
return dataTypeToString(*fixedListTypeInfo->getChildType()) + "[" +
std::to_string(fixedListTypeInfo->getNumElementsInList()) + "]";
}
case LogicalTypeID::UNION: {

Check warning on line 342 in src/common/types/types.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/types.cpp#L342

Added line #L342 was not covered by tests
auto unionTypeInfo = reinterpret_cast<StructTypeInfo*>(dataType.extraTypeInfo.get());
std::string dataTypeStr = dataTypeToString(dataType.typeID) + "(";
auto numFields = unionTypeInfo->getChildrenTypes().size();
auto fieldNames = unionTypeInfo->getChildrenNames();
for (auto i = 1u; i < numFields; i++) {
dataTypeStr += fieldNames[i] + ":";
dataTypeStr += dataTypeToString(*unionTypeInfo->getChildrenTypes()[i]);
dataTypeStr += (i == numFields - 1 ? ")" : ", ");

Check warning on line 350 in src/common/types/types.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/types.cpp#L344-L350

Added lines #L344 - L350 were not covered by tests
}
return dataTypeStr;
}

Check warning on line 353 in src/common/types/types.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/types.cpp#L353

Added line #L353 was not covered by tests
case LogicalTypeID::STRUCT: {
auto structTypeInfo = reinterpret_cast<StructTypeInfo*>(dataType.extraTypeInfo.get());
std::string dataTypeStr = dataTypeToString(dataType.typeID) + "(";
Expand Down Expand Up @@ -416,6 +429,8 @@ std::string LogicalTypeUtils::dataTypeToString(LogicalTypeID dataTypeID) {
return "SERIAL";
case LogicalTypeID::MAP:
return "MAP";
case LogicalTypeID::UNION:
return "UNION";

Check warning on line 433 in src/common/types/types.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/types.cpp#L433

Added line #L433 was not covered by tests
default:
throw NotImplementedException("LogicalTypeUtils::dataTypeToString.");
}
Expand Down
33 changes: 29 additions & 4 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Value Value::createDefaultValue(const LogicalType& dataType) {
case LogicalTypeID::MAP:
case LogicalTypeID::VAR_LIST:
case LogicalTypeID::FIXED_LIST:
case LogicalTypeID::UNION:
case LogicalTypeID::STRUCT:
return Value(dataType, std::vector<std::unique_ptr<Value>>{});
default:
Expand Down Expand Up @@ -306,6 +307,10 @@ std::string Value::toString() const {
result += "]";
return result;
}
case LogicalTypeID::UNION: {
// Only one member in the union can be active at a time.
return nestedTypeVal[0]->toString();
}
case LogicalTypeID::STRUCT: {
std::string result = "{";
auto fieldNames = StructType::getFieldNames(&dataType);
Expand Down Expand Up @@ -389,15 +394,35 @@ std::vector<std::unique_ptr<Value>> Value::convertKUStructToVector(const uint8_t
auto numFields = childrenTypes.size();
auto structNullValues = kuStruct;
auto structValues = structNullValues + NullBuffer::getNumBytesForNullValues(numFields);
for (auto i = 0; i < numFields; i++) {
auto childValue = std::make_unique<Value>(Value::createDefaultValue(*childrenTypes[i]));
if (NullBuffer::isNull(structNullValues, i)) {
if (dataType.getLogicalTypeID() == LogicalTypeID::UNION) {
// For union dataType, only one member can be active at a time. So we don't need to copy all
// union fields into value.
auto activeMemberIdx = *(union_field_idx_t*)structValues + 1;
auto childValue =
std::make_unique<Value>(Value::createDefaultValue(*childrenTypes[activeMemberIdx]));
auto curMemberIdx = 0u;
// Seek to the current active member value.
while (curMemberIdx < activeMemberIdx) {
structValues += storage::StorageUtils::getDataTypeSize(*childrenTypes[curMemberIdx]);
curMemberIdx++;
}
if (NullBuffer::isNull(structNullValues, activeMemberIdx)) {
childValue->setNull(true);
} else {
childValue->copyValueFrom(structValues);
}
structVal.emplace_back(std::move(childValue));
structValues += storage::StorageUtils::getDataTypeSize(*childrenTypes[i]);
} else {
for (auto i = 0; i < numFields; i++) {
auto childValue = std::make_unique<Value>(Value::createDefaultValue(*childrenTypes[i]));
if (NullBuffer::isNull(structNullValues, i)) {
childValue->setNull(true);
} else {
childValue->copyValueFrom(structValues);
}
structVal.emplace_back(std::move(childValue));
structValues += storage::StorageUtils::getDataTypeSize(*childrenTypes[i]);
}
}
return structVal;
}
Expand Down
3 changes: 3 additions & 0 deletions src/expression_evaluator/base_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ void BaseExpressionEvaluator::init(

void BaseExpressionEvaluator::resolveResultStateFromChildren(
const std::vector<BaseExpressionEvaluator*>& inputEvaluators) {
if (resultVector->state != nullptr) {
return;
}
for (auto& input : inputEvaluators) {
if (!input->isResultFlat()) {
isResultFlat_ = false;
Expand Down
17 changes: 15 additions & 2 deletions src/expression_evaluator/function_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ void FunctionExpressionEvaluator::evaluate() {
for (auto& child : children) {
child->evaluate();
}
execFunc(parameters, *resultVector);
if (execFunc != nullptr) {
// Some functions are evaluated at compile time (e.g. struct_extract).
execFunc(parameters, *resultVector);
}
}

bool FunctionExpressionEvaluator::select(SelectionVector& selVector) {
Expand Down Expand Up @@ -61,18 +64,21 @@ std::unique_ptr<BaseExpressionEvaluator> FunctionExpressionEvaluator::clone() {
void FunctionExpressionEvaluator::resolveResultVector(
const ResultSet& resultSet, MemoryManager* memoryManager) {
auto& functionExpression = (binder::ScalarFunctionExpression&)*expression;
if (functionExpression.getFunctionName() == STRUCT_EXTRACT_FUNC_NAME) {
auto functionName = functionExpression.getFunctionName();
if (functionName == STRUCT_EXTRACT_FUNC_NAME || functionName == UNION_EXTRACT_FUNC_NAME) {
auto& bindData = (function::StructExtractBindData&)*functionExpression.getBindData();
resultVector =
StructVector::getFieldVector(children[0]->resultVector.get(), bindData.childIdx);
} else {
resultVector = std::make_shared<ValueVector>(expression->dataType, memoryManager);
}
std::vector<BaseExpressionEvaluator*> inputEvaluators;
inputEvaluators.reserve(children.size());
for (auto& child : children) {
inputEvaluators.push_back(child.get());
}
resolveResultStateFromChildren(inputEvaluators);
// TODO(Ziyi): We should move result valueVector state resolution to each function.
if (functionExpression.getFunctionName() == STRUCT_PACK_FUNC_NAME) {
// Our goal is to make the state of the resultVector consistent with its children vectors.
// If the resultVector and inputVector are in different dataChunks, we should create a new
Expand All @@ -85,6 +91,13 @@ void FunctionExpressionEvaluator::resolveResultVector(
resultVector.get(), i, inputEvaluator->resultVector);
}
}
} else if (functionExpression.getFunctionName() == UNION_VALUE_FUNC_NAME) {
assert(inputEvaluators.size() == 1);
resultVector->setState(inputEvaluators[0]->resultVector->state);
common::UnionVector::getTagVector(resultVector.get())
->setState(inputEvaluators[0]->resultVector->state);
common::UnionVector::referenceVector(
resultVector.get(), 0 /* fieldIdx */, inputEvaluators[0]->resultVector);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ add_library(kuzu_function
vector_string_operations.cpp
vector_timestamp_operations.cpp
vector_struct_operations.cpp
vector_map_operation.cpp)
vector_map_operation.cpp
vector_union_operations.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_function>
Expand Down
9 changes: 9 additions & 0 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "function/string/vector_string_operations.h"
#include "function/struct/vector_struct_operations.h"
#include "function/timestamp/vector_timestamp_operations.h"
#include "function/union/vector_union_operations.h"

using namespace kuzu::common;

Expand All @@ -28,6 +29,7 @@ void BuiltInVectorOperations::registerVectorOperations() {
registerListOperations();
registerStructOperations();
registerMapOperations();
registerUnionOperations();
// register internal offset operation
vectorOperations.insert({OFFSET_FUNC_NAME, OffsetVectorOperation::getDefinitions()});
}
Expand Down Expand Up @@ -496,5 +498,12 @@ void BuiltInVectorOperations::registerMapOperations() {
vectorOperations.insert({MAP_VALUES_FUNC_NAME, MapValuesVectorOperations::getDefinitions()});
}

void BuiltInVectorOperations::registerUnionOperations() {
vectorOperations.insert({UNION_VALUE_FUNC_NAME, UnionValueVectorOperations::getDefinitions()});
vectorOperations.insert({UNION_TAG_FUNC_NAME, UnionTagVectorOperations::getDefinitions()});
vectorOperations.insert(
{UNION_EXTRACT_FUNC_NAME, UnionExtractVectorOperations::getDefinitions()});
}

} // namespace function
} // namespace kuzu
Loading

0 comments on commit 5536970

Please sign in to comment.