Skip to content

Commit

Permalink
X
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Mar 7, 2023
1 parent 4262884 commit fec4d9f
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,8 @@ class LogicalCreateRel : public LogicalUpdateRel {
setItemsPerRel{std::move(setItemsPerRel)} {}
~LogicalCreateRel() override = default;

inline void computeFactorizedSchema() override {
copyChildSchema(0);
}
inline void computeFlatSchema() override {
copyChildSchema(0);
}
inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

inline f_group_pos_set getGroupsPosToFlatten() {
auto childSchema = children[0]->getSchema();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class LogicalFilter : public LogicalOperator {
expression)} {}

inline void computeFactorizedSchema() override { copyChildSchema(0); }
inline void computeFlatSchema() override {copyChildSchema(0); }
inline void computeFlatSchema() override { copyChildSchema(0); }

f_group_pos_set getGroupsPosToFlatten();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,13 @@
namespace kuzu {
namespace planner {

struct LogicalIntersectBuildInfo {
LogicalIntersectBuildInfo(
std::shared_ptr<binder::Expression> keyNodeID, binder::expression_vector expressions)
: keyNodeID{std::move(keyNodeID)}, expressionsToMaterialize{std::move(expressions)} {}

inline void setExpressionsToMaterialize(binder::expression_vector expressions) {
expressionsToMaterialize = std::move(expressions);
}
inline binder::expression_vector getExpressionsToMaterialize() const {
return expressionsToMaterialize;
}

inline std::unique_ptr<LogicalIntersectBuildInfo> copy() {
return make_unique<LogicalIntersectBuildInfo>(keyNodeID, expressionsToMaterialize);
}

std::shared_ptr<binder::Expression> keyNodeID;
binder::expression_vector expressionsToMaterialize;
};

class LogicalIntersect : public LogicalOperator {
public:
LogicalIntersect(std::shared_ptr<binder::Expression> intersectNodeID,
std::shared_ptr<LogicalOperator> probeChild,
std::vector<std::shared_ptr<LogicalOperator>> buildChildren,
std::vector<std::unique_ptr<LogicalIntersectBuildInfo>> buildInfos)
binder::expression_vector keyNodeIDs, std::shared_ptr<LogicalOperator> probeChild,
std::vector<std::shared_ptr<LogicalOperator>> buildChildren)
: LogicalOperator{LogicalOperatorType::INTERSECT, std::move(probeChild)},
intersectNodeID{std::move(intersectNodeID)}, buildInfos{std::move(buildInfos)} {
intersectNodeID{std::move(intersectNodeID)}, keyNodeIDs{std::move(keyNodeIDs)} {
for (auto& child : buildChildren) {
children.push_back(std::move(child));
}
Expand All @@ -51,16 +30,17 @@ class LogicalIntersect : public LogicalOperator {
inline std::shared_ptr<binder::Expression> getIntersectNodeID() const {
return intersectNodeID;
}
inline LogicalIntersectBuildInfo* getBuildInfo(uint32_t idx) const {
return buildInfos[idx].get();

inline uint32_t getNumBuilds() const { return keyNodeIDs.size(); }
inline std::shared_ptr<binder::Expression> getKeyNodeID(uint32_t idx) const {
return keyNodeIDs[idx];
}
inline uint32_t getNumBuilds() const { return buildInfos.size(); }

std::unique_ptr<LogicalOperator> copy() override;

private:
std::shared_ptr<binder::Expression> intersectNodeID;
std::vector<std::unique_ptr<LogicalIntersectBuildInfo>> buildInfos;
binder::expression_vector keyNodeIDs;
};

} // namespace planner
Expand Down
15 changes: 8 additions & 7 deletions src/optimizer/projection_push_down_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,30 +71,31 @@ void ProjectionPushDownOptimizer::visitIntersect(planner::LogicalOperator* op) {
auto intersect = (LogicalIntersect*)op;
collectPropertiesInUse(intersect->getIntersectNodeID());
for (auto i = 0u; i < intersect->getNumBuilds(); ++i) {
auto buildInfo = intersect->getBuildInfo(i);
collectPropertiesInUse(buildInfo->keyNodeID);
auto childIdx = i + 1; // skip probe
auto keyNodeID = intersect->getKeyNodeID(i);
collectPropertiesInUse(keyNodeID);
// Note: we have a potential bug under intersect.cpp. The following code ensures build key
// and intersect key always appear as the first and second column. Should be removed once
// the bug is fixed.
expression_vector expressionsBeforePruning;
expression_vector expressionsAfterPruning;
for (auto& expression : buildInfo->expressionsToMaterialize) {
for (auto& expression :
intersect->getChild(childIdx)->getSchema()->getExpressionsInScope()) {
if (expression->getUniqueName() == intersect->getIntersectNodeID()->getUniqueName() ||
expression->getUniqueName() == buildInfo->keyNodeID->getUniqueName()) {
expression->getUniqueName() == keyNodeID->getUniqueName()) {
continue;
}
expressionsBeforePruning.push_back(expression);
}
expressionsAfterPruning.push_back(buildInfo->keyNodeID);
expressionsAfterPruning.push_back(keyNodeID);
expressionsAfterPruning.push_back(intersect->getIntersectNodeID());
for (auto& expression : pruneExpressions(expressionsBeforePruning)) {
expressionsAfterPruning.push_back(expression);
}
if (expressionsBeforePruning.size() == expressionsAfterPruning.size()) {
return;
}
buildInfo->setExpressionsToMaterialize(expressionsAfterPruning);
auto childIdx = i + 1; // skip probe

preAppendProjection(op, childIdx, expressionsAfterPruning);
}
}
Expand Down
16 changes: 6 additions & 10 deletions src/planner/join_order_enumerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,17 +644,13 @@ void JoinOrderEnumerator::appendIntersect(const std::shared_ptr<Expression>& int
std::vector<std::unique_ptr<LogicalPlan>>& buildPlans) {
assert(boundNodeIDs.size() == buildPlans.size());
std::vector<std::shared_ptr<LogicalOperator>> buildChildren;
std::vector<std::unique_ptr<LogicalIntersectBuildInfo>> buildInfos;
binder::expression_vector keyNodeIDs;
for (auto i = 0u; i < buildPlans.size(); ++i) {
auto boundNodeID = boundNodeIDs[i];
auto buildPlan = buildPlans[i].get();
auto buildInfo = std::make_unique<LogicalIntersectBuildInfo>(
boundNodeID, buildPlan->getSchema()->getExpressionsInScope());
buildChildren.push_back(buildPlan->getLastOperator());
buildInfos.push_back(std::move(buildInfo));
}
auto intersect = make_shared<LogicalIntersect>(intersectNodeID, probePlan.getLastOperator(),
std::move(buildChildren), std::move(buildInfos));
keyNodeIDs.push_back(boundNodeIDs[i]);
buildChildren.push_back(buildPlans[i]->getLastOperator());
}
auto intersect = make_shared<LogicalIntersect>(intersectNodeID, std::move(keyNodeIDs),
probePlan.getLastOperator(), std::move(buildChildren));
QueryPlanner::appendFlattens(intersect->getGroupsPosToFlattenOnProbeSide(), probePlan);
intersect->setChild(0, probePlan.getLastOperator());
for (auto i = 0u; i < buildPlans.size(); ++i) {
Expand Down
6 changes: 4 additions & 2 deletions src/planner/operator/logical_hash_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,11 @@ binder::expression_vector LogicalHashJoin::getExpressionsToMaterialize() const {
case common::JoinType::INNER:
case common::JoinType::LEFT: {
return children[1]->getSchema()->getExpressionsInScope();
} case common::JoinType::MARK: {
}
case common::JoinType::MARK: {
return binder::expression_vector{};
} default:
}
default:
throw common::NotImplementedException("HashJoin::getExpressionsToMaterialize");
}
}
Expand Down
21 changes: 9 additions & 12 deletions src/planner/operator/logical_intersect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ namespace planner {

f_group_pos_set LogicalIntersect::getGroupsPosToFlattenOnProbeSide() {
f_group_pos_set result;
for (auto& buildInfo : buildInfos) {
result.insert(children[0]->getSchema()->getGroupPos(*buildInfo->keyNodeID));
for (auto& keyNodeID : keyNodeIDs) {
result.insert(children[0]->getSchema()->getGroupPos(*keyNodeID));
}
return result;
}

f_group_pos_set LogicalIntersect::getGroupsPosToFlattenOnBuildSide(uint32_t buildIdx) {
f_group_pos_set result;
auto childIdx = buildIdx + 1; // skip probe
result.insert(children[childIdx]->getSchema()->getGroupPos(*buildInfos[buildIdx]->keyNodeID));
result.insert(children[childIdx]->getSchema()->getGroupPos(*keyNodeIDs[buildIdx]));
return result;
}

Expand All @@ -26,11 +26,11 @@ void LogicalIntersect::computeFactorizedSchema() {
schema->insertToGroupAndScope(intersectNodeID, outGroupPos);
for (auto i = 1; i < children.size(); ++i) {
auto buildSchema = children[i]->getSchema();
auto buildInfo = buildInfos[i - 1].get();
auto keyNodeID = keyNodeIDs[i - 1];
// Write rel properties into output group.
for (auto& expression : buildSchema->getExpressionsInScope()) {
if (expression->getUniqueName() == intersectNodeID->getUniqueName() ||
expression->getUniqueName() == buildInfo->keyNodeID->getUniqueName()) {
expression->getUniqueName() == keyNodeID->getUniqueName()) {
continue;
}
schema->insertToGroupAndScope(expression, outGroupPos);
Expand All @@ -45,10 +45,10 @@ void LogicalIntersect::computeFlatSchema() {
schema->insertToGroupAndScope(intersectNodeID, 0);
for (auto i = 1; i < children.size(); ++i) {
auto buildSchema = children[i]->getSchema();
auto buildInfo = buildInfos[i - 1].get();
auto keyNodeID = keyNodeIDs[i - 1];
for (auto& expression : buildSchema->getExpressionsInScope()) {
if (expression->getUniqueName() == intersectNodeID->getUniqueName() ||
expression->getUniqueName() == buildInfo->keyNodeID->getUniqueName()) {
expression->getUniqueName() == keyNodeID->getUniqueName()) {
continue;
}
schema->insertToGroupAndScope(expression, 0);
Expand All @@ -58,14 +58,11 @@ void LogicalIntersect::computeFlatSchema() {

std::unique_ptr<LogicalOperator> LogicalIntersect::copy() {
std::vector<std::shared_ptr<LogicalOperator>> buildChildren;
std::vector<std::unique_ptr<LogicalIntersectBuildInfo>> buildInfos_;
for (auto i = 1u; i < children.size(); ++i) {
buildChildren.push_back(children[i]->copy());
buildInfos_.push_back(buildInfos[i - 1]->copy());
}
auto result = make_unique<LogicalIntersect>(
intersectNodeID, children[0]->copy(), std::move(buildChildren), std::move(buildInfos_));
return result;
return make_unique<LogicalIntersect>(
intersectNodeID, keyNodeIDs, children[0]->copy(), std::move(buildChildren));
}

} // namespace planner
Expand Down
11 changes: 5 additions & 6 deletions src/processor/mapper/map_intersect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalIntersectToPhysical(
std::vector<IntersectDataInfo> intersectDataInfos;
// Map build side children.
for (auto i = 1u; i < logicalIntersect->getNumChildren(); i++) {
auto buildInfo = logicalIntersect->getBuildInfo(i - 1);
auto keyNodeID = logicalIntersect->getKeyNodeID(i - 1);
auto buildSchema = logicalIntersect->getChild(i)->getSchema();
auto buildSidePrevOperator = mapLogicalOperatorToPhysical(logicalIntersect->getChild(i));
std::vector<DataPos> payloadsDataPos;
auto buildDataInfo = generateBuildDataInfo(
*buildSchema, {buildInfo->keyNodeID}, buildInfo->expressionsToMaterialize);
auto buildDataInfo =
generateBuildDataInfo(*buildSchema, {keyNodeID}, buildSchema->getExpressionsInScope());
for (auto& [dataPos, _] : buildDataInfo.payloadsPosAndType) {
auto expression = buildSchema->getGroup(dataPos.dataChunkPos)
->getExpressions()[dataPos.valueVectorPos];
Expand All @@ -38,9 +38,8 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalIntersectToPhysical(
sharedStates.push_back(sharedState);
children.push_back(make_unique<IntersectBuild>(
std::make_unique<ResultSetDescriptor>(*buildSchema), sharedState, buildDataInfo,
std::move(buildSidePrevOperator), getOperatorID(), buildInfo->keyNodeID->toString()));
IntersectDataInfo info{
DataPos(outSchema->getExpressionPos(*buildInfo->keyNodeID)), payloadsDataPos};
std::move(buildSidePrevOperator), getOperatorID(), keyNodeID->toString()));
IntersectDataInfo info{DataPos(outSchema->getExpressionPos(*keyNodeID)), payloadsDataPos};
intersectDataInfos.push_back(info);
}
// Map intersect.
Expand Down

0 comments on commit fec4d9f

Please sign in to comment.