Skip to content

Commit

Permalink
remove dummy transactions (#3106)
Browse files Browse the repository at this point in the history
remove dummy transactions
  • Loading branch information
hououou authored and ray6080 committed Mar 23, 2024
1 parent 12e3e18 commit 8d6b5bc
Show file tree
Hide file tree
Showing 27 changed files with 125 additions and 124 deletions.
2 changes: 1 addition & 1 deletion src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ std::vector<BoundInsertInfo> Binder::bindInsertInfos(
const QueryGraphCollection& queryGraphCollection, const expression_set& nodeRelScope_) {
auto nodeRelScope = nodeRelScope_;
std::vector<BoundInsertInfo> result;
auto analyzer = QueryGraphLabelAnalyzer(*clientContext->getCatalog());
auto analyzer = QueryGraphLabelAnalyzer(*clientContext->getCatalog(), *clientContext);
for (auto i = 0u; i < queryGraphCollection.getNumQueryGraphs(); ++i) {
auto queryGraph = queryGraphCollection.getQueryGraph(i);
// Ensure query graph does not violate declared schema.
Expand Down
2 changes: 1 addition & 1 deletion src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ std::unique_ptr<BoundStatement> Binder::bind(const Statement& statement) {
KU_UNREACHABLE;
}
}
BoundStatementRewriter::rewrite(*boundStatement, *clientContext->getCatalog());
BoundStatementRewriter::rewrite(*boundStatement, *clientContext->getCatalog(), *clientContext);
return boundStatement;
}

Expand Down
6 changes: 3 additions & 3 deletions src/binder/bound_statement_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
namespace kuzu {
namespace binder {

void BoundStatementRewriter::rewrite(
BoundStatement& boundStatement, const catalog::Catalog& catalog) {
void BoundStatementRewriter::rewrite(BoundStatement& boundStatement,
const catalog::Catalog& catalog, const main::ClientContext& clientContext) {
auto withClauseProjectionRewriter = WithClauseProjectionRewriter();
withClauseProjectionRewriter.visitUnsafe(boundStatement);

auto matchClausePatternLabelRewriter = MatchClausePatternLabelRewriter(catalog);
auto matchClausePatternLabelRewriter = MatchClausePatternLabelRewriter(catalog, clientContext);
matchClausePatternLabelRewriter.visit(boundStatement);

auto defaultTypeSolver = DefaultTypeSolver();
Expand Down
7 changes: 3 additions & 4 deletions src/binder/query/query_graph_label_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "common/cast.h"
#include "common/exception/binder.h"
#include "common/string_format.h"
#include "transaction/transaction.h"

using namespace kuzu::common;
using namespace kuzu::catalog;
Expand Down Expand Up @@ -32,7 +31,7 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression&
std::unordered_set<std::string> candidateNamesSet;
auto isSrcConnect = *queryRel->getSrcNode() == node;
auto isDstConnect = *queryRel->getDstNode() == node;
auto tx = &DUMMY_READ_TRANSACTION;
auto tx = clientContext.getTx();
if (queryRel->getDirectionType() == RelDirectionType::BOTH) {
if (isSrcConnect || isDstConnect) {
for (auto relTableID : queryRel->getTableIDs()) {
Expand Down Expand Up @@ -109,7 +108,7 @@ void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) {
}
for (auto& relTableID : rel.getTableIDs()) {
auto relTableSchema = ku_dynamic_cast<CatalogEntry*, RelTableCatalogEntry*>(
catalog.getTableCatalogEntry(&DUMMY_READ_TRANSACTION, relTableID));
catalog.getTableCatalogEntry(clientContext.getTx(), relTableID));
auto srcTableID = relTableSchema->getSrcTableID();
auto dstTableID = relTableSchema->getDstTableID();
if (!boundTableIDSet.contains(srcTableID) || !boundTableIDSet.contains(dstTableID)) {
Expand All @@ -122,7 +121,7 @@ void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) {
auto dstTableIDSet = rel.getDstNode()->getTableIDsSet();
for (auto& relTableID : rel.getTableIDs()) {
auto relTableSchema = ku_dynamic_cast<CatalogEntry*, RelTableCatalogEntry*>(
catalog.getTableCatalogEntry(&DUMMY_READ_TRANSACTION, relTableID));
catalog.getTableCatalogEntry(clientContext.getTx(), relTableID));
auto srcTableID = relTableSchema->getSrcTableID();
auto dstTableID = relTableSchema->getDstTableID();
if (!srcTableIDSet.contains(srcTableID) || !dstTableIDSet.contains(dstTableID)) {
Expand Down
2 changes: 1 addition & 1 deletion src/function/table/call/storage_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ static void appendColumnChunkStorageInfo(node_group_idx_t nodeGroupIdx,

static void appendStorageInfoForColumn(StorageInfoLocalState* localState, std::string tableType,
const Column* column, DataChunk& outputChunk, ClientContext* context) {
auto numNodeGroups = column->getNumNodeGroups(&transaction::DUMMY_READ_TRANSACTION);
auto numNodeGroups = column->getNumNodeGroups(context->getTx());
for (auto nodeGroupIdx = 0u; nodeGroupIdx < numNodeGroups; nodeGroupIdx++) {
if (outputChunk.state->selVector->selectedSize == DEFAULT_VECTOR_CAPACITY) {
localState->dataChunkCollection->append(outputChunk);
Expand Down
4 changes: 3 additions & 1 deletion src/include/binder/bound_statement_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ namespace binder {
// Perform semantic rewrite over bound statement.
class BoundStatementRewriter {
public:
static void rewrite(BoundStatement& boundStatement, const catalog::Catalog& catalog);
// TODO(Jiamin): remove catalog
static void rewrite(BoundStatement& boundStatement, const catalog::Catalog& catalog,
const main::ClientContext& clientContext);
};

} // namespace binder
Expand Down
6 changes: 5 additions & 1 deletion src/include/binder/query/query_graph_label_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ namespace binder {

class QueryGraphLabelAnalyzer {
public:
explicit QueryGraphLabelAnalyzer(const catalog::Catalog& catalog) : catalog{catalog} {}
// TODO(Jiamin): remove catalog
explicit QueryGraphLabelAnalyzer(
const catalog::Catalog& catalog, const main::ClientContext& clientContext)
: catalog{catalog}, clientContext{clientContext} {}

void pruneLabel(const QueryGraph& graph);

Expand All @@ -18,6 +21,7 @@ class QueryGraphLabelAnalyzer {

private:
const catalog::Catalog& catalog;
const main::ClientContext& clientContext;
};

} // namespace binder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ namespace binder {

class MatchClausePatternLabelRewriter : public BoundStatementVisitor {
public:
explicit MatchClausePatternLabelRewriter(const catalog::Catalog& catalog) : analyzer{catalog} {}
// TODO(Jiamin): remove catalog
explicit MatchClausePatternLabelRewriter(
const catalog::Catalog& catalog, const main::ClientContext& clientContext)
: analyzer{catalog, clientContext} {}

void visitMatch(const BoundReadingClause& readingClause) final;

Expand Down
4 changes: 2 additions & 2 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ class ClientContext {

// Database component getters.
KUZU_API Database* getDatabase() const { return database; }
storage::StorageManager* getStorageManager();
storage::StorageManager* getStorageManager() const;
KUZU_API storage::MemoryManager* getMemoryManager();
catalog::Catalog* getCatalog();
catalog::Catalog* getCatalog() const;
common::VirtualFileSystem* getVFSUnsafe() const;
common::RandomEngine* getRandomEngine();

Expand Down
16 changes: 9 additions & 7 deletions src/include/planner/join_order/cardinality_estimator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class CardinalityEstimator {
DELETE_COPY_DEFAULT_MOVE(CardinalityEstimator);

// TODO(Xiyang): revisit this init at some point. Maybe we should init while enumerating.
void initNodeIDDom(const binder::QueryGraph& queryGraph);
void addNodeIDDom(
const binder::Expression& nodeID, const std::vector<common::table_id_t>& tableIDs);
void initNodeIDDom(const binder::QueryGraph& queryGraph, transaction::Transaction* transaction);
void addNodeIDDom(const binder::Expression& nodeID,
const std::vector<common::table_id_t>& tableIDs, transaction::Transaction* transaction);

uint64_t estimateScanNode(LogicalOperator* op);
uint64_t estimateHashJoin(const binder::expression_vector& joinKeys,
Expand All @@ -30,8 +30,8 @@ class CardinalityEstimator {
uint64_t estimateFlatten(const LogicalPlan& childPlan, f_group_pos groupPosToFlatten);
uint64_t estimateFilter(const LogicalPlan& childPlan, const binder::Expression& predicate);

double getExtensionRate(
const binder::RelExpression& rel, const binder::NodeExpression& boundNode);
double getExtensionRate(const binder::RelExpression& rel,
const binder::NodeExpression& boundNode, transaction::Transaction* transaction);

private:
inline uint64_t atLeastOne(uint64_t x) { return x == 0 ? 1 : x; }
Expand All @@ -40,9 +40,11 @@ class CardinalityEstimator {
KU_ASSERT(nodeIDName2dom.contains(nodeIDName));
return nodeIDName2dom.at(nodeIDName);
}
uint64_t getNumNodes(const std::vector<common::table_id_t>& tableIDs);
uint64_t getNumNodes(
const std::vector<common::table_id_t>& tableIDs, transaction::Transaction* transaction);

uint64_t getNumRels(const std::vector<common::table_id_t>& tableIDs);
uint64_t getNumRels(
const std::vector<common::table_id_t>& tableIDs, transaction::Transaction* transaction);

private:
const storage::NodesStoreStatsAndDeletedIDs* nodesStatistics;
Expand Down
5 changes: 4 additions & 1 deletion src/include/planner/planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ struct LogicalSetPropertyInfo;

class Planner {
public:
Planner(catalog::Catalog* catalog, storage::StorageManager* storageManager);
// TODO(Jiamin): Remove catalog and storageManager
Planner(catalog::Catalog* catalog, storage::StorageManager* storageManager,
main::ClientContext* clientContext);
DELETE_COPY_AND_MOVE(Planner);

std::unique_ptr<LogicalPlan> getBestPlan(const binder::BoundStatement& statement);
Expand Down Expand Up @@ -281,6 +283,7 @@ class Planner {

private:
catalog::Catalog* catalog;
main::ClientContext* clientContext;
storage::StorageManager* storageManager;
binder::expression_vector propertiesToScan;
CardinalityEstimator cardinalityEstimator;
Expand Down
9 changes: 2 additions & 7 deletions src/include/storage/stats/nodes_store_statistics.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "storage/stats/node_table_statistics.h"
#include "storage/stats/table_statistics_collection.h"
#include "storage/storage_utils.h"
#include "transaction/transaction.h"

namespace kuzu {
namespace storage {
Expand All @@ -27,8 +26,8 @@ class NodesStoreStatsAndDeletedIDs : public TablesStatistics {
}

inline NodeTableStatsAndDeletedIDs* getNodeStatisticsAndDeletedIDs(
common::table_id_t tableID) const {
return getNodeTableStats(transaction::TransactionType::READ_ONLY, tableID);
transaction::Transaction* transaction, common::table_id_t tableID) const {
return getNodeTableStats(transaction->getType(), tableID);
}

static inline void saveInitialNodesStatisticsAndDeletedIDsToFile(
Expand Down Expand Up @@ -71,10 +70,6 @@ class NodesStoreStatsAndDeletedIDs : public TablesStatistics {
getNodeTableStats(transaction::TransactionType::WRITE, tableID)->deleteNode(nodeOffset);
}

// This function is only used by storageManager to construct relsStore during start-up, so
// we can just safely return the maxNodeOffsetPerTable for readOnlyVersion.
std::map<common::table_id_t, common::offset_t> getMaxNodeOffsetPerTable() const;

void setDeletedNodeOffsetsForMorsel(transaction::Transaction* transaction,
const std::shared_ptr<common::ValueVector>& nodeOffsetVector, common::table_id_t tableID);

Expand Down
6 changes: 3 additions & 3 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,15 @@ std::string ClientContext::getExtensionDir() const {
return common::stringFormat("{}/.kuzu/extension", config.homeDirectory);
}

storage::StorageManager* ClientContext::getStorageManager() {
storage::StorageManager* ClientContext::getStorageManager() const {
return database->storageManager.get();
}

storage::MemoryManager* ClientContext::getMemoryManager() {
return database->memoryManager.get();
}

catalog::Catalog* ClientContext::getCatalog() {
catalog::Catalog* ClientContext::getCatalog() const {
return database->catalog.get();
}

Expand Down Expand Up @@ -295,7 +295,7 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
preparedStatement->statementResult =
std::make_unique<BoundStatementResult>(boundStatement->getStatementResult()->copy());
// planning
auto planner = Planner(database->catalog.get(), database->storageManager.get());
auto planner = Planner(database->catalog.get(), database->storageManager.get(), this);
std::vector<std::unique_ptr<LogicalPlan>> plans;
if (enumerateAllPlans) {
plans = planner.getAllPlans(*boundStatement);
Expand Down
2 changes: 1 addition & 1 deletion src/main/storage_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ uint64_t StorageDriver::getNumNodes(const std::string& nodeName) {
auto nodeTableID = catalog->getTableID(&DUMMY_READ_TRANSACTION, nodeName);
auto nodeStatistics =
storageManager->getNodesStatisticsAndDeletedIDs()->getNodeStatisticsAndDeletedIDs(
nodeTableID);
&DUMMY_READ_TRANSACTION, nodeTableID);
return nodeStatistics->getNumTuples();
}

Expand Down
33 changes: 17 additions & 16 deletions src/planner/join_order/cardinality_estimator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,30 @@

using namespace kuzu::binder;
using namespace kuzu::common;
using namespace kuzu::transaction;

namespace kuzu {
namespace planner {

void CardinalityEstimator::initNodeIDDom(const QueryGraph& queryGraph) {
void CardinalityEstimator::initNodeIDDom(const QueryGraph& queryGraph, Transaction* transaction) {
for (auto i = 0u; i < queryGraph.getNumQueryNodes(); ++i) {
auto node = queryGraph.getQueryNode(i).get();
addNodeIDDom(*node->getInternalID(), node->getTableIDs());
addNodeIDDom(*node->getInternalID(), node->getTableIDs(), transaction);
}
for (auto i = 0u; i < queryGraph.getNumQueryRels(); ++i) {
auto rel = queryGraph.getQueryRel(i);
if (QueryRelTypeUtils::isRecursive(rel->getRelType())) {
auto node = rel->getRecursiveInfo()->node.get();
addNodeIDDom(*node->getInternalID(), node->getTableIDs());
addNodeIDDom(*node->getInternalID(), node->getTableIDs(), transaction);
}
}
}

void CardinalityEstimator::addNodeIDDom(
const binder::Expression& nodeID, const std::vector<common::table_id_t>& tableIDs) {
void CardinalityEstimator::addNodeIDDom(const binder::Expression& nodeID,
const std::vector<common::table_id_t>& tableIDs, Transaction* transaction) {
auto key = nodeID.getUniqueName();
if (!nodeIDName2dom.contains(key)) {
nodeIDName2dom.insert({key, getNumNodes(tableIDs)});
nodeIDName2dom.insert({key, getNumNodes(tableIDs, transaction)});
}
}

Expand Down Expand Up @@ -102,29 +103,29 @@ uint64_t CardinalityEstimator::estimateFilter(
}
}

uint64_t CardinalityEstimator::getNumNodes(const std::vector<common::table_id_t>& tableIDs) {
uint64_t CardinalityEstimator::getNumNodes(
const std::vector<common::table_id_t>& tableIDs, Transaction* transaction) {
auto numNodes = 0u;
for (auto& tableID : tableIDs) {
numNodes += nodesStatistics->getNodeStatisticsAndDeletedIDs(tableID)->getNumTuples();
numNodes +=
nodesStatistics->getNodeStatisticsAndDeletedIDs(transaction, tableID)->getNumTuples();
}
return atLeastOne(numNodes);
}

uint64_t CardinalityEstimator::getNumRels(const std::vector<common::table_id_t>& tableIDs) {
uint64_t CardinalityEstimator::getNumRels(
const std::vector<common::table_id_t>& tableIDs, Transaction* transaction) {
auto numRels = 0u;
for (auto tableID : tableIDs) {
numRels +=
relsStatistics
->getRelStatistics(tableID, transaction::Transaction::getDummyReadOnlyTrx().get())
->getNumTuples();
numRels += relsStatistics->getRelStatistics(tableID, transaction)->getNumTuples();
}
return atLeastOne(numRels);
}

double CardinalityEstimator::getExtensionRate(
const RelExpression& rel, const NodeExpression& boundNode) {
auto numBoundNodes = (double)getNumNodes(boundNode.getTableIDs());
auto numRels = (double)getNumRels(rel.getTableIDs());
const RelExpression& rel, const NodeExpression& boundNode, Transaction* transaction) {
auto numBoundNodes = (double)getNumNodes(boundNode.getTableIDs(), transaction);
auto numRels = (double)getNumRels(rel.getTableIDs(), transaction);
auto oneHopExtensionRate = numRels / numBoundNodes;
switch (rel.getRelType()) {
case QueryRelType::NON_RECURSIVE: {
Expand Down
Loading

0 comments on commit 8d6b5bc

Please sign in to comment.