Skip to content

Commit

Permalink
Add recursive rel data type
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed May 20, 2023
1 parent bd235a4 commit 3d680f2
Show file tree
Hide file tree
Showing 13 changed files with 51 additions and 32 deletions.
7 changes: 2 additions & 5 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,8 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
// bind variable length
auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern);
auto isVariableLength = !(lowerBound == 1 && upperBound == 1);
auto dataType = isVariableLength ?
common::LogicalType(LogicalTypeID::VAR_LIST,
std::make_unique<VarListTypeInfo>(
std::make_unique<LogicalType>(LogicalTypeID::INTERNAL_ID))) :
common::LogicalType(common::LogicalTypeID::REL);
auto dataType = isVariableLength ? common::LogicalType(common::LogicalTypeID::RECURSIVE_REL) :
common::LogicalType(common::LogicalTypeID::REL);
auto queryRel = make_shared<RelExpression>(std::move(dataType),
getUniqueExpressionName(parsedName), parsedName, tableIDs, srcNode, dstNode,
relPattern.getDirection() != BOTH, relPattern.getRelType(), lowerBound, upperBound);
Expand Down
8 changes: 0 additions & 8 deletions src/binder/bind_expression/bind_property_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,6 @@ static std::unordered_map<table_id_t, property_id_t> populatePropertyIDPerTable(
std::shared_ptr<Expression> ExpressionBinder::bindRelPropertyExpression(
const Expression& expression, const std::string& propertyName) {
auto& rel = (RelExpression&)expression;
switch (rel.getRelType()) {
case common::QueryRelType::VARIABLE_LENGTH:
case common::QueryRelType::SHORTEST:
throw BinderException(
"Cannot read property of variable length rel " + rel.toString() + ".");
default:
break;
}
if (!rel.hasPropertyExpression(propertyName)) {
throw BinderException(
"Cannot find property " + propertyName + " for " + expression.toString() + ".");
Expand Down
10 changes: 6 additions & 4 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ std::string LogicalTypeUtils::dataTypeToString(const LogicalType& dataType) {
case LogicalTypeID::ANY:
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::INTERNAL_ID:
case LogicalTypeID::BOOL:
case LogicalTypeID::INT64:
Expand All @@ -317,8 +318,7 @@ std::string LogicalTypeUtils::dataTypeToString(const LogicalType& dataType) {
case LogicalTypeID::STRING:
return dataTypeToString(dataType.typeID);
default:
throw NotImplementedException(
"Unsupported DataType: " + LogicalTypeUtils::dataTypeToString(dataType) + ".");
throw NotImplementedException("LogicalTypeUtils::dataTypeToString.");
}
}

Expand All @@ -330,6 +330,8 @@ std::string LogicalTypeUtils::dataTypeToString(LogicalTypeID dataTypeID) {
return "NODE";
case LogicalTypeID::REL:
return "REL";
case LogicalTypeID::RECURSIVE_REL:
return "RECURSIVE_REL";
case LogicalTypeID::INTERNAL_ID:
return "INTERNAL_ID";
case LogicalTypeID::BOOL:
Expand Down Expand Up @@ -361,8 +363,7 @@ std::string LogicalTypeUtils::dataTypeToString(LogicalTypeID dataTypeID) {
case LogicalTypeID::SERIAL:
return "SERIAL";
default:
throw NotImplementedException(
"Unsupported DataType: " + LogicalTypeUtils::dataTypeToString(dataTypeID) + ".");
throw NotImplementedException("LogicalTypeUtils::dataTypeToString.");
}
}

Expand Down Expand Up @@ -460,6 +461,7 @@ void LogicalType::setPhysicalType() {
case LogicalTypeID::STRING: {
physicalType = PhysicalTypeID::STRING;
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
physicalType = PhysicalTypeID::VAR_LIST;
} break;
Expand Down
18 changes: 13 additions & 5 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Value Value::createDefaultValue(const LogicalType& dataType) {
return Value(std::string(""));
case LogicalTypeID::FLOAT:
return Value((float_t)0);
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST:
case LogicalTypeID::FIXED_LIST:
case LogicalTypeID::STRUCT:
Expand Down Expand Up @@ -180,8 +181,13 @@ void Value::copyValueFrom(const uint8_t* value) {
case LogicalTypeID::STRING: {
strVal = ((ku_string_t*)value)->getAsString();
} break;
case LogicalTypeID::RECURSIVE_REL: {
nestedTypeVal =
convertKUVarListToVector(*(ku_list_t*)value, LogicalType(LogicalTypeID::INTERNAL_ID));
} break;
case LogicalTypeID::VAR_LIST: {
nestedTypeVal = convertKUVarListToVector(*(ku_list_t*)value);
nestedTypeVal =
convertKUVarListToVector(*(ku_list_t*)value, *VarListType::getChildType(&dataType));
} break;
case LogicalTypeID::FIXED_LIST: {
nestedTypeVal = convertKUFixedListToVector(value);
Expand Down Expand Up @@ -236,6 +242,7 @@ void Value::copyValueFrom(const Value& other) {
case LogicalTypeID::STRING: {
strVal = other.strVal;
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST:
case LogicalTypeID::FIXED_LIST:
case LogicalTypeID::STRUCT: {
Expand Down Expand Up @@ -287,6 +294,7 @@ std::string Value::toString() const {
return TypeUtils::toString(val.internalIDVal);
case LogicalTypeID::STRING:
return strVal;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST:
case LogicalTypeID::FIXED_LIST: {
std::string result = "[";
Expand Down Expand Up @@ -328,15 +336,15 @@ Value::Value() : dataType{LogicalTypeID::ANY}, isNull_{true} {}

Value::Value(LogicalType dataType) : dataType{std::move(dataType)}, isNull_{true} {}

std::vector<std::unique_ptr<Value>> Value::convertKUVarListToVector(ku_list_t& list) const {
std::vector<std::unique_ptr<Value>> Value::convertKUVarListToVector(
ku_list_t& list, const LogicalType& childType) const {
std::vector<std::unique_ptr<Value>> listResultValue;
auto childType = VarListType::getChildType(&dataType);
auto numBytesPerElement = storage::StorageUtils::getDataTypeSize(*childType);
auto numBytesPerElement = storage::StorageUtils::getDataTypeSize(childType);
auto listNullBytes = reinterpret_cast<uint8_t*>(list.overflowPtr);
auto numBytesForNullValues = NullBuffer::getNumBytesForNullValues(list.size);
auto listValues = listNullBytes + numBytesForNullValues;
for (auto i = 0; i < list.size; i++) {
auto childValue = std::make_unique<Value>(Value::createDefaultValue(*childType));
auto childValue = std::make_unique<Value>(Value::createDefaultValue(childType));
if (NullBuffer::isNull(listNullBytes, i)) {
childValue->setNull();
} else {
Expand Down
3 changes: 3 additions & 0 deletions src/common/vector/auxiliary_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ std::unique_ptr<AuxiliaryBuffer> AuxiliaryBufferFactory::getAuxiliaryBuffer(
return std::make_unique<StringAuxiliaryBuffer>(memoryManager);
case LogicalTypeID::STRUCT:
return std::make_unique<StructAuxiliaryBuffer>(type, memoryManager);
case LogicalTypeID::RECURSIVE_REL:
return std::make_unique<ListAuxiliaryBuffer>(
common::LogicalType(common::LogicalTypeID::INTERNAL_ID), memoryManager);
case LogicalTypeID::VAR_LIST:
return std::make_unique<ListAuxiliaryBuffer>(
*VarListType::getChildType(&type), memoryManager);
Expand Down
1 change: 1 addition & 0 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ uint32_t ValueVector::getDataTypeSize(const LogicalType& type) {
case LogicalTypeID::STRUCT: {
return sizeof(struct_entry_t);
}
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
return sizeof(list_entry_t);
}
Expand Down
3 changes: 3 additions & 0 deletions src/common/vector/value_vector_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos(
structValues += processor::FactorizedTable::getDataTypeSize(structField->dataType);
}
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
auto srcKuList = *(ku_list_t*)srcData;
auto srcNullBytes = reinterpret_cast<uint8_t*>(srcKuList.overflowPtr);
Expand Down Expand Up @@ -76,6 +77,7 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector&
structValues += processor::FactorizedTable::getDataTypeSize(structField->dataType);
}
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
auto srcListEntry = srcVector.getValue<list_entry_t>(pos);
auto srcListDataVector = common::ListVector::getDataVector(&srcVector);
Expand Down Expand Up @@ -114,6 +116,7 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector&
void ValueVectorUtils::copyValue(uint8_t* dstValue, common::ValueVector& dstVector,
const uint8_t* srcValue, const common::ValueVector& srcVector) {
switch (srcVector.dataType.getLogicalTypeID()) {
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
auto srcList = reinterpret_cast<const common::list_entry_t*>(srcValue);
auto dstList = reinterpret_cast<common::list_entry_t*>(dstValue);
Expand Down
3 changes: 3 additions & 0 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> ListLenVectorOperation::
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_LEN_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::VAR_LIST}, LogicalTypeID::INT64, execFunc,
true /* isVarlength*/));
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_LEN_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::RECURSIVE_REL}, LogicalTypeID::INT64, execFunc,
true /* isVarlength*/));
return result;
}

Expand Down
6 changes: 4 additions & 2 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ KUZU_API enum class LogicalTypeID : uint8_t {
ANY = 0,
NODE = 10,
REL = 11,
RECURSIVE_REL = 12,
// SERIAL is a special data type that is used to represent a sequence of INT64 values that are
// incremented by 1 starting from 0.
SERIAL = 12,
SERIAL = 13,

// fixed size types
BOOL = 22,
Expand Down Expand Up @@ -232,7 +233,8 @@ class LogicalType {

struct VarListType {
static inline LogicalType* getChildType(const LogicalType* type) {
assert(type->getLogicalTypeID() == LogicalTypeID::VAR_LIST);
assert(type->getLogicalTypeID() == LogicalTypeID::VAR_LIST ||
type->getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL);
auto varListTypeInfo = reinterpret_cast<VarListTypeInfo*>(type->extraTypeInfo.get());
return varListTypeInfo->getChildType();
}
Expand Down
3 changes: 2 additions & 1 deletion src/include/common/types/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ class Value {
}
}

std::vector<std::unique_ptr<Value>> convertKUVarListToVector(ku_list_t& list) const;
std::vector<std::unique_ptr<Value>> convertKUVarListToVector(
ku_list_t& list, const LogicalType& childType) const;
std::vector<std::unique_ptr<Value>> convertKUFixedListToVector(const uint8_t* fixedList) const;
std::vector<std::unique_ptr<Value>> convertKUStructToVector(const uint8_t* kuStruct) const;

Expand Down
15 changes: 10 additions & 5 deletions src/include/common/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,28 +106,33 @@ class StringVector {
class ListVector {
public:
static inline ValueVector* getDataVector(const ValueVector* vector) {
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST);
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST ||
vector->dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL);
return reinterpret_cast<ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get())
->getDataVector();
}
static inline uint8_t* getListValues(const ValueVector* vector, const list_entry_t& listEntry) {
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST);
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST ||
vector->dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL);
auto dataVector = getDataVector(vector);
return dataVector->getData() + dataVector->getNumBytesPerValue() * listEntry.offset;
}
static inline uint8_t* getListValuesWithOffset(const ValueVector* vector,
const list_entry_t& listEntry, common::offset_t elementOffsetInList) {
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST);
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST ||
vector->dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL);
return getListValues(vector, listEntry) +
elementOffsetInList * getDataVector(vector)->getNumBytesPerValue();
}
static inline list_entry_t addList(ValueVector* vector, uint64_t listSize) {
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST);
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST ||
vector->dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL);
return reinterpret_cast<ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get())
->addList(listSize);
}
static inline void resetListAuxiliaryBuffer(ValueVector* vector) {
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST);
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST ||
vector->dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL);
reinterpret_cast<ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get())->resetSize();
}
};
Expand Down
2 changes: 2 additions & 0 deletions src/processor/result/factorized_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ uint32_t FactorizedTable::getDataTypeSize(const common::LogicalType& type) {
return getDataTypeSize(*FixedListType::getChildType(&type)) *
FixedListType::getNumElementsInList(&type);
}
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
return sizeof(ku_list_t);
}
Expand Down Expand Up @@ -682,6 +683,7 @@ void FactorizedTable::copyOverflowIfNecessary(
*stringToWriteFrom, *(ku_string_t*)dst);
}
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
diskOverflowFile->writeListOverflowAndUpdateOverflowPtr(
*(ku_list_t*)src, *(ku_list_t*)dst, type);
Expand Down
4 changes: 2 additions & 2 deletions test/runner/e2e_exception_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ class TinySnbExceptionTest : public DBTest {

TEST_F(TinySnbExceptionTest, ReadVarlengthRelPropertyTest1) {
auto result = conn->query("MATCH (a:person)-[e:knows*1..3]->(b:person) RETURN e.age;");
ASSERT_STREQ("Binder exception: e has data type VAR_LIST. (STRUCT,REL,NODE) was expected.",
ASSERT_STREQ("Binder exception: e has data type RECURSIVE_REL. (STRUCT,REL,NODE) was expected.",
result->getErrorMessage().c_str());
}

TEST_F(TinySnbExceptionTest, ReadVarlengthRelPropertyTest2) {
auto result =
conn->query("MATCH (a:person)-[e:knows*1..3]->(b:person) WHERE ID(e) = 0 RETURN COUNT(*);");
ASSERT_STREQ("Binder exception: e has data type VAR_LIST. (REL,NODE) was expected.",
ASSERT_STREQ("Binder exception: e has data type RECURSIVE_REL. (REL,NODE) was expected.",
result->getErrorMessage().c_str());
}

Expand Down

0 comments on commit 3d680f2

Please sign in to comment.