Skip to content

Commit

Permalink
fix rel update/insert/delete, enable both directions
Browse files Browse the repository at this point in the history
  • Loading branch information
ray6080 committed Mar 18, 2024
1 parent 09e132a commit 8702527
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 33 deletions.
1 change: 0 additions & 1 deletion src/include/processor/operator/scan/scan_rel_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class ScanRelTable : public ScanTable {
ScanTable::initLocalStateInternal(resultSet, executionContext);
scanState = std::make_unique<storage::RelTableReadState>(
*inVector, info->columnIDs, outVectors, info->direction);
scanState->dataReadState = std::make_unique<storage::RelDataReadState>();
}

bool getNextTuplesInternal(ExecutionContext* context) override;
Expand Down
3 changes: 0 additions & 3 deletions src/include/storage/store/node_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ class NodeTable final : public Table {
inline void initializeReadState(transaction::Transaction* transaction,
std::vector<common::column_id_t> columnIDs, const common::ValueVector& inNodeIDVector,
TableReadState& readState) {
if (!readState.dataReadState) {
readState.dataReadState = std::make_unique<TableDataReadState>();
}
tableData->initializeReadState(
transaction, std::move(columnIDs), inNodeIDVector, *readState.dataReadState);
}
Expand Down
4 changes: 3 additions & 1 deletion src/include/storage/store/rel_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ struct RelTableReadState : public TableReadState {
RelTableReadState(const common::ValueVector& nodeIDVector,
const std::vector<common::column_id_t>& columnIDs,
const std::vector<common::ValueVector*>& outputVectors, common::RelDataDirection direction)
: TableReadState{nodeIDVector, columnIDs, outputVectors}, direction{direction} {}
: TableReadState{nodeIDVector, columnIDs, outputVectors}, direction{direction} {
dataReadState = std::make_unique<RelDataReadState>();
}

bool hasMoreToRead(transaction::Transaction* transaction) const {
auto relDataReadState =
Expand Down
4 changes: 3 additions & 1 deletion src/include/storage/store/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ struct TableReadState {
const std::vector<common::column_id_t>& columnIDs,
const std::vector<common::ValueVector*>& outputVectors)
: nodeIDVector{nodeIDVector}, columnIDs{std::move(columnIDs)}, outputVectors{
outputVectors} {}
outputVectors} {
dataReadState = std::make_unique<TableDataReadState>();
}
virtual ~TableReadState() = default;
};

Expand Down
1 change: 0 additions & 1 deletion src/processor/operator/scan/scan_multi_rel_tables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ void RelTableCollectionScanner::init(
auto scanInfo = scanInfos[i].get();
readStates[i] = std::make_unique<RelTableReadState>(
*inVector, scanInfo->columnIDs, outputVectors, scanInfo->direction);
readStates[i]->dataReadState = std::make_unique<RelDataReadState>();
}
}

Expand Down
37 changes: 29 additions & 8 deletions src/storage/local_storage/local_rel_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,25 +219,46 @@ LocalRelTable::LocalRelTable(Table& table) : LocalTable{table} {

bool LocalRelTable::insert(TableInsertState& state) {
auto& insertState = ku_dynamic_cast<TableInsertState&, RelTableInsertState&>(state);
auto IDVectors =
auto fwdIDVectors =
std::vector<ValueVector*>{const_cast<ValueVector*>(&insertState.srcNodeIDVector),
const_cast<ValueVector*>(&insertState.dstNodeIDVector)};
return getTableData(RelDataDirection::FWD)->insert(IDVectors, insertState.propertyVectors);
auto bwdIDVectors =
std::vector<ValueVector*>{const_cast<ValueVector*>(&insertState.dstNodeIDVector),
const_cast<ValueVector*>(&insertState.srcNodeIDVector)};
auto fwdInserted =
getTableData(RelDataDirection::FWD)->insert(fwdIDVectors, insertState.propertyVectors);
auto bwdInserted =
getTableData(RelDataDirection::BWD)->insert(bwdIDVectors, insertState.propertyVectors);
KU_ASSERT(fwdInserted == bwdInserted);
return fwdInserted && bwdInserted;
}

bool LocalRelTable::update(TableUpdateState& updateState) {
auto& state = ku_dynamic_cast<TableUpdateState&, RelTableUpdateState&>(updateState);
auto IDVectors = std::vector<ValueVector*>{const_cast<ValueVector*>(&state.srcNodeIDVector),
auto fwdIDVectors = std::vector<ValueVector*>{const_cast<ValueVector*>(&state.srcNodeIDVector),
const_cast<ValueVector*>(&state.relIDVector)};
return getTableData(RelDataDirection::FWD)
->update(IDVectors, state.columnID, const_cast<ValueVector*>(&state.propertyVector));
auto bwdIDVectors = std::vector<ValueVector*>{const_cast<ValueVector*>(&state.dstNodeIDVector),
const_cast<ValueVector*>(&state.relIDVector)};
auto fwdUpdated =
getTableData(RelDataDirection::FWD)
->update(fwdIDVectors, state.columnID, const_cast<ValueVector*>(&state.propertyVector));
auto bwdUpdated =
getTableData(RelDataDirection::BWD)
->update(bwdIDVectors, state.columnID, const_cast<ValueVector*>(&state.propertyVector));
KU_ASSERT(fwdUpdated == bwdUpdated);
return fwdUpdated && bwdUpdated;
}

bool LocalRelTable::delete_(TableDeleteState& deleteState) {
auto& state = ku_dynamic_cast<TableDeleteState&, RelTableDeleteState&>(deleteState);
return getTableData(RelDataDirection::FWD)
->delete_(const_cast<ValueVector*>(&state.srcNodeIDVector),
const_cast<ValueVector*>(&state.relIDVector));
auto fwdDeleted = getTableData(RelDataDirection::FWD)
->delete_(const_cast<ValueVector*>(&state.srcNodeIDVector),
const_cast<ValueVector*>(&state.relIDVector));
auto bwdDeleted = getTableData(RelDataDirection::BWD)
->delete_(const_cast<ValueVector*>(&state.dstNodeIDVector),
const_cast<ValueVector*>(&state.relIDVector));
KU_ASSERT(fwdDeleted == bwdDeleted);
return fwdDeleted && bwdDeleted;
}

void LocalRelTable::scan(TableReadState&) {
Expand Down
8 changes: 0 additions & 8 deletions src/storage/store/rel_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,6 @@ row_idx_t RelTable::detachDeleteForCSRRels(Transaction* transaction, RelTableDat
row_idx_t numRelsDeleted = 0;
auto tempState = deleteState->dstNodeIDVector->state.get();
while (relDataReadState->hasMoreToRead(transaction)) {
// auto relIDColumns = {REL_ID_COLUMN_ID};
// auto relIDVectors = std::vector<ValueVector*>{
// deleteState->dstNodeIDVector.get(), deleteState->relIDVector.get()};
// auto relIDReadState = std::make_unique<RelTableReadState>(
// *srcNodeIDVector, relIDColumns, relIDVectors, RelDataDirection::FWD);
// initializeReadState(
// transaction, RelDataDirection::FWD, relIDColumns, *srcNodeIDVector,
// *relIDReadState);
scan(transaction, *relDataReadState);
auto numRelsScanned = tempState->selVector->selectedSize;
tempState->selVector->resetSelectorToValuePosBufferWithSize(1);
Expand Down
27 changes: 17 additions & 10 deletions test/transaction/transaction_manager_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ class TransactionManagerTest : public EmptyDBTest {

public:
void runTwoCommitRollback(TransactionType type, bool firstIsCommit, bool secondIsCommit) {
std::unique_ptr<Transaction> trx = TransactionType::WRITE == type ?
transactionManager->beginWriteTransaction() :
transactionManager->beginReadOnlyTransaction();
std::unique_ptr<Transaction> trx =
TransactionType::WRITE == type ?
transactionManager->beginWriteTransaction(*getClientContext(*conn)) :
transactionManager->beginReadOnlyTransaction(*getClientContext(*conn));
if (firstIsCommit) {
transactionManager->commit(trx.get());
} else {
Expand All @@ -51,9 +52,10 @@ class TransactionManagerTest : public EmptyDBTest {
};

TEST_F(TransactionManagerTest, MultipleWriteTransactionsErrors) {
std::unique_ptr<Transaction> trx1 = transactionManager->beginWriteTransaction();
std::unique_ptr<Transaction> trx1 =
transactionManager->beginWriteTransaction(*getClientContext(*conn));
try {
transactionManager->beginWriteTransaction();
transactionManager->beginWriteTransaction(*getClientContext(*conn));
FAIL();
} catch (TransactionManagerException& e) {}
}
Expand Down Expand Up @@ -96,9 +98,12 @@ TEST_F(TransactionManagerTest, BasicOneWriteMultipleReadOnlyTransactions) {
// before and after commits or rollbacks under concurrent transactions. Specifically we test:
// that transaction IDs increase incrementally, the states of activeReadOnlyTransactionIDs set,
// and activeWriteTransactionID.
std::unique_ptr<Transaction> trx1 = transactionManager->beginReadOnlyTransaction();
std::unique_ptr<Transaction> trx2 = transactionManager->beginWriteTransaction();
std::unique_ptr<Transaction> trx3 = transactionManager->beginReadOnlyTransaction();
std::unique_ptr<Transaction> trx1 =
transactionManager->beginReadOnlyTransaction(*getClientContext(*conn));
std::unique_ptr<Transaction> trx2 =
transactionManager->beginWriteTransaction(*getClientContext(*conn));
std::unique_ptr<Transaction> trx3 =
transactionManager->beginReadOnlyTransaction(*getClientContext(*conn));
ASSERT_EQ(TransactionType::READ_ONLY, trx1->getType());
ASSERT_EQ(TransactionType::WRITE, trx2->getType());
ASSERT_EQ(TransactionType::READ_ONLY, trx3->getType());
Expand All @@ -120,8 +125,10 @@ TEST_F(TransactionManagerTest, BasicOneWriteMultipleReadOnlyTransactions) {
ASSERT_EQ(
expectedReadOnlyTransactionSet, transactionManager->getActiveReadOnlyTransactionIDs());

std::unique_ptr<Transaction> trx4 = transactionManager->beginWriteTransaction();
std::unique_ptr<Transaction> trx5 = transactionManager->beginReadOnlyTransaction();
std::unique_ptr<Transaction> trx4 =
transactionManager->beginWriteTransaction(*getClientContext(*conn));
std::unique_ptr<Transaction> trx5 =
transactionManager->beginReadOnlyTransaction(*getClientContext(*conn));
ASSERT_EQ(trx3->getID() + 1, trx4->getID());
ASSERT_EQ(trx4->getID() + 1, trx5->getID());
ASSERT_EQ(trx4->getID(), transactionManager->getActiveWriteTransactionID());
Expand Down

0 comments on commit 8702527

Please sign in to comment.