Skip to content

Commit

Permalink
Merge pull request #2370 from kuzudb/reader-function
Browse files Browse the repository at this point in the history
Implement copy function framework
  • Loading branch information
acquamarin committed Nov 9, 2023
2 parents 6731798 + b9e1045 commit cc7f6ac
Show file tree
Hide file tree
Showing 48 changed files with 1,037 additions and 996 deletions.
39 changes: 29 additions & 10 deletions src/binder/bind/bind_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "common/exception/binder.h"
#include "common/exception/message.h"
#include "common/string_format.h"
#include "function/table_functions.h"
#include "function/table_functions/bind_input.h"
#include "parser/copy.h"

using namespace kuzu::binder;
Expand Down Expand Up @@ -91,22 +93,31 @@ std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& stat
auto fileType = bindFileType(filePaths);
auto readerConfig =
std::make_unique<ReaderConfig>(fileType, std::move(filePaths), std::move(csvReaderConfig));
auto inputType = std::make_unique<LogicalType>(LogicalTypeID::STRING);
std::vector<LogicalType*> inputTypes;
inputTypes.push_back(inputType.get());
auto scanFunction =
getScanFunction(fileType, std::move(inputTypes), readerConfig->csvReaderConfig->parallel);
std::vector<std::unique_ptr<Value>> inputValues;
inputValues.push_back(std::make_unique<Value>(Value::createValue(readerConfig->filePaths[0])));
auto tableFuncBindInput =
std::make_unique<function::ScanTableFuncBindInput>(*readerConfig, memoryManager);
validateByColumnKeyword(readerConfig->fileType, copyStatement.byColumn());
if (readerConfig->fileType == FileType::NPY) {
validateCopyNpyNotForRelTables(tableSchema);
}
switch (tableSchema->tableType) {
case TableType::NODE:
if (readerConfig->fileType == FileType::TURTLE) {
return bindCopyRdfNodeFrom(std::move(readerConfig), tableSchema);
return bindCopyRdfNodeFrom(scanFunction, std::move(readerConfig), tableSchema);
} else {
return bindCopyNodeFrom(std::move(readerConfig), tableSchema);
return bindCopyNodeFrom(scanFunction, std::move(readerConfig), tableSchema);
}
case TableType::REL: {
if (readerConfig->fileType == FileType::TURTLE) {
return bindCopyRdfRelFrom(std::move(readerConfig), tableSchema);
return bindCopyRdfRelFrom(scanFunction, std::move(readerConfig), tableSchema);
} else {
return bindCopyRelFrom(std::move(readerConfig), tableSchema);
return bindCopyRelFrom(scanFunction, std::move(readerConfig), tableSchema);
}
}
// LCOV_EXCL_START
Expand All @@ -117,28 +128,36 @@ std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& stat
}
}

std::unique_ptr<BoundStatement> Binder::bindCopyNodeFrom(
std::unique_ptr<ReaderConfig> readerConfig, TableSchema* tableSchema) {
std::unique_ptr<BoundStatement> Binder::bindCopyNodeFrom(function::TableFunction* copyFunc,
std::unique_ptr<common::ReaderConfig> readerConfig, TableSchema* tableSchema) {
// For table with SERIAL columns, we need to read in serial from files.
auto containsSerial = bindContainsSerial(tableSchema);
auto columns = bindExpectedNodeFileColumns(tableSchema, *readerConfig);
auto tableFuncBindInput =
std::make_unique<function::ScanTableFuncBindInput>(std::move(*readerConfig), memoryManager);
auto copyFuncBindData =
copyFunc->bindFunc(clientContext, tableFuncBindInput.get(), catalog.getReadOnlyVersion());
auto nodeID = createVariable(std::string(Property::INTERNAL_ID_NAME), LogicalTypeID::INT64);
auto boundFileScanInfo = std::make_unique<BoundFileScanInfo>(
std::move(readerConfig), columns, std::move(nodeID), TableType::NODE);
copyFunc, copyFuncBindData->copy(), columns, std::move(nodeID), TableType::NODE);
auto boundCopyFromInfo = std::make_unique<BoundCopyFromInfo>(tableSchema,
std::move(boundFileScanInfo), containsSerial, std::move(columns), nullptr /* extraInfo */);
return std::make_unique<BoundCopyFrom>(std::move(boundCopyFromInfo));
}

std::unique_ptr<BoundStatement> Binder::bindCopyRelFrom(
std::unique_ptr<ReaderConfig> readerConfig, TableSchema* tableSchema) {
std::unique_ptr<BoundStatement> Binder::bindCopyRelFrom(function::TableFunction* copyFunc,
std::unique_ptr<common::ReaderConfig> readerConfig, TableSchema* tableSchema) {
// For table with SERIAL columns, we need to read in serial from files.
auto containsSerial = bindContainsSerial(tableSchema);
KU_ASSERT(containsSerial == false);
auto columnsToRead = bindExpectedRelFileColumns(tableSchema, *readerConfig);
auto tableFuncBindInput =
std::make_unique<function::ScanTableFuncBindInput>(std::move(*readerConfig), memoryManager);
auto copyFuncBindData =
copyFunc->bindFunc(clientContext, tableFuncBindInput.get(), catalog.getReadOnlyVersion());
auto relID = createVariable(std::string(Property::INTERNAL_ID_NAME), LogicalTypeID::INT64);
auto boundFileScanInfo = std::make_unique<BoundFileScanInfo>(
std::move(readerConfig), columnsToRead, relID->copy(), TableType::REL);
copyFunc, copyFuncBindData->copy(), columnsToRead, relID->copy(), TableType::REL);
auto relTableSchema = reinterpret_cast<RelTableSchema*>(tableSchema);
auto srcTableSchema =
catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getSrcTableID());
Expand Down
37 changes: 29 additions & 8 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "binder/binder.h"
#include "binder/expression/variable_expression.h"
#include "binder/query/reading_clause/bound_in_query_call.h"
#include "binder/query/reading_clause/bound_load_from.h"
#include "binder/query/reading_clause/bound_match_clause.h"
Expand All @@ -14,7 +15,7 @@
#include "parser/query/reading_clause/match_clause.h"
#include "parser/query/reading_clause/unwind_clause.h"
#include "processor/operator/persistent/reader/csv/serial_csv_reader.h"
#include "processor/operator/persistent/reader/npy_reader.h"
#include "processor/operator/persistent/reader/npy/npy_reader.h"
#include "processor/operator/persistent/reader/parquet/parquet_reader.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -117,15 +118,20 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
StringUtils::toUpper(funcNameToMatch);
auto tableFunction = reinterpret_cast<function::TableFunction*>(
catalog.getBuiltInFunctions()->matchScalarFunction(std::move(funcNameToMatch), inputTypes));
auto bindData = tableFunction->bindFunc(clientContext,
function::TableFuncBindInput{std::move(inputValues)}, catalog.getReadOnlyVersion());
std::unique_ptr<function::TableFuncBindInput> tableFuncBindInput;
tableFuncBindInput = std::make_unique<function::TableFuncBindInput>(std::move(inputValues));
auto bindData = tableFunction->bindFunc(
clientContext, tableFuncBindInput.get(), catalog.getReadOnlyVersion());
expression_vector outputExpressions;
for (auto i = 0u; i < bindData->returnColumnNames.size(); i++) {
outputExpressions.push_back(
createVariable(bindData->returnColumnNames[i], bindData->returnTypes[i]));
createVariable(bindData->returnColumnNames[i], *bindData->returnTypes[i]));
}
return std::make_unique<BoundInQueryCall>(
std::move(tableFunction), std::move(bindData), std::move(outputExpressions));
return std::make_unique<BoundInQueryCall>(std::move(tableFunction), std::move(bindData),
std::move(outputExpressions),
std::make_shared<VariableExpression>(LogicalType{LogicalTypeID::INT64},
getUniqueExpressionName(CopyConstants::ROW_IDX_COLUMN_NAME),
CopyConstants::ROW_IDX_COLUMN_NAME));
}

static std::unique_ptr<LogicalType> bindFixedListType(
Expand Down Expand Up @@ -172,8 +178,23 @@ std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(
}
columns = createColumnExpressions(*readerConfig, expectedColumnNames, expectedColumnTypes);
}
auto info = std::make_unique<BoundFileScanInfo>(
std::move(readerConfig), std::move(columns), nullptr /* offset */, TableType::UNKNOWN);
auto inputType = std::make_unique<LogicalType>(LogicalTypeID::STRING);
std::vector<LogicalType*> inputTypes;
inputTypes.push_back(inputType.get());
auto scanFunction = getScanFunction(
readerConfig->fileType, std::move(inputTypes), readerConfig->csvReaderConfig->parallel);
std::vector<std::unique_ptr<Value>> inputValues;
inputValues.push_back(std::make_unique<Value>(Value::createValue(readerConfig->filePaths[0])));
std::unique_ptr<function::TableFuncBindInput> tableFuncBindInput =
std::make_unique<function::ScanTableFuncBindInput>(*readerConfig, memoryManager);
auto bindData = scanFunction->bindFunc(
clientContext, tableFuncBindInput.get(), catalog.getReadOnlyVersion());
auto info =
std::make_unique<BoundFileScanInfo>(scanFunction, std::move(bindData), std::move(columns),
std::make_shared<VariableExpression>(LogicalType{LogicalTypeID::INT64},
getUniqueExpressionName(CopyConstants::ROW_IDX_COLUMN_NAME),
CopyConstants::ROW_IDX_COLUMN_NAME),
TableType::UNKNOWN);
auto boundLoadFrom = std::make_unique<BoundLoadFrom>(std::move(info));
if (loadFrom.hasWherePredicate()) {
auto wherePredicate = expressionBinder.bindExpression(*loadFrom.getWherePredicate());
Expand Down
17 changes: 13 additions & 4 deletions src/binder/bind/copy/bind_copy_rdf_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "common/constants.h"
#include "common/keyword/rdf_keyword.h"
#include "common/types/rdf_variant_type.h"
#include "function/table_functions/bind_input.h"

using namespace kuzu::binder;
using namespace kuzu::catalog;
Expand All @@ -13,7 +14,7 @@ using namespace kuzu::parser;
namespace kuzu {
namespace binder {

std::unique_ptr<BoundStatement> Binder::bindCopyRdfNodeFrom(
std::unique_ptr<BoundStatement> Binder::bindCopyRdfNodeFrom(function::TableFunction* copyFunc,
std::unique_ptr<ReaderConfig> readerConfig, TableSchema* tableSchema) {
bool containsSerial;
auto stringType = LogicalType{LogicalTypeID::STRING};
Expand All @@ -35,14 +36,18 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfNodeFrom(
std::make_unique<RdfReaderConfig>(RdfReaderMode::LITERAL, nullptr /* index */);
columns.push_back(createVariable(columnName, *RdfVariantType::getType()));
}
auto tableFuncBindInput =
std::make_unique<function::ScanTableFuncBindInput>(*readerConfig, memoryManager);
auto bindData =
copyFunc->bindFunc(clientContext, tableFuncBindInput.get(), catalog.getReadOnlyVersion());
auto boundFileScanInfo = std::make_unique<BoundFileScanInfo>(
std::move(readerConfig), columns, std::move(nodeID), TableType::NODE);
copyFunc, std::move(bindData), columns, std::move(nodeID), TableType::NODE);
auto boundCopyFromInfo = std::make_unique<BoundCopyFromInfo>(tableSchema,
std::move(boundFileScanInfo), containsSerial, std::move(columns), nullptr /* extraInfo */);
return std::make_unique<BoundCopyFrom>(std::move(boundCopyFromInfo));
}

std::unique_ptr<BoundStatement> Binder::bindCopyRdfRelFrom(
std::unique_ptr<BoundStatement> Binder::bindCopyRdfRelFrom(function::TableFunction* copyFunc,
std::unique_ptr<ReaderConfig> readerConfig, TableSchema* tableSchema) {
auto containsSerial = false;
auto offsetType = std::make_unique<LogicalType>(LogicalTypeID::INT64);
Expand All @@ -64,8 +69,12 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfRelFrom(
std::make_unique<RdfReaderConfig>(RdfReaderMode::LITERAL_TRIPLE, index);
}
auto relID = createVariable(std::string(Property::INTERNAL_ID_NAME), LogicalTypeID::INT64);
auto tableFuncBindInput =
std::make_unique<function::ScanTableFuncBindInput>(*readerConfig, memoryManager);
auto bindData =
copyFunc->bindFunc(clientContext, tableFuncBindInput.get(), catalog.getReadOnlyVersion());
auto boundFileScanInfo = std::make_unique<BoundFileScanInfo>(
std::move(readerConfig), columns, relID, TableType::REL);
copyFunc, std::move(bindData), columns, relID, TableType::REL);
auto extraInfo = std::make_unique<ExtraBoundCopyRdfRelInfo>(columns[0], columns[1], columns[2]);
columns.push_back(std::move(relID));
auto boundCopyFromInfo = std::make_unique<BoundCopyFromInfo>(tableSchema,
Expand Down
25 changes: 25 additions & 0 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,5 +209,30 @@ void Binder::restoreScope(std::unique_ptr<BinderScope> prevVariableScope) {
scope = std::move(prevVariableScope);
}

function::TableFunction* Binder::getScanFunction(
common::FileType fileType, std::vector<common::LogicalType*> inputTypes, bool isParallel) {
switch (fileType) {
case common::FileType::PARQUET:
return reinterpret_cast<function::TableFunction*>(
catalog.getBuiltInFunctions()->matchScalarFunction(
READ_PARQUET_FUNC_NAME, std::move(inputTypes)));
case common::FileType::NPY:
return reinterpret_cast<function::TableFunction*>(
catalog.getBuiltInFunctions()->matchScalarFunction(
READ_NPY_FUNC_NAME, std::move(inputTypes)));
case common::FileType::CSV:
return reinterpret_cast<function::TableFunction*>(
catalog.getBuiltInFunctions()->matchScalarFunction(
isParallel ? READ_CSV_PARALLEL_FUNC_NAME : READ_CSV_SERIAL_FUNC_NAME,
std::move(inputTypes)));
case common::FileType::TURTLE:
return reinterpret_cast<function::TableFunction*>(
catalog.getBuiltInFunctions()->matchScalarFunction(
READ_RDF_FUNC_NAME, std::move(inputTypes)));
default:
KU_UNREACHABLE;
}
}

} // namespace binder
} // namespace kuzu
10 changes: 10 additions & 0 deletions src/function/built_in_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
#include "function/table_functions/call_functions.h"
#include "function/timestamp/vector_timestamp_functions.h"
#include "function/union/vector_union_functions.h"
#include "processor/operator/persistent/reader/csv/parallel_csv_reader.h"
#include "processor/operator/persistent/reader/csv/serial_csv_reader.h"
#include "processor/operator/persistent/reader/npy/npy_reader.h"
#include "processor/operator/persistent/reader/parquet/parquet_reader.h"
#include "processor/operator/persistent/reader/rdf/rdf_reader.h"

using namespace kuzu::common;

Expand Down Expand Up @@ -754,6 +759,11 @@ void BuiltInFunctions::registerTableFunctions() {
functions.insert({SHOW_TABLES_FUNC_NAME, ShowTablesFunction::getFunctionSet()});
functions.insert({TABLE_INFO_FUNC_NAME, TableInfoFunction::getFunctionSet()});
functions.insert({SHOW_CONNECTION_FUNC_NAME, ShowConnectionFunction::getFunctionSet()});
functions.insert({READ_PARQUET_FUNC_NAME, processor::ParquetScanFunction::getFunctionSet()});
functions.insert({READ_NPY_FUNC_NAME, processor::NpyScanFunction::getFunctionSet()});
functions.insert({READ_CSV_SERIAL_FUNC_NAME, processor::SerialCSVScan::getFunctionSet()});
functions.insert({READ_CSV_PARALLEL_FUNC_NAME, processor::ParallelCSVScan::getFunctionSet()});
functions.insert({READ_RDF_FUNC_NAME, processor::RdfScan::getFunctionSet()});
}

void BuiltInFunctions::addFunction(std::string name, function::function_set definitions) {
Expand Down
3 changes: 2 additions & 1 deletion src/function/table_functions/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(kuzu_table_function
OBJECT
call_functions.cpp)
call_functions.cpp
scan_functions.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_table_function>
Expand Down
Loading

0 comments on commit cc7f6ac

Please sign in to comment.