Skip to content

Commit

Permalink
compute schema based on expressionsToMaterialize
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jan 4, 2023
1 parent 928c1cc commit e159b7a
Show file tree
Hide file tree
Showing 22 changed files with 210 additions and 245 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,27 @@ using namespace kuzu::binder;
class LogicalHashJoin : public LogicalOperator {
public:
// Inner and left join.
LogicalHashJoin(vector<shared_ptr<NodeExpression>> joinNodes, JoinType joinType,
bool isProbeAcc, expression_vector expressionsToMaterialize,
shared_ptr<LogicalOperator> probeSideChild, shared_ptr<LogicalOperator> buildSideChild)
: LogicalHashJoin{std::move(joinNodes), joinType, nullptr, UINT32_MAX, isProbeAcc,
LogicalHashJoin(expression_vector joinNodeIDs, JoinType joinType, bool isProbeAcc,
expression_vector expressionsToMaterialize, shared_ptr<LogicalOperator> probeSideChild,
shared_ptr<LogicalOperator> buildSideChild)
: LogicalHashJoin{std::move(joinNodeIDs), joinType, nullptr, UINT32_MAX, isProbeAcc,
std::move(expressionsToMaterialize), std::move(probeSideChild),
std::move(buildSideChild)} {}

// Mark join.
LogicalHashJoin(vector<shared_ptr<NodeExpression>> joinNodes, shared_ptr<Expression> mark,
uint32_t markPos, bool isProbeAcc, shared_ptr<LogicalOperator> probeSideChild,
LogicalHashJoin(expression_vector joinNodeIDs, shared_ptr<Expression> mark, uint32_t markPos,
bool isProbeAcc, shared_ptr<LogicalOperator> probeSideChild,
shared_ptr<LogicalOperator> buildSideChild)
: LogicalHashJoin{std::move(joinNodes), JoinType::MARK, std::move(mark), markPos,
: LogicalHashJoin{std::move(joinNodeIDs), JoinType::MARK, std::move(mark), markPos,
isProbeAcc, expression_vector{} /* expressionsToMaterialize */,
std::move(probeSideChild), std::move(buildSideChild)} {}

LogicalHashJoin(vector<shared_ptr<NodeExpression>> joinNodes, JoinType joinType,
shared_ptr<Expression> mark, uint32_t markPos, bool isProbeAcc,
expression_vector expressionsToMaterialize, shared_ptr<LogicalOperator> probeSideChild,
shared_ptr<LogicalOperator> buildSideChild)
LogicalHashJoin(expression_vector joinNodeIDs, JoinType joinType, shared_ptr<Expression> mark,
uint32_t markPos, bool isProbeAcc, expression_vector expressionsToMaterialize,
shared_ptr<LogicalOperator> probeSideChild, shared_ptr<LogicalOperator> buildSideChild)
: LogicalOperator{LogicalOperatorType::HASH_JOIN, std::move(probeSideChild),
std::move(buildSideChild)},
joinNodes(std::move(joinNodes)), joinType{joinType}, mark{std::move(mark)},
joinNodeIDs(std::move(joinNodeIDs)), joinType{joinType}, mark{std::move(mark)},
markPos{markPos}, isProbeAcc{isProbeAcc}, expressionsToMaterialize{
std::move(expressionsToMaterialize)} {}

Expand All @@ -46,7 +45,7 @@ class LogicalHashJoin : public LogicalOperator {
inline expression_vector getExpressionsToMaterialize() const {
return expressionsToMaterialize;
}
inline vector<shared_ptr<NodeExpression>> getJoinNodes() const { return joinNodes; }
inline expression_vector getJoinNodeIDs() const { return joinNodeIDs; }
inline JoinType getJoinType() const { return joinType; }

inline shared_ptr<Expression> getMark() const {
Expand All @@ -57,12 +56,12 @@ class LogicalHashJoin : public LogicalOperator {
inline Schema* getBuildSideSchema() const { return children[1]->getSchema(); }

inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalHashJoin>(joinNodes, joinType, mark, markPos, isProbeAcc,
return make_unique<LogicalHashJoin>(joinNodeIDs, joinType, mark, markPos, isProbeAcc,
expressionsToMaterialize, children[0]->copy(), children[1]->copy());
}

private:
vector<shared_ptr<NodeExpression>> joinNodes;
expression_vector joinNodeIDs;
JoinType joinType;
shared_ptr<Expression> mark; // when joinType is Mark
uint32_t markPos;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,34 @@ namespace kuzu {
namespace planner {

struct LogicalIntersectBuildInfo {
LogicalIntersectBuildInfo(shared_ptr<NodeExpression> key, expression_vector expressions)
: key{std::move(key)}, expressionsToMaterialize{std::move(expressions)} {}
LogicalIntersectBuildInfo(shared_ptr<Expression> keyNodeID, expression_vector expressions)
: keyNodeID{std::move(keyNodeID)}, expressionsToMaterialize{std::move(expressions)} {}

inline unique_ptr<LogicalIntersectBuildInfo> copy() {
return make_unique<LogicalIntersectBuildInfo>(key, expressionsToMaterialize);
return make_unique<LogicalIntersectBuildInfo>(keyNodeID, expressionsToMaterialize);
}

shared_ptr<NodeExpression> key;
shared_ptr<Expression> keyNodeID;
expression_vector expressionsToMaterialize;
};

class LogicalIntersect : public LogicalOperator {
public:
LogicalIntersect(shared_ptr<NodeExpression> intersectNode,
shared_ptr<LogicalOperator> probeChild, vector<shared_ptr<LogicalOperator>> buildChildren,
LogicalIntersect(shared_ptr<Expression> intersectNodeID, shared_ptr<LogicalOperator> probeChild,
vector<shared_ptr<LogicalOperator>> buildChildren,
vector<unique_ptr<LogicalIntersectBuildInfo>> buildInfos)
: LogicalOperator{LogicalOperatorType::INTERSECT, std::move(probeChild)},
intersectNode{std::move(intersectNode)}, buildInfos{std::move(buildInfos)} {
intersectNodeID{std::move(intersectNodeID)}, buildInfos{std::move(buildInfos)} {
for (auto& child : buildChildren) {
children.push_back(std::move(child));
}
}

void computeSchema() override;

string getExpressionsForPrinting() const override { return intersectNode->getRawName(); }
string getExpressionsForPrinting() const override { return intersectNodeID->getRawName(); }

inline shared_ptr<NodeExpression> getIntersectNode() const { return intersectNode; }
inline shared_ptr<Expression> getIntersectNodeID() const { return intersectNodeID; }
inline LogicalIntersectBuildInfo* getBuildInfo(uint32_t idx) const {
return buildInfos[idx].get();
}
Expand All @@ -44,7 +44,7 @@ class LogicalIntersect : public LogicalOperator {
unique_ptr<LogicalOperator> copy() override;

private:
shared_ptr<NodeExpression> intersectNode;
shared_ptr<Expression> intersectNodeID;
vector<unique_ptr<LogicalIntersectBuildInfo>> buildInfos;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class LogicalOrderBy : public LogicalOperator {

void computeSchema() override;

string getExpressionsForPrinting() const override;
inline string getExpressionsForPrinting() const override {
return ExpressionUtil::toString(expressionsToOrderBy);
}

inline expression_vector getExpressionsToOrderBy() const { return expressionsToOrderBy; }
inline vector<bool> getIsAscOrders() const { return isAscOrders; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,22 @@ namespace planner {

class LogicalUnion : public LogicalOperator {
public:
LogicalUnion(expression_vector expressions, vector<unique_ptr<Schema>> schemasBeforeUnion,
vector<shared_ptr<LogicalOperator>> children)
LogicalUnion(expression_vector expressions, vector<shared_ptr<LogicalOperator>> children)
: LogicalOperator{LogicalOperatorType::UNION_ALL, std::move(children)},
expressionsToUnion{std::move(expressions)}, schemasBeforeUnion{
std::move(schemasBeforeUnion)} {}
expressionsToUnion{std::move(expressions)} {}

void computeSchema() override;

inline string getExpressionsForPrinting() const override { return string(); }

inline expression_vector getExpressionsToUnion() { return expressionsToUnion; }

inline Schema* getSchemaBeforeUnion(uint32_t idx) { return schemasBeforeUnion[idx].get(); }
inline Schema* getSchemaBeforeUnion(uint32_t idx) { return children[idx]->getSchema(); }

unique_ptr<LogicalOperator> copy() override;

private:
expression_vector expressionsToUnion;
vector<unique_ptr<Schema>> schemasBeforeUnion;
};

} // namespace planner
Expand Down
26 changes: 17 additions & 9 deletions src/include/planner/logical_plan/logical_operator/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ using namespace kuzu::common;
namespace kuzu {
namespace planner {

typedef uint32_t f_group_pos;

class FactorizationGroup {
friend class Schema;

Expand Down Expand Up @@ -56,7 +58,7 @@ class FactorizationGroup {

class Schema {
public:
inline uint32_t getNumGroups() const { return groups.size(); }
inline f_group_pos getNumGroups() const { return groups.size(); }

inline FactorizationGroup* getGroup(shared_ptr<Expression> expression) const {
return getGroup(getGroupPos(expression->getUniqueName()));
Expand All @@ -68,36 +70,36 @@ class Schema {

inline FactorizationGroup* getGroup(uint32_t pos) const { return groups[pos].get(); }

uint32_t createGroup();
f_group_pos createGroup();

void insertToScope(const shared_ptr<Expression>& expression, uint32_t groupPos);

void insertToGroupAndScope(const shared_ptr<Expression>& expression, uint32_t groupPos);

void insertToGroupAndScope(const expression_vector& expressions, uint32_t groupPos);

inline uint32_t getGroupPos(const Expression& expression) const {
inline f_group_pos getGroupPos(const Expression& expression) const {
return getGroupPos(expression.getUniqueName());
}

inline uint32_t getGroupPos(const string& expressionName) const {
inline f_group_pos getGroupPos(const string& expressionName) const {
assert(expressionNameToGroupPos.contains(expressionName));
return expressionNameToGroupPos.at(expressionName);
}

inline pair<uint32_t, uint32_t> getExpressionPos(const Expression& expression) const {
inline pair<f_group_pos, uint32_t> getExpressionPos(const Expression& expression) const {
auto groupPos = getGroupPos(expression);
return make_pair(groupPos, groups[groupPos]->getExpressionPos(expression));
}

inline void flattenGroup(uint32_t pos) { groups[pos]->setFlat(); }
inline void setGroupAsSingleState(uint32_t pos) { groups[pos]->setSingleState(); }
inline void flattenGroup(f_group_pos pos) { groups[pos]->setFlat(); }
inline void setGroupAsSingleState(f_group_pos pos) { groups[pos]->setSingleState(); }

bool isExpressionInScope(const Expression& expression) const;

inline expression_vector getExpressionsInScope() const { return expressionsInScope; }

expression_vector getExpressionsInScope(uint32_t pos) const;
expression_vector getExpressionsInScope(f_group_pos pos) const;

expression_vector getSubExpressionsInScope(const shared_ptr<Expression>& expression);

Expand All @@ -109,7 +111,7 @@ class Schema {
}

// Get the group positions containing at least one expression in scope.
unordered_set<uint32_t> getGroupsPosInScope() const;
unordered_set<f_group_pos> getGroupsPosInScope() const;

unique_ptr<Schema> copy() const;

Expand All @@ -123,5 +125,11 @@ class Schema {
expression_vector expressionsInScope;
};

class SchemaUtils {
public:
static vector<expression_vector> getExpressionsPerGroup(
const expression_vector& expressions, const Schema& schema);
};

} // namespace planner
} // namespace kuzu
35 changes: 9 additions & 26 deletions src/include/planner/logical_plan/logical_operator/sink_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,22 @@

namespace kuzu {
namespace planner {
using namespace kuzu::binder;

// This class contains the logic for re-computing factorization structure after
// This class contains the logic for re-computing factorization structure after sinking
class SinkOperatorUtil {
public:
static void mergeSchema(const Schema& inputSchema, Schema& result, const vector<string>& keys);
static void mergeSchema(const Schema& inputSchema, const expression_vector& expressionsToMerge,
Schema& resultSchema);

static void mergeSchema(const Schema& inputSchema, Schema& result);

static void recomputeSchema(const Schema& inputSchema, Schema& result);

static unordered_set<uint32_t> getGroupsPosIgnoringKeyGroups(
const Schema& schema, const vector<string>& keys);
static void recomputeSchema(const Schema& inputSchema,
const expression_vector& expressionsToMerge, Schema& resultSchema);

private:
static void mergeKeyGroup(const Schema& inputSchema, Schema& resultSchema, uint32_t keyGroupPos,
const vector<string>& keysInGroup);

static inline expression_vector getFlatPayloadsIgnoringKeyGroup(
const Schema& schema, const vector<string>& keys) {
return getFlatPayloads(schema, getGroupsPosIgnoringKeyGroups(schema, keys));
}
static inline expression_vector getFlatPayloads(const Schema& schema) {
return getFlatPayloads(schema, schema.getGroupsPosInScope());
}
static expression_vector getFlatPayloads(
const Schema& schema, const unordered_set<uint32_t>& payloadGroupsPos);
static unordered_map<f_group_pos, expression_vector> getUnFlatPayloadsPerGroup(
const Schema& schema, const expression_vector& payloads);

static inline bool hasUnFlatPayload(const Schema& schema) {
return hasUnFlatPayload(schema, schema.getGroupsPosInScope());
}
static bool hasUnFlatPayload(
const Schema& schema, const unordered_set<uint32_t>& payloadGroupsPos);
static expression_vector getFlatPayloads(
const Schema& schema, const expression_vector& payloads);

static uint32_t appendPayloadsToNewGroup(Schema& schema, expression_vector& payloads);
};
Expand Down
2 changes: 1 addition & 1 deletion src/include/processor/mapper/plan_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class PlanMapper {
const Schema& outSchema, vector<bool>& isInputGroupByHashKeyVectorFlat);

static BuildDataInfo generateBuildDataInfo(const Schema& buildSideSchema,
const vector<shared_ptr<NodeExpression>>& keys, const expression_vector& payloads);
const expression_vector& keys, const expression_vector& payloads);

public:
StorageManager& storageManager;
Expand Down
11 changes: 5 additions & 6 deletions src/include/processor/operator/scan_node_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,23 +118,22 @@ class ScanNodeIDSharedState {

class ScanNodeID : public PhysicalOperator {
public:
ScanNodeID(string nodeName, const DataPos& outDataPos,
ScanNodeID(string nodeID, const DataPos& outDataPos,
shared_ptr<ScanNodeIDSharedState> sharedState, uint32_t id, const string& paramsString)
: PhysicalOperator{PhysicalOperatorType::SCAN_NODE_ID, id, paramsString},
nodeName{std::move(nodeName)}, outDataPos{outDataPos}, sharedState{
std::move(sharedState)} {}
nodeID{std::move(nodeID)}, outDataPos{outDataPos}, sharedState{std::move(sharedState)} {}

bool isSource() const override { return true; }

inline string getNodeName() const { return nodeName; }
inline string getNodeID() const { return nodeID; }
inline ScanNodeIDSharedState* getSharedState() const { return sharedState.get(); }

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

bool getNextTuplesInternal() override;

inline unique_ptr<PhysicalOperator> clone() override {
return make_unique<ScanNodeID>(nodeName, outDataPos, sharedState, id, paramsString);
return make_unique<ScanNodeID>(nodeID, outDataPos, sharedState, id, paramsString);
}

private:
Expand All @@ -144,7 +143,7 @@ class ScanNodeID : public PhysicalOperator {
ScanTableNodeIDSharedState* tableState, node_offset_t startOffset, node_offset_t endOffset);

private:
string nodeName;
string nodeID;
DataPos outDataPos;
shared_ptr<ScanNodeIDSharedState> sharedState;

Expand Down
Loading

0 comments on commit e159b7a

Please sign in to comment.