Skip to content

Commit

Permalink
Refactor extend direction
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed May 22, 2023
1 parent 06b1b9d commit a51064d
Show file tree
Hide file tree
Showing 26 changed files with 220 additions and 115 deletions.
33 changes: 27 additions & 6 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,31 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
" to relationship with same name is not supported.");
}
auto tableIDs = bindRelTableIDs(relPattern.getTableNames());
// bind node to rel
auto isLeftNodeSrc = RIGHT == relPattern.getDirection();
auto srcNode = isLeftNodeSrc ? leftNode : rightNode;
auto dstNode = isLeftNodeSrc ? rightNode : leftNode;
// bind src & dst node
RelDirectionType directionType;
std::shared_ptr<NodeExpression> srcNode;
std::shared_ptr<NodeExpression> dstNode;
switch (relPattern.getDirection()) {
case ArrowDirection::LEFT: {
srcNode = rightNode;
dstNode = leftNode;
directionType = RelDirectionType::SINGLE;
} break;
case ArrowDirection::RIGHT: {
srcNode = leftNode;
dstNode = rightNode;
directionType = RelDirectionType::SINGLE;
} break;
case ArrowDirection::BOTH: {
// For both direction, left and right will be written with the same label set. So either one
// being src will be correct.
srcNode = leftNode;
dstNode = rightNode;
directionType = RelDirectionType::BOTH;
} break;
default:
throw common::NotImplementedException("Binder::bindQueryRel");
}
if (srcNode->getUniqueName() == dstNode->getUniqueName()) {
throw BinderException("Self-loop rel " + parsedName + " is not supported.");
}
Expand All @@ -160,8 +181,8 @@ void Binder::bindQueryRel(const RelPattern& relPattern,
auto dataType = isVariableLength ? common::LogicalType(common::LogicalTypeID::RECURSIVE_REL) :
common::LogicalType(common::LogicalTypeID::REL);
auto queryRel = make_shared<RelExpression>(std::move(dataType),
getUniqueExpressionName(parsedName), parsedName, tableIDs, srcNode, dstNode,
relPattern.getDirection() != BOTH, relPattern.getRelType(), lowerBound, upperBound);
getUniqueExpressionName(parsedName), parsedName, tableIDs, srcNode, dstNode, directionType,
relPattern.getRelType(), lowerBound, upperBound);
queryRel->setAlias(parsedName);
if (isVariableLength) {
queryRel->setInternalLengthExpression(
Expand Down
1 change: 1 addition & 0 deletions src/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_subdirectory(vector)
add_library(kuzu_common
OBJECT
assert.cpp
rel_direction.cpp
expression_type.cpp
file_utils.cpp
in_mem_overflow_buffer.cpp
Expand Down
22 changes: 22 additions & 0 deletions src/common/rel_direction.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "common/rel_direction.h"

#include "common/exception.h"

namespace kuzu {
namespace common {

std::string RelDataDirectionUtils::relDataDirectionToString(RelDataDirection direction) {
switch (direction) {
case RelDataDirection::FWD: {
return "forward";
}
case RelDataDirection::BWD: {
return "backward";
}
default:
throw NotImplementedException("RelDataDirectionUtils::relDataDirectionToString");
}
}

} // namespace common
} // namespace kuzu
8 changes: 0 additions & 8 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,14 +475,6 @@ void LogicalType::setPhysicalType() {
}
}

RelDataDirection operator!(RelDataDirection& direction) {
return (FWD == direction) ? BWD : FWD;
}

std::string getRelDataDirectionAsString(RelDataDirection direction) {
return (FWD == direction) ? "forward" : "backward";
}

// Specialized Ser/Deser functions for logical dataTypes.
template<>
uint64_t SerDeser::serializeValue(
Expand Down
15 changes: 10 additions & 5 deletions src/include/binder/expression/rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@
namespace kuzu {
namespace binder {

enum class RelDirectionType : uint8_t {
SINGLE = 0,
BOTH = 1,
};

class RelExpression : public NodeOrRelExpression {
public:
RelExpression(common::LogicalType dataType, std::string uniqueName, std::string variableName,
std::vector<common::table_id_t> tableIDs, std::shared_ptr<NodeExpression> srcNode,
std::shared_ptr<NodeExpression> dstNode, bool directed, common::QueryRelType relType,
uint64_t lowerBound, uint64_t upperBound)
std::shared_ptr<NodeExpression> dstNode, RelDirectionType directionType,
common::QueryRelType relType, uint64_t lowerBound, uint64_t upperBound)
: NodeOrRelExpression{std::move(dataType), std::move(uniqueName), std::move(variableName),
std::move(tableIDs)},
srcNode{std::move(srcNode)}, dstNode{std::move(dstNode)}, directed{directed},
srcNode{std::move(srcNode)}, dstNode{std::move(dstNode)}, directionType{directionType},
relType{relType}, lowerBound{lowerBound}, upperBound{upperBound} {}

inline bool isBoundByMultiLabeledNode() const {
Expand All @@ -31,7 +36,7 @@ class RelExpression : public NodeOrRelExpression {
inline uint64_t getLowerBound() const { return lowerBound; }
inline uint64_t getUpperBound() const { return upperBound; }

inline bool isDirected() const { return directed; }
inline RelDirectionType getDirectionType() const { return directionType; }

inline bool hasInternalIDProperty() const {
return hasPropertyExpression(common::INTERNAL_ID_SUFFIX);
Expand All @@ -50,7 +55,7 @@ class RelExpression : public NodeOrRelExpression {
private:
std::shared_ptr<NodeExpression> srcNode;
std::shared_ptr<NodeExpression> dstNode;
bool directed;
RelDirectionType directionType;
common::QueryRelType relType;
uint64_t lowerBound;
uint64_t upperBound;
Expand Down
1 change: 1 addition & 0 deletions src/include/catalog/catalog_structs.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "common/constants.h"
#include "common/exception.h"
#include "common/rel_direction.h"
#include "common/types/types_include.h"

namespace kuzu {
Expand Down
21 changes: 21 additions & 0 deletions src/include/common/rel_direction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include <cstdint>
#include <string>
#include <vector>

namespace kuzu {
namespace common {

enum RelDataDirection : uint8_t { FWD = 0, BWD = 1 };

struct RelDataDirectionUtils {
static inline std::vector<RelDataDirection> getRelDataDirections() {
return std::vector<RelDataDirection>{RelDataDirection::FWD, RelDataDirection::BWD};
}

static std::string relDataDirectionToString(RelDataDirection direction);
};

} // namespace common
} // namespace kuzu
8 changes: 0 additions & 8 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,6 @@ class LogicalTypeUtils {
static LogicalTypeID dataTypeIDFromString(const std::string& dataTypeIDString);
};

// RelDataDirection
enum RelDataDirection : uint8_t { FWD = 0, BWD = 1 };
const std::vector<RelDataDirection> REL_DIRECTIONS = {FWD, BWD};
RelDataDirection operator!(RelDataDirection& direction);
std::string getRelDataDirectionAsString(RelDataDirection relDirection);

enum class ExtendDirection : uint8_t { FWD = 0, BWD = 1, BOTH = 2 };

enum class DBFileType : uint8_t { ORIGINAL = 0, WAL_VERSION = 1 };

} // namespace common
Expand Down
3 changes: 1 addition & 2 deletions src/include/parser/query/graph_pattern/rel_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
namespace kuzu {
namespace parser {

enum ArrowDirection : uint8_t { LEFT = 0, RIGHT = 1, BOTH = 2 };

enum class ArrowDirection : uint8_t { LEFT = 0, RIGHT = 1, BOTH = 2 };
/**
* RelationshipPattern represents "-[relName:RelTableName+]-"
*/
Expand Down
11 changes: 6 additions & 5 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "catalog/catalog.h"
#include "common/join_type.h"
#include "planner/join_order_enumerator_context.h"
#include "planner/logical_plan/logical_operator/extend_direction.h"
#include "storage/store/nodes_statistics_and_deleted_ids.h"

namespace kuzu {
Expand Down Expand Up @@ -59,8 +60,9 @@ class JoinOrderEnumerator {
void planBaseTableScan();
void planNodeScan(uint32_t nodePos);
void planRelScan(uint32_t relPos);
void appendExtendAndFilter(std::shared_ptr<RelExpression> rel,
common::ExtendDirection direction, const expression_vector& predicates, LogicalPlan& plan);
void appendExtendAndFilter(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
ExtendDirection direction, const expression_vector& predicates, LogicalPlan& plan);

void planLevel(uint32_t level);
void planLevelExactly(uint32_t level);
Expand All @@ -84,11 +86,10 @@ class JoinOrderEnumerator {

void appendNonRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::ExtendDirection direction, const binder::expression_vector& properties,
LogicalPlan& plan);
ExtendDirection direction, const binder::expression_vector& properties, LogicalPlan& plan);
void appendRecursiveExtend(std::shared_ptr<NodeExpression> boundNode,
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
common::ExtendDirection direction, LogicalPlan& plan);
ExtendDirection direction, LogicalPlan& plan);

void planJoin(const binder::expression_vector& joinNodeIDs, common::JoinType joinType,
std::shared_ptr<Expression> mark, LogicalPlan& probePlan, LogicalPlan& buildPlan);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "base_logical_operator.h"
#include "binder/expression/rel_expression.h"
#include "extend_direction.h"

namespace kuzu {
namespace planner {
Expand All @@ -11,14 +12,14 @@ class BaseLogicalExtend : public LogicalOperator {
BaseLogicalExtend(LogicalOperatorType operatorType,
std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::ExtendDirection direction, std::shared_ptr<LogicalOperator> child)
ExtendDirection direction, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{operatorType, std::move(child)}, boundNode{std::move(boundNode)},
nbrNode{std::move(nbrNode)}, rel{std::move(rel)}, direction{direction} {}

inline std::shared_ptr<binder::NodeExpression> getBoundNode() const { return boundNode; }
inline std::shared_ptr<binder::NodeExpression> getNbrNode() const { return nbrNode; }
inline std::shared_ptr<binder::RelExpression> getRel() const { return rel; }
inline common::ExtendDirection getDirection() const { return direction; }
inline ExtendDirection getDirection() const { return direction; }
virtual f_group_pos_set getGroupsPosToFlatten() = 0;

std::string getExpressionsForPrinting() const override;
Expand All @@ -27,7 +28,7 @@ class BaseLogicalExtend : public LogicalOperator {
std::shared_ptr<binder::NodeExpression> boundNode;
std::shared_ptr<binder::NodeExpression> nbrNode;
std::shared_ptr<binder::RelExpression> rel;
common::ExtendDirection direction;
ExtendDirection direction;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include <cstdint>

#include "binder/expression/rel_expression.h"
#include "common/rel_direction.h"

namespace kuzu {
namespace planner {

enum class ExtendDirection : uint8_t { FWD = 0, BWD = 1, BOTH = 2 };

struct ExtendDirectionUtils {
static inline ExtendDirection getExtendDirection(
const binder::RelExpression& relExpression, const binder::NodeExpression& boundNode) {
if (relExpression.getDirectionType() == binder::RelDirectionType::BOTH) {
return ExtendDirection::BOTH;
}
if (relExpression.getSrcNodeName() == boundNode.getUniqueName()) {
return ExtendDirection::FWD;
} else {
return ExtendDirection::BWD;
}
}

static inline common::RelDataDirection getRelDataDirection(ExtendDirection extendDirection) {
assert(extendDirection != ExtendDirection::BOTH);
return extendDirection == ExtendDirection::FWD ? common::RelDataDirection::FWD :
common::RelDataDirection::BWD;
}
};

} // namespace planner
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ class LogicalExtend : public BaseLogicalExtend {
public:
LogicalExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::ExtendDirection direction, binder::expression_vector properties,
bool hasAtMostOneNbr, std::shared_ptr<LogicalOperator> child)
ExtendDirection direction, binder::expression_vector properties, bool hasAtMostOneNbr,
std::shared_ptr<LogicalOperator> child)
: BaseLogicalExtend{LogicalOperatorType::EXTEND, std::move(boundNode), std::move(nbrNode),
std::move(rel), direction, std::move(child)},
properties{std::move(properties)}, hasAtMostOneNbr{hasAtMostOneNbr} {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ class LogicalRecursiveExtend : public BaseLogicalExtend {
public:
LogicalRecursiveExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::ExtendDirection direction, std::shared_ptr<LogicalOperator> child)
ExtendDirection direction, std::shared_ptr<LogicalOperator> child)
: LogicalRecursiveExtend{std::move(boundNode), std::move(nbrNode), std::move(rel),
direction, RecursiveJoinType::TRACK_PATH, std::move(child)} {}
LogicalRecursiveExtend(std::shared_ptr<binder::NodeExpression> boundNode,
std::shared_ptr<binder::NodeExpression> nbrNode, std::shared_ptr<binder::RelExpression> rel,
common::ExtendDirection direction, RecursiveJoinType joinType,
ExtendDirection direction, RecursiveJoinType joinType,
std::shared_ptr<LogicalOperator> child)
: BaseLogicalExtend{LogicalOperatorType::RECURSIVE_EXTEND, std::move(boundNode),
std::move(nbrNode), std::move(rel), direction, std::move(child)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "catalog/catalog_structs.h"
#include "common/data_chunk/data_chunk.h"
#include "common/rel_direction.h"
#include "common/types/types.h"
#include "processor/result/factorized_table.h"
#include "storage/storage_structure/lists/list_handle.h"
Expand Down Expand Up @@ -78,8 +79,9 @@ class ListsUpdatesStore {
}
initListsUpdatesPerTablePerDirection();
}
inline ListsUpdatesPerChunk& getListsUpdatesPerChunk(common::RelDataDirection relDirection) {
return listsUpdatesPerDirection[relDirection];
inline ListsUpdatesPerChunk& getListsUpdatesPerChunk(
common::RelDataDirection relDataDirection) {
return listsUpdatesPerDirection[relDataDirection];
}

void updateSchema(catalog::RelTableSchema& relTableSchema);
Expand Down
1 change: 1 addition & 0 deletions src/include/storage/wal/wal_record.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "common/rel_direction.h"
#include "common/types/types_include.h"
#include "common/utils.h"

Expand Down
Loading

0 comments on commit a51064d

Please sign in to comment.