diff --git a/src/include/storage/store/rel_table.h b/src/include/storage/store/rel_table.h index cc84f090ca..d8b30b7efe 100644 --- a/src/include/storage/store/rel_table.h +++ b/src/include/storage/store/rel_table.h @@ -91,11 +91,14 @@ class DirectedRelTableData { inline uint32_t getNumPropertyLists(table_id_t boundNodeTableID) { return propertyLists.at(boundNodeTableID).size(); } + // Returns the list offset of the given relID if the relID stored as list in the current + // direction, otherwise it returns UINT64_MAX. inline list_offset_t getListOffset(nodeID_t nodeID, int64_t relID) { - return ((RelIDList*)(propertyLists - .at(nodeID.tableID)[RelTableSchema::INTERNAL_REL_ID_PROPERTY_IDX] - .get())) - ->getListOffset(nodeID.offset, relID); + return propertyLists.contains(nodeID.tableID) ? + ((RelIDList*)getPropertyLists( + nodeID.tableID, RelTableSchema::INTERNAL_REL_ID_PROPERTY_IDX)) + ->getListOffset(nodeID.offset, relID) : + UINT64_MAX; } void initializeData(RelTableSchema* tableSchema, BufferManager& bufferManager, WAL* wal); @@ -119,10 +122,12 @@ class DirectedRelTableData { } } - void insertRel(table_id_t boundTableID, const shared_ptr& boundVector, + void insertRel(const shared_ptr& boundVector, const shared_ptr& nbrVector, const vector>& relPropertyVectors); - void deleteRel(table_id_t boundTableID, const shared_ptr& boundVector); + void deleteRel(const shared_ptr& boundVector); + void updateRel(const shared_ptr& boundVector, property_id_t propertyID, + const shared_ptr& propertyVector); void performOpOnListsWithUpdates(const std::function& opOnListsWithUpdates); unique_ptr getListsUpdateIteratorsForDirection( table_id_t boundNodeTableID); diff --git a/src/storage/store/rel_table.cpp b/src/storage/store/rel_table.cpp index 6f97ccb824..b40934c6c2 100644 --- a/src/storage/store/rel_table.cpp +++ b/src/storage/store/rel_table.cpp @@ -154,9 +154,12 @@ void DirectedRelTableData::scanLists(Transaction* transaction, RelTableScanState } } -void DirectedRelTableData::insertRel(table_id_t boundTableID, - const shared_ptr& boundVector, const shared_ptr& nbrVector, +void DirectedRelTableData::insertRel(const shared_ptr& boundVector, + const shared_ptr& nbrVector, const vector>& relPropertyVectors) { + auto boundTableID = + boundVector->getValue(boundVector->state->selVector->selectedPositions[0]) + .tableID; if (!adjColumns.contains(boundTableID)) { return; } @@ -177,20 +180,31 @@ void DirectedRelTableData::insertRel(table_id_t boundTableID, } } -void DirectedRelTableData::deleteRel( - table_id_t boundTableID, const shared_ptr& boundVector) { - if (!adjColumns.contains(boundTableID)) { +void DirectedRelTableData::deleteRel(const shared_ptr& boundVector) { + auto boundNode = + boundVector->getValue(boundVector->state->selVector->selectedPositions[0]); + if (!adjColumns.contains(boundNode.tableID)) { return; } - auto adjColumn = adjColumns.at(boundTableID).get(); + auto adjColumn = adjColumns.at(boundNode.tableID).get(); auto nodeOffset = boundVector->readNodeOffset(boundVector->state->selVector->selectedPositions[0]); adjColumn->setNodeOffsetToNull(nodeOffset); - for (auto& [_, propertyColumn] : propertyColumns.at(boundTableID)) { + for (auto& [_, propertyColumn] : propertyColumns.at(boundNode.tableID)) { propertyColumn->setNodeOffsetToNull(nodeOffset); } } +void DirectedRelTableData::updateRel(const shared_ptr& boundVector, + property_id_t propertyID, const shared_ptr& propertyVector) { + auto boundNode = + boundVector->getValue(boundVector->state->selVector->selectedPositions[0]); + if (!adjColumns.contains(boundNode.tableID)) { + return; + } + propertyColumns.at(boundNode.tableID).at(propertyID)->writeValues(boundVector, propertyVector); +} + void DirectedRelTableData::performOpOnListsWithUpdates( const std::function& opOnListsWithUpdates) { for (auto& [boundNodeTableID, listsUpdatePerTable] : @@ -284,14 +298,8 @@ void RelTable::insertRel(const shared_ptr& srcNodeIDVector, const shared_ptr& dstNodeIDVector, const vector>& relPropertyVectors) { assert(srcNodeIDVector->state->isFlat() && dstNodeIDVector->state->isFlat()); - auto srcTableID = - srcNodeIDVector->getValue(srcNodeIDVector->state->selVector->selectedPositions[0]) - .tableID; - auto dstTableID = - dstNodeIDVector->getValue(dstNodeIDVector->state->selVector->selectedPositions[0]) - .tableID; - fwdRelTableData->insertRel(srcTableID, srcNodeIDVector, dstNodeIDVector, relPropertyVectors); - bwdRelTableData->insertRel(dstTableID, dstNodeIDVector, srcNodeIDVector, relPropertyVectors); + fwdRelTableData->insertRel(srcNodeIDVector, dstNodeIDVector, relPropertyVectors); + bwdRelTableData->insertRel(dstNodeIDVector, srcNodeIDVector, relPropertyVectors); listsUpdatesStore->insertRelIfNecessary(srcNodeIDVector, dstNodeIDVector, relPropertyVectors); } @@ -299,14 +307,8 @@ void RelTable::deleteRel(const shared_ptr& srcNodeIDVector, const shared_ptr& dstNodeIDVector, const shared_ptr& relIDVector) { assert(srcNodeIDVector->state->isFlat() && dstNodeIDVector->state->isFlat() && relIDVector->state->isFlat()); - auto srcTableID = - srcNodeIDVector->getValue(srcNodeIDVector->state->selVector->selectedPositions[0]) - .tableID; - auto dstTableID = - dstNodeIDVector->getValue(dstNodeIDVector->state->selVector->selectedPositions[0]) - .tableID; - fwdRelTableData->deleteRel(srcTableID, srcNodeIDVector); - bwdRelTableData->deleteRel(dstTableID, dstNodeIDVector); + fwdRelTableData->deleteRel(srcNodeIDVector); + bwdRelTableData->deleteRel(dstNodeIDVector); listsUpdatesStore->deleteRelIfNecessary(srcNodeIDVector, dstNodeIDVector, relIDVector); } @@ -319,6 +321,8 @@ void RelTable::updateRel(const shared_ptr& srcNodeIDVector, srcNodeIDVector->state->selVector->selectedPositions[0]); auto dstNode = dstNodeIDVector->getValue( dstNodeIDVector->state->selVector->selectedPositions[0]); + fwdRelTableData->updateRel(srcNodeIDVector, propertyID, propertyVector); + bwdRelTableData->updateRel(dstNodeIDVector, propertyID, propertyVector); auto relID = relIDVector->getValue(relIDVector->state->selVector->selectedPositions[0]); ListsUpdateInfo listsUpdateInfo = ListsUpdateInfo{propertyVector, propertyID, relID, diff --git a/test/runner/e2e_update_rel_test.cpp b/test/runner/e2e_update_rel_test.cpp index e18c5b8950..e34a9536ae 100644 --- a/test/runner/e2e_update_rel_test.cpp +++ b/test/runner/e2e_update_rel_test.cpp @@ -243,6 +243,41 @@ class UpdateRelTest : public DBTest { sortAndCheckTestResults(actualResult, expectedResult); } + void updateManyToOneRelTable(bool isCommit, TransactionTestType transactionTestType) { + conn->beginWriteTransaction(); + ASSERT_TRUE(conn->query(getUpdateRelQuery("person" /* srcTable */, "person" /* dstTable */, + "teaches" /* relation */, 21 /* srcID */, 2 /* dstID */, "SET e.length=null"))); + ASSERT_TRUE(conn->query(getUpdateRelQuery("person" /* srcTable */, "person" /* dstTable */, + "teaches" /* relation */, 32 /* srcID */, 3 /* dstID */, "SET e.length = 512"))); + ASSERT_TRUE(conn->query(getUpdateRelQuery("person" /* srcTable */, "person" /* relation */, + "teaches" /* relation */, 33 /* srcID */, 3 /* dstID */, "SET e.length = 312"))); + commitOrRollbackConnectionAndInitDBIfNecessary(isCommit, transactionTestType); + auto expectedResult = isCommit ? vector{"11", "", "22", "31", "512", "312"} : + vector{"11", "21", "22", "31", "32", "33"}; + auto result = conn->query("MATCH (p:person)-[e:teaches]->(:person) RETURN e.length"); + auto actualResult = TestHelper::convertResultToString(*result); + sortAndCheckTestResults(actualResult, expectedResult); + } + + void updateOneToOneRelTable(bool isCommit, TransactionTestType transactionTestType) { + conn->beginWriteTransaction(); + ASSERT_TRUE(conn->query(getUpdateRelQuery("animal" /* srcTable */, "person" /* dstTable */, + "hasOwner" /* relation */, 2 /* srcID */, 52 /* dstID */, "SET e.place='kuzu'"))); + ASSERT_TRUE(conn->query(getUpdateRelQuery("animal" /* srcTable */, "person" /* dstTable */, + "hasOwner" /* relation */, 4 /* srcID */, 54 /* dstID */, "SET e.place='db'"))); + ASSERT_TRUE(conn->query(getUpdateRelQuery("animal" /* srcTable */, "person" /* relation */, + "hasOwner" /* relation */, 8 /* srcID */, 58 /* dstID */, "SET e.place=null"))); + commitOrRollbackConnectionAndInitDBIfNecessary(isCommit, transactionTestType); + auto expectedResult = + isCommit ? vector{"1999", "kuzu", "1997", "db", "1995", "199419941994", "1993", + "", "1991", "1989"} : + vector{"1999", "199819981998", "1997", "199619961996", "1995", + "199419941994", "1993", "199219921992", "1991", "1989"}; + auto result = conn->query("MATCH (:animal)-[e:hasOwner]->(:person) RETURN e.place"); + auto actualResult = TestHelper::convertResultToString(*result); + sortAndCheckTestResults(actualResult, expectedResult); + } + static constexpr uint64_t NUM_PERSON_KNOWS_PERSON_RELS = 2500; }; @@ -392,3 +427,35 @@ TEST_F(UpdateRelTest, InsertAndUpdateRelsForNewlyAddedNodeRollbackNormalExecutio TEST_F(UpdateRelTest, InsertAndUpdateRelsForNewlyAddedNodeRollbackRecovery) { insertAndUpdateRelsForNewlyAddedNode(false /* isCommit */, TransactionTestType::RECOVERY); } + +TEST_F(UpdateRelTest, UpdateManyToOneRelTableCommitNormalExecution) { + updateManyToOneRelTable(true /* isCommit */, TransactionTestType::NORMAL_EXECUTION); +} + +TEST_F(UpdateRelTest, UpdateManyToOneRelTableCommitRecovery) { + updateManyToOneRelTable(true /* isCommit */, TransactionTestType::RECOVERY); +} + +TEST_F(UpdateRelTest, UpdateManyToOneRelTableRollbackNormalExecution) { + updateManyToOneRelTable(false /* isCommit */, TransactionTestType::NORMAL_EXECUTION); +} + +TEST_F(UpdateRelTest, UpdateManyToOneRelTableRollbackRecovery) { + updateManyToOneRelTable(false /* isCommit */, TransactionTestType::RECOVERY); +} + +TEST_F(UpdateRelTest, UpdateOneToOneRelTableCommitNormalExecution) { + updateOneToOneRelTable(true /* isCommit */, TransactionTestType::NORMAL_EXECUTION); +} + +TEST_F(UpdateRelTest, UpdateOneToOneRelTableCommitRecovery) { + updateOneToOneRelTable(true /* isCommit */, TransactionTestType::RECOVERY); +} + +TEST_F(UpdateRelTest, UpdateOneToOneRelTableRollbackNormalExecution) { + updateOneToOneRelTable(false /* isCommit */, TransactionTestType::NORMAL_EXECUTION); +} + +TEST_F(UpdateRelTest, UpdateOneToOneRelTableRollbackRecovery) { + updateOneToOneRelTable(false /* isCommit */, TransactionTestType::RECOVERY); +}