Skip to content

Commit

Permalink
Implement union functions
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Jun 14, 2023
1 parent 904c585 commit d4dfe64
Show file tree
Hide file tree
Showing 34 changed files with 2,969 additions and 2,442 deletions.
5 changes: 4 additions & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,14 @@ oC_ParenthesizedExpression

oC_FunctionInvocation
: oC_FunctionName SP? '(' SP? '*' SP? ')'
| oC_FunctionName SP? '(' SP? ( DISTINCT SP? )? ( oC_Expression SP? ( ',' SP? oC_Expression SP? )* )? ')' ;
| oC_FunctionName SP? '(' SP? ( DISTINCT SP? )? ( kU_FunctionParameter SP? ( ',' SP? kU_FunctionParameter SP? )* )? ')' ;

oC_FunctionName
: oC_SymbolicName ;

kU_FunctionParameter
: ( oC_SymbolicName SP? ':' '=' SP? )? oC_Expression ;

oC_ExistentialSubquery
: EXISTS SP? '{' SP? MATCH SP? oC_Pattern ( SP? oC_Where )? SP? '}' ;

Expand Down
15 changes: 15 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ void LogicalType::setPhysicalType() {
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::UNION:
case LogicalTypeID::STRUCT: {
physicalType = PhysicalTypeID::STRUCT;
} break;
Expand Down Expand Up @@ -338,6 +339,18 @@ std::string LogicalTypeUtils::dataTypeToString(const LogicalType& dataType) {
return dataTypeToString(*fixedListTypeInfo->getChildType()) + "[" +
std::to_string(fixedListTypeInfo->getNumElementsInList()) + "]";
}
case LogicalTypeID::UNION: {

Check warning on line 342 in src/common/types/types.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/types.cpp#L342

Added line #L342 was not covered by tests
auto unionTypeInfo = reinterpret_cast<StructTypeInfo*>(dataType.extraTypeInfo.get());
std::string dataTypeStr = dataTypeToString(dataType.typeID) + "(";
auto numFields = unionTypeInfo->getChildrenTypes().size();
auto fieldNames = unionTypeInfo->getChildrenNames();
for (auto i = 1u; i < numFields; i++) {
dataTypeStr += fieldNames[i] + ":";
dataTypeStr += dataTypeToString(*unionTypeInfo->getChildrenTypes()[i]);
dataTypeStr += (i == numFields - 1 ? ")" : ", ");

Check warning on line 350 in src/common/types/types.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/types.cpp#L344-L350

Added lines #L344 - L350 were not covered by tests
}
return dataTypeStr;
}

Check warning on line 353 in src/common/types/types.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/types.cpp#L353

Added line #L353 was not covered by tests
case LogicalTypeID::STRUCT: {
auto structTypeInfo = reinterpret_cast<StructTypeInfo*>(dataType.extraTypeInfo.get());
std::string dataTypeStr = dataTypeToString(dataType.typeID) + "(";
Expand Down Expand Up @@ -416,6 +429,8 @@ std::string LogicalTypeUtils::dataTypeToString(LogicalTypeID dataTypeID) {
return "SERIAL";
case LogicalTypeID::MAP:
return "MAP";
case LogicalTypeID::UNION:
return "UNION";

Check warning on line 433 in src/common/types/types.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/types.cpp#L433

Added line #L433 was not covered by tests
default:
throw NotImplementedException("LogicalTypeUtils::dataTypeToString.");
}
Expand Down
67 changes: 54 additions & 13 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Value Value::createDefaultValue(const LogicalType& dataType) {
case LogicalTypeID::MAP:
case LogicalTypeID::VAR_LIST:
case LogicalTypeID::FIXED_LIST:
case LogicalTypeID::UNION:
case LogicalTypeID::STRUCT:
return Value(dataType, std::vector<std::unique_ptr<Value>>{});
default:
Expand Down Expand Up @@ -147,42 +148,52 @@ Value::Value(const Value& other) : dataType{other.dataType}, isNull_{other.isNul
}

void Value::copyValueFrom(const uint8_t* value) {
switch (dataType.getPhysicalType()) {
case PhysicalTypeID::INT64: {
switch (dataType.getLogicalTypeID()) {
case LogicalTypeID::SERIAL:
case LogicalTypeID::TIMESTAMP:
case LogicalTypeID::INT64: {
val.int64Val = *((int64_t*)value);
} break;
case PhysicalTypeID::INT32: {
case LogicalTypeID::DATE:
case LogicalTypeID::INT32: {
val.int32Val = *((int32_t*)value);
} break;
case PhysicalTypeID::INT16: {
case LogicalTypeID::INT16: {
val.int16Val = *((int16_t*)value);
} break;
case PhysicalTypeID::BOOL: {
case LogicalTypeID::BOOL: {
val.booleanVal = *((bool*)value);
} break;
case PhysicalTypeID::DOUBLE: {
case LogicalTypeID::DOUBLE: {
val.doubleVal = *((double*)value);
} break;
case PhysicalTypeID::FLOAT: {
case LogicalTypeID::FLOAT: {
val.floatVal = *((float_t*)value);
} break;
case PhysicalTypeID::INTERVAL: {
case LogicalTypeID::INTERVAL: {
val.intervalVal = *((interval_t*)value);
} break;
case PhysicalTypeID::INTERNAL_ID: {
case LogicalTypeID::INTERNAL_ID: {
val.internalIDVal = *((nodeID_t*)value);
} break;
case PhysicalTypeID::STRING: {
case LogicalTypeID::STRING: {
strVal = ((ku_string_t*)value)->getAsString();
} break;
case PhysicalTypeID::VAR_LIST: {
case LogicalTypeID::MAP:
case LogicalTypeID::VAR_LIST: {
nestedTypeVal =
convertKUVarListToVector(*(ku_list_t*)value, *VarListType::getChildType(&dataType));
} break;
case PhysicalTypeID::FIXED_LIST: {
case LogicalTypeID::FIXED_LIST: {
nestedTypeVal = convertKUFixedListToVector(value);
} break;
case PhysicalTypeID::STRUCT: {
case LogicalTypeID::UNION: {
nestedTypeVal = convertKUUnionToVector(value);
} break;
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::STRUCT: {
nestedTypeVal = convertKUStructToVector(value);
} break;
default:
Expand Down Expand Up @@ -305,6 +316,11 @@ std::string Value::toString() const {
result += "]";
return result;
}
case LogicalTypeID::UNION: {
// Only one member in the union can be active at a time and that member is always stored
// at index 0.
return nestedTypeVal[0]->toString();
}
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::STRUCT: {
std::string result = "{";
Expand Down Expand Up @@ -402,6 +418,31 @@ std::vector<std::unique_ptr<Value>> Value::convertKUStructToVector(const uint8_t
return structVal;
}

std::vector<std::unique_ptr<Value>> Value::convertKUUnionToVector(const uint8_t* kuUnion) const {
std::vector<std::unique_ptr<Value>> unionVal;
auto childrenTypes = StructType::getFieldTypes(&dataType);
auto unionNullValues = kuUnion;
auto unionValues = unionNullValues + NullBuffer::getNumBytesForNullValues(childrenTypes.size());
// For union dataType, only one member can be active at a time. So we don't need to copy all
// union fields into value.
auto activeMemberIdx = UnionType::getInternalFieldIdx(*(union_field_idx_t*)unionValues);
auto childValue =
std::make_unique<Value>(Value::createDefaultValue(*childrenTypes[activeMemberIdx]));
auto curMemberIdx = 0u;
// Seek to the current active member value.
while (curMemberIdx < activeMemberIdx) {
unionValues += storage::StorageUtils::getDataTypeSize(*childrenTypes[curMemberIdx]);
curMemberIdx++;
}
if (NullBuffer::isNull(unionNullValues, activeMemberIdx)) {
childValue->setNull(true);

Check warning on line 438 in src/common/types/value.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/value.cpp#L438

Added line #L438 was not covered by tests
} else {
childValue->copyValueFrom(unionValues);
}
unionVal.emplace_back(std::move(childValue));
return unionVal;
}

static std::string propertiesToString(
const std::vector<std::pair<std::string, std::unique_ptr<Value>>>& properties) {
std::string result = "{";
Expand Down
4 changes: 2 additions & 2 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void ValueVector::resetAuxiliaryBuffer() {
case PhysicalTypeID::STRUCT: {
auto structAuxiliaryBuffer =
reinterpret_cast<StructAuxiliaryBuffer*>(auxiliaryBuffer.get());
for (auto& vector : structAuxiliaryBuffer->getChildrenVectors()) {
for (auto& vector : structAuxiliaryBuffer->getFieldVectors()) {
vector->resetAuxiliaryBuffer();
}
return;
Expand Down Expand Up @@ -105,7 +105,7 @@ uint32_t ValueVector::getDataTypeSize(const LogicalType& type) {

void ValueVector::initializeValueBuffer() {
valueBuffer = std::make_unique<uint8_t[]>(numBytesPerValue * DEFAULT_VECTOR_CAPACITY);
if (dataType.getLogicalTypeID() == LogicalTypeID::STRUCT) {
if (dataType.getPhysicalType() == PhysicalTypeID::STRUCT) {
// For struct valueVectors, each struct_entry_t stores its current position in the
// valueVector.
StructVector::initializeEntries(this);
Expand Down
3 changes: 3 additions & 0 deletions src/expression_evaluator/base_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ void BaseExpressionEvaluator::init(

void BaseExpressionEvaluator::resolveResultStateFromChildren(
const std::vector<BaseExpressionEvaluator*>& inputEvaluators) {
if (resultVector->state != nullptr) {
return;
}
for (auto& input : inputEvaluators) {
if (!input->isResultFlat()) {
isResultFlat_ = false;
Expand Down
17 changes: 15 additions & 2 deletions src/expression_evaluator/function_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ void FunctionExpressionEvaluator::evaluate() {
for (auto& child : children) {
child->evaluate();
}
execFunc(parameters, *resultVector);
if (execFunc != nullptr) {
// Some functions are evaluated at compile time (e.g. struct_extract).
execFunc(parameters, *resultVector);
}
}

bool FunctionExpressionEvaluator::select(SelectionVector& selVector) {
Expand Down Expand Up @@ -61,18 +64,21 @@ std::unique_ptr<BaseExpressionEvaluator> FunctionExpressionEvaluator::clone() {
void FunctionExpressionEvaluator::resolveResultVector(
const ResultSet& resultSet, MemoryManager* memoryManager) {
auto& functionExpression = (binder::ScalarFunctionExpression&)*expression;
if (functionExpression.getFunctionName() == STRUCT_EXTRACT_FUNC_NAME) {
auto functionName = functionExpression.getFunctionName();
if (functionName == STRUCT_EXTRACT_FUNC_NAME || functionName == UNION_EXTRACT_FUNC_NAME) {
auto& bindData = (function::StructExtractBindData&)*functionExpression.getBindData();
resultVector =
StructVector::getFieldVector(children[0]->resultVector.get(), bindData.childIdx);
} else {
resultVector = std::make_shared<ValueVector>(expression->dataType, memoryManager);
}
std::vector<BaseExpressionEvaluator*> inputEvaluators;
inputEvaluators.reserve(children.size());
for (auto& child : children) {
inputEvaluators.push_back(child.get());
}
resolveResultStateFromChildren(inputEvaluators);
// TODO(Ziyi): We should move result valueVector state resolution to each function.
if (functionExpression.getFunctionName() == STRUCT_PACK_FUNC_NAME) {
// Our goal is to make the state of the resultVector consistent with its children vectors.
// If the resultVector and inputVector are in different dataChunks, we should create a new
Expand All @@ -85,6 +91,13 @@ void FunctionExpressionEvaluator::resolveResultVector(
resultVector.get(), i, inputEvaluator->resultVector);
}
}
} else if (functionExpression.getFunctionName() == UNION_VALUE_FUNC_NAME) {
assert(inputEvaluators.size() == 1);
resultVector->setState(inputEvaluators[0]->resultVector->state);
common::UnionVector::getTagVector(resultVector.get())
->setState(inputEvaluators[0]->resultVector->state);
common::UnionVector::referenceVector(
resultVector.get(), 0 /* fieldIdx */, inputEvaluators[0]->resultVector);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ add_library(kuzu_function
vector_string_operations.cpp
vector_timestamp_operations.cpp
vector_struct_operations.cpp
vector_map_operation.cpp)
vector_map_operation.cpp
vector_union_operations.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_function>
Expand Down
9 changes: 9 additions & 0 deletions src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "function/string/vector_string_operations.h"
#include "function/struct/vector_struct_operations.h"
#include "function/timestamp/vector_timestamp_operations.h"
#include "function/union/vector_union_operations.h"

using namespace kuzu::common;

Expand All @@ -28,6 +29,7 @@ void BuiltInVectorOperations::registerVectorOperations() {
registerListOperations();
registerStructOperations();
registerMapOperations();
registerUnionOperations();
// register internal offset operation
vectorOperations.insert({OFFSET_FUNC_NAME, OffsetVectorOperation::getDefinitions()});
}
Expand Down Expand Up @@ -496,5 +498,12 @@ void BuiltInVectorOperations::registerMapOperations() {
vectorOperations.insert({MAP_VALUES_FUNC_NAME, MapValuesVectorOperations::getDefinitions()});
}

void BuiltInVectorOperations::registerUnionOperations() {
vectorOperations.insert({UNION_VALUE_FUNC_NAME, UnionValueVectorOperations::getDefinitions()});
vectorOperations.insert({UNION_TAG_FUNC_NAME, UnionTagVectorOperations::getDefinitions()});
vectorOperations.insert(
{UNION_EXTRACT_FUNC_NAME, UnionExtractVectorOperations::getDefinitions()});
}

} // namespace function
} // namespace kuzu
Loading

0 comments on commit d4dfe64

Please sign in to comment.