Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added progress for in_query_call operators #3120

Merged
merged 8 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/common/task_system/progress_bar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,23 @@
return;
}
numPipelinesFinished++;
if (printing) {
std::cout << "\033[1A\033[2K\033[1B";

Check warning on line 34 in src/common/task_system/progress_bar.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/task_system/progress_bar.cpp#L34

Added line #L34 was not covered by tests
MSebanc marked this conversation as resolved.
Show resolved Hide resolved
}
// This ensures that the progress bar is updated back to 0% after a pipeline is finished.
prevCurPipelineProgress = -0.01;
updateProgress(0.0);
}

void ProgressBar::updateProgress(double curPipelineProgress) {
// Only update the progress bar if the progress has changed by at least 1%.
if (!trackProgress || curPipelineProgress - prevCurPipelineProgress < 0.01) {
if (!trackProgress) {
return;
}
std::lock_guard<std::mutex> lock(progressBarLock);
if (curPipelineProgress - prevCurPipelineProgress < 0.01) {
return;
}
prevCurPipelineProgress = curPipelineProgress;
if (printing) {
std::cout << "\033[2A";
Expand Down
3 changes: 2 additions & 1 deletion src/include/function/table/scan_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ struct ScanSharedState : public BaseScanSharedState {

struct ScanFileSharedState : public ScanSharedState {
main::ClientContext* context;
uint64_t totalSize;

ScanFileSharedState(
common::ReaderConfig readerConfig, uint64_t numRows, main::ClientContext* context)
: ScanSharedState{std::move(readerConfig), numRows}, context{context} {}
: ScanSharedState{std::move(readerConfig), numRows}, context{context}, totalSize{0} {}
};

} // namespace function
Expand Down
8 changes: 8 additions & 0 deletions src/include/function/table_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,15 @@ using table_func_init_shared_t =
using table_func_init_local_t = std::function<std::unique_ptr<TableFuncLocalState>(
TableFunctionInitInput&, TableFuncSharedState*, storage::MemoryManager*)>;
using table_func_can_parallel_t = std::function<bool()>;
using table_func_progress_t = std::function<double(TableFuncSharedState* sharedState)>;

struct TableFunction final : public Function {
table_func_t tableFunc;
table_func_bind_t bindFunc;
table_func_init_shared_t initSharedStateFunc;
table_func_init_local_t initLocalStateFunc;
table_func_can_parallel_t canParallelFunc = [] { return true; };
table_func_progress_t progressFunc = [](TableFuncSharedState* /*sharedState*/) { return 0.0; };

TableFunction()
: Function{}, tableFunc{nullptr}, bindFunc{nullptr}, initSharedStateFunc{nullptr},
Expand All @@ -80,6 +82,12 @@ struct TableFunction final : public Function {
: Function{FunctionType::TABLE, std::move(name), std::move(inputTypes)},
tableFunc{tableFunc}, bindFunc{bindFunc}, initSharedStateFunc{initSharedFunc},
initLocalStateFunc{initLocalFunc} {}
TableFunction(std::string name, table_func_t tableFunc, table_func_bind_t bindFunc,
table_func_init_shared_t initSharedFunc, table_func_init_local_t initLocalFunc,
table_func_progress_t progressFunc, std::vector<common::LogicalTypeID> inputTypes)
: Function{FunctionType::TABLE, std::move(name), std::move(inputTypes)},
tableFunc{tableFunc}, bindFunc{bindFunc}, initSharedStateFunc{initSharedFunc},
initLocalStateFunc{initLocalFunc}, progressFunc{progressFunc} {}

inline std::string signatureToString() const override {
return common::LogicalTypeUtils::toString(parameterTypeIDs);
Expand Down
2 changes: 2 additions & 0 deletions src/include/processor/operator/call/in_query_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class InQueryCall : public PhysicalOperator {

bool getNextTuplesInternal(ExecutionContext* context) override;

double getProgress(ExecutionContext* context) const override;

std::unique_ptr<PhysicalOperator> clone() override {
return std::make_unique<InQueryCall>(info.copy(), sharedState, id, paramsString);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class BaseCSVReader {

uint64_t countRows();
bool isEOF() const;
uint64_t getFileSize();
// Get the file offset of the current buffer position.
uint64_t getFileOffset() const;

protected:
template<typename Driver>
Expand Down Expand Up @@ -56,8 +59,6 @@ class BaseCSVReader {

inline bool isNewLine(char c) { return c == '\n' || c == '\r'; }

// Get the file offset of the current buffer position.
uint64_t getFileOffset() const;
uint64_t getLineNumber();

protected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ struct ParquetScanSharedState final : public function::ScanFileSharedState {
const common::ReaderConfig readerConfig, uint64_t numRows, main::ClientContext* context);

std::vector<std::unique_ptr<ParquetReader>> readers;
uint64_t totalRowsGroups;
};

struct ParquetScanLocalState final : public function::TableFuncLocalState {
Expand Down
4 changes: 4 additions & 0 deletions src/processor/operator/call/in_query_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,9 @@ bool InQueryCall::getNextTuplesInternal(ExecutionContext*) {
return numTuplesScanned != 0;
}

double InQueryCall::getProgress(ExecutionContext* /*context*/) const {
return info.function.progressFunc(sharedState->funcState.get());
}

} // namespace processor
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ bool BaseCSVReader::isEOF() const {
return getFileOffset() >= fileInfo->getFileSize();
}

uint64_t BaseCSVReader::getFileSize() {
return fileInfo->getFileSize();
}

template<typename Driver>
void BaseCSVReader::addValue(Driver& driver, uint64_t rowNum, column_id_t columnIdx,
std::string_view strVal, std::vector<uint64_t>& escapePositions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,15 @@ static std::unique_ptr<TableFuncSharedState> initSharedState(TableFunctionInitIn
auto bindData = reinterpret_cast<ScanBindData*>(input.bindData);
auto csvConfig = CSVReaderConfig::construct(bindData->config.options);
row_idx_t numRows = 0;
return std::make_unique<ParallelCSVScanSharedState>(bindData->config.copy(), numRows,
bindData->columnNames.size(), bindData->context, csvConfig.copy());
auto sharedState = std::make_unique<ParallelCSVScanSharedState>(bindData->config.copy(),
numRows, bindData->columnNames.size(), bindData->context, csvConfig.copy());
for (auto filePath : sharedState->readerConfig.filePaths) {
auto reader = std::make_unique<ParallelCSVReader>(filePath,
sharedState->csvReaderConfig.option.copy(), sharedState->numColumns,
sharedState->context);
sharedState->totalSize += reader->getFileSize();
}
return sharedState;
}

static std::unique_ptr<TableFuncLocalState> initLocalState(TableFunctionInitInput& /*input*/,
Expand All @@ -182,11 +189,24 @@ static std::unique_ptr<TableFuncLocalState> initLocalState(TableFunctionInitInpu
return localState;
}

static double progressFunc(TableFuncSharedState* sharedState) {
auto state = reinterpret_cast<ParallelCSVScanSharedState*>(sharedState);
MSebanc marked this conversation as resolved.
Show resolved Hide resolved
if (state->totalSize == 0) {
return 0.0;
} else if ((state->blockIdx * CopyConstants::PARALLEL_BLOCK_SIZE) * (state->fileIdx + 1) >=
MSebanc marked this conversation as resolved.
Show resolved Hide resolved
state->totalSize) {
return 1.0;
}
return static_cast<double>(
(state->blockIdx * CopyConstants::PARALLEL_BLOCK_SIZE) * (state->fileIdx + 1)) /
state->totalSize;
}

function_set ParallelCSVScan::getFunctionSet() {
function_set functionSet;
functionSet.push_back(
std::make_unique<TableFunction>(READ_CSV_PARALLEL_FUNC_NAME, tableFunc, bindFunc,
initSharedState, initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
functionSet.push_back(std::make_unique<TableFunction>(READ_CSV_PARALLEL_FUNC_NAME, tableFunc,
bindFunc, initSharedState, initLocalState, progressFunc,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}

Expand Down
30 changes: 26 additions & 4 deletions src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,42 @@
auto bindData = reinterpret_cast<ScanBindData*>(input.bindData);
auto csvConfig = CSVReaderConfig::construct(bindData->config.options);
row_idx_t numRows = 0;
return std::make_unique<SerialCSVScanSharedState>(bindData->config.copy(), numRows,
auto sharedState = std::make_unique<SerialCSVScanSharedState>(bindData->config.copy(), numRows,
bindData->columnNames.size(), csvConfig.copy(), bindData->context);
for (auto filePath : sharedState->readerConfig.filePaths) {
auto reader =
std::make_unique<SerialCSVReader>(filePath, sharedState->csvReaderConfig.option.copy(),
sharedState->numColumns, sharedState->context);
sharedState->totalSize += reader->getFileSize();
}
return sharedState;
}

static std::unique_ptr<TableFuncLocalState> initLocalState(TableFunctionInitInput& /*input*/,
TableFuncSharedState* /*state*/, storage::MemoryManager* /*mm*/) {
return std::make_unique<TableFuncLocalState>();
}

static double progressFunc(TableFuncSharedState* sharedState) {
auto state = reinterpret_cast<SerialCSVScanSharedState*>(sharedState);
MSebanc marked this conversation as resolved.
Show resolved Hide resolved
if (state->totalSize == 0) {
return 0.0;
} else if (state->fileIdx >= state->readerConfig.getNumFiles()) {
return 1.0;
}
uint64_t totalReadSize = 0;
for (auto i = 0u; i < state->fileIdx; i++) {
MSebanc marked this conversation as resolved.
Show resolved Hide resolved
totalReadSize += state->reader->getFileSize();

Check warning on line 143 in src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp

View check run for this annotation

Codecov / codecov/patch

src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp#L143

Added line #L143 was not covered by tests
}
totalReadSize += state->reader->getFileOffset();
return static_cast<double>(totalReadSize) / state->totalSize;
}

function_set SerialCSVScan::getFunctionSet() {
function_set functionSet;
functionSet.push_back(
std::make_unique<TableFunction>(READ_CSV_SERIAL_FUNC_NAME, tableFunc, bindFunc,
initSharedState, initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
functionSet.push_back(std::make_unique<TableFunction>(READ_CSV_SERIAL_FUNC_NAME, tableFunc,
bindFunc, initSharedState, initLocalState, progressFunc,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,11 @@
: ScanFileSharedState{std::move(readerConfig), numRows, context} {
readers.push_back(
std::make_unique<ParquetReader>(this->readerConfig.filePaths[fileIdx], context));
totalRowsGroups = 0;
for (auto i = 0u; i < readerConfig.getNumFiles(); i++) {
auto reader = std::make_unique<ParquetReader>(readerConfig.filePaths[i], context);
totalRowsGroups += reader->getNumRowsGroups();
}

Check warning on line 570 in src/processor/operator/persistent/reader/parquet/parquet_reader.cpp

View check run for this annotation

Codecov / codecov/patch

src/processor/operator/persistent/reader/parquet/parquet_reader.cpp#L568-L570

Added lines #L568 - L570 were not covered by tests
}

static bool parquetSharedStateNext(
Expand Down Expand Up @@ -674,11 +679,23 @@
return localState;
}

static double progressFunc(TableFuncSharedState* state) {
auto parquetScanSharedState = reinterpret_cast<ParquetScanSharedState*>(state);
if (parquetScanSharedState->fileIdx >= parquetScanSharedState->readerConfig.getNumFiles()) {
return 1.0;
}
if (parquetScanSharedState->totalRowsGroups == 0) {
return 0.0;
}
return (double)(parquetScanSharedState->blockIdx * (parquetScanSharedState->fileIdx + 1)) /
MSebanc marked this conversation as resolved.
Show resolved Hide resolved
parquetScanSharedState->totalRowsGroups;

Check warning on line 691 in src/processor/operator/persistent/reader/parquet/parquet_reader.cpp

View check run for this annotation

Codecov / codecov/patch

src/processor/operator/persistent/reader/parquet/parquet_reader.cpp#L690-L691

Added lines #L690 - L691 were not covered by tests
}

function_set ParquetScanFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(
std::make_unique<TableFunction>(READ_PARQUET_FUNC_NAME, tableFunc, bindFunc,
initSharedState, initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
functionSet.push_back(std::make_unique<TableFunction>(READ_PARQUET_FUNC_NAME, tableFunc,
bindFunc, initSharedState, initLocalState, progressFunc,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}

Expand Down
4 changes: 3 additions & 1 deletion tools/python_api/src_cpp/include/pandas/pandas_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ struct PandasScanLocalState : public function::TableFuncLocalState {
};

struct PandasScanSharedState : public function::BaseScanSharedState {
explicit PandasScanSharedState(uint64_t numRows) : BaseScanSharedState{numRows}, position{0} {}
explicit PandasScanSharedState(uint64_t numRows)
: BaseScanSharedState{numRows}, position{0}, numReadRows{0} {}

std::mutex lock;
uint64_t position;
uint64_t numReadRows;
};

struct PandasScanFunction {
Expand Down
12 changes: 11 additions & 1 deletion tools/python_api/src_cpp/pandas/pandas_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ offset_t tableFunc(
TableFuncInput& input, TableFuncOutput& output) {
auto pandasScanData = reinterpret_cast<PandasScanFunctionData*>(input.bindData);
auto pandasLocalState = reinterpret_cast<PandasScanLocalState*>(input.localState);
auto pandasSharedState = reinterpret_cast<PandasScanSharedState*>(input.sharedState);

if (pandasLocalState->start >= pandasLocalState->end) {
if (!sharedStateNext(input.bindData, pandasLocalState, input.sharedState)) {
Expand All @@ -97,6 +98,7 @@ offset_t tableFunc(
}
output.dataChunk.state->selVector->selectedSize = numValuesToOutput;
pandasLocalState->start += numValuesToOutput;
pandasSharedState->numReadRows += numValuesToOutput;
return numValuesToOutput;
}

Expand All @@ -109,9 +111,17 @@ std::vector<std::unique_ptr<PandasColumnBindData>> PandasScanFunctionData::copyC
return result;
}

static double progressFunc(TableFuncSharedState* sharedState) {
auto pandasSharedState = reinterpret_cast<PandasScanSharedState*>(sharedState);
if (pandasSharedState->numRows == 0) {
return 0.0;
}
return static_cast<double>(pandasSharedState->numReadRows) / pandasSharedState->numRows;
}

static TableFunction getFunction() {
return TableFunction(READ_PANDAS_FUNC_NAME, tableFunc, bindFunc, initSharedState,
initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::POINTER});
initLocalState, progressFunc, std::vector<LogicalTypeID>{LogicalTypeID::POINTER});
}

function_set PandasScanFunction::getFunctionSet() {
Expand Down
13 changes: 11 additions & 2 deletions tools/python_api/src_cpp/pyarrow/pyarrow_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,27 @@
return len;
}

double progressFunc(function::TableFuncSharedState* sharedState) {
PyArrowTableScanSharedState* state =
dynamic_cast<PyArrowTableScanSharedState*>(sharedState);
if (state->chunks.size() == 0) {
return 0.0;
}
return static_cast<double>(state->currentChunk) / state->chunks.size();
}

function::function_set PyArrowTableScanFunction::getFunctionSet() {

function_set functionSet;
functionSet.push_back(
std::make_unique<TableFunction>(READ_PYARROW_FUNC_NAME, tableFunc, bindFunc,
initSharedState, initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::POINTER}));
initSharedState, initLocalState, progressFunc, std::vector<LogicalTypeID>{LogicalTypeID::POINTER}));

Check warning on line 102 in tools/python_api/src_cpp/pyarrow/pyarrow_scan.cpp

View check run for this annotation

Codecov / codecov/patch

tools/python_api/src_cpp/pyarrow/pyarrow_scan.cpp#L102

Added line #L102 was not covered by tests
return functionSet;
}

TableFunction PyArrowTableScanFunction::getFunction() {
return TableFunction(READ_PYARROW_FUNC_NAME, tableFunc, bindFunc, initSharedState,
initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::POINTER});
initLocalState, progressFunc, std::vector<LogicalTypeID>{LogicalTypeID::POINTER});
}

} // namespace kuzu
Loading