Skip to content

Commit

Permalink
clean
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jun 9, 2023
1 parent 729328e commit 526176f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 27 deletions.
48 changes: 34 additions & 14 deletions src/include/processor/operator/cross_product.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,60 @@
namespace kuzu {
namespace processor {

class CrossProduct;
class CrossProductLocalState {
friend class CrossProduct;

public:
CrossProductLocalState(std::shared_ptr<FactorizedTable> table, uint64_t maxMorselSize)
: table{std::move(table)}, maxMorselSize{maxMorselSize}, startIdx{0} {}

void init() { startIdx = table->getNumTuples(); }

inline std::unique_ptr<CrossProductLocalState> copy() const {
return std::make_unique<CrossProductLocalState>(table, maxMorselSize);
}

private:
std::shared_ptr<FactorizedTable> table;
uint64_t maxMorselSize;
uint64_t startIdx = 0u;
};

class CrossProduct : public PhysicalOperator {
public:
CrossProduct(std::shared_ptr<FTableSharedState> sharedState, std::vector<DataPos> outVecPos,
std::vector<uint32_t> colIndicesToScan, std::unique_ptr<PhysicalOperator> probeChild,
std::unique_ptr<PhysicalOperator> buildChild, uint32_t id, const std::string& paramsString)
CrossProduct(std::vector<DataPos> outVecPos, std::vector<uint32_t> colIndicesToScan,
std::unique_ptr<CrossProductLocalState> localState,
std::unique_ptr<PhysicalOperator> probeChild, std::unique_ptr<PhysicalOperator> buildChild,
uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::CROSS_PRODUCT, std::move(probeChild),
std::move(buildChild), id, paramsString},
sharedState{std::move(sharedState)}, outVecPos{std::move(outVecPos)},
colIndicesToScan{std::move(colIndicesToScan)} {}
outVecPos{std::move(outVecPos)}, colIndicesToScan{std::move(colIndicesToScan)},
localState{std::move(localState)} {}

// Clone only.
CrossProduct(std::shared_ptr<FTableSharedState> sharedState, std::vector<DataPos> outVecPos,
std::vector<uint32_t> colIndicesToScan, std::unique_ptr<PhysicalOperator> child,
CrossProduct(std::vector<DataPos> outVecPos, std::vector<uint32_t> colIndicesToScan,
std::unique_ptr<CrossProductLocalState> localState, std::unique_ptr<PhysicalOperator> child,
uint32_t id, const std::string& paramsString)
: PhysicalOperator{PhysicalOperatorType::CROSS_PRODUCT, std::move(child), id, paramsString},
sharedState{std::move(sharedState)}, outVecPos{std::move(outVecPos)},
colIndicesToScan{std::move(colIndicesToScan)} {}
outVecPos{std::move(outVecPos)}, colIndicesToScan{std::move(colIndicesToScan)},
localState{std::move(localState)} {}

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

bool getNextTuplesInternal(ExecutionContext* context) override;

std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<CrossProduct>(
sharedState, outVecPos, colIndicesToScan, children[0]->clone(), id, paramsString);
return std::make_unique<CrossProduct>(outVecPos, colIndicesToScan, localState->copy(),
children[0]->clone(), id, paramsString);
}

private:
// TODO(Xiyang): make this a local state
std::shared_ptr<FTableSharedState> sharedState;
std::vector<DataPos> outVecPos;
std::vector<uint32_t> colIndicesToScan;

uint64_t startIdx = 0u;
std::unique_ptr<CrossProductLocalState> localState;

std::vector<common::ValueVector*> vectorsToScan;
};

Expand Down
7 changes: 5 additions & 2 deletions src/processor/mapper/map_cross_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalCrossProductToPhysical(
outVecPos.emplace_back(outSchema->getExpressionPos(*expression));
colIndicesToScan.push_back(i);
}
return make_unique<CrossProduct>(resultCollector->getSharedState(), std::move(outVecPos),
std::move(colIndicesToScan), std::move(probeSidePrevOperator), std::move(resultCollector),
auto sharedState = resultCollector->getSharedState();
auto localState = std::make_unique<CrossProductLocalState>(
sharedState->getTable(), sharedState->getMaxMorselSize());
return make_unique<CrossProduct>(std::move(outVecPos), std::move(colIndicesToScan),
std::move(localState), std::move(probeSidePrevOperator), std::move(resultCollector),
getOperatorID(), logicalCrossProduct->getExpressionsForPrinting());
}

Expand Down
21 changes: 10 additions & 11 deletions src/processor/operator/cross_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,30 @@ namespace kuzu {
namespace processor {

void CrossProduct::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) {
for (auto pos : outVecPos) {
auto vector = resultSet->getValueVector(pos);
vectorsToScan.push_back(vector.get());
for (auto& pos : outVecPos) {
vectorsToScan.push_back(resultSet->getValueVector(pos).get());
}
startIdx = sharedState->getTable()->getNumTuples();
localState->init();
}

bool CrossProduct::getNextTuplesInternal(ExecutionContext* context) {
// Note: we should NOT morselize right table scanning (i.e. calling sharedState.getMorsel)
// because every thread should scan its own table.
auto table = sharedState->getTable();
auto table = localState->table.get();
if (table->getNumTuples() == 0) {
return false;
}
if (startIdx == table->getNumTuples()) { // no more to scan from right
if (!children[0]->getNextTuple(context)) { // fetch a new left tuple
if (localState->startIdx == table->getNumTuples()) { // no more to scan from right
if (!children[0]->getNextTuple(context)) { // fetch a new left tuple
return false;
}
startIdx = 0; // reset right table scanning for a new left tuple
localState->startIdx = 0; // reset right table scanning for a new left tuple
}
// scan from right table if there is tuple left
auto numTuplesToScan =
std::min(sharedState->getMaxMorselSize(), table->getNumTuples() - startIdx);
table->scan(vectorsToScan, startIdx, numTuplesToScan, colIndicesToScan);
startIdx += numTuplesToScan;
std::min(localState->maxMorselSize, table->getNumTuples() - localState->startIdx);
table->scan(vectorsToScan, localState->startIdx, numTuplesToScan, colIndicesToScan);
localState->startIdx += numTuplesToScan;
metrics->numOutputTuple.increase(numTuplesToScan);
return true;
}
Expand Down

0 comments on commit 526176f

Please sign in to comment.