Skip to content

Commit

Permalink
Merge pull request #1658 from kuzudb/ftable-scan-rework
Browse files Browse the repository at this point in the history
Append FTableScan before RecursiveJoin
  • Loading branch information
andyfengHKU committed Jun 9, 2023
2 parents 976d035 + 0ffde6f commit 43c7026
Show file tree
Hide file tree
Showing 21 changed files with 222 additions and 243 deletions.
47 changes: 34 additions & 13 deletions src/include/processor/operator/cross_product.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,60 @@
namespace kuzu {
namespace processor {

class CrossProduct;
class CrossProductLocalState {
friend class CrossProduct;

public:
CrossProductLocalState(std::shared_ptr<FactorizedTable> table, uint64_t maxMorselSize)
: table{std::move(table)}, maxMorselSize{maxMorselSize}, startIdx{0} {}

void init() { startIdx = table->getNumTuples(); }

inline std::unique_ptr<CrossProductLocalState> copy() const {
return std::make_unique<CrossProductLocalState>(table, maxMorselSize);
}

private:
std::shared_ptr<FactorizedTable> table;
uint64_t maxMorselSize;
uint64_t startIdx = 0u;
};

class CrossProduct : public PhysicalOperator {
public:
CrossProduct(std::shared_ptr<FTableSharedState> sharedState, std::vector<DataPos> outVecPos,
std::vector<uint32_t> colIndicesToScan, std::unique_ptr<PhysicalOperator> probeChild,
std::unique_ptr<PhysicalOperator> buildChild, uint32_t id, const std::string& paramsString)
CrossProduct(std::vector<DataPos> outVecPos, std::vector<uint32_t> colIndicesToScan,
std::unique_ptr<CrossProductLocalState> localState,
std::unique_ptr<PhysicalOperator> probeChild, std::unique_ptr<PhysicalOperator> buildChild,
uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::CROSS_PRODUCT, std::move(probeChild),
std::move(buildChild), id, paramsString},
sharedState{std::move(sharedState)}, outVecPos{std::move(outVecPos)},
colIndicesToScan{std::move(colIndicesToScan)} {}
outVecPos{std::move(outVecPos)}, colIndicesToScan{std::move(colIndicesToScan)},
localState{std::move(localState)} {}

// Clone only.
CrossProduct(std::shared_ptr<FTableSharedState> sharedState, std::vector<DataPos> outVecPos,
std::vector<uint32_t> colIndicesToScan, std::unique_ptr<PhysicalOperator> child,
CrossProduct(std::vector<DataPos> outVecPos, std::vector<uint32_t> colIndicesToScan,
std::unique_ptr<CrossProductLocalState> localState, std::unique_ptr<PhysicalOperator> child,
uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::CROSS_PRODUCT, std::move(child), id, paramsString},
sharedState{std::move(sharedState)}, outVecPos{std::move(outVecPos)},
colIndicesToScan{std::move(colIndicesToScan)} {}
outVecPos{std::move(outVecPos)}, colIndicesToScan{std::move(colIndicesToScan)},
localState{std::move(localState)} {}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

bool getNextTuplesInternal(ExecutionContext* context) override;

std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<CrossProduct>(
sharedState, outVecPos, colIndicesToScan, children[0]->clone(), id, paramsString);
return std::make_unique<CrossProduct>(outVecPos, colIndicesToScan, localState->copy(),
children[0]->clone(), id, paramsString);
}

private:
std::shared_ptr<FTableSharedState> sharedState;
std::vector<DataPos> outVecPos;
std::vector<uint32_t> colIndicesToScan;

uint64_t startIdx = 0u;
std::unique_ptr<CrossProductLocalState> localState;

std::vector<common::ValueVector*> vectorsToScan;
};

Expand Down
59 changes: 18 additions & 41 deletions src/include/processor/operator/recursive_extend/recursive_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,13 @@ namespace processor {
class ScanFrontier;

struct RecursiveJoinSharedState {
std::shared_ptr<FTableSharedState> inputFTableSharedState;
std::vector<std::unique_ptr<NodeOffsetSemiMask>> semiMasks;

RecursiveJoinSharedState(std::shared_ptr<FTableSharedState> inputFTableSharedState,
const std::vector<storage::NodeTable*>& nodeTables)
: inputFTableSharedState{std::move(inputFTableSharedState)} {
for (auto nodeTable : nodeTables) {
semiMasks.push_back(std::make_unique<NodeOffsetSemiMask>(nodeTable));
}
}
RecursiveJoinSharedState(std::vector<std::unique_ptr<NodeOffsetSemiMask>> semiMasks)
: semiMasks{std::move(semiMasks)} {}
};

struct RecursiveJoinDataInfo {
// Scanning input table info.
std::vector<DataPos> vectorsToScanPos;
std::vector<ft_col_idx_t> colIndicesToScan;
// Join input info.
DataPos srcNodePos;
// Join output info.
Expand All @@ -38,43 +29,37 @@ struct RecursiveJoinDataInfo {
// Recursive join info.
std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor;
DataPos recursiveDstNodeIDPos;
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs; // TODO: move this out?
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs;
DataPos recursiveEdgeIDPos;
// Path info
DataPos pathPos;

RecursiveJoinDataInfo(std::vector<DataPos> vectorsToScanPos,
std::vector<ft_col_idx_t> colIndicesToScan, const DataPos& srcNodePos,
const DataPos& dstNodePos, std::unordered_set<common::table_id_t> dstNodeTableIDs,
const DataPos& pathLengthPos, std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor,
RecursiveJoinDataInfo(const DataPos& srcNodePos, const DataPos& dstNodePos,
std::unordered_set<common::table_id_t> dstNodeTableIDs, const DataPos& pathLengthPos,
std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor,
const DataPos& recursiveDstNodeIDPos,
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs,
const DataPos& recursiveEdgeIDPos)
: RecursiveJoinDataInfo{std::move(vectorsToScanPos), std::move(colIndicesToScan),
srcNodePos, dstNodePos, std::move(dstNodeTableIDs), pathLengthPos,
: RecursiveJoinDataInfo{srcNodePos, dstNodePos, std::move(dstNodeTableIDs), pathLengthPos,
std::move(localResultSetDescriptor), recursiveDstNodeIDPos,
std::move(recursiveDstNodeTableIDs), recursiveEdgeIDPos, DataPos()} {}
RecursiveJoinDataInfo(std::vector<DataPos> vectorsToScanPos,
std::vector<ft_col_idx_t> colIndicesToScan, const DataPos& srcNodePos,
const DataPos& dstNodePos, std::unordered_set<common::table_id_t> dstNodeTableIDs,
const DataPos& pathLengthPos, std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor,
RecursiveJoinDataInfo(const DataPos& srcNodePos, const DataPos& dstNodePos,
std::unordered_set<common::table_id_t> dstNodeTableIDs, const DataPos& pathLengthPos,
std::unique_ptr<ResultSetDescriptor> localResultSetDescriptor,
const DataPos& recursiveDstNodeIDPos,
std::unordered_set<common::table_id_t> recursiveDstNodeTableIDs,
const DataPos& recursiveEdgeIDPos, const DataPos& pathPos)
: vectorsToScanPos{std::move(vectorsToScanPos)},
colIndicesToScan{std::move(colIndicesToScan)}, srcNodePos{srcNodePos},
dstNodePos{dstNodePos}, dstNodeTableIDs{std::move(dstNodeTableIDs)},
pathLengthPos{pathLengthPos}, localResultSetDescriptor{std::move(
localResultSetDescriptor)},
: srcNodePos{srcNodePos}, dstNodePos{dstNodePos},
dstNodeTableIDs{std::move(dstNodeTableIDs)}, pathLengthPos{pathLengthPos},
localResultSetDescriptor{std::move(localResultSetDescriptor)},
recursiveDstNodeIDPos{recursiveDstNodeIDPos}, recursiveDstNodeTableIDs{std::move(
recursiveDstNodeTableIDs)},
recursiveEdgeIDPos{recursiveEdgeIDPos}, pathPos{pathPos} {}

inline std::unique_ptr<RecursiveJoinDataInfo> copy() {
return std::make_unique<RecursiveJoinDataInfo>(vectorsToScanPos, colIndicesToScan,
srcNodePos, dstNodePos, dstNodeTableIDs, pathLengthPos,
localResultSetDescriptor->copy(), recursiveDstNodeIDPos, recursiveDstNodeTableIDs,
recursiveEdgeIDPos, pathPos);
return std::make_unique<RecursiveJoinDataInfo>(srcNodePos, dstNodePos, dstNodeTableIDs,
pathLengthPos, localResultSetDescriptor->copy(), recursiveDstNodeIDPos,
recursiveDstNodeTableIDs, recursiveEdgeIDPos, pathPos);
}
};

Expand All @@ -91,15 +76,6 @@ class RecursiveJoin : public PhysicalOperator {
joinType{joinType}, sharedState{std::move(sharedState)}, dataInfo{std::move(dataInfo)},
recursiveRoot{std::move(recursiveRoot)} {}

RecursiveJoin(uint8_t lowerBound, uint8_t upperBound, common::QueryRelType queryRelType,
planner::RecursiveJoinType joinType, std::shared_ptr<RecursiveJoinSharedState> sharedState,
std::unique_ptr<RecursiveJoinDataInfo> dataInfo, uint32_t id,
const std::string& paramsString, std::unique_ptr<PhysicalOperator> recursiveRoot)
: PhysicalOperator{PhysicalOperatorType::RECURSIVE_JOIN, id, paramsString},
lowerBound{lowerBound}, upperBound{upperBound}, queryRelType{queryRelType},
joinType{joinType}, sharedState{std::move(sharedState)}, dataInfo{std::move(dataInfo)},
recursiveRoot{std::move(recursiveRoot)} {}

inline RecursiveJoinSharedState* getSharedState() const { return sharedState.get(); }

void initLocalStateInternal(ResultSet* resultSet_, ExecutionContext* context) final;
Expand All @@ -108,7 +84,8 @@ class RecursiveJoin : public PhysicalOperator {

inline std::unique_ptr<PhysicalOperator> clone() final {
return std::make_unique<RecursiveJoin>(lowerBound, upperBound, queryRelType, joinType,
sharedState, dataInfo->copy(), id, paramsString, recursiveRoot->clone());
sharedState, dataInfo->copy(), children[0]->clone(), id, paramsString,
recursiveRoot->clone());
}

private:
Expand Down
53 changes: 26 additions & 27 deletions src/include/processor/operator/result_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ namespace kuzu {
namespace processor {

struct FTableScanMorsel {

FTableScanMorsel(FactorizedTable* table, uint64_t startTupleIdx, uint64_t numTuples)
: table{table}, startTupleIdx{startTupleIdx}, numTuples{numTuples} {}

Expand All @@ -18,8 +17,18 @@ struct FTableScanMorsel {

class FTableSharedState {
public:
void initTable(
storage::MemoryManager* memoryManager, std::unique_ptr<FactorizedTableSchema> tableSchema);
FTableSharedState(std::shared_ptr<FactorizedTable> table, uint64_t maxMorselSize)
: table{std::move(table)}, maxMorselSize{maxMorselSize} {}
FTableSharedState(storage::MemoryManager* memoryManager,
std::unique_ptr<FactorizedTableSchema> tableSchema, uint64_t maxMorselSize)
: maxMorselSize{maxMorselSize}, nextTupleIdxToScan{0} {
table = std::make_shared<FactorizedTable>(memoryManager, std::move(tableSchema));
}

// We want to control the granularity of morsel, e.g. in recursive join pipeline, we always want
// to scan 1 src at a time.
inline void setMaxMorselSize(uint64_t size) { maxMorselSize = size; }
inline uint64_t getMaxMorselSize() const { return maxMorselSize; }

inline void mergeLocalTable(FactorizedTable& localTable) {
std::lock_guard<std::mutex> lck{mtx};
Expand All @@ -28,55 +37,45 @@ class FTableSharedState {

inline std::shared_ptr<FactorizedTable> getTable() { return table; }

inline void setTable(std::shared_ptr<FactorizedTable> other) { table = other; }

inline uint64_t getMaxMorselSize() {
std::lock_guard<std::mutex> lck{mtx};
return table->hasUnflatCol() ? 1 : common::DEFAULT_VECTOR_CAPACITY;
}
std::unique_ptr<FTableScanMorsel> getMorsel(uint64_t maxMorselSize);
std::unique_ptr<FTableScanMorsel> getMorsel();

private:
std::mutex mtx;
std::shared_ptr<FactorizedTable> table;
uint64_t maxMorselSize;

uint64_t nextTupleIdxToScan = 0u;
};

class ResultCollector : public Sink {
public:
ResultCollector(std::unique_ptr<ResultSetDescriptor> resultSetDescriptor,
std::vector<std::pair<DataPos, common::LogicalType>> payloadsPosAndType,
std::vector<bool> isPayloadFlat, std::shared_ptr<FTableSharedState> sharedState,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
std::unique_ptr<FactorizedTableSchema> tableSchema, std::vector<DataPos> payloadsPos,
std::shared_ptr<FTableSharedState> sharedState, std::unique_ptr<PhysicalOperator> child,
uint32_t id, const std::string& paramsString)
: Sink{std::move(resultSetDescriptor), PhysicalOperatorType::RESULT_COLLECTOR,
std::move(child), id, paramsString},
payloadsPosAndType{std::move(payloadsPosAndType)},
isPayloadFlat{std::move(isPayloadFlat)}, sharedState{std::move(sharedState)} {}
tableSchema{std::move(tableSchema)}, payloadsPos{std::move(payloadsPos)},
sharedState{std::move(sharedState)} {}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

void executeInternal(ExecutionContext* context) override;

std::unique_ptr<PhysicalOperator> clone() override {
return make_unique<ResultCollector>(resultSetDescriptor->copy(), payloadsPosAndType,
isPayloadFlat, sharedState, children[0]->clone(), id, paramsString);
}

inline std::shared_ptr<FTableSharedState> getSharedState() { return sharedState; }
inline std::shared_ptr<FactorizedTable> getResultFactorizedTable() {
return sharedState->getTable();
}

private:
void initGlobalStateInternal(ExecutionContext* context) override;

std::unique_ptr<FactorizedTableSchema> populateTableSchema();
std::unique_ptr<PhysicalOperator> clone() override {
return make_unique<ResultCollector>(resultSetDescriptor->copy(), tableSchema->copy(),
payloadsPos, sharedState, children[0]->clone(), id, paramsString);
}

private:
std::vector<std::pair<DataPos, common::LogicalType>> payloadsPosAndType;
std::vector<bool> isPayloadFlat;
std::vector<common::ValueVector*> vectorsToCollect;
std::unique_ptr<FactorizedTableSchema> tableSchema;
std::vector<DataPos> payloadsPos;
std::vector<common::ValueVector*> payloadVectors;
std::shared_ptr<FTableSharedState> sharedState;
std::unique_ptr<FactorizedTable> localTable;
};
Expand Down
8 changes: 3 additions & 5 deletions src/include/processor/operator/table_scan/base_table_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ class BaseTableScan : public PhysicalOperator {
BaseTableScan(PhysicalOperatorType operatorType, std::vector<DataPos> outVecPositions,
std::vector<uint32_t> colIndicesToScan, std::unique_ptr<PhysicalOperator> child,
uint32_t id, const std::string& paramsString)
: PhysicalOperator{operatorType, std::move(child), id, paramsString}, maxMorselSize{0},
: PhysicalOperator{operatorType, std::move(child), id, paramsString},
outVecPositions{std::move(outVecPositions)}, colIndicesToScan{
std::move(colIndicesToScan)} {}

// For factorized table scan of some columns
BaseTableScan(PhysicalOperatorType operatorType, std::vector<DataPos> outVecPositions,
std::vector<uint32_t> colIndicesToScan, uint32_t id, const std::string& paramsString)
: PhysicalOperator{operatorType, id, paramsString}, maxMorselSize{0},
: PhysicalOperator{operatorType, id, paramsString},
outVecPositions{std::move(outVecPositions)}, colIndicesToScan{
std::move(colIndicesToScan)} {}

Expand All @@ -27,21 +27,19 @@ class BaseTableScan : public PhysicalOperator {
std::vector<uint32_t> colIndicesToScan,
std::vector<std::unique_ptr<PhysicalOperator>> children, uint32_t id,
const std::string& paramsString)
: PhysicalOperator{operatorType, std::move(children), id, paramsString}, maxMorselSize{0},
: PhysicalOperator{operatorType, std::move(children), id, paramsString},
outVecPositions{std::move(outVecPositions)}, colIndicesToScan{
std::move(colIndicesToScan)} {}

inline bool isSource() const override { return true; }

virtual void setMaxMorselSize() = 0;
virtual std::unique_ptr<FTableScanMorsel> getMorsel() = 0;

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

bool getNextTuplesInternal(ExecutionContext* context) override;

protected:
uint64_t maxMorselSize;
std::vector<DataPos> outVecPositions;
std::vector<uint32_t> colIndicesToScan;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,11 @@ class FactorizedTableScan : public BaseTableScan {
std::move(colIndicesToScan), id, paramsString},
sharedState{std::move(sharedState)} {}

inline void setSharedState(std::shared_ptr<FTableSharedState> state) {
sharedState = std::move(state);
}
inline void setMaxMorselSize() override { maxMorselSize = sharedState->getMaxMorselSize(); }
inline std::unique_ptr<FTableScanMorsel> getMorsel() override {
return sharedState->getMorsel(maxMorselSize);
return sharedState->getMorsel();
}

inline std::unique_ptr<PhysicalOperator> clone() override {
assert(sharedState != nullptr);
return make_unique<FactorizedTableScan>(
outVecPositions, colIndicesToScan, sharedState, id, paramsString);
}
Expand Down
6 changes: 2 additions & 4 deletions src/include/processor/operator/table_scan/union_all_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ class UnionAllScanSharedState {
std::vector<std::shared_ptr<FTableSharedState>> fTableSharedStates)
: fTableSharedStates{std::move(fTableSharedStates)}, fTableToScanIdx{0} {}

uint64_t getMaxMorselSize() const;
std::unique_ptr<FTableScanMorsel> getMorsel(uint64_t maxMorselSize);
std::unique_ptr<FTableScanMorsel> getMorsel();

private:
std::mutex mtx;
Expand All @@ -39,9 +38,8 @@ class UnionAllScan : public BaseTableScan {
std::move(colIndicesToScan), id, paramsString},
sharedState{std::move(sharedState)} {}

inline void setMaxMorselSize() override { maxMorselSize = sharedState->getMaxMorselSize(); }
inline std::unique_ptr<FTableScanMorsel> getMorsel() override {
return sharedState->getMorsel(maxMorselSize);
return sharedState->getMorsel();
}

std::unique_ptr<PhysicalOperator> clone() override {
Expand Down
4 changes: 4 additions & 0 deletions src/include/processor/result/factorized_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ class FactorizedTableSchema {
bool operator==(const FactorizedTableSchema& other) const;
inline bool operator!=(const FactorizedTableSchema& other) const { return !(*this == other); }

inline std::unique_ptr<FactorizedTableSchema> copy() const {
return std::make_unique<FactorizedTableSchema>(*this);
}

private:
std::vector<std::unique_ptr<ColumnSchema>> columns;
uint32_t numBytesForDataPerTuple = 0;
Expand Down
1 change: 1 addition & 0 deletions src/processor/mapper/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_library(kuzu_processor_mapper
map_multiplicity_reducer.cpp
map_order_by.cpp
map_projection.cpp
map_recursive_extend.cpp
map_scan_frontier.cpp
map_scan_node.cpp
map_scan_node_property.cpp
Expand Down
7 changes: 5 additions & 2 deletions src/processor/mapper/map_cross_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalCrossProductToPhysical(
outVecPos.emplace_back(outSchema->getExpressionPos(*expression));
colIndicesToScan.push_back(i);
}
return make_unique<CrossProduct>(resultCollector->getSharedState(), std::move(outVecPos),
std::move(colIndicesToScan), std::move(probeSidePrevOperator), std::move(resultCollector),
auto sharedState = resultCollector->getSharedState();
auto localState = std::make_unique<CrossProductLocalState>(
sharedState->getTable(), sharedState->getMaxMorselSize());
return make_unique<CrossProduct>(std::move(outVecPos), std::move(colIndicesToScan),
std::move(localState), std::move(probeSidePrevOperator), std::move(resultCollector),
getOperatorID(), logicalCrossProduct->getExpressionsForPrinting());
}

Expand Down
Loading

0 comments on commit 43c7026

Please sign in to comment.