Skip to content

Commit

Permalink
Implemented undirected query for single-label queries
Browse files Browse the repository at this point in the history
  • Loading branch information
aziz-mu committed May 17, 2023
1 parent e9e27a1 commit a63026a
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 58 deletions.
24 changes: 12 additions & 12 deletions src/include/processor/operator/scan/generic_scan_rel_tables.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@
namespace kuzu {
namespace processor {

class RelTableCollection {
class RelTableDataCollection {
public:
RelTableCollection(std::vector<storage::DirectedRelTableData*> tables,
RelTableDataCollection(std::vector<storage::DirectedRelTableData*> relTableDatas,
std::vector<std::unique_ptr<storage::RelTableScanState>> tableScanStates)
: tables{std::move(tables)}, tableScanStates{std::move(tableScanStates)} {}
: relTableDatas{std::move(relTableDatas)}, tableScanStates{std::move(tableScanStates)} {}

void resetState();
inline uint32_t getNumTablesInCollection() { return tables.size(); }
inline uint32_t getNumTablesInCollection() { return relTableDatas.size(); }

bool scan(common::ValueVector* inVector, const std::vector<common::ValueVector*>& outputVectors,
transaction::Transaction* transaction);

std::unique_ptr<RelTableCollection> clone() const;
std::unique_ptr<RelTableDataCollection> clone() const;

private:
std::vector<storage::DirectedRelTableData*> tables;
std::vector<storage::DirectedRelTableData*> relTableDatas;
std::vector<std::unique_ptr<storage::RelTableScanState>> tableScanStates;

uint32_t currentRelTableIdxToScan = UINT32_MAX;
Expand All @@ -32,7 +32,7 @@ class RelTableCollection {
class GenericScanRelTables : public ScanRelTable {
public:
GenericScanRelTables(const DataPos& inNodeIDVectorPos, std::vector<DataPos> outputVectorsPos,
std::unordered_map<common::table_id_t, std::unique_ptr<RelTableCollection>>
std::unordered_map<common::table_id_t, std::unique_ptr<RelTableDataCollection>>
relTableCollectionPerNodeTable,
std::unique_ptr<PhysicalOperator> child, uint32_t id, const std::string& paramsString)
: ScanRelTable{inNodeIDVectorPos, std::move(outputVectorsPos),
Expand All @@ -44,7 +44,7 @@ class GenericScanRelTables : public ScanRelTable {
bool getNextTuplesInternal(ExecutionContext* context) override;

std::unique_ptr<PhysicalOperator> clone() override {
std::unordered_map<common::table_id_t, std::unique_ptr<RelTableCollection>>
std::unordered_map<common::table_id_t, std::unique_ptr<RelTableDataCollection>>
clonedCollections;
for (auto& [tableID, propertyCollection] : relTableCollectionPerNodeTable) {
clonedCollections.insert({tableID, propertyCollection->clone()});
Expand All @@ -54,13 +54,13 @@ class GenericScanRelTables : public ScanRelTable {
}

private:
bool scanCurrentRelTableCollection();
void initCurrentRelTableCollection(const common::nodeID_t& nodeID);
bool scanCurrentRelTableDataCollection();
void initCurrentRelTableDataCollection(const common::nodeID_t& nodeID);

private:
std::unordered_map<common::table_id_t, std::unique_ptr<RelTableCollection>>
std::unordered_map<common::table_id_t, std::unique_ptr<RelTableDataCollection>>
relTableCollectionPerNodeTable;
RelTableCollection* currentRelTableCollection = nullptr;
RelTableDataCollection* currentRelTableDataCollection = nullptr;
};

} // namespace processor
Expand Down
5 changes: 3 additions & 2 deletions src/planner/join_order_enumerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,9 @@ void JoinOrderEnumerator::appendNonRecursiveExtend(std::shared_ptr<NodeExpressio
std::shared_ptr<NodeExpression> nbrNode, std::shared_ptr<RelExpression> rel,
ExtendDirection direction, const expression_vector& properties, LogicalPlan& plan) {
auto hasAtMostOneNbr = extendHasAtMostOneNbrGuarantee(*rel, *boundNode, direction, catalog);
auto extend = make_shared<LogicalExtend>(
boundNode, nbrNode, rel, direction, properties, hasAtMostOneNbr, plan.getLastOperator());
auto extend = make_shared<LogicalExtend>(boundNode, nbrNode, rel,
rel->isDirected() ? direction : ExtendDirection::BOTH, properties, hasAtMostOneNbr,
plan.getLastOperator());
queryPlanner->appendFlattens(extend->getGroupsPosToFlatten(), plan);
extend->setChild(0, plan.getLastOperator());
extend->computeFactorizedSchema();
Expand Down
102 changes: 75 additions & 27 deletions src/processor/mapper/map_extend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,77 @@ static std::vector<property_id_t> populatePropertyIds(
return outputColumns;
}

static std::unique_ptr<RelTableCollection> populateRelTableCollection(table_id_t boundNodeTableID,
const RelExpression& rel, RelDataDirection direction, const expression_vector& properties,
const RelsStore& relsStore, const catalog::Catalog& catalog) {
std::vector<DirectedRelTableData*> tables;
static std::pair<DirectedRelTableData*, std::unique_ptr<RelTableScanState>>
getRelTableDataAndScanState(RelDataDirection direction, catalog::RelTableSchema* relTableSchema,
table_id_t boundNodeTableID, const RelsStore& relsStore, table_id_t relTableID,
const expression_vector& properties) {
if (relTableSchema->getBoundTableID(direction) != boundNodeTableID) {
// No data stored for given direction and boundNode.
return std::make_pair(nullptr, nullptr);
}
auto relData = relsStore.getRelTable(relTableID)->getDirectedTableData(direction);
std::vector<property_id_t> propertyIds;
for (auto& property : properties) {
auto propertyExpression = reinterpret_cast<PropertyExpression*>(property.get());
propertyIds.push_back(propertyExpression->hasPropertyID(relTableID) ?
propertyExpression->getPropertyID(relTableID) :
INVALID_PROPERTY_ID);
}
auto scanState = make_unique<RelTableScanState>(
std::move(propertyIds), relsStore.isSingleMultiplicityInDirection(direction, relTableID) ?
RelTableDataType::COLUMNS :
RelTableDataType::LISTS);
return std::make_pair(relData, std::move(scanState));
}

static std::unique_ptr<RelTableDataCollection> populateRelTableDataCollection(
table_id_t boundNodeTableID, const RelExpression& rel, ExtendDirection extendDirection,
const expression_vector& properties, const RelsStore& relsStore,
const catalog::Catalog& catalog) {
std::vector<DirectedRelTableData*> relTableDatas;
std::vector<std::unique_ptr<RelTableScanState>> tableScanStates;
for (auto relTableID : rel.getTableIDs()) {
auto relTableSchema = catalog.getReadOnlyVersion()->getRelTableSchema(relTableID);
if (relTableSchema->getBoundTableID(direction) != boundNodeTableID) {
continue;
switch (extendDirection) {
case ExtendDirection::FWD: {
auto [relTableData, scanState] = getRelTableDataAndScanState(RelDataDirection::FWD,
relTableSchema, boundNodeTableID, relsStore, relTableID, properties);
if (relTableData != nullptr && scanState != nullptr) {
relTableDatas.push_back(relTableData);
tableScanStates.push_back(std::move(scanState));
}
break;
}
case ExtendDirection::BWD: {
auto [relTableData, scanState] = getRelTableDataAndScanState(RelDataDirection::BWD,
relTableSchema, boundNodeTableID, relsStore, relTableID, properties);
if (relTableData != nullptr && scanState != nullptr) {
relTableDatas.push_back(relTableData);
tableScanStates.push_back(std::move(scanState));
}
break;
}
case ExtendDirection::BOTH: {
auto [relTableDataFWD, scanStateFWD] =
getRelTableDataAndScanState(RelDataDirection::FWD, relTableSchema, boundNodeTableID,
relsStore, relTableID, properties);
if (relTableDataFWD != nullptr && scanStateFWD != nullptr) {
relTableDatas.push_back(relTableDataFWD);
tableScanStates.push_back(std::move(scanStateFWD));
}
auto [relTableDataBWD, scanStateBWD] =
getRelTableDataAndScanState(RelDataDirection::BWD, relTableSchema, boundNodeTableID,
relsStore, relTableID, properties);
if (relTableDataBWD != nullptr && scanStateBWD != nullptr) {
relTableDatas.push_back(relTableDataBWD);
tableScanStates.push_back(std::move(scanStateBWD));
}
break;
}
tables.push_back(relsStore.getRelTable(relTableID)->getDirectedTableData(direction));
std::vector<property_id_t> propertyIds;
for (auto& property : properties) {
auto propertyExpression = reinterpret_cast<PropertyExpression*>(property.get());
propertyIds.push_back(propertyExpression->hasPropertyID(relTableID) ?
propertyExpression->getPropertyID(relTableID) :
INVALID_PROPERTY_ID);
}
tableScanStates.push_back(make_unique<RelTableScanState>(std::move(propertyIds),
relsStore.isSingleMultiplicityInDirection(direction, relTableID) ?
RelTableDataType::COLUMNS :
RelTableDataType::LISTS));
}
return std::make_unique<RelTableCollection>(std::move(tables), std::move(tableScanStates));
return std::make_unique<RelTableDataCollection>(
std::move(relTableDatas), std::move(tableScanStates));
}

std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalExtendToPhysical(
Expand All @@ -64,7 +111,7 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalExtendToPhysical(
auto boundNode = extend->getBoundNode();
auto nbrNode = extend->getNbrNode();
auto rel = extend->getRel();
auto direction = getRelDataDirection(extend->getDirection());
auto extendDirection = extend->getDirection();
auto prevOperator = mapLogicalOperatorToPhysical(logicalOperator->getChild(0));
auto inNodeIDVectorPos =
DataPos(inSchema->getExpressionPos(*boundNode->getInternalIDProperty()));
Expand All @@ -76,28 +123,29 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapLogicalExtendToPhysical(
outputVectorsPos.emplace_back(outSchema->getExpressionPos(*expression));
}
auto& relsStore = storageManager.getRelsStore();
if (!rel->isMultiLabeled() && !boundNode->isMultiLabeled()) {
if (!rel->isMultiLabeled() && !boundNode->isMultiLabeled() && rel->isDirected()) {
auto relDataDirection = getRelDataDirection(extendDirection);
auto relTableID = rel->getSingleTableID();
if (relsStore.isSingleMultiplicityInDirection(direction, relTableID)) {
if (relsStore.isSingleMultiplicityInDirection(relDataDirection, relTableID)) {
auto propertyIds = populatePropertyIds(relTableID, extend->getProperties());
return make_unique<ScanRelTableColumns>(
relsStore.getRelTable(relTableID)->getDirectedTableData(direction),
relsStore.getRelTable(relTableID)->getDirectedTableData(relDataDirection),
std::move(propertyIds), inNodeIDVectorPos, std::move(outputVectorsPos),
std::move(prevOperator), getOperatorID(), extend->getExpressionsForPrinting());
} else {
assert(!relsStore.isSingleMultiplicityInDirection(direction, relTableID));
assert(!relsStore.isSingleMultiplicityInDirection(relDataDirection, relTableID));
auto propertyIds = populatePropertyIds(relTableID, extend->getProperties());
return make_unique<ScanRelTableLists>(
relsStore.getRelTable(relTableID)->getDirectedTableData(direction),
relsStore.getRelTable(relTableID)->getDirectedTableData(relDataDirection),
std::move(propertyIds), inNodeIDVectorPos, std::move(outputVectorsPos),
std::move(prevOperator), getOperatorID(), extend->getExpressionsForPrinting());
}
} else { // map to generic extend
std::unordered_map<table_id_t, std::unique_ptr<RelTableCollection>>
std::unordered_map<table_id_t, std::unique_ptr<RelTableDataCollection>>
relTableCollectionPerNodeTable;
for (auto boundNodeTableID : boundNode->getTableIDs()) {
auto relTableCollection = populateRelTableCollection(
boundNodeTableID, *rel, direction, extend->getProperties(), relsStore, *catalog);
auto relTableCollection = populateRelTableDataCollection(boundNodeTableID, *rel,
extendDirection, extend->getProperties(), relsStore, *catalog);
if (relTableCollection->getNumTablesInCollection() > 0) {
relTableCollectionPerNodeTable.insert(
{boundNodeTableID, std::move(relTableCollection)});
Expand Down
34 changes: 17 additions & 17 deletions src/processor/operator/scan/generic_scan_rel_tables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ using namespace kuzu::transaction;
namespace kuzu {
namespace processor {

void RelTableCollection::resetState() {
void RelTableDataCollection::resetState() {
currentRelTableIdxToScan = 0;
nextRelTableIdxToScan = 0;
}

bool RelTableCollection::scan(ValueVector* inVector, const std::vector<ValueVector*>& outputVectors,
Transaction* transaction) {
bool RelTableDataCollection::scan(ValueVector* inVector,
const std::vector<ValueVector*>& outputVectors, Transaction* transaction) {
do {
if (tableScanStates[currentRelTableIdxToScan]->hasMoreAndSwitchSourceIfNecessary()) {
assert(tableScanStates[currentRelTableIdxToScan]->relTableDataType ==
storage::RelTableDataType::LISTS);
tables[currentRelTableIdxToScan]->scan(
relTableDatas[currentRelTableIdxToScan]->scan(
transaction, *tableScanStates[currentRelTableIdxToScan], inVector, outputVectors);
} else {
currentRelTableIdxToScan = nextRelTableIdxToScan;
Expand All @@ -33,31 +33,31 @@ bool RelTableCollection::scan(ValueVector* inVector, const std::vector<ValueVect
} else {
tableScanStates[currentRelTableIdxToScan]->syncState->resetState();
}
tables[currentRelTableIdxToScan]->scan(
relTableDatas[currentRelTableIdxToScan]->scan(
transaction, *tableScanStates[currentRelTableIdxToScan], inVector, outputVectors);
nextRelTableIdxToScan++;
}
} while (outputVectors[0]->state->selVector->selectedSize == 0);
return true;
}

std::unique_ptr<RelTableCollection> RelTableCollection::clone() const {
std::unique_ptr<RelTableDataCollection> RelTableDataCollection::clone() const {
std::vector<std::unique_ptr<RelTableScanState>> clonedScanStates;
for (auto& scanState : tableScanStates) {
clonedScanStates.push_back(
make_unique<RelTableScanState>(scanState->propertyIds, scanState->relTableDataType));
}
return make_unique<RelTableCollection>(tables, std::move(clonedScanStates));
return make_unique<RelTableDataCollection>(relTableDatas, std::move(clonedScanStates));
}

void GenericScanRelTables::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) {
ScanRelTable::initLocalStateInternal(resultSet, context);
currentRelTableCollection = nullptr;
currentRelTableDataCollection = nullptr;
}

bool GenericScanRelTables::getNextTuplesInternal(ExecutionContext* context) {
while (true) {
if (scanCurrentRelTableCollection()) {
if (scanCurrentRelTableDataCollection()) {
metrics->numOutputTuple.increase(outputVectors[0]->state->selVector->selectedSize);
return true;
}
Expand All @@ -70,23 +70,23 @@ bool GenericScanRelTables::getNextTuplesInternal(ExecutionContext* context) {
continue;
}
auto nodeID = inNodeIDVector->getValue<nodeID_t>(currentIdx);
initCurrentRelTableCollection(nodeID);
initCurrentRelTableDataCollection(nodeID);
}
}

bool GenericScanRelTables::scanCurrentRelTableCollection() {
if (currentRelTableCollection == nullptr) {
bool GenericScanRelTables::scanCurrentRelTableDataCollection() {
if (currentRelTableDataCollection == nullptr) {
return false;
}
return currentRelTableCollection->scan(inNodeIDVector, outputVectors, transaction);
return currentRelTableDataCollection->scan(inNodeIDVector, outputVectors, transaction);
}

void GenericScanRelTables::initCurrentRelTableCollection(const nodeID_t& nodeID) {
void GenericScanRelTables::initCurrentRelTableDataCollection(const nodeID_t& nodeID) {
if (relTableCollectionPerNodeTable.contains(nodeID.tableID)) {
currentRelTableCollection = relTableCollectionPerNodeTable.at(nodeID.tableID).get();
currentRelTableCollection->resetState();
currentRelTableDataCollection = relTableCollectionPerNodeTable.at(nodeID.tableID).get();
currentRelTableDataCollection->resetState();
} else {
currentRelTableCollection = nullptr;
currentRelTableDataCollection = nullptr;
}
}

Expand Down
12 changes: 12 additions & 0 deletions test/test_files/demo_db/demo_db.test
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,15 @@ Adam|30
-QUERY MATCH (a:User) WITH a ORDER BY a.age DESC LIMIT 1 MATCH (a)-[:Follows]->(b:User) RETURN *;
---- 1
(label:User, 0:3, {name:Noura, age:25})|(label:User, 0:2, {name:Zhang, age:50})

-NAME Undir1
-QUERY MATCH (a:User)-[:Follows]-(b:User) RETURN a.name, b.age;
---- 8
Adam|40
Adam|50
Karissa|50
Zhang|25
Karissa|30
Zhang|30
Zhang|40
Noura|50
21 changes: 21 additions & 0 deletions test/test_files/tinysnb/match/undirected.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
-NAME UndirKnows1
-QUERY MATCH (a:person)-[:knows]-(b:person) WHERE b.fName = "Bob" RETURN a.fName;
---- 6
Alice
Carol
Dan
Alice
Carol
Dan

-NAME UndirKnows2
-QUERY MATCH (a:person)-[:knows]-(b:person)-[:knows]-(c:person) WHERE a.gender = 1 AND b.gender = 2 AND c.fName = "Bob" RETURN a.fName, b.fName;
---- 8
Alice|Dan
Carol|Dan
Alice|Dan
Carol|Dan
Alice|Dan
Carol|Dan
Alice|Dan
Carol|Dan

0 comments on commit a63026a

Please sign in to comment.