Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change recursive rel physical type #1674

Merged
merged 1 commit into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,23 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
" are not connected through rel " + parsedName + ".");
}
}
auto dataType = isVariableLength ?
common::LogicalType(common::LogicalTypeID::RECURSIVE_REL,
std::make_unique<VarListTypeInfo>(
std::make_unique<LogicalType>(LogicalTypeID::INTERNAL_ID))) :
common::LogicalType(common::LogicalTypeID::REL);
common::LogicalType dataType;
if (isVariableLength) {
std::vector<std::unique_ptr<StructField>> structFields;
auto varListTypeInfo = std::make_unique<common::VarListTypeInfo>(
std::make_unique<LogicalType>(LogicalTypeID::INTERNAL_ID));
auto nodeStructField = std::make_unique<StructField>(InternalKeyword::NODES,
std::make_unique<LogicalType>(LogicalTypeID::VAR_LIST, varListTypeInfo->copy()));
auto relStructField = std::make_unique<StructField>(InternalKeyword::RELS,
std::make_unique<LogicalType>(LogicalTypeID::VAR_LIST, varListTypeInfo->copy()));
structFields.push_back(std::move(nodeStructField));
structFields.push_back(std::move(relStructField));
auto structTypeInfo = std::make_unique<StructTypeInfo>(std::move(structFields));
dataType =
common::LogicalType(common::LogicalTypeID::RECURSIVE_REL, std::move(structTypeInfo));
} else {
dataType = common::LogicalType(common::LogicalTypeID::REL);
}
auto queryRel = make_shared<RelExpression>(dataType, getUniqueExpressionName(parsedName),
parsedName, tableIDs, srcNode, dstNode, directionType, relPattern.getRelType());
if (isVariableLength) {
Expand Down
6 changes: 3 additions & 3 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ std::unique_ptr<Expression> ExpressionBinder::createInternalNodeIDExpression(
propertyIDPerTable.insert({tableID, INVALID_PROPERTY_ID});
}
return std::make_unique<PropertyExpression>(LogicalType(LogicalTypeID::INTERNAL_ID),
INTERNAL_ID_SUFFIX, node, std::move(propertyIDPerTable), false /* isPrimaryKey */);
InternalKeyword::ID, node, std::move(propertyIDPerTable), false /* isPrimaryKey */);
}

std::shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
Expand All @@ -165,7 +165,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
return node.getInternalIDProperty();
}
case common::LogicalTypeID::REL: {
return bindRelPropertyExpression(expression, INTERNAL_ID_SUFFIX);
return bindRelPropertyExpression(expression, InternalKeyword::ID);
}
default:
throw NotImplementedException("ExpressionBinder::bindInternalIDExpression");
Expand Down Expand Up @@ -240,7 +240,7 @@ std::unique_ptr<Expression> ExpressionBinder::createInternalLengthExpression(
propertyIDPerTable.insert({tableID, INVALID_PROPERTY_ID});
}
return std::make_unique<PropertyExpression>(LogicalType(common::LogicalTypeID::INT64),
INTERNAL_LENGTH_SUFFIX, rel, std::move(propertyIDPerTable), false /* isPrimaryKey */);
InternalKeyword::LENGTH, rel, std::move(propertyIDPerTable), false /* isPrimaryKey */);
}

std::shared_ptr<Expression> ExpressionBinder::bindRecursiveJoinLengthFunction(
Expand Down
2 changes: 1 addition & 1 deletion src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ table_id_t CatalogContent::addRelTableSchema(std::string tableName, RelMultiplic
nodeTableSchemas[srcTableID]->addFwdRelTableID(tableID);
nodeTableSchemas[dstTableID]->addBwdRelTableID(tableID);
auto relInternalIDProperty =
Property(INTERNAL_ID_SUFFIX, LogicalType{LogicalTypeID::INTERNAL_ID});
Property(InternalKeyword::ID, LogicalType{LogicalTypeID::INTERNAL_ID});
properties.insert(properties.begin(), relInternalIDProperty);
for (auto i = 0u; i < properties.size(); ++i) {
properties[i].propertyID = i;
Expand Down
2 changes: 1 addition & 1 deletion src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,13 @@ void LogicalType::setPhysicalType() {
case LogicalTypeID::STRING: {
physicalType = PhysicalTypeID::STRING;
} break;
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::MAP:
case LogicalTypeID::VAR_LIST: {
physicalType = PhysicalTypeID::VAR_LIST;
} break;
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::STRUCT: {
physicalType = PhysicalTypeID::STRUCT;
} break;
Expand Down
2 changes: 1 addition & 1 deletion src/common/types/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ std::string Value::toString() const {
}
return result;
}
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::VAR_LIST:
case LogicalTypeID::FIXED_LIST: {
std::string result = "[";
Expand All @@ -306,6 +305,7 @@ std::string Value::toString() const {
result += "]";
return result;
}
case LogicalTypeID::RECURSIVE_REL:
case LogicalTypeID::STRUCT: {
std::string result = "{";
auto fieldNames = StructType::getFieldNames(&dataType);
Expand Down
1 change: 0 additions & 1 deletion src/common/vector/auxiliary_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ list_entry_t ListAuxiliaryBuffer::addList(uint64_t listSize) {
while (size + listSize > capacity) {
capacity *= 2;
}
auto numBytesPerElement = dataVector->getNumBytesPerValue();
if (needResizeDataVector) {
resizeDataVector(dataVector.get());
}
Expand Down
8 changes: 8 additions & 0 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ void ValueVector::resetAuxiliaryBuffer() {
reinterpret_cast<ListAuxiliaryBuffer*>(auxiliaryBuffer.get())->resetSize();
return;
}
case PhysicalTypeID::STRUCT: {
auto structAuxiliaryBuffer =
reinterpret_cast<StructAuxiliaryBuffer*>(auxiliaryBuffer.get());
for (auto& vector : structAuxiliaryBuffer->getChildrenVectors()) {
vector->resetAuxiliaryBuffer();
}
return;
}
default:
return;
}
Expand Down
2 changes: 1 addition & 1 deletion src/include/binder/expression/property_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class PropertyExpression : public Expression {
return propertyIDPerTable.at(tableID);
}

inline bool isInternalID() const { return getPropertyName() == common::INTERNAL_ID_SUFFIX; }
inline bool isInternalID() const { return getPropertyName() == common::InternalKeyword::ID; }

inline std::unique_ptr<Expression> copy() const override {
return make_unique<PropertyExpression>(*this);
Expand Down
2 changes: 1 addition & 1 deletion src/include/binder/expression/rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class RelExpression : public NodeOrRelExpression {
inline RelDirectionType getDirectionType() const { return directionType; }

inline std::shared_ptr<Expression> getInternalIDProperty() const {
return getPropertyExpression(common::INTERNAL_ID_SUFFIX);
return getPropertyExpression(common::InternalKeyword::ID);
}

inline void setRecursiveInfo(std::unique_ptr<RecursiveInfo> recursiveInfo_) {
Expand Down
4 changes: 2 additions & 2 deletions src/include/catalog/catalog_structs.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct TableSchema {
virtual ~TableSchema() = default;

static inline bool isReservedPropertyName(const std::string& propertyName) {
return propertyName == common::INTERNAL_ID_SUFFIX;
return propertyName == common::InternalKeyword::ID;
}

inline uint32_t getNumProperties() const { return properties.size(); }
Expand Down Expand Up @@ -132,7 +132,7 @@ struct RelTableSchema : TableSchema {

inline Property& getRelIDDefinition() {
for (auto& property : properties) {
if (property.name == common::INTERNAL_ID_SUFFIX) {
if (property.name == common::InternalKeyword::ID) {
return property;
}
}
Expand Down
8 changes: 6 additions & 2 deletions src/include/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ constexpr uint64_t THREAD_SLEEP_TIME_WHEN_WAITING_IN_MICROS = 500;

constexpr uint64_t DEFAULT_CHECKPOINT_WAIT_TIMEOUT_FOR_TRANSACTIONS_TO_LEAVE_IN_MICROS = 5000000;

const std::string INTERNAL_ID_SUFFIX = "_id";
const std::string INTERNAL_LENGTH_SUFFIX = "_length";
struct InternalKeyword {
static constexpr char ID[] = "_id";
static constexpr char LENGTH[] = "_length";
static constexpr char NODES[] = "_nodes";
static constexpr char RELS[] = "_rels";
};

enum PageSizeClass : uint8_t {
PAGE_4KB = 0,
Expand Down
2 changes: 1 addition & 1 deletion src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ struct StructType {
return structTypeInfo->getStructFields();
}

static inline struct_field_idx_t getFieldIdx(const LogicalType* type, std::string& key) {
static inline struct_field_idx_t getFieldIdx(const LogicalType* type, const std::string& key) {
assert(type->getPhysicalType() == PhysicalTypeID::STRUCT);
auto structTypeInfo = reinterpret_cast<StructTypeInfo*>(type->extraTypeInfo.get());
return structTypeInfo->getStructFieldIdx(key);
Expand Down
48 changes: 18 additions & 30 deletions src/include/processor/operator/recursive_extend/frontier_scanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
namespace kuzu {
namespace processor {

struct RecursiveJoinVectors;
/*
* BaseFrontierScanner scans all dst nodes from k'th frontier. To identify the
* destination nodes in the k'th frontier, we use a semi mask that marks the destination nodes (or
Expand All @@ -19,22 +20,20 @@ class BaseFrontierScanner {
currentDstNodeID{common::INVALID_OFFSET, common::INVALID_TABLE_ID} {}
virtual ~BaseFrontierScanner() = default;

size_t scan(common::ValueVector* pathVector, common::ValueVector* dstNodeIDVector,
common::ValueVector* pathLengthVector, common::sel_t& offsetVectorPos,
common::sel_t& dataVectorPos);
size_t scan(RecursiveJoinVectors* vectors, common::sel_t& vectorPos,
common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos);

void resetState(const BaseBFSState& bfsState);

protected:
virtual void initScanFromDstOffset() = 0;
virtual void scanFromDstOffset(common::ValueVector* pathVector,
common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector,
common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos) = 0;
virtual void scanFromDstOffset(RecursiveJoinVectors* vectors, common::sel_t& vectorPos,
common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos) = 0;

inline void writeDstNodeOffsetAndLength(common::ValueVector* dstNodeIDVector,
common::ValueVector* pathLengthVector, common::sel_t& offsetVectorPos) {
dstNodeIDVector->setValue<common::nodeID_t>(offsetVectorPos, currentDstNodeID);
pathLengthVector->setValue<int64_t>(offsetVectorPos, (int64_t)k);
common::ValueVector* pathLengthVector, common::sel_t& vectorPos) {
dstNodeIDVector->setValue<common::nodeID_t>(vectorPos, currentDstNodeID);
pathLengthVector->setValue<int64_t>(vectorPos, (int64_t)k);
}

protected:
Expand All @@ -57,13 +56,8 @@ class DstNodeScanner : public BaseFrontierScanner {

private:
inline void initScanFromDstOffset() final {}
inline void scanFromDstOffset(common::ValueVector* pathVector,
common::ValueVector* dstNodeIDVector, common::ValueVector* pathLengthVector,
common::sel_t& offsetVectorPos, common::sel_t& dataVectorPos) final {
assert(offsetVectorPos < common::DEFAULT_VECTOR_CAPACITY);
writeDstNodeOffsetAndLength(dstNodeIDVector, pathLengthVector, offsetVectorPos);
offsetVectorPos++;
}
void scanFromDstOffset(RecursiveJoinVectors* vectors, common::sel_t& vectorPos,
common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos) final;
};

/*
Expand All @@ -78,7 +72,6 @@ class PathScanner : public BaseFrontierScanner {

public:
PathScanner(TargetDstNodes* targetDstNodes, size_t k) : BaseFrontierScanner{targetDstNodes, k} {
listEntrySize = 2 * k + 1;
nodeIDs.resize(k + 1);
relIDs.resize(k + 1);
}
Expand All @@ -89,20 +82,17 @@ class PathScanner : public BaseFrontierScanner {
initDfs(std::make_pair(currentDstNodeID, dummyRelID), k);
}
// Scan current stacks until exhausted or vector is filled up.
void scanFromDstOffset(common::ValueVector* pathVector, common::ValueVector* dstNodeIDVector,
common::ValueVector* pathLengthVector, common::sel_t& offsetVectorPos,
common::sel_t& dataVectorPos) final;
void scanFromDstOffset(RecursiveJoinVectors* vectors, common::sel_t& vectorPos,
common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos) final;

// Initialize stacks for given offset.
void initDfs(const frontier::node_rel_id_t& nodeAndRelID, size_t currentDepth);

void writePathToVector(common::ValueVector* pathVector, common::ValueVector* dstNodeIDVector,
common::ValueVector* pathLengthVector, common::sel_t& offsetVectorPos,
common::sel_t& dataVectorPos);
void writePathToVector(RecursiveJoinVectors* vectors, common::sel_t& vectorPos,
common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos);

private:
// DFS states
size_t listEntrySize;
std::vector<common::nodeID_t> nodeIDs;
std::vector<common::relID_t> relIDs;
std::stack<nbrs_t> nbrsStack;
Expand All @@ -120,9 +110,8 @@ class DstNodeWithMultiplicityScanner : public BaseFrontierScanner {

private:
inline void initScanFromDstOffset() final {}
void scanFromDstOffset(common::ValueVector* pathVector, common::ValueVector* dstNodeIDVector,
common::ValueVector* pathLengthVector, common::sel_t& offsetVectorPos,
common::sel_t& dataVectorPos) final;
void scanFromDstOffset(RecursiveJoinVectors* vectors, common::sel_t& vectorPos,
common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos) final;
};

/*
Expand All @@ -142,9 +131,8 @@ struct FrontiersScanner {
explicit FrontiersScanner(std::vector<std::unique_ptr<BaseFrontierScanner>> scanners)
: scanners{std::move(scanners)}, cursor{0} {}

void scan(common::ValueVector* pathVector, common::ValueVector* dstNodeIDVector,
common::ValueVector* pathLengthVector, common::sel_t& offsetVectorPos,
common::sel_t& dataVectorPos);
void scan(RecursiveJoinVectors* vectors, common::sel_t& vectorPos,
common::sel_t& nodeIDDataVectorPos, common::sel_t& relIDDataVectorPos);

inline void resetState(const BaseBFSState& bfsState) {
cursor = 0;
Expand Down
26 changes: 14 additions & 12 deletions src/include/processor/operator/recursive_extend/recursive_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ struct RecursiveJoinDataInfo {
}
};

struct RecursiveJoinVectors {
common::ValueVector* srcNodeIDVector = nullptr;
common::ValueVector* dstNodeIDVector = nullptr;
common::ValueVector* pathLengthVector = nullptr;
common::ValueVector* pathVector = nullptr;
common::ValueVector* pathNodeIDVector = nullptr;
common::ValueVector* pathRelIDVector = nullptr;

common::ValueVector* recursiveEdgeIDVector = nullptr;
common::ValueVector* recursiveDstNodeIDVector = nullptr;
};

class RecursiveJoin : public PhysicalOperator {
public:
RecursiveJoin(uint8_t lowerBound, uint8_t upperBound, common::QueryRelType queryRelType,
Expand Down Expand Up @@ -100,7 +112,7 @@ class RecursiveJoin : public PhysicalOperator {

void updateVisitedNodes(common::nodeID_t boundNodeID);

protected:
private:
uint8_t lowerBound;
uint8_t upperBound;
common::QueryRelType queryRelType;
Expand All @@ -114,17 +126,7 @@ class RecursiveJoin : public PhysicalOperator {
std::unique_ptr<PhysicalOperator> recursiveRoot;
ScanFrontier* scanFrontier;

// Vectors
std::vector<common::ValueVector*> vectorsToScan;
common::ValueVector* srcNodeIDVector;
common::ValueVector* dstNodeIDVector;
common::ValueVector* pathLengthVector;
common::ValueVector* pathVector;

// temporary recursive join result.
common::ValueVector* recursiveEdgeIDVector;
common::ValueVector* recursiveDstNodeIDVector;

std::unique_ptr<RecursiveJoinVectors> vectors;
std::unique_ptr<BaseBFSState> bfsState;
std::unique_ptr<FrontiersScanner> frontiersScanner;
std::unique_ptr<TargetDstNodes> targetDstNodes;
Expand Down
Loading
Loading