Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better error message for extensions #3397

Merged
merged 2 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions src/binder/bind/bind_extension.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,54 @@
#include "binder/binder.h"
#include "binder/bound_extension_statement.h"
#include "common/cast.h"
#include "common/exception/binder.h"
#include "common/file_system/local_file_system.h"
#include "extension/extension.h"
#include "parser/extension_statement.h"

using namespace kuzu::parser;

namespace kuzu {
namespace binder {

static void bindInstallExtension(const ExtensionStatement* extensionStatement) {
if (extensionStatement->getAction() == ExtensionAction::INSTALL) {
auto extensionName = extensionStatement->getPath();
if (!ExtensionUtils::isOfficialExtension(extensionName)) {
throw common::BinderException(common::stringFormat(
"{} is not an official extension.\nNon-official extensions "
"can be installed directly by: `LOAD EXTENSION [EXTENSION_PATH]`.",
extensionName));
}
}
}

static void bindLoadExtension(const ExtensionStatement* extensionStatement) {
if (ExtensionUtils::isOfficialExtension(extensionStatement->getPath())) {
return;
}
auto localFileSystem = common::LocalFileSystem();
if (!localFileSystem.fileOrPathExists(extensionStatement->getPath())) {
throw common::BinderException(
common::stringFormat("The extension {} is neither an official extension, nor does "
"the extension path: '{}' exists.",
extensionStatement->getPath(), extensionStatement->getPath()));
}
}

std::unique_ptr<BoundStatement> Binder::bindExtension(const Statement& statement) {
auto extensionStatement =
common::ku_dynamic_cast<const Statement&, const ExtensionStatement&>(statement);
return std::make_unique<BoundExtensionStatement>(extensionStatement.getAction(),
extensionStatement.getPath());
auto extensionStatement = statement.constPtrCast<ExtensionStatement>();
switch (extensionStatement->getAction()) {
case ExtensionAction::INSTALL:
bindInstallExtension(extensionStatement);
break;
case ExtensionAction::LOAD:
bindLoadExtension(extensionStatement);
break;
default:
KU_UNREACHABLE;
}
return std::make_unique<BoundExtensionStatement>(extensionStatement->getAction(),
extensionStatement->getPath());
}

} // namespace binder
Expand Down
10 changes: 10 additions & 0 deletions src/extension/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ void ExtensionUtils::registerTableFunction(main::Database& database,
catalog->checkpointInMemory();
}

bool ExtensionUtils::isOfficialExtension(const std::string& extension) {
auto extensionUpperCase = common::StringUtils::getUpper(extension);
for (auto& officialExtension : OFFICIAL_EXTENSION) {
if (officialExtension == extensionUpperCase) {
return true;
}
}
return false;
}

void ExtensionOptions::addExtensionOption(std::string name, common::LogicalTypeID type,
common::Value defaultValue) {
common::StringUtils::toLower(name);
Expand Down
4 changes: 4 additions & 0 deletions src/include/extension/extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ struct ExtensionUtils {
static constexpr const char* EXTENSION_REPO =
"http://extension.kuzudb.com/v{}/{}/lib{}.kuzu_extension";

static constexpr const char* OFFICIAL_EXTENSION[] = {"HTTPFS", "POSTGRES", "DUCKDB"};

static std::string getExtensionPath(const std::string& extensionDir, const std::string& name);

static bool isFullPath(const std::string& extension);
Expand All @@ -38,6 +40,8 @@ struct ExtensionUtils {

KUZU_API static void registerTableFunction(main::Database& database,
std::unique_ptr<function::TableFunction> function);

static bool isOfficialExtension(const std::string& extension);
};

struct ExtensionOptions {
Expand Down
12 changes: 12 additions & 0 deletions test/test_files/extension/extension.test
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,15 @@ Waterloo|150000
-STATEMENT INSTALL http;
---- error
IO exception: HTTP Returns: 404, Failed to download extension: "http" from extension.kuzudb.com/v0.1.0/osx_arm64/libhttp.kuzu_extension.

-CASE InstallUnofficialExtensions
acquamarin marked this conversation as resolved.
Show resolved Hide resolved
-STATEMENT INSTALL sqlitescanner;
---- error
Binder exception: sqlitescanner is not an official extension.
Non-official extensions can be installed directly by: `LOAD EXTENSION [EXTENSION_PATH]`.
-STATEMENT LOAD EXTENSION sqlitescanner;
---- error
Binder exception: The extension sqlitescanner is neither an official extension, nor does the extension path: 'sqlitescanner' exists.
-STATEMENT LOAD EXTENSION '/tmp/iceberg';
---- error
Binder exception: The extension /tmp/iceberg is neither an official extension, nor does the extension path: '/tmp/iceberg' exists.
5 changes: 2 additions & 3 deletions tools/shell/embedded_shell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "catalog/catalog.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rel_table_catalog_entry.h"
#include "common/task_system/progress_bar.h"
#include "transaction/transaction.h"
#include "utf8proc.h"
#include "utf8proc_wrapper.h"
Expand Down Expand Up @@ -52,7 +51,7 @@ struct ShellCommand {

const char* TAB = " ";

const std::array<const char*, 104> keywordList = {"CALL", "CREATE", "DELETE", "DETACH", "EXISTS",
const std::array<const char*, 106> keywordList = {"CALL", "CREATE", "DELETE", "DETACH", "EXISTS",
"FOREACH", "LOAD", "MATCH", "MERGE", "OPTIONAL", "REMOVE", "RETURN", "SET", "START", "UNION",
"UNWIND", "WITH", "LIMIT", "ORDER", "SKIP", "WHERE", "YIELD", "ASC", "ASCENDING", "ASSERT",
"BY", "CSV", "DESC", "DESCENDING", "ON", "ALL", "CASE", "ELSE", "END", "THEN", "WHEN", "AND",
Expand All @@ -63,7 +62,7 @@ const std::array<const char*, 104> keywordList = {"CALL", "CREATE", "DELETE", "D
"MINUS", "COUNT", "PRIMARY", "COPY", "RDFGRAPH", "ALTER", "RENAME", "COMMENT", "MACRO", "GLOB",
"COLUMN", "GROUP", "DEFAULT", "TO", "BEGIN", "TRANSACTION", "READ", "ONLY", "WRITE",
"COMMIT_SKIP_CHECKPOINT", "ROLLBACK", "ROLLBACK_SKIP_CHECKPOINT", "INSTALL", "EXTENSION",
"SHORTEST", "ATTACH"};
"SHORTEST", "ATTACH", "IMPORT", "EXPORT"};

const char* keywordColorPrefix = "\033[32m\033[1m";
const char* keywordResetPostfix = "\033[39m\033[22m";
Expand Down
Loading