From a8ea1fc2ba633aacd66cabedae7ec9d313991023 Mon Sep 17 00:00:00 2001 From: ziyi chen Date: Mon, 29 Apr 2024 09:22:49 +0800 Subject: [PATCH] Better error message for extensions (#3397) --- src/binder/bind/bind_extension.cpp | 46 +++++++++++++++++++++--- src/extension/extension.cpp | 10 ++++++ src/include/extension/extension.h | 4 +++ test/test_files/extension/extension.test | 12 +++++++ tools/shell/embedded_shell.cpp | 5 ++- 5 files changed, 69 insertions(+), 8 deletions(-) diff --git a/src/binder/bind/bind_extension.cpp b/src/binder/bind/bind_extension.cpp index a5b5b0b4d66..a53728a73a7 100644 --- a/src/binder/bind/bind_extension.cpp +++ b/src/binder/bind/bind_extension.cpp @@ -1,6 +1,8 @@ #include "binder/binder.h" #include "binder/bound_extension_statement.h" -#include "common/cast.h" +#include "common/exception/binder.h" +#include "common/file_system/local_file_system.h" +#include "extension/extension.h" #include "parser/extension_statement.h" using namespace kuzu::parser; @@ -8,11 +10,45 @@ using namespace kuzu::parser; namespace kuzu { namespace binder { +static void bindInstallExtension(const ExtensionStatement* extensionStatement) { + if (extensionStatement->getAction() == ExtensionAction::INSTALL) { + auto extensionName = extensionStatement->getPath(); + if (!ExtensionUtils::isOfficialExtension(extensionName)) { + throw common::BinderException(common::stringFormat( + "{} is not an official extension.\nNon-official extensions " + "can be installed directly by: `LOAD EXTENSION [EXTENSION_PATH]`.", + extensionName)); + } + } +} + +static void bindLoadExtension(const ExtensionStatement* extensionStatement) { + if (ExtensionUtils::isOfficialExtension(extensionStatement->getPath())) { + return; + } + auto localFileSystem = common::LocalFileSystem(); + if (!localFileSystem.fileOrPathExists(extensionStatement->getPath())) { + throw common::BinderException( + common::stringFormat("The extension {} is neither an official extension, nor does " + "the extension path: '{}' exists.", + extensionStatement->getPath(), extensionStatement->getPath())); + } +} + std::unique_ptr Binder::bindExtension(const Statement& statement) { - auto extensionStatement = - common::ku_dynamic_cast(statement); - return std::make_unique(extensionStatement.getAction(), - extensionStatement.getPath()); + auto extensionStatement = statement.constPtrCast(); + switch (extensionStatement->getAction()) { + case ExtensionAction::INSTALL: + bindInstallExtension(extensionStatement); + break; + case ExtensionAction::LOAD: + bindLoadExtension(extensionStatement); + break; + default: + KU_UNREACHABLE; + } + return std::make_unique(extensionStatement->getAction(), + extensionStatement->getPath()); } } // namespace binder diff --git a/src/extension/extension.cpp b/src/extension/extension.cpp index d0e4c8648f8..beddf40050f 100644 --- a/src/extension/extension.cpp +++ b/src/extension/extension.cpp @@ -69,6 +69,16 @@ void ExtensionUtils::registerTableFunction(main::Database& database, catalog->checkpointInMemory(); } +bool ExtensionUtils::isOfficialExtension(const std::string& extension) { + auto extensionUpperCase = common::StringUtils::getUpper(extension); + for (auto& officialExtension : OFFICIAL_EXTENSION) { + if (officialExtension == extensionUpperCase) { + return true; + } + } + return false; +} + void ExtensionOptions::addExtensionOption(std::string name, common::LogicalTypeID type, common::Value defaultValue) { common::StringUtils::toLower(name); diff --git a/src/include/extension/extension.h b/src/include/extension/extension.h index 30e229ef6e1..f14025a0344 100644 --- a/src/include/extension/extension.h +++ b/src/include/extension/extension.h @@ -30,6 +30,8 @@ struct ExtensionUtils { static constexpr const char* EXTENSION_REPO = "http://extension.kuzudb.com/v{}/{}/lib{}.kuzu_extension"; + static constexpr const char* OFFICIAL_EXTENSION[] = {"HTTPFS", "POSTGRES", "DUCKDB"}; + static std::string getExtensionPath(const std::string& extensionDir, const std::string& name); static bool isFullPath(const std::string& extension); @@ -38,6 +40,8 @@ struct ExtensionUtils { KUZU_API static void registerTableFunction(main::Database& database, std::unique_ptr function); + + static bool isOfficialExtension(const std::string& extension); }; struct ExtensionOptions { diff --git a/test/test_files/extension/extension.test b/test/test_files/extension/extension.test index aa16402a30d..9c9a5a9258e 100644 --- a/test/test_files/extension/extension.test +++ b/test/test_files/extension/extension.test @@ -25,3 +25,15 @@ Waterloo|150000 -STATEMENT INSTALL http; ---- error IO exception: HTTP Returns: 404, Failed to download extension: "http" from extension.kuzudb.com/v0.1.0/osx_arm64/libhttp.kuzu_extension. + +-CASE InstallUnofficialExtensions +-STATEMENT INSTALL sqlitescanner; +---- error +Binder exception: sqlitescanner is not an official extension. +Non-official extensions can be installed directly by: `LOAD EXTENSION [EXTENSION_PATH]`. +-STATEMENT LOAD EXTENSION sqlitescanner; +---- error +Binder exception: The extension sqlitescanner is neither an official extension, nor does the extension path: 'sqlitescanner' exists. +-STATEMENT LOAD EXTENSION '/tmp/iceberg'; +---- error +Binder exception: The extension /tmp/iceberg is neither an official extension, nor does the extension path: '/tmp/iceberg' exists. diff --git a/tools/shell/embedded_shell.cpp b/tools/shell/embedded_shell.cpp index d9c555fcd41..d216d0d564b 100644 --- a/tools/shell/embedded_shell.cpp +++ b/tools/shell/embedded_shell.cpp @@ -16,7 +16,6 @@ #include "catalog/catalog.h" #include "catalog/catalog_entry/node_table_catalog_entry.h" #include "catalog/catalog_entry/rel_table_catalog_entry.h" -#include "common/task_system/progress_bar.h" #include "transaction/transaction.h" #include "utf8proc.h" #include "utf8proc_wrapper.h" @@ -52,7 +51,7 @@ struct ShellCommand { const char* TAB = " "; -const std::array keywordList = {"CALL", "CREATE", "DELETE", "DETACH", "EXISTS", +const std::array keywordList = {"CALL", "CREATE", "DELETE", "DETACH", "EXISTS", "FOREACH", "LOAD", "MATCH", "MERGE", "OPTIONAL", "REMOVE", "RETURN", "SET", "START", "UNION", "UNWIND", "WITH", "LIMIT", "ORDER", "SKIP", "WHERE", "YIELD", "ASC", "ASCENDING", "ASSERT", "BY", "CSV", "DESC", "DESCENDING", "ON", "ALL", "CASE", "ELSE", "END", "THEN", "WHEN", "AND", @@ -63,7 +62,7 @@ const std::array keywordList = {"CALL", "CREATE", "DELETE", "D "MINUS", "COUNT", "PRIMARY", "COPY", "RDFGRAPH", "ALTER", "RENAME", "COMMENT", "MACRO", "GLOB", "COLUMN", "GROUP", "DEFAULT", "TO", "BEGIN", "TRANSACTION", "READ", "ONLY", "WRITE", "COMMIT_SKIP_CHECKPOINT", "ROLLBACK", "ROLLBACK_SKIP_CHECKPOINT", "INSTALL", "EXTENSION", - "SHORTEST", "ATTACH"}; + "SHORTEST", "ATTACH", "IMPORT", "EXPORT"}; const char* keywordColorPrefix = "\033[32m\033[1m"; const char* keywordResetPostfix = "\033[39m\033[22m";