diff --git a/src/include/processor/operator/scan/scan_rel_table.h b/src/include/processor/operator/scan/scan_rel_table.h index 55b00751279..23f60164246 100644 --- a/src/include/processor/operator/scan/scan_rel_table.h +++ b/src/include/processor/operator/scan/scan_rel_table.h @@ -36,7 +36,6 @@ class ScanRelTable : public ScanTable { ScanTable::initLocalStateInternal(resultSet, executionContext); scanState = std::make_unique( *inVector, info->columnIDs, outVectors, info->direction); - scanState->dataReadState = std::make_unique(); } bool getNextTuplesInternal(ExecutionContext* context) override; diff --git a/src/include/storage/store/node_table.h b/src/include/storage/store/node_table.h index 71991aa6679..a60c57e1841 100644 --- a/src/include/storage/store/node_table.h +++ b/src/include/storage/store/node_table.h @@ -65,9 +65,6 @@ class NodeTable final : public Table { inline void initializeReadState(transaction::Transaction* transaction, std::vector columnIDs, const common::ValueVector& inNodeIDVector, TableReadState& readState) { - if (!readState.dataReadState) { - readState.dataReadState = std::make_unique(); - } tableData->initializeReadState( transaction, std::move(columnIDs), inNodeIDVector, *readState.dataReadState); } diff --git a/src/include/storage/store/rel_table.h b/src/include/storage/store/rel_table.h index 6ecd1d301d1..9db61a92b7d 100644 --- a/src/include/storage/store/rel_table.h +++ b/src/include/storage/store/rel_table.h @@ -13,7 +13,9 @@ struct RelTableReadState : public TableReadState { RelTableReadState(const common::ValueVector& nodeIDVector, const std::vector& columnIDs, const std::vector& outputVectors, common::RelDataDirection direction) - : TableReadState{nodeIDVector, columnIDs, outputVectors}, direction{direction} {} + : TableReadState{nodeIDVector, columnIDs, outputVectors}, direction{direction} { + dataReadState = std::make_unique(); + } bool hasMoreToRead(transaction::Transaction* transaction) const { auto relDataReadState = diff --git a/src/include/storage/store/table.h b/src/include/storage/store/table.h index 6c37fbe680d..6d1c3a00682 100644 --- a/src/include/storage/store/table.h +++ b/src/include/storage/store/table.h @@ -17,7 +17,9 @@ struct TableReadState { const std::vector& columnIDs, const std::vector& outputVectors) : nodeIDVector{nodeIDVector}, columnIDs{std::move(columnIDs)}, outputVectors{ - outputVectors} {} + outputVectors} { + dataReadState = std::make_unique(); + } virtual ~TableReadState() = default; }; diff --git a/src/processor/operator/scan/scan_multi_rel_tables.cpp b/src/processor/operator/scan/scan_multi_rel_tables.cpp index 8d8730417f7..95c68b60c55 100644 --- a/src/processor/operator/scan/scan_multi_rel_tables.cpp +++ b/src/processor/operator/scan/scan_multi_rel_tables.cpp @@ -14,7 +14,6 @@ void RelTableCollectionScanner::init( auto scanInfo = scanInfos[i].get(); readStates[i] = std::make_unique( *inVector, scanInfo->columnIDs, outputVectors, scanInfo->direction); - readStates[i]->dataReadState = std::make_unique(); } } diff --git a/src/storage/local_storage/local_rel_table.cpp b/src/storage/local_storage/local_rel_table.cpp index f4dc9e66d2a..08e4f1248be 100644 --- a/src/storage/local_storage/local_rel_table.cpp +++ b/src/storage/local_storage/local_rel_table.cpp @@ -219,25 +219,46 @@ LocalRelTable::LocalRelTable(Table& table) : LocalTable{table} { bool LocalRelTable::insert(TableInsertState& state) { auto& insertState = ku_dynamic_cast(state); - auto IDVectors = + auto fwdIDVectors = std::vector{const_cast(&insertState.srcNodeIDVector), const_cast(&insertState.dstNodeIDVector)}; - return getTableData(RelDataDirection::FWD)->insert(IDVectors, insertState.propertyVectors); + auto bwdIDVectors = + std::vector{const_cast(&insertState.dstNodeIDVector), + const_cast(&insertState.srcNodeIDVector)}; + auto fwdInserted = + getTableData(RelDataDirection::FWD)->insert(fwdIDVectors, insertState.propertyVectors); + auto bwdInserted = + getTableData(RelDataDirection::BWD)->insert(bwdIDVectors, insertState.propertyVectors); + KU_ASSERT(fwdInserted == bwdInserted); + return fwdInserted && bwdInserted; } bool LocalRelTable::update(TableUpdateState& updateState) { auto& state = ku_dynamic_cast(updateState); - auto IDVectors = std::vector{const_cast(&state.srcNodeIDVector), + auto fwdIDVectors = std::vector{const_cast(&state.srcNodeIDVector), const_cast(&state.relIDVector)}; - return getTableData(RelDataDirection::FWD) - ->update(IDVectors, state.columnID, const_cast(&state.propertyVector)); + auto bwdIDVectors = std::vector{const_cast(&state.dstNodeIDVector), + const_cast(&state.relIDVector)}; + auto fwdUpdated = + getTableData(RelDataDirection::FWD) + ->update(fwdIDVectors, state.columnID, const_cast(&state.propertyVector)); + auto bwdUpdated = + getTableData(RelDataDirection::BWD) + ->update(bwdIDVectors, state.columnID, const_cast(&state.propertyVector)); + KU_ASSERT(fwdUpdated == bwdUpdated); + return fwdUpdated && bwdUpdated; } bool LocalRelTable::delete_(TableDeleteState& deleteState) { auto& state = ku_dynamic_cast(deleteState); - return getTableData(RelDataDirection::FWD) - ->delete_(const_cast(&state.srcNodeIDVector), - const_cast(&state.relIDVector)); + auto fwdDeleted = getTableData(RelDataDirection::FWD) + ->delete_(const_cast(&state.srcNodeIDVector), + const_cast(&state.relIDVector)); + auto bwdDeleted = getTableData(RelDataDirection::BWD) + ->delete_(const_cast(&state.dstNodeIDVector), + const_cast(&state.relIDVector)); + KU_ASSERT(fwdDeleted == bwdDeleted); + return fwdDeleted && bwdDeleted; } void LocalRelTable::scan(TableReadState&) { diff --git a/src/storage/store/rel_table.cpp b/src/storage/store/rel_table.cpp index 9c640ae216d..2766f9def57 100644 --- a/src/storage/store/rel_table.cpp +++ b/src/storage/store/rel_table.cpp @@ -101,14 +101,6 @@ row_idx_t RelTable::detachDeleteForCSRRels(Transaction* transaction, RelTableDat row_idx_t numRelsDeleted = 0; auto tempState = deleteState->dstNodeIDVector->state.get(); while (relDataReadState->hasMoreToRead(transaction)) { - // auto relIDColumns = {REL_ID_COLUMN_ID}; - // auto relIDVectors = std::vector{ - // deleteState->dstNodeIDVector.get(), deleteState->relIDVector.get()}; - // auto relIDReadState = std::make_unique( - // *srcNodeIDVector, relIDColumns, relIDVectors, RelDataDirection::FWD); - // initializeReadState( - // transaction, RelDataDirection::FWD, relIDColumns, *srcNodeIDVector, - // *relIDReadState); scan(transaction, *relDataReadState); auto numRelsScanned = tempState->selVector->selectedSize; tempState->selVector->resetSelectorToValuePosBufferWithSize(1); diff --git a/test/transaction/transaction_manager_test.cpp b/test/transaction/transaction_manager_test.cpp index 7207dd3c919..941132f5d29 100644 --- a/test/transaction/transaction_manager_test.cpp +++ b/test/transaction/transaction_manager_test.cpp @@ -27,9 +27,10 @@ class TransactionManagerTest : public EmptyDBTest { public: void runTwoCommitRollback(TransactionType type, bool firstIsCommit, bool secondIsCommit) { - std::unique_ptr trx = TransactionType::WRITE == type ? - transactionManager->beginWriteTransaction() : - transactionManager->beginReadOnlyTransaction(); + std::unique_ptr trx = + TransactionType::WRITE == type ? + transactionManager->beginWriteTransaction(*getClientContext(*conn)) : + transactionManager->beginReadOnlyTransaction(*getClientContext(*conn)); if (firstIsCommit) { transactionManager->commit(trx.get()); } else { @@ -51,9 +52,10 @@ class TransactionManagerTest : public EmptyDBTest { }; TEST_F(TransactionManagerTest, MultipleWriteTransactionsErrors) { - std::unique_ptr trx1 = transactionManager->beginWriteTransaction(); + std::unique_ptr trx1 = + transactionManager->beginWriteTransaction(*getClientContext(*conn)); try { - transactionManager->beginWriteTransaction(); + transactionManager->beginWriteTransaction(*getClientContext(*conn)); FAIL(); } catch (TransactionManagerException& e) {} } @@ -96,9 +98,12 @@ TEST_F(TransactionManagerTest, BasicOneWriteMultipleReadOnlyTransactions) { // before and after commits or rollbacks under concurrent transactions. Specifically we test: // that transaction IDs increase incrementally, the states of activeReadOnlyTransactionIDs set, // and activeWriteTransactionID. - std::unique_ptr trx1 = transactionManager->beginReadOnlyTransaction(); - std::unique_ptr trx2 = transactionManager->beginWriteTransaction(); - std::unique_ptr trx3 = transactionManager->beginReadOnlyTransaction(); + std::unique_ptr trx1 = + transactionManager->beginReadOnlyTransaction(*getClientContext(*conn)); + std::unique_ptr trx2 = + transactionManager->beginWriteTransaction(*getClientContext(*conn)); + std::unique_ptr trx3 = + transactionManager->beginReadOnlyTransaction(*getClientContext(*conn)); ASSERT_EQ(TransactionType::READ_ONLY, trx1->getType()); ASSERT_EQ(TransactionType::WRITE, trx2->getType()); ASSERT_EQ(TransactionType::READ_ONLY, trx3->getType()); @@ -120,8 +125,10 @@ TEST_F(TransactionManagerTest, BasicOneWriteMultipleReadOnlyTransactions) { ASSERT_EQ( expectedReadOnlyTransactionSet, transactionManager->getActiveReadOnlyTransactionIDs()); - std::unique_ptr trx4 = transactionManager->beginWriteTransaction(); - std::unique_ptr trx5 = transactionManager->beginReadOnlyTransaction(); + std::unique_ptr trx4 = + transactionManager->beginWriteTransaction(*getClientContext(*conn)); + std::unique_ptr trx5 = + transactionManager->beginReadOnlyTransaction(*getClientContext(*conn)); ASSERT_EQ(trx3->getID() + 1, trx4->getID()); ASSERT_EQ(trx4->getID() + 1, trx5->getID()); ASSERT_EQ(trx4->getID(), transactionManager->getActiveWriteTransactionID());