From c92c04622950d0a446d85e56643074991edb790a Mon Sep 17 00:00:00 2001 From: andyfeng Date: Wed, 29 Nov 2023 18:18:59 +0800 Subject: [PATCH] hide catalog content (#2502) --- src/binder/bind/bind_comment_on.cpp | 6 +- src/binder/bind/bind_copy.cpp | 18 +- src/binder/bind/bind_create_macro.cpp | 3 +- src/binder/bind/bind_ddl.cpp | 32 ++-- src/binder/bind/bind_graph_pattern.cpp | 49 +++-- src/binder/bind/bind_reading_clause.cpp | 6 +- src/binder/bind/bind_updating_clause.cpp | 5 +- src/binder/bind/copy/bind_copy_rdf_graph.cpp | 4 +- .../bind_function_expression.cpp | 26 +-- src/binder/binder.cpp | 10 +- .../query/query_graph_label_analyzer.cpp | 12 +- src/catalog/catalog.cpp | 151 +++++++++++---- src/catalog/catalog_content.cpp | 104 +++++------ .../table_functions/call_functions.cpp | 44 +++-- src/include/catalog/catalog.h | 49 +++-- src/include/catalog/catalog_content.h | 104 +++-------- src/include/function/table_functions.h | 6 +- .../function/table_functions/call_functions.h | 26 +-- src/include/main/client_context.h | 2 +- src/include/parser/statement.h | 4 - .../operator/ddl/add_node_property.h | 4 +- .../processor/operator/ddl/add_property.h | 2 +- .../processor/operator/ddl/add_rel_property.h | 6 +- .../operator/ddl/create_node_table.h | 2 +- .../processor/operator/ddl/create_rdf_graph.h | 8 +- .../processor/operator/ddl/create_rel_table.h | 4 +- .../operator/ddl/create_rel_table_group.h | 2 +- src/include/processor/operator/ddl/ddl.h | 2 +- .../processor/operator/ddl/drop_property.h | 2 +- .../processor/operator/ddl/drop_table.h | 11 +- .../processor/operator/ddl/rename_property.h | 4 +- .../processor/operator/ddl/rename_table.h | 6 +- .../reader/csv/parallel_csv_reader.h | 2 +- .../persistent/reader/csv/serial_csv_reader.h | 2 +- .../persistent/reader/npy/npy_reader.h | 2 +- .../reader/parquet/parquet_reader.h | 2 +- .../persistent/reader/rdf/rdf_reader.h | 2 +- src/include/processor/operator/scan_node_id.h | 2 +- .../stats/table_statistics_collection.h | 3 +- src/include/storage/storage_manager.h | 7 +- src/main/client_context.cpp | 2 +- src/main/connection.cpp | 56 ++++-- src/main/storage_driver.cpp | 12 +- src/planner/plan/append_extend.cpp | 8 +- src/processor/map/map_ddl.cpp | 7 +- src/processor/map/map_delete.cpp | 8 +- src/processor/map/map_extend.cpp | 5 +- src/processor/map/map_insert.cpp | 3 +- src/processor/map/map_recursive_extend.cpp | 3 +- src/processor/map/map_scan_node_property.cpp | 7 +- src/processor/map/map_set.cpp | 10 +- .../operator/ddl/add_node_property.cpp | 4 +- .../operator/ddl/add_rel_property.cpp | 7 +- .../operator/ddl/create_node_table.cpp | 6 +- .../operator/ddl/create_rdf_graph.cpp | 14 +- .../operator/ddl/create_rel_table.cpp | 4 +- .../operator/ddl/create_rel_table_group.cpp | 8 +- src/processor/operator/ddl/ddl.cpp | 4 +- src/processor/operator/ddl/drop_property.cpp | 9 +- src/processor/operator/ddl/drop_table.cpp | 6 +- src/processor/operator/index_lookup.cpp | 2 +- .../operator/persistent/copy_node.cpp | 8 +- .../operator/persistent/delete_executor.cpp | 18 +- .../reader/csv/parallel_csv_reader.cpp | 2 +- .../reader/csv/serial_csv_reader.cpp | 2 +- .../persistent/reader/npy/npy_reader.cpp | 2 +- .../reader/parquet/parquet_reader.cpp | 2 +- .../persistent/reader/rdf/rdf_reader.cpp | 2 +- .../operator/persistent/set_executor.cpp | 12 +- src/processor/operator/physical_operator.cpp | 2 +- .../operator/scan/scan_rel_csr_columns.cpp | 2 +- src/storage/storage_manager.cpp | 10 +- src/storage/wal_replayer.cpp | 22 ++- src/transaction/transaction_context.cpp | 4 +- test/runner/e2e_copy_transaction_test.cpp | 16 +- test/runner/e2e_ddl_test.cpp | 172 +++++++----------- test/test_files/ddl/ddl.test | 4 +- .../src_cpp/include/pandas/pandas_scan.h | 2 +- .../python_api/src_cpp/pandas/pandas_scan.cpp | 2 +- tools/python_api/src_cpp/py_connection.cpp | 9 +- tools/python_api/src_py/connection.py | 2 + tools/shell/embedded_shell.cpp | 5 +- 82 files changed, 622 insertions(+), 599 deletions(-) diff --git a/src/binder/bind/bind_comment_on.cpp b/src/binder/bind/bind_comment_on.cpp index 4b7e0ab9b3..74098ee6d9 100644 --- a/src/binder/bind/bind_comment_on.cpp +++ b/src/binder/bind/bind_comment_on.cpp @@ -1,5 +1,6 @@ #include "binder/binder.h" #include "binder/bound_comment_on.h" +#include "main/client_context.h" #include "parser/comment_on.h" namespace kuzu { @@ -9,11 +10,8 @@ std::unique_ptr Binder::bindCommentOn(const parser::Statement& s auto& commentOnStatement = reinterpret_cast(statement); auto tableName = commentOnStatement.getTable(); auto comment = commentOnStatement.getComment(); - validateTableExist(tableName); - auto catalogContent = catalog.getReadOnlyVersion(); - auto tableID = catalogContent->getTableID(tableName); - + auto tableID = catalog.getTableID(clientContext->getTx(), tableName); return std::make_unique(tableID, tableName, comment); } diff --git a/src/binder/bind/bind_copy.cpp b/src/binder/bind/bind_copy.cpp index 9f80076448..62101ba5d4 100644 --- a/src/binder/bind/bind_copy.cpp +++ b/src/binder/bind/bind_copy.cpp @@ -8,6 +8,7 @@ #include "common/exception/message.h" #include "common/string_format.h" #include "function/table_functions/bind_input.h" +#include "main/client_context.h" #include "parser/copy.h" using namespace kuzu::binder; @@ -62,12 +63,11 @@ static void validateCopyNpyNotForRelTables(TableSchema* schema) { std::unique_ptr Binder::bindCopyFromClause(const Statement& statement) { auto& copyStatement = reinterpret_cast(statement); - auto catalogContent = catalog.getReadOnlyVersion(); auto tableName = copyStatement.getTableName(); validateTableExist(tableName); // Bind to table schema. - auto tableID = catalogContent->getTableID(tableName); - auto tableSchema = catalogContent->getTableSchema(tableID); + auto tableID = catalog.getTableID(clientContext->getTx(), tableName); + auto tableSchema = catalog.getTableSchema(clientContext->getTx(), tableID); switch (tableSchema->tableType) { case TableType::REL_GROUP: case TableType::RDF: { @@ -120,7 +120,7 @@ std::unique_ptr Binder::bindCopyNodeFrom(const Statement& statem tableSchema, copyStatement.getColumnNames(), expectedColumnNames, expectedColumnTypes); auto bindInput = std::make_unique( memoryManager, *config, std::move(expectedColumnNames), std::move(expectedColumnTypes)); - auto bindData = func->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion()); + auto bindData = func->bindFunc(clientContext, bindInput.get(), (Catalog*)&catalog); expression_vector columns; for (auto i = 0u; i < bindData->columnTypes.size(); i++) { columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i])); @@ -147,7 +147,7 @@ std::unique_ptr Binder::bindCopyRelFrom(const parser::Statement& tableSchema, copyStatement.getColumnNames(), expectedColumnNames, expectedColumnTypes); auto bindInput = std::make_unique(memoryManager, std::move(*config), std::move(expectedColumnNames), std::move(expectedColumnTypes)); - auto bindData = func->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion()); + auto bindData = func->bindFunc(clientContext, bindInput.get(), (Catalog*)&catalog); expression_vector columns; for (auto i = 0u; i < bindData->columnTypes.size(); i++) { columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i])); @@ -158,9 +158,9 @@ std::unique_ptr Binder::bindCopyRelFrom(const parser::Statement& std::make_unique(func, std::move(bindData), columns, offset); auto relTableSchema = reinterpret_cast(tableSchema); auto srcTableSchema = - catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getSrcTableID()); + catalog.getTableSchema(clientContext->getTx(), relTableSchema->getSrcTableID()); auto dstTableSchema = - catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getDstTableID()); + catalog.getTableSchema(clientContext->getTx(), relTableSchema->getDstTableID()); auto srcKey = columns[0]; auto dstKey = columns[1]; auto srcNodeID = createVariable(std::string(InternalKeyword::SRC_OFFSET), LogicalTypeID::INT64); @@ -228,9 +228,9 @@ void Binder::bindExpectedRelColumns(TableSchema* tableSchema, KU_ASSERT(columnNames.empty() && columnTypes.empty()); auto relTableSchema = reinterpret_cast(tableSchema); auto srcTable = reinterpret_cast( - catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getSrcTableID())); + catalog.getTableSchema(clientContext->getTx(), relTableSchema->getSrcTableID())); auto dstTable = reinterpret_cast( - catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getDstTableID())); + catalog.getTableSchema(clientContext->getTx(), relTableSchema->getDstTableID())); columnNames.push_back("from"); columnNames.push_back("to"); auto srcPKColumnType = srcTable->getPrimaryKey()->getDataType()->copy(); diff --git a/src/binder/bind/bind_create_macro.cpp b/src/binder/bind/bind_create_macro.cpp index 4347655f26..21dcc59b24 100644 --- a/src/binder/bind/bind_create_macro.cpp +++ b/src/binder/bind/bind_create_macro.cpp @@ -3,6 +3,7 @@ #include "common/exception/binder.h" #include "common/string_format.h" #include "common/string_utils.h" +#include "main/client_context.h" #include "parser/create_macro.h" using namespace kuzu::common; @@ -15,7 +16,7 @@ std::unique_ptr Binder::bindCreateMacro(const Statement& stateme auto& createMacro = reinterpret_cast(statement); auto macroName = createMacro.getMacroName(); StringUtils::toUpper(macroName); - if (catalog.getReadOnlyVersion()->containMacro(macroName)) { + if (catalog.containsMacro(clientContext->getTx(), macroName)) { throw BinderException{stringFormat("Macro {} already exists.", macroName)}; } parser::default_macro_args defaultArgs; diff --git a/src/binder/bind/bind_ddl.cpp b/src/binder/bind/bind_ddl.cpp index 6e705f512b..4b73bfb305 100644 --- a/src/binder/bind/bind_ddl.cpp +++ b/src/binder/bind/bind_ddl.cpp @@ -8,6 +8,7 @@ #include "common/exception/binder.h" #include "common/string_format.h" #include "common/string_utils.h" +#include "main/client_context.h" #include "parser/ddl/alter.h" #include "parser/ddl/create_table.h" #include "parser/ddl/drop.h" @@ -155,7 +156,7 @@ std::unique_ptr Binder::bindCreateRelTableGroupInfo( std::unique_ptr Binder::bindCreateTable(const parser::Statement& statement) { auto& createTable = reinterpret_cast(statement); auto tableName = createTable.getTableName(); - if (catalog.getReadOnlyVersion()->containsTable(tableName)) { + if (catalog.containsTable(clientContext->getTx(), tableName)) { throw BinderException(tableName + " already exists in catalog."); } auto boundCreateInfo = bindCreateTableInfo(createTable.getInfo()); @@ -166,12 +167,11 @@ std::unique_ptr Binder::bindDropTable(const Statement& statement auto& dropTable = reinterpret_cast(statement); auto tableName = dropTable.getTableName(); validateTableExist(tableName); - auto catalogContent = catalog.getReadOnlyVersion(); - auto tableID = catalogContent->getTableID(tableName); - auto tableSchema = catalogContent->getTableSchema(tableID); + auto tableID = catalog.getTableID(clientContext->getTx(), tableName); + auto tableSchema = catalog.getTableSchema(clientContext->getTx(), tableID); switch (tableSchema->tableType) { case TableType::NODE: { - for (auto& schema : catalogContent->getRelTableSchemas()) { + for (auto& schema : catalog.getRelTableSchemas(clientContext->getTx())) { auto relTableSchema = reinterpret_cast(schema); if (relTableSchema->isSrcOrDstTable(tableID)) { throw BinderException( @@ -181,7 +181,7 @@ std::unique_ptr Binder::bindDropTable(const Statement& statement } } break; case TableType::REL: { - for (auto& schema : catalogContent->getRelTableGroupSchemas()) { + for (auto& schema : catalog.getRelTableGroupSchemas(clientContext->getTx())) { auto relTableGroupSchema = reinterpret_cast(schema); for (auto& relTableID : relTableGroupSchema->getRelTableIDs()) { if (relTableID == tableSchema->getTableID()) { @@ -226,12 +226,11 @@ std::unique_ptr Binder::bindRenameTable(const Statement& stateme auto extraInfo = reinterpret_cast(info->extraInfo.get()); auto tableName = info->tableName; auto newName = extraInfo->newName; - auto catalogContent = catalog.getReadOnlyVersion(); validateTableExist(tableName); - if (catalogContent->containsTable(newName)) { + if (catalog.containsTable(clientContext->getTx(), newName)) { throw BinderException("Table: " + newName + " already exists."); } - auto tableID = catalogContent->getTableID(tableName); + auto tableID = catalog.getTableID(clientContext->getTx(), tableName); auto boundExtraInfo = std::make_unique(newName); auto boundInfo = std::make_unique( AlterType::RENAME_TABLE, tableName, tableID, std::move(boundExtraInfo)); @@ -273,9 +272,8 @@ std::unique_ptr Binder::bindAddProperty(const Statement& stateme auto dataType = bindDataType(extraInfo->dataType); auto propertyName = extraInfo->propertyName; validateTableExist(tableName); - auto catalogContent = catalog.getReadOnlyVersion(); - auto tableID = catalogContent->getTableID(tableName); - auto tableSchema = catalogContent->getTableSchema(tableID); + auto tableID = catalog.getTableID(clientContext->getTx(), tableName); + auto tableSchema = catalog.getTableSchema(clientContext->getTx(), tableID); validatePropertyDDLOnTable(tableSchema, "add"); validatePropertyNotExist(tableSchema, propertyName); if (dataType->getLogicalTypeID() == LogicalTypeID::SERIAL) { @@ -297,9 +295,8 @@ std::unique_ptr Binder::bindDropProperty(const Statement& statem auto tableName = info->tableName; auto propertyName = extraInfo->propertyName; validateTableExist(tableName); - auto catalogContent = catalog.getReadOnlyVersion(); - auto tableID = catalogContent->getTableID(tableName); - auto tableSchema = catalogContent->getTableSchema(tableID); + auto tableID = catalog.getTableID(clientContext->getTx(), tableName); + auto tableSchema = catalog.getTableSchema(clientContext->getTx(), tableID); validatePropertyDDLOnTable(tableSchema, "drop"); validatePropertyExist(tableSchema, propertyName); auto propertyID = tableSchema->getPropertyID(propertyName); @@ -321,9 +318,8 @@ std::unique_ptr Binder::bindRenameProperty(const Statement& stat auto propertyName = extraInfo->propertyName; auto newName = extraInfo->newName; validateTableExist(tableName); - auto catalogContent = catalog.getReadOnlyVersion(); - auto tableID = catalogContent->getTableID(tableName); - auto tableSchema = catalogContent->getTableSchema(tableID); + auto tableID = catalog.getTableID(clientContext->getTx(), tableName); + auto tableSchema = catalog.getTableSchema(clientContext->getTx(), tableID); validatePropertyDDLOnTable(tableSchema, "rename"); validatePropertyExist(tableSchema, propertyName); auto propertyID = tableSchema->getPropertyID(propertyName); diff --git a/src/binder/bind/bind_graph_pattern.cpp b/src/binder/bind/bind_graph_pattern.cpp index 4e27d2b8cf..871e4f339f 100644 --- a/src/binder/bind/bind_graph_pattern.cpp +++ b/src/binder/bind/bind_graph_pattern.cpp @@ -262,18 +262,19 @@ std::shared_ptr Binder::createNonRecursiveQueryRel(const std::str } auto extraInfo = std::make_unique(fieldNames, fieldTypes); RelType::setExtraTypeInfo(queryRel->getDataTypeReference(), std::move(extraInfo)); - auto readVersion = catalog.getReadOnlyVersion(); - if (readVersion->getTableSchema(tableIDs[0])->getTableType() == TableType::RDF) { + auto tableSchema = catalog.getTableSchema(clientContext->getTx(), tableIDs[0]); + if (tableSchema->getTableType() == TableType::RDF) { auto predicateID = expressionBinder.bindNodeOrRelPropertyExpression(*queryRel, std::string(rdf::PID)); std::vector resourceTableIDs; std::vector resourceTableSchemas; for (auto& tableID : tableIDs) { - auto rdfGraphSchema = - reinterpret_cast(readVersion->getTableSchema(tableID)); + auto rdfGraphSchema = reinterpret_cast( + catalog.getTableSchema(clientContext->getTx(), tableID)); auto resourceTableID = rdfGraphSchema->getResourceTableID(); resourceTableIDs.push_back(resourceTableID); - resourceTableSchemas.push_back(readVersion->getTableSchema(resourceTableID)); + resourceTableSchemas.push_back( + catalog.getTableSchema(clientContext->getTx(), resourceTableID)); } auto predicateIRI = createPropertyExpression(std::string(rdf::IRI), queryRel->getUniqueName(), queryRel->getVariableName(), resourceTableSchemas); @@ -305,7 +306,7 @@ std::shared_ptr Binder::createRecursiveQueryRel(const parser::Rel std::unordered_set nodeTableIDs; for (auto relTableID : relTableIDs) { auto relTableSchema = reinterpret_cast( - catalog.getReadOnlyVersion()->getTableSchema(relTableID)); + catalog.getTableSchema(clientContext->getTx(), relTableID)); nodeTableIDs.insert(relTableSchema->getSrcTableID()); nodeTableIDs.insert(relTableSchema->getDstTableID()); } @@ -463,7 +464,7 @@ std::pair Binder::bindVariableLengthRelBound( void Binder::bindQueryRelProperties(RelExpression& rel) { std::vector tableSchemas; for (auto tableID : rel.getTableIDs()) { - tableSchemas.push_back(catalog.getReadOnlyVersion()->getTableSchema(tableID)); + tableSchemas.push_back(catalog.getTableSchema(clientContext->getTx(), tableID)); } auto propertyNames = getPropertyNames(tableSchemas); for (auto& propertyName : propertyNames) { @@ -539,7 +540,7 @@ std::shared_ptr Binder::createQueryNode( } void Binder::bindQueryNodeProperties(NodeExpression& node) { - auto tableSchemas = catalog.getReadOnlyVersion()->getTableSchemas(node.getTableIDs()); + auto tableSchemas = catalog.getTableSchemas(clientContext->getTx(), node.getTableIDs()); auto propertyNames = getPropertyNames(tableSchemas); for (auto& propertyName : propertyNames) { auto property = createPropertyExpression( @@ -550,26 +551,26 @@ void Binder::bindQueryNodeProperties(NodeExpression& node) { std::vector Binder::bindTableIDs( const std::vector& tableNames, bool nodePattern) { - auto catalogContent = catalog.getReadOnlyVersion(); + auto tx = clientContext->getTx(); std::unordered_set tableIDSet; if (tableNames.empty()) { // Rewrite empty table names as all tables. - if (catalogContent->containsRdfGraph()) { + if (catalog.containsRdfGraph(tx)) { // If catalog contains rdf graph then it should NOT have any property graph table. - for (auto tableID : catalogContent->getRdfGraphIDs()) { + for (auto tableID : catalog.getRdfGraphIDs(tx)) { tableIDSet.insert(tableID); } } else if (nodePattern) { - if (!catalogContent->containsNodeTable()) { + if (!catalog.containsNodeTable(tx)) { throw BinderException("No node table exists in database."); } - for (auto tableID : catalogContent->getNodeTableIDs()) { + for (auto tableID : catalog.getNodeTableIDs(tx)) { tableIDSet.insert(tableID); } } else { // rel - if (!catalogContent->containsRelTable()) { + if (!catalog.containsRelTable(tx)) { throw BinderException("No rel table exists in database."); } - for (auto tableID : catalogContent->getRelTableIDs()) { + for (auto tableID : catalog.getRelTableIDs(tx)) { tableIDSet.insert(tableID); } } @@ -584,14 +585,13 @@ std::vector Binder::bindTableIDs( } std::vector Binder::getNodeTableIDs(const std::vector& tableIDs) { - auto readVersion = catalog.getReadOnlyVersion(); - auto tableType = readVersion->getTableSchema(tableIDs[0])->getTableType(); + auto tableType = catalog.getTableSchema(clientContext->getTx(), tableIDs[0])->getTableType(); switch (tableType) { case TableType::RDF: { // extract node table ID from rdf graph schema. std::vector result; for (auto& tableID : tableIDs) { - auto rdfGraphSchema = - reinterpret_cast(readVersion->getTableSchema(tableID)); + auto rdfGraphSchema = reinterpret_cast( + catalog.getTableSchema(clientContext->getTx(), tableID)); result.push_back(rdfGraphSchema->getResourceTableID()); result.push_back(rdfGraphSchema->getLiteralTableID()); } @@ -606,14 +606,13 @@ std::vector Binder::getNodeTableIDs(const std::vector& t } std::vector Binder::getRelTableIDs(const std::vector& tableIDs) { - auto readVersion = catalog.getReadOnlyVersion(); - auto tableType = readVersion->getTableSchema(tableIDs[0])->getTableType(); + auto tableType = catalog.getTableSchema(clientContext->getTx(), tableIDs[0])->getTableType(); switch (tableType) { case TableType::RDF: { // extract rel table ID from rdf graph schema. std::vector result; for (auto& tableID : tableIDs) { - auto rdfGraphSchema = - reinterpret_cast(readVersion->getTableSchema(tableID)); + auto rdfGraphSchema = reinterpret_cast( + catalog.getTableSchema(clientContext->getTx(), tableID)); result.push_back(rdfGraphSchema->getResourceTripleTableID()); result.push_back(rdfGraphSchema->getLiteralTripleTableID()); } @@ -622,8 +621,8 @@ std::vector Binder::getRelTableIDs(const std::vector& ta case TableType::REL_GROUP: { // extract rel table ID from rel group schema. std::vector result; for (auto& tableID : tableIDs) { - auto relGroupSchema = - reinterpret_cast(readVersion->getTableSchema(tableID)); + auto relGroupSchema = reinterpret_cast( + catalog.getTableSchema(clientContext->getTx(), tableID)); for (auto& relTableID : relGroupSchema->getRelTableIDs()) { result.push_back(relTableID); } diff --git a/src/binder/bind/bind_reading_clause.cpp b/src/binder/bind/bind_reading_clause.cpp index e95e8356b8..dce431b541 100644 --- a/src/binder/bind/bind_reading_clause.cpp +++ b/src/binder/bind/bind_reading_clause.cpp @@ -127,8 +127,7 @@ std::unique_ptr Binder::bindInQueryCall(const ReadingClause& catalog.getBuiltInFunctions()->matchScalarFunction(funcName, inputTypes)); auto bindInput = std::make_unique(); bindInput->inputs = std::move(inputValues); - auto bindData = - tableFunction->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion()); + auto bindData = tableFunction->bindFunc(clientContext, bindInput.get(), (Catalog*)&catalog); expression_vector columns; for (auto i = 0u; i < bindData->columnTypes.size(); i++) { columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i])); @@ -175,8 +174,7 @@ std::unique_ptr Binder::bindLoadFrom( getScanFunction(readerConfig->fileType, readerConfig->csvReaderConfig->parallel); auto bindInput = std::make_unique(memoryManager, *readerConfig, std::move(expectedColumnNames), std::move(expectedColumnTypes)); - auto bindData = - scanFunction->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion()); + auto bindData = scanFunction->bindFunc(clientContext, bindInput.get(), (Catalog*)&catalog); expression_vector columns; for (auto i = 0u; i < bindData->columnTypes.size(); i++) { columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i])); diff --git a/src/binder/bind/bind_updating_clause.cpp b/src/binder/bind/bind_updating_clause.cpp index 0305bcded1..2cf706137e 100644 --- a/src/binder/bind/bind_updating_clause.cpp +++ b/src/binder/bind/bind_updating_clause.cpp @@ -9,6 +9,7 @@ #include "common/assert.h" #include "common/exception/binder.h" #include "common/string_format.h" +#include "main/client_context.h" #include "parser/query/updating_clause/delete_clause.h" #include "parser/query/updating_clause/insert_clause.h" #include "parser/query/updating_clause/merge_clause.h" @@ -149,7 +150,7 @@ std::unique_ptr Binder::bindInsertNodeInfo( "Create node " + node->toString() + " with multiple node labels is not supported."); } auto tableID = node->getSingleTableID(); - auto tableSchema = catalog.getReadOnlyVersion()->getTableSchema(tableID); + auto tableSchema = catalog.getTableSchema(clientContext->getTx(), tableID); validatePrimaryKeyExistence(collection, tableSchema, node); auto setItems = bindSetItems(collection, tableSchema, node); return std::make_unique( @@ -168,7 +169,7 @@ std::unique_ptr Binder::bindInsertRelInfo( common::stringFormat("Cannot create recursive rel {}.", rel->toString())); } auto relTableID = rel->getSingleTableID(); - auto tableSchema = catalog.getReadOnlyVersion()->getTableSchema(relTableID); + auto tableSchema = catalog.getTableSchema(clientContext->getTx(), relTableID); auto setItems = bindSetItems(collection, tableSchema, rel); return std::make_unique( UpdateTableType::REL, std::move(rel), std::move(setItems)); diff --git a/src/binder/bind/copy/bind_copy_rdf_graph.cpp b/src/binder/bind/copy/bind_copy_rdf_graph.cpp index 3a20633526..1faa8a20d2 100644 --- a/src/binder/bind/copy/bind_copy_rdf_graph.cpp +++ b/src/binder/bind/copy/bind_copy_rdf_graph.cpp @@ -36,7 +36,7 @@ std::unique_ptr Binder::bindCopyRdfNodeFrom(const Statement& /*s } auto bindInput = std::make_unique( memoryManager, *config, columnNames, std::move(columnTypes)); - auto bindData = func->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion()); + auto bindData = func->bindFunc(clientContext, bindInput.get(), (Catalog*)&catalog); expression_vector columns; for (auto i = 0u; i < bindData->columnTypes.size(); i++) { columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i])); @@ -75,7 +75,7 @@ std::unique_ptr Binder::bindCopyRdfRelFrom(const Statement& /*st } auto bindInput = std::make_unique( memoryManager, *config, columnNames, std::move(columnTypes)); - auto bindData = func->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion()); + auto bindData = func->bindFunc(clientContext, bindInput.get(), (Catalog*)&catalog); expression_vector columns; for (auto i = 0u; i < bindData->columnTypes.size(); i++) { columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i])); diff --git a/src/binder/bind_expression/bind_function_expression.cpp b/src/binder/bind_expression/bind_function_expression.cpp index c8f21724c6..0157d6b95c 100644 --- a/src/binder/bind_expression/bind_function_expression.cpp +++ b/src/binder/bind_expression/bind_function_expression.cpp @@ -6,6 +6,7 @@ #include "common/exception/binder.h" #include "common/string_utils.h" #include "function/schema/vector_label_functions.h" +#include "main/client_context.h" #include "parser/expression/parsed_function_expression.h" #include "parser/parsed_expression_visitor.h" @@ -199,8 +200,8 @@ std::shared_ptr ExpressionBinder::bindInternalIDExpression( STRUCT_EXTRACT_FUNC_NAME); } -static std::vector> populateLabelValues( - std::vector tableIDs, const catalog::CatalogContent& catalogContent) { +static std::vector> populateLabelValues(std::vector tableIDs, + const catalog::Catalog& catalog, transaction::Transaction* tx) { auto tableIDsSet = std::unordered_set(tableIDs.begin(), tableIDs.end()); table_id_t maxTableID = *std::max_element(tableIDsSet.begin(), tableIDsSet.end()); std::vector> labels; @@ -208,7 +209,7 @@ static std::vector> populateLabelValues( for (auto i = 0; i < labels.size(); ++i) { if (tableIDsSet.contains(i)) { labels[i] = std::make_unique( - LogicalType{LogicalTypeID::STRING}, catalogContent.getTableName(i)); + LogicalType{LogicalTypeID::STRING}, catalog.getTableName(tx, i)); } else { // TODO(Xiyang/Guodong): change to null literal once we support null in LIST type. labels[i] = @@ -219,7 +220,6 @@ static std::vector> populateLabelValues( } std::shared_ptr ExpressionBinder::bindLabelFunction(const Expression& expression) { - auto catalogContent = binder->catalog.getReadOnlyVersion(); auto varListTypeInfo = std::make_unique(LogicalType::STRING()); auto listType = std::make_unique(LogicalTypeID::VAR_LIST, std::move(varListTypeInfo)); @@ -228,27 +228,29 @@ std::shared_ptr ExpressionBinder::bindLabelFunction(const Expression case LogicalTypeID::NODE: { auto& node = (NodeExpression&)expression; if (!node.isMultiLabeled()) { - auto labelName = catalogContent->getTableName(node.getSingleTableID()); + auto labelName = binder->catalog.getTableName( + binder->clientContext->getTx(), node.getSingleTableID()); return createLiteralExpression( std::make_unique(LogicalType{LogicalTypeID::STRING}, labelName)); } - auto nodeTableIDs = catalogContent->getNodeTableIDs(); + auto nodeTableIDs = binder->catalog.getNodeTableIDs(binder->clientContext->getTx()); children.push_back(node.getInternalID()); - auto labelsValue = - std::make_unique(*listType, populateLabelValues(nodeTableIDs, *catalogContent)); + auto labelsValue = std::make_unique(*listType, + populateLabelValues(nodeTableIDs, binder->catalog, binder->clientContext->getTx())); children.push_back(createLiteralExpression(std::move(labelsValue))); } break; case LogicalTypeID::REL: { auto& rel = (RelExpression&)expression; if (!rel.isMultiLabeled()) { - auto labelName = catalogContent->getTableName(rel.getSingleTableID()); + auto labelName = binder->catalog.getTableName( + binder->clientContext->getTx(), rel.getSingleTableID()); return createLiteralExpression( std::make_unique(LogicalType{LogicalTypeID::STRING}, labelName)); } - auto relTableIDs = catalogContent->getRelTableIDs(); + auto relTableIDs = binder->catalog.getRelTableIDs(binder->clientContext->getTx()); children.push_back(rel.getInternalIDProperty()); - auto labelsValue = - std::make_unique(*listType, populateLabelValues(relTableIDs, *catalogContent)); + auto labelsValue = std::make_unique(*listType, + populateLabelValues(relTableIDs, binder->catalog, binder->clientContext->getTx())); children.push_back(createLiteralExpression(std::move(labelsValue))); } break; default: diff --git a/src/binder/binder.cpp b/src/binder/binder.cpp index ecbae17c29..57eb4a51d8 100644 --- a/src/binder/binder.cpp +++ b/src/binder/binder.cpp @@ -4,6 +4,7 @@ #include "common/exception/binder.h" #include "common/string_format.h" #include "function/table_functions.h" +#include "main/client_context.h" using namespace kuzu::common; using namespace kuzu::parser; @@ -63,11 +64,10 @@ std::shared_ptr Binder::bindWhereExpression(const ParsedExpression& } common::table_id_t Binder::bindTableID(const std::string& tableName) const { - auto catalogContent = catalog.getReadOnlyVersion(); - if (!catalogContent->containsTable(tableName)) { + if (!catalog.containsTable(clientContext->getTx(), tableName)) { throw BinderException(common::stringFormat("Table {} does not exist.", tableName)); } - return catalogContent->getTableID(tableName); + return catalog.getTableID(clientContext->getTx(), tableName); } std::shared_ptr Binder::createVariable( @@ -186,13 +186,13 @@ void Binder::validateReadNotFollowUpdate(const NormalizedSingleQuery& singleQuer } void Binder::validateTableType(table_id_t tableID, TableType expectedTableType) { - if (catalog.getReadOnlyVersion()->getTableSchema(tableID)->tableType != expectedTableType) { + if (catalog.getTableSchema(clientContext->getTx(), tableID)->tableType != expectedTableType) { throw BinderException("aa"); } } void Binder::validateTableExist(const std::string& tableName) { - if (!catalog.getReadOnlyVersion()->containsTable(tableName)) { + if (!catalog.containsTable(clientContext->getTx(), tableName)) { throw BinderException("Table " + tableName + " does not exist."); } } diff --git a/src/binder/query/query_graph_label_analyzer.cpp b/src/binder/query/query_graph_label_analyzer.cpp index e661d41541..03e14279f5 100644 --- a/src/binder/query/query_graph_label_analyzer.cpp +++ b/src/binder/query/query_graph_label_analyzer.cpp @@ -3,9 +3,11 @@ #include "catalog/rel_table_schema.h" #include "common/exception/binder.h" #include "common/string_format.h" +#include "transaction/transaction.h" using namespace kuzu::common; using namespace kuzu::catalog; +using namespace kuzu::transaction; namespace kuzu { namespace binder { @@ -32,7 +34,7 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression& if (isSrcConnect || isDstConnect) { for (auto relTableID : queryRel->getTableIDs()) { auto relTableSchema = reinterpret_cast( - catalog.getReadOnlyVersion()->getTableSchema(relTableID)); + catalog.getTableSchema(&DUMMY_READ_TRANSACTION, relTableID)); candidates.insert(relTableSchema->getSrcTableID()); candidates.insert(relTableSchema->getDstTableID()); } @@ -41,13 +43,13 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression& if (isSrcConnect) { for (auto relTableID : queryRel->getTableIDs()) { auto relTableSchema = reinterpret_cast( - catalog.getReadOnlyVersion()->getTableSchema(relTableID)); + catalog.getTableSchema(&DUMMY_READ_TRANSACTION, relTableID)); candidates.insert(relTableSchema->getSrcTableID()); } } else if (isDstConnect) { for (auto relTableID : queryRel->getTableIDs()) { auto relTableSchema = reinterpret_cast( - catalog.getReadOnlyVersion()->getTableSchema(relTableID)); + catalog.getTableSchema(&DUMMY_READ_TRANSACTION, relTableID)); candidates.insert(relTableSchema->getDstTableID()); } } @@ -86,7 +88,7 @@ void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) { } for (auto& relTableID : rel.getTableIDs()) { auto relTableSchema = reinterpret_cast( - catalog.getReadOnlyVersion()->getTableSchema(relTableID)); + catalog.getTableSchema(&DUMMY_READ_TRANSACTION, relTableID)); auto srcTableID = relTableSchema->getSrcTableID(); auto dstTableID = relTableSchema->getDstTableID(); if (!boundTableIDSet.contains(srcTableID) || !boundTableIDSet.contains(dstTableID)) { @@ -99,7 +101,7 @@ void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) { auto dstTableIDSet = rel.getDstNode()->getTableIDsSet(); for (auto& relTableID : rel.getTableIDs()) { auto relTableSchema = reinterpret_cast( - catalog.getReadOnlyVersion()->getTableSchema(relTableID)); + catalog.getTableSchema(&DUMMY_READ_TRANSACTION, relTableID)); auto srcTableID = relTableSchema->getSrcTableID(); auto dstTableID = relTableSchema->getDstTableID(); if (!srcTableIDSet.contains(srcTableID) || !dstTableIDSet.contains(dstTableID)) { diff --git a/src/catalog/catalog.cpp b/src/catalog/catalog.cpp index 6f28b13f12..9f6d46c3a1 100644 --- a/src/catalog/catalog.cpp +++ b/src/catalog/catalog.cpp @@ -2,6 +2,7 @@ #include "catalog/rel_table_group_schema.h" #include "storage/wal/wal.h" +#include "transaction/transaction.h" #include "transaction/transaction_action.h" using namespace kuzu::common; @@ -12,118 +13,196 @@ namespace kuzu { namespace catalog { Catalog::Catalog() : wal{nullptr} { - catalogContentForReadOnlyTrx = std::make_unique(); + readOnlyVersion = std::make_unique(); } Catalog::Catalog(WAL* wal) : wal{wal} { - catalogContentForReadOnlyTrx = std::make_unique(wal->getDirectory()); + readOnlyVersion = std::make_unique(wal->getDirectory()); +} + +uint64_t Catalog::getTableCount(Transaction* tx) const { + return getVersion(tx)->getTableCount(); +} + +bool Catalog::containsNodeTable(Transaction* tx) const { + return getVersion(tx)->containsTable(TableType::NODE); +} + +bool Catalog::containsRelTable(Transaction* tx) const { + return getVersion(tx)->containsTable(TableType::REL); +} + +bool Catalog::containsRdfGraph(Transaction* tx) const { + return getVersion(tx)->containsTable(TableType::RDF); +} + +bool Catalog::containsTable(Transaction* tx, const std::string& tableName) const { + return getVersion(tx)->containsTable(tableName); +} + +table_id_t Catalog::getTableID(Transaction* tx, const std::string& tableName) const { + return getVersion(tx)->getTableID(tableName); +} + +std::vector Catalog::getNodeTableIDs(Transaction* tx) const { + return getVersion(tx)->getTableIDs(TableType::NODE); +} + +std::vector Catalog::getRelTableIDs(Transaction* tx) const { + return getVersion(tx)->getTableIDs(TableType::REL); +} + +std::vector Catalog::getRdfGraphIDs(Transaction* tx) const { + return getVersion(tx)->getTableIDs(TableType::RDF); +} + +std::string Catalog::getTableName(Transaction* tx, table_id_t tableID) const { + return getVersion(tx)->getTableName(tableID); +} + +TableSchema* Catalog::getTableSchema(Transaction* tx, table_id_t tableID) const { + return getVersion(tx)->getTableSchema(tableID); +} + +std::vector Catalog::getNodeTableSchemas(Transaction* tx) const { + return getVersion(tx)->getTableSchemas(TableType::NODE); +} + +std::vector Catalog::getRelTableSchemas(Transaction* tx) const { + return getVersion(tx)->getTableSchemas(TableType::REL); +} + +std::vector Catalog::getRelTableGroupSchemas(Transaction* tx) const { + return getVersion(tx)->getTableSchemas(TableType::REL_GROUP); +} + +std::vector Catalog::getTableSchemas(Transaction* tx) const { + std::vector result; + for (auto& [_, schema] : getVersion(tx)->tableSchemas) { + result.push_back(schema.get()); + } + return result; +} + +std::vector Catalog::getTableSchemas( + Transaction* tx, const table_id_vector_t& tableIDs) const { + std::vector result; + for (auto tableID : tableIDs) { + result.push_back(getVersion(tx)->getTableSchema(tableID)); + } + return result; } void Catalog::prepareCommitOrRollback(TransactionAction action) { if (hasUpdates()) { wal->logCatalogRecord(); if (action == TransactionAction::COMMIT) { - catalogContentForWriteTrx->saveToFile( - wal->getDirectory(), FileVersionType::WAL_VERSION); + readWriteVersion->saveToFile(wal->getDirectory(), FileVersionType::WAL_VERSION); } } } void Catalog::checkpointInMemory() { if (hasUpdates()) { - catalogContentForReadOnlyTrx = std::move(catalogContentForWriteTrx); + readOnlyVersion = std::move(readWriteVersion); } } ExpressionType Catalog::getFunctionType(const std::string& name) const { - return catalogContentForReadOnlyTrx->getFunctionType(name); + return readOnlyVersion->getFunctionType(name); } table_id_t Catalog::addNodeTableSchema(const binder::BoundCreateTableInfo& info) { - initCatalogContentForWriteTrxIfNecessary(); - return catalogContentForWriteTrx->addNodeTableSchema(info); + KU_ASSERT(readWriteVersion != nullptr); + return readWriteVersion->addNodeTableSchema(info); } table_id_t Catalog::addRelTableSchema(const binder::BoundCreateTableInfo& info) { - initCatalogContentForWriteTrxIfNecessary(); - auto tableID = catalogContentForWriteTrx->addRelTableSchema(info); - return tableID; + KU_ASSERT(readWriteVersion != nullptr); + return readWriteVersion->addRelTableSchema(info); } common::table_id_t Catalog::addRelTableGroupSchema(const binder::BoundCreateTableInfo& info) { - initCatalogContentForWriteTrxIfNecessary(); - auto tableID = catalogContentForWriteTrx->addRelTableGroupSchema(info); + KU_ASSERT(readWriteVersion != nullptr); + auto tableID = readWriteVersion->addRelTableGroupSchema(info); return tableID; } table_id_t Catalog::addRdfGraphSchema(const binder::BoundCreateTableInfo& info) { - initCatalogContentForWriteTrxIfNecessary(); - return catalogContentForWriteTrx->addRdfGraphSchema(info); + KU_ASSERT(readWriteVersion != nullptr); + return readWriteVersion->addRdfGraphSchema(info); } void Catalog::dropTableSchema(table_id_t tableID) { - initCatalogContentForWriteTrxIfNecessary(); - auto tableSchema = catalogContentForWriteTrx->getTableSchema(tableID); + KU_ASSERT(readWriteVersion != nullptr); + auto tableSchema = readWriteVersion->getTableSchema(tableID); switch (tableSchema->tableType) { case TableType::REL_GROUP: { auto relTableGroupSchema = reinterpret_cast(tableSchema); auto relTableIDs = relTableGroupSchema->getRelTableIDs(); - catalogContentForWriteTrx->dropTableSchema(tableID); + readWriteVersion->dropTableSchema(tableID); for (auto relTableID : relTableIDs) { wal->logDropTableRecord(relTableID); } } break; default: { - catalogContentForWriteTrx->dropTableSchema(tableID); + readWriteVersion->dropTableSchema(tableID); wal->logDropTableRecord(tableID); } } } void Catalog::renameTable(table_id_t tableID, const std::string& newName) { - initCatalogContentForWriteTrxIfNecessary(); - catalogContentForWriteTrx->renameTable(tableID, newName); + KU_ASSERT(readWriteVersion != nullptr); + readWriteVersion->renameTable(tableID, newName); } void Catalog::addNodeProperty( table_id_t tableID, const std::string& propertyName, std::unique_ptr dataType) { - initCatalogContentForWriteTrxIfNecessary(); - catalogContentForWriteTrx->getTableSchema(tableID)->addNodeProperty( - propertyName, std::move(dataType)); + KU_ASSERT(readWriteVersion != nullptr); + readWriteVersion->getTableSchema(tableID)->addNodeProperty(propertyName, std::move(dataType)); } void Catalog::addRelProperty( table_id_t tableID, const std::string& propertyName, std::unique_ptr dataType) { - initCatalogContentForWriteTrxIfNecessary(); - catalogContentForWriteTrx->getTableSchema(tableID)->addRelProperty( - propertyName, std::move(dataType)); + KU_ASSERT(readWriteVersion != nullptr); + readWriteVersion->getTableSchema(tableID)->addRelProperty(propertyName, std::move(dataType)); } void Catalog::dropProperty(table_id_t tableID, property_id_t propertyID) { - initCatalogContentForWriteTrxIfNecessary(); - catalogContentForWriteTrx->getTableSchema(tableID)->dropProperty(propertyID); + KU_ASSERT(readWriteVersion != nullptr); + readWriteVersion->getTableSchema(tableID)->dropProperty(propertyID); wal->logDropPropertyRecord(tableID, propertyID); } void Catalog::renameProperty( table_id_t tableID, property_id_t propertyID, const std::string& newName) { - initCatalogContentForWriteTrxIfNecessary(); - catalogContentForWriteTrx->getTableSchema(tableID)->renameProperty(propertyID, newName); + KU_ASSERT(readWriteVersion != nullptr); + readWriteVersion->getTableSchema(tableID)->renameProperty(propertyID, newName); } void Catalog::addFunction(std::string name, function::function_set functionSet) { - catalogContentForReadOnlyTrx->addFunction(std::move(name), std::move(functionSet)); + readOnlyVersion->addFunction(std::move(name), std::move(functionSet)); +} + +bool Catalog::containsMacro(Transaction* tx, const std::string& macroName) const { + return getVersion(tx)->containMacro(macroName); } void Catalog::addScalarMacroFunction( std::string name, std::unique_ptr macro) { - initCatalogContentForWriteTrxIfNecessary(); - catalogContentForWriteTrx->addScalarMacroFunction(std::move(name), std::move(macro)); + KU_ASSERT(readWriteVersion != nullptr); + readWriteVersion->addScalarMacroFunction(std::move(name), std::move(macro)); } void Catalog::setTableComment(table_id_t tableID, const std::string& comment) { - initCatalogContentForWriteTrxIfNecessary(); - catalogContentForWriteTrx->getTableSchema(tableID)->setComment(comment); + KU_ASSERT(readWriteVersion != nullptr); + readWriteVersion->getTableSchema(tableID)->setComment(comment); +} + +CatalogContent* Catalog::getVersion(Transaction* tx) const { + return tx->getType() == TransactionType::READ_ONLY ? readOnlyVersion.get() : + readWriteVersion.get(); } } // namespace catalog diff --git a/src/catalog/catalog_content.cpp b/src/catalog/catalog_content.cpp index 2f6867d08d..4aa0941a4e 100644 --- a/src/catalog/catalog_content.cpp +++ b/src/catalog/catalog_content.cpp @@ -12,6 +12,7 @@ #include "common/serializer/serializer.h" #include "common/string_format.h" #include "common/string_utils.h" +#include "storage/storage_info.h" #include "storage/storage_utils.h" using namespace kuzu::binder; @@ -126,25 +127,6 @@ table_id_t CatalogContent::addRdfGraphSchema(const BoundCreateTableInfo& info) { return rdfGraphID; } -std::vector CatalogContent::getTableSchemas() const { - std::vector result; - result.reserve(tableSchemas.size()); - for (auto&& [_, schema] : tableSchemas) { - result.push_back(schema.get()); - } - return result; -} - -std::vector CatalogContent::getTableSchemas( - const std::vector& tableIDs) const { - std::vector result; - for (auto tableID : tableIDs) { - KU_ASSERT(tableSchemas.contains(tableID)); - result.push_back(tableSchemas.at(tableID).get()); - } - return result; -} - void CatalogContent::dropTableSchema(table_id_t tableID) { auto tableSchema = getTableSchema(tableID); switch (tableSchema->tableType) { @@ -168,6 +150,37 @@ void CatalogContent::renameTable(table_id_t tableID, const std::string& newName) tableSchema->updateTableName(newName); } +static void validateStorageVersion(storage_version_t savedStorageVersion) { + auto storageVersion = StorageVersionInfo::getStorageVersion(); + if (savedStorageVersion != storageVersion) { + // LCOV_EXCL_START + throw RuntimeException( + stringFormat("Trying to read a database file with a different version. " + "Database file version: {}, Current build storage version: {}", + savedStorageVersion, storageVersion)); + // LCOV_EXCL_STOP + } +} + +static void validateMagicBytes(Deserializer& deserializer) { + auto numMagicBytes = strlen(StorageVersionInfo::MAGIC_BYTES); + uint8_t magicBytes[4]; + for (auto i = 0u; i < numMagicBytes; i++) { + deserializer.deserializeValue(magicBytes[i]); + } + if (memcmp(magicBytes, StorageVersionInfo::MAGIC_BYTES, numMagicBytes) != 0) { + throw RuntimeException( + "This is not a valid Kuzu database directory for the current version of Kuzu."); + } +} + +static void writeMagicBytes(Serializer& serializer) { + auto numMagicBytes = strlen(StorageVersionInfo::MAGIC_BYTES); + for (auto i = 0u; i < numMagicBytes; i++) { + serializer.serializeValue(StorageVersionInfo::MAGIC_BYTES[i]); + } +} + void CatalogContent::saveToFile(const std::string& directory, FileVersionType dbFileType) { auto catalogPath = StorageUtils::getCatalogFilePath(directory, dbFileType); Serializer serializer( @@ -249,47 +262,12 @@ std::unique_ptr CatalogContent::copy() const { nextTableID, builtInFunctions->copy(), std::move(macrosToCopy)); } -void CatalogContent::validateStorageVersion(storage_version_t savedStorageVersion) { - auto storageVersion = StorageVersionInfo::getStorageVersion(); - if (savedStorageVersion != storageVersion) { - // LCOV_EXCL_START - throw RuntimeException( - stringFormat("Trying to read a database file with a different version. " - "Database file version: {}, Current build storage version: {}", - savedStorageVersion, storageVersion)); - // LCOV_EXCL_STOP - } -} - -void CatalogContent::validateMagicBytes(Deserializer& deserializer) { - auto numMagicBytes = strlen(StorageVersionInfo::MAGIC_BYTES); - uint8_t magicBytes[4]; - for (auto i = 0u; i < numMagicBytes; i++) { - deserializer.deserializeValue(magicBytes[i]); - } - if (memcmp(magicBytes, StorageVersionInfo::MAGIC_BYTES, numMagicBytes) != 0) { - throw RuntimeException( - "This is not a valid Kuzu database directory for the current version of Kuzu."); - } -} - -void CatalogContent::writeMagicBytes(Serializer& serializer) { - auto numMagicBytes = strlen(StorageVersionInfo::MAGIC_BYTES); - for (auto i = 0u; i < numMagicBytes; i++) { - serializer.serializeValue(StorageVersionInfo::MAGIC_BYTES[i]); - } -} - void CatalogContent::registerBuiltInFunctions() { builtInFunctions = std::make_unique(); } -bool CatalogContent::containsTable(const std::string& tableName, TableType tableType) const { - if (!tableNameToIDMap.contains(tableName)) { - return false; - } - auto tableID = getTableID(tableName); - return tableSchemas.at(tableID)->tableType == tableType; +bool CatalogContent::containsTable(const std::string& tableName) const { + return tableNameToIDMap.contains(tableName); } bool CatalogContent::containsTable(common::TableType tableType) const { @@ -301,6 +279,15 @@ bool CatalogContent::containsTable(common::TableType tableType) const { return false; } +std::string CatalogContent::getTableName(table_id_t tableID) const { + return getTableSchema(tableID)->tableName; +} + +TableSchema* CatalogContent::getTableSchema(table_id_t tableID) const { + KU_ASSERT(tableSchemas.contains(tableID)); + return tableSchemas.at(tableID).get(); +} + std::vector CatalogContent::getTableSchemas(TableType tableType) const { std::vector result; for (auto& [id, schema] : tableSchemas) { @@ -311,6 +298,11 @@ std::vector CatalogContent::getTableSchemas(TableType tableType) c return result; } +table_id_t CatalogContent::getTableID(const std::string& tableName) const { + KU_ASSERT(tableNameToIDMap.contains(tableName)); + return tableNameToIDMap.at(tableName); +} + std::vector CatalogContent::getTableIDs(TableType tableType) const { std::vector tableIDs; for (auto& [id, schema] : tableSchemas) { diff --git a/src/function/table_functions/call_functions.cpp b/src/function/table_functions/call_functions.cpp index abf113ffcf..aa53971231 100644 --- a/src/function/table_functions/call_functions.cpp +++ b/src/function/table_functions/call_functions.cpp @@ -1,17 +1,21 @@ #include "function/table_functions/call_functions.h" +#include "catalog/catalog.h" #include "catalog/node_table_schema.h" #include "catalog/rel_table_group_schema.h" #include "catalog/rel_table_schema.h" #include "common/exception/binder.h" +#include "main/client_context.h" +#include "transaction/transaction.h" -namespace kuzu { -namespace function { - +using namespace kuzu::transaction; using namespace kuzu::common; using namespace kuzu::catalog; using namespace kuzu::main; +namespace kuzu { +namespace function { + std::unique_ptr initLocalState( TableFunctionInitInput& /*input*/, TableFuncSharedState* /*state*/) { return std::make_unique(); @@ -56,7 +60,7 @@ void CurrentSettingFunction::tableFunc(TableFunctionInput& data, DataChunk& outp } std::unique_ptr CurrentSettingFunction::bindFunc( - ClientContext* context, TableFuncBindInput* input, CatalogContent* /*catalog*/) { + ClientContext* context, TableFuncBindInput* input, Catalog* /*catalog*/) { auto optionName = input->inputs[0]->getValue(); std::vector returnColumnNames; std::vector> returnTypes; @@ -87,7 +91,7 @@ void DBVersionFunction::tableFunc(TableFunctionInput& input, DataChunk& outputCh } std::unique_ptr DBVersionFunction::bindFunc( - ClientContext* /*context*/, TableFuncBindInput* /*input*/, CatalogContent* /*catalog*/) { + ClientContext* /*context*/, TableFuncBindInput* /*input*/, Catalog* /*catalog*/) { std::vector returnColumnNames; std::vector> returnTypes; returnColumnNames.emplace_back("version"); @@ -123,7 +127,7 @@ void ShowTablesFunction::tableFunc(TableFunctionInput& input, DataChunk& outputC } std::unique_ptr ShowTablesFunction::bindFunc( - ClientContext* /*context*/, TableFuncBindInput* /*input*/, CatalogContent* catalog) { + ClientContext* context, TableFuncBindInput* /*input*/, Catalog* catalog) { std::vector returnColumnNames; std::vector> returnTypes; returnColumnNames.emplace_back("name"); @@ -132,8 +136,9 @@ std::unique_ptr ShowTablesFunction::bindFunc( returnTypes.emplace_back(LogicalType::STRING()); returnColumnNames.emplace_back("comment"); returnTypes.emplace_back(LogicalType::STRING()); - return std::make_unique(catalog->getTableSchemas(), std::move(returnTypes), - std::move(returnColumnNames), catalog->getTableCount()); + return std::make_unique(catalog->getTableSchemas(context->getTx()), + std::move(returnTypes), std::move(returnColumnNames), + catalog->getTableCount(context->getTx())); } function_set TableInfoFunction::getFunctionSet() { @@ -174,12 +179,12 @@ void TableInfoFunction::tableFunc(TableFunctionInput& input, DataChunk& outputCh } std::unique_ptr TableInfoFunction::bindFunc( - ClientContext* /*context*/, TableFuncBindInput* input, CatalogContent* catalog) { + ClientContext* context, TableFuncBindInput* input, Catalog* catalog) { std::vector returnColumnNames; std::vector> returnTypes; auto tableName = input->inputs[0]->getValue(); - auto tableID = catalog->getTableID(tableName); - auto schema = catalog->getTableSchema(tableID); + auto tableID = catalog->getTableID(context->getTx(), tableName); + auto schema = catalog->getTableSchema(context->getTx(), tableID); returnColumnNames.emplace_back("property id"); returnTypes.push_back(LogicalType::INT64()); returnColumnNames.emplace_back("name"); @@ -202,14 +207,15 @@ function_set ShowConnectionFunction::getFunctionSet() { } void ShowConnectionFunction::outputRelTableConnection(ValueVector* srcTableNameVector, - ValueVector* dstTableNameVector, uint64_t outputPos, CatalogContent* catalog, - table_id_t tableID) { - auto tableSchema = catalog->getTableSchema(tableID); + ValueVector* dstTableNameVector, uint64_t outputPos, Catalog* catalog, table_id_t tableID) { + auto tableSchema = catalog->getTableSchema(&DUMMY_READ_TRANSACTION, tableID); KU_ASSERT(tableSchema->tableType == TableType::REL); auto srcTableID = reinterpret_cast(tableSchema)->getSrcTableID(); auto dstTableID = reinterpret_cast(tableSchema)->getDstTableID(); - srcTableNameVector->setValue(outputPos, catalog->getTableName(srcTableID)); - dstTableNameVector->setValue(outputPos, catalog->getTableName(dstTableID)); + srcTableNameVector->setValue( + outputPos, catalog->getTableName(&DUMMY_READ_TRANSACTION, srcTableID)); + dstTableNameVector->setValue( + outputPos, catalog->getTableName(&DUMMY_READ_TRANSACTION, dstTableID)); } void ShowConnectionFunction::tableFunc(TableFunctionInput& input, DataChunk& outputChunk) { @@ -246,12 +252,12 @@ void ShowConnectionFunction::tableFunc(TableFunctionInput& input, DataChunk& out } std::unique_ptr ShowConnectionFunction::bindFunc( - ClientContext* /*context*/, TableFuncBindInput* input, CatalogContent* catalog) { + ClientContext* context, TableFuncBindInput* input, Catalog* catalog) { std::vector returnColumnNames; std::vector> returnTypes; auto tableName = input->inputs[0]->getValue(); - auto tableID = catalog->getTableID(tableName); - auto schema = catalog->getTableSchema(tableID); + auto tableID = catalog->getTableID(context->getTx(), tableName); + auto schema = catalog->getTableSchema(context->getTx(), tableID); auto tableType = schema->getTableType(); if (tableType != TableType::REL && tableType != TableType::REL_GROUP) { throw BinderException{"Show connection can only be called on a rel table!"}; diff --git a/src/include/catalog/catalog.h b/src/include/catalog/catalog.h index f937a49c84..2c19479df4 100644 --- a/src/include/catalog/catalog.h +++ b/src/include/catalog/catalog.h @@ -10,7 +10,8 @@ class WAL; } namespace transaction { enum class TransactionAction : uint8_t; -} +class Transaction; +} // namespace transaction namespace catalog { class Catalog { @@ -19,19 +20,39 @@ class Catalog { explicit Catalog(storage::WAL* wal); - // TODO(Guodong): Get rid of these two functions. - inline CatalogContent* getReadOnlyVersion() const { return catalogContentForReadOnlyTrx.get(); } - inline CatalogContent* getWriteVersion() const { return catalogContentForWriteTrx.get(); } + // TODO(Guodong): Get rid of the following. + inline CatalogContent* getReadOnlyVersion() const { return readOnlyVersion.get(); } + + uint64_t getTableCount(transaction::Transaction* tx) const; + + bool containsNodeTable(transaction::Transaction* tx) const; + bool containsRelTable(transaction::Transaction* tx) const; + bool containsRdfGraph(transaction::Transaction* tx) const; + bool containsTable(transaction::Transaction* tx, const std::string& tableName) const; + + common::table_id_t getTableID(transaction::Transaction* tx, const std::string& tableName) const; + std::vector getNodeTableIDs(transaction::Transaction* tx) const; + std::vector getRelTableIDs(transaction::Transaction* tx) const; + std::vector getRdfGraphIDs(transaction::Transaction* tx) const; + + std::string getTableName(transaction::Transaction* tx, common::table_id_t tableID) const; + TableSchema* getTableSchema(transaction::Transaction* tx, common::table_id_t tableID) const; + std::vector getNodeTableSchemas(transaction::Transaction* tx) const; + std::vector getRelTableSchemas(transaction::Transaction* tx) const; + std::vector getRelTableGroupSchemas(transaction::Transaction* tx) const; + std::vector getTableSchemas(transaction::Transaction* tx) const; + std::vector getTableSchemas( + transaction::Transaction* tx, const common::table_id_vector_t& tableIDs) const; inline function::BuiltInFunctions* getBuiltInFunctions() const { - return catalogContentForReadOnlyTrx->builtInFunctions.get(); + return readOnlyVersion->builtInFunctions.get(); } void prepareCommitOrRollback(transaction::TransactionAction action); void checkpointInMemory(); inline void initCatalogContentForWriteTrxIfNecessary() { - if (!catalogContentForWriteTrx) { - catalogContentForWriteTrx = catalogContentForReadOnlyTrx->copy(); + if (!readWriteVersion) { + readWriteVersion = readOnlyVersion->copy(); } } @@ -46,9 +67,7 @@ class Catalog { common::table_id_t addRelTableSchema(const binder::BoundCreateTableInfo& info); common::table_id_t addRelTableGroupSchema(const binder::BoundCreateTableInfo& info); common::table_id_t addRdfGraphSchema(const binder::BoundCreateTableInfo& info); - void dropTableSchema(common::table_id_t tableID); - void renameTable(common::table_id_t tableID, const std::string& newName); void addNodeProperty(common::table_id_t tableID, const std::string& propertyName, @@ -62,7 +81,7 @@ class Catalog { common::table_id_t tableID, common::property_id_t propertyID, const std::string& newName); void addFunction(std::string name, function::function_set functionSet); - + bool containsMacro(transaction::Transaction* tx, const std::string& macroName) const; void addScalarMacroFunction( std::string name, std::unique_ptr macro); @@ -70,15 +89,17 @@ class Catalog { // TODO(Ziyi): pass transaction pointer here. inline function::ScalarMacroFunction* getScalarMacroFunction(const std::string& name) const { - return catalogContentForReadOnlyTrx->macros.at(name).get(); + return readOnlyVersion->macros.at(name).get(); } private: - inline bool hasUpdates() { return catalogContentForWriteTrx != nullptr; } + inline CatalogContent* getVersion(transaction::Transaction* tx) const; + + inline bool hasUpdates() { return readWriteVersion != nullptr; } protected: - std::unique_ptr catalogContentForReadOnlyTrx; - std::unique_ptr catalogContentForWriteTrx; + std::unique_ptr readOnlyVersion; + std::unique_ptr readWriteVersion; storage::WAL* wal; }; diff --git a/src/include/catalog/catalog_content.h b/src/include/catalog/catalog_content.h index 18c588e02b..d3db394401 100644 --- a/src/include/catalog/catalog_content.h +++ b/src/include/catalog/catalog_content.h @@ -3,7 +3,6 @@ #include "binder/ddl/bound_create_table_info.h" #include "function/built_in_function.h" #include "function/scalar_macro_function.h" -#include "storage/storage_info.h" #include "table_schema.h" namespace kuzu { @@ -31,105 +30,46 @@ class CatalogContent { nextTableID{nextTableID}, builtInFunctions{std::move(builtInFunctions)}, macros{std::move( macros)} {} - /* - * Single schema lookup. - * */ - inline bool containsTable(const std::string& tableName) const { - return tableNameToIDMap.contains(tableName); - } - inline bool containsNodeTable(const std::string& tableName) const { - return containsTable(tableName, common::TableType::NODE); - } - inline bool containsRelTable(const std::string& tableName) const { - return containsTable(tableName, common::TableType::REL); - } - inline std::string getTableName(common::table_id_t tableID) const { - KU_ASSERT(tableSchemas.contains(tableID)); - return getTableSchema(tableID)->tableName; - } - inline TableSchema* getTableSchema(common::table_id_t tableID) const { - KU_ASSERT(tableSchemas.contains(tableID)); - return tableSchemas.at(tableID).get(); - } - inline common::table_id_t getTableID(const std::string& tableName) const { - KU_ASSERT(tableNameToIDMap.contains(tableName)); - return tableNameToIDMap.at(tableName); - } + void saveToFile(const std::string& directory, common::FileVersionType dbFileType); + void readFromFile(const std::string& directory, common::FileVersionType dbFileType); - /* - * Batch schema lookup. - * */ - inline uint64_t getTableCount() const { return tableSchemas.size(); } - inline bool containsNodeTable() const { return containsTable(common::TableType::NODE); } - inline bool containsRelTable() const { return containsTable(common::TableType::REL); } - inline bool containsRdfGraph() const { return containsTable(common::TableType::RDF); } - inline std::vector getNodeTableIDs() const { - return getTableIDs(common::TableType::NODE); - } - inline std::vector getRelTableIDs() const { - return getTableIDs(common::TableType::REL); - } - inline std::vector getRdfGraphIDs() const { - return getTableIDs(common::TableType::RDF); - } - inline std::vector getNodeTableSchemas() const { - return getTableSchemas(common::TableType::NODE); - } - inline std::vector getRelTableSchemas() const { - return getTableSchemas(common::TableType::REL); - } - inline std::vector getRelTableGroupSchemas() const { - return getTableSchemas(common::TableType::REL_GROUP); - } + std::unique_ptr copy() const; - std::vector getTableSchemas() const; - std::vector getTableSchemas( - const std::vector& tableIDs) const; +private: + // ----------------------------- Functions ---------------------------- + common::ExpressionType getFunctionType(const std::string& name) const; - /** - * Add schema. - */ - common::table_id_t addNodeTableSchema(const binder::BoundCreateTableInfo& info); - common::table_id_t addRelTableSchema(const binder::BoundCreateTableInfo& info); - common::table_id_t addRelTableGroupSchema(const binder::BoundCreateTableInfo& info); - common::table_id_t addRdfGraphSchema(const binder::BoundCreateTableInfo& info); + void registerBuiltInFunctions(); inline bool containMacro(const std::string& macroName) const { return macros.contains(macroName); } - - void dropTableSchema(common::table_id_t tableID); - - void renameTable(common::table_id_t tableID, const std::string& newName); - - void saveToFile(const std::string& directory, common::FileVersionType dbFileType); - void readFromFile(const std::string& directory, common::FileVersionType dbFileType); - - common::ExpressionType getFunctionType(const std::string& name) const; - void addFunction(std::string name, function::function_set definitions); - void addScalarMacroFunction( std::string name, std::unique_ptr macro); - std::unique_ptr copy() const; - -private: + // ----------------------------- Table Schemas ---------------------------- inline common::table_id_t assignNextTableID() { return nextTableID++; } + inline uint64_t getTableCount() const { return tableSchemas.size(); } - static void validateStorageVersion(storage::storage_version_t savedStorageVersion); - - static void validateMagicBytes(common::Deserializer& deserializer); - - static void writeMagicBytes(common::Serializer& serializer); + bool containsTable(const std::string& tableName) const; + bool containsTable(common::TableType tableType) const; - void registerBuiltInFunctions(); + std::string getTableName(common::table_id_t tableID) const; - bool containsTable(const std::string& tableName, common::TableType tableType) const; - bool containsTable(common::TableType tableType) const; + TableSchema* getTableSchema(common::table_id_t tableID) const; std::vector getTableSchemas(common::TableType tableType) const; + + common::table_id_t getTableID(const std::string& tableName) const; std::vector getTableIDs(common::TableType tableType) const; + common::table_id_t addNodeTableSchema(const binder::BoundCreateTableInfo& info); + common::table_id_t addRelTableSchema(const binder::BoundCreateTableInfo& info); + common::table_id_t addRelTableGroupSchema(const binder::BoundCreateTableInfo& info); + common::table_id_t addRdfGraphSchema(const binder::BoundCreateTableInfo& info); + void dropTableSchema(common::table_id_t tableID); + void renameTable(common::table_id_t tableID, const std::string& newName); + private: // TODO(Guodong): I don't think it's necessary to keep separate maps for node and rel tables. std::unordered_map> tableSchemas; diff --git a/src/include/function/table_functions.h b/src/include/function/table_functions.h index 2c7bf2f1da..32fbf51274 100644 --- a/src/include/function/table_functions.h +++ b/src/include/function/table_functions.h @@ -5,7 +5,7 @@ namespace kuzu { namespace catalog { -class CatalogContent; +class Catalog; } // namespace catalog namespace common { class ValueVector; @@ -46,8 +46,8 @@ struct TableFunctionInitInput { virtual ~TableFunctionInitInput() = default; }; -typedef std::unique_ptr (*table_func_bind_t)(main::ClientContext* /*context*/, - TableFuncBindInput* /*input*/, catalog::CatalogContent* /*catalog*/); +typedef std::unique_ptr (*table_func_bind_t)( + main::ClientContext* /*context*/, TableFuncBindInput* /*input*/, catalog::Catalog* /*catalog*/); typedef void (*table_func_t)(TableFunctionInput& data, common::DataChunk& output); typedef std::unique_ptr (*table_func_init_shared_t)( TableFunctionInitInput& input); diff --git a/src/include/function/table_functions/call_functions.h b/src/include/function/table_functions/call_functions.h index 9bfef0fadd..b61a8a611f 100644 --- a/src/include/function/table_functions/call_functions.h +++ b/src/include/function/table_functions/call_functions.h @@ -72,8 +72,8 @@ struct CurrentSettingFunction : public CallFunction { static void tableFunc(TableFunctionInput& data, common::DataChunk& outputChunk); - static std::unique_ptr bindFunc(main::ClientContext* context, - TableFuncBindInput* input, catalog::CatalogContent* /*catalog*/); + static std::unique_ptr bindFunc( + main::ClientContext* context, TableFuncBindInput* input, catalog::Catalog* /*catalog*/); }; struct DBVersionFunction : public CallFunction { @@ -82,7 +82,7 @@ struct DBVersionFunction : public CallFunction { static void tableFunc(TableFunctionInput& input, common::DataChunk& outputChunk); static std::unique_ptr bindFunc(main::ClientContext* /*context*/, - TableFuncBindInput* /*input*/, catalog::CatalogContent* /*catalog*/); + TableFuncBindInput* /*input*/, catalog::Catalog* /*catalog*/); }; struct ShowTablesBindData : public CallTableFuncBindData { @@ -105,8 +105,8 @@ struct ShowTablesFunction : public CallFunction { static void tableFunc(TableFunctionInput& input, common::DataChunk& outputChunk); - static std::unique_ptr bindFunc(main::ClientContext* /*context*/, - TableFuncBindInput* /*input*/, catalog::CatalogContent* catalog); + static std::unique_ptr bindFunc( + main::ClientContext* /*context*/, TableFuncBindInput* /*input*/, catalog::Catalog* catalog); }; struct TableInfoBindData : public CallTableFuncBindData { @@ -129,14 +129,14 @@ struct TableInfoFunction : public CallFunction { static void tableFunc(TableFunctionInput& input, common::DataChunk& outputChunk); - static std::unique_ptr bindFunc(main::ClientContext* /*context*/, - TableFuncBindInput* input, catalog::CatalogContent* catalog); + static std::unique_ptr bindFunc( + main::ClientContext* /*context*/, TableFuncBindInput* input, catalog::Catalog* catalog); }; struct ShowConnectionBindData : public TableInfoBindData { - catalog::CatalogContent* catalog; + catalog::Catalog* catalog; - ShowConnectionBindData(catalog::CatalogContent* catalog, catalog::TableSchema* tableSchema, + ShowConnectionBindData(catalog::Catalog* catalog, catalog::TableSchema* tableSchema, std::vector> returnTypes, std::vector returnColumnNames, common::offset_t maxOffset) : catalog{catalog}, TableInfoBindData{tableSchema, std::move(returnTypes), @@ -152,13 +152,13 @@ struct ShowConnectionFunction : public CallFunction { static function_set getFunctionSet(); static void outputRelTableConnection(common::ValueVector* srcTableNameVector, - common::ValueVector* dstTableNameVector, uint64_t outputPos, - catalog::CatalogContent* catalog, common::table_id_t tableID); + common::ValueVector* dstTableNameVector, uint64_t outputPos, catalog::Catalog* catalog, + common::table_id_t tableID); static void tableFunc(TableFunctionInput& input, common::DataChunk& outputChunk); - static std::unique_ptr bindFunc(main::ClientContext* /*context*/, - TableFuncBindInput* input, catalog::CatalogContent* catalog); + static std::unique_ptr bindFunc( + main::ClientContext* /*context*/, TableFuncBindInput* input, catalog::Catalog* catalog); }; } // namespace function diff --git a/src/include/main/client_context.h b/src/include/main/client_context.h index ab40f859e3..230e4a6d76 100644 --- a/src/include/main/client_context.h +++ b/src/include/main/client_context.h @@ -61,7 +61,7 @@ class ClientContext { std::string getCurrentSetting(const std::string& optionName); - transaction::Transaction* getActiveTransaction() const; + transaction::Transaction* getTx() const; transaction::TransactionContext* getTransactionContext() const; inline void setReplaceFunc(replace_func_t replaceFunc) { diff --git a/src/include/parser/statement.h b/src/include/parser/statement.h index b090cac78c..3cb56749f5 100644 --- a/src/include/parser/statement.h +++ b/src/include/parser/statement.h @@ -13,12 +13,8 @@ class Statement { inline common::StatementType getStatementType() const { return statementType; } - inline void enableProfile() { profile = true; } - inline bool isProfile() const { return profile; } - private: common::StatementType statementType; - bool profile = false; }; } // namespace parser diff --git a/src/include/processor/operator/ddl/add_node_property.h b/src/include/processor/operator/ddl/add_node_property.h index c290be0571..1f93a2d3cf 100644 --- a/src/include/processor/operator/ddl/add_node_property.h +++ b/src/include/processor/operator/ddl/add_node_property.h @@ -5,7 +5,7 @@ namespace kuzu { namespace processor { -class AddNodeProperty : public AddProperty { +class AddNodeProperty final : public AddProperty { public: AddNodeProperty(catalog::Catalog* catalog, common::table_id_t tableID, std::string propertyName, std::unique_ptr dataType, @@ -15,7 +15,7 @@ class AddNodeProperty : public AddProperty { : AddProperty{catalog, tableID, std::move(propertyName), std::move(dataType), std::move(defaultValueEvaluator), storageManager, outputPos, id, paramsString} {} - void executeDDLInternal() final; + void executeDDLInternal(ExecutionContext* context) final; std::unique_ptr clone() override { return make_unique(catalog, tableID, propertyName, dataType->copy(), diff --git a/src/include/processor/operator/ddl/add_property.h b/src/include/processor/operator/ddl/add_property.h index d65e2c7549..f7d3ba06d6 100644 --- a/src/include/processor/operator/ddl/add_property.h +++ b/src/include/processor/operator/ddl/add_property.h @@ -23,7 +23,7 @@ class AddProperty : public DDL { defaultValueEvaluator->init(*resultSet, context->memoryManager); } - void executeDDLInternal() override = 0; + void executeDDLInternal(ExecutionContext* context) override = 0; std::string getOutputMsg() override { return {"Add Succeed."}; } diff --git a/src/include/processor/operator/ddl/add_rel_property.h b/src/include/processor/operator/ddl/add_rel_property.h index 6f881d5f7c..fd9c4e0641 100644 --- a/src/include/processor/operator/ddl/add_rel_property.h +++ b/src/include/processor/operator/ddl/add_rel_property.h @@ -5,9 +5,7 @@ namespace kuzu { namespace processor { -class AddRelProperty; - -class AddRelProperty : public AddProperty { +class AddRelProperty final : public AddProperty { public: AddRelProperty(catalog::Catalog* catalog, common::table_id_t tableID, std::string propertyName, std::unique_ptr dataType, @@ -17,7 +15,7 @@ class AddRelProperty : public AddProperty { : AddProperty(catalog, tableID, std::move(propertyName), std::move(dataType), std::move(expressionEvaluator), storageManager, outputPos, id, paramsString) {} - void executeDDLInternal() override; + void executeDDLInternal(ExecutionContext* context) override; std::unique_ptr clone() override { return make_unique(catalog, tableID, propertyName, dataType->copy(), diff --git a/src/include/processor/operator/ddl/create_node_table.h b/src/include/processor/operator/ddl/create_node_table.h index 83a506bfda..bffc714d28 100644 --- a/src/include/processor/operator/ddl/create_node_table.h +++ b/src/include/processor/operator/ddl/create_node_table.h @@ -13,7 +13,7 @@ class CreateNodeTable : public DDL { : DDL{PhysicalOperatorType::CREATE_NODE_TABLE, catalog, outputPos, id, paramsString}, storageManager{storageManager}, info{std::move(info)} {} - void executeDDLInternal() final; + void executeDDLInternal(ExecutionContext* context) final; std::string getOutputMsg() final; diff --git a/src/include/processor/operator/ddl/create_rdf_graph.h b/src/include/processor/operator/ddl/create_rdf_graph.h index 019d4c147e..5f676e0484 100644 --- a/src/include/processor/operator/ddl/create_rdf_graph.h +++ b/src/include/processor/operator/ddl/create_rdf_graph.h @@ -6,7 +6,7 @@ namespace kuzu { namespace processor { -class CreateRdfGraph : public DDL { +class CreateRdfGraph final : public DDL { public: CreateRdfGraph(catalog::Catalog* catalog, storage::StorageManager* storageManager, std::unique_ptr info, const DataPos& outputPos, uint32_t id, @@ -16,11 +16,11 @@ class CreateRdfGraph : public DDL { nodesStatistics{storageManager->getNodesStatisticsAndDeletedIDs()}, relsStatistics{storageManager->getRelsStatistics()}, info{std::move(info)} {} - void executeDDLInternal() final; + void executeDDLInternal(ExecutionContext* context) override; - std::string getOutputMsg() final; + std::string getOutputMsg() override; - inline std::unique_ptr clone() final { + inline std::unique_ptr clone() override { return std::make_unique( catalog, storageManager, info->copy(), outputPos, id, paramsString); } diff --git a/src/include/processor/operator/ddl/create_rel_table.h b/src/include/processor/operator/ddl/create_rel_table.h index b54344ed32..d34037287b 100644 --- a/src/include/processor/operator/ddl/create_rel_table.h +++ b/src/include/processor/operator/ddl/create_rel_table.h @@ -5,7 +5,7 @@ namespace kuzu { namespace processor { -class CreateRelTable : public DDL { +class CreateRelTable final : public DDL { public: CreateRelTable(catalog::Catalog* catalog, storage::StorageManager* storageManager, std::unique_ptr info, const DataPos& outputPos, uint32_t id, @@ -13,7 +13,7 @@ class CreateRelTable : public DDL { : DDL{PhysicalOperatorType::CREATE_REL_TABLE, catalog, outputPos, id, paramsString}, storageManager{storageManager}, info{std::move(info)} {} - void executeDDLInternal() override; + void executeDDLInternal(ExecutionContext* context) override; std::string getOutputMsg() override; diff --git a/src/include/processor/operator/ddl/create_rel_table_group.h b/src/include/processor/operator/ddl/create_rel_table_group.h index 760c407622..c9445af6cb 100644 --- a/src/include/processor/operator/ddl/create_rel_table_group.h +++ b/src/include/processor/operator/ddl/create_rel_table_group.h @@ -13,7 +13,7 @@ class CreateRelTableGroup : public DDL { : DDL{PhysicalOperatorType::CREATE_REL_TABLE, catalog, outputPos, id, paramsString}, info{std::move(info)}, storageManager{storageManager} {} - void executeDDLInternal() override; + void executeDDLInternal(ExecutionContext* context) override; std::string getOutputMsg() override; diff --git a/src/include/processor/operator/ddl/ddl.h b/src/include/processor/operator/ddl/ddl.h index 8627127703..c715d34503 100644 --- a/src/include/processor/operator/ddl/ddl.h +++ b/src/include/processor/operator/ddl/ddl.h @@ -22,7 +22,7 @@ class DDL : public PhysicalOperator { protected: virtual std::string getOutputMsg() = 0; - virtual void executeDDLInternal() = 0; + virtual void executeDDLInternal(ExecutionContext* context) = 0; protected: catalog::Catalog* catalog; diff --git a/src/include/processor/operator/ddl/drop_property.h b/src/include/processor/operator/ddl/drop_property.h index 1da3ad395e..13d88308be 100644 --- a/src/include/processor/operator/ddl/drop_property.h +++ b/src/include/processor/operator/ddl/drop_property.h @@ -14,7 +14,7 @@ class DropProperty : public DDL { : DDL{PhysicalOperatorType::DROP_PROPERTY, catalog, outputPos, id, paramsString}, storageManager{storageManager}, tableID{tableID}, propertyID{propertyID} {} - void executeDDLInternal() final; + void executeDDLInternal(ExecutionContext* context) final; std::string getOutputMsg() final { return {"Drop succeed."}; } diff --git a/src/include/processor/operator/ddl/drop_table.h b/src/include/processor/operator/ddl/drop_table.h index 96e3615ec1..76df09d0d6 100644 --- a/src/include/processor/operator/ddl/drop_table.h +++ b/src/include/processor/operator/ddl/drop_table.h @@ -7,20 +7,21 @@ namespace processor { class DropTable : public DDL { public: - DropTable(catalog::Catalog* catalog, common::table_id_t tableID, const DataPos& outputPos, - uint32_t id, const std::string& paramsString) + DropTable(catalog::Catalog* catalog, std::string tableName, common::table_id_t tableID, + const DataPos& outputPos, uint32_t id, const std::string& paramsString) : DDL{PhysicalOperatorType::DROP_TABLE, catalog, outputPos, id, paramsString}, - tableID{tableID} {} + tableName{std::move(tableName)}, tableID{tableID} {} - void executeDDLInternal() override; + void executeDDLInternal(ExecutionContext* context) override; std::string getOutputMsg() override; std::unique_ptr clone() override { - return make_unique(catalog, tableID, outputPos, id, paramsString); + return make_unique(catalog, tableName, tableID, outputPos, id, paramsString); } protected: + std::string tableName; common::table_id_t tableID; }; diff --git a/src/include/processor/operator/ddl/rename_property.h b/src/include/processor/operator/ddl/rename_property.h index 94cda72e7c..ac3ad075ca 100644 --- a/src/include/processor/operator/ddl/rename_property.h +++ b/src/include/processor/operator/ddl/rename_property.h @@ -13,7 +13,9 @@ class RenameProperty : public DDL { : DDL{PhysicalOperatorType::RENAME_PROPERTY, catalog, outputPos, id, paramsString}, tableID{tableID}, propertyID{propertyID}, newName{std::move(newName)} {} - void executeDDLInternal() override { catalog->renameProperty(tableID, propertyID, newName); } + void executeDDLInternal(ExecutionContext* /*context*/) override { + catalog->renameProperty(tableID, propertyID, newName); + } std::string getOutputMsg() override { return "Property renamed"; } diff --git a/src/include/processor/operator/ddl/rename_table.h b/src/include/processor/operator/ddl/rename_table.h index 228b737911..42b0b2b1ec 100644 --- a/src/include/processor/operator/ddl/rename_table.h +++ b/src/include/processor/operator/ddl/rename_table.h @@ -5,14 +5,16 @@ namespace kuzu { namespace processor { -class RenameTable : public DDL { +class RenameTable final : public DDL { public: RenameTable(catalog::Catalog* catalog, common::table_id_t tableID, std::string newName, const DataPos& outputPos, uint32_t id, const std::string& paramsString) : DDL{PhysicalOperatorType::RENAME_TABLE, catalog, outputPos, id, paramsString}, tableID{tableID}, newName{std::move(newName)} {} - void executeDDLInternal() override { catalog->renameTable(tableID, newName); } + void executeDDLInternal(ExecutionContext* /*context*/) override { + catalog->renameTable(tableID, newName); + } std::string getOutputMsg() override { return "Table renamed"; } diff --git a/src/include/processor/operator/persistent/reader/csv/parallel_csv_reader.h b/src/include/processor/operator/persistent/reader/csv/parallel_csv_reader.h index 88e6797df6..8abb35f11d 100644 --- a/src/include/processor/operator/persistent/reader/csv/parallel_csv_reader.h +++ b/src/include/processor/operator/persistent/reader/csv/parallel_csv_reader.h @@ -52,7 +52,7 @@ struct ParallelCSVScan { static void tableFunc(function::TableFunctionInput& input, common::DataChunk& outputChunk); static std::unique_ptr bindFunc(main::ClientContext* /*context*/, - function::TableFuncBindInput* input, catalog::CatalogContent* /*catalog*/); + function::TableFuncBindInput* input, catalog::Catalog* /*catalog*/); static std::unique_ptr initSharedState( function::TableFunctionInitInput& input); diff --git a/src/include/processor/operator/persistent/reader/csv/serial_csv_reader.h b/src/include/processor/operator/persistent/reader/csv/serial_csv_reader.h index 9ab8be2cf8..e3363cb6cb 100644 --- a/src/include/processor/operator/persistent/reader/csv/serial_csv_reader.h +++ b/src/include/processor/operator/persistent/reader/csv/serial_csv_reader.h @@ -44,7 +44,7 @@ struct SerialCSVScan { static void tableFunc(function::TableFunctionInput& input, common::DataChunk& outputChunk); static std::unique_ptr bindFunc(main::ClientContext* /*context*/, - function::TableFuncBindInput* input, catalog::CatalogContent* /*catalog*/); + function::TableFuncBindInput* input, catalog::Catalog* /*catalog*/); static std::unique_ptr initSharedState( function::TableFunctionInitInput& input); diff --git a/src/include/processor/operator/persistent/reader/npy/npy_reader.h b/src/include/processor/operator/persistent/reader/npy/npy_reader.h index 810c28efb3..9b7bdc0f3e 100644 --- a/src/include/processor/operator/persistent/reader/npy/npy_reader.h +++ b/src/include/processor/operator/persistent/reader/npy/npy_reader.h @@ -71,7 +71,7 @@ struct NpyScanFunction { static void tableFunc(function::TableFunctionInput& input, common::DataChunk& outputChunk); static std::unique_ptr bindFunc(main::ClientContext* /*context*/, - function::TableFuncBindInput* input, catalog::CatalogContent* catalog); + function::TableFuncBindInput* input, catalog::Catalog* catalog); static std::unique_ptr initSharedState( function::TableFunctionInitInput& input); diff --git a/src/include/processor/operator/persistent/reader/parquet/parquet_reader.h b/src/include/processor/operator/persistent/reader/parquet/parquet_reader.h index 3414843e3b..14bb72ec4e 100644 --- a/src/include/processor/operator/persistent/reader/parquet/parquet_reader.h +++ b/src/include/processor/operator/persistent/reader/parquet/parquet_reader.h @@ -113,7 +113,7 @@ struct ParquetScanFunction { static void tableFunc(function::TableFunctionInput& input, common::DataChunk& outputChunk); static std::unique_ptr bindFunc(main::ClientContext* /*context*/, - function::TableFuncBindInput* input, catalog::CatalogContent* catalog); + function::TableFuncBindInput* input, catalog::Catalog* catalog); static std::unique_ptr initSharedState( function::TableFunctionInitInput& input); diff --git a/src/include/processor/operator/persistent/reader/rdf/rdf_reader.h b/src/include/processor/operator/persistent/reader/rdf/rdf_reader.h index a45ea121cc..3b22fab7c8 100644 --- a/src/include/processor/operator/persistent/reader/rdf/rdf_reader.h +++ b/src/include/processor/operator/persistent/reader/rdf/rdf_reader.h @@ -60,7 +60,7 @@ struct RdfScan { static void tableFunc(function::TableFunctionInput& input, common::DataChunk& outputChunk); static std::unique_ptr bindFunc(main::ClientContext* /*context*/, - function::TableFuncBindInput* input, catalog::CatalogContent* /*catalog*/); + function::TableFuncBindInput* input, catalog::Catalog* /*catalog*/); static std::unique_ptr initSharedState( function::TableFunctionInitInput& input); diff --git a/src/include/processor/operator/scan_node_id.h b/src/include/processor/operator/scan_node_id.h index 2b586e24a9..6ec3edd325 100644 --- a/src/include/processor/operator/scan_node_id.h +++ b/src/include/processor/operator/scan_node_id.h @@ -74,7 +74,7 @@ class ScanNodeID : public PhysicalOperator { private: inline void initGlobalStateInternal(ExecutionContext* context) override { - sharedState->initialize(context->clientContext->getActiveTransaction()); + sharedState->initialize(context->clientContext->getTx()); } void setSelVector( diff --git a/src/include/storage/stats/table_statistics_collection.h b/src/include/storage/stats/table_statistics_collection.h index b64648a363..4bccd62bdb 100644 --- a/src/include/storage/stats/table_statistics_collection.h +++ b/src/include/storage/stats/table_statistics_collection.h @@ -65,6 +65,8 @@ class TablesStatistics { static std::unique_ptr createMetadataDAHInfo( const common::LogicalType& dataType, BMFileHandle& metadataFH, BufferManager* bm, WAL* wal); + void initTableStatisticsForWriteTrx(); + protected: virtual std::unique_ptr constructTableStatistic( catalog::TableSchema* tableSchema) = 0; @@ -81,7 +83,6 @@ class TablesStatistics { void saveToFile(const std::string& directory, common::FileVersionType dbFileType, transaction::TransactionType transactionType); - void initTableStatisticsForWriteTrx(); void initTableStatisticsForWriteTrxNoLock(); protected: diff --git a/src/include/storage/storage_manager.h b/src/include/storage/storage_manager.h index 8abc2a9b32..9aca25fdb8 100644 --- a/src/include/storage/storage_manager.h +++ b/src/include/storage/storage_manager.h @@ -16,7 +16,8 @@ class StorageManager { StorageManager(bool readOnly, const catalog::Catalog& catalog, MemoryManager& memoryManager, WAL* wal, bool enableCompression); - void createTable(common::table_id_t tableID, catalog::Catalog* catalog); + void createTable(common::table_id_t tableID, catalog::Catalog* catalog, + transaction::Transaction* transaction); void dropTable(common::table_id_t tableID); void prepareCommit(transaction::Transaction* transaction); @@ -41,6 +42,10 @@ class StorageManager { inline WAL* getWAL() const { return wal; } inline BMFileHandle* getDataFH() const { return dataFH.get(); } inline BMFileHandle* getMetadataFH() const { return metadataFH.get(); } + inline void initStatistics() { + nodesStatisticsAndDeletedIDs->initTableStatisticsForWriteTrx(); + relsStatistics->initTableStatisticsForWriteTrx(); + } inline NodesStoreStatsAndDeletedIDs* getNodesStatisticsAndDeletedIDs() { return nodesStatisticsAndDeletedIDs.get(); } diff --git a/src/main/client_context.cpp b/src/main/client_context.cpp index 6931a63b62..ef113e2485 100644 --- a/src/main/client_context.cpp +++ b/src/main/client_context.cpp @@ -43,7 +43,7 @@ std::string ClientContext::getCurrentSetting(const std::string& optionName) { return option->getSetting(this); } -Transaction* ClientContext::getActiveTransaction() const { +transaction::Transaction* ClientContext::getTx() const { return transactionContext->getActiveTransaction(); } diff --git a/src/main/connection.cpp b/src/main/connection.cpp index af104fe7d6..eaba59e198 100644 --- a/src/main/connection.cpp +++ b/src/main/connection.cpp @@ -93,20 +93,42 @@ std::unique_ptr Connection::prepareNoLock( } auto compilingTimer = TimeMetric(true /* enable */); compilingTimer.start(); - std::unique_ptr executionContext; - std::unique_ptr logicalPlan; + std::unique_ptr statement; try { - // parsing - auto statement = Parser::parseQuery(query); + statement = Parser::parseQuery(query); + preparedStatement->preparedSummary.statementType = statement->getStatementType(); preparedStatement->readOnly = parser::StatementReadWriteAnalyzer().isReadOnly(*statement); if (database->systemConfig.readOnly && !preparedStatement->isReadOnly()) { throw ConnectionException("Cannot execute write operations in a read-only database!"); } + } catch (std::exception& exception) { + preparedStatement->success = false; + preparedStatement->errMsg = exception.what(); + compilingTimer.stop(); + preparedStatement->preparedSummary.compilingTime = compilingTimer.getElapsedTimeMS(); + return preparedStatement; + } + std::unique_ptr executionContext; + std::unique_ptr logicalPlan; + try { + // parsing + if (statement->getStatementType() != StatementType::TRANSACTION) { + auto txContext = clientContext->transactionContext.get(); + if (txContext->isAutoTransaction()) { + txContext->beginAutoTransaction(preparedStatement->readOnly); + } else { + txContext->validateManualTransaction( + preparedStatement->allowActiveTransaction(), preparedStatement->readOnly); + } + if (!clientContext->getTx()->isReadOnly()) { + database->catalog->initCatalogContentForWriteTrxIfNecessary(); + database->storageManager->initStatistics(); + } + } // binding auto binder = Binder(*database->catalog, database->memoryManager.get(), database->storageManager.get(), clientContext.get()); auto boundStatement = binder.bind(*statement); - preparedStatement->preparedSummary.statementType = boundStatement->getStatementType(); preparedStatement->parameterMap = binder.getParameterMap(); preparedStatement->statementResult = boundStatement->getStatementResult()->copy(); // planning @@ -141,6 +163,7 @@ std::unique_ptr Connection::prepareNoLock( } catch (std::exception& exception) { preparedStatement->success = false; preparedStatement->errMsg = exception.what(); + clientContext->transactionContext->rollback(); } compilingTimer.stop(); preparedStatement->preparedSummary.compilingTime = compilingTimer.getElapsedTimeMS(); @@ -200,6 +223,17 @@ void Connection::bindParametersNoLock(PreparedStatement* preparedStatement, std::unique_ptr Connection::executeAndAutoCommitIfNecessaryNoLock( PreparedStatement* preparedStatement, uint32_t planIdx) { + if (!preparedStatement->isSuccess()) { + return queryResultWithError(preparedStatement->errMsg); + } + if (preparedStatement->preparedSummary.statementType != common::StatementType::TRANSACTION && + clientContext->getTx() == nullptr) { + clientContext->transactionContext->beginAutoTransaction(preparedStatement->isReadOnly()); + if (!preparedStatement->readOnly) { + database->catalog->initCatalogContentForWriteTrxIfNecessary(); + database->storageManager->initStatistics(); + } + } clientContext->resetActiveQuery(); clientContext->startTimingIfEnabled(); auto mapper = PlanMapper( @@ -211,14 +245,10 @@ std::unique_ptr Connection::executeAndAutoCommitIfNecessaryNoLock( mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[planIdx].get(), preparedStatement->statementResult->getColumns()); } catch (std::exception& exception) { - preparedStatement->success = false; - preparedStatement->errMsg = exception.what(); + clientContext->transactionContext->rollback(); + return queryResultWithError(exception.what()); } } - if (!preparedStatement->isSuccess()) { - clientContext->transactionContext->rollback(); - return queryResultWithError(preparedStatement->errMsg); - } auto queryResult = std::make_unique(preparedStatement->preparedSummary); auto profiler = std::make_unique(); auto executionContext = @@ -234,14 +264,10 @@ std::unique_ptr Connection::executeAndAutoCommitIfNecessaryNoLock( database->queryProcessor->execute(physicalPlan.get(), executionContext.get()); } else { if (clientContext->transactionContext->isAutoTransaction()) { - clientContext->transactionContext->beginAutoTransaction( - preparedStatement->isReadOnly()); resultFT = database->queryProcessor->execute(physicalPlan.get(), executionContext.get()); clientContext->transactionContext->commit(); } else { - clientContext->transactionContext->validateManualTransaction( - preparedStatement->allowActiveTransaction(), preparedStatement->isReadOnly()); resultFT = database->queryProcessor->execute(physicalPlan.get(), executionContext.get()); } diff --git a/src/main/storage_driver.cpp b/src/main/storage_driver.cpp index 7bc9a68d86..863ba5a67a 100644 --- a/src/main/storage_driver.cpp +++ b/src/main/storage_driver.cpp @@ -16,9 +16,9 @@ StorageDriver::StorageDriver(Database* database) void StorageDriver::scan(const std::string& nodeName, const std::string& propertyName, offset_t* offsets, size_t size, uint8_t* result, size_t numThreads) { // Resolve files to read from - auto catalogContent = catalog->getReadOnlyVersion(); - auto nodeTableID = catalogContent->getTableID(nodeName); - auto propertyID = catalogContent->getTableSchema(nodeTableID)->getPropertyID(propertyName); + auto nodeTableID = catalog->getTableID(&DUMMY_READ_TRANSACTION, nodeName); + auto propertyID = + catalog->getTableSchema(&DUMMY_READ_TRANSACTION, nodeTableID)->getPropertyID(propertyName); auto nodeTable = storageManager->getNodeTable(nodeTableID); auto column = nodeTable->getColumn(propertyID); auto current_buffer = result; @@ -40,8 +40,7 @@ void StorageDriver::scan(const std::string& nodeName, const std::string& propert } uint64_t StorageDriver::getNumNodes(const std::string& nodeName) { - auto catalogContent = catalog->getReadOnlyVersion(); - auto nodeTableID = catalogContent->getTableID(nodeName); + auto nodeTableID = catalog->getTableID(&DUMMY_READ_TRANSACTION, nodeName); auto nodeStatistics = storageManager->getNodesStatisticsAndDeletedIDs()->getNodeStatisticsAndDeletedIDs( nodeTableID); @@ -49,8 +48,7 @@ uint64_t StorageDriver::getNumNodes(const std::string& nodeName) { } uint64_t StorageDriver::getNumRels(const std::string& relName) { - auto catalogContent = catalog->getReadOnlyVersion(); - auto relTableID = catalogContent->getTableID(relName); + auto relTableID = catalog->getTableID(&DUMMY_READ_TRANSACTION, relName); auto relStatistics = storageManager->getRelsStatistics()->getRelStatistics( relTableID, Transaction::getDummyReadOnlyTrx().get()); return relStatistics->getNumTuples(); diff --git a/src/planner/plan/append_extend.cpp b/src/planner/plan/append_extend.cpp index c222c74882..fe83f63341 100644 --- a/src/planner/plan/append_extend.cpp +++ b/src/planner/plan/append_extend.cpp @@ -9,10 +9,12 @@ #include "planner/operator/extend/logical_recursive_extend.h" #include "planner/operator/logical_node_label_filter.h" #include "planner/query_planner.h" +#include "transaction/transaction.h" using namespace kuzu::common; using namespace kuzu::binder; using namespace kuzu::catalog; +using namespace kuzu::transaction; namespace kuzu { namespace planner { @@ -30,7 +32,7 @@ static bool extendHasAtMostOneNbrGuarantee(RelExpression& rel, NodeExpression& b } auto relDirection = ExtendDirectionUtils::getRelDataDirection(direction); auto relTableSchema = reinterpret_cast( - catalog.getReadOnlyVersion()->getTableSchema(rel.getSingleTableID())); + catalog.getTableSchema(&DUMMY_READ_TRANSACTION, rel.getSingleTableID())); return relTableSchema->isSingleMultiplicityInDirection(relDirection); } @@ -39,7 +41,7 @@ static std::unordered_set getBoundNodeTableIDSet( std::unordered_set result; for (auto tableID : rel.getTableIDs()) { auto tableSchema = reinterpret_cast( - catalog.getReadOnlyVersion()->getTableSchema(tableID)); + catalog.getTableSchema(&DUMMY_READ_TRANSACTION, tableID)); switch (extendDirection) { case ExtendDirection::FWD: { result.insert(tableSchema->getBoundTableID(RelDataDirection::FWD)); @@ -63,7 +65,7 @@ static std::unordered_set getNbrNodeTableIDSet( std::unordered_set result; for (auto tableID : rel.getTableIDs()) { auto tableSchema = reinterpret_cast( - catalog.getReadOnlyVersion()->getTableSchema(tableID)); + catalog.getTableSchema(&DUMMY_READ_TRANSACTION, tableID)); switch (extendDirection) { case ExtendDirection::FWD: { result.insert(tableSchema->getNbrTableID(RelDataDirection::FWD)); diff --git a/src/processor/map/map_ddl.cpp b/src/processor/map/map_ddl.cpp index 9ea4060582..594a9d83d8 100644 --- a/src/processor/map/map_ddl.cpp +++ b/src/processor/map/map_ddl.cpp @@ -12,6 +12,7 @@ #include "processor/operator/ddl/rename_property.h" #include "processor/operator/ddl/rename_table.h" #include "processor/plan_mapper.h" +#include "transaction/transaction.h" using namespace kuzu::binder; using namespace kuzu::common; @@ -77,8 +78,8 @@ std::unique_ptr PlanMapper::mapCreateRdfGraph(LogicalOperator* std::unique_ptr PlanMapper::mapDropTable(LogicalOperator* logicalOperator) { auto dropTable = (LogicalDropTable*)logicalOperator; - return std::make_unique(catalog, dropTable->getTableID(), getOutputPos(dropTable), - getOperatorID(), dropTable->getExpressionsForPrinting()); + return std::make_unique(catalog, dropTable->getTableName(), dropTable->getTableID(), + getOutputPos(dropTable), getOperatorID(), dropTable->getExpressionsForPrinting()); } std::unique_ptr PlanMapper::mapAlter(LogicalOperator* logicalOperator) { @@ -115,7 +116,7 @@ std::unique_ptr PlanMapper::mapAddProperty(LogicalOperator* lo auto extraInfo = reinterpret_cast(info->extraInfo.get()); auto expressionEvaluator = ExpressionMapper::getEvaluator(extraInfo->defaultValue, alter->getSchema()); - auto tableSchema = catalog->getReadOnlyVersion()->getTableSchema(info->tableID); + auto tableSchema = catalog->getTableSchema(&transaction::DUMMY_READ_TRANSACTION, info->tableID); switch (tableSchema->getTableType()) { case TableType::NODE: return std::make_unique(catalog, info->tableID, extraInfo->propertyName, diff --git a/src/processor/map/map_delete.cpp b/src/processor/map/map_delete.cpp index 6fa658e0fb..5c08da3e31 100644 --- a/src/processor/map/map_delete.cpp +++ b/src/processor/map/map_delete.cpp @@ -1,6 +1,7 @@ #include "planner/operator/persistent/logical_delete.h" #include "processor/operator/persistent/delete.h" #include "processor/plan_mapper.h" +#include "transaction/transaction.h" using namespace kuzu::binder; using namespace kuzu::common; @@ -18,7 +19,8 @@ static std::unique_ptr getNodeDeleteExecutor(catalog::Catalo std::unordered_map> tableIDToFwdRelTablesMap; std::unordered_map> tableIDToBwdRelTablesMap; for (auto tableID : info->node->getTableIDs()) { - auto tableSchema = catalog->getReadOnlyVersion()->getTableSchema(tableID); + auto tableSchema = + catalog->getTableSchema(&transaction::DUMMY_READ_TRANSACTION, tableID); auto nodeTableSchema = ku_dynamic_cast(tableSchema); auto table = storageManager.getNodeTable(tableID); @@ -41,8 +43,8 @@ static std::unique_ptr getNodeDeleteExecutor(catalog::Catalo info->deleteType, nodeIDPos); } else { auto table = storageManager.getNodeTable(info->node->getSingleTableID()); - auto tableSchema = - catalog->getReadOnlyVersion()->getTableSchema(info->node->getSingleTableID()); + auto tableSchema = catalog->getTableSchema( + &transaction::DUMMY_READ_TRANSACTION, info->node->getSingleTableID()); auto nodeTableSchema = ku_dynamic_cast(tableSchema); auto fwdRelTableIDs = nodeTableSchema->getFwdRelTableIDSet(); diff --git a/src/processor/map/map_extend.cpp b/src/processor/map/map_extend.cpp index 46393e43db..c348652c39 100644 --- a/src/processor/map/map_extend.cpp +++ b/src/processor/map/map_extend.cpp @@ -4,6 +4,7 @@ #include "processor/operator/scan/scan_rel_csr_columns.h" #include "processor/operator/scan/scan_rel_regular_columns.h" #include "processor/plan_mapper.h" +#include "transaction/transaction.h" using namespace kuzu::binder; using namespace kuzu::common; @@ -36,7 +37,7 @@ static std::unique_ptr populateRelTableCollectionScan std::vector> scanInfos; for (auto relTableID : rel.getTableIDs()) { auto relTableSchema = reinterpret_cast( - catalog.getReadOnlyVersion()->getTableSchema(relTableID)); + catalog.getTableSchema(&transaction::DUMMY_READ_TRANSACTION, relTableID)); switch (extendDirection) { case ExtendDirection::FWD: { if (relTableSchema->getBoundTableID(RelDataDirection::FWD) == boundNodeTableID) { @@ -101,7 +102,7 @@ std::unique_ptr PlanMapper::mapExtend(LogicalOperator* logical if (!rel->isMultiLabeled() && !boundNode->isMultiLabeled() && extendDirection != ExtendDirection::BOTH) { auto tableSchema = dynamic_cast( - catalog->getReadOnlyVersion()->getTableSchema(rel->getSingleTableID())); + catalog->getTableSchema(&transaction::DUMMY_READ_TRANSACTION, rel->getSingleTableID())); auto relDataDirection = ExtendDirectionUtils::getRelDataDirection(extendDirection); auto scanInfo = getRelTableScanInfo( tableSchema, relDataDirection, storageManager, extend->getProperties()); diff --git a/src/processor/map/map_insert.cpp b/src/processor/map/map_insert.cpp index dc13a35254..dfbdb45719 100644 --- a/src/processor/map/map_insert.cpp +++ b/src/processor/map/map_insert.cpp @@ -2,6 +2,7 @@ #include "planner/operator/persistent/logical_insert.h" #include "processor/operator/persistent/insert.h" #include "processor/plan_mapper.h" +#include "transaction/transaction.h" using namespace kuzu::evaluator; using namespace kuzu::planner; @@ -32,7 +33,7 @@ std::unique_ptr PlanMapper::getNodeInsertExecutor( auto table = storageManager.getNodeTable(nodeTableID); std::unordered_set fwdRelTablesToInit; std::unordered_set bwdRelTablesToInit; - auto tableSchema = catalog->getReadOnlyVersion()->getTableSchema(nodeTableID); + auto tableSchema = catalog->getTableSchema(&transaction::DUMMY_READ_TRANSACTION, nodeTableID); auto nodeTableSchema = common::ku_dynamic_cast(tableSchema); auto fwdRelTableIDs = nodeTableSchema->getFwdRelTableIDSet(); diff --git a/src/processor/map/map_recursive_extend.cpp b/src/processor/map/map_recursive_extend.cpp index 29e7bfce10..726b4b706b 100644 --- a/src/processor/map/map_recursive_extend.cpp +++ b/src/processor/map/map_recursive_extend.cpp @@ -1,6 +1,7 @@ #include "planner/operator/extend/logical_recursive_extend.h" #include "processor/operator/recursive_extend/recursive_join.h" #include "processor/plan_mapper.h" +#include "transaction/transaction.h" using namespace kuzu::binder; using namespace kuzu::planner; @@ -48,7 +49,7 @@ std::unique_ptr PlanMapper::mapRecursiveExtend( pathPos = DataPos(outSchema->getExpressionPos(*rel)); } std::unordered_map tableIDToName; - for (auto& schema : catalog->getReadOnlyVersion()->getTableSchemas()) { + for (auto& schema : catalog->getTableSchemas(&transaction::DUMMY_READ_TRANSACTION)) { tableIDToName.insert({schema->getTableID(), schema->tableName}); } auto dataInfo = std::make_unique(boundNodeIDPos, nbrNodeIDPos, diff --git a/src/processor/map/map_scan_node_property.cpp b/src/processor/map/map_scan_node_property.cpp index 14770ae433..60ec076432 100644 --- a/src/processor/map/map_scan_node_property.cpp +++ b/src/processor/map/map_scan_node_property.cpp @@ -2,6 +2,7 @@ #include "planner/operator/scan/logical_scan_node_property.h" #include "processor/operator/scan/scan_multi_node_tables.h" #include "processor/plan_mapper.h" +#include "transaction/transaction.h" using namespace kuzu::binder; using namespace kuzu::common; @@ -32,8 +33,8 @@ std::unique_ptr PlanMapper::mapScanNodeProperty( columns.push_back(UINT32_MAX); } else { columns.push_back( - catalog->getReadOnlyVersion()->getTableSchema(tableID)->getColumnID( - property->getPropertyID(tableID))); + catalog->getTableSchema(&transaction::DUMMY_READ_TRANSACTION, tableID) + ->getColumnID(property->getPropertyID(tableID))); } } tables.insert({tableID, std::make_unique( @@ -44,7 +45,7 @@ std::unique_ptr PlanMapper::mapScanNodeProperty( scanProperty.getExpressionsForPrinting()); } else { auto tableID = tableIDs[0]; - auto tableSchema = catalog->getReadOnlyVersion()->getTableSchema(tableID); + auto tableSchema = catalog->getTableSchema(&transaction::DUMMY_READ_TRANSACTION, tableID); std::vector columnIDs; for (auto& expression : scanProperty.getProperties()) { auto property = static_pointer_cast(expression); diff --git a/src/processor/map/map_set.cpp b/src/processor/map/map_set.cpp index e9c4f528fb..a145e22048 100644 --- a/src/processor/map/map_set.cpp +++ b/src/processor/map/map_set.cpp @@ -3,11 +3,13 @@ #include "planner/operator/persistent/logical_set.h" #include "processor/operator/persistent/set.h" #include "processor/plan_mapper.h" +#include "transaction/transaction.h" using namespace kuzu::binder; using namespace kuzu::common; using namespace kuzu::planner; using namespace kuzu::evaluator; +using namespace kuzu::transaction; namespace kuzu { namespace processor { @@ -31,7 +33,7 @@ std::unique_ptr PlanMapper::getNodeSetExecutor( auto propertyID = property->getPropertyID(tableID); auto table = storageManager.getNodeTable(tableID); auto columnID = - catalog->getReadOnlyVersion()->getTableSchema(tableID)->getColumnID(propertyID); + catalog->getTableSchema(&DUMMY_READ_TRANSACTION, tableID)->getColumnID(propertyID); tableIDToSetInfo.insert({tableID, NodeSetInfo{table, columnID}}); } return std::make_unique( @@ -43,7 +45,7 @@ std::unique_ptr PlanMapper::getNodeSetExecutor( if (property->hasPropertyID(tableID)) { auto propertyID = property->getPropertyID(tableID); columnID = - catalog->getReadOnlyVersion()->getTableSchema(tableID)->getColumnID(propertyID); + catalog->getTableSchema(&DUMMY_READ_TRANSACTION, tableID)->getColumnID(propertyID); } return std::make_unique( NodeSetInfo{table, columnID}, nodeIDPos, propertyPos, std::move(evaluator)); @@ -84,7 +86,7 @@ std::unique_ptr PlanMapper::getRelSetExecutor( auto table = storageManager.getRelTable(tableID); auto propertyID = property->getPropertyID(tableID); auto columnID = - catalog->getReadOnlyVersion()->getTableSchema(tableID)->getColumnID(propertyID); + catalog->getTableSchema(&DUMMY_READ_TRANSACTION, tableID)->getColumnID(propertyID); tableIDToTableAndColumnID.insert({tableID, std::make_pair(table, columnID)}); } return std::make_unique(std::move(tableIDToTableAndColumnID), @@ -96,7 +98,7 @@ std::unique_ptr PlanMapper::getRelSetExecutor( if (property->hasPropertyID(tableID)) { auto propertyID = property->getPropertyID(tableID); columnID = - catalog->getReadOnlyVersion()->getTableSchema(tableID)->getColumnID(propertyID); + catalog->getTableSchema(&DUMMY_READ_TRANSACTION, tableID)->getColumnID(propertyID); } return std::make_unique( table, columnID, srcNodePos, dstNodePos, relIDPos, propertyPos, std::move(evaluator)); diff --git a/src/processor/operator/ddl/add_node_property.cpp b/src/processor/operator/ddl/add_node_property.cpp index 670f0a1562..13ee8277db 100644 --- a/src/processor/operator/ddl/add_node_property.cpp +++ b/src/processor/operator/ddl/add_node_property.cpp @@ -3,9 +3,9 @@ namespace kuzu { namespace processor { -void AddNodeProperty::executeDDLInternal() { +void AddNodeProperty::executeDDLInternal(ExecutionContext* context) { catalog->addNodeProperty(tableID, propertyName, std::move(dataType)); - auto schema = catalog->getWriteVersion()->getTableSchema(tableID); + auto schema = catalog->getTableSchema(context->clientContext->getTx(), tableID); auto addedPropID = schema->getPropertyID(propertyName); auto addedProp = schema->getProperty(addedPropID); storageManager.getNodeTable(tableID)->addColumn(transaction, *addedProp, getDefaultValVector()); diff --git a/src/processor/operator/ddl/add_rel_property.cpp b/src/processor/operator/ddl/add_rel_property.cpp index b91b83f7e7..87e110c5c3 100644 --- a/src/processor/operator/ddl/add_rel_property.cpp +++ b/src/processor/operator/ddl/add_rel_property.cpp @@ -7,12 +7,11 @@ using namespace kuzu::common; namespace kuzu { namespace processor { -void AddRelProperty::executeDDLInternal() { +void AddRelProperty::executeDDLInternal(ExecutionContext* context) { catalog->addRelProperty(tableID, propertyName, dataType->copy()); - auto tableSchema = catalog->getWriteVersion()->getTableSchema(tableID); + auto tableSchema = catalog->getTableSchema(context->clientContext->getTx(), tableID); auto addedPropertyID = tableSchema->getPropertyID(propertyName); - auto addedProp = - catalog->getWriteVersion()->getTableSchema(tableID)->getProperty(addedPropertyID); + auto addedProp = tableSchema->getProperty(addedPropertyID); storageManager.getRelTable(tableID)->addColumn(transaction, *addedProp, getDefaultValVector()); storageManager.getWAL()->logAddPropertyRecord(tableID, addedProp->getPropertyID()); } diff --git a/src/processor/operator/ddl/create_node_table.cpp b/src/processor/operator/ddl/create_node_table.cpp index aa597985e2..56ea78cc92 100644 --- a/src/processor/operator/ddl/create_node_table.cpp +++ b/src/processor/operator/ddl/create_node_table.cpp @@ -10,10 +10,10 @@ using namespace kuzu::common; namespace kuzu { namespace processor { -void CreateNodeTable::executeDDLInternal() { +void CreateNodeTable::executeDDLInternal(ExecutionContext* context) { auto newTableID = catalog->addNodeTableSchema(*info); - auto newNodeTableSchema = - reinterpret_cast(catalog->getWriteVersion()->getTableSchema(newTableID)); + auto newNodeTableSchema = reinterpret_cast( + catalog->getTableSchema(context->clientContext->getTx(), newTableID)); storageManager->getNodesStatisticsAndDeletedIDs()->addNodeStatisticsAndDeletedIDs( newNodeTableSchema); storageManager->getWAL()->logCreateNodeTableRecord(newTableID); diff --git a/src/processor/operator/ddl/create_rdf_graph.cpp b/src/processor/operator/ddl/create_rdf_graph.cpp index 5b96fa252a..4c711eebe3 100644 --- a/src/processor/operator/ddl/create_rdf_graph.cpp +++ b/src/processor/operator/ddl/create_rdf_graph.cpp @@ -9,26 +9,26 @@ using namespace kuzu::catalog; namespace kuzu { namespace processor { -void CreateRdfGraph::executeDDLInternal() { +void CreateRdfGraph::executeDDLInternal(ExecutionContext* context) { + auto tx = context->clientContext->getTx(); auto newRdfGraphID = catalog->addRdfGraphSchema(*info); - auto writeCatalog = catalog->getWriteVersion(); auto rdfGraphSchema = - reinterpret_cast(writeCatalog->getTableSchema(newRdfGraphID)); + reinterpret_cast(catalog->getTableSchema(tx, newRdfGraphID)); auto resourceTableID = rdfGraphSchema->getResourceTableID(); auto resourceTableSchema = - reinterpret_cast(writeCatalog->getTableSchema(resourceTableID)); + reinterpret_cast(catalog->getTableSchema(tx, resourceTableID)); nodesStatistics->addNodeStatisticsAndDeletedIDs(resourceTableSchema); auto literalTableID = rdfGraphSchema->getLiteralTableID(); auto literalTableSchema = - reinterpret_cast(writeCatalog->getTableSchema(literalTableID)); + reinterpret_cast(catalog->getTableSchema(tx, literalTableID)); nodesStatistics->addNodeStatisticsAndDeletedIDs(literalTableSchema); auto resourceTripleTableID = rdfGraphSchema->getResourceTripleTableID(); auto resourceTripleTableSchema = - reinterpret_cast(writeCatalog->getTableSchema(resourceTripleTableID)); + reinterpret_cast(catalog->getTableSchema(tx, resourceTripleTableID)); relsStatistics->addTableStatistic(resourceTripleTableSchema); auto literalTripleTableID = rdfGraphSchema->getLiteralTripleTableID(); auto literalTripleTableSchema = - reinterpret_cast(writeCatalog->getTableSchema(literalTripleTableID)); + reinterpret_cast(catalog->getTableSchema(tx, literalTripleTableID)); relsStatistics->addTableStatistic(literalTripleTableSchema); storageManager->getWAL()->logRdfGraphRecord(newRdfGraphID, resourceTableID, literalTableID, resourceTripleTableID, literalTripleTableID); diff --git a/src/processor/operator/ddl/create_rel_table.cpp b/src/processor/operator/ddl/create_rel_table.cpp index 7a46b48f93..bf8fa01f9d 100644 --- a/src/processor/operator/ddl/create_rel_table.cpp +++ b/src/processor/operator/ddl/create_rel_table.cpp @@ -11,10 +11,10 @@ using namespace kuzu::binder; namespace kuzu { namespace processor { -void CreateRelTable::executeDDLInternal() { +void CreateRelTable::executeDDLInternal(ExecutionContext* context) { auto newRelTableID = catalog->addRelTableSchema(*info); auto newRelTableSchema = reinterpret_cast( - catalog->getWriteVersion()->getTableSchema(newRelTableID)); + catalog->getTableSchema(context->clientContext->getTx(), newRelTableID)); storageManager->getRelsStatistics()->addTableStatistic(newRelTableSchema); storageManager->getWAL()->logCreateRelTableRecord(newRelTableID); } diff --git a/src/processor/operator/ddl/create_rel_table_group.cpp b/src/processor/operator/ddl/create_rel_table_group.cpp index d96573fca2..7a1df0365e 100644 --- a/src/processor/operator/ddl/create_rel_table_group.cpp +++ b/src/processor/operator/ddl/create_rel_table_group.cpp @@ -11,13 +11,13 @@ using namespace kuzu::catalog; namespace kuzu { namespace processor { -void CreateRelTableGroup::executeDDLInternal() { +void CreateRelTableGroup::executeDDLInternal(ExecutionContext* context) { auto newRelTableGroupID = catalog->addRelTableGroupSchema(*info); - auto writeCatalog = catalog->getWriteVersion(); + auto tx = context->clientContext->getTx(); auto newRelTableGroupSchema = - (RelTableGroupSchema*)writeCatalog->getTableSchema(newRelTableGroupID); + reinterpret_cast(catalog->getTableSchema(tx, newRelTableGroupID)); for (auto& relTableID : newRelTableGroupSchema->getRelTableIDs()) { - auto newRelTableSchema = writeCatalog->getTableSchema(relTableID); + auto newRelTableSchema = catalog->getTableSchema(tx, relTableID); storageManager->getRelsStatistics()->addTableStatistic((RelTableSchema*)newRelTableSchema); } // TODO(Ziyi): remove this when we can log variable size record. See also wal_record.h diff --git a/src/processor/operator/ddl/ddl.cpp b/src/processor/operator/ddl/ddl.cpp index 70eeedcaa2..f081583032 100644 --- a/src/processor/operator/ddl/ddl.cpp +++ b/src/processor/operator/ddl/ddl.cpp @@ -7,12 +7,12 @@ void DDL::initLocalStateInternal(ResultSet* resultSet, ExecutionContext* /*conte outputVector = resultSet->getValueVector(outputPos).get(); } -bool DDL::getNextTuplesInternal(ExecutionContext* /*context*/) { +bool DDL::getNextTuplesInternal(ExecutionContext* context) { if (hasExecuted) { return false; } hasExecuted = true; - executeDDLInternal(); + executeDDLInternal(context); outputVector->setValue(0, getOutputMsg()); metrics->numOutputTuple.increase(1); return true; diff --git a/src/processor/operator/ddl/drop_property.cpp b/src/processor/operator/ddl/drop_property.cpp index 680dfa1999..f85ec7c018 100644 --- a/src/processor/operator/ddl/drop_property.cpp +++ b/src/processor/operator/ddl/drop_property.cpp @@ -3,15 +3,16 @@ namespace kuzu { namespace processor { -void DropProperty::executeDDLInternal() { - auto tableSchema = catalog->getReadOnlyVersion()->getTableSchema(tableID); +void DropProperty::executeDDLInternal(ExecutionContext* context) { + auto tableSchema = catalog->getTableSchema(context->clientContext->getTx(), tableID); + auto columnID = tableSchema->getColumnID(propertyID); catalog->dropProperty(tableID, propertyID); if (tableSchema->tableType == common::TableType::NODE) { auto nodesStats = storageManager.getNodesStatisticsAndDeletedIDs(); - nodesStats->removeMetadataDAHInfo(tableID, tableSchema->getColumnID(propertyID)); + nodesStats->removeMetadataDAHInfo(tableID, columnID); } else { auto relsStats = storageManager.getRelsStatistics(); - relsStats->removeMetadataDAHInfo(tableID, tableSchema->getColumnID(propertyID)); + relsStats->removeMetadataDAHInfo(tableID, columnID); } } diff --git a/src/processor/operator/ddl/drop_table.cpp b/src/processor/operator/ddl/drop_table.cpp index 4838421501..67d9919c3b 100644 --- a/src/processor/operator/ddl/drop_table.cpp +++ b/src/processor/operator/ddl/drop_table.cpp @@ -8,14 +8,12 @@ using namespace kuzu::common; namespace kuzu { namespace processor { -void DropTable::executeDDLInternal() { +void DropTable::executeDDLInternal(ExecutionContext* /*context*/) { catalog->dropTableSchema(tableID); } std::string DropTable::getOutputMsg() { - auto tableSchema = catalog->getReadOnlyVersion()->getTableSchema(tableID); - return stringFormat("{} table: {} has been dropped.", - tableSchema->tableType == TableType::NODE ? "Node" : "Rel", tableSchema->tableName); + return stringFormat("Table: {} has been dropped.", tableName); } } // namespace processor diff --git a/src/processor/operator/index_lookup.cpp b/src/processor/operator/index_lookup.cpp index 49c280c48d..9ece2bcde4 100644 --- a/src/processor/operator/index_lookup.cpp +++ b/src/processor/operator/index_lookup.cpp @@ -15,7 +15,7 @@ bool IndexLookup::getNextTuplesInternal(ExecutionContext* context) { } for (auto& info : infos) { KU_ASSERT(info); - indexLookup(context->clientContext->getActiveTransaction(), *info); + indexLookup(context->clientContext->getTx(), *info); } return true; } diff --git a/src/processor/operator/persistent/copy_node.cpp b/src/processor/operator/persistent/copy_node.cpp index 7ee5242ce9..21a9eec1b2 100644 --- a/src/processor/operator/persistent/copy_node.cpp +++ b/src/processor/operator/persistent/copy_node.cpp @@ -169,12 +169,12 @@ void CopyNode::finalize(ExecutionContext* context) { sharedState->table->getNodeStatisticsAndDeletedIDs()->setNumTuplesForTable( sharedState->table->getTableID(), numNodes); for (auto relTable : info->fwdRelTables) { - relTable->resizeColumns(context->clientContext->getActiveTransaction(), - RelDataDirection::FWD, sharedState->getCurNodeGroupIdx()); + relTable->resizeColumns(context->clientContext->getTx(), RelDataDirection::FWD, + sharedState->getCurNodeGroupIdx()); } for (auto relTable : info->bwdRelTables) { - relTable->resizeColumns(context->clientContext->getActiveTransaction(), - RelDataDirection::BWD, sharedState->getCurNodeGroupIdx()); + relTable->resizeColumns(context->clientContext->getTx(), RelDataDirection::BWD, + sharedState->getCurNodeGroupIdx()); } auto outputMsg = stringFormat( "{} number of tuples has been copied to table: {}.", numNodes, info->tableName.c_str()); diff --git a/src/processor/operator/persistent/delete_executor.cpp b/src/processor/operator/persistent/delete_executor.cpp index 8aeb288410..2341aa3a57 100644 --- a/src/processor/operator/persistent/delete_executor.cpp +++ b/src/processor/operator/persistent/delete_executor.cpp @@ -25,12 +25,12 @@ static void deleteFromRelTable(ExecutionContext* context, DeleteNodeType deleteT RelDetachDeleteState* detachDeleteState) { switch (deleteType) { case DeleteNodeType::DETACH_DELETE: { - relTable->detachDelete(context->clientContext->getActiveTransaction(), direction, - nodeIDVector, detachDeleteState); + relTable->detachDelete( + context->clientContext->getTx(), direction, nodeIDVector, detachDeleteState); } break; case DeleteNodeType::DELETE: { if (relTable->checkIfNodeHasRels( - context->clientContext->getActiveTransaction(), direction, nodeIDVector)) { + context->clientContext->getTx(), direction, nodeIDVector)) { throw RuntimeException( stringFormat("Deleted nodes has connected edges in the {} direction.", RelDataDirectionUtils::relDirectionToString(direction))); @@ -57,7 +57,7 @@ void SingleLabelNodeDeleteExecutor::delete_(ExecutionContext* context) { deleteFromRelTable(context, deleteType, RelDataDirection::BWD, relTable, nodeIDVector, detachDeleteState.get()); } - table->delete_(context->clientContext->getActiveTransaction(), nodeIDVector, pkVector.get()); + table->delete_(context->clientContext->getTx(), nodeIDVector, pkVector.get()); } void MultiLabelNodeDeleteExecutor::init(ResultSet* resultSet, ExecutionContext* context) { @@ -88,8 +88,8 @@ void MultiLabelNodeDeleteExecutor::delete_(ExecutionContext* context) { deleteFromRelTable(context, deleteType, RelDataDirection::BWD, relTable, nodeIDVector, detachDeleteState.get()); } - table->delete_(context->clientContext->getActiveTransaction(), nodeIDVector, - pkVectors.at(nodeID.tableID).get()); + table->delete_( + context->clientContext->getTx(), nodeIDVector, pkVectors.at(nodeID.tableID).get()); } void RelDeleteExecutor::init(ResultSet* resultSet, ExecutionContext* /*context*/) { @@ -99,8 +99,7 @@ void RelDeleteExecutor::init(ResultSet* resultSet, ExecutionContext* /*context*/ } void SingleLabelRelDeleteExecutor::delete_(ExecutionContext* context) { - table->delete_(context->clientContext->getActiveTransaction(), srcNodeIDVector, dstNodeIDVector, - relIDVector); + table->delete_(context->clientContext->getTx(), srcNodeIDVector, dstNodeIDVector, relIDVector); } void MultiLabelRelDeleteExecutor::delete_(ExecutionContext* context) { @@ -109,8 +108,7 @@ void MultiLabelRelDeleteExecutor::delete_(ExecutionContext* context) { auto relID = relIDVector->getValue(pos); KU_ASSERT(tableIDToTableMap.contains(relID.tableID)); auto table = tableIDToTableMap.at(relID.tableID); - table->delete_(context->clientContext->getActiveTransaction(), srcNodeIDVector, dstNodeIDVector, - relIDVector); + table->delete_(context->clientContext->getTx(), srcNodeIDVector, dstNodeIDVector, relIDVector); } } // namespace processor diff --git a/src/processor/operator/persistent/reader/csv/parallel_csv_reader.cpp b/src/processor/operator/persistent/reader/csv/parallel_csv_reader.cpp index f6ff3fc3fb..d26e113d9d 100644 --- a/src/processor/operator/persistent/reader/csv/parallel_csv_reader.cpp +++ b/src/processor/operator/persistent/reader/csv/parallel_csv_reader.cpp @@ -159,7 +159,7 @@ void ParallelCSVScan::tableFunc(TableFunctionInput& input, common::DataChunk& ou std::unique_ptr ParallelCSVScan::bindFunc( main::ClientContext* /*context*/, function::TableFuncBindInput* input, - catalog::CatalogContent* /*catalog*/) { + catalog::Catalog* /*catalog*/) { auto scanInput = reinterpret_cast(input); std::vector detectedColumnNames; std::vector> detectedColumnTypes; diff --git a/src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp b/src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp index aece2c4e9c..1754bf1874 100644 --- a/src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp +++ b/src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp @@ -81,7 +81,7 @@ void SerialCSVScan::tableFunc(TableFunctionInput& input, DataChunk& outputChunk) std::unique_ptr SerialCSVScan::bindFunc( main::ClientContext* /*context*/, function::TableFuncBindInput* input, - catalog::CatalogContent* /*catalog*/) { + catalog::Catalog* /*catalog*/) { auto scanInput = reinterpret_cast(input); std::vector detectedColumnNames; std::vector> detectedColumnTypes; diff --git a/src/processor/operator/persistent/reader/npy/npy_reader.cpp b/src/processor/operator/persistent/reader/npy/npy_reader.cpp index a3dac5a062..1c8abaab05 100644 --- a/src/processor/operator/persistent/reader/npy/npy_reader.cpp +++ b/src/processor/operator/persistent/reader/npy/npy_reader.cpp @@ -252,7 +252,7 @@ void NpyScanFunction::tableFunc(TableFunctionInput& input, DataChunk& outputChun std::unique_ptr NpyScanFunction::bindFunc( main::ClientContext* /*context*/, function::TableFuncBindInput* input, - catalog::CatalogContent* /*catalog*/) { + catalog::Catalog* /*catalog*/) { auto scanInput = reinterpret_cast(input); std::vector detectedColumnNames; diff --git a/src/processor/operator/persistent/reader/parquet/parquet_reader.cpp b/src/processor/operator/persistent/reader/parquet/parquet_reader.cpp index 8b1bf1ac1b..ff82314ffd 100644 --- a/src/processor/operator/persistent/reader/parquet/parquet_reader.cpp +++ b/src/processor/operator/persistent/reader/parquet/parquet_reader.cpp @@ -622,7 +622,7 @@ void ParquetScanFunction::tableFunc(TableFunctionInput& input, DataChunk& output std::unique_ptr ParquetScanFunction::bindFunc( main::ClientContext* /*context*/, function::TableFuncBindInput* input, - catalog::CatalogContent* /*catalog*/) { + catalog::Catalog* /*catalog*/) { auto scanInput = reinterpret_cast(input); std::vector detectedColumnNames; std::vector> detectedColumnTypes; diff --git a/src/processor/operator/persistent/reader/rdf/rdf_reader.cpp b/src/processor/operator/persistent/reader/rdf/rdf_reader.cpp index a9e16b5bee..c6cf67c6db 100644 --- a/src/processor/operator/persistent/reader/rdf/rdf_reader.cpp +++ b/src/processor/operator/persistent/reader/rdf/rdf_reader.cpp @@ -289,7 +289,7 @@ void RdfScan::tableFunc(function::TableFunctionInput& input, common::DataChunk& } std::unique_ptr RdfScan::bindFunc(main::ClientContext* /*context*/, - function::TableFuncBindInput* input, catalog::CatalogContent* /*catalog*/) { + function::TableFuncBindInput* input, catalog::Catalog* /*catalog*/) { auto scanInput = reinterpret_cast(input); return std::make_unique( common::LogicalType::copy(scanInput->expectedColumnTypes), scanInput->expectedColumnNames, diff --git a/src/processor/operator/persistent/set_executor.cpp b/src/processor/operator/persistent/set_executor.cpp index 3550438f81..13e6c87984 100644 --- a/src/processor/operator/persistent/set_executor.cpp +++ b/src/processor/operator/persistent/set_executor.cpp @@ -52,7 +52,7 @@ void SingleLabelNodeSetExecutor::set(ExecutionContext* context) { auto lhsPos = nodeIDVector->state->selVector->selectedPositions[0]; auto rhsPos = rhsVector->state->selVector->selectedPositions[0]; setInfo.table->update( - context->clientContext->getActiveTransaction(), setInfo.columnID, nodeIDVector, rhsVector); + context->clientContext->getTx(), setInfo.columnID, nodeIDVector, rhsVector); if (lhsVector != nullptr) { writeToPropertyVector(nodeIDVector, lhsVector, lhsPos, rhsVector, rhsPos); } @@ -73,7 +73,7 @@ void MultiLabelNodeSetExecutor::set(ExecutionContext* context) { auto rhsPos = rhsVector->state->selVector->selectedPositions[0]; auto& setInfo = tableIDToSetInfo.at(nodeID.tableID); setInfo.table->update( - context->clientContext->getActiveTransaction(), setInfo.columnID, nodeIDVector, rhsVector); + context->clientContext->getTx(), setInfo.columnID, nodeIDVector, rhsVector); if (lhsVector != nullptr) { KU_ASSERT(lhsVector->state->selVector->selectedSize == 1); writeToPropertyVector(nodeIDVector, lhsVector, lhsPos, rhsVector, rhsPos); @@ -120,8 +120,8 @@ void SingleLabelRelSetExecutor::set(ExecutionContext* context) { return; } evaluator->evaluate(); - table->update(context->clientContext->getActiveTransaction(), columnID, srcNodeIDVector, - dstNodeIDVector, relIDVector, rhsVector); + table->update(context->clientContext->getTx(), columnID, srcNodeIDVector, dstNodeIDVector, + relIDVector, rhsVector); if (lhsVector != nullptr) { writeToPropertyVector(relIDVector, lhsVector, rhsVector); } @@ -139,8 +139,8 @@ void MultiLabelRelSetExecutor::set(ExecutionContext* context) { return; } auto [table, propertyID] = tableIDToTableAndColumnID.at(relID.tableID); - table->update(context->clientContext->getActiveTransaction(), propertyID, srcNodeIDVector, - dstNodeIDVector, relIDVector, rhsVector); + table->update(context->clientContext->getTx(), propertyID, srcNodeIDVector, dstNodeIDVector, + relIDVector, rhsVector); if (lhsVector != nullptr) { writeToPropertyVector(relIDVector, lhsVector, rhsVector); } diff --git a/src/processor/operator/physical_operator.cpp b/src/processor/operator/physical_operator.cpp index 6ee6167b85..98712f1a30 100644 --- a/src/processor/operator/physical_operator.cpp +++ b/src/processor/operator/physical_operator.cpp @@ -240,7 +240,7 @@ void PhysicalOperator::initLocalState(ResultSet* resultSet_, ExecutionContext* c if (!isSource()) { children[0]->initLocalState(resultSet_, context); } - transaction = context->clientContext->getActiveTransaction(); + transaction = context->clientContext->getTx(); resultSet = resultSet_; registerProfilingMetrics(context->profiler); initLocalStateInternal(resultSet_, context); diff --git a/src/processor/operator/scan/scan_rel_csr_columns.cpp b/src/processor/operator/scan/scan_rel_csr_columns.cpp index ab6c69bf96..b915b6ec5d 100644 --- a/src/processor/operator/scan/scan_rel_csr_columns.cpp +++ b/src/processor/operator/scan/scan_rel_csr_columns.cpp @@ -5,7 +5,7 @@ namespace processor { bool ScanRelCSRColumns::getNextTuplesInternal(ExecutionContext* context) { while (true) { - if (scanState->hasMoreToRead(context->clientContext->getActiveTransaction())) { + if (scanState->hasMoreToRead(context->clientContext->getTx())) { info->table->read(transaction, *scanState, inVector, outVectors); return true; } diff --git a/src/storage/storage_manager.cpp b/src/storage/storage_manager.cpp index 5755d625fa..771d4c13ea 100644 --- a/src/storage/storage_manager.cpp +++ b/src/storage/storage_manager.cpp @@ -5,6 +5,7 @@ using namespace kuzu::catalog; using namespace kuzu::common; +using namespace kuzu::transaction; namespace kuzu { namespace storage { @@ -30,14 +31,14 @@ StorageManager::StorageManager(bool readOnly, const Catalog& catalog, MemoryMana } void StorageManager::loadTables(bool readOnly, const catalog::Catalog& catalog) { - for (auto& schema : catalog.getReadOnlyVersion()->getNodeTableSchemas()) { + for (auto& schema : catalog.getNodeTableSchemas(&DUMMY_READ_TRANSACTION)) { KU_ASSERT(!tables.contains(schema->tableID)); auto nodeTableSchema = reinterpret_cast(schema); tables[schema->tableID] = std::make_unique(dataFH.get(), metadataFH.get(), nodeTableSchema, nodesStatisticsAndDeletedIDs.get(), &memoryManager, wal, readOnly, enableCompression); } - for (auto schema : catalog.getReadOnlyVersion()->getRelTableSchemas()) { + for (auto schema : catalog.getRelTableSchemas(&DUMMY_READ_TRANSACTION)) { KU_ASSERT(!tables.contains(schema->tableID)); auto relTableSchema = dynamic_cast(schema); tables[schema->tableID] = std::make_unique(dataFH.get(), metadataFH.get(), @@ -45,8 +46,9 @@ void StorageManager::loadTables(bool readOnly, const catalog::Catalog& catalog) } } -void StorageManager::createTable(common::table_id_t tableID, catalog::Catalog* catalog) { - auto tableSchema = catalog->getReadOnlyVersion()->getTableSchema(tableID); +void StorageManager::createTable( + common::table_id_t tableID, Catalog* catalog, Transaction* transaction) { + auto tableSchema = catalog->getTableSchema(transaction, tableID); switch (tableSchema->tableType) { case TableType::NODE: { auto nodeTableSchema = reinterpret_cast(tableSchema); diff --git a/src/storage/wal_replayer.cpp b/src/storage/wal_replayer.cpp index 1dcd4b6539..ccfd85f3e6 100644 --- a/src/storage/wal_replayer.cpp +++ b/src/storage/wal_replayer.cpp @@ -5,10 +5,12 @@ #include "storage/storage_manager.h" #include "storage/storage_utils.h" #include "storage/wal_replayer_utils.h" +#include "transaction/transaction.h" using namespace kuzu::catalog; using namespace kuzu::common; using namespace kuzu::storage; +using namespace kuzu::transaction; namespace kuzu { namespace storage { @@ -163,9 +165,9 @@ void WALReplayer::replayCreateTableRecord(const WALRecord& walRecord) { // record. auto catalogForCheckpointing = getCatalogForRecovery(FileVersionType::WAL_VERSION); if (walRecord.copyTableRecord.tableType == TableType::NODE) { - auto nodeTableSchema = reinterpret_cast( - catalogForCheckpointing->getReadOnlyVersion()->getTableSchema( - walRecord.createTableRecord.tableID)); + auto nodeTableSchema = + reinterpret_cast(catalogForCheckpointing->getTableSchema( + &DUMMY_READ_TRANSACTION, walRecord.createTableRecord.tableID)); WALReplayerUtils::createEmptyHashIndexFiles(nodeTableSchema, wal->getDirectory()); } if (!isRecovering) { @@ -173,8 +175,8 @@ void WALReplayer::replayCreateTableRecord(const WALRecord& walRecord) { // then we need to create the NodeTable object for the newly created node table. // Therefore, this effectively fixes the in-memory data structures (i.e., performs // the in-memory checkpointing). - storageManager->createTable( - walRecord.createTableRecord.tableID, catalogForCheckpointing.get()); + storageManager->createTable(walRecord.createTableRecord.tableID, + catalogForCheckpointing.get(), &DUMMY_READ_TRANSACTION); } } else { // Since DDL statements are single statements that are auto committed, it is @@ -243,7 +245,7 @@ void WALReplayer::replayCopyTableRecord(const kuzu::storage::WALRecord& walRecor // have likely changed they need to reconstruct their page locks). if (walRecord.copyTableRecord.tableType == TableType::NODE) { auto nodeTableSchema = reinterpret_cast( - catalog->getReadOnlyVersion()->getTableSchema(tableID)); + catalog->getTableSchema(&DUMMY_READ_TRANSACTION, tableID)); storageManager->getNodeTable(tableID)->initializePKIndex( nodeTableSchema, false /* readOnly */); } @@ -264,7 +266,7 @@ void WALReplayer::replayDropTableRecord(const kuzu::storage::WALRecord& walRecor if (isCheckpoint) { auto tableID = walRecord.dropTableRecord.tableID; if (!isRecovering) { - auto tableSchema = catalog->getReadOnlyVersion()->getTableSchema(tableID); + auto tableSchema = catalog->getTableSchema(&DUMMY_READ_TRANSACTION, tableID); switch (tableSchema->getTableType()) { case TableType::NODE: { storageManager->dropTable(tableID); @@ -286,7 +288,7 @@ void WALReplayer::replayDropTableRecord(const kuzu::storage::WALRecord& walRecor return; } auto catalogForRecovery = getCatalogForRecovery(FileVersionType::ORIGINAL); - auto tableSchema = catalogForRecovery->getReadOnlyVersion()->getTableSchema(tableID); + auto tableSchema = catalogForRecovery->getTableSchema(&DUMMY_READ_TRANSACTION, tableID); switch (tableSchema->getTableType()) { case TableType::NODE: { // TODO(Guodong): Do nothing for now. Should remove metaDA and reclaim free pages. @@ -311,7 +313,7 @@ void WALReplayer::replayDropPropertyRecord(const kuzu::storage::WALRecord& walRe auto tableID = walRecord.dropPropertyRecord.tableID; auto propertyID = walRecord.dropPropertyRecord.propertyID; if (!isRecovering) { - auto tableSchema = catalog->getReadOnlyVersion()->getTableSchema(tableID); + auto tableSchema = catalog->getTableSchema(&DUMMY_READ_TRANSACTION, tableID); switch (tableSchema->getTableType()) { case TableType::NODE: { storageManager->getNodeTable(tableID)->dropColumn( @@ -343,7 +345,7 @@ void WALReplayer::replayAddPropertyRecord(const kuzu::storage::WALRecord& walRec auto tableID = walRecord.addPropertyRecord.tableID; auto propertyID = walRecord.addPropertyRecord.propertyID; if (!isCheckpoint) { - auto tableSchema = catalog->getReadOnlyVersion()->getTableSchema(tableID); + auto tableSchema = catalog->getTableSchema(&DUMMY_READ_TRANSACTION, tableID); switch (tableSchema->getTableType()) { case TableType::NODE: { storageManager->getNodeTable(tableID)->dropColumn(tableSchema->getColumnID(propertyID)); diff --git a/src/transaction/transaction_context.cpp b/src/transaction/transaction_context.cpp index 60708f458f..395496c4f5 100644 --- a/src/transaction/transaction_context.cpp +++ b/src/transaction/transaction_context.cpp @@ -32,7 +32,9 @@ void TransactionContext::beginWriteTransaction() { } void TransactionContext::beginAutoTransaction(bool readOnlyStatement) { - KU_ASSERT(!hasActiveTransaction() && mode == TransactionMode::AUTO); + if (mode == TransactionMode::AUTO && hasActiveTransaction()) { + activeTransaction.reset(); + } beginTransactionInternal( readOnlyStatement ? TransactionType::READ_ONLY : TransactionType::WRITE); } diff --git a/test/runner/e2e_copy_transaction_test.cpp b/test/runner/e2e_copy_transaction_test.cpp index 61e1c81dd7..0bcd15383d 100644 --- a/test/runner/e2e_copy_transaction_test.cpp +++ b/test/runner/e2e_copy_transaction_test.cpp @@ -4,6 +4,7 @@ #include "graph_test/graph_test.h" #include "processor/plan_mapper.h" #include "processor/processor.h" +#include "transaction/transaction.h" using namespace kuzu::catalog; using namespace kuzu::common; @@ -11,6 +12,7 @@ using namespace kuzu::main; using namespace kuzu::processor; using namespace kuzu::storage; using namespace kuzu::testing; +using namespace kuzu::transaction; namespace kuzu { namespace testing { @@ -44,8 +46,6 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest { } void validateDatabaseStateBeforeCheckPointCopyNode(table_id_t tableID) { - auto nodeTableSchema = - (NodeTableSchema*)catalog->getReadOnlyVersion()->getTableSchema(tableID); ASSERT_EQ(std::make_unique(database.get()) ->query("MATCH (p:person) return *") ->getNumTuples(), @@ -56,8 +56,6 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest { } void validateDatabaseStateAfterCheckPointCopyNode(table_id_t tableID) { - auto nodeTableSchema = - (NodeTableSchema*)catalog->getReadOnlyVersion()->getTableSchema(tableID); validateTinysnbPersonAgeProperty(); ASSERT_EQ(getStorageManager(*database)->getNodesStatisticsAndDeletedIDs()->getMaxNodeOffset( &transaction::DUMMY_READ_TRANSACTION, tableID), @@ -67,7 +65,6 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest { void copyNodeCSVCommitAndRecoveryTest(TransactionTestType transactionTestType) { conn->query(createPersonTableCMD); auto preparedStatement = conn->prepare(copyPersonTableCMD); - conn->query("BEGIN TRANSACTION"); auto mapper = PlanMapper( *getStorageManager(*database), getMemoryManager(*database), getCatalog(*database)); auto physicalPlan = @@ -75,7 +72,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest { preparedStatement->statementResult->getColumns()); executionContext->clientContext->resetActiveQuery(); getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get()); - auto tableID = catalog->getReadOnlyVersion()->getTableID("person"); + auto tableID = catalog->getTableID(&DUMMY_READ_TRANSACTION, "person"); validateDatabaseStateBeforeCheckPointCopyNode(tableID); if (transactionTestType == TransactionTestType::RECOVERY) { conn->query("COMMIT_SKIP_CHECKPOINT"); @@ -115,8 +112,6 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest { } void validateDatabaseStateBeforeCheckPointCopyRel(table_id_t tableID) { - auto relTableSchema = - (RelTableSchema*)catalog->getReadOnlyVersion()->getTableSchema(tableID); auto dummyWriteTrx = transaction::Transaction::getDummyWriteTrx(); ASSERT_EQ(getStorageManager(*database)->getRelsStatistics()->getNextRelOffset( dummyWriteTrx.get(), tableID), @@ -124,8 +119,6 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest { } void validateDatabaseStateAfterCheckPointCopyRel(table_id_t knowsTableID) { - auto relTableSchema = - (RelTableSchema*)catalog->getReadOnlyVersion()->getTableSchema(knowsTableID); validateTinysnbKnowsDateProperty(); auto relsStatistics = getStorageManager(*database)->getRelsStatistics(); auto dummyWriteTrx = transaction::Transaction::getDummyWriteTrx(); @@ -142,7 +135,6 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest { conn->query(copyPersonTableCMD); conn->query(createKnowsTableCMD); auto preparedStatement = conn->prepare(copyKnowsTableCMD); - conn->query("BEGIN TRANSACTION"); auto mapper = PlanMapper( *getStorageManager(*database), getMemoryManager(*database), getCatalog(*database)); auto physicalPlan = @@ -150,7 +142,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest { preparedStatement->statementResult->getColumns()); executionContext->clientContext->resetActiveQuery(); getQueryProcessor(*database)->execute(physicalPlan.get(), executionContext.get()); - auto tableID = catalog->getReadOnlyVersion()->getTableID("knows"); + auto tableID = catalog->getTableID(&DUMMY_READ_TRANSACTION, "knows"); validateDatabaseStateBeforeCheckPointCopyRel(tableID); if (transactionTestType == TransactionTestType::RECOVERY) { conn->query("COMMIT_SKIP_CHECKPOINT"); diff --git a/test/runner/e2e_ddl_test.cpp b/test/runner/e2e_ddl_test.cpp index e66b599c27..3543a94996 100644 --- a/test/runner/e2e_ddl_test.cpp +++ b/test/runner/e2e_ddl_test.cpp @@ -26,8 +26,8 @@ class TinySnbDDLTest : public DBTest { memoryManager = std::make_unique(bufferManager.get()); executionContext = std::make_unique(1 /* numThreads */, profiler.get(), memoryManager.get(), bufferManager.get(), conn->clientContext.get()); - personTableID = catalog->getReadOnlyVersion()->getTableID("person"); - studyAtTableID = catalog->getReadOnlyVersion()->getTableID("studyAt"); + personTableID = catalog->getTableID(&DUMMY_READ_TRANSACTION, "person"); + studyAtTableID = catalog->getTableID(&DUMMY_READ_TRANSACTION, "studyAt"); } void initWithoutLoadingGraph() { @@ -40,7 +40,7 @@ class TinySnbDDLTest : public DBTest { } void validateDatabaseStateAfterCommitCreateNodeTable() { - ASSERT_TRUE(catalog->getReadOnlyVersion()->containsNodeTable("EXAM_PAPER")); + ASSERT_TRUE(catalog->containsTable(&DUMMY_READ_TRANSACTION, "EXAM_PAPER")); ASSERT_EQ(getStorageManager(*database) ->getNodesStatisticsAndDeletedIDs() ->getNumNodeStatisticsAndDeleteIDsPerTable(), @@ -49,35 +49,32 @@ class TinySnbDDLTest : public DBTest { // Since DDL statements are in an auto-commit transaction, we can't use the query interface to // test the recovery algorithm and parallel read. - void createNodeTable(TransactionTestType transactionTestType) { - executeQueryWithoutCommit( - "CREATE NODE TABLE EXAM_PAPER(STUDENT_ID INT64, MARK DOUBLE, PRIMARY KEY(STUDENT_ID))"); - ASSERT_FALSE(catalog->getReadOnlyVersion()->containsNodeTable("EXAM_PAPER")); - if (transactionTestType == TransactionTestType::RECOVERY) { - conn->query("COMMIT_SKIP_CHECKPOINT"); - ASSERT_FALSE(catalog->getReadOnlyVersion()->containsNodeTable("EXAM_PAPER")); - ASSERT_EQ(getStorageManager(*database) - ->getNodesStatisticsAndDeletedIDs() - ->getNumNodeStatisticsAndDeleteIDsPerTable(), - 3); - initWithoutLoadingGraph(); - } else { - conn->query("COMMIT"); + void createTable(TableType tableType, TransactionTestType transactionTestType) { + std::string tableName; + switch (tableType) { + case TableType::NODE: { + tableName = "EXAM_PAPER"; + executeQueryWithoutCommit("CREATE NODE TABLE EXAM_PAPER(STUDENT_ID INT64, MARK DOUBLE, " + "PRIMARY KEY(STUDENT_ID))"); + } break; + case TableType::REL: { + tableName = "likes"; + executeQueryWithoutCommit( + "CREATE REL TABLE likes(FROM person TO organisation, RATING INT64, MANY_ONE)"); + } break; + default: { + KU_UNREACHABLE; } - validateDatabaseStateAfterCommitCreateNodeTable(); - } - - void createRelTable(TransactionTestType transactionTestType) { - executeQueryWithoutCommit( - "CREATE REL TABLE likes(FROM person TO organisation, RATING INT64, MANY_ONE)"); - ASSERT_FALSE(catalog->getReadOnlyVersion()->containsRelTable("likes")); + } + ASSERT_FALSE(catalog->containsTable(&DUMMY_READ_TRANSACTION, tableName)); if (transactionTestType == TransactionTestType::RECOVERY) { conn->query("COMMIT_SKIP_CHECKPOINT"); - ASSERT_FALSE(catalog->getReadOnlyVersion()->containsRelTable("likes")); + ASSERT_FALSE(catalog->containsTable(&DUMMY_READ_TRANSACTION, tableName)); initWithoutLoadingGraph(); } else { conn->query("COMMIT"); } + ASSERT_TRUE(catalog->containsTable(&DUMMY_READ_TRANSACTION, tableName)); } void validateBelongsRelTable() { @@ -106,69 +103,25 @@ class TinySnbDDLTest : public DBTest { "Binder exception: Nodes a and b are not connected through rel e."); } - void createRelMixedRelationCommitAndRecoveryTest(TransactionTestType transactionTestType) { - conn->query("CREATE NODE TABLE country(id INT64, PRIMARY KEY(id));"); - conn->query("CREATE (c:country{id: 0});"); - executeQueryWithoutCommit( - "CREATE REL TABLE belongs(FROM person TO organisation, FROM organisation TO country);"); - ASSERT_FALSE(catalog->getReadOnlyVersion()->containsRelTable("belongs")); - if (transactionTestType == TransactionTestType::RECOVERY) { - conn->query("COMMIT_SKIP_CHECKPOINT"); - initWithoutLoadingGraph(); - ASSERT_TRUE(catalog->getReadOnlyVersion()->containsRelTable("belongs")); - } else { - conn->query("COMMIT"); - ASSERT_TRUE(catalog->getReadOnlyVersion()->containsRelTable("belongs")); - } - executeQueryWithoutCommit("COPY belongs FROM \"" + - TestHelper::appendKuzuRootPath("dataset/tinysnb/eBelongs.csv\"")); - if (transactionTestType == TransactionTestType::RECOVERY) { - conn->query("COMMIT_SKIP_CHECKPOINT"); - initWithoutLoadingGraph(); - } else { - conn->query("COMMIT"); - } - validateBelongsRelTable(); - } - - void dropNodeTableCommitAndRecoveryTest(TransactionTestType transactionTestType) { - conn->query("CREATE NODE TABLE university(address STRING, PRIMARY KEY(address));"); - auto tableSchema = - catalog->getReadOnlyVersion() - ->getTableSchema(catalog->getReadOnlyVersion()->getTableID("university")) - ->copy(); - executeQueryWithoutCommit("DROP TABLE university"); - ASSERT_TRUE(catalog->getReadOnlyVersion()->containsNodeTable("university")); + void dropTableCommitAndRecoveryTest( + std::string tableName, TransactionTestType transactionTestType) { + auto tableID = catalog->getTableID(&DUMMY_READ_TRANSACTION, tableName); + auto tableSchema = catalog->getTableSchema(&DUMMY_READ_TRANSACTION, tableID)->copy(); + executeQueryWithoutCommit(stringFormat("DROP TABLE {}", tableName)); + ASSERT_TRUE(catalog->containsTable(&DUMMY_READ_TRANSACTION, tableName)); if (transactionTestType == TransactionTestType::RECOVERY) { conn->query("COMMIT_SKIP_CHECKPOINT"); - ASSERT_TRUE(catalog->getReadOnlyVersion()->containsNodeTable("university")); + ASSERT_TRUE(catalog->containsTable(&DUMMY_READ_TRANSACTION, tableName)); initWithoutLoadingGraph(); } else { conn->query("COMMIT"); } - ASSERT_FALSE(catalog->getReadOnlyVersion()->containsNodeTable("university")); - } - - void dropRelTableCommitAndRecoveryTest(TransactionTestType transactionTestType) { - auto tableSchema = catalog->getReadOnlyVersion() - ->getTableSchema(catalog->getReadOnlyVersion()->getTableID("knows")) - ->copy(); - executeQueryWithoutCommit("DROP TABLE knows"); - ASSERT_TRUE(catalog->getReadOnlyVersion()->containsRelTable("knows")); - if (transactionTestType == TransactionTestType::RECOVERY) { - conn->query("COMMIT_SKIP_CHECKPOINT"); - ASSERT_TRUE(catalog->getReadOnlyVersion()->containsRelTable("knows")); - initWithoutLoadingGraph(); - } else { - conn->query("COMMIT"); - } - ASSERT_FALSE(catalog->getReadOnlyVersion()->containsRelTable("knows")); + ASSERT_FALSE(catalog->containsTable(&DUMMY_READ_TRANSACTION, tableName)); } void dropNodeTableProperty(TransactionTestType transactionTestType) { executeQueryWithoutCommit("ALTER TABLE person DROP gender"); - ASSERT_TRUE(catalog->getReadOnlyVersion() - ->getTableSchema(personTableID) + ASSERT_TRUE(catalog->getTableSchema(&DUMMY_READ_TRANSACTION, personTableID) ->containProperty("gender")); if (transactionTestType == TransactionTestType::RECOVERY) { conn->query("COMMIT_SKIP_CHECKPOINT"); @@ -177,8 +130,7 @@ class TinySnbDDLTest : public DBTest { } else { conn->query("COMMIT"); } - ASSERT_FALSE(catalog->getReadOnlyVersion() - ->getTableSchema(personTableID) + ASSERT_FALSE(catalog->getTableSchema(&DUMMY_READ_TRANSACTION, personTableID) ->containProperty("gender")); auto result = conn->query("MATCH (p:person) RETURN * ORDER BY p.ID LIMIT 1"); ASSERT_EQ(TestHelper::convertResultToString(*result), @@ -192,8 +144,7 @@ class TinySnbDDLTest : public DBTest { void dropRelTableProperty(TransactionTestType transactionTestType) { executeQueryWithoutCommit("ALTER TABLE studyAt DROP places"); - ASSERT_TRUE(catalog->getReadOnlyVersion() - ->getTableSchema(studyAtTableID) + ASSERT_TRUE(catalog->getTableSchema(&DUMMY_READ_TRANSACTION, studyAtTableID) ->containProperty("places")); if (transactionTestType == TransactionTestType::RECOVERY) { conn->query("COMMIT_SKIP_CHECKPOINT"); @@ -202,8 +153,7 @@ class TinySnbDDLTest : public DBTest { } else { conn->query("COMMIT"); } - ASSERT_FALSE(catalog->getReadOnlyVersion() - ->getTableSchema(personTableID) + ASSERT_FALSE(catalog->getTableSchema(&DUMMY_READ_TRANSACTION, personTableID) ->containProperty("places")); auto result = conn->query( "MATCH (:person)-[s:studyAt]->(:organisation) RETURN * ORDER BY s.year DESC LIMIT 1"); @@ -216,7 +166,6 @@ class TinySnbDDLTest : public DBTest { void executeQueryWithoutCommit(std::string query) { auto preparedStatement = conn->prepare(query); - conn->query("BEGIN TRANSACTION"); auto mapper = PlanMapper( *getStorageManager(*database), getMemoryManager(*database), getCatalog(*database)); auto physicalPlan = @@ -297,21 +246,22 @@ class TinySnbDDLTest : public DBTest { void renameTable(TransactionTestType transactionTestType) { executeQueryWithoutCommit("ALTER TABLE person RENAME TO student"); - ASSERT_EQ(catalog->getWriteVersion()->getTableSchema(personTableID)->tableName, "student"); ASSERT_EQ( - catalog->getReadOnlyVersion()->getTableSchema(personTableID)->tableName, "person"); + catalog->getTableSchema(&DUMMY_WRITE_TRANSACTION, personTableID)->tableName, "student"); + ASSERT_EQ( + catalog->getTableSchema(&DUMMY_READ_TRANSACTION, personTableID)->tableName, "person"); if (transactionTestType == TransactionTestType::RECOVERY) { conn->query("COMMIT_SKIP_CHECKPOINT"); - ASSERT_EQ( - catalog->getWriteVersion()->getTableSchema(personTableID)->tableName, "student"); - ASSERT_EQ( - catalog->getReadOnlyVersion()->getTableSchema(personTableID)->tableName, "person"); + ASSERT_EQ(catalog->getTableSchema(&DUMMY_WRITE_TRANSACTION, personTableID)->tableName, + "student"); + ASSERT_EQ(catalog->getTableSchema(&DUMMY_READ_TRANSACTION, personTableID)->tableName, + "person"); initWithoutLoadingGraph(); } else { conn->query("COMMIT"); } ASSERT_EQ( - catalog->getReadOnlyVersion()->getTableSchema(personTableID)->tableName, "student"); + catalog->getTableSchema(&DUMMY_READ_TRANSACTION, personTableID)->tableName, "student"); auto result = conn->query("MATCH (s:student) return s.age order by s.age"); ASSERT_EQ(TestHelper::convertResultToString(*result), std::vector({"20", "20", "25", "30", "35", "40", "45", "83"})); @@ -319,23 +269,22 @@ class TinySnbDDLTest : public DBTest { void renameProperty(TransactionTestType transactionTestType) { executeQueryWithoutCommit("ALTER TABLE person RENAME fName TO name"); - ASSERT_TRUE( - catalog->getWriteVersion()->getTableSchema(personTableID)->containProperty("name")); - ASSERT_TRUE( - catalog->getReadOnlyVersion()->getTableSchema(personTableID)->containProperty("fName")); + ASSERT_TRUE(catalog->getTableSchema(&DUMMY_WRITE_TRANSACTION, personTableID) + ->containProperty("name")); + ASSERT_TRUE(catalog->getTableSchema(&DUMMY_READ_TRANSACTION, personTableID) + ->containProperty("fName")); if (transactionTestType == TransactionTestType::RECOVERY) { conn->query("COMMIT_SKIP_CHECKPOINT"); - ASSERT_TRUE( - catalog->getWriteVersion()->getTableSchema(personTableID)->containProperty("name")); - ASSERT_TRUE(catalog->getReadOnlyVersion() - ->getTableSchema(personTableID) + ASSERT_TRUE(catalog->getTableSchema(&DUMMY_WRITE_TRANSACTION, personTableID) + ->containProperty("name")); + ASSERT_TRUE(catalog->getTableSchema(&DUMMY_READ_TRANSACTION, personTableID) ->containProperty("fName")); initWithoutLoadingGraph(); } else { conn->query("COMMIT"); } - ASSERT_TRUE( - catalog->getReadOnlyVersion()->getTableSchema(personTableID)->containProperty("name")); + ASSERT_TRUE(catalog->getTableSchema(&DUMMY_READ_TRANSACTION, personTableID) + ->containProperty("name")); auto result = conn->query("MATCH (p:person) return p.name order by p.name"); ASSERT_EQ(TestHelper::convertResultToString(*result), std::vector({"Alice", "Bob", "Carol", "Dan", "Elizabeth", "Farooq", "Greg", @@ -352,35 +301,35 @@ class TinySnbDDLTest : public DBTest { }; TEST_F(TinySnbDDLTest, CreateNodeTableCommitNormalExecution) { - createNodeTable(TransactionTestType::NORMAL_EXECUTION); + createTable(TableType::NODE, TransactionTestType::NORMAL_EXECUTION); } TEST_F(TinySnbDDLTest, CreateNodeTableCommitRecovery) { - createNodeTable(TransactionTestType::RECOVERY); + createTable(TableType::NODE, TransactionTestType::RECOVERY); } TEST_F(TinySnbDDLTest, CreateRelTableCommitNormalExecution) { - createRelTable(TransactionTestType::NORMAL_EXECUTION); + createTable(TableType::REL, TransactionTestType::NORMAL_EXECUTION); } TEST_F(TinySnbDDLTest, CreateRelTableCommitRecovery) { - createRelTable(TransactionTestType::RECOVERY); + createTable(TableType::REL, TransactionTestType::RECOVERY); } TEST_F(TinySnbDDLTest, DropNodeTableCommitNormalExecution) { - dropNodeTableCommitAndRecoveryTest(TransactionTestType::NORMAL_EXECUTION); + dropTableCommitAndRecoveryTest("movies", TransactionTestType::NORMAL_EXECUTION); } TEST_F(TinySnbDDLTest, DropNodeTableCommitRecovery) { - dropNodeTableCommitAndRecoveryTest(TransactionTestType::RECOVERY); + dropTableCommitAndRecoveryTest("movies", TransactionTestType::RECOVERY); } TEST_F(TinySnbDDLTest, DropRelTableCommitNormalExecution) { - dropRelTableCommitAndRecoveryTest(TransactionTestType::NORMAL_EXECUTION); + dropTableCommitAndRecoveryTest("knows", TransactionTestType::NORMAL_EXECUTION); } TEST_F(TinySnbDDLTest, DropRelTableCommitRecovery) { - dropRelTableCommitAndRecoveryTest(TransactionTestType::RECOVERY); + dropTableCommitAndRecoveryTest("knows", TransactionTestType::RECOVERY); } TEST_F(TinySnbDDLTest, DropNodeTablePropertyNormalExecution) { @@ -667,7 +616,8 @@ TEST_F(TinySnbDDLTest, AddListOfListOfStringPropertyToStudyAtTableWithDefaultVal TEST_F(TinySnbDDLTest, AddListOfListOfStringPropertyToStudyAtTableWithDefaultValueRecovery) { addPropertyToStudyAtTableWithDefaultValue("STRING[][]" /* propertyType */, - "[['hello','good','long long string test'],['6'],['very very long string']]" /* defaultVal*/ + "[['hello','good','long long string test'],['6'],['very very long string']]" /* + defaultVal*/ , TransactionTestType::RECOVERY); } diff --git a/test/test_files/ddl/ddl.test b/test/test_files/ddl/ddl.test index 4d7b08e695..2966e85553 100644 --- a/test/test_files/ddl/ddl.test +++ b/test/test_files/ddl/ddl.test @@ -42,10 +42,10 @@ Node table: university has been created. Rel table: nearTo has been created. -STATEMENT DROP TABLE nearTo; ---- 1 -Rel table: nearTo has been dropped. +Table: nearTo has been dropped. -STATEMENT DROP TABLE university ---- 1 -Node table: university has been dropped. +Table: university has been dropped. -STATEMENT ALTER TABLE person DROP fName ---- 1 Drop succeed. diff --git a/tools/python_api/src_cpp/include/pandas/pandas_scan.h b/tools/python_api/src_cpp/include/pandas/pandas_scan.h index 3ed952f251..49100c2e24 100644 --- a/tools/python_api/src_cpp/include/pandas/pandas_scan.h +++ b/tools/python_api/src_cpp/include/pandas/pandas_scan.h @@ -29,7 +29,7 @@ struct PandasScanFunction { static void tableFunc(function::TableFunctionInput& input, common::DataChunk& outputChunk); static std::unique_ptr bindFunc(main::ClientContext* /*context*/, - function::TableFuncBindInput* input, catalog::CatalogContent* catalog); + function::TableFuncBindInput* input, catalog::Catalog* catalog); static std::unique_ptr initSharedState( function::TableFunctionInitInput& input); diff --git a/tools/python_api/src_cpp/pandas/pandas_scan.cpp b/tools/python_api/src_cpp/pandas/pandas_scan.cpp index e5e3c02df2..39ca96492a 100644 --- a/tools/python_api/src_cpp/pandas/pandas_scan.cpp +++ b/tools/python_api/src_cpp/pandas/pandas_scan.cpp @@ -20,7 +20,7 @@ function_set PandasScanFunction::getFunctionSet() { } std::unique_ptr PandasScanFunction::bindFunc( - main::ClientContext* /*context*/, TableFuncBindInput* input, CatalogContent* /*catalog*/) { + main::ClientContext* /*context*/, TableFuncBindInput* input, Catalog* /*catalog*/) { py::gil_scoped_acquire acquire; py::handle df(reinterpret_cast(input->inputs[0]->getValue())); std::vector> columnBindData; diff --git a/tools/python_api/src_cpp/py_connection.cpp b/tools/python_api/src_cpp/py_connection.cpp index 6abc474573..d1a678b4d7 100644 --- a/tools/python_api/src_cpp/py_connection.cpp +++ b/tools/python_api/src_cpp/py_connection.cpp @@ -42,11 +42,11 @@ void PyConnection::setQueryTimeout(uint64_t timeoutInMS) { } static std::unordered_map> transformPythonParameters( - const py::dict& params); + const py::dict& params, Connection* conn); std::unique_ptr PyConnection::execute( PyPreparedStatement* preparedStatement, const py::dict& params) { - auto parameters = transformPythonParameters(params); + auto parameters = transformPythonParameters(params, conn.get()); py::gil_scoped_release release; auto queryResult = conn->executeWithParams(preparedStatement->preparedStatement.get(), std::move(parameters)); @@ -156,11 +156,12 @@ bool PyConnection::isPandasDataframe(const py::object& object) { static Value transformPythonValue(py::handle val); -std::unordered_map> transformPythonParameters( - const py::dict& params) { +std::unordered_map> transformPythonParameters(const py::dict& params, Connection* conn) { std::unordered_map> result; for (auto& [key, value] : params) { if (!py::isinstance(key)) { + //TODO(Chang): remove ROLLBACK once we can guarantee database is deleted after conn + conn->query("ROLLBACK"); throw std::runtime_error("Parameter name must be of type string but get " + py::str(key.get_type()).cast()); } diff --git a/tools/python_api/src_py/connection.py b/tools/python_api/src_py/connection.py index 6446b86697..102073d86d 100644 --- a/tools/python_api/src_py/connection.py +++ b/tools/python_api/src_py/connection.py @@ -74,6 +74,8 @@ def execute(self, query, parameters={}): """ self.init_connection() if type(parameters) != dict: + # TODO(Chang): remove ROLLBACK once we can guarantee database is deleted after conn + self._connection.execute(self.prepare("ROLLBACK")._prepared_statement, {}) raise RuntimeError("Parameters must be a dict") prepared_statement = self.prepare( query) if type(query) == str else query diff --git a/tools/shell/embedded_shell.cpp b/tools/shell/embedded_shell.cpp index 92911ce4c2..c6c9a59b3b 100644 --- a/tools/shell/embedded_shell.cpp +++ b/tools/shell/embedded_shell.cpp @@ -10,6 +10,7 @@ #include "common/string_utils.h" #include "utf8proc.h" #include "utf8proc_wrapper.h" +#include "transaction/transaction.h" using namespace kuzu::common; using namespace kuzu::utf8proc; @@ -63,10 +64,10 @@ static Connection* globalConnection; void EmbeddedShell::updateTableNames() { nodeTableNames.clear(); relTableNames.clear(); - for (auto& tableSchema : database->catalog->getReadOnlyVersion()->getNodeTableSchemas()) { + for (auto& tableSchema : database->catalog->getNodeTableSchemas(&transaction::DUMMY_READ_TRANSACTION)) { nodeTableNames.push_back(tableSchema->tableName); } - for (auto& tableSchema : database->catalog->getReadOnlyVersion()->getRelTableSchemas()) { + for (auto& tableSchema : database->catalog->getRelTableSchemas(&transaction::DUMMY_READ_TRANSACTION)) { relTableNames.push_back(tableSchema->tableName); } }