From 2b1f9c4aa0aad252a3e0d180e0b9e1455b9e7faf Mon Sep 17 00:00:00 2001 From: xiyang Date: Fri, 8 Mar 2024 15:15:05 -0500 Subject: [PATCH] Abstract client config --- src/binder/bind/bind_graph_pattern.cpp | 14 +-- src/common/task_system/task_scheduler.cpp | 2 +- src/include/common/constants.h | 6 +- src/include/main/client_config.h | 24 ++++++ src/include/main/client_context.h | 85 ++++++++----------- src/include/main/settings.h | 25 +++--- .../processor/operator/physical_operator.h | 10 +-- src/main/client_context.cpp | 79 +++++++++-------- src/optimizer/optimizer.cpp | 2 +- src/processor/operator/physical_operator.cpp | 10 +++ 10 files changed, 139 insertions(+), 118 deletions(-) create mode 100644 src/include/main/client_config.h diff --git a/src/binder/bind/bind_graph_pattern.cpp b/src/binder/bind/bind_graph_pattern.cpp index c3bd5c68ced..51dc25ac33d 100644 --- a/src/binder/bind/bind_graph_pattern.cpp +++ b/src/binder/bind/bind_graph_pattern.cpp @@ -454,20 +454,20 @@ std::pair Binder::bindVariableLengthRelBound( function::CastString::operation( ku_string_t{recursiveInfo->lowerBound.c_str(), recursiveInfo->lowerBound.length()}, lowerBound); - auto upperBound = clientContext->varLengthExtendMaxDepth; + auto maxDepth = clientContext->getClientConfig()->varLengthMaxDepth; + auto upperBound = maxDepth; if (!recursiveInfo->upperBound.empty()) { function::CastString::operation( ku_string_t{recursiveInfo->upperBound.c_str(), recursiveInfo->upperBound.length()}, upperBound); } if (lowerBound > upperBound) { - throw BinderException( - "Lower bound of rel " + relPattern.getVariableName() + " is greater than upperBound."); + throw BinderException(stringFormat( + "Lower bound of rel {} is greater than upperBound.", relPattern.getVariableName())); } - if (upperBound > clientContext->varLengthExtendMaxDepth) { - throw BinderException( - "Upper bound of rel " + relPattern.getVariableName() + - " exceeds maximum: " + std::to_string(clientContext->varLengthExtendMaxDepth) + "."); + if (upperBound > maxDepth) { + throw BinderException(stringFormat("Upper bound of rel {} exceeds maximum: {}.", + relPattern.getVariableName(), std::to_string(maxDepth))); } if ((relPattern.getRelType() == QueryRelType::ALL_SHORTEST || relPattern.getRelType() == QueryRelType::SHORTEST) && diff --git a/src/common/task_system/task_scheduler.cpp b/src/common/task_system/task_scheduler.cpp index a01a4114584..a107c920994 100644 --- a/src/common/task_system/task_scheduler.cpp +++ b/src/common/task_system/task_scheduler.cpp @@ -40,7 +40,7 @@ void TaskScheduler::scheduleTaskAndWaitOrError( taskLck.unlock(); break; } - if (context->clientContext->isTimeOutEnabled()) { + if (context->clientContext->hasTimeout()) { timeout = context->clientContext->getTimeoutRemainingInMS(); if (timeout == 0) { context->clientContext->interrupt(); diff --git a/src/include/common/constants.h b/src/include/common/constants.h index d769cf0cb49..96e6091a89e 100644 --- a/src/include/common/constants.h +++ b/src/include/common/constants.h @@ -173,9 +173,11 @@ struct PlannerKnobs { static constexpr uint64_t SIP_RATIO = 5; }; -struct ClientContextConstants { - // We disable query timeout by default. +struct ClientConfigDefault { + // 0 means timeout is disabled by default. static constexpr uint64_t TIMEOUT_IN_MS = 0; + static constexpr uint32_t VAR_LENGTH_MAX_DEPTH = 30; + static constexpr bool ENABLE_SEMI_MASK = true; }; struct OrderByConstants { diff --git a/src/include/main/client_config.h b/src/include/main/client_config.h new file mode 100644 index 00000000000..8649c5bd2c0 --- /dev/null +++ b/src/include/main/client_config.h @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace kuzu { +namespace main { + +struct ClientConfig { + // System home directory. + std::string homeDirectory; + // File search path. + std::string fileSearchPath; + // If using semi mask in join. + bool enableSemiMask; + // Number of threads for execution. + uint64_t numThreads; + // Timeout (milliseconds) + uint64_t timeoutInMS; + // variable length maximum depth + uint32_t varLengthMaxDepth; +}; + +} // namespace main +} // namespace kuzu diff --git a/src/include/main/client_context.h b/src/include/main/client_context.h index 032350a949e..8c784ea9200 100644 --- a/src/include/main/client_context.h +++ b/src/include/main/client_context.h @@ -6,6 +6,7 @@ #include #include +#include "client_config.h" #include "common/timer.h" #include "common/types/value/value.h" #include "function/scalar_function.h" @@ -47,77 +48,59 @@ class ClientContext { friend class Connection; friend class binder::Binder; friend class binder::ExpressionBinder; - friend class testing::TinySnbDDLTest; - friend class testing::TinySnbCopyCSVTransactionTest; - friend struct ThreadsSetting; - friend struct TimeoutSetting; - friend struct VarLengthExtendMaxDepthSetting; - friend struct EnableSemiMaskSetting; - friend struct HomeDirectorySetting; - friend struct FileSearchPathSetting; public: explicit ClientContext(Database* database); - inline void interrupt() { activeQuery.interrupted = true; } - - bool isInterrupted() const { return activeQuery.interrupted; } - - inline bool isTimeOutEnabled() const { return timeoutInMS != 0; } - - inline uint64_t getTimeoutRemainingInMS() { - KU_ASSERT(isTimeOutEnabled()); - auto elapsed = activeQuery.timer.getElapsedTimeInMS(); - return elapsed >= timeoutInMS ? 0 : timeoutInMS - elapsed; - } - - inline bool isEnableSemiMask() const { return enableSemiMask; } - - void startTimingIfEnabled(); - + // Client config + const ClientConfig* getClientConfig() const { return &config; } + ClientConfig* getClientConfigUnsafe() { return &config; } KUZU_API common::Value getCurrentSetting(const std::string& optionName); + // Timer and timeout + void interrupt() { activeQuery.interrupted = true; } + bool interrupted() const { return activeQuery.interrupted; } + bool hasTimeout() const { return config.timeoutInMS != 0; } + void setQueryTimeOut(uint64_t timeoutInMS); + uint64_t getQueryTimeOut(); + void startTimer(); + uint64_t getTimeoutRemainingInMS(); + void resetActiveQuery() { activeQuery.reset(); } + // Parallelism + void setMaxNumThreadForExec(uint64_t numThreads); + uint64_t getMaxNumThreadForExec(); + + // Transaction. transaction::Transaction* getTx() const; KUZU_API transaction::TransactionContext* getTransactionContext() const; + // Replace function. inline bool hasReplaceFunc() { return replaceFunc != nullptr; } inline void setReplaceFunc(replace_func_t func) { replaceFunc = func; } + // Extension KUZU_API void setExtensionOption(std::string name, common::Value value); - - common::RandomEngine* getRandomEngine() { return randomEngine.get(); } - - common::VirtualFileSystem* getVFSUnsafe() const; - std::string getExtensionDir() const; + // Environment. + KUZU_API std::string getEnvVariable(const std::string& name); + + // Database component getters. KUZU_API Database* getDatabase() const { return database; } storage::StorageManager* getStorageManager(); storage::MemoryManager* getMemoryManager(); catalog::Catalog* getCatalog(); + common::VirtualFileSystem* getVFSUnsafe() const; + common::RandomEngine* getRandomEngine(); - KUZU_API std::string getEnvVariable(const std::string& name); - + // Query. std::unique_ptr prepare(std::string_view query); - - void setQueryTimeOut(uint64_t timeoutInMS); - - uint64_t getQueryTimeOut(); - - void setMaxNumThreadForExec(uint64_t numThreads); - - uint64_t getMaxNumThreadForExec(); - KUZU_API std::unique_ptr executeWithParams(PreparedStatement* preparedStatement, std::unordered_map> inputParams); - std::unique_ptr query(std::string_view queryStatement); - void runQuery(std::string query); private: - inline void resetActiveQuery() { activeQuery.reset(); } - std::unique_ptr query( std::string_view query, std::string_view encodedJoin, bool enumerateAllPlans = true); @@ -152,17 +135,19 @@ class ClientContext { void commitUDFTrx(bool isAutoCommitTrx); - uint64_t numThreadsForExecution; + // Client side configurable settings. + ClientConfig config; + // Current query. ActiveQuery activeQuery; - uint64_t timeoutInMS; - uint32_t varLengthExtendMaxDepth; + // Transaction context. std::unique_ptr transactionContext; - bool enableSemiMask; + // Replace external object as pointer Value; replace_func_t replaceFunc; + // Extension configurable settings. std::unordered_map extensionOptionValues; + // Random generator for UUID. std::unique_ptr randomEngine; - std::string homeDirectory; - std::string fileSearchPath; + // Attached database. Database* database; std::mutex mtx; }; diff --git a/src/include/main/settings.h b/src/include/main/settings.h index 3e3b6a0a32c..a31ddd1f621 100644 --- a/src/include/main/settings.h +++ b/src/include/main/settings.h @@ -11,10 +11,10 @@ struct ThreadsSetting { static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::INT64; static void setContext(ClientContext* context, const common::Value& parameter) { KU_ASSERT(parameter.getDataType()->getLogicalTypeID() == common::LogicalTypeID::INT64); - context->numThreadsForExecution = parameter.getValue(); + context->getClientConfigUnsafe()->numThreads = parameter.getValue(); } static common::Value getSetting(ClientContext* context) { - return common::Value(context->numThreadsForExecution); + return common::Value(context->getClientConfig()->numThreads); } }; @@ -23,11 +23,10 @@ struct TimeoutSetting { static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::INT64; static void setContext(ClientContext* context, const common::Value& parameter) { KU_ASSERT(parameter.getDataType()->getLogicalTypeID() == common::LogicalTypeID::INT64); - context->timeoutInMS = parameter.getValue(); - context->startTimingIfEnabled(); + context->getClientConfigUnsafe()->timeoutInMS = parameter.getValue(); } static common::Value getSetting(ClientContext* context) { - return common::Value(context->timeoutInMS); + return common::Value(context->getClientConfig()->timeoutInMS); } }; @@ -36,10 +35,10 @@ struct VarLengthExtendMaxDepthSetting { static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::INT64; static void setContext(ClientContext* context, const common::Value& parameter) { KU_ASSERT(parameter.getDataType()->getLogicalTypeID() == common::LogicalTypeID::INT64); - context->varLengthExtendMaxDepth = parameter.getValue(); + context->getClientConfigUnsafe()->varLengthMaxDepth = parameter.getValue(); } static common::Value getSetting(ClientContext* context) { - return common::Value(context->varLengthExtendMaxDepth); + return common::Value(context->getClientConfig()->varLengthMaxDepth); } }; @@ -48,10 +47,10 @@ struct EnableSemiMaskSetting { static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::BOOL; static void setContext(ClientContext* context, const common::Value& parameter) { KU_ASSERT(parameter.getDataType()->getLogicalTypeID() == common::LogicalTypeID::BOOL); - context->enableSemiMask = parameter.getValue(); + context->getClientConfigUnsafe()->enableSemiMask = parameter.getValue(); } static common::Value getSetting(ClientContext* context) { - return common::Value(context->enableSemiMask); + return common::Value(context->getClientConfig()->enableSemiMask); } }; @@ -60,10 +59,10 @@ struct HomeDirectorySetting { static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::STRING; static void setContext(ClientContext* context, const common::Value& parameter) { KU_ASSERT(parameter.getDataType()->getLogicalTypeID() == common::LogicalTypeID::STRING); - context->homeDirectory = parameter.getValue(); + context->getClientConfigUnsafe()->homeDirectory = parameter.getValue(); } static common::Value getSetting(ClientContext* context) { - return common::Value::createValue(context->homeDirectory); + return common::Value::createValue(context->getClientConfig()->homeDirectory); } }; @@ -72,10 +71,10 @@ struct FileSearchPathSetting { static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::STRING; static void setContext(ClientContext* context, const common::Value& parameter) { KU_ASSERT(parameter.getDataType()->getLogicalTypeID() == common::LogicalTypeID::STRING); - context->fileSearchPath = parameter.getValue(); + context->getClientConfigUnsafe()->fileSearchPath = parameter.getValue(); } static common::Value getSetting(ClientContext* context) { - return common::Value::createValue(context->fileSearchPath); + return common::Value::createValue(context->getClientConfig()->fileSearchPath); } }; diff --git a/src/include/processor/operator/physical_operator.h b/src/include/processor/operator/physical_operator.h index 9076d65047a..8f04be21722 100644 --- a/src/include/processor/operator/physical_operator.h +++ b/src/include/processor/operator/physical_operator.h @@ -130,15 +130,7 @@ class PhysicalOperator { // Local state is initialized for each thread. void initLocalState(ResultSet* resultSet, ExecutionContext* context); - inline bool getNextTuple(ExecutionContext* context) { - if (context->clientContext->isInterrupted()) { - throw common::InterruptException{}; - } - metrics->executionTime.start(); - auto result = getNextTuplesInternal(context); - metrics->executionTime.stop(); - return result; - } + bool getNextTuple(ExecutionContext* context); std::unordered_map getProfilerKeyValAttributes( common::Profiler& profiler) const; diff --git a/src/main/client_context.cpp b/src/main/client_context.cpp index 87d2e63b9b6..c9b9f65c8a0 100644 --- a/src/main/client_context.cpp +++ b/src/main/client_context.cpp @@ -42,27 +42,51 @@ void ActiveQuery::reset() { timer = Timer(); } -ClientContext::ClientContext(Database* database) - : numThreadsForExecution{database->systemConfig.maxNumThreads}, - timeoutInMS{ClientContextConstants::TIMEOUT_IN_MS}, - varLengthExtendMaxDepth{DEFAULT_VAR_LENGTH_EXTEND_MAX_DEPTH}, - enableSemiMask{DEFAULT_ENABLE_SEMI_MASK}, database{database} { +ClientContext::ClientContext(Database* database) : database{database} { transactionContext = std::make_unique(database); randomEngine = std::make_unique(); - fileSearchPath = ""; #if defined(_WIN32) - homeDirectory = getEnvVariable("USERPROFILE"); + config.homeDirectory = getEnvVariable("USERPROFILE"); #else - homeDirectory = getEnvVariable("HOME"); + config.homeDirectory = getEnvVariable("HOME"); #endif + config.fileSearchPath = ""; + config.enableSemiMask = ClientConfigDefault::ENABLE_SEMI_MASK; + config.numThreads = database->systemConfig.maxNumThreads; + config.timeoutInMS = ClientConfigDefault::TIMEOUT_IN_MS; + config.varLengthMaxDepth = ClientConfigDefault::VAR_LENGTH_MAX_DEPTH; } -void ClientContext::startTimingIfEnabled() { - if (isTimeOutEnabled()) { +uint64_t ClientContext::getTimeoutRemainingInMS() { + KU_ASSERT(hasTimeout()); + auto elapsed = activeQuery.timer.getElapsedTimeInMS(); + return elapsed >= config.timeoutInMS ? 0 : config.timeoutInMS - elapsed; +} + +void ClientContext::startTimer() { + if (hasTimeout()) { activeQuery.timer.start(); } } +void ClientContext::setQueryTimeOut(uint64_t timeoutInMS) { + lock_t lck{mtx}; + config.timeoutInMS = timeoutInMS; +} + +uint64_t ClientContext::getQueryTimeOut() { + return config.timeoutInMS; +} + +void ClientContext::setMaxNumThreadForExec(uint64_t numThreads) { + lock_t lck{mtx}; + config.numThreads = numThreads; +} + +uint64_t ClientContext::getMaxNumThreadForExec() { + return config.numThreads; +} + Value ClientContext::getCurrentSetting(const std::string& optionName) { auto lowerCaseOptionName = optionName; StringUtils::toLower(lowerCaseOptionName); @@ -96,12 +120,8 @@ void ClientContext::setExtensionOption(std::string name, common::Value value) { extensionOptionValues.insert_or_assign(name, std::move(value)); } -VirtualFileSystem* ClientContext::getVFSUnsafe() const { - return database->vfs.get(); -} - std::string ClientContext::getExtensionDir() const { - return common::stringFormat("{}/.kuzu/extension", homeDirectory); + return common::stringFormat("{}/.kuzu/extension", config.homeDirectory); } storage::StorageManager* ClientContext::getStorageManager() { @@ -116,6 +136,14 @@ catalog::Catalog* ClientContext::getCatalog() { return database->catalog.get(); } +VirtualFileSystem* ClientContext::getVFSUnsafe() const { + return database->vfs.get(); +} + +common::RandomEngine* ClientContext::getRandomEngine() { + return randomEngine.get(); +} + std::string ClientContext::getEnvVariable(const std::string& name) { #if defined(_WIN32) auto envValue = common::WindowsUtils::utf8ToUnicode(name.c_str()); @@ -133,15 +161,6 @@ std::string ClientContext::getEnvVariable(const std::string& name) { #endif } -void ClientContext::setMaxNumThreadForExec(uint64_t numThreads) { - numThreadsForExecution = numThreads; -} - -uint64_t ClientContext::getMaxNumThreadForExec() { - std::unique_lock lck{mtx}; - return numThreadsForExecution; -} - std::unique_ptr ClientContext::prepare(std::string_view query) { auto preparedStatement = std::unique_ptr(); std::unique_lock lck{mtx}; @@ -297,16 +316,6 @@ std::vector> ClientContext::parseQuery(std::string_vi return statements; } -void ClientContext::setQueryTimeOut(uint64_t timeoutInMS) { - lock_t lck{mtx}; - this->timeoutInMS = timeoutInMS; -} - -uint64_t ClientContext::getQueryTimeOut() { - lock_t lck{mtx}; - return this->timeoutInMS; -} - std::unique_ptr ClientContext::executeWithParams(PreparedStatement* preparedStatement, std::unordered_map> inputParams) { // NOLINT(performance-unnecessary-value-param): It doesn't make sense to pass @@ -358,7 +367,7 @@ std::unique_ptr ClientContext::executeAndAutoCommitIfNecessaryNoLoc } } this->resetActiveQuery(); - this->startTimingIfEnabled(); + this->startTimer(); auto mapper = PlanMapper( *database->storageManager, database->memoryManager.get(), database->catalog.get(), this); std::unique_ptr physicalPlan; diff --git a/src/optimizer/optimizer.cpp b/src/optimizer/optimizer.cpp index c89786f97e3..86d10f3249b 100644 --- a/src/optimizer/optimizer.cpp +++ b/src/optimizer/optimizer.cpp @@ -31,7 +31,7 @@ void Optimizer::optimize(planner::LogicalPlan* plan, main::ClientContext* client auto projectionPushDownOptimizer = ProjectionPushDownOptimizer(); projectionPushDownOptimizer.rewrite(plan); - if (client->isEnableSemiMask()) { + if (client->getClientConfig()->enableSemiMask) { // HashJoinSIPOptimizer should be applied after optimizers that manipulate hash join. auto hashJoinSIPOptimizer = HashJoinSIPOptimizer(); hashJoinSIPOptimizer.rewrite(plan); diff --git a/src/processor/operator/physical_operator.cpp b/src/processor/operator/physical_operator.cpp index ebae013256a..e3a4868a71a 100644 --- a/src/processor/operator/physical_operator.cpp +++ b/src/processor/operator/physical_operator.cpp @@ -180,6 +180,16 @@ void PhysicalOperator::initLocalState(ResultSet* resultSet_, ExecutionContext* c initLocalStateInternal(resultSet_, context); } +bool PhysicalOperator::getNextTuple(ExecutionContext* context) { + if (context->clientContext->interrupted()) { + throw InterruptException{}; + } + metrics->executionTime.start(); + auto result = getNextTuplesInternal(context); + metrics->executionTime.stop(); + return result; +} + void PhysicalOperator::registerProfilingMetrics(Profiler* profiler) { auto executionTime = profiler->registerTimeMetric(getTimeMetricKey()); auto numOutputTuple = profiler->registerNumericMetric(getNumTupleMetricKey());