Skip to content

Commit

Permalink
fix issue-606
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Dec 3, 2022
1 parent a81da23 commit 9e057b5
Show file tree
Hide file tree
Showing 32 changed files with 127 additions and 180 deletions.
2 changes: 1 addition & 1 deletion src/expression_evaluator/function_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ void FunctionExpressionEvaluator::init(const ResultSet& resultSet, MemoryManager
selectFunc = ((ScalarFunctionExpression&)*expression).selectFunc;
}
resultVector = make_shared<ValueVector>(expression->dataType, memoryManager);
if (children.empty()) {
if (children.empty()) { // const function, e.g. PI()
resultVector->state = DataChunkState::getSingleValueDataChunkState();
}
for (auto& child : children) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,29 @@ namespace planner {

class LogicalAccumulate : public LogicalOperator {
public:
LogicalAccumulate(expression_vector expressions, vector<uint64_t> flatOutputGroupPositions,
unique_ptr<Schema> schemaBeforeSink, shared_ptr<LogicalOperator> child)
: LogicalOperator{move(child)}, expressions{move(expressions)},
flatOutputGroupPositions{move(flatOutputGroupPositions)}, schemaBeforeSink{
move(schemaBeforeSink)} {}
LogicalAccumulate(expression_vector expressions, unique_ptr<Schema> schemaBeforeSink,
shared_ptr<LogicalOperator> child)
: LogicalOperator{std::move(child)}, expressions{std::move(expressions)},
schemaBeforeSink{std::move(schemaBeforeSink)} {}

LogicalOperatorType getLogicalOperatorType() const override {
return LogicalOperatorType::LOGICAL_ACCUMULATE;
}

string getExpressionsForPrinting() const override {
string result;
for (auto& expression : expressions) {
result += expression->getRawName() + ",";
}
return result;
return ExpressionUtil::toString(expressions);
}

inline expression_vector getExpressions() const { return expressions; }
inline vector<uint64_t> getFlatOutputGroupPositions() const { return flatOutputGroupPositions; }
inline Schema* getSchemaBeforeSink() const { return schemaBeforeSink.get(); }

unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAccumulate>(
expressions, flatOutputGroupPositions, schemaBeforeSink->copy(), children[0]->copy());
expressions, schemaBeforeSink->copy(), children[0]->copy());
}

private:
expression_vector expressions;
// TODO(Xiyang): remove this when fixing issue #606
vector<uint64_t> flatOutputGroupPositions;
unique_ptr<Schema> schemaBeforeSink;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ namespace planner {
class LogicalCrossProduct : public LogicalOperator {
public:
LogicalCrossProduct(unique_ptr<Schema> buildSideSchema,
vector<uint64_t> flatOutputGroupPositions, shared_ptr<LogicalOperator> probeSideChild,
shared_ptr<LogicalOperator> buildSideChild)
shared_ptr<LogicalOperator> probeSideChild, shared_ptr<LogicalOperator> buildSideChild)
: LogicalOperator{std::move(probeSideChild), std::move(buildSideChild)},
buildSideSchema{std::move(buildSideSchema)}, flatOutputGroupPositions{
std::move(flatOutputGroupPositions)} {}
buildSideSchema{std::move(buildSideSchema)} {}

inline LogicalOperatorType getLogicalOperatorType() const override {
return LogicalOperatorType::LOGICAL_CROSS_PRODUCT;
Expand All @@ -21,16 +19,14 @@ class LogicalCrossProduct : public LogicalOperator {
inline string getExpressionsForPrinting() const override { return string(); }

inline Schema* getBuildSideSchema() const { return buildSideSchema.get(); }
inline vector<uint64_t> getFlatOutputGroupPositions() const { return flatOutputGroupPositions; }

inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalCrossProduct>(buildSideSchema->copy(), flatOutputGroupPositions,
children[0]->copy(), children[1]->copy());
return make_unique<LogicalCrossProduct>(
buildSideSchema->copy(), children[0]->copy(), children[1]->copy());
}

private:
unique_ptr<Schema> buildSideSchema;
vector<uint64_t> flatOutputGroupPositions;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class LogicalExtend : public LogicalOperator {
if (!extendToNewGroup) {
nbrGroupPos = boundGroupPos;
} else {
assert(schema.getGroup(boundGroupPos)->getIsFlat());
assert(schema.getGroup(boundGroupPos)->isFlat());
nbrGroupPos = schema.createGroup();
}
schema.insertToGroupAndScope(nbrNode->getInternalIDProperty(), nbrGroupPos);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ class LogicalFlatten : public LogicalOperator {
inline shared_ptr<Expression> getExpression() const { return expression; }

inline void computeSchema(Schema& schema) {
auto group = schema.getGroup(expression);
assert(!group->getIsFlat());
group->setIsFlat(true);
auto groupPos = schema.getGroupPos(expression->getUniqueName());
schema.flattenGroup(groupPos);
}

unique_ptr<LogicalOperator> copy() override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ using namespace kuzu::binder;

class LogicalFTableScan : public LogicalOperator {
public:
LogicalFTableScan(expression_vector expressionsToScan, expression_vector expressionsAccumulated,
vector<uint64_t> flatOutputGroupPositions)
: expressionsToScan{std::move(expressionsToScan)}, expressionsAccumulated{std::move(
expressionsAccumulated)},
flatOutputGroupPositions{std::move(flatOutputGroupPositions)} {}
LogicalFTableScan(expression_vector expressionsToScan, expression_vector expressionsAccumulated)
: expressionsToScan{std::move(expressionsToScan)}, expressionsAccumulated{
std::move(expressionsAccumulated)} {}

inline LogicalOperatorType getLogicalOperatorType() const override {
return LogicalOperatorType::LOGICAL_FTABLE_SCAN;
Expand All @@ -25,19 +23,16 @@ class LogicalFTableScan : public LogicalOperator {

inline expression_vector getExpressionsToScan() const { return expressionsToScan; }
inline expression_vector getExpressionsAccumulated() const { return expressionsAccumulated; }
inline vector<uint64_t> getFlatOutputGroupPositions() const { return flatOutputGroupPositions; }

unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalFTableScan>(
expressionsToScan, expressionsAccumulated, flatOutputGroupPositions);
return make_unique<LogicalFTableScan>(expressionsToScan, expressionsAccumulated);
}

private:
expression_vector expressionsToScan;
// expressionsToScan can be a subset of expressionsAccumulated (i.e. partially scan a factorized
// table).
expression_vector expressionsAccumulated;
vector<uint64_t> flatOutputGroupPositions;
};

} // namespace planner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,29 @@ class LogicalHashJoin : public LogicalOperator {
// Inner and left join.
LogicalHashJoin(vector<shared_ptr<NodeExpression>> joinNodes, JoinType joinType,
bool isProbeAcc, unique_ptr<Schema> buildSideSchema,
vector<uint64_t> flatOutputGroupPositions, expression_vector expressionsToMaterialize,
shared_ptr<LogicalOperator> probeSideChild, shared_ptr<LogicalOperator> buildSideChild)
expression_vector expressionsToMaterialize, shared_ptr<LogicalOperator> probeSideChild,
shared_ptr<LogicalOperator> buildSideChild)
: LogicalHashJoin{std::move(joinNodes), joinType, nullptr, isProbeAcc,
std::move(buildSideSchema), std::move(flatOutputGroupPositions),
std::move(expressionsToMaterialize), std::move(probeSideChild),
std::move(buildSideChild)} {}
std::move(buildSideSchema), std::move(expressionsToMaterialize),
std::move(probeSideChild), std::move(buildSideChild)} {}

// Mark join.
LogicalHashJoin(vector<shared_ptr<NodeExpression>> joinNodes, shared_ptr<Expression> mark,
bool isProbeAcc, unique_ptr<Schema> buildSideSchema,
shared_ptr<LogicalOperator> probeSideChild, shared_ptr<LogicalOperator> buildSideChild)
: LogicalHashJoin{std::move(joinNodes), JoinType::MARK, std::move(mark), isProbeAcc,
std::move(buildSideSchema), vector<uint64_t>{} /* flatOutputGroupPositions */,
expression_vector{} /* expressionsToMaterialize */, std::move(probeSideChild),
std::move(buildSideChild)} {}
std::move(buildSideSchema), expression_vector{} /* expressionsToMaterialize */,
std::move(probeSideChild), std::move(buildSideChild)} {}

LogicalHashJoin(vector<shared_ptr<NodeExpression>> joinNodes, JoinType joinType,
shared_ptr<Expression> mark, bool isProbeAcc, unique_ptr<Schema> buildSideSchema,
vector<uint64_t> flatOutputGroupPositions, expression_vector expressionsToMaterialize,
shared_ptr<LogicalOperator> probeSideChild, shared_ptr<LogicalOperator> buildSideChild)
expression_vector expressionsToMaterialize, shared_ptr<LogicalOperator> probeSideChild,
shared_ptr<LogicalOperator> buildSideChild)
: LogicalOperator{std::move(probeSideChild), std::move(buildSideChild)},
joinNodes(std::move(joinNodes)), joinType{joinType}, mark{std::move(mark)},
isProbeAcc{isProbeAcc},
buildSideSchema(std::move(buildSideSchema)), flatOutputGroupPositions{std::move(
flatOutputGroupPositions)},
expressionsToMaterialize{std::move(expressionsToMaterialize)} {}
buildSideSchema(std::move(buildSideSchema)), expressionsToMaterialize{
std::move(expressionsToMaterialize)} {}

inline LogicalOperatorType getLogicalOperatorType() const override {
return LogicalOperatorType::LOGICAL_HASH_JOIN;
Expand All @@ -67,12 +64,11 @@ class LogicalHashJoin : public LogicalOperator {
}
inline bool getIsProbeAcc() const { return isProbeAcc; }
inline Schema* getBuildSideSchema() const { return buildSideSchema.get(); }
inline vector<uint64_t> getFlatOutputGroupPositions() const { return flatOutputGroupPositions; }

inline unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalHashJoin>(joinNodes, joinType, mark, isProbeAcc,
buildSideSchema->copy(), flatOutputGroupPositions, expressionsToMaterialize,
children[0]->copy(), children[1]->copy());
buildSideSchema->copy(), expressionsToMaterialize, children[0]->copy(),
children[1]->copy());
}

private:
Expand All @@ -81,8 +77,6 @@ class LogicalHashJoin : public LogicalOperator {
shared_ptr<Expression> mark; // when joinType is Mark
bool isProbeAcc;
unique_ptr<Schema> buildSideSchema;
// TODO(Xiyang): solve this with issue 606
vector<uint64_t> flatOutputGroupPositions;
expression_vector expressionsToMaterialize;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class LogicalIndexScanNode : public LogicalScanNode {
inline void computeSchema(Schema& schema) override {
LogicalScanNode::computeSchema(schema);
auto groupPos = schema.getGroupPos(node->getInternalIDPropertyName());
schema.getGroup(groupPos)->setIsFlat(true);
schema.setGroupAsSingleState(groupPos);
}

inline shared_ptr<Expression> getIndexExpression() const { return indexExpression; }
Expand Down
25 changes: 18 additions & 7 deletions src/include/planner/logical_plan/logical_operator/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,22 @@ class FactorizationGroup {
friend class Schema;

public:
FactorizationGroup() : isFlat{false}, cardinalityMultiplier{1} {}
FactorizationGroup() : flat{false}, singleState{false}, cardinalityMultiplier{1} {}
FactorizationGroup(const FactorizationGroup& other)
: isFlat{other.isFlat}, cardinalityMultiplier{other.cardinalityMultiplier},
expressions{other.expressions} {}
: flat{other.flat}, singleState{other.singleState},
cardinalityMultiplier{other.cardinalityMultiplier}, expressions{other.expressions} {}

inline void setIsFlat(bool flag) { isFlat = flag; }
inline bool getIsFlat() const { return isFlat; }
inline void setFlat() {
assert(!flat);
flat = true;
}
inline bool isFlat() const { return flat; }
inline void setSingleState() {
assert(!singleState);
singleState = true;
setFlat();
}
inline bool isSingleState() const { return singleState; }

inline void setMultiplier(uint64_t multiplier) { cardinalityMultiplier = multiplier; }
inline uint64_t getMultiplier() const { return cardinalityMultiplier; }
Expand All @@ -31,7 +40,8 @@ class FactorizationGroup {
inline expression_vector getExpressions() const { return expressions; }

private:
bool isFlat;
bool flat;
bool singleState;
uint64_t cardinalityMultiplier;
expression_vector expressions;
};
Expand Down Expand Up @@ -67,7 +77,8 @@ class Schema {
return expressionNameToGroupPos.at(expressionName);
}

inline void flattenGroup(uint32_t pos) { groups[pos]->isFlat = true; }
inline void flattenGroup(uint32_t pos) { groups[pos]->setFlat(); }
inline void setGroupAsSingleState(uint32_t pos) { groups[pos]->setSingleState(); }

bool isExpressionInScope(const Expression& expression) const;

Expand Down
22 changes: 9 additions & 13 deletions src/include/processor/operator/cross_product.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,20 @@ class CrossProduct : public PhysicalOperator {
public:
CrossProduct(shared_ptr<FTableSharedState> sharedState,
vector<pair<DataPos, DataType>> outVecPosAndTypePairs, vector<uint32_t> colIndicesToScan,
vector<uint64_t> flatDataChunkPositions, unique_ptr<PhysicalOperator> probeChild,
unique_ptr<PhysicalOperator> buildChild, uint32_t id, const string& paramsString)
unique_ptr<PhysicalOperator> probeChild, unique_ptr<PhysicalOperator> buildChild,
uint32_t id, const string& paramsString)
: PhysicalOperator{std::move(probeChild), std::move(buildChild), id, paramsString},
sharedState{std::move(sharedState)}, outVecPosAndTypePairs{std::move(
outVecPosAndTypePairs)},
colIndicesToScan{std::move(colIndicesToScan)}, flatDataChunkPositions{
std::move(flatDataChunkPositions)} {}
sharedState{std::move(sharedState)},
outVecPosAndTypePairs{std::move(outVecPosAndTypePairs)}, colIndicesToScan{std::move(
colIndicesToScan)} {}

// Clone only.
CrossProduct(shared_ptr<FTableSharedState> sharedState,
vector<pair<DataPos, DataType>> outVecPosAndTypePairs, vector<uint32_t> colIndicesToScan,
vector<uint64_t> flatDataChunkPositions, unique_ptr<PhysicalOperator> child, uint32_t id,
const string& paramsString)
unique_ptr<PhysicalOperator> child, uint32_t id, const string& paramsString)
: PhysicalOperator{std::move(child), id, paramsString}, sharedState{std::move(sharedState)},
outVecPosAndTypePairs{std::move(outVecPosAndTypePairs)},
colIndicesToScan{std::move(colIndicesToScan)}, flatDataChunkPositions{
std::move(flatDataChunkPositions)} {}
outVecPosAndTypePairs{std::move(outVecPosAndTypePairs)}, colIndicesToScan{std::move(
colIndicesToScan)} {}

PhysicalOperatorType getOperatorType() override { return PhysicalOperatorType::CROSS_PRODUCT; }

Expand All @@ -35,14 +32,13 @@ class CrossProduct : public PhysicalOperator {

unique_ptr<PhysicalOperator> clone() override {
return make_unique<CrossProduct>(sharedState, outVecPosAndTypePairs, colIndicesToScan,
flatDataChunkPositions, children[0]->clone(), id, paramsString);
children[0]->clone(), id, paramsString);
}

private:
shared_ptr<FTableSharedState> sharedState;
vector<pair<DataPos, DataType>> outVecPosAndTypePairs;
vector<uint32_t> colIndicesToScan;
vector<uint64_t> flatDataChunkPositions;

uint64_t startIdx = 0u;
vector<shared_ptr<ValueVector>> vectorsToScan;
Expand Down
20 changes: 8 additions & 12 deletions src/include/processor/operator/hash_join/hash_join_probe.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,20 @@ struct ProbeDataInfo {
class HashJoinProbe : public PhysicalOperator, FilteringOperator {
public:
HashJoinProbe(shared_ptr<HashJoinSharedState> sharedState, JoinType joinType,
vector<uint64_t> flatDataChunkPositions, const ProbeDataInfo& probeDataInfo,
unique_ptr<PhysicalOperator> probeChild, unique_ptr<PhysicalOperator> buildChild,
uint32_t id, const string& paramsString)
const ProbeDataInfo& probeDataInfo, unique_ptr<PhysicalOperator> probeChild,
unique_ptr<PhysicalOperator> buildChild, uint32_t id, const string& paramsString)
: PhysicalOperator{std::move(probeChild), std::move(buildChild), id, paramsString},
FilteringOperator{probeDataInfo.keysDataPos.size()},
sharedState{std::move(sharedState)}, joinType{joinType},
flatDataChunkPositions{std::move(flatDataChunkPositions)}, probeDataInfo{probeDataInfo} {}
sharedState{std::move(sharedState)}, joinType{joinType}, probeDataInfo{probeDataInfo} {}

// This constructor is used for cloning only.
// HashJoinProbe do not need to clone hashJoinBuild which is on a different pipeline.
HashJoinProbe(shared_ptr<HashJoinSharedState> sharedState, JoinType joinType,
vector<uint64_t> flatDataChunkPositions, const ProbeDataInfo& probeDataInfo,
unique_ptr<PhysicalOperator> probeChild, uint32_t id, const string& paramsString)
const ProbeDataInfo& probeDataInfo, unique_ptr<PhysicalOperator> probeChild, uint32_t id,
const string& paramsString)
: PhysicalOperator{std::move(probeChild), id, paramsString},
FilteringOperator{probeDataInfo.keysDataPos.size()},
sharedState{std::move(sharedState)}, joinType{joinType},
flatDataChunkPositions{std::move(flatDataChunkPositions)}, probeDataInfo{probeDataInfo} {}
sharedState{std::move(sharedState)}, joinType{joinType}, probeDataInfo{probeDataInfo} {}

inline PhysicalOperatorType getOperatorType() override { return HASH_JOIN_PROBE; }

Expand All @@ -76,8 +73,8 @@ class HashJoinProbe : public PhysicalOperator, FilteringOperator {
bool getNextTuplesInternal() override;

inline unique_ptr<PhysicalOperator> clone() override {
return make_unique<HashJoinProbe>(sharedState, joinType, flatDataChunkPositions,
probeDataInfo, children[0]->clone(), id, paramsString);
return make_unique<HashJoinProbe>(
sharedState, joinType, probeDataInfo, children[0]->clone(), id, paramsString);
}

private:
Expand All @@ -93,7 +90,6 @@ class HashJoinProbe : public PhysicalOperator, FilteringOperator {
private:
shared_ptr<HashJoinSharedState> sharedState;
JoinType joinType;
vector<uint64_t> flatDataChunkPositions;

ProbeDataInfo probeDataInfo;
vector<shared_ptr<ValueVector>> vectorsToReadInto;
Expand Down
6 changes: 5 additions & 1 deletion src/include/processor/operator/source_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ class SourceOperator {
auto resultSet = make_shared<ResultSet>(numDataChunks);
for (auto i = 0u; i < numDataChunks; ++i) {
auto dataChunkDescriptor = resultSetDescriptor->getDataChunkDescriptor(i);
resultSet->insert(i, make_shared<DataChunk>(dataChunkDescriptor->getNumValueVectors()));
auto dataChunk = make_shared<DataChunk>(dataChunkDescriptor->getNumValueVectors());
if (dataChunkDescriptor->isSingleState()) {
dataChunk->state = DataChunkState::getSingleValueDataChunkState();
}
resultSet->insert(i, dataChunk);
}
return resultSet;
}
Expand Down
Loading

0 comments on commit 9e057b5

Please sign in to comment.