Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added progress for in_query_call operators #3120

Merged
merged 8 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/common/task_system/progress_bar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,23 @@
return;
}
numPipelinesFinished++;
if (printing) {
std::cout << "\033[1A\033[2K\033[1B";

Check warning on line 34 in src/common/task_system/progress_bar.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/task_system/progress_bar.cpp#L34

Added line #L34 was not covered by tests
MSebanc marked this conversation as resolved.
Show resolved Hide resolved
}
// This ensures that the progress bar is updated back to 0% after a pipeline is finished.
prevCurPipelineProgress = -0.01;
updateProgress(0.0);
}

void ProgressBar::updateProgress(double curPipelineProgress) {
// Only update the progress bar if the progress has changed by at least 1%.
if (!trackProgress || curPipelineProgress - prevCurPipelineProgress < 0.01) {
if (!trackProgress) {
return;
}
std::lock_guard<std::mutex> lock(progressBarLock);
// Only update the progress bar if the progress has changed by at least 1%.
if (curPipelineProgress - prevCurPipelineProgress < 0.01) {
return;
}
prevCurPipelineProgress = curPipelineProgress;
if (printing) {
std::cout << "\033[2A";
Expand Down
3 changes: 2 additions & 1 deletion src/include/function/table/scan_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ struct ScanSharedState : public BaseScanSharedState {

struct ScanFileSharedState : public ScanSharedState {
main::ClientContext* context;
uint64_t totalSize;

ScanFileSharedState(
common::ReaderConfig readerConfig, uint64_t numRows, main::ClientContext* context)
: ScanSharedState{std::move(readerConfig), numRows}, context{context} {}
: ScanSharedState{std::move(readerConfig), numRows}, context{context}, totalSize{0} {}
};

} // namespace function
Expand Down
8 changes: 8 additions & 0 deletions src/include/function/table_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ using table_func_init_shared_t =
using table_func_init_local_t = std::function<std::unique_ptr<TableFuncLocalState>(
TableFunctionInitInput&, TableFuncSharedState*, storage::MemoryManager*)>;
using table_func_can_parallel_t = std::function<bool()>;
using table_func_progress_t = std::function<double(TableFuncSharedState* sharedState)>;

struct TableFunction final : public Function {
table_func_t tableFunc;
table_func_bind_t bindFunc;
table_func_init_shared_t initSharedStateFunc;
table_func_init_local_t initLocalStateFunc;
table_func_can_parallel_t canParallelFunc = [] { return true; };
table_func_progress_t progressFunc = [](TableFuncSharedState* /*sharedState*/) { return 0.0; };

TableFunction()
: Function{}, tableFunc{nullptr}, bindFunc{nullptr}, initSharedStateFunc{nullptr},
Expand All @@ -80,6 +82,12 @@ struct TableFunction final : public Function {
: Function{FunctionType::TABLE, std::move(name), std::move(inputTypes)},
tableFunc{tableFunc}, bindFunc{bindFunc}, initSharedStateFunc{initSharedFunc},
initLocalStateFunc{initLocalFunc} {}
TableFunction(std::string name, table_func_t tableFunc, table_func_bind_t bindFunc,
table_func_init_shared_t initSharedFunc, table_func_init_local_t initLocalFunc,
table_func_progress_t progressFunc, std::vector<common::LogicalTypeID> inputTypes)
: Function{FunctionType::TABLE, std::move(name), std::move(inputTypes)},
tableFunc{tableFunc}, bindFunc{bindFunc}, initSharedStateFunc{initSharedFunc},
initLocalStateFunc{initLocalFunc}, progressFunc{progressFunc} {}

inline std::string signatureToString() const override {
return common::LogicalTypeUtils::toString(parameterTypeIDs);
Expand Down
2 changes: 2 additions & 0 deletions src/include/processor/operator/call/in_query_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class InQueryCall : public PhysicalOperator {

bool getNextTuplesInternal(ExecutionContext* context) override;

double getProgress(ExecutionContext* context) const override;

std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<InQueryCall>(info.copy(), sharedState, id, paramsString);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class BaseCSVReader {

uint64_t countRows();
bool isEOF() const;
uint64_t getFileSize();
// Get the file offset of the current buffer position.
uint64_t getFileOffset() const;

protected:
template<typename Driver>
Expand Down Expand Up @@ -56,8 +59,6 @@ class BaseCSVReader {

inline bool isNewLine(char c) { return c == '\n' || c == '\r'; }

// Get the file offset of the current buffer position.
uint64_t getFileOffset() const;
uint64_t getLineNumber();

protected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ struct ParallelCSVScanSharedState final : public function::ScanFileSharedState {
explicit ParallelCSVScanSharedState(common::ReaderConfig readerConfig, uint64_t numRows,
uint64_t numColumns, main::ClientContext* context, common::CSVReaderConfig csvReaderConfig)
: ScanFileSharedState{std::move(readerConfig), numRows, context}, numColumns{numColumns},
csvReaderConfig{std::move(csvReaderConfig)} {}
numBlocksReadByFiles{0}, csvReaderConfig{std::move(csvReaderConfig)} {}

void setFileComplete(uint64_t completedFileIdx);

uint64_t numColumns;
uint64_t numBlocksReadByFiles = 0;
common::CSVReaderConfig csvReaderConfig;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ class SerialCSVReader final : public BaseCSVReader {
struct SerialCSVScanSharedState final : public function::ScanFileSharedState {
std::unique_ptr<SerialCSVReader> reader;
uint64_t numColumns;
uint64_t totalReadSizeByFile;
common::CSVReaderConfig csvReaderConfig;

SerialCSVScanSharedState(common::ReaderConfig readerConfig, uint64_t numRows,
uint64_t numColumns, common::CSVReaderConfig csvReaderConfig, main::ClientContext* context)
: ScanFileSharedState{std::move(readerConfig), numRows, context}, numColumns{numColumns},
csvReaderConfig{std::move(csvReaderConfig)} {
totalReadSizeByFile{0}, csvReaderConfig{std::move(csvReaderConfig)} {
initReader(context);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ struct ParquetScanSharedState final : public function::ScanFileSharedState {
const common::ReaderConfig readerConfig, uint64_t numRows, main::ClientContext* context);

std::vector<std::unique_ptr<ParquetReader>> readers;
uint64_t totalRowsGroups;
uint64_t numBlocksReadByFiles;
};

struct ParquetScanLocalState final : public function::TableFuncLocalState {
Expand Down
4 changes: 4 additions & 0 deletions src/processor/operator/call/in_query_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,9 @@ bool InQueryCall::getNextTuplesInternal(ExecutionContext*) {
return numTuplesScanned != 0;
}

double InQueryCall::getProgress(ExecutionContext* /*context*/) const {
return info.function.progressFunc(sharedState->funcState.get());
}

} // namespace processor
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ bool BaseCSVReader::isEOF() const {
return getFileOffset() >= fileInfo->getFileSize();
}

uint64_t BaseCSVReader::getFileSize() {
return fileInfo->getFileSize();
}

template<typename Driver>
void BaseCSVReader::addValue(Driver& driver, uint64_t rowNum, column_id_t columnIdx,
std::string_view strVal, std::vector<uint64_t>& escapePositions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,18 @@ bool ParallelCSVReader::finishedBlock() const {
void ParallelCSVScanSharedState::setFileComplete(uint64_t completedFileIdx) {
std::lock_guard<std::mutex> guard{lock};
if (completedFileIdx == fileIdx) {
numBlocksReadByFiles += blockIdx;
blockIdx = 0;
fileIdx++;
}
}

static offset_t tableFunc(TableFuncInput& input, TableFuncOutput& output) {
auto& outputChunk = output.dataChunk;
auto parallelCSVLocalState = reinterpret_cast<ParallelCSVLocalState*>(input.localState);
auto parallelCSVSharedState = reinterpret_cast<ParallelCSVScanSharedState*>(input.sharedState);
auto parallelCSVLocalState =
ku_dynamic_cast<TableFuncLocalState*, ParallelCSVLocalState*>(input.localState);
auto parallelCSVSharedState =
ku_dynamic_cast<TableFuncSharedState*, ParallelCSVScanSharedState*>(input.sharedState);
do {
if (parallelCSVLocalState->reader != nullptr &&
parallelCSVLocalState->reader->hasMoreToRead()) {
Expand Down Expand Up @@ -152,7 +155,7 @@ static offset_t tableFunc(TableFuncInput& input, TableFuncOutput& output) {

static std::unique_ptr<TableFuncBindData> bindFunc(
main::ClientContext* /*context*/, TableFuncBindInput* input) {
auto scanInput = reinterpret_cast<ScanTableFuncBindInput*>(input);
auto scanInput = ku_dynamic_cast<TableFuncBindInput*, ScanTableFuncBindInput*>(input);
std::vector<std::string> detectedColumnNames;
std::vector<LogicalType> detectedColumnTypes;
SerialCSVScan::bindColumns(scanInput, detectedColumnNames, detectedColumnTypes);
Expand All @@ -165,28 +168,48 @@ static std::unique_ptr<TableFuncBindData> bindFunc(
}

static std::unique_ptr<TableFuncSharedState> initSharedState(TableFunctionInitInput& input) {
auto bindData = reinterpret_cast<ScanBindData*>(input.bindData);
auto bindData = ku_dynamic_cast<TableFuncBindData*, ScanBindData*>(input.bindData);
auto csvConfig = CSVReaderConfig::construct(bindData->config.options);
row_idx_t numRows = 0;
return std::make_unique<ParallelCSVScanSharedState>(bindData->config.copy(), numRows,
bindData->columnNames.size(), bindData->context, csvConfig.copy());
auto sharedState = std::make_unique<ParallelCSVScanSharedState>(bindData->config.copy(),
numRows, bindData->columnNames.size(), bindData->context, csvConfig.copy());
for (auto filePath : sharedState->readerConfig.filePaths) {
auto reader = std::make_unique<ParallelCSVReader>(filePath,
sharedState->csvReaderConfig.option.copy(), sharedState->numColumns,
sharedState->context);
sharedState->totalSize += reader->getFileSize();
}
return sharedState;
}

static std::unique_ptr<TableFuncLocalState> initLocalState(TableFunctionInitInput& /*input*/,
TableFuncSharedState* state, storage::MemoryManager* /*mm*/) {
auto localState = std::make_unique<ParallelCSVLocalState>();
auto sharedState = reinterpret_cast<ParallelCSVScanSharedState*>(state);
auto sharedState = ku_dynamic_cast<TableFuncSharedState*, ParallelCSVScanSharedState*>(state);
localState->reader = std::make_unique<ParallelCSVReader>(sharedState->readerConfig.filePaths[0],
sharedState->csvReaderConfig.option.copy(), sharedState->numColumns, sharedState->context);
localState->fileIdx = 0;
return localState;
}

static double progressFunc(TableFuncSharedState* sharedState) {
auto state = ku_dynamic_cast<TableFuncSharedState*, ParallelCSVScanSharedState*>(sharedState);
if (state->fileIdx >= state->readerConfig.getNumFiles()) {
return 1.0;
}
if (state->totalSize == 0) {
return 0.0;
}
uint64_t totalReadSize =
(state->numBlocksReadByFiles + state->blockIdx) * CopyConstants::PARALLEL_BLOCK_SIZE;
return static_cast<double>(totalReadSize) / state->totalSize;
}

function_set ParallelCSVScan::getFunctionSet() {
function_set functionSet;
functionSet.push_back(
std::make_unique<TableFunction>(READ_CSV_PARALLEL_FUNC_NAME, tableFunc, bindFunc,
initSharedState, initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
functionSet.push_back(std::make_unique<TableFunction>(READ_CSV_PARALLEL_FUNC_NAME, tableFunc,
bindFunc, initSharedState, initLocalState, progressFunc,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}

Expand Down
34 changes: 27 additions & 7 deletions src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void SerialCSVScanSharedState::read(DataChunk& outputChunk) {
if (numRows > 0) {
return;
}
totalReadSizeByFile += reader->getFileSize();
fileIdx++;
initReader(context);
} while (true);
Expand All @@ -68,7 +69,8 @@ void SerialCSVScanSharedState::initReader(main::ClientContext* context) {
}

static common::offset_t tableFunc(TableFuncInput& input, TableFuncOutput& output) {
auto serialCSVScanSharedState = reinterpret_cast<SerialCSVScanSharedState*>(input.sharedState);
auto serialCSVScanSharedState =
ku_dynamic_cast<TableFuncSharedState*, SerialCSVScanSharedState*>(input.sharedState);
serialCSVScanSharedState->read(output.dataChunk);
return output.dataChunk.state->selVector->selectedSize;
}
Expand Down Expand Up @@ -99,7 +101,7 @@ void SerialCSVScan::bindColumns(const ScanTableFuncBindInput* bindInput,

static std::unique_ptr<TableFuncBindData> bindFunc(
main::ClientContext* /*context*/, TableFuncBindInput* input) {
auto scanInput = reinterpret_cast<ScanTableFuncBindInput*>(input);
auto scanInput = ku_dynamic_cast<TableFuncBindInput*, ScanTableFuncBindInput*>(input);
std::vector<std::string> detectedColumnNames;
std::vector<LogicalType> detectedColumnTypes;
SerialCSVScan::bindColumns(scanInput, detectedColumnNames, detectedColumnTypes);
Expand All @@ -112,23 +114,41 @@ static std::unique_ptr<TableFuncBindData> bindFunc(
}

static std::unique_ptr<TableFuncSharedState> initSharedState(TableFunctionInitInput& input) {
auto bindData = reinterpret_cast<ScanBindData*>(input.bindData);
auto bindData = ku_dynamic_cast<TableFuncBindData*, ScanBindData*>(input.bindData);
auto csvConfig = CSVReaderConfig::construct(bindData->config.options);
row_idx_t numRows = 0;
return std::make_unique<SerialCSVScanSharedState>(bindData->config.copy(), numRows,
auto sharedState = std::make_unique<SerialCSVScanSharedState>(bindData->config.copy(), numRows,
bindData->columnNames.size(), csvConfig.copy(), bindData->context);
for (auto filePath : sharedState->readerConfig.filePaths) {
auto reader =
std::make_unique<SerialCSVReader>(filePath, sharedState->csvReaderConfig.option.copy(),
sharedState->numColumns, sharedState->context);
sharedState->totalSize += reader->getFileSize();
}
return sharedState;
}

static std::unique_ptr<TableFuncLocalState> initLocalState(TableFunctionInitInput& /*input*/,
TableFuncSharedState* /*state*/, storage::MemoryManager* /*mm*/) {
return std::make_unique<TableFuncLocalState>();
}

static double progressFunc(TableFuncSharedState* sharedState) {
auto state = ku_dynamic_cast<TableFuncSharedState*, SerialCSVScanSharedState*>(sharedState);
if (state->totalSize == 0) {
return 0.0;
} else if (state->fileIdx >= state->readerConfig.getNumFiles()) {
return 1.0;
}
uint64_t totalReadSize = state->totalReadSizeByFile + state->reader->getFileOffset();
return static_cast<double>(totalReadSize) / state->totalSize;
}

function_set SerialCSVScan::getFunctionSet() {
function_set functionSet;
functionSet.push_back(
std::make_unique<TableFunction>(READ_CSV_SERIAL_FUNC_NAME, tableFunc, bindFunc,
initSharedState, initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
functionSet.push_back(std::make_unique<TableFunction>(READ_CSV_SERIAL_FUNC_NAME, tableFunc,
bindFunc, initSharedState, initLocalState, progressFunc,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}

Expand Down
Loading
Loading