diff --git a/src/common/task_system/progress_bar.cpp b/src/common/task_system/progress_bar.cpp index 3efb778a1e..526dde33f4 100644 --- a/src/common/task_system/progress_bar.cpp +++ b/src/common/task_system/progress_bar.cpp @@ -9,6 +9,13 @@ void ProgressBar::startProgress() { } std::lock_guard lock(progressBarLock); printProgressBar(0.0); + printing = true; +} + +void ProgressBar::endProgress() { + std::lock_guard lock(progressBarLock); + resetProgressBar(); + printing = false; } void ProgressBar::addPipeline() { @@ -23,21 +30,23 @@ void ProgressBar::finishPipeline() { return; } numPipelinesFinished++; - prevCurPipelineProgress = 0.0; + // This ensures that the progress bar is updated back to 0% after a pipeline is finished. + prevCurPipelineProgress = -0.01; updateProgress(0.0); - if (numPipelines == numPipelinesFinished) { - resetProgressBar(); - } } void ProgressBar::updateProgress(double curPipelineProgress) { + // Only update the progress bar if the progress has changed by at least 1%. if (!trackProgress || curPipelineProgress - prevCurPipelineProgress < 0.01) { return; } std::lock_guard lock(progressBarLock); prevCurPipelineProgress = curPipelineProgress; - std::cout << "\033[2A\033[2K\033[1B\033[2K\033[1A"; + if (printing) { + std::cout << "\033[2A"; + } printProgressBar(curPipelineProgress); + printing = true; } void ProgressBar::printProgressBar(double curPipelineProgress) const { @@ -55,9 +64,10 @@ void ProgressBar::printProgressBar(double curPipelineProgress) const { } void ProgressBar::resetProgressBar() { - std::lock_guard lock(progressBarLock); - std::cout << "\033[2A\033[2K\033[1B\033[2K\033[1A"; - std::cout.flush(); + if (printing) { + std::cout << "\033[2A\033[2K\033[1B\033[2K\033[1A"; + std::cout.flush(); + } numPipelines = 0; numPipelinesFinished = 0; prevCurPipelineProgress = 0.0; diff --git a/src/include/common/constants.h b/src/include/common/constants.h index 96e6091a89..2a1674387e 100644 --- a/src/include/common/constants.h +++ b/src/include/common/constants.h @@ -173,13 +173,6 @@ struct PlannerKnobs { static constexpr uint64_t SIP_RATIO = 5; }; -struct ClientConfigDefault { - // 0 means timeout is disabled by default. - static constexpr uint64_t TIMEOUT_IN_MS = 0; - static constexpr uint32_t VAR_LENGTH_MAX_DEPTH = 30; - static constexpr bool ENABLE_SEMI_MASK = true; -}; - struct OrderByConstants { static constexpr uint64_t NUM_BYTES_FOR_PAYLOAD_IDX = 8; static constexpr uint64_t MIN_SIZE_TO_REDUCE = common::DEFAULT_VECTOR_CAPACITY * 5; diff --git a/src/include/common/task_system/progress_bar.h b/src/include/common/task_system/progress_bar.h index bfd0050e3a..50d65eb5fc 100644 --- a/src/include/common/task_system/progress_bar.h +++ b/src/include/common/task_system/progress_bar.h @@ -7,19 +7,22 @@ namespace kuzu { namespace common { /** - * TODO: PUT DESCRIPTION HERE + * @brief Progress bar for tracking the progress of a pipeline. Prints the progress of each query + * pipeline and the overall progress. */ class ProgressBar { public: ProgressBar() - : numPipelines{0}, numPipelinesFinished{0}, prevCurPipelineProgress{0.0}, trackProgress{ - false} {}; + : numPipelines{0}, numPipelinesFinished{0}, prevCurPipelineProgress{0.0}, + trackProgress{false}, printing{false} {}; void addPipeline(); void finishPipeline(); + void endProgress(); + void addJobsToPipeline(int jobs); void finishJobsInPipeline(int jobs); @@ -45,6 +48,7 @@ class ProgressBar { double prevCurPipelineProgress; std::mutex progressBarLock; bool trackProgress; + bool printing; }; } // namespace common diff --git a/src/include/main/client_config.h b/src/include/main/client_config.h index 8649c5bd2c..7ca23607c0 100644 --- a/src/include/main/client_config.h +++ b/src/include/main/client_config.h @@ -18,6 +18,16 @@ struct ClientConfig { uint64_t timeoutInMS; // variable length maximum depth uint32_t varLengthMaxDepth; + // If using progress bar + bool enableProgressBar; +}; + +struct ClientConfigDefault { + // 0 means timeout is disabled by default. + static constexpr uint64_t TIMEOUT_IN_MS = 0; + static constexpr uint32_t VAR_LENGTH_MAX_DEPTH = 30; + static constexpr bool ENABLE_SEMI_MASK = true; + static constexpr bool ENABLE_PROGRESS_BAR = true; }; } // namespace main diff --git a/src/include/main/client_context.h b/src/include/main/client_context.h index 0cf4083029..de11e5bd05 100644 --- a/src/include/main/client_context.h +++ b/src/include/main/client_context.h @@ -79,6 +79,9 @@ class ClientContext { transaction::Transaction* getTx() const; KUZU_API transaction::TransactionContext* getTransactionContext() const; + // Progress bar + common::ProgressBar* getProgressBar() const; + // Replace function. inline bool hasReplaceFunc() { return replaceFunc != nullptr; } inline void setReplaceFunc(replace_func_t func) { replaceFunc = func; } @@ -106,10 +109,6 @@ class ClientContext { std::unique_ptr query(std::string_view queryStatement); void runQuery(std::string query); - void setProgressBarPrinting(bool progressBarPrinting); - - common::ProgressBar* getProgressBar() const { return progressBar.get(); } - private: std::unique_ptr query( std::string_view query, std::string_view encodedJoin, bool enumerateAllPlans = true); diff --git a/src/include/main/settings.h b/src/include/main/settings.h index a31ddd1f62..93796cca1a 100644 --- a/src/include/main/settings.h +++ b/src/include/main/settings.h @@ -30,6 +30,19 @@ struct TimeoutSetting { } }; +struct ProgressBarSetting { + static constexpr const char* name = "progress_bar"; + static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::BOOL; + static void setContext(ClientContext* context, const common::Value& parameter) { + KU_ASSERT(parameter.getDataType()->getLogicalTypeID() == common::LogicalTypeID::BOOL); + context->getClientConfigUnsafe()->enableProgressBar = parameter.getValue(); + context->getProgressBar()->toggleProgressBarPrinting(parameter.getValue()); + } + static common::Value getSetting(ClientContext* context) { + return common::Value(context->getClientConfig()->enableProgressBar); + } +}; + struct VarLengthExtendMaxDepthSetting { static constexpr const char* name = "var_length_extend_max_depth"; static constexpr const common::LogicalTypeID inputType = common::LogicalTypeID::INT64; diff --git a/src/main/client_context.cpp b/src/main/client_context.cpp index 1eac0c94ed..e6a3553c6d 100644 --- a/src/main/client_context.cpp +++ b/src/main/client_context.cpp @@ -3,7 +3,6 @@ #include #include "binder/binder.h" -#include "common/constants.h" #include "common/exception/connection.h" #include "common/exception/runtime.h" #include "common/random_engine.h" @@ -56,6 +55,7 @@ ClientContext::ClientContext(Database* database) : database{database} { config.numThreads = database->systemConfig.maxNumThreads; config.timeoutInMS = ClientConfigDefault::TIMEOUT_IN_MS; config.varLengthMaxDepth = ClientConfigDefault::VAR_LENGTH_MAX_DEPTH; + config.enableProgressBar = ClientConfigDefault::ENABLE_PROGRESS_BAR; } uint64_t ClientContext::getTimeoutRemainingInMS() const { @@ -116,6 +116,10 @@ TransactionContext* ClientContext::getTransactionContext() const { return transactionContext.get(); } +common::ProgressBar* ClientContext::getProgressBar() const { + return progressBar.get(); +} + void ClientContext::setExtensionOption(std::string name, common::Value value) { StringUtils::toLower(name); extensionOptionValues.insert_or_assign(name, std::move(value)); diff --git a/src/main/db_config.cpp b/src/main/db_config.cpp index ad567a4320..dfcfe1756e 100644 --- a/src/main/db_config.cpp +++ b/src/main/db_config.cpp @@ -14,7 +14,8 @@ namespace main { static ConfigurationOption options[] = { // NOLINT(cert-err58-cpp): GET_CONFIGURATION(ThreadsSetting), GET_CONFIGURATION(TimeoutSetting), GET_CONFIGURATION(VarLengthExtendMaxDepthSetting), GET_CONFIGURATION(EnableSemiMaskSetting), - GET_CONFIGURATION(HomeDirectorySetting), GET_CONFIGURATION(FileSearchPathSetting)}; + GET_CONFIGURATION(HomeDirectorySetting), GET_CONFIGURATION(FileSearchPathSetting), + GET_CONFIGURATION(ProgressBarSetting)}; ConfigurationOption* DBConfig::getOptionByName(const std::string& optionName) { auto lOptionName = optionName; diff --git a/src/processor/processor.cpp b/src/processor/processor.cpp index 382aa737be..c255bd03c6 100644 --- a/src/processor/processor.cpp +++ b/src/processor/processor.cpp @@ -27,6 +27,7 @@ std::shared_ptr QueryProcessor::execute( initTask(task.get()); context->clientContext->getProgressBar()->startProgress(); taskScheduler->scheduleTaskAndWaitOrError(task, context); + context->clientContext->getProgressBar()->endProgress(); return resultCollector->getResultFactorizedTable(); } diff --git a/test/test_files/tinysnb/call/call.test b/test/test_files/tinysnb/call/call.test index d9668f8d26..94a4898817 100644 --- a/test/test_files/tinysnb/call/call.test +++ b/test/test_files/tinysnb/call/call.test @@ -42,6 +42,18 @@ Binder exception: Upper bound of rel exceeds maximum: 10. ---- 1 354290 +-LOG SetGetProgressBar +-STATEMENT CALL progress_bar=true +---- ok +-STATEMENT CALL current_setting('progress_bar') RETURN * +---- 1 +True +-STATEMENT CALL progress_bar=false +---- ok +-STATEMENT CALL current_setting('progress_bar') RETURN * +---- 1 +False + -LOG disableSemihMaskOptimization -STATEMENT CALL enable_semi_mask=true ---- ok diff --git a/tools/shell/embedded_shell.cpp b/tools/shell/embedded_shell.cpp index 68d6ccade2..a82dd9ff05 100644 --- a/tools/shell/embedded_shell.cpp +++ b/tools/shell/embedded_shell.cpp @@ -51,8 +51,7 @@ struct ShellCommand { const char* QUIT = ":quit"; const char* MAX_ROWS = ":max_rows"; const char* MAX_WIDTH = ":max_width"; - const char* PROGRESS_BAR = ":progress_bar"; - const std::array commandList = {HELP, CLEAR, QUIT, MAX_ROWS, MAX_WIDTH, PROGRESS_BAR}; + const std::array commandList = {HELP, CLEAR, QUIT, MAX_ROWS, MAX_WIDTH}; } shellCommand; const char* TAB = " "; @@ -285,8 +284,6 @@ int EmbeddedShell::processShellCommands(std::string lineStr) { setMaxRows(lineStr.substr(strlen(shellCommand.MAX_ROWS))); } else if (lineStr.rfind(shellCommand.MAX_WIDTH) == 0) { setMaxWidth(lineStr.substr(strlen(shellCommand.MAX_WIDTH))); - } else if (lineStr.rfind(shellCommand.PROGRESS_BAR) == 0) { - toggleProgressBar(lineStr.substr(strlen(shellCommand.PROGRESS_BAR))); } else { printf("Error: Unknown command: \"%s\". Enter \":help\" for help\n", lineStr.c_str()); printf("Did you mean: \"%s\"?\n", findClosestCommand(lineStr).c_str()); @@ -430,20 +427,6 @@ void EmbeddedShell::setMaxWidth(const std::string& maxWidthString) { printf("maxWidth set as %d\n", parsedMaxWidth); } -void EmbeddedShell::toggleProgressBar(const std::string& state) { - std::string stateTrimmed = state; - stateTrimmed = stateTrimmed.erase(0, state.find_first_not_of(" \t\n\r\f\v")); - if (stateTrimmed == "on") { - conn->setProgressBarPrinting(true); - printf("Turned progress bar on.\n"); - } else if (stateTrimmed == "off") { - conn->setProgressBarPrinting(false); - printf("Turned progress bar off\n"); - } else { - printf("Cannot parse '%s' as progress bar state. Expect on|off.\n", stateTrimmed.c_str()); - } -} - void EmbeddedShell::printHelp() { printf("%s%s %sget command list\n", TAB, shellCommand.HELP, TAB); printf("%s%s %sclear shell\n", TAB, shellCommand.CLEAR, TAB); @@ -452,11 +435,10 @@ void EmbeddedShell::printHelp() { shellCommand.MAX_ROWS, TAB); printf("%s%s [max_width] %sset maximum width in characters for display\n", TAB, shellCommand.MAX_WIDTH, TAB); - printf("%s%s [on|off] %stoggle progress bar for queries\n", TAB, shellCommand.PROGRESS_BAR, TAB); printf("\n"); printf("%sNote: you can change and see several system configurations, such as num-threads, \n", TAB); - printf("%s%s timeout, and logging_level using Cypher CALL statements.\n", TAB, TAB); + printf("%s%s timeout, and progress_bar using Cypher CALL statements.\n", TAB, TAB); printf("%s%s e.g. CALL THREADS=5; or CALL current_setting('threads') return *;\n", TAB, TAB); printf("%s%s See: https://docs.kuzudb.com/cypher/configuration\n", TAB, TAB); } diff --git a/tools/shell/include/embedded_shell.h b/tools/shell/include/embedded_shell.h index 8aa44e6c04..7b5171f455 100644 --- a/tools/shell/include/embedded_shell.h +++ b/tools/shell/include/embedded_shell.h @@ -32,8 +32,6 @@ class EmbeddedShell { void setMaxWidth(const std::string& maxWidthString); - void toggleProgressBar(const std::string& state); - private: std::unique_ptr database; std::unique_ptr conn; diff --git a/tools/shell/test/test_shell_commands.py b/tools/shell/test/test_shell_commands.py index f402eba89c..ed610c2dbd 100644 --- a/tools/shell/test/test_shell_commands.py +++ b/tools/shell/test/test_shell_commands.py @@ -11,10 +11,9 @@ def test_help(temp_db) -> None: " :quit exit from shell", " :max_rows [max_rows] set maximum number of rows for display (default: 20)", " :max_width [max_width] set maximum width in characters for display", - " :progress_bar [on|off] toggle progress bar for queries", "", " Note: you can change and see several system configurations, such as num-threads, ", - " timeout, and logging_level using Cypher CALL statements.", + " timeout, and progress_bar using Cypher CALL statements.", " e.g. CALL THREADS=5; or CALL current_setting('threads') return *;", " See: https://docs.kuzudb.com/cypher/configuration", ],