Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recursive rel logical type #1553

Merged
merged 1 commit into from
May 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,8 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
// bind variable length
auto [lowerBound, upperBound] = bindVariableLengthRelBound(relPattern);
auto isVariableLength = !(lowerBound == 1 && upperBound == 1);
auto dataType = isVariableLength ?
common::LogicalType(LogicalTypeID::VAR_LIST,
std::make_unique<VarListTypeInfo>(
std::make_unique<LogicalType>(LogicalTypeID::INTERNAL_ID))) :
common::LogicalType(common::LogicalTypeID::REL);
auto dataType = isVariableLength ? common::LogicalType(common::LogicalTypeID::RECURSIVE_REL) :
common::LogicalType(common::LogicalTypeID::REL);
auto queryRel = make_shared<RelExpression>(std::move(dataType),
getUniqueExpressionName(parsedName), parsedName, tableIDs, srcNode, dstNode,
relPattern.getDirection() != BOTH, relPattern.getRelType(), lowerBound, upperBound);
Expand Down
8 changes: 0 additions & 8 deletions src/binder/bind_expression/bind_property_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,6 @@ static std::unordered_map<table_id_t, property_id_t> populatePropertyIDPerTable(
std::shared_ptr<Expression> ExpressionBinder::bindRelPropertyExpression(
const Expression& expression, const std::string& propertyName) {
auto& rel = (RelExpression&)expression;
switch (rel.getRelType()) {
case common::QueryRelType::VARIABLE_LENGTH:
case common::QueryRelType::SHORTEST:
throw BinderException(
"Cannot read property of variable length rel " + rel.toString() + ".");
default:
break;
}
if (!rel.hasPropertyExpression(propertyName)) {
throw BinderException(
"Cannot find property " + propertyName + " for " + expression.toString() + ".");
Expand Down
10 changes: 6 additions & 4 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ std::string LogicalTypeUtils::dataTypeToString(const LogicalType& dataType) {
case LogicalTypeID::ANY:
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::INTERNAL_ID:
case LogicalTypeID::BOOL:
case LogicalTypeID::INT64:
Expand All @@ -317,8 +318,7 @@ std::string LogicalTypeUtils::dataTypeToString(const LogicalType& dataType) {
case LogicalTypeID::STRING:
return dataTypeToString(dataType.typeID);
default:
throw NotImplementedException(
"Unsupported DataType: " + LogicalTypeUtils::dataTypeToString(dataType) + ".");
throw NotImplementedException("LogicalTypeUtils::dataTypeToString.");
}
}

Expand All @@ -330,6 +330,8 @@ std::string LogicalTypeUtils::dataTypeToString(LogicalTypeID dataTypeID) {
return "NODE";
case LogicalTypeID::REL:
return "REL";
case LogicalTypeID::RECURSIVE_REL:
return "RECURSIVE_REL";
case LogicalTypeID::INTERNAL_ID:
return "INTERNAL_ID";
case LogicalTypeID::BOOL:
Expand Down Expand Up @@ -361,8 +363,7 @@ std::string LogicalTypeUtils::dataTypeToString(LogicalTypeID dataTypeID) {
case LogicalTypeID::SERIAL:
return "SERIAL";
default:
throw NotImplementedException(
"Unsupported DataType: " + LogicalTypeUtils::dataTypeToString(dataTypeID) + ".");
throw NotImplementedException("LogicalTypeUtils::dataTypeToString.");
}
}

Expand Down Expand Up @@ -460,6 +461,7 @@ void LogicalType::setPhysicalType() {
case LogicalTypeID::STRING: {
physicalType = PhysicalTypeID::STRING;
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
physicalType = PhysicalTypeID::VAR_LIST;
} break;
Expand Down
18 changes: 13 additions & 5 deletions src/common/types/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Value Value::createDefaultValue(const LogicalType& dataType) {
return Value(std::string(""));
case LogicalTypeID::FLOAT:
return Value((float_t)0);
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST:
case LogicalTypeID::FIXED_LIST:
case LogicalTypeID::STRUCT:
Expand Down Expand Up @@ -180,8 +181,13 @@ void Value::copyValueFrom(const uint8_t* value) {
case LogicalTypeID::STRING: {
strVal = ((ku_string_t*)value)->getAsString();
} break;
case LogicalTypeID::RECURSIVE_REL: {
nestedTypeVal =
convertKUVarListToVector(*(ku_list_t*)value, LogicalType(LogicalTypeID::INTERNAL_ID));
} break;
case LogicalTypeID::VAR_LIST: {
nestedTypeVal = convertKUVarListToVector(*(ku_list_t*)value);
nestedTypeVal =
convertKUVarListToVector(*(ku_list_t*)value, *VarListType::getChildType(&dataType));
} break;
case LogicalTypeID::FIXED_LIST: {
nestedTypeVal = convertKUFixedListToVector(value);
Expand Down Expand Up @@ -236,6 +242,7 @@ void Value::copyValueFrom(const Value& other) {
case LogicalTypeID::STRING: {
strVal = other.strVal;
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST:
case LogicalTypeID::FIXED_LIST:
case LogicalTypeID::STRUCT: {
Expand Down Expand Up @@ -287,6 +294,7 @@ std::string Value::toString() const {
return TypeUtils::toString(val.internalIDVal);
case LogicalTypeID::STRING:
return strVal;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST:
case LogicalTypeID::FIXED_LIST: {
std::string result = "[";
Expand Down Expand Up @@ -328,15 +336,15 @@ Value::Value() : dataType{LogicalTypeID::ANY}, isNull_{true} {}

Value::Value(LogicalType dataType) : dataType{std::move(dataType)}, isNull_{true} {}

std::vector<std::unique_ptr<Value>> Value::convertKUVarListToVector(ku_list_t& list) const {
std::vector<std::unique_ptr<Value>> Value::convertKUVarListToVector(
ku_list_t& list, const LogicalType& childType) const {
std::vector<std::unique_ptr<Value>> listResultValue;
auto childType = VarListType::getChildType(&dataType);
auto numBytesPerElement = storage::StorageUtils::getDataTypeSize(*childType);
auto numBytesPerElement = storage::StorageUtils::getDataTypeSize(childType);
auto listNullBytes = reinterpret_cast<uint8_t*>(list.overflowPtr);
auto numBytesForNullValues = NullBuffer::getNumBytesForNullValues(list.size);
auto listValues = listNullBytes + numBytesForNullValues;
for (auto i = 0; i < list.size; i++) {
auto childValue = std::make_unique<Value>(Value::createDefaultValue(*childType));
auto childValue = std::make_unique<Value>(Value::createDefaultValue(childType));
if (NullBuffer::isNull(listNullBytes, i)) {
childValue->setNull();
} else {
Expand Down
3 changes: 3 additions & 0 deletions src/common/vector/auxiliary_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ std::unique_ptr<AuxiliaryBuffer> AuxiliaryBufferFactory::getAuxiliaryBuffer(
return std::make_unique<StringAuxiliaryBuffer>(memoryManager);
case LogicalTypeID::STRUCT:
return std::make_unique<StructAuxiliaryBuffer>(type, memoryManager);
case LogicalTypeID::RECURSIVE_REL:
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
return std::make_unique<ListAuxiliaryBuffer>(
common::LogicalType(common::LogicalTypeID::INTERNAL_ID), memoryManager);
case LogicalTypeID::VAR_LIST:
return std::make_unique<ListAuxiliaryBuffer>(
*VarListType::getChildType(&type), memoryManager);
Expand Down
1 change: 1 addition & 0 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ uint32_t ValueVector::getDataTypeSize(const LogicalType& type) {
case LogicalTypeID::STRUCT: {
return sizeof(struct_entry_t);
}
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
return sizeof(list_entry_t);
}
Expand Down
3 changes: 3 additions & 0 deletions src/common/vector/value_vector_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeIntoPos(
structValues += processor::FactorizedTable::getDataTypeSize(structField->dataType);
}
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
auto srcKuList = *(ku_list_t*)srcData;
auto srcNullBytes = reinterpret_cast<uint8_t*>(srcKuList.overflowPtr);
Expand Down Expand Up @@ -76,6 +77,7 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector&
structValues += processor::FactorizedTable::getDataTypeSize(structField->dataType);
}
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
auto srcListEntry = srcVector.getValue<list_entry_t>(pos);
auto srcListDataVector = common::ListVector::getDataVector(&srcVector);
Expand Down Expand Up @@ -114,6 +116,7 @@ void ValueVectorUtils::copyNonNullDataWithSameTypeOutFromPos(const ValueVector&
void ValueVectorUtils::copyValue(uint8_t* dstValue, common::ValueVector& dstVector,
const uint8_t* srcValue, const common::ValueVector& srcVector) {
switch (srcVector.dataType.getLogicalTypeID()) {
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
auto srcList = reinterpret_cast<const common::list_entry_t*>(srcValue);
auto dstList = reinterpret_cast<common::list_entry_t*>(dstValue);
Expand Down
3 changes: 3 additions & 0 deletions src/function/vector_list_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ std::vector<std::unique_ptr<VectorOperationDefinition>> ListLenVectorOperation::
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_LEN_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::VAR_LIST}, LogicalTypeID::INT64, execFunc,
true /* isVarlength*/));
result.push_back(std::make_unique<VectorOperationDefinition>(LIST_LEN_FUNC_NAME,
std::vector<LogicalTypeID>{LogicalTypeID::RECURSIVE_REL}, LogicalTypeID::INT64, execFunc,
true /* isVarlength*/));
return result;
}

Expand Down
6 changes: 4 additions & 2 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ KUZU_API enum class LogicalTypeID : uint8_t {
ANY = 0,
NODE = 10,
REL = 11,
RECURSIVE_REL = 12,
// SERIAL is a special data type that is used to represent a sequence of INT64 values that are
// incremented by 1 starting from 0.
SERIAL = 12,
SERIAL = 13,

// fixed size types
BOOL = 22,
Expand Down Expand Up @@ -232,7 +233,8 @@ class LogicalType {

struct VarListType {
static inline LogicalType* getChildType(const LogicalType* type) {
assert(type->getLogicalTypeID() == LogicalTypeID::VAR_LIST);
assert(type->getLogicalTypeID() == LogicalTypeID::VAR_LIST ||
type->getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL);
auto varListTypeInfo = reinterpret_cast<VarListTypeInfo*>(type->extraTypeInfo.get());
return varListTypeInfo->getChildType();
}
Expand Down
3 changes: 2 additions & 1 deletion src/include/common/types/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ class Value {
}
}

std::vector<std::unique_ptr<Value>> convertKUVarListToVector(ku_list_t& list) const;
std::vector<std::unique_ptr<Value>> convertKUVarListToVector(
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
ku_list_t& list, const LogicalType& childType) const;
std::vector<std::unique_ptr<Value>> convertKUFixedListToVector(const uint8_t* fixedList) const;
std::vector<std::unique_ptr<Value>> convertKUStructToVector(const uint8_t* kuStruct) const;

Expand Down
15 changes: 10 additions & 5 deletions src/include/common/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,28 +106,33 @@ class StringVector {
class ListVector {
public:
static inline ValueVector* getDataVector(const ValueVector* vector) {
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST);
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST ||
vector->dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL);
return reinterpret_cast<ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get())
->getDataVector();
}
static inline uint8_t* getListValues(const ValueVector* vector, const list_entry_t& listEntry) {
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST);
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST ||
vector->dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL);
auto dataVector = getDataVector(vector);
return dataVector->getData() + dataVector->getNumBytesPerValue() * listEntry.offset;
}
static inline uint8_t* getListValuesWithOffset(const ValueVector* vector,
const list_entry_t& listEntry, common::offset_t elementOffsetInList) {
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST);
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST ||
vector->dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL);
return getListValues(vector, listEntry) +
elementOffsetInList * getDataVector(vector)->getNumBytesPerValue();
}
static inline list_entry_t addList(ValueVector* vector, uint64_t listSize) {
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST);
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST ||
vector->dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL);
return reinterpret_cast<ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get())
->addList(listSize);
}
static inline void resetListAuxiliaryBuffer(ValueVector* vector) {
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST);
assert(vector->dataType.getLogicalTypeID() == LogicalTypeID::VAR_LIST ||
vector->dataType.getLogicalTypeID() == LogicalTypeID::RECURSIVE_REL);
reinterpret_cast<ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get())->resetSize();
}
};
Expand Down
2 changes: 2 additions & 0 deletions src/processor/result/factorized_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ uint32_t FactorizedTable::getDataTypeSize(const common::LogicalType& type) {
return getDataTypeSize(*FixedListType::getChildType(&type)) *
FixedListType::getNumElementsInList(&type);
}
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
return sizeof(ku_list_t);
}
Expand Down Expand Up @@ -682,6 +683,7 @@ void FactorizedTable::copyOverflowIfNecessary(
*stringToWriteFrom, *(ku_string_t*)dst);
}
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST: {
diskOverflowFile->writeListOverflowAndUpdateOverflowPtr(
*(ku_list_t*)src, *(ku_list_t*)dst, type);
Expand Down
4 changes: 2 additions & 2 deletions test/runner/e2e_exception_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ class TinySnbExceptionTest : public DBTest {

TEST_F(TinySnbExceptionTest, ReadVarlengthRelPropertyTest1) {
auto result = conn->query("MATCH (a:person)-[e:knows*1..3]->(b:person) RETURN e.age;");
ASSERT_STREQ("Binder exception: e has data type VAR_LIST. (STRUCT,REL,NODE) was expected.",
ASSERT_STREQ("Binder exception: e has data type RECURSIVE_REL. (STRUCT,REL,NODE) was expected.",
result->getErrorMessage().c_str());
}

TEST_F(TinySnbExceptionTest, ReadVarlengthRelPropertyTest2) {
auto result =
conn->query("MATCH (a:person)-[e:knows*1..3]->(b:person) WHERE ID(e) = 0 RETURN COUNT(*);");
ASSERT_STREQ("Binder exception: e has data type VAR_LIST. (REL,NODE) was expected.",
ASSERT_STREQ("Binder exception: e has data type RECURSIVE_REL. (REL,NODE) was expected.",
result->getErrorMessage().c_str());
}

Expand Down