Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor extend direction & rel data direction #1560

Merged
merged 1 commit into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,31 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
" to relationship with same name is not supported.");
}
auto tableIDs = bindRelTableIDs(relPattern.getTableNames());
// bind node to rel
auto isLeftNodeSrc = RIGHT == relPattern.getDirection();
auto srcNode = isLeftNodeSrc ? leftNode : rightNode;
auto dstNode = isLeftNodeSrc ? rightNode : leftNode;
// bind src & dst node
RelDirectionType directionType;
std::shared_ptr<NodeExpression> srcNode;
std::shared_ptr<NodeExpression> dstNode;
switch (relPattern.getDirection()) {
case ArrowDirection::LEFT: {
srcNode = rightNode;
dstNode = leftNode;
directionType = RelDirectionType::SINGLE;
} break;
case ArrowDirection::RIGHT: {
srcNode = leftNode;
dstNode = rightNode;
directionType = RelDirectionType::SINGLE;
} break;
case ArrowDirection::BOTH: {
// For both direction, left and right will be written with the same label set. So either one
// being src will be correct.
srcNode = leftNode;
dstNode = rightNode;
directionType = RelDirectionType::BOTH;
} break;
default:
throw common::NotImplementedException("Binder::bindQueryRel");
}
if (srcNode->getUniqueName() == dstNode->getUniqueName()) {
throw BinderException("Self-loop rel " + parsedName + " is not supported.");
}
Expand All @@ -160,8 +181,8 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
auto dataType = isVariableLength ? common::LogicalType(common::LogicalTypeID::RECURSIVE_REL) :
common::LogicalType(common::LogicalTypeID::REL);
auto queryRel = make_shared<RelExpression>(std::move(dataType),
getUniqueExpressionName(parsedName), parsedName, tableIDs, srcNode, dstNode,
relPattern.getDirection() != BOTH, relPattern.getRelType(), lowerBound, upperBound);
getUniqueExpressionName(parsedName), parsedName, tableIDs, srcNode, dstNode, directionType,
relPattern.getRelType(), lowerBound, upperBound);
queryRel->setAlias(parsedName);
if (isVariableLength) {
queryRel->setInternalLengthExpression(
Expand Down
1 change: 1 addition & 0 deletions src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_subdirectory(vector)
add_library(kuzu_common
OBJECT
assert.cpp
rel_direction.cpp
expression_type.cpp
file_utils.cpp
in_mem_overflow_buffer.cpp
Expand Down
22 changes: 22 additions & 0 deletions src/common/rel_direction.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "common/rel_direction.h"

#include "common/exception.h"

namespace kuzu {
namespace common {

std::string RelDataDirectionUtils::relDataDirectionToString(RelDataDirection direction) {
switch (direction) {
case RelDataDirection::FWD: {
return "forward";
}
case RelDataDirection::BWD: {
return "backward";
}
default:
throw NotImplementedException("RelDataDirectionUtils::relDataDirectionToString");
}
}

} // namespace common
} // namespace kuzu
8 changes: 0 additions & 8 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,14 +475,6 @@ void LogicalType::setPhysicalType() {
}
}

RelDataDirection operator!(RelDataDirection& direction) {
return (FWD == direction) ? BWD : FWD;
}

std::string getRelDataDirectionAsString(RelDataDirection direction) {
return (FWD == direction) ? "forward" : "backward";
}

// Specialized Ser/Deser functions for logical dataTypes.
template<>
uint64_t SerDeser::serializeValue(
Expand Down
15 changes: 10 additions & 5 deletions src/include/binder/expression/rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@
namespace kuzu {
namespace binder {

enum class RelDirectionType : uint8_t {
SINGLE = 0,
BOTH = 1,
};

class RelExpression : public NodeOrRelExpression {
public:
RelExpression(common::LogicalType dataType, std::string uniqueName, std::string variableName,
std::vector<common::table_id_t> tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, bool directed, common::QueryRelType relType,
uint64_t lowerBound, uint64_t upperBound)
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType,
common::QueryRelType relType, uint64_t lowerBound, uint64_t upperBound)
: NodeOrRelExpression{std::move(dataType), std::move(uniqueName), std::move(variableName),
std::move(tableIDs)},
srcNode{std::move(srcNode)}, dstNode{std::move(dstNode)}, directed{directed},
srcNode{std::move(srcNode)}, dstNode{std::move(dstNode)}, directionType{directionType},
relType{relType}, lowerBound{lowerBound}, upperBound{upperBound} {}

inline bool isBoundByMultiLabeledNode() const {
Expand All @@ -31,7 +36,7 @@ class RelExpression : public NodeOrRelExpression {
inline uint64_t getLowerBound() const { return lowerBound; }
inline uint64_t getUpperBound() const { return upperBound; }

inline bool isDirected() const { return directed; }
inline RelDirectionType getDirectionType() const { return directionType; }

inline bool hasInternalIDProperty() const {
return hasPropertyExpression(common::INTERNAL_ID_SUFFIX);
Expand All @@ -50,7 +55,7 @@ class RelExpression : public NodeOrRelExpression {
private:
std::shared_ptr<NodeExpression> srcNode;
std::shared_ptr<NodeExpression> dstNode;
bool directed;
RelDirectionType directionType;
common::QueryRelType relType;
uint64_t lowerBound;
uint64_t upperBound;
Expand Down
1 change: 1 addition & 0 deletions src/include/catalog/catalog_structs.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "common/constants.h"
#include "common/exception.h"
#include "common/rel_direction.h"
#include "common/types/types_include.h"

namespace kuzu {
Expand Down
21 changes: 21 additions & 0 deletions src/include/common/rel_direction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include <cstdint>
#include <string>
#include <vector>

namespace kuzu {
namespace common {

enum RelDataDirection : uint8_t { FWD = 0, BWD = 1 };

struct RelDataDirectionUtils {
static inline std::vector<RelDataDirection> getRelDataDirections() {
return std::vector<RelDataDirection>{RelDataDirection::FWD, RelDataDirection::BWD};
}

static std::string relDataDirectionToString(RelDataDirection direction);
};

} // namespace common
} // namespace kuzu
8 changes: 0 additions & 8 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,6 @@ class LogicalTypeUtils {
static LogicalTypeID dataTypeIDFromString(const std::string& dataTypeIDString);
};

// RelDataDirection
enum RelDataDirection : uint8_t { FWD = 0, BWD = 1 };
const std::vector<RelDataDirection> REL_DIRECTIONS = {FWD, BWD};
RelDataDirection operator!(RelDataDirection& direction);
std::string getRelDataDirectionAsString(RelDataDirection relDirection);

enum class ExtendDirection : uint8_t { FWD = 0, BWD = 1, BOTH = 2 };

enum class DBFileType : uint8_t { ORIGINAL = 0, WAL_VERSION = 1 };

} // namespace common
Expand Down
3 changes: 1 addition & 2 deletions src/include/parser/query/graph_pattern/rel_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
namespace kuzu {
namespace parser {

enum ArrowDirection : uint8_t { LEFT = 0, RIGHT = 1, BOTH = 2 };

enum class ArrowDirection : uint8_t { LEFT = 0, RIGHT = 1, BOTH = 2 };
/**
* RelationshipPattern represents "-[relName:RelTableName+]-"
*/
Expand Down
11 changes: 6 additions & 5 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "catalog/catalog.h"
#include "common/join_type.h"
#include "planner/join_order_enumerator_context.h"
#include "planner/logical_plan/logical_operator/extend_direction.h"
#include "storage/store/nodes_statistics_and_deleted_ids.h"

namespace kuzu {
Expand Down Expand Up @@ -59,8 +60,9 @@ class JoinOrderEnumerator {
void planBaseTableScan();
void planNodeScan(uint32_t nodePos);
void planRelScan(uint32_t relPos);
void appendExtendAndFilter(std::shared_ptr<RelExpression> rel,
common::ExtendDirection direction, const expression_vector& predicates, LogicalPlan& plan);
void appendExtendAndFilter(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
ExtendDirection direction, const expression_vector& predicates, LogicalPlan& plan);

void planLevel(uint32_t level);
void planLevelExactly(uint32_t level);
Expand All @@ -84,11 +86,10 @@ class JoinOrderEnumerator {

void appendNonRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::ExtendDirection direction, const binder::expression_vector& properties,
LogicalPlan& plan);
ExtendDirection direction, const binder::expression_vector& properties, LogicalPlan& plan);
void appendRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::ExtendDirection direction, LogicalPlan& plan);
ExtendDirection direction, LogicalPlan& plan);

void planJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "base_logical_operator.h"
#include "binder/expression/rel_expression.h"
#include "extend_direction.h"

namespace kuzu {
namespace planner {
Expand All @@ -11,14 +12,14 @@ class BaseLogicalExtend : public LogicalOperator {
BaseLogicalExtend(LogicalOperatorType operatorType,
std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::ExtendDirection direction, std::shared_ptr<LogicalOperator> child)
ExtendDirection direction, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{operatorType, std::move(child)}, boundNode{std::move(boundNode)},
nbrNode{std::move(nbrNode)}, rel{std::move(rel)}, direction{direction} {}

inline std::shared_ptr<binder::NodeExpression> getBoundNode() const { return boundNode; }
inline std::shared_ptr<binder::NodeExpression> getNbrNode() const { return nbrNode; }
inline std::shared_ptr<binder::RelExpression> getRel() const { return rel; }
inline common::ExtendDirection getDirection() const { return direction; }
inline ExtendDirection getDirection() const { return direction; }
virtual f_group_pos_set getGroupsPosToFlatten() = 0;

std::string getExpressionsForPrinting() const override;
Expand All @@ -27,7 +28,7 @@ class BaseLogicalExtend : public LogicalOperator {
std::shared_ptr<binder::NodeExpression> boundNode;
std::shared_ptr<binder::NodeExpression> nbrNode;
std::shared_ptr<binder::RelExpression> rel;
common::ExtendDirection direction;
ExtendDirection direction;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include <cstdint>

#include "binder/expression/rel_expression.h"
#include "common/rel_direction.h"

namespace kuzu {
namespace planner {

enum class ExtendDirection : uint8_t { FWD = 0, BWD = 1, BOTH = 2 };

struct ExtendDirectionUtils {
static inline ExtendDirection getExtendDirection(
const binder::RelExpression& relExpression, const binder::NodeExpression& boundNode) {
if (relExpression.getDirectionType() == binder::RelDirectionType::BOTH) {
return ExtendDirection::BOTH;
}
if (relExpression.getSrcNodeName() == boundNode.getUniqueName()) {
return ExtendDirection::FWD;
} else {
return ExtendDirection::BWD;
}
}

static inline common::RelDataDirection getRelDataDirection(ExtendDirection extendDirection) {
assert(extendDirection != ExtendDirection::BOTH);
return extendDirection == ExtendDirection::FWD ? common::RelDataDirection::FWD :
common::RelDataDirection::BWD;
}
};

} // namespace planner
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class LogicalExtend : public BaseLogicalExtend {
public:
LogicalExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::ExtendDirection direction, binder::expression_vector properties,
bool hasAtMostOneNbr, std::shared_ptr<LogicalOperator> child)
ExtendDirection direction, binder::expression_vector properties, bool hasAtMostOneNbr,
std::shared_ptr<LogicalOperator> child)
: BaseLogicalExtend{LogicalOperatorType::EXTEND, std::move(boundNode), std::move(nbrNode),
std::move(rel), direction, std::move(child)},
properties{std::move(properties)}, hasAtMostOneNbr{hasAtMostOneNbr} {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ class LogicalRecursiveExtend : public BaseLogicalExtend {
public:
LogicalRecursiveExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::ExtendDirection direction, std::shared_ptr<LogicalOperator> child)
ExtendDirection direction, std::shared_ptr<LogicalOperator> child)
: LogicalRecursiveExtend{std::move(boundNode), std::move(nbrNode), std::move(rel),
direction, RecursiveJoinType::TRACK_PATH, std::move(child)} {}
LogicalRecursiveExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::ExtendDirection direction, RecursiveJoinType joinType,
ExtendDirection direction, RecursiveJoinType joinType,
std::shared_ptr<LogicalOperator> child)
: BaseLogicalExtend{LogicalOperatorType::RECURSIVE_EXTEND, std::move(boundNode),
std::move(nbrNode), std::move(rel), direction, std::move(child)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "catalog/catalog_structs.h"
#include "common/data_chunk/data_chunk.h"
#include "common/rel_direction.h"
#include "common/types/types.h"
#include "processor/result/factorized_table.h"
#include "storage/storage_structure/lists/list_handle.h"
Expand Down Expand Up @@ -78,8 +79,9 @@ class ListsUpdatesStore {
}
initListsUpdatesPerTablePerDirection();
}
inline ListsUpdatesPerChunk& getListsUpdatesPerChunk(common::RelDataDirection relDirection) {
return listsUpdatesPerDirection[relDirection];
inline ListsUpdatesPerChunk& getListsUpdatesPerChunk(
common::RelDataDirection relDataDirection) {
return listsUpdatesPerDirection[relDataDirection];
}

void updateSchema(catalog::RelTableSchema& relTableSchema);
Expand Down
1 change: 1 addition & 0 deletions src/include/storage/wal/wal_record.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "common/rel_direction.h"
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
#include "common/types/types_include.h"
#include "common/utils.h"

Expand Down
Loading