Skip to content

Commit

Permalink
Better error message for extensions (#3397)
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin authored and manh9203 committed Apr 29, 2024
1 parent 10a1b7f commit a8ea1fc
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 8 deletions.
46 changes: 41 additions & 5 deletions src/binder/bind/bind_extension.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,54 @@
#include "binder/binder.h"
#include "binder/bound_extension_statement.h"
#include "common/cast.h"
#include "common/exception/binder.h"
#include "common/file_system/local_file_system.h"
#include "extension/extension.h"
#include "parser/extension_statement.h"

using namespace kuzu::parser;

namespace kuzu {
namespace binder {

static void bindInstallExtension(const ExtensionStatement* extensionStatement) {
if (extensionStatement->getAction() == ExtensionAction::INSTALL) {
auto extensionName = extensionStatement->getPath();
if (!ExtensionUtils::isOfficialExtension(extensionName)) {
throw common::BinderException(common::stringFormat(
"{} is not an official extension.\nNon-official extensions "
"can be installed directly by: `LOAD EXTENSION [EXTENSION_PATH]`.",
extensionName));
}
}
}

static void bindLoadExtension(const ExtensionStatement* extensionStatement) {
if (ExtensionUtils::isOfficialExtension(extensionStatement->getPath())) {
return;
}
auto localFileSystem = common::LocalFileSystem();
if (!localFileSystem.fileOrPathExists(extensionStatement->getPath())) {
throw common::BinderException(
common::stringFormat("The extension {} is neither an official extension, nor does "
"the extension path: '{}' exists.",
extensionStatement->getPath(), extensionStatement->getPath()));
}
}

std::unique_ptr<BoundStatement> Binder::bindExtension(const Statement& statement) {
auto extensionStatement =
common::ku_dynamic_cast<const Statement&, const ExtensionStatement&>(statement);
return std::make_unique<BoundExtensionStatement>(extensionStatement.getAction(),
extensionStatement.getPath());
auto extensionStatement = statement.constPtrCast<ExtensionStatement>();
switch (extensionStatement->getAction()) {
case ExtensionAction::INSTALL:
bindInstallExtension(extensionStatement);
break;
case ExtensionAction::LOAD:
bindLoadExtension(extensionStatement);
break;
default:
KU_UNREACHABLE;
}
return std::make_unique<BoundExtensionStatement>(extensionStatement->getAction(),
extensionStatement->getPath());
}

} // namespace binder
Expand Down
10 changes: 10 additions & 0 deletions src/extension/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ void ExtensionUtils::registerTableFunction(main::Database& database,
catalog->checkpointInMemory();
}

bool ExtensionUtils::isOfficialExtension(const std::string& extension) {
auto extensionUpperCase = common::StringUtils::getUpper(extension);
for (auto& officialExtension : OFFICIAL_EXTENSION) {
if (officialExtension == extensionUpperCase) {
return true;
}
}
return false;
}

void ExtensionOptions::addExtensionOption(std::string name, common::LogicalTypeID type,
common::Value defaultValue) {
common::StringUtils::toLower(name);
Expand Down
4 changes: 4 additions & 0 deletions src/include/extension/extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ struct ExtensionUtils {
static constexpr const char* EXTENSION_REPO =
"http://extension.kuzudb.com/v{}/{}/lib{}.kuzu_extension";

static constexpr const char* OFFICIAL_EXTENSION[] = {"HTTPFS", "POSTGRES", "DUCKDB"};

static std::string getExtensionPath(const std::string& extensionDir, const std::string& name);

static bool isFullPath(const std::string& extension);
Expand All @@ -38,6 +40,8 @@ struct ExtensionUtils {

KUZU_API static void registerTableFunction(main::Database& database,
std::unique_ptr<function::TableFunction> function);

static bool isOfficialExtension(const std::string& extension);
};

struct ExtensionOptions {
Expand Down
12 changes: 12 additions & 0 deletions test/test_files/extension/extension.test
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,15 @@ Waterloo|150000
-STATEMENT INSTALL http;
---- error
IO exception: HTTP Returns: 404, Failed to download extension: "http" from extension.kuzudb.com/v0.1.0/osx_arm64/libhttp.kuzu_extension.

-CASE InstallUnofficialExtensions
-STATEMENT INSTALL sqlitescanner;
---- error
Binder exception: sqlitescanner is not an official extension.
Non-official extensions can be installed directly by: `LOAD EXTENSION [EXTENSION_PATH]`.
-STATEMENT LOAD EXTENSION sqlitescanner;
---- error
Binder exception: The extension sqlitescanner is neither an official extension, nor does the extension path: 'sqlitescanner' exists.
-STATEMENT LOAD EXTENSION '/tmp/iceberg';
---- error
Binder exception: The extension /tmp/iceberg is neither an official extension, nor does the extension path: '/tmp/iceberg' exists.
5 changes: 2 additions & 3 deletions tools/shell/embedded_shell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "catalog/catalog.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rel_table_catalog_entry.h"
#include "common/task_system/progress_bar.h"
#include "transaction/transaction.h"
#include "utf8proc.h"
#include "utf8proc_wrapper.h"
Expand Down Expand Up @@ -52,7 +51,7 @@ struct ShellCommand {

const char* TAB = " ";

const std::array<const char*, 104> keywordList = {"CALL", "CREATE", "DELETE", "DETACH", "EXISTS",
const std::array<const char*, 106> keywordList = {"CALL", "CREATE", "DELETE", "DETACH", "EXISTS",
"FOREACH", "LOAD", "MATCH", "MERGE", "OPTIONAL", "REMOVE", "RETURN", "SET", "START", "UNION",
"UNWIND", "WITH", "LIMIT", "ORDER", "SKIP", "WHERE", "YIELD", "ASC", "ASCENDING", "ASSERT",
"BY", "CSV", "DESC", "DESCENDING", "ON", "ALL", "CASE", "ELSE", "END", "THEN", "WHEN", "AND",
Expand All @@ -63,7 +62,7 @@ const std::array<const char*, 104> keywordList = {"CALL", "CREATE", "DELETE", "D
"MINUS", "COUNT", "PRIMARY", "COPY", "RDFGRAPH", "ALTER", "RENAME", "COMMENT", "MACRO", "GLOB",
"COLUMN", "GROUP", "DEFAULT", "TO", "BEGIN", "TRANSACTION", "READ", "ONLY", "WRITE",
"COMMIT_SKIP_CHECKPOINT", "ROLLBACK", "ROLLBACK_SKIP_CHECKPOINT", "INSTALL", "EXTENSION",
"SHORTEST", "ATTACH"};
"SHORTEST", "ATTACH", "IMPORT", "EXPORT"};

const char* keywordColorPrefix = "\033[32m\033[1m";
const char* keywordResetPostfix = "\033[39m\033[22m";
Expand Down

0 comments on commit a8ea1fc

Please sign in to comment.