Skip to content

Commit

Permalink
Change recurisve rel to struct of list
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jun 14, 2023
1 parent 708ac89 commit aa21181
Show file tree
Hide file tree
Showing 21 changed files with 214 additions and 178 deletions.
22 changes: 17 additions & 5 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,23 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
" are not connected through rel " + parsedName + ".");
}
}
auto dataType = isVariableLength ?
common::LogicalType(common::LogicalTypeID::RECURSIVE_REL,
std::make_unique<VarListTypeInfo>(
std::make_unique<LogicalType>(LogicalTypeID::INTERNAL_ID))) :
common::LogicalType(common::LogicalTypeID::REL);
common::LogicalType dataType;
if (isVariableLength) {
std::vector<std::unique_ptr<StructField>> structFields;
auto varListTypeInfo = std::make_unique<common::VarListTypeInfo>(
std::make_unique<LogicalType>(LogicalTypeID::INTERNAL_ID));
auto nodeStructField = std::make_unique<StructField>(InternalKeyword::NODES,
std::make_unique<LogicalType>(LogicalTypeID::VAR_LIST, varListTypeInfo->copy()));
auto relStructField = std::make_unique<StructField>(InternalKeyword::RELS,
std::make_unique<LogicalType>(LogicalTypeID::VAR_LIST, varListTypeInfo->copy()));
structFields.push_back(std::move(nodeStructField));
structFields.push_back(std::move(relStructField));
auto structTypeInfo = std::make_unique<StructTypeInfo>(std::move(structFields));
dataType =
common::LogicalType(common::LogicalTypeID::RECURSIVE_REL, std::move(structTypeInfo));
} else {
dataType = common::LogicalType(common::LogicalTypeID::REL);
}
auto queryRel = make_shared<RelExpression>(dataType, getUniqueExpressionName(parsedName),
parsedName, tableIDs, srcNode, dstNode, directionType, relPattern.getRelType());
if (isVariableLength) {
Expand Down
6 changes: 3 additions & 3 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ std::unique_ptr<Expression> ExpressionBinder::createInternalNodeIDExpression(
propertyIDPerTable.insert({tableID, INVALID_PROPERTY_ID});
}
return std::make_unique<PropertyExpression>(LogicalType(LogicalTypeID::INTERNAL_ID),
INTERNAL_ID_SUFFIX, node, std::move(propertyIDPerTable), false /* isPrimaryKey */);
InternalKeyword::ID, node, std::move(propertyIDPerTable), false /* isPrimaryKey */);
}

std::shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
Expand All @@ -165,7 +165,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
return node.getInternalIDProperty();
}
case common::LogicalTypeID::REL: {
return bindRelPropertyExpression(expression, INTERNAL_ID_SUFFIX);
return bindRelPropertyExpression(expression, InternalKeyword::ID);
}
default:
throw NotImplementedException("ExpressionBinder::bindInternalIDExpression");
Expand Down Expand Up @@ -240,7 +240,7 @@ std::unique_ptr<Expression> ExpressionBinder::createInternalLengthExpression(
propertyIDPerTable.insert({tableID, INVALID_PROPERTY_ID});
}
return std::make_unique<PropertyExpression>(LogicalType(common::LogicalTypeID::INT64),
INTERNAL_LENGTH_SUFFIX, rel, std::move(propertyIDPerTable), false /* isPrimaryKey */);
InternalKeyword::LENGTH, rel, std::move(propertyIDPerTable), false /* isPrimaryKey */);
}

std::shared_ptr<Expression> ExpressionBinder::bindRecursiveJoinLengthFunction(
Expand Down
2 changes: 1 addition & 1 deletion src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ table_id_t CatalogContent::addRelTableSchema(std::string tableName, RelMultiplic
nodeTableSchemas[srcTableID]->addFwdRelTableID(tableID);
nodeTableSchemas[dstTableID]->addBwdRelTableID(tableID);
auto relInternalIDProperty =
Property(INTERNAL_ID_SUFFIX, LogicalType{LogicalTypeID::INTERNAL_ID});
Property(InternalKeyword::ID, LogicalType{LogicalTypeID::INTERNAL_ID});
properties.insert(properties.begin(), relInternalIDProperty);
for (auto i = 0u; i < properties.size(); ++i) {
properties[i].propertyID = i;
Expand Down
2 changes: 1 addition & 1 deletion src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,13 @@ void LogicalType::setPhysicalType() {
case LogicalTypeID::STRING: {
physicalType = PhysicalTypeID::STRING;
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::MAP:
case LogicalTypeID::VAR_LIST: {
physicalType = PhysicalTypeID::VAR_LIST;
} break;
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::STRUCT: {
physicalType = PhysicalTypeID::STRUCT;
} break;
Expand Down
2 changes: 1 addition & 1 deletion src/common/types/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ std::string Value::toString() const {
}
return result;
}
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST:
case LogicalTypeID::FIXED_LIST: {
std::string result = "[";
Expand All @@ -306,6 +305,7 @@ std::string Value::toString() const {
result += "]";
return result;
}
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::STRUCT: {
std::string result = "{";
auto fieldNames = StructType::getFieldNames(&dataType);
Expand Down
1 change: 0 additions & 1 deletion src/common/vector/auxiliary_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ list_entry_t ListAuxiliaryBuffer::addList(uint64_t listSize) {
while (size + listSize > capacity) {
capacity *= 2;
}
auto numBytesPerElement = dataVector->getNumBytesPerValue();
if (needResizeDataVector) {
resizeDataVector(dataVector.get());
}
Expand Down
8 changes: 8 additions & 0 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ void ValueVector::resetAuxiliaryBuffer() {
reinterpret_cast<ListAuxiliaryBuffer*>(auxiliaryBuffer.get())->resetSize();
return;
}
case PhysicalTypeID::STRUCT: {
auto structAuxiliaryBuffer =
reinterpret_cast<StructAuxiliaryBuffer*>(auxiliaryBuffer.get());
for (auto& vector : structAuxiliaryBuffer->getChildrenVectors()) {
vector->resetAuxiliaryBuffer();
}
return;
}
default:
return;
}
Expand Down
2 changes: 1 addition & 1 deletion src/include/binder/expression/property_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class PropertyExpression : public Expression {
return propertyIDPerTable.at(tableID);
}

inline bool isInternalID() const { return getPropertyName() == common::INTERNAL_ID_SUFFIX; }
inline bool isInternalID() const { return getPropertyName() == common::InternalKeyword::ID; }

inline std::unique_ptr<Expression> copy() const override {
return make_unique<PropertyExpression>(*this);
Expand Down
2 changes: 1 addition & 1 deletion src/include/binder/expression/rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class RelExpression : public NodeOrRelExpression {
inline RelDirectionType getDirectionType() const { return directionType; }

inline std::shared_ptr<Expression> getInternalIDProperty() const {
return getPropertyExpression(common::INTERNAL_ID_SUFFIX);
return getPropertyExpression(common::InternalKeyword::ID);
}

inline void setRecursiveInfo(std::unique_ptr<RecursiveInfo> recursiveInfo_) {
Expand Down
4 changes: 2 additions & 2 deletions src/include/catalog/catalog_structs.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct TableSchema {
virtual ~TableSchema() = default;

static inline bool isReservedPropertyName(const std::string& propertyName) {
return propertyName == common::INTERNAL_ID_SUFFIX;
return propertyName == common::InternalKeyword::ID;
}

inline uint32_t getNumProperties() const { return properties.size(); }
Expand Down Expand Up @@ -132,7 +132,7 @@ struct RelTableSchema : TableSchema {

inline Property& getRelIDDefinition() {
for (auto& property : properties) {
if (property.name == common::INTERNAL_ID_SUFFIX) {
if (property.name == common::InternalKeyword::ID) {
return property;
}
}
Expand Down
8 changes: 6 additions & 2 deletions src/include/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ constexpr uint64_t THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS = 500;

constexpr uint64_t DEFAULT_CHECKPOINT_WAIT_TIMEOUT_FOR_TRANSACTIONS_TO_LEAVE_IN_MICROS = 5000000;

const std::string INTERNAL_ID_SUFFIX = "_id";
const std::string INTERNAL_LENGTH_SUFFIX = "_length";
struct InternalKeyword {
static constexpr char ID[] = "_id";
static constexpr char LENGTH[] = "_length";
static constexpr char NODES[] = "_nodes";
static constexpr char RELS[] = "_rels";
};

enum PageSizeClass : uint8_t {
PAGE_4KB = 0,
Expand Down
2 changes: 1 addition & 1 deletion src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ struct StructType {
return structTypeInfo->getStructFields();
}

static inline struct_field_idx_t getFieldIdx(const LogicalType* type, std::string& key) {
static inline struct_field_idx_t getFieldIdx(const LogicalType* type, const std::string& key) {
assert(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = reinterpret_cast<StructTypeInfo*>(type->extraTypeInfo.get());
return structTypeInfo->getStructFieldIdx(key);
Expand Down
48 changes: 18 additions & 30 deletions src/include/processor/operator/recursive_extend/frontier_scanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
namespace kuzu {
namespace processor {

struct RecursiveJoinVectors;
/*
* BaseFrontierScanner scans all dst nodes from k'th frontier. To identify the
* destination nodes in the k'th frontier, we use a semi mask that marks the destination nodes (or
Expand All @@ -19,22 +20,20 @@ class BaseFrontierScanner {
currentDstNodeID{common::INVALID_OFFSET, common::INVALID_TABLE_ID} {}
virtual ~BaseFrontierScanner() = default;

size_t scan(common::ValueVector* pathVector, common::ValueVector* dstNodeIDVector,
common::ValueVector* pathLengthVector, common::sel_t& offsetVectorPos,
common::sel_t& dataVectorPos);
size_t scan(RecursiveJoinVectors* vectors, common::sel_t& vectorPos,
common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos);

void resetState(const BaseBFSState& bfsState);

protected:
virtual void initScanFromDstOffset() = 0;
virtual void scanFromDstOffset(common::ValueVector* pathVector,
common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector,
common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos) = 0;
virtual void scanFromDstOffset(RecursiveJoinVectors* vectors, common::sel_t& vectorPos,
common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos) = 0;

inline void writeDstNodeOffsetAndLength(common::ValueVector* dstNodeIDVector,
common::ValueVector* pathLengthVector, common::sel_t& offsetVectorPos) {
dstNodeIDVector->setValue<common::nodeID_t>(offsetVectorPos, currentDstNodeID);
pathLengthVector->setValue<int64_t>(offsetVectorPos, (int64_t)k);
common::ValueVector* pathLengthVector, common::sel_t& vectorPos) {
dstNodeIDVector->setValue<common::nodeID_t>(vectorPos, currentDstNodeID);
pathLengthVector->setValue<int64_t>(vectorPos, (int64_t)k);
}

protected:
Expand All @@ -57,13 +56,8 @@ class DstNodeScanner : public BaseFrontierScanner {

private:
inline void initScanFromDstOffset() final {}
inline void scanFromDstOffset(common::ValueVector* pathVector,
common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector,
common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos) final {
assert(offsetVectorPos < common::DEFAULT_VECTOR_CAPACITY);
writeDstNodeOffsetAndLength(dstNodeIDVector, pathLengthVector, offsetVectorPos);
offsetVectorPos++;
}
void scanFromDstOffset(RecursiveJoinVectors* vectors, common::sel_t& vectorPos,
common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos) final;
};

/*
Expand All @@ -78,7 +72,6 @@ class PathScanner : public BaseFrontierScanner {

public:
PathScanner(TargetDstNodes* targetDstNodes, size_t k) : BaseFrontierScanner{targetDstNodes, k} {
listEntrySize = 2 * k + 1;
nodeIDs.resize(k + 1);
relIDs.resize(k + 1);
}
Expand All @@ -89,20 +82,17 @@ class PathScanner : public BaseFrontierScanner {
initDfs(std::make_pair(currentDstNodeID, dummyRelID), k);
}
// Scan current stacks until exhausted or vector is filled up.
void scanFromDstOffset(common::ValueVector* pathVector, common::ValueVector* dstNodeIDVector,
common::ValueVector* pathLengthVector, common::sel_t& offsetVectorPos,
common::sel_t& dataVectorPos) final;
void scanFromDstOffset(RecursiveJoinVectors* vectors, common::sel_t& vectorPos,
common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos) final;

// Initialize stacks for given offset.
void initDfs(const frontier::node_rel_id_t& nodeAndRelID, size_t currentDepth);

void writePathToVector(common::ValueVector* pathVector, common::ValueVector* dstNodeIDVector,
common::ValueVector* pathLengthVector, common::sel_t& offsetVectorPos,
common::sel_t& dataVectorPos);
void writePathToVector(RecursiveJoinVectors* vectors, common::sel_t& vectorPos,
common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos);

private:
// DFS states
size_t listEntrySize;
std::vector<common::nodeID_t> nodeIDs;
std::vector<common::relID_t> relIDs;
std::stack<nbrs_t> nbrsStack;
Expand All @@ -120,9 +110,8 @@ class DstNodeWithMultiplicityScanner : public BaseFrontierScanner {

private:
inline void initScanFromDstOffset() final {}
void scanFromDstOffset(common::ValueVector* pathVector, common::ValueVector* dstNodeIDVector,
common::ValueVector* pathLengthVector, common::sel_t& offsetVectorPos,
common::sel_t& dataVectorPos) final;
void scanFromDstOffset(RecursiveJoinVectors* vectors, common::sel_t& vectorPos,
common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos) final;
};

/*
Expand All @@ -142,9 +131,8 @@ struct FrontiersScanner {
explicit FrontiersScanner(std::vector<std::unique_ptr<BaseFrontierScanner>> scanners)
: scanners{std::move(scanners)}, cursor{0} {}

void scan(common::ValueVector* pathVector, common::ValueVector* dstNodeIDVector,
common::ValueVector* pathLengthVector, common::sel_t& offsetVectorPos,
common::sel_t& dataVectorPos);
void scan(RecursiveJoinVectors* vectors, common::sel_t& vectorPos,
common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos);

inline void resetState(const BaseBFSState& bfsState) {
cursor = 0;
Expand Down
26 changes: 14 additions & 12 deletions src/include/processor/operator/recursive_extend/recursive_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ struct RecursiveJoinDataInfo {
}
};

struct RecursiveJoinVectors {
common::ValueVector* srcNodeIDVector = nullptr;
common::ValueVector* dstNodeIDVector = nullptr;
common::ValueVector* pathLengthVector = nullptr;
common::ValueVector* pathVector = nullptr;
common::ValueVector* pathNodeIDVector = nullptr;
common::ValueVector* pathRelIDVector = nullptr;

common::ValueVector* recursiveEdgeIDVector = nullptr;
common::ValueVector* recursiveDstNodeIDVector = nullptr;
};

class RecursiveJoin : public PhysicalOperator {
public:
RecursiveJoin(uint8_t lowerBound, uint8_t upperBound, common::QueryRelType queryRelType,
Expand Down Expand Up @@ -100,7 +112,7 @@ class RecursiveJoin : public PhysicalOperator {

void updateVisitedNodes(common::nodeID_t boundNodeID);

protected:
private:
uint8_t lowerBound;
uint8_t upperBound;
common::QueryRelType queryRelType;
Expand All @@ -114,17 +126,7 @@ class RecursiveJoin : public PhysicalOperator {
std::unique_ptr<PhysicalOperator> recursiveRoot;
ScanFrontier* scanFrontier;

// Vectors
std::vector<common::ValueVector*> vectorsToScan;
common::ValueVector* srcNodeIDVector;
common::ValueVector* dstNodeIDVector;
common::ValueVector* pathLengthVector;
common::ValueVector* pathVector;

// temporary recursive join result.
common::ValueVector* recursiveEdgeIDVector;
common::ValueVector* recursiveDstNodeIDVector;

std::unique_ptr<RecursiveJoinVectors> vectors;
std::unique_ptr<BaseBFSState> bfsState;
std::unique_ptr<FrontiersScanner> frontiersScanner;
std::unique_ptr<TargetDstNodes> targetDstNodes;
Expand Down
Loading

0 comments on commit aa21181

Please sign in to comment.