diff --git a/CMakeLists.txt b/CMakeLists.txt index d37ba66b06..cfdd71d603 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -217,7 +217,7 @@ add_subdirectory(third_party) if(${BUILD_KUZU}) add_definitions(-DKUZU_ROOT_DIRECTORY="${PROJECT_SOURCE_DIR}") add_definitions(-DKUZU_CMAKE_VERSION="${CMAKE_PROJECT_VERSION}") -add_definitions(-DKUZU_EXTENSION_VERSION="0.2.9") +add_definitions(-DKUZU_EXTENSION_VERSION="0.3.0") include_directories(src/include) diff --git a/extension/duckdb/src/duckdb_functions.cpp b/extension/duckdb/src/duckdb_functions.cpp index 5253e5037f..4c6f8be0d6 100644 --- a/extension/duckdb/src/duckdb_functions.cpp +++ b/extension/duckdb/src/duckdb_functions.cpp @@ -10,32 +10,32 @@ using namespace kuzu::common; namespace kuzu { namespace duckdb_extension { -static offset_t DuckDBClearCacheTableFunc(TableFuncInput& input, TableFuncOutput& output) { +static offset_t clearCacheTableFunc(TableFuncInput& input, TableFuncOutput& output) { auto& dataChunk = output.dataChunk; auto sharedState = input.sharedState->ptrCast(); auto morsel = sharedState->getMorsel(); if (!morsel.hasMoreToOutput()) { return 0; } - auto bindData = input.bindData->constPtrCast(); - bindData->databaseManager->invalidateCache(DuckDBStorageExtension::dbType); + auto bindData = input.bindData->constPtrCast(); + bindData->databaseManager->invalidateCache(); dataChunk.getValueVector(0)->setValue(0, - "All attached duckdb database caches have been cleared."); + "All attached database caches have been cleared."); return 1; } -static std::unique_ptr DuckDBClearCacheBindFunc(ClientContext* context, +static std::unique_ptr clearCacheBindFunc(ClientContext* context, TableFuncBindInput*) { std::vector columnNames; std::vector columnTypes; columnNames.emplace_back("message"); columnTypes.emplace_back(*LogicalType::STRING()); - return std::make_unique(context->getDatabaseManager(), + return std::make_unique(context->getDatabaseManager(), std::move(columnTypes), std::move(columnNames), 1 /* maxOffset */); } -DuckDBClearCacheFunction::DuckDBClearCacheFunction() - : TableFunction{name, DuckDBClearCacheTableFunc, DuckDBClearCacheBindFunc, +ClearCacheFunction::ClearCacheFunction() + : TableFunction{name, clearCacheTableFunc, clearCacheBindFunc, function::CallFunction::initSharedState, function::CallFunction::initEmptyLocalState, std::vector{}} {} diff --git a/extension/duckdb/src/duckdb_storage.cpp b/extension/duckdb/src/duckdb_storage.cpp index 5a197ca036..bd331a21f8 100644 --- a/extension/duckdb/src/duckdb_storage.cpp +++ b/extension/duckdb/src/duckdb_storage.cpp @@ -40,7 +40,7 @@ std::unique_ptr attachDuckDB(std::string dbName, std::st DuckDBStorageExtension::DuckDBStorageExtension(main::Database* database) : StorageExtension{attachDuckDB} { - auto duckDBClearCacheFunction = std::make_unique(); + auto duckDBClearCacheFunction = std::make_unique(); extension::ExtensionUtils::registerTableFunction(*database, std::move(duckDBClearCacheFunction)); } diff --git a/extension/duckdb/src/include/duckdb_functions.h b/extension/duckdb/src/include/duckdb_functions.h index 73b81080b3..4b37a8bb85 100644 --- a/extension/duckdb/src/include/duckdb_functions.h +++ b/extension/duckdb/src/include/duckdb_functions.h @@ -8,25 +8,25 @@ namespace kuzu { namespace duckdb_extension { -struct DuckDBClearCacheBindData : public function::CallTableFuncBindData { +struct ClearCacheBindData : public function::CallTableFuncBindData { main::DatabaseManager* databaseManager; - DuckDBClearCacheBindData(main::DatabaseManager* databaseManager, + ClearCacheBindData(main::DatabaseManager* databaseManager, std::vector returnTypes, std::vector returnColumnNames, common::offset_t maxOffset) : CallTableFuncBindData{std::move(returnTypes), std::move(returnColumnNames), maxOffset}, databaseManager{databaseManager} {} std::unique_ptr copy() const override { - return std::make_unique(databaseManager, columnTypes, columnNames, + return std::make_unique(databaseManager, columnTypes, columnNames, maxOffset); } }; -struct DuckDBClearCacheFunction final : public function::TableFunction { - static constexpr const char* name = "duckdb_clear_cache"; +struct ClearCacheFunction final : public function::TableFunction { + static constexpr const char* name = "clear_attached_db_cache"; - DuckDBClearCacheFunction(); + ClearCacheFunction(); }; } // namespace duckdb_extension diff --git a/extension/duckdb/test/test_files/duckdb.test b/extension/duckdb/test/test_files/duckdb.test index 9cb89b85d4..c762b269b9 100644 --- a/extension/duckdb/test/test_files/duckdb.test +++ b/extension/duckdb/test/test_files/duckdb.test @@ -54,12 +54,12 @@ Attached database successfully. Attached database successfully. -STATEMENT CALL SHOW_TABLES() RETURN *; ---- 6 -movies|FOREIGN|tinysnb(DUCKDB)| -organisation|FOREIGN|tinysnb(DUCKDB)| -person|FOREIGN|Other1(DUCKDB)| -person|FOREIGN|tinysnb(DUCKDB)| -tableOfTypes1|FOREIGN|tinysnb(DUCKDB)| -tableOfTypes|FOREIGN|tinysnb(DUCKDB)| +movies|ATTACHED|tinysnb(DUCKDB)| +organisation|ATTACHED|tinysnb(DUCKDB)| +person|ATTACHED|Other1(DUCKDB)| +person|ATTACHED|tinysnb(DUCKDB)| +tableOfTypes1|ATTACHED|tinysnb(DUCKDB)| +tableOfTypes|ATTACHED|tinysnb(DUCKDB)| -STATEMENT LOAD FROM other1.person RETURN *; ---- 4 1 @@ -88,7 +88,7 @@ Used database successfully. -STATEMENT LOAD FROM other1.person RETURN count(*); ---- 1 4 --STATEMENT CALL DUCKDB_CLEAR_CACHE() RETURN *; +-STATEMENT CALL CLEAR_ATTACHED_DB_CACHE() RETURN *; ---- ok -STATEMENT LOAD FROM other1.person RETURN count(*); ---- 1 diff --git a/extension/postgres/CMakeLists.txt b/extension/postgres/CMakeLists.txt index 8cced4f631..4e152f4a26 100644 --- a/extension/postgres/CMakeLists.txt +++ b/extension/postgres/CMakeLists.txt @@ -16,10 +16,10 @@ add_library(postgres_extension ../duckdb/src/duckdb_catalog.cpp ../duckdb/src/duckdb_table_catalog_entry.cpp ../duckdb/src/duckdb_type_converter.cpp + ../duckdb/src/duckdb_functions.cpp src/postgres_extension.cpp src/postgres_storage.cpp - src/postgres_catalog.cpp - src/postgres_functions.cpp) + src/postgres_catalog.cpp) include_directories( src/include diff --git a/extension/postgres/src/include/postgres_functions.h b/extension/postgres/src/include/postgres_functions.h deleted file mode 100644 index 4293a71c0f..0000000000 --- a/extension/postgres/src/include/postgres_functions.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include "function/table_functions.h" - -namespace kuzu { -namespace postgres_extension { - -struct PostgresClearCacheFunction final : public function::TableFunction { - static constexpr const char* name = "postgres_clear_cache"; - - PostgresClearCacheFunction(); -}; - -} // namespace postgres_extension -} // namespace kuzu diff --git a/extension/postgres/src/postgres_functions.cpp b/extension/postgres/src/postgres_functions.cpp deleted file mode 100644 index f94359b627..0000000000 --- a/extension/postgres/src/postgres_functions.cpp +++ /dev/null @@ -1,57 +0,0 @@ -#include "postgres_functions.h" - -#include "duckdb_functions.h" -#include "postgres_storage.h" - -using namespace kuzu::function; -using namespace kuzu::main; -using namespace kuzu::common; - -namespace kuzu { -namespace postgres_extension { - -struct PostgresClearCacheBindData : public duckdb_extension::DuckDBClearCacheBindData { - - PostgresClearCacheBindData(DatabaseManager* databaseManager, - std::vector returnTypes, std::vector returnColumnNames, - offset_t maxOffset) - : DuckDBClearCacheBindData{databaseManager, std::move(returnTypes), - std::move(returnColumnNames), maxOffset} {} - - std::unique_ptr copy() const override { - return std::make_unique(databaseManager, columnTypes, - columnNames, maxOffset); - } -}; - -static offset_t postgresClearCacheTableFunc(TableFuncInput& input, TableFuncOutput& output) { - auto& dataChunk = output.dataChunk; - auto sharedState = input.sharedState->ptrCast(); - auto morsel = sharedState->getMorsel(); - if (!morsel.hasMoreToOutput()) { - return 0; - } - auto bindData = input.bindData->constPtrCast(); - bindData->databaseManager->invalidateCache(PostgresStorageExtension::dbType); - dataChunk.getValueVector(0)->setValue(0, - "All attached duckdb database caches have been cleared."); - return 1; -} - -static std::unique_ptr postgresClearCacheBindFunc(ClientContext* context, - TableFuncBindInput*) { - std::vector columnNames; - std::vector columnTypes; - columnNames.emplace_back("message"); - columnTypes.emplace_back(*LogicalType::STRING()); - return std::make_unique(context->getDatabaseManager(), - std::move(columnTypes), std::move(columnNames), 1 /* maxOffset */); -} - -PostgresClearCacheFunction::PostgresClearCacheFunction() - : TableFunction{name, postgresClearCacheTableFunc, postgresClearCacheBindFunc, - function::CallFunction::initSharedState, function::CallFunction::initEmptyLocalState, - std::vector{}} {} - -} // namespace postgres_extension -} // namespace kuzu diff --git a/extension/postgres/src/postgres_storage.cpp b/extension/postgres/src/postgres_storage.cpp index befd3197f7..174fa3142b 100644 --- a/extension/postgres/src/postgres_storage.cpp +++ b/extension/postgres/src/postgres_storage.cpp @@ -2,12 +2,11 @@ #include -#include "catalog/catalog_entry/table_catalog_entry.h" #include "common/exception/runtime.h" #include "common/string_utils.h" +#include "duckdb_functions.h" #include "extension/extension.h" #include "postgres_catalog.h" -#include "postgres_functions.h" namespace kuzu { namespace postgres_extension { @@ -35,9 +34,8 @@ std::unique_ptr attachPostgres(std::string dbName, std:: PostgresStorageExtension::PostgresStorageExtension(main::Database* database) : StorageExtension{attachPostgres} { - auto postgresClearCacheFunction = std::make_unique(); - extension::ExtensionUtils::registerTableFunction(*database, - std::move(postgresClearCacheFunction)); + auto clearCacheFunction = std::make_unique(); + extension::ExtensionUtils::registerTableFunction(*database, std::move(clearCacheFunction)); } bool PostgresStorageExtension::canHandleDB(std::string dbType_) const { diff --git a/extension/postgres/test/test_files/postgres.test b/extension/postgres/test/test_files/postgres.test index 010967c84d..e2a21e5b95 100644 --- a/extension/postgres/test/test_files/postgres.test +++ b/extension/postgres/test/test_files/postgres.test @@ -39,7 +39,7 @@ Binder exception: No database named tinysnb1 has been attached. -STATEMENT LOAD FROM pgscan.movies RETURN count(*); ---- 1 3 --STATEMENT CALL POSTGRES_CLEAR_CACHE() RETURN *; +-STATEMENT CALL CLEAR_ATTACHED_DB_CACHE() RETURN *; ---- ok -STATEMENT LOAD FROM pgscan.movies where length > 2500 RETURN name; ---- 1 @@ -55,14 +55,14 @@ The 😂😃🧘🏻‍♂️🌍🌦️🍞🚗 movie # be displayed. -STATEMENT CALL SHOW_TABLES() RETURN *; ---- 8 -movies|FOREIGN|pgscan(POSTGRES)| -movies|FOREIGN|tinysnb(POSTGRES)| -organisation|FOREIGN|pgscan(POSTGRES)| -organisation|FOREIGN|tinysnb(POSTGRES)| -persontest|FOREIGN|pgscan(POSTGRES)| -persontest|FOREIGN|tinysnb(POSTGRES)| -person|FOREIGN|pgscan(POSTGRES)| -person|FOREIGN|tinysnb(POSTGRES)| +movies|ATTACHED|pgscan(POSTGRES)| +movies|ATTACHED|tinysnb(POSTGRES)| +organisation|ATTACHED|pgscan(POSTGRES)| +organisation|ATTACHED|tinysnb(POSTGRES)| +persontest|ATTACHED|pgscan(POSTGRES)| +persontest|ATTACHED|tinysnb(POSTGRES)| +person|ATTACHED|pgscan(POSTGRES)| +person|ATTACHED|tinysnb(POSTGRES)| -STATEMENT CALL TABLE_INFO('pgscan.person') RETURN *; ---- 14 diff --git a/src/catalog/catalog_content.cpp b/src/catalog/catalog_content.cpp index 1fa15b1c74..422e52c47b 100644 --- a/src/catalog/catalog_content.cpp +++ b/src/catalog/catalog_content.cpp @@ -10,7 +10,6 @@ #include "catalog/catalog_entry/rel_group_catalog_entry.h" #include "catalog/catalog_entry/rel_table_catalog_entry.h" #include "catalog/catalog_entry/scalar_macro_catalog_entry.h" -#include "common/cast.h" #include "common/exception/catalog.h" #include "common/exception/runtime.h" #include "common/file_system/virtual_file_system.h" @@ -65,9 +64,7 @@ table_id_t CatalogContent::createTable(const BoundCreateTableInfo& info) { std::unique_ptr CatalogContent::createNodeTableEntry(table_id_t tableID, const BoundCreateTableInfo& info) const { - auto extraInfo = - ku_dynamic_cast( - info.extraInfo.get()); + auto extraInfo = info.extraInfo->constPtrCast(); auto nodeTableEntry = std::make_unique(info.tableName, tableID, extraInfo->primaryKeyIdx); for (auto& propertyInfo : extraInfo->propertyInfos) { @@ -78,13 +75,11 @@ std::unique_ptr CatalogContent::createNodeTableEntry(table_id_t ta std::unique_ptr CatalogContent::createRelTableEntry(table_id_t tableID, const BoundCreateTableInfo& info) const { - auto extraInfo = - ku_dynamic_cast( - info.extraInfo.get()); - auto srcTableEntry = ku_dynamic_cast( - getTableCatalogEntry(extraInfo->srcTableID)); - auto dstTableEntry = ku_dynamic_cast( - getTableCatalogEntry(extraInfo->dstTableID)); + auto extraInfo = info.extraInfo->constPtrCast(); + auto srcTableEntry = + getTableCatalogEntry(extraInfo->srcTableID)->ptrCast(); + auto dstTableEntry = + getTableCatalogEntry(extraInfo->dstTableID)->ptrCast(); srcTableEntry->addFwdRelTableID(tableID); dstTableEntry->addBWdRelTableID(tableID); auto relTableEntry = @@ -98,9 +93,7 @@ std::unique_ptr CatalogContent::createRelTableEntry(table_id_t tab std::unique_ptr CatalogContent::createRelTableGroupEntry(table_id_t tableID, const BoundCreateTableInfo& info) { - auto extraInfo = - ku_dynamic_cast( - info.extraInfo.get()); + auto extraInfo = info.extraInfo->constPtrCast(); std::vector relTableIDs; relTableIDs.reserve(extraInfo->infos.size()); for (auto& childInfo : extraInfo->infos) { @@ -111,19 +104,15 @@ std::unique_ptr CatalogContent::createRelTableGroupEntry(table_id_ std::unique_ptr CatalogContent::createRdfGraphEntry(table_id_t tableID, const BoundCreateTableInfo& info) { - auto extraInfo = - ku_dynamic_cast( - info.extraInfo.get()); + auto extraInfo = info.extraInfo->constPtrCast(); auto& resourceInfo = extraInfo->resourceInfo; auto& literalInfo = extraInfo->literalInfo; auto& resourceTripleInfo = extraInfo->resourceTripleInfo; auto& literalTripleInfo = extraInfo->literalTripleInfo; auto resourceTripleExtraInfo = - ku_dynamic_cast( - resourceTripleInfo.extraInfo.get()); + resourceTripleInfo.extraInfo->ptrCast(); auto literalTripleExtraInfo = - ku_dynamic_cast( - literalTripleInfo.extraInfo.get()); + literalTripleInfo.extraInfo->ptrCast(); // Resource table auto resourceTableID = createTable(resourceInfo); // Literal table @@ -145,7 +134,7 @@ std::unique_ptr CatalogContent::createRdfGraphEntry(table_id_t tab void CatalogContent::dropTable(table_id_t tableID) { auto tableEntry = getTableCatalogEntry(tableID); if (tableEntry->getType() == CatalogEntryType::REL_GROUP_ENTRY) { - auto relGroupEntry = ku_dynamic_cast(tableEntry); + auto relGroupEntry = tableEntry->ptrCast(); for (auto& relTableID : relGroupEntry->getRelTableIDs()) { dropTable(relTableID); } @@ -156,34 +145,23 @@ void CatalogContent::dropTable(table_id_t tableID) { void CatalogContent::alterTable(const BoundAlterInfo& info) { switch (info.alterType) { case AlterType::RENAME_TABLE: { - auto& renameInfo = - ku_dynamic_cast( - *info.extraInfo); - renameTable(info.tableID, renameInfo.newName); + auto renameInfo = info.extraInfo->constPtrCast(); + renameTable(info.tableID, renameInfo->newName); } break; case AlterType::ADD_PROPERTY: { - auto& addPropInfo = - ku_dynamic_cast( - *info.extraInfo); - auto tableEntry = - ku_dynamic_cast(getTableCatalogEntry(info.tableID)); - tableEntry->addProperty(addPropInfo.propertyName, addPropInfo.dataType.copy()); + auto addPropInfo = info.extraInfo->constPtrCast(); + auto tableEntry = getTableCatalogEntry(info.tableID)->ptrCast(); + tableEntry->addProperty(addPropInfo->propertyName, addPropInfo->dataType.copy()); } break; case AlterType::RENAME_PROPERTY: { - auto& renamePropInfo = - ku_dynamic_cast( - *info.extraInfo); - auto tableEntry = - ku_dynamic_cast(getTableCatalogEntry(info.tableID)); - tableEntry->renameProperty(renamePropInfo.propertyID, renamePropInfo.newName); + auto renamePropInfo = info.extraInfo->constPtrCast(); + auto tableEntry = getTableCatalogEntry(info.tableID)->ptrCast(); + tableEntry->renameProperty(renamePropInfo->propertyID, renamePropInfo->newName); } break; case AlterType::DROP_PROPERTY: { - auto& dropPropInfo = - ku_dynamic_cast( - *info.extraInfo); - auto tableEntry = - ku_dynamic_cast(getTableCatalogEntry(info.tableID)); - tableEntry->dropProperty(dropPropInfo.propertyID); + auto dropPropInfo = info.extraInfo->constPtrCast(); + auto tableEntry = getTableCatalogEntry(info.tableID)->ptrCast(); + tableEntry->dropProperty(dropPropInfo->propertyID); } break; default: { KU_UNREACHABLE; @@ -195,7 +173,7 @@ void CatalogContent::renameTable(table_id_t tableID, const std::string& newName) // TODO(Xiyang/Ziyi): Do we allow renaming of rel table groups? auto tableEntry = getTableCatalogEntry(tableID); if (tableEntry->getType() == CatalogEntryType::RDF_GRAPH_ENTRY) { - auto rdfGraphEntry = ku_dynamic_cast(tableEntry); + auto rdfGraphEntry = tableEntry->constPtrCast(); renameTable(rdfGraphEntry->getResourceTableID(), RDFGraphCatalogEntry::getResourceTableName(newName)); renameTable(rdfGraphEntry->getLiteralTableID(), @@ -276,8 +254,7 @@ void CatalogContent::addFunction(CatalogEntryType entryType, std::string name, function::ScalarMacroFunction* CatalogContent::getScalarMacroFunction( const std::string& name) const { - return ku_dynamic_cast(functions->getEntry(name)) - ->getMacroFunction(); + return functions->getEntry(name)->constPtrCast()->getMacroFunction(); } std::vector CatalogContent::getTableEntries() const { @@ -315,7 +292,7 @@ std::string CatalogContent::getTableName(table_id_t tableID) const { CatalogEntry* CatalogContent::getTableCatalogEntry(table_id_t tableID) const { for (auto& [name, table] : tables->getEntries()) { - auto tableEntry = ku_dynamic_cast(table.get()); + auto tableEntry = table->constPtrCast(); if (tableEntry->getTableID() == tableID) { return table.get(); } @@ -324,18 +301,16 @@ CatalogEntry* CatalogContent::getTableCatalogEntry(table_id_t tableID) const { } table_id_t CatalogContent::getTableID(const std::string& tableName) const { - if (tables->containsEntry(tableName)) { - return ku_dynamic_cast(tables->getEntry(tableName)) - ->getTableID(); - } else { + if (!tables->containsEntry(tableName)) { throw CatalogException{stringFormat("Table: {} does not exist.", tableName)}; } + return tables->getEntry(tableName)->constPtrCast()->getTableID(); } std::vector CatalogContent::getTableIDs(CatalogEntryType catalogType) const { std::vector tableIDs; for (auto& [_, entry] : tables->getEntries()) { - auto tableEntry = ku_dynamic_cast(entry.get()); + auto tableEntry = entry->constPtrCast(); if (tableEntry->getType() == catalogType) { tableIDs.push_back(tableEntry->getTableID()); } diff --git a/src/common/enums/table_type.cpp b/src/common/enums/table_type.cpp index 055ef5551d..f591db6c47 100644 --- a/src/common/enums/table_type.cpp +++ b/src/common/enums/table_type.cpp @@ -23,7 +23,7 @@ std::string TableTypeUtils::toString(TableType tableType) { return "REL_GROUP"; } case TableType::FOREIGN: { - return "FOREIGN"; + return "ATTACHED"; } default: KU_UNREACHABLE; diff --git a/src/extension/extension.cpp b/src/extension/extension.cpp index beddf40050..3692352052 100644 --- a/src/extension/extension.cpp +++ b/src/extension/extension.cpp @@ -64,6 +64,10 @@ void ExtensionUtils::registerTableFunction(main::Database& database, function::function_set functionSet; functionSet.push_back(std::move(function)); auto catalog = database.getCatalog(); + if (catalog->getFunctions(transaction::Transaction::getDummyReadOnlyTrx().get()) + ->containsEntry(name)) { + return; + } catalog->addFunction(catalog::CatalogEntryType::TABLE_FUNCTION_ENTRY, std::move(name), std::move(functionSet)); catalog->checkpointInMemory(); diff --git a/src/include/binder/ddl/bound_alter_info.h b/src/include/binder/ddl/bound_alter_info.h index 9fa5515af9..330982f03b 100644 --- a/src/include/binder/ddl/bound_alter_info.h +++ b/src/include/binder/ddl/bound_alter_info.h @@ -9,6 +9,11 @@ namespace binder { struct BoundExtraAlterInfo { virtual ~BoundExtraAlterInfo() = default; + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast(this); + } + virtual std::unique_ptr copy() const = 0; }; diff --git a/src/include/binder/ddl/bound_create_table_info.h b/src/include/binder/ddl/bound_create_table_info.h index 197d9f5fac..20b3364e8e 100644 --- a/src/include/binder/ddl/bound_create_table_info.h +++ b/src/include/binder/ddl/bound_create_table_info.h @@ -13,6 +13,18 @@ namespace binder { struct BoundExtraCreateCatalogEntryInfo { virtual ~BoundExtraCreateCatalogEntryInfo() = default; + + template + const TARGET* constPtrCast() const { + return common::ku_dynamic_cast( + this); + } + + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + virtual inline std::unique_ptr copy() const = 0; }; diff --git a/src/include/catalog/catalog_entry/catalog_entry.h b/src/include/catalog/catalog_entry/catalog_entry.h index bec303e19d..7a6600b87a 100644 --- a/src/include/catalog/catalog_entry/catalog_entry.h +++ b/src/include/catalog/catalog_entry/catalog_entry.h @@ -38,6 +38,11 @@ class KUZU_API CatalogEntry { return common::ku_dynamic_cast(this); } + template + TARGET* ptrCast() { + return common::ku_dynamic_cast(this); + } + private: CatalogEntryType type; std::string name; diff --git a/src/include/main/database_manager.h b/src/include/main/database_manager.h index b603ac9760..780c27acf5 100644 --- a/src/include/main/database_manager.h +++ b/src/include/main/database_manager.h @@ -16,7 +16,7 @@ class DatabaseManager { bool hasDefaultDatabase() const { return defaultDatabase != ""; } void setDefaultDatabase(const std::string& databaseName); std::vector getAttachedDatabases() const; - void KUZU_API invalidateCache(const std::string& dbType); + void KUZU_API invalidateCache(); private: std::vector> attachedDatabases; diff --git a/src/main/database_manager.cpp b/src/main/database_manager.cpp index dc10abb031..bfbed2baa8 100644 --- a/src/main/database_manager.cpp +++ b/src/main/database_manager.cpp @@ -60,11 +60,9 @@ std::vector DatabaseManager::getAttachedDatabases() const { return attachedDatabasesPtr; } -void DatabaseManager::invalidateCache(const std::string& dbType) { +void DatabaseManager::invalidateCache() { for (auto& attachedDatabase : attachedDatabases) { - if (attachedDatabase->getDBType() == dbType) { - attachedDatabase->invalidateCache(); - } + attachedDatabase->invalidateCache(); } } diff --git a/tools/python_api/test/test_scan_pandas.py b/tools/python_api/test/test_scan_pandas.py index c7573991f5..05ec338dd8 100644 --- a/tools/python_api/test/test_scan_pandas.py +++ b/tools/python_api/test/test_scan_pandas.py @@ -334,3 +334,32 @@ def test_scan_all_null(tmp_path: Path) -> None: assert result.get_next() == [None] assert result.get_next() == [None] assert result.get_next() == [None] + + +def test_copy_from_scan_pandas_result(tmp_path: Path) -> None: + db = kuzu.Database(tmp_path) + conn = kuzu.Connection(db) + df = pd.DataFrame({ + "name": ["Adam", "Karissa", "Zhang", "Noura"], + "age": [30, 40, 50, 25] + }) + conn.execute("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY (name));") + conn.execute("COPY Person FROM (LOAD FROM df WHERE age < 30 RETURN *);") + result = conn.execute("match (p:Person) return p.*") + assert result.get_next() == ["Noura", 25] + assert result.has_next() is False + + +def test_scan_from_py_arrow_pandas(tmp_path: Path) -> None: + db = kuzu.Database(tmp_path) + conn = kuzu.Connection(db) + df = pd.DataFrame({ + "name": ["Adam", "Karissa", "Zhang", "Noura"], + "age": [30, 40, 50, 25] + }).convert_dtypes(dtype_backend="pyarrow") + result = conn.execute("LOAD FROM df RETURN *;") + assert result.get_next() == ["Adam", 30] + assert result.get_next() == ["Karissa", 40] + assert result.get_next() == ["Zhang", 50] + assert result.get_next() == ["Noura", 25] + assert result.has_next() is False diff --git a/tools/shell/embedded_shell.cpp b/tools/shell/embedded_shell.cpp index 5c309543b3..138d8231e1 100644 --- a/tools/shell/embedded_shell.cpp +++ b/tools/shell/embedded_shell.cpp @@ -51,7 +51,7 @@ struct ShellCommand { const char* TAB = " "; -const std::array keywordList = {"CALL", "CREATE", "DELETE", "DETACH", "EXISTS", +const std::array keywordList = {"CALL", "CREATE", "DELETE", "DETACH", "EXISTS", "FOREACH", "LOAD", "MATCH", "MERGE", "OPTIONAL", "REMOVE", "RETURN", "SET", "START", "UNION", "UNWIND", "WITH", "LIMIT", "ORDER", "SKIP", "WHERE", "YIELD", "ASC", "ASCENDING", "ASSERT", "BY", "CSV", "DESC", "DESCENDING", "ON", "ALL", "CASE", "ELSE", "END", "THEN", "WHEN", "AND", @@ -62,7 +62,7 @@ const std::array keywordList = {"CALL", "CREATE", "DELETE", "D "MINUS", "COUNT", "PRIMARY", "COPY", "RDFGRAPH", "ALTER", "RENAME", "COMMENT", "MACRO", "GLOB", "COLUMN", "GROUP", "DEFAULT", "TO", "BEGIN", "TRANSACTION", "READ", "ONLY", "WRITE", "COMMIT_SKIP_CHECKPOINT", "ROLLBACK", "ROLLBACK_SKIP_CHECKPOINT", "INSTALL", "EXTENSION", - "SHORTEST", "ATTACH", "IMPORT", "EXPORT"}; + "SHORTEST", "ATTACH", "IMPORT", "EXPORT", "USE"}; const char* keywordColorPrefix = "\033[32m\033[1m"; const char* keywordResetPostfix = "\033[39m\033[22m";