diff --git a/src/binder/bind/bind_copy.cpp b/src/binder/bind/bind_copy.cpp index 257570c95f3..9a91b2a0767 100644 --- a/src/binder/bind/bind_copy.cpp +++ b/src/binder/bind/bind_copy.cpp @@ -37,9 +37,8 @@ std::unique_ptr Binder::bindCopyToClause(const Statement& statem if (fileType != FileType::CSV && fileType != FileType::PARQUET) { throw BinderException(ExceptionMessage::validateCopyToCSVParquetExtensionsException()); } - auto readerConfig = std::make_unique( - fileType, std::vector{boundFilePath}, columnNames, std::move(columnTypes)); - return std::make_unique(std::move(readerConfig), std::move(query)); + return std::make_unique( + boundFilePath, fileType, std::move(columnNames), std::move(columnTypes), std::move(query)); } // As a temporary constraint, we require npy files loaded with COPY FROM BY COLUMN keyword. @@ -93,15 +92,7 @@ std::unique_ptr Binder::bindCopyFromClause(const Statement& stat auto fileType = bindFileType(filePaths); auto readerConfig = std::make_unique(fileType, std::move(filePaths), std::move(csvReaderConfig)); - auto inputType = std::make_unique(LogicalTypeID::STRING); - std::vector inputTypes; - inputTypes.push_back(inputType.get()); - auto scanFunction = - getScanFunction(fileType, std::move(inputTypes), readerConfig->csvReaderConfig->parallel); - std::vector> inputValues; - inputValues.push_back(std::make_unique(Value::createValue(readerConfig->filePaths[0]))); - auto tableFuncBindInput = - std::make_unique(*readerConfig, memoryManager); + auto scanFunction = getScanFunction(fileType, readerConfig->csvReaderConfig->parallel); validateByColumnKeyword(readerConfig->fileType, copyStatement.byColumn()); if (readerConfig->fileType == FileType::NPY) { validateCopyNpyNotForRelTables(tableSchema); @@ -132,14 +123,21 @@ std::unique_ptr Binder::bindCopyNodeFrom(function::TableFunction std::unique_ptr readerConfig, TableSchema* tableSchema) { // For table with SERIAL columns, we need to read in serial from files. auto containsSerial = bindContainsSerial(tableSchema); - auto columns = bindExpectedNodeFileColumns(tableSchema, *readerConfig); - auto tableFuncBindInput = - std::make_unique(std::move(*readerConfig), memoryManager); - auto copyFuncBindData = - copyFunc->bindFunc(clientContext, tableFuncBindInput.get(), catalog.getReadOnlyVersion()); - auto nodeID = createVariable(std::string(Property::INTERNAL_ID_NAME), LogicalTypeID::INT64); + std::vector expectedColumnNames; + std::vector> expectedColumnTypes; + bindExpectedNodeColumns(tableSchema, expectedColumnNames, expectedColumnTypes); + auto bindInput = std::make_unique(memoryManager, + *readerConfig, std::move(expectedColumnNames), std::move(expectedColumnTypes)); + auto bindData = + copyFunc->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion()); + expression_vector columns; + for (auto i = 0u; i < bindData->columnTypes.size(); i++) { + columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i])); + } + auto offset = expressionBinder.createVariableExpression( + LogicalType(LogicalTypeID::INT64), common::InternalKeyword::ANONYMOUS); auto boundFileScanInfo = std::make_unique( - copyFunc, copyFuncBindData->copy(), columns, std::move(nodeID), TableType::NODE); + copyFunc, std::move(bindData), columns, std::move(offset), TableType::NODE); auto boundCopyFromInfo = std::make_unique(tableSchema, std::move(boundFileScanInfo), containsSerial, std::move(columns), nullptr /* extraInfo */); return std::make_unique(std::move(boundCopyFromInfo)); @@ -150,21 +148,28 @@ std::unique_ptr Binder::bindCopyRelFrom(function::TableFunction* // For table with SERIAL columns, we need to read in serial from files. auto containsSerial = bindContainsSerial(tableSchema); KU_ASSERT(containsSerial == false); - auto columnsToRead = bindExpectedRelFileColumns(tableSchema, *readerConfig); - auto tableFuncBindInput = - std::make_unique(std::move(*readerConfig), memoryManager); - auto copyFuncBindData = - copyFunc->bindFunc(clientContext, tableFuncBindInput.get(), catalog.getReadOnlyVersion()); - auto relID = createVariable(std::string(Property::INTERNAL_ID_NAME), LogicalTypeID::INT64); + std::vector expectedColumnNames; + std::vector> expectedColumnTypes; + bindExpectedRelColumns(tableSchema, expectedColumnNames, expectedColumnTypes); + auto bindInput = std::make_unique(memoryManager, + std::move(*readerConfig), std::move(expectedColumnNames), std::move(expectedColumnTypes)); + auto bindData = + copyFunc->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion()); + expression_vector columns; + for (auto i = 0u; i < bindData->columnTypes.size(); i++) { + columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i])); + } + auto offset = expressionBinder.createVariableExpression( + LogicalType(LogicalTypeID::INT64), common::InternalKeyword::ANONYMOUS); auto boundFileScanInfo = std::make_unique( - copyFunc, copyFuncBindData->copy(), columnsToRead, relID->copy(), TableType::REL); + copyFunc, std::move(bindData), columns, offset, TableType::REL); auto relTableSchema = reinterpret_cast(tableSchema); auto srcTableSchema = catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getSrcTableID()); auto dstTableSchema = catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getDstTableID()); - auto srcKey = columnsToRead[0]; - auto dstKey = columnsToRead[1]; + auto srcKey = columns[0]; + auto dstKey = columns[1]; auto srcNodeID = createVariable(std::string(Property::REL_BOUND_OFFSET_NAME), LogicalTypeID::INT64); auto dstNodeID = @@ -172,9 +177,9 @@ std::unique_ptr Binder::bindCopyRelFrom(function::TableFunction* auto extraCopyRelInfo = std::make_unique( srcTableSchema, dstTableSchema, srcNodeID, dstNodeID, srcKey, dstKey); // Skip the first two columns. - expression_vector columnsToCopy{std::move(srcNodeID), std::move(dstNodeID), std::move(relID)}; - for (auto i = NUM_COLUMNS_TO_SKIP_IN_REL_FILE; i < columnsToRead.size(); i++) { - columnsToCopy.push_back(std::move(columnsToRead[i])); + expression_vector columnsToCopy{std::move(srcNodeID), std::move(dstNodeID), std::move(offset)}; + for (auto i = NUM_COLUMNS_TO_SKIP_IN_REL_FILE; i < columns.size(); i++) { + columnsToCopy.push_back(columns[i]); } auto boundCopyFromInfo = std::make_unique(tableSchema, std::move(boundFileScanInfo), @@ -187,97 +192,47 @@ static bool skipPropertyInFile(const Property& property) { TableSchema::isReservedPropertyName(property.getName()); } -expression_vector Binder::bindExpectedNodeFileColumns( - TableSchema* tableSchema, ReaderConfig& readerConfig) { - // Resolve expected columns. - std::vector expectedColumnNames; - std::vector> expectedColumnTypes; - switch (readerConfig.fileType) { - case FileType::NPY: - case FileType::PARQUET: - case FileType::CSV: { - for (auto& property : tableSchema->properties) { - if (skipPropertyInFile(*property)) { - continue; - } - expectedColumnNames.push_back(property->getName()); - expectedColumnTypes.push_back(property->getDataType()->copy()); +void Binder::bindExpectedNodeColumns(catalog::TableSchema* tableSchema, + std::vector& columnNames, + std::vector>& columnTypes) { + for (auto& property : tableSchema->properties) { + if (skipPropertyInFile(*property)) { + continue; } - } break; - default: { - KU_UNREACHABLE; - } - } - // Detect columns from file. - std::vector detectedColumnNames; - std::vector> detectedColumnTypes; - sniffFiles(readerConfig, detectedColumnNames, detectedColumnTypes); - // Validate. - validateNumColumns(expectedColumnTypes.size(), detectedColumnTypes.size()); - if (readerConfig.fileType == common::FileType::PARQUET) { - // HACK(Ziyi): We should allow casting in Parquet reader. - validateColumnTypes(expectedColumnNames, expectedColumnTypes, detectedColumnTypes); + columnNames.push_back(property->getName()); + columnTypes.push_back(property->getDataType()->copy()); } - return createColumnExpressions(readerConfig, expectedColumnNames, expectedColumnTypes); } -expression_vector Binder::bindExpectedRelFileColumns( - TableSchema* tableSchema, ReaderConfig& readerConfig) { +void Binder::bindExpectedRelColumns(catalog::TableSchema* tableSchema, + std::vector& columnNames, + std::vector>& columnTypes) { auto relTableSchema = reinterpret_cast(tableSchema); - expression_vector columns; - switch (readerConfig.fileType) { - case FileType::CSV: - case FileType::PARQUET: - case FileType::NPY: { - auto srcColumnName = std::string(Property::REL_FROM_PROPERTY_NAME); - auto dstColumnName = std::string(Property::REL_TO_PROPERTY_NAME); - readerConfig.columnNames.push_back(srcColumnName); - readerConfig.columnNames.push_back(dstColumnName); - auto srcTable = - catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getSrcTableID()); - KU_ASSERT(srcTable->tableType == TableType::NODE); - auto dstTable = - catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getDstTableID()); - KU_ASSERT(dstTable->tableType == TableType::NODE); - auto srcPKColumnType = - reinterpret_cast(srcTable)->getPrimaryKey()->getDataType()->copy(); - if (srcPKColumnType->getLogicalTypeID() == LogicalTypeID::SERIAL) { - srcPKColumnType = std::make_unique(LogicalTypeID::INT64); - } - auto dstPKColumnType = - reinterpret_cast(dstTable)->getPrimaryKey()->getDataType()->copy(); - if (dstPKColumnType->getLogicalTypeID() == LogicalTypeID::SERIAL) { - dstPKColumnType = std::make_unique(LogicalTypeID::INT64); - } - columns.push_back(createVariable(srcColumnName, *srcPKColumnType)); - columns.push_back(createVariable(dstColumnName, *dstPKColumnType)); - readerConfig.columnTypes.push_back(std::move(srcPKColumnType)); - readerConfig.columnTypes.push_back(std::move(dstPKColumnType)); - for (auto& property : tableSchema->properties) { - if (skipPropertyInFile(*property)) { - continue; - } - readerConfig.columnNames.push_back(property->getName()); - auto columnType = property->getDataType()->copy(); - columns.push_back(createVariable(property->getName(), *columnType)); - readerConfig.columnTypes.push_back(std::move(columnType)); + auto srcTable = reinterpret_cast( + catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getSrcTableID())); + auto dstTable = reinterpret_cast( + catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getDstTableID())); + auto srcColumnName = std::string(Property::REL_FROM_PROPERTY_NAME); + auto dstColumnName = std::string(Property::REL_TO_PROPERTY_NAME); + columnNames.push_back(srcColumnName); + columnNames.push_back(dstColumnName); + auto srcPKColumnType = srcTable->getPrimaryKey()->getDataType()->copy(); + if (srcPKColumnType->getLogicalTypeID() == LogicalTypeID::SERIAL) { + srcPKColumnType = std::make_unique(LogicalTypeID::INT64); + } + auto dstPKColumnType = dstTable->getPrimaryKey()->getDataType()->copy(); + if (dstPKColumnType->getLogicalTypeID() == LogicalTypeID::SERIAL) { + dstPKColumnType = std::make_unique(LogicalTypeID::INT64); + } + columnTypes.push_back(std::move(srcPKColumnType)); + columnTypes.push_back(std::move(dstPKColumnType)); + for (auto& property : tableSchema->properties) { + if (skipPropertyInFile(*property)) { + continue; } - } break; - default: { - KU_UNREACHABLE; - } - } - // Detect columns from file. - std::vector detectedColumnNames; - std::vector> detectedColumnTypes; - sniffFiles(readerConfig, detectedColumnNames, detectedColumnTypes); - // Validate number of columns. - validateNumColumns(readerConfig.getNumColumns(), detectedColumnTypes.size()); - if (readerConfig.fileType == common::FileType::PARQUET) { - validateColumnTypes( - readerConfig.columnNames, readerConfig.columnTypes, detectedColumnTypes); + columnNames.push_back(property->getName()); + columnTypes.push_back(property->getDataType()->copy()); } - return columns; } } // namespace binder diff --git a/src/binder/bind/bind_reading_clause.cpp b/src/binder/bind/bind_reading_clause.cpp index 214d886e248..03750837d62 100644 --- a/src/binder/bind/bind_reading_clause.cpp +++ b/src/binder/bind/bind_reading_clause.cpp @@ -1,5 +1,4 @@ #include "binder/binder.h" -#include "binder/expression/variable_expression.h" #include "binder/query/reading_clause/bound_in_query_call.h" #include "binder/query/reading_clause/bound_load_from.h" #include "binder/query/reading_clause/bound_match_clause.h" @@ -14,13 +13,9 @@ #include "parser/query/reading_clause/load_from.h" #include "parser/query/reading_clause/match_clause.h" #include "parser/query/reading_clause/unwind_clause.h" -#include "processor/operator/persistent/reader/csv/serial_csv_reader.h" -#include "processor/operator/persistent/reader/npy/npy_reader.h" -#include "processor/operator/persistent/reader/parquet/parquet_reader.h" using namespace kuzu::common; using namespace kuzu::parser; -using namespace kuzu::processor; using namespace kuzu::catalog; namespace kuzu { @@ -114,35 +109,23 @@ std::unique_ptr Binder::bindInQueryCall(const ReadingClause& inputTypes.push_back(expressionValue->getDataType()); inputValues.push_back(expressionValue->copy()); } - auto funcNameToMatch = funcExpr->getFunctionName(); - StringUtils::toUpper(funcNameToMatch); + auto funcName = funcExpr->getFunctionName(); + StringUtils::toUpper(funcName); + // TODO: this is dangerous because we could match to a scan function. auto tableFunction = reinterpret_cast( - catalog.getBuiltInFunctions()->matchScalarFunction(std::move(funcNameToMatch), inputTypes)); - std::unique_ptr tableFuncBindInput; - tableFuncBindInput = std::make_unique(std::move(inputValues)); - auto bindData = tableFunction->bindFunc( - clientContext, tableFuncBindInput.get(), catalog.getReadOnlyVersion()); - expression_vector outputExpressions; - for (auto i = 0u; i < bindData->returnColumnNames.size(); i++) { - outputExpressions.push_back( - createVariable(bindData->returnColumnNames[i], *bindData->returnTypes[i])); - } - return std::make_unique(std::move(tableFunction), std::move(bindData), - std::move(outputExpressions), - std::make_shared(LogicalType{LogicalTypeID::INT64}, - getUniqueExpressionName(CopyConstants::ROW_IDX_COLUMN_NAME), - CopyConstants::ROW_IDX_COLUMN_NAME)); -} - -static std::unique_ptr bindFixedListType( - const std::vector& shape, LogicalTypeID typeID) { - if (shape.size() == 1) { - return std::make_unique(typeID); + catalog.getBuiltInFunctions()->matchScalarFunction(std::move(funcName), inputTypes)); + auto bindInput = std::make_unique(); + bindInput->inputs = std::move(inputValues); + auto bindData = + tableFunction->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion()); + expression_vector columns; + for (auto i = 0u; i < bindData->columnTypes.size(); i++) { + columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i])); } - auto childShape = std::vector{shape.begin() + 1, shape.end()}; - auto childType = bindFixedListType(childShape, typeID); - auto extraInfo = std::make_unique(std::move(childType), (uint32_t)shape[0]); - return std::make_unique(LogicalTypeID::FIXED_LIST, std::move(extraInfo)); + auto offset = expressionBinder.createVariableExpression( + LogicalType(LogicalTypeID::INT64), common::InternalKeyword::ANONYMOUS); + return std::make_unique( + std::move(tableFunction), std::move(bindData), std::move(columns), offset); } std::unique_ptr Binder::bindLoadFrom( @@ -156,6 +139,15 @@ std::unique_ptr Binder::bindLoadFrom( if (readerConfig->getNumFiles() > 1) { throw BinderException("Load from multiple files is not supported."); } + switch (fileType) { + case common::FileType::CSV: + case common::FileType::PARQUET: + case common::FileType::NPY: + break; + default: + throw BinderException( + stringFormat("Cannot load from file type {}.", FileTypeUtils::toString(fileType))); + } // Bind columns from input. std::vector expectedColumnNames; std::vector> expectedColumnTypes; @@ -163,38 +155,20 @@ std::unique_ptr Binder::bindLoadFrom( expectedColumnNames.push_back(name); expectedColumnTypes.push_back(bindDataType(type)); } - // Detect columns from file. - std::vector detectedColumnNames; - std::vector> detectedColumnTypes; - sniffFiles(*readerConfig, detectedColumnNames, detectedColumnTypes); - // Validate and resolve columns to use. + auto scanFunction = + getScanFunction(readerConfig->fileType, readerConfig->csvReaderConfig->parallel); + auto bindInput = std::make_unique(memoryManager, + *readerConfig, std::move(expectedColumnNames), std::move(expectedColumnTypes)); + auto bindData = + scanFunction->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion()); expression_vector columns; - if (expectedColumnTypes.empty()) { // Input is empty. Use detected columns. - columns = createColumnExpressions(*readerConfig, detectedColumnNames, detectedColumnTypes); - } else { - validateNumColumns(expectedColumnTypes.size(), detectedColumnTypes.size()); - if (fileType == common::FileType::PARQUET) { - validateColumnTypes(expectedColumnNames, expectedColumnTypes, detectedColumnTypes); - } - columns = createColumnExpressions(*readerConfig, expectedColumnNames, expectedColumnTypes); + for (auto i = 0u; i < bindData->columnTypes.size(); i++) { + columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i])); } - auto inputType = std::make_unique(LogicalTypeID::STRING); - std::vector inputTypes; - inputTypes.push_back(inputType.get()); - auto scanFunction = getScanFunction( - readerConfig->fileType, std::move(inputTypes), readerConfig->csvReaderConfig->parallel); - std::vector> inputValues; - inputValues.push_back(std::make_unique(Value::createValue(readerConfig->filePaths[0]))); - std::unique_ptr tableFuncBindInput = - std::make_unique(*readerConfig, memoryManager); - auto bindData = scanFunction->bindFunc( - clientContext, tableFuncBindInput.get(), catalog.getReadOnlyVersion()); - auto info = - std::make_unique(scanFunction, std::move(bindData), std::move(columns), - std::make_shared(LogicalType{LogicalTypeID::INT64}, - getUniqueExpressionName(CopyConstants::ROW_IDX_COLUMN_NAME), - CopyConstants::ROW_IDX_COLUMN_NAME), - TableType::UNKNOWN); + auto offset = expressionBinder.createVariableExpression( + LogicalType(LogicalTypeID::INT64), common::InternalKeyword::ANONYMOUS); + auto info = std::make_unique(scanFunction, std::move(bindData), + std::move(columns), std::move(offset), TableType::UNKNOWN); auto boundLoadFrom = std::make_unique(std::move(info)); if (loadFrom.hasWherePredicate()) { auto wherePredicate = expressionBinder.bindExpression(*loadFrom.getWherePredicate()); @@ -203,105 +177,5 @@ std::unique_ptr Binder::bindLoadFrom( return boundLoadFrom; } -expression_vector Binder::createColumnExpressions(common::ReaderConfig& readerConfig, - const std::vector& columnNames, - const std::vector>& columnTypes) { - expression_vector columns; - for (auto i = 0u; i < columnTypes.size(); ++i) { - auto columnName = columnNames[i]; - auto columnType = columnTypes[i].get(); - readerConfig.columnNames.push_back(columnName); - readerConfig.columnTypes.push_back(columnType->copy()); - columns.push_back(createVariable(columnName, *columnType)); - } - return columns; -} - -void Binder::validateColumnTypes(const std::vector& columnNames, - const std::vector>& expectedColumnTypes, - const std::vector>& detectedColumnTypes) { - KU_ASSERT(expectedColumnTypes.size() == detectedColumnTypes.size()); - for (auto i = 0; i < expectedColumnTypes.size(); ++i) { - if (*expectedColumnTypes[i] != *detectedColumnTypes[i]) { - throw BinderException( - stringFormat("Column `{}` type mismatch. Expected {} but got {}.", columnNames[i], - expectedColumnTypes[i]->toString(), detectedColumnTypes[i]->toString())); - } - } -} - -void Binder::validateNumColumns(uint32_t expectedNumber, uint32_t detectedNumber) { - if (detectedNumber == 0) { - return; // Empty CSV. Continue processing. - } - if (expectedNumber != detectedNumber) { - throw BinderException(stringFormat( - "Number of columns mismatch. Expected {} but got {}.", expectedNumber, detectedNumber)); - } -} - -void Binder::sniffFiles(const common::ReaderConfig& readerConfig, - std::vector& columnNames, - std::vector>& columnTypes) { - KU_ASSERT(readerConfig.getNumFiles() > 0); - sniffFile(readerConfig, 0, columnNames, columnTypes); - for (auto i = 1; i < readerConfig.getNumFiles(); ++i) { - std::vector tmpColumnNames; - std::vector> tmpColumnTypes; - sniffFile(readerConfig, i, tmpColumnNames, tmpColumnTypes); - switch (readerConfig.fileType) { - case FileType::CSV: { - validateNumColumns(columnTypes.size(), tmpColumnTypes.size()); - } - case FileType::PARQUET: { - validateNumColumns(columnTypes.size(), tmpColumnTypes.size()); - validateColumnTypes(columnNames, columnTypes, tmpColumnTypes); - } break; - case FileType::NPY: { - validateNumColumns(1, tmpColumnTypes.size()); - columnNames.push_back(tmpColumnNames[0]); - columnTypes.push_back(tmpColumnTypes[0]->copy()); - } break; - case FileType::TURTLE: - break; - default: - KU_UNREACHABLE; - } - } -} - -void Binder::sniffFile(const common::ReaderConfig& readerConfig, uint32_t fileIdx, - std::vector& columnNames, std::vector>& columnTypes) { - switch (readerConfig.fileType) { - case FileType::CSV: { - auto csvReader = SerialCSVReader(readerConfig.filePaths[fileIdx], readerConfig); - auto sniffedColumns = csvReader.sniffCSV(); - for (auto& [name, type] : sniffedColumns) { - columnNames.push_back(name); - columnTypes.push_back(type.copy()); - } - } break; - case FileType::PARQUET: { - auto reader = ParquetReader(readerConfig.filePaths[fileIdx], memoryManager); - auto state = std::make_unique(); - reader.initializeScan(*state, std::vector{}); - for (auto i = 0u; i < reader.getNumColumns(); ++i) { - columnNames.push_back(reader.getColumnName(i)); - columnTypes.push_back(reader.getColumnType(i)->copy()); - } - } break; - case FileType::NPY: { - auto reader = NpyReader(readerConfig.filePaths[0]); - auto columnName = std::string("column" + std::to_string(fileIdx)); - auto columnType = bindFixedListType(reader.getShape(), reader.getType()); - columnNames.push_back(columnName); - columnTypes.push_back(columnType->copy()); - } break; - default: - throw BinderException(stringFormat( - "Cannot sniff header of file type {}", FileTypeUtils::toString(readerConfig.fileType))); - } -} - } // namespace binder } // namespace kuzu diff --git a/src/binder/bind/copy/bind_copy_rdf_graph.cpp b/src/binder/bind/copy/bind_copy_rdf_graph.cpp index 46a17c9e2f8..2592cb537bf 100644 --- a/src/binder/bind/copy/bind_copy_rdf_graph.cpp +++ b/src/binder/bind/copy/bind_copy_rdf_graph.cpp @@ -18,30 +18,33 @@ std::unique_ptr Binder::bindCopyRdfNodeFrom(function::TableFunct std::unique_ptr readerConfig, TableSchema* tableSchema) { bool containsSerial; auto stringType = LogicalType{LogicalTypeID::STRING}; - auto nodeID = createVariable(std::string(Property::INTERNAL_ID_NAME), LogicalTypeID::INT64); - expression_vector columns; - auto columnName = std::string(InternalKeyword::ANONYMOUS); - readerConfig->columnNames.push_back(columnName); + std::vector columnNames; + std::vector> columnTypes; + columnNames.push_back(std::string(InternalKeyword::ANONYMOUS)); if (tableSchema->tableName.ends_with(rdf::RESOURCE_TABLE_SUFFIX)) { containsSerial = false; - readerConfig->columnTypes.push_back(stringType.copy()); + columnTypes.push_back(stringType.copy()); readerConfig->rdfReaderConfig = std::make_unique(RdfReaderMode::RESOURCE, nullptr /* index */); - columns.push_back(createVariable(columnName, stringType)); } else { KU_ASSERT(tableSchema->tableName.ends_with(rdf::LITERAL_TABLE_SUFFIX)); containsSerial = true; - readerConfig->columnTypes.push_back(RdfVariantType::getType()); + columnTypes.push_back(RdfVariantType::getType()); readerConfig->rdfReaderConfig = std::make_unique(RdfReaderMode::LITERAL, nullptr /* index */); - columns.push_back(createVariable(columnName, *RdfVariantType::getType())); } - auto tableFuncBindInput = - std::make_unique(*readerConfig, memoryManager); + auto bindInput = std::make_unique( + memoryManager, *readerConfig, columnNames, std::move(columnTypes)); auto bindData = - copyFunc->bindFunc(clientContext, tableFuncBindInput.get(), catalog.getReadOnlyVersion()); + copyFunc->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion()); + expression_vector columns; + for (auto i = 0u; i < bindData->columnTypes.size(); i++) { + columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i])); + } + auto offset = expressionBinder.createVariableExpression( + LogicalType(LogicalTypeID::INT64), common::InternalKeyword::ANONYMOUS); auto boundFileScanInfo = std::make_unique( - copyFunc, std::move(bindData), columns, std::move(nodeID), TableType::NODE); + copyFunc, std::move(bindData), columns, std::move(offset), TableType::NODE); auto boundCopyFromInfo = std::make_unique(tableSchema, std::move(boundFileScanInfo), containsSerial, std::move(columns), nullptr /* extraInfo */); return std::make_unique(std::move(boundCopyFromInfo)); @@ -50,13 +53,12 @@ std::unique_ptr Binder::bindCopyRdfNodeFrom(function::TableFunct std::unique_ptr Binder::bindCopyRdfRelFrom(function::TableFunction* copyFunc, std::unique_ptr readerConfig, TableSchema* tableSchema) { auto containsSerial = false; - auto offsetType = std::make_unique(LogicalTypeID::INT64); - expression_vector columns; + std::vector columnNames; + std::vector> columnTypes; for (auto i = 0u; i < 3; ++i) { auto columnName = std::string(InternalKeyword::ANONYMOUS) + std::to_string(i); - readerConfig->columnNames.push_back(columnName); - readerConfig->columnTypes.push_back(offsetType->copy()); - columns.push_back(createVariable(columnName, *offsetType)); + columnNames.push_back(columnName); + columnTypes.push_back(std::make_unique(LogicalTypeID::INT64)); } auto relTableSchema = reinterpret_cast(tableSchema); auto resourceTableID = relTableSchema->getSrcTableID(); @@ -68,15 +70,20 @@ std::unique_ptr Binder::bindCopyRdfRelFrom(function::TableFuncti readerConfig->rdfReaderConfig = std::make_unique(RdfReaderMode::LITERAL_TRIPLE, index); } - auto relID = createVariable(std::string(Property::INTERNAL_ID_NAME), LogicalTypeID::INT64); - auto tableFuncBindInput = - std::make_unique(*readerConfig, memoryManager); + auto bindInput = std::make_unique( + memoryManager, *readerConfig, columnNames, std::move(columnTypes)); auto bindData = - copyFunc->bindFunc(clientContext, tableFuncBindInput.get(), catalog.getReadOnlyVersion()); + copyFunc->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion()); + expression_vector columns; + for (auto i = 0u; i < bindData->columnTypes.size(); i++) { + columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i])); + } + auto offset = expressionBinder.createVariableExpression( + LogicalType(LogicalTypeID::INT64), common::InternalKeyword::ANONYMOUS); auto boundFileScanInfo = std::make_unique( - copyFunc, std::move(bindData), columns, relID, TableType::REL); + copyFunc, std::move(bindData), columns, offset, TableType::REL); auto extraInfo = std::make_unique(columns[0], columns[2]); - expression_vector columnsToCopy = {columns[0], columns[2], relID, columns[1]}; + expression_vector columnsToCopy = {columns[0], columns[2], offset, columns[1]}; auto boundCopyFromInfo = std::make_unique(tableSchema, std::move(boundFileScanInfo), containsSerial, std::move(columnsToCopy), std::move(extraInfo)); diff --git a/src/binder/binder.cpp b/src/binder/binder.cpp index 4f998bcba6f..fe9ba50ada2 100644 --- a/src/binder/binder.cpp +++ b/src/binder/binder.cpp @@ -208,29 +208,34 @@ void Binder::restoreScope(std::unique_ptr prevVariableScope) { scope = std::move(prevVariableScope); } -function::TableFunction* Binder::getScanFunction( - common::FileType fileType, std::vector inputTypes, bool isParallel) { +function::TableFunction* Binder::getScanFunction(common::FileType fileType, bool isParallel) { + function::Function* func; + auto stringType = LogicalType(LogicalTypeID::STRING); + std::vector inputTypes; + inputTypes.push_back(&stringType); switch (fileType) { - case common::FileType::PARQUET: - return reinterpret_cast( - catalog.getBuiltInFunctions()->matchScalarFunction( - READ_PARQUET_FUNC_NAME, std::move(inputTypes))); - case common::FileType::NPY: - return reinterpret_cast( - catalog.getBuiltInFunctions()->matchScalarFunction( - READ_NPY_FUNC_NAME, std::move(inputTypes))); - case common::FileType::CSV: - return reinterpret_cast( - catalog.getBuiltInFunctions()->matchScalarFunction( - isParallel ? READ_CSV_PARALLEL_FUNC_NAME : READ_CSV_SERIAL_FUNC_NAME, - std::move(inputTypes))); - case common::FileType::TURTLE: - return reinterpret_cast( - catalog.getBuiltInFunctions()->matchScalarFunction( - READ_RDF_FUNC_NAME, std::move(inputTypes))); + case common::FileType::PARQUET: { + func = + catalog.getBuiltInFunctions()->matchScalarFunction(READ_PARQUET_FUNC_NAME, inputTypes); + } break; + + case common::FileType::NPY: { + func = catalog.getBuiltInFunctions()->matchScalarFunction( + READ_NPY_FUNC_NAME, std::move(inputTypes)); + } break; + case common::FileType::CSV: { + func = catalog.getBuiltInFunctions()->matchScalarFunction( + isParallel ? READ_CSV_PARALLEL_FUNC_NAME : READ_CSV_SERIAL_FUNC_NAME, + std::move(inputTypes)); + } break; + case common::FileType::TURTLE: { + func = catalog.getBuiltInFunctions()->matchScalarFunction( + READ_RDF_FUNC_NAME, std::move(inputTypes)); + } break; default: KU_UNREACHABLE; } + return reinterpret_cast(func); } } // namespace binder diff --git a/src/common/copier_config/copier_config.cpp b/src/common/copier_config/copier_config.cpp index 48206e342e3..a21700fa600 100644 --- a/src/common/copier_config/copier_config.cpp +++ b/src/common/copier_config/copier_config.cpp @@ -1,5 +1,7 @@ #include "common/copier_config/copier_config.h" +#include + #include "common/assert.h" #include "common/exception/copy.h" diff --git a/src/include/binder/binder.h b/src/include/binder/binder.h index b9bb9681b42..7851c0f5c52 100644 --- a/src/include/binder/binder.h +++ b/src/include/binder/binder.h @@ -128,10 +128,13 @@ class Binder { std::unique_ptr readerConfig, catalog::TableSchema* tableSchema); std::unique_ptr bindCopyRdfRelFrom(function::TableFunction* copyFunc, std::unique_ptr readerConfig, catalog::TableSchema* tableSchema); - expression_vector bindExpectedNodeFileColumns( - catalog::TableSchema* tableSchema, common::ReaderConfig& readerConfig); - expression_vector bindExpectedRelFileColumns( - catalog::TableSchema* tableSchema, common::ReaderConfig& readerConfig); + void bindExpectedNodeColumns(catalog::TableSchema* tableSchema, + std::vector& columnNames, + std::vector>& columnTypes); + void bindExpectedRelColumns(catalog::TableSchema* tableSchema, + std::vector& columnNames, + std::vector>& columnTypes); + std::unique_ptr bindCopyToClause(const parser::Statement& statement); /*** bind file scan ***/ @@ -169,18 +172,6 @@ class Binder { const parser::ReadingClause& readingClause); std::unique_ptr bindInQueryCall(const parser::ReadingClause& readingClause); std::unique_ptr bindLoadFrom(const parser::ReadingClause& readingClause); - expression_vector createColumnExpressions(common::ReaderConfig& readerConfig, - const std::vector& columnNames, - const std::vector>& columnTypes); - void sniffFiles(const common::ReaderConfig& readerConfig, std::vector& columnNames, - std::vector>& columnTypes); - void sniffFile(const common::ReaderConfig& readerConfig, uint32_t fileIdx, - std::vector& columnNames, - std::vector>& columnTypes); - static void validateNumColumns(uint32_t expectedNumber, uint32_t detectedNumber); - static void validateColumnTypes(const std::vector& expectedColumnNames, - const std::vector>& expectedColumnTypes, - const std::vector>& detectedColumnTypes); /*** bind updating clause ***/ // TODO(Guodong/Xiyang): Is update clause an accurate name? How about (data)modificationClause? @@ -292,8 +283,7 @@ class Binder { std::unique_ptr saveScope(); void restoreScope(std::unique_ptr prevVariableScope); - function::TableFunction* getScanFunction( - common::FileType fileType, std::vector inputTypes, bool isParallel); + function::TableFunction* getScanFunction(common::FileType fileType, bool isParallel); private: const catalog::Catalog& catalog; diff --git a/src/include/binder/copy/bound_copy_to.h b/src/include/binder/copy/bound_copy_to.h index a3746506094..e1fecef84b4 100644 --- a/src/include/binder/copy/bound_copy_to.h +++ b/src/include/binder/copy/bound_copy_to.h @@ -8,17 +8,28 @@ namespace binder { class BoundCopyTo : public BoundStatement { public: - BoundCopyTo(std::unique_ptr config, + BoundCopyTo(std::string filePath, common::FileType fileType, + std::vector columnNames, + std::vector> columnTypes, std::unique_ptr regularQuery) : BoundStatement{common::StatementType::COPY_TO, BoundStatementResult::createEmptyResult()}, - config{std::move(config)}, regularQuery{std::move(regularQuery)} {} + filePath{std::move(filePath)}, fileType{fileType}, columnNames{std::move(columnNames)}, + columnTypes{std::move(columnTypes)}, regularQuery{std::move(regularQuery)} {} - inline common::ReaderConfig* getConfig() const { return config.get(); } + inline std::string getFilePath() const { return filePath; } + inline common::FileType getFileType() const { return fileType; } + inline std::vector getColumnNames() const { return columnNames; } + inline const std::vector>& getColumnTypesRef() const { + return columnTypes; + } inline BoundRegularQuery* getRegularQuery() const { return regularQuery.get(); } private: - std::unique_ptr config; + std::string filePath; + common::FileType fileType; + std::vector columnNames; + std::vector> columnTypes; std::unique_ptr regularQuery; }; diff --git a/src/include/common/copier_config/copier_config.h b/src/include/common/copier_config/copier_config.h index 43aa6e84651..afc6af51a5f 100644 --- a/src/include/common/copier_config/copier_config.h +++ b/src/include/common/copier_config/copier_config.h @@ -4,7 +4,6 @@ #include #include "common/constants.h" -#include "common/types/types.h" #include "rdf_config.h" namespace kuzu { @@ -43,8 +42,6 @@ struct FileTypeUtils { struct ReaderConfig { FileType fileType = FileType::UNKNOWN; std::vector filePaths; - std::vector columnNames; - std::vector> columnTypes; std::unique_ptr csvReaderConfig = nullptr; // NOTE: Do not try to refactor this with CSVReaderConfig. We might remove this. std::unique_ptr rdfReaderConfig; @@ -53,15 +50,8 @@ struct ReaderConfig { std::unique_ptr csvReaderConfig) : fileType{fileType}, filePaths{std::move(filePaths)}, csvReaderConfig{ std::move(csvReaderConfig)} {} - ReaderConfig(FileType fileType, std::vector filePaths, - std::vector columnNames, - std::vector> columnTypes) - : fileType{fileType}, filePaths{std::move(filePaths)}, columnNames{std::move(columnNames)}, - columnTypes{std::move(columnTypes)} {} - ReaderConfig(const ReaderConfig& other) - : fileType{other.fileType}, filePaths{other.filePaths}, columnNames{other.columnNames}, - columnTypes{LogicalType::copy(other.columnTypes)} { + ReaderConfig(const ReaderConfig& other) : fileType{other.fileType}, filePaths{other.filePaths} { if (other.csvReaderConfig != nullptr) { this->csvReaderConfig = other.csvReaderConfig->copy(); } @@ -70,12 +60,7 @@ struct ReaderConfig { } } - inline bool parallelRead() const { - return (fileType != FileType::CSV || csvReaderConfig->parallel) && - fileType != FileType::TURTLE; - } inline uint32_t getNumFiles() const { return filePaths.size(); } - inline uint32_t getNumColumns() const { return columnNames.size(); } inline std::unique_ptr copy() const { return std::make_unique(*this); diff --git a/src/include/function/table_functions/bind_data.h b/src/include/function/table_functions/bind_data.h index 430e6c830e4..8a36339a18c 100644 --- a/src/include/function/table_functions/bind_data.h +++ b/src/include/function/table_functions/bind_data.h @@ -8,12 +8,12 @@ namespace kuzu { namespace function { struct TableFuncBindData { - std::vector> returnTypes; - std::vector returnColumnNames; + std::vector> columnTypes; + std::vector columnNames; - TableFuncBindData(std::vector> returnTypes, - std::vector returnColumnNames) - : returnTypes{std::move(returnTypes)}, returnColumnNames{std::move(returnColumnNames)} {} + TableFuncBindData(std::vector> columnTypes, + std::vector columnNames) + : columnTypes{std::move(columnTypes)}, columnNames{std::move(columnNames)} {} virtual ~TableFuncBindData() = default; @@ -21,18 +21,18 @@ struct TableFuncBindData { }; struct ScanBindData : public TableFuncBindData { - common::ReaderConfig config; storage::MemoryManager* mm; + common::ReaderConfig config; - ScanBindData(std::vector> returnTypes, - std::vector returnColumnNames, const common::ReaderConfig config, - storage::MemoryManager* mm) - : TableFuncBindData{std::move(returnTypes), std::move(returnColumnNames)}, config{config}, - mm{mm} {} + ScanBindData(std::vector> columnTypes, + std::vector columnNames, storage::MemoryManager* mm, + const common::ReaderConfig& config) + : TableFuncBindData{std::move(columnTypes), std::move(columnNames)}, mm{mm}, config{ + config} {} std::unique_ptr copy() override { return std::make_unique( - common::LogicalType::copy(returnTypes), returnColumnNames, config, mm); + common::LogicalType::copy(columnTypes), columnNames, mm, config); } }; diff --git a/src/include/function/table_functions/bind_input.h b/src/include/function/table_functions/bind_input.h index 38705249016..1007fbfd265 100644 --- a/src/include/function/table_functions/bind_input.h +++ b/src/include/function/table_functions/bind_input.h @@ -13,19 +13,22 @@ struct TableFuncBindInput { std::vector> inputs; TableFuncBindInput() = default; - explicit TableFuncBindInput(std::vector> inputs) - : inputs{std::move(inputs)} {} virtual ~TableFuncBindInput() = default; }; struct ScanTableFuncBindInput final : public TableFuncBindInput { - const common::ReaderConfig config; storage::MemoryManager* mm; - - ScanTableFuncBindInput(const common::ReaderConfig config, storage::MemoryManager* mm) - : TableFuncBindInput{}, config{config}, mm{mm} { - inputs.push_back( - std::make_unique(common::Value::createValue(this->config.filePaths[0]))); + common::ReaderConfig config; + std::vector expectedColumnNames; + std::vector> expectedColumnTypes; + + ScanTableFuncBindInput(storage::MemoryManager* mm, const common::ReaderConfig& config, + std::vector expectedColumnNames, + std::vector> expectedColumnTypes) + : TableFuncBindInput{}, mm{mm}, config{config}, + expectedColumnNames{std::move(expectedColumnNames)}, expectedColumnTypes{ + std::move(expectedColumnTypes)} { + inputs.push_back(common::Value::createValue(config.filePaths[0]).copy()); } }; diff --git a/src/include/function/table_functions/call_functions.h b/src/include/function/table_functions/call_functions.h index 71537722a34..c141cd13a2c 100644 --- a/src/include/function/table_functions/call_functions.h +++ b/src/include/function/table_functions/call_functions.h @@ -44,7 +44,7 @@ struct CallTableFuncBindData : public TableFuncBindData { inline std::unique_ptr copy() override { return std::make_unique( - common::LogicalType::copy(returnTypes), returnColumnNames, maxOffset); + common::LogicalType::copy(columnTypes), columnNames, maxOffset); } }; @@ -63,7 +63,7 @@ struct CurrentSettingBindData : public CallTableFuncBindData { inline std::unique_ptr copy() override { return std::make_unique( - result, common::LogicalType::copy(returnTypes), returnColumnNames, maxOffset); + result, common::LogicalType::copy(columnTypes), columnNames, maxOffset); } }; @@ -96,7 +96,7 @@ struct ShowTablesBindData : public CallTableFuncBindData { inline std::unique_ptr copy() override { return std::make_unique( - tables, common::LogicalType::copy(returnTypes), returnColumnNames, maxOffset); + tables, common::LogicalType::copy(columnTypes), columnNames, maxOffset); } }; @@ -120,7 +120,7 @@ struct TableInfoBindData : public CallTableFuncBindData { inline std::unique_ptr copy() override { return std::make_unique( - tableSchema, common::LogicalType::copy(returnTypes), returnColumnNames, maxOffset); + tableSchema, common::LogicalType::copy(columnTypes), columnNames, maxOffset); } }; @@ -143,8 +143,8 @@ struct ShowConnectionBindData : public TableInfoBindData { std::move(returnColumnNames), maxOffset} {} inline std::unique_ptr copy() override { - return std::make_unique(catalog, tableSchema, - common::LogicalType::copy(returnTypes), returnColumnNames, maxOffset); + return std::make_unique( + catalog, tableSchema, common::LogicalType::copy(columnTypes), columnNames, maxOffset); } }; diff --git a/src/include/planner/operator/persistent/logical_copy_to.h b/src/include/planner/operator/persistent/logical_copy_to.h index 9b80a4695da..b0beeea33a3 100644 --- a/src/include/planner/operator/persistent/logical_copy_to.h +++ b/src/include/planner/operator/persistent/logical_copy_to.h @@ -8,28 +8,39 @@ namespace planner { class LogicalCopyTo : public LogicalOperator { public: - LogicalCopyTo( - std::shared_ptr child, std::unique_ptr config) - : LogicalOperator{LogicalOperatorType::COPY_TO, std::move(child)}, config{ - std::move(config)} {} + LogicalCopyTo(std::string filePath, common::FileType fileType, + std::vector columnNames, + std::vector> columnTypes, + std::shared_ptr child) + : LogicalOperator{LogicalOperatorType::COPY_TO, std::move(child)}, filePath{std::move( + filePath)}, + fileType{fileType}, columnNames{std::move(columnNames)}, columnTypes{ + std::move(columnTypes)} {} f_group_pos_set getGroupsPosToFlatten(); inline std::string getExpressionsForPrinting() const override { return std::string{}; } - inline common::ReaderConfig* getConfig() const { return config.get(); } - void computeFactorizedSchema() override; - void computeFlatSchema() override; + inline std::string getFilePath() const { return filePath; } + inline common::FileType getFileType() const { return fileType; } + inline std::vector getColumnNames() const { return columnNames; } + inline const std::vector>& getColumnTypesRef() const { + return columnTypes; + } + inline std::unique_ptr copy() override { - return make_unique(children[0]->copy(), config->copy()); + return make_unique(filePath, fileType, columnNames, + common::LogicalType::copy(columnTypes), children[0]->copy()); } private: - std::shared_ptr outputExpression; - std::unique_ptr config; + std::string filePath; + common::FileType fileType; + std::vector columnNames; + std::vector> columnTypes; }; } // namespace planner diff --git a/src/include/processor/operator/persistent/reader/csv/base_csv_reader.h b/src/include/processor/operator/persistent/reader/csv/base_csv_reader.h index e67a9f8fa7e..97ec7a05654 100644 --- a/src/include/processor/operator/persistent/reader/csv/base_csv_reader.h +++ b/src/include/processor/operator/persistent/reader/csv/base_csv_reader.h @@ -15,7 +15,8 @@ class BaseCSVReader { friend class ParsingDriver; public: - BaseCSVReader(const std::string& filePath, const common::ReaderConfig& readerConfig); + BaseCSVReader( + const std::string& filePath, const common::ReaderConfig& readerConfig, uint64_t numColumns); virtual ~BaseCSVReader(); @@ -64,8 +65,7 @@ class BaseCSVReader { std::string filePath; common::CSVReaderConfig csvReaderConfig; - uint64_t expectedNumColumns; - uint64_t numColumnsDetected; + uint64_t numColumns; int fd; common::block_idx_t currentBlockIdx; diff --git a/src/include/processor/operator/persistent/reader/csv/parallel_csv_reader.h b/src/include/processor/operator/persistent/reader/csv/parallel_csv_reader.h index f01ed367f64..7ec6e2ff0a5 100644 --- a/src/include/processor/operator/persistent/reader/csv/parallel_csv_reader.h +++ b/src/include/processor/operator/persistent/reader/csv/parallel_csv_reader.h @@ -16,7 +16,8 @@ class ParallelCSVReader final : public BaseCSVReader { friend class ParallelParsingDriver; public: - ParallelCSVReader(const std::string& filePath, const common::ReaderConfig& readerConfig); + ParallelCSVReader( + const std::string& filePath, const common::ReaderConfig& readerConfig, uint64_t numColumns); bool hasMoreToRead() const; uint64_t parseBlock(common::block_idx_t blockIdx, common::DataChunk& resultChunk) override; @@ -36,9 +37,13 @@ struct ParallelCSVLocalState final : public function::TableFuncLocalState { }; struct ParallelCSVScanSharedState final : public function::ScanSharedTableFuncState { - explicit ParallelCSVScanSharedState(const common::ReaderConfig readerConfig, uint64_t numRows); + explicit ParallelCSVScanSharedState( + const common::ReaderConfig readerConfig, uint64_t numRows, uint64_t numColumns) + : ScanSharedTableFuncState{std::move(readerConfig), numRows}, numColumns{numColumns} {} void setFileComplete(uint64_t completedFileIdx); + + uint64_t numColumns; }; struct ParallelCSVScan { diff --git a/src/include/processor/operator/persistent/reader/csv/serial_csv_reader.h b/src/include/processor/operator/persistent/reader/csv/serial_csv_reader.h index 2a04515e146..f1b74ad334c 100644 --- a/src/include/processor/operator/persistent/reader/csv/serial_csv_reader.h +++ b/src/include/processor/operator/persistent/reader/csv/serial_csv_reader.h @@ -12,7 +12,8 @@ namespace processor { //! Serial CSV reader is a class that reads values from a stream in a single thread. class SerialCSVReader final : public BaseCSVReader { public: - SerialCSVReader(const std::string& filePath, const common::ReaderConfig& readerConfig); + SerialCSVReader( + const std::string& filePath, const common::ReaderConfig& readerConfig, uint64_t numColumns); //! Sniffs CSV dialect and determines skip rows, header row, column types and column names std::vector> sniffCSV(); @@ -23,8 +24,9 @@ class SerialCSVReader final : public BaseCSVReader { }; struct SerialCSVScanSharedState final : public function::ScanSharedTableFuncState { - explicit SerialCSVScanSharedState(const common::ReaderConfig readerConfig, uint64_t numRows) - : ScanSharedTableFuncState{std::move(readerConfig), numRows} { + explicit SerialCSVScanSharedState( + const common::ReaderConfig readerConfig, uint64_t numRows, uint64_t numColumns) + : ScanSharedTableFuncState{std::move(readerConfig), numRows}, numColumns{numColumns} { initReader(); } @@ -33,6 +35,7 @@ struct SerialCSVScanSharedState final : public function::ScanSharedTableFuncStat void initReader(); std::unique_ptr reader; + uint64_t numColumns; }; struct SerialCSVScan { @@ -48,6 +51,13 @@ struct SerialCSVScan { static std::unique_ptr initLocalState( function::TableFunctionInitInput& /*input*/, function::TableFuncSharedState* /*state*/); + + static void bindColumns(const common::ReaderConfig& readerConfig, + std::vector& columnNames, + std::vector>& columnTypes); + static void bindColumns(const common::ReaderConfig& readerConfig, uint32_t fileIdx, + std::vector& columnNames, + std::vector>& columnTypes); }; } // namespace processor diff --git a/src/include/processor/operator/persistent/reader/npy/npy_reader.h b/src/include/processor/operator/persistent/reader/npy/npy_reader.h index 7829dd0fc75..be26330f765 100644 --- a/src/include/processor/operator/persistent/reader/npy/npy_reader.h +++ b/src/include/processor/operator/persistent/reader/npy/npy_reader.h @@ -78,6 +78,13 @@ struct NpyScanFunction { static std::unique_ptr initLocalState( function::TableFunctionInitInput& /*input*/, function::TableFuncSharedState* /*state*/); + + static void bindColumns(const common::ReaderConfig& readerConfig, + std::vector& columnNames, + std::vector>& columnTypes); + static void bindColumns(const common::ReaderConfig& readerConfig, uint32_t fileIdx, + std::vector& columnNames, + std::vector>& columnTypes); }; } // namespace processor diff --git a/src/include/processor/operator/persistent/reader/parquet/parquet_reader.h b/src/include/processor/operator/persistent/reader/parquet/parquet_reader.h index 61a42a9b0ad..e45418ae820 100644 --- a/src/include/processor/operator/persistent/reader/parquet/parquet_reader.h +++ b/src/include/processor/operator/persistent/reader/parquet/parquet_reader.h @@ -120,6 +120,13 @@ struct ParquetScanFunction { static std::unique_ptr initLocalState( function::TableFunctionInitInput& input, function::TableFuncSharedState* state); + + static void bindColumns(const common::ReaderConfig& readerConfig, storage::MemoryManager* mm, + std::vector& columnNames, + std::vector>& columnTypes); + static void bindColumns(const common::ReaderConfig& readerConfig, uint32_t fileIdx, + storage::MemoryManager* mm, std::vector& columnNames, + std::vector>& columnTypes); }; } // namespace processor diff --git a/src/include/processor/operator/persistent/reader/reader_bind_utils.h b/src/include/processor/operator/persistent/reader/reader_bind_utils.h new file mode 100644 index 00000000000..b10d847a313 --- /dev/null +++ b/src/include/processor/operator/persistent/reader/reader_bind_utils.h @@ -0,0 +1,22 @@ +#pragma once + +#include "common/types/types.h" + +namespace kuzu { +namespace processor { + +struct ReaderBindUtils { + static void validateNumColumns(uint32_t expectedNumber, uint32_t detectedNumber); + static void validateColumnTypes(const std::vector& columnNames, + const std::vector>& expectedColumnTypes, + const std::vector>& detectedColumnTypes); + static void resolveColumns(const std::vector& expectedColumnNames, + const std::vector& detectedColumnNames, + std::vector& resultColumnNames, + const std::vector>& expectedColumnTypes, + const std::vector>& detectedColumnTypes, + std::vector>& resultColumnTypes); +}; + +} // namespace processor +} // namespace kuzu diff --git a/src/planner/operator/persistent/logical_copy_to.cpp b/src/planner/operator/persistent/logical_copy_to.cpp index c7c4ad7a53c..fff11142008 100644 --- a/src/planner/operator/persistent/logical_copy_to.cpp +++ b/src/planner/operator/persistent/logical_copy_to.cpp @@ -15,12 +15,7 @@ void LogicalCopyTo::computeFlatSchema() { f_group_pos_set LogicalCopyTo::getGroupsPosToFlatten() { auto childSchema = children[0]->getSchema(); - f_group_pos_set dependentGroupsPos; - for (auto& expression : childSchema->getExpressionsInScope()) { - for (auto& grouPos : childSchema->getDependentGroupsPos(expression)) { - dependentGroupsPos.insert(grouPos); - } - } + auto dependentGroupsPos = childSchema->getGroupsPosInScope(); return factorization::FlattenAllButOne::getGroupsPosToFlatten(dependentGroupsPos, childSchema); } diff --git a/src/planner/plan/plan_copy.cpp b/src/planner/plan/plan_copy.cpp index bb5387b620d..03656f15054 100644 --- a/src/planner/plan/plan_copy.cpp +++ b/src/planner/plan/plan_copy.cpp @@ -109,13 +109,14 @@ std::unique_ptr Planner::planCopyFrom(const BoundStatement& stateme std::unique_ptr Planner::planCopyTo(const Catalog& catalog, const NodesStoreStatsAndDeletedIDs& nodesStatistics, const RelsStoreStats& relsStatistics, const BoundStatement& statement) { - auto& copyClause = reinterpret_cast(statement); - auto regularQuery = copyClause.getRegularQuery(); + auto& boundCopy = reinterpret_cast(statement); + auto regularQuery = boundCopy.getRegularQuery(); KU_ASSERT(regularQuery->getStatementType() == StatementType::QUERY); auto plan = QueryPlanner(catalog, nodesStatistics, relsStatistics).getBestPlan(*regularQuery); - auto logicalCopyTo = - make_shared(plan->getLastOperator(), copyClause.getConfig()->copy()); - plan->setLastOperator(std::move(logicalCopyTo)); + auto copyTo = make_shared(boundCopy.getFilePath(), boundCopy.getFileType(), + boundCopy.getColumnNames(), LogicalType::copy(boundCopy.getColumnTypesRef()), + plan->getLastOperator()); + plan->setLastOperator(std::move(copyTo)); return plan; } diff --git a/src/processor/map/map_copy_to.cpp b/src/processor/map/map_copy_to.cpp index fab6bedf37a..b7e08be3e13 100644 --- a/src/processor/map/map_copy_to.cpp +++ b/src/processor/map/map_copy_to.cpp @@ -10,10 +10,11 @@ using namespace kuzu::storage; namespace kuzu { namespace processor { -std::unique_ptr getCopyToInfo(Schema* childSchema, ReaderConfig* config, +std::unique_ptr getCopyToInfo(Schema* childSchema, const std::string& filePath, + common::FileType fileType, std::vector columnNames, std::vector> columnsTypes, std::vector vectorsToCopyPos, std::vector isFlat) { - switch (config->fileType) { + switch (fileType) { case FileType::PARQUET: { auto copyToSchema = std::make_unique(); auto copyToExpressions = childSchema->getExpressionsInScope(); @@ -31,11 +32,11 @@ std::unique_ptr getCopyToInfo(Schema* childSchema, ReaderConfig* con copyToSchema->appendColumn(std::move(columnSchema)); } return std::make_unique(std::move(copyToSchema), std::move(columnsTypes), - config->columnNames, vectorsToCopyPos, config->filePaths[0]); + columnNames, vectorsToCopyPos, filePath); } case FileType::CSV: { return std::make_unique( - config->columnNames, vectorsToCopyPos, config->filePaths[0], std::move(isFlat)); + columnNames, vectorsToCopyPos, filePath, std::move(isFlat)); } // LCOV_EXCL_START default: @@ -44,8 +45,8 @@ std::unique_ptr getCopyToInfo(Schema* childSchema, ReaderConfig* con } } -std::shared_ptr getCopyToSharedState(ReaderConfig* config) { - switch (config->fileType) { +static std::shared_ptr getCopyToSharedState(FileType fileType) { + switch (fileType) { case FileType::CSV: return std::make_shared(); case FileType::PARQUET: @@ -59,13 +60,8 @@ std::shared_ptr getCopyToSharedState(ReaderConfig* config) { std::unique_ptr PlanMapper::mapCopyTo(LogicalOperator* logicalOperator) { auto copy = (LogicalCopyTo*)logicalOperator; - auto config = copy->getConfig(); - std::vector> columnsTypes; - std::vector columnNames; - columnsTypes.reserve(config->columnTypes.size()); - for (auto& type : config->columnTypes) { - columnsTypes.push_back(type->copy()); - } + auto columnNames = copy->getColumnNames(); + auto columnTypes = LogicalType::copy(copy->getColumnTypesRef()); auto childSchema = logicalOperator->getChild(0)->getSchema(); auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); std::vector vectorsToCopyPos; @@ -74,11 +70,12 @@ std::unique_ptr PlanMapper::mapCopyTo(LogicalOperator* logical vectorsToCopyPos.emplace_back(childSchema->getExpressionPos(*expression)); isFlat.push_back(childSchema->getGroup(expression)->isFlat()); } - std::unique_ptr copyToInfo = getCopyToInfo(childSchema, config, - std::move(columnsTypes), std::move(vectorsToCopyPos), std::move(isFlat)); - auto sharedState = getCopyToSharedState(config); + std::unique_ptr copyToInfo = + getCopyToInfo(childSchema, copy->getFilePath(), copy->getFileType(), std::move(columnNames), + std::move(columnTypes), std::move(vectorsToCopyPos), std::move(isFlat)); + auto sharedState = getCopyToSharedState(copy->getFileType()); std::unique_ptr copyTo; - if (copy->getConfig()->fileType == common::FileType::PARQUET) { + if (copy->getFileType() == common::FileType::PARQUET) { copyTo = std::make_unique(std::make_unique(childSchema), std::move(copyToInfo), std::move(sharedState), std::move(prevOperator), getOperatorID(), copy->getExpressionsForPrinting()); diff --git a/src/processor/operator/persistent/CMakeLists.txt b/src/processor/operator/persistent/CMakeLists.txt index 14270bec689..cc21cb429b9 100644 --- a/src/processor/operator/persistent/CMakeLists.txt +++ b/src/processor/operator/persistent/CMakeLists.txt @@ -1,7 +1,4 @@ -add_subdirectory(reader/parquet) -add_subdirectory(reader/csv) -add_subdirectory(reader/rdf) -add_subdirectory(reader/npy) +add_subdirectory(reader) add_subdirectory(writer/parquet) add_library(kuzu_processor_operator_persistent diff --git a/src/processor/operator/persistent/reader/CMakeLists.txt b/src/processor/operator/persistent/reader/CMakeLists.txt new file mode 100644 index 00000000000..c0a22950488 --- /dev/null +++ b/src/processor/operator/persistent/reader/CMakeLists.txt @@ -0,0 +1,12 @@ +add_subdirectory(csv) +add_subdirectory(npy) +add_subdirectory(parquet) +add_subdirectory(rdf) + +add_library(kuzu_processor_operator_persistent_reader + OBJECT + reader_bind_utils.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/src/processor/operator/persistent/reader/csv/base_csv_reader.cpp b/src/processor/operator/persistent/reader/csv/base_csv_reader.cpp index b12964fd35b..01b86db3ed9 100644 --- a/src/processor/operator/persistent/reader/csv/base_csv_reader.cpp +++ b/src/processor/operator/persistent/reader/csv/base_csv_reader.cpp @@ -19,10 +19,10 @@ using namespace kuzu::common; namespace kuzu { namespace processor { -BaseCSVReader::BaseCSVReader(const std::string& filePath, const common::ReaderConfig& readerConfig) - : filePath{filePath}, csvReaderConfig{*readerConfig.csvReaderConfig}, - expectedNumColumns(readerConfig.getNumColumns()), numColumnsDetected(-1), fd(-1), - buffer(nullptr), bufferSize(0), position(0), rowEmpty(false) { +BaseCSVReader::BaseCSVReader( + const std::string& filePath, const common::ReaderConfig& readerConfig, uint64_t numColumns) + : filePath{filePath}, csvReaderConfig{*readerConfig.csvReaderConfig}, numColumns(numColumns), + fd(-1), buffer(nullptr), bufferSize(0), position(0), rowEmpty(false) { // TODO(Ziyi): should we wrap this fd using kuzu file handler? fd = open(filePath.c_str(), O_RDONLY #ifdef _WIN32 @@ -54,7 +54,7 @@ uint64_t BaseCSVReader::countRows() { } // If the number of columns is 1, every line start indicates a row. - if (expectedNumColumns == 1) { + if (numColumns == 1) { rows++; } @@ -66,7 +66,7 @@ uint64_t BaseCSVReader::countRows() { goto line_start; } else { // If we have more than one column, every non-empty line is a row. - if (expectedNumColumns != 1) { + if (numColumns != 1) { rows++; } goto normal; diff --git a/src/processor/operator/persistent/reader/csv/driver.cpp b/src/processor/operator/persistent/reader/csv/driver.cpp index 7daa5e09430..18e71891ade 100644 --- a/src/processor/operator/persistent/reader/csv/driver.cpp +++ b/src/processor/operator/persistent/reader/csv/driver.cpp @@ -26,14 +26,14 @@ void ParsingDriver::addValue( rowEmpty = false; } BaseCSVReader* reader = getReader(); - if (columnIdx == reader->expectedNumColumns && length == 0) { + if (columnIdx == reader->numColumns && length == 0) { // skip a single trailing delimiter in last columnIdx return; } - if (columnIdx >= reader->expectedNumColumns) { + if (columnIdx >= reader->numColumns) { throw CopyException( stringFormat("Error in file {}, on line {}: expected {} values per row, but got more.", - reader->filePath, reader->getLineNumber(), reader->expectedNumColumns)); + reader->filePath, reader->getLineNumber(), reader->numColumns)); } function::CastString::copyStringToVector( @@ -44,16 +44,16 @@ bool ParsingDriver::addRow(uint64_t /*rowNum*/, common::column_id_t columnCount) BaseCSVReader* reader = getReader(); if (rowEmpty) { rowEmpty = false; - if (reader->expectedNumColumns != 1) { + if (reader->numColumns != 1) { return false; } // Otherwise, treat it as null. } - if (columnCount < reader->expectedNumColumns) { + if (columnCount < reader->numColumns) { // Column number mismatch. - throw CopyException(stringFormat( - "Error in file {} on line {}: expected {} values per row, but got {}", reader->filePath, - reader->getLineNumber(), reader->expectedNumColumns, columnCount)); + throw CopyException( + stringFormat("Error in file {} on line {}: expected {} values per row, but got {}", + reader->filePath, reader->getLineNumber(), reader->numColumns, columnCount)); } return true; } diff --git a/src/processor/operator/persistent/reader/csv/parallel_csv_reader.cpp b/src/processor/operator/persistent/reader/csv/parallel_csv_reader.cpp index 79452340fc0..0d403d2bfef 100644 --- a/src/processor/operator/persistent/reader/csv/parallel_csv_reader.cpp +++ b/src/processor/operator/persistent/reader/csv/parallel_csv_reader.cpp @@ -1,5 +1,8 @@ #include "processor/operator/persistent/reader/csv/parallel_csv_reader.h" +#include "processor/operator/persistent/reader/csv/serial_csv_reader.h" +#include "processor/operator/persistent/reader/reader_bind_utils.h" + #if defined(_WIN32) #include #else @@ -19,8 +22,8 @@ namespace kuzu { namespace processor { ParallelCSVReader::ParallelCSVReader( - const std::string& filePath, const common::ReaderConfig& readerConfig) - : BaseCSVReader{filePath, readerConfig} {} + const std::string& filePath, const common::ReaderConfig& readerConfig, uint64_t numColumns) + : BaseCSVReader{filePath, readerConfig, numColumns} {} bool ParallelCSVReader::hasMoreToRead() const { // If we haven't started the first block yet or are done our block, get the next block. @@ -105,10 +108,6 @@ bool ParallelCSVReader::finishedBlock() const { return getFileOffset() > (currentBlockIdx + 1) * CopyConstants::PARALLEL_BLOCK_SIZE; } -ParallelCSVScanSharedState::ParallelCSVScanSharedState( - const common::ReaderConfig readerConfig, uint64_t numRows) - : ScanSharedTableFuncState{std::move(readerConfig), numRows} {} - void ParallelCSVScanSharedState::setFileComplete(uint64_t completedFileIdx) { std::lock_guard guard{lock}; if (completedFileIdx == fileIdx) { @@ -145,7 +144,7 @@ void ParallelCSVScan::tableFunc(TableFunctionInput& input, common::DataChunk& ou parallelCSVLocalState->fileIdx = fileIdx; parallelCSVLocalState->reader = std::make_unique( parallelCSVSharedState->readerConfig.filePaths[fileIdx], - parallelCSVSharedState->readerConfig); + parallelCSVSharedState->readerConfig, parallelCSVSharedState->numColumns); } auto numRowsRead = parallelCSVLocalState->reader->parseBlock(blockIdx, outputChunk); outputChunk.state->selVector->selectedSize = numRowsRead; @@ -162,10 +161,16 @@ void ParallelCSVScan::tableFunc(TableFunctionInput& input, common::DataChunk& ou std::unique_ptr ParallelCSVScan::bindFunc( main::ClientContext* /*context*/, function::TableFuncBindInput* input, catalog::CatalogContent* /*catalog*/) { - auto csvScanBindInput = reinterpret_cast(input); - return std::make_unique( - common::LogicalType::copy(csvScanBindInput->config.columnTypes), - csvScanBindInput->config.columnNames, csvScanBindInput->config, csvScanBindInput->mm); + auto scanInput = reinterpret_cast(input); + std::vector detectedColumnNames; + std::vector> detectedColumnTypes; + SerialCSVScan::bindColumns(scanInput->config, detectedColumnNames, detectedColumnTypes); + std::vector resultColumnNames; + std::vector> resultColumnTypes; + ReaderBindUtils::resolveColumns(scanInput->expectedColumnNames, detectedColumnNames, + resultColumnNames, scanInput->expectedColumnTypes, detectedColumnTypes, resultColumnTypes); + return std::make_unique(std::move(resultColumnTypes), + std::move(resultColumnNames), scanInput->mm, scanInput->config); } std::unique_ptr ParallelCSVScan::initSharedState( @@ -173,18 +178,21 @@ std::unique_ptr ParallelCSVScan::initSharedState auto bindData = reinterpret_cast(input.bindData); common::row_idx_t numRows = 0; for (const auto& path : bindData->config.filePaths) { - auto reader = make_unique(path, bindData->config); + auto reader = + make_unique(path, bindData->config, bindData->columnNames.size()); numRows += reader->countRows(); } - return std::make_unique(bindData->config, numRows); + return std::make_unique( + bindData->config, numRows, bindData->columnNames.size()); } std::unique_ptr ParallelCSVScan::initLocalState( function::TableFunctionInitInput& /*input*/, function::TableFuncSharedState* state) { auto localState = std::make_unique(); - auto scanSharedState = reinterpret_cast(state); - localState->reader = std::make_unique( - scanSharedState->readerConfig.filePaths[0], scanSharedState->readerConfig); + auto scanSharedState = reinterpret_cast(state); + localState->reader = + std::make_unique(scanSharedState->readerConfig.filePaths[0], + scanSharedState->readerConfig, scanSharedState->numColumns); localState->fileIdx = 0; return localState; } diff --git a/src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp b/src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp index 69b7e282336..aece2c4e9ca 100644 --- a/src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp +++ b/src/processor/operator/persistent/reader/csv/serial_csv_reader.cpp @@ -2,6 +2,7 @@ #include "common/string_format.h" #include "processor/operator/persistent/reader/csv/driver.h" +#include "processor/operator/persistent/reader/reader_bind_utils.h" using namespace kuzu::common; using namespace kuzu::function; @@ -10,12 +11,11 @@ namespace kuzu { namespace processor { SerialCSVReader::SerialCSVReader( - const std::string& filePath, const common::ReaderConfig& readerConfig) - : BaseCSVReader{filePath, readerConfig} {} + const std::string& filePath, const common::ReaderConfig& readerConfig, uint64_t numColumns) + : BaseCSVReader{filePath, readerConfig, numColumns} {} std::vector> SerialCSVReader::sniffCSV() { readBOM(); - numColumnsDetected = 0; if (csvReaderConfig.hasHeader) { SniffCSVNameAndTypeDriver driver; @@ -61,7 +61,8 @@ void SerialCSVScanSharedState::read(common::DataChunk& outputChunk) { void SerialCSVScanSharedState::initReader() { if (fileIdx < readerConfig.getNumFiles()) { - reader = std::make_unique(readerConfig.filePaths[fileIdx], readerConfig); + reader = std::make_unique( + readerConfig.filePaths[fileIdx], readerConfig, numColumns); } } @@ -81,10 +82,16 @@ void SerialCSVScan::tableFunc(TableFunctionInput& input, DataChunk& outputChunk) std::unique_ptr SerialCSVScan::bindFunc( main::ClientContext* /*context*/, function::TableFuncBindInput* input, catalog::CatalogContent* /*catalog*/) { - auto csvScanBindInput = reinterpret_cast(input); - return std::make_unique( - common::LogicalType::copy(csvScanBindInput->config.columnTypes), - csvScanBindInput->config.columnNames, csvScanBindInput->config, csvScanBindInput->mm); + auto scanInput = reinterpret_cast(input); + std::vector detectedColumnNames; + std::vector> detectedColumnTypes; + bindColumns(scanInput->config, detectedColumnNames, detectedColumnTypes); + std::vector resultColumnNames; + std::vector> resultColumnTypes; + ReaderBindUtils::resolveColumns(scanInput->expectedColumnNames, detectedColumnNames, + resultColumnNames, scanInput->expectedColumnTypes, detectedColumnTypes, resultColumnTypes); + return std::make_unique(std::move(resultColumnTypes), + std::move(resultColumnNames), scanInput->mm, scanInput->config); } std::unique_ptr SerialCSVScan::initSharedState( @@ -92,10 +99,12 @@ std::unique_ptr SerialCSVScan::initSharedState( auto bindData = reinterpret_cast(input.bindData); common::row_idx_t numRows = 0; for (const auto& path : bindData->config.filePaths) { - auto reader = make_unique(path, bindData->config); + auto reader = + make_unique(path, bindData->config, bindData->columnNames.size()); numRows += reader->countRows(); } - return std::make_unique(bindData->config, numRows); + return std::make_unique( + bindData->config, numRows, bindData->columnNames.size()); } std::unique_ptr SerialCSVScan::initLocalState( @@ -103,5 +112,29 @@ std::unique_ptr SerialCSVScan::initLocalState( return std::make_unique(); } +void SerialCSVScan::bindColumns(const ReaderConfig& readerConfig, + std::vector& columnNames, std::vector>& columnTypes) { + KU_ASSERT(readerConfig.getNumFiles() > 0); + bindColumns(readerConfig, 0, columnNames, columnTypes); + for (auto i = 1; i < readerConfig.getNumFiles(); ++i) { + std::vector tmpColumnNames; + std::vector> tmpColumnTypes; + bindColumns(readerConfig, i, tmpColumnNames, tmpColumnTypes); + ReaderBindUtils::validateNumColumns(columnTypes.size(), tmpColumnTypes.size()); + } +} + +void SerialCSVScan::bindColumns(const common::ReaderConfig& readerConfig, uint32_t fileIdx, + std::vector& columnNames, + std::vector>& columnTypes) { + auto csvReader = + SerialCSVReader(readerConfig.filePaths[fileIdx], readerConfig, 0 /* numColumns */); + auto sniffedColumns = csvReader.sniffCSV(); + for (auto& [name, type] : sniffedColumns) { + columnNames.push_back(name); + columnTypes.push_back(type.copy()); + } +} + } // namespace processor } // namespace kuzu diff --git a/src/processor/operator/persistent/reader/npy/npy_reader.cpp b/src/processor/operator/persistent/reader/npy/npy_reader.cpp index a9b9675bba2..33ab13b515b 100644 --- a/src/processor/operator/persistent/reader/npy/npy_reader.cpp +++ b/src/processor/operator/persistent/reader/npy/npy_reader.cpp @@ -3,6 +3,8 @@ #include #include +#include "processor/operator/persistent/reader/reader_bind_utils.h" + #ifdef _WIN32 #include "common/exception/buffer_manager.h" #include @@ -193,9 +195,8 @@ void NpyReader::validate(const LogicalType& type_, offset_t numRows) { filePath)); } if (getNumElementsPerRow() != FixedListType::getNumElementsInList(&type_)) { - throw CopyException(stringFormat("The shape of {} does not match the length of the " - "fixed list property.", - filePath)); + throw CopyException( + stringFormat("The shape of {} does not match {}.", filePath, type_.toString())); } return; } else { @@ -252,19 +253,27 @@ void NpyScanFunction::tableFunc(TableFunctionInput& input, DataChunk& outputChun std::unique_ptr NpyScanFunction::bindFunc( main::ClientContext* /*context*/, function::TableFuncBindInput* input, catalog::CatalogContent* /*catalog*/) { - auto bindInput = reinterpret_cast(input); - auto config = bindInput->config; - KU_ASSERT(!config.filePaths.empty() && config.getNumFiles() == config.getNumColumns()); + auto scanInput = reinterpret_cast(input); + + std::vector detectedColumnNames; + std::vector> detectedColumnTypes; + bindColumns(scanInput->config, detectedColumnNames, detectedColumnTypes); + std::vector resultColumnNames; + std::vector> resultColumnTypes; + ReaderBindUtils::resolveColumns(scanInput->expectedColumnNames, detectedColumnNames, + resultColumnNames, scanInput->expectedColumnTypes, detectedColumnTypes, resultColumnTypes); + auto config = scanInput->config; + KU_ASSERT(!config.filePaths.empty() && config.getNumFiles() == resultColumnNames.size()); row_idx_t numRows; for (auto i = 0u; i < config.getNumFiles(); i++) { auto reader = make_unique(config.filePaths[i]); if (i == 0) { numRows = reader->getNumRows(); } - reader->validate(*config.columnTypes[i], numRows); + reader->validate(*resultColumnTypes[i], numRows); } - return std::make_unique( - common::LogicalType::copy(config.columnTypes), config.columnNames, config, bindInput->mm); + return std::make_unique(std::move(resultColumnTypes), + std::move(resultColumnNames), scanInput->mm, scanInput->config); } std::unique_ptr NpyScanFunction::initSharedState( @@ -279,5 +288,41 @@ std::unique_ptr NpyScanFunction::initLocalState( return std::make_unique(); } +void NpyScanFunction::bindColumns(const common::ReaderConfig& readerConfig, + std::vector& columnNames, + std::vector>& columnTypes) { + KU_ASSERT(readerConfig.getNumFiles() > 0); + bindColumns(readerConfig, 0, columnNames, columnTypes); + for (auto i = 1; i < readerConfig.getNumFiles(); ++i) { + std::vector tmpColumnNames; + std::vector> tmpColumnTypes; + bindColumns(readerConfig, i, tmpColumnNames, tmpColumnTypes); + ReaderBindUtils::validateNumColumns(1, tmpColumnTypes.size()); + columnNames.push_back(tmpColumnNames[0]); + columnTypes.push_back(tmpColumnTypes[0]->copy()); + } +} + +static std::unique_ptr bindFixedListType( + const std::vector& shape, LogicalTypeID typeID) { + if (shape.size() == 1) { + return std::make_unique(typeID); + } + auto childShape = std::vector{shape.begin() + 1, shape.end()}; + auto childType = bindFixedListType(childShape, typeID); + auto extraInfo = std::make_unique(std::move(childType), (uint32_t)shape[0]); + return std::make_unique(LogicalTypeID::FIXED_LIST, std::move(extraInfo)); +} + +void NpyScanFunction::bindColumns(const common::ReaderConfig& readerConfig, uint32_t fileIdx, + std::vector& columnNames, + std::vector>& columnTypes) { + auto reader = NpyReader(readerConfig.filePaths[fileIdx]); // TODO: double check + auto columnName = std::string("column" + std::to_string(fileIdx)); + auto columnType = bindFixedListType(reader.getShape(), reader.getType()); + columnNames.push_back(columnName); + columnTypes.push_back(columnType->copy()); +} + } // namespace processor } // namespace kuzu diff --git a/src/processor/operator/persistent/reader/parquet/parquet_reader.cpp b/src/processor/operator/persistent/reader/parquet/parquet_reader.cpp index 3d3f3b62562..bea10578fc5 100644 --- a/src/processor/operator/persistent/reader/parquet/parquet_reader.cpp +++ b/src/processor/operator/persistent/reader/parquet/parquet_reader.cpp @@ -6,6 +6,7 @@ #include "processor/operator/persistent/reader/parquet/list_column_reader.h" #include "processor/operator/persistent/reader/parquet/struct_column_reader.h" #include "processor/operator/persistent/reader/parquet/thrift_tools.h" +#include "processor/operator/persistent/reader/reader_bind_utils.h" using namespace kuzu_parquet::format; @@ -622,11 +623,20 @@ void ParquetScanFunction::tableFunc(TableFunctionInput& input, DataChunk& output std::unique_ptr ParquetScanFunction::bindFunc( main::ClientContext* /*context*/, function::TableFuncBindInput* input, catalog::CatalogContent* /*catalog*/) { - auto parquetScanBindInput = reinterpret_cast(input); - return std::make_unique( - common::LogicalType::copy(parquetScanBindInput->config.columnTypes), - parquetScanBindInput->config.columnNames, parquetScanBindInput->config, - parquetScanBindInput->mm); + auto scanInput = reinterpret_cast(input); + std::vector detectedColumnNames; + std::vector> detectedColumnTypes; + bindColumns(scanInput->config, scanInput->mm, detectedColumnNames, detectedColumnTypes); + std::vector resultColumnNames; + std::vector> resultColumnTypes; + ReaderBindUtils::resolveColumns(scanInput->expectedColumnNames, detectedColumnNames, + resultColumnNames, scanInput->expectedColumnTypes, detectedColumnTypes, resultColumnTypes); + if (!scanInput->expectedColumnTypes.empty()) { + ReaderBindUtils::validateColumnTypes( + scanInput->expectedColumnNames, scanInput->expectedColumnTypes, detectedColumnTypes); + } + return std::make_unique(std::move(resultColumnTypes), + std::move(resultColumnNames), scanInput->mm, scanInput->config); } std::unique_ptr ParquetScanFunction::initSharedState( @@ -651,5 +661,31 @@ std::unique_ptr ParquetScanFunction::initLocalSta return localState; } +void ParquetScanFunction::bindColumns(const common::ReaderConfig& readerConfig, + storage::MemoryManager* mm, std::vector& columnNames, + std::vector>& columnTypes) { + KU_ASSERT(readerConfig.getNumFiles() > 0); + bindColumns(readerConfig, 0, mm, columnNames, columnTypes); + for (auto i = 1; i < readerConfig.getNumFiles(); ++i) { + std::vector tmpColumnNames; + std::vector> tmpColumnTypes; + bindColumns(readerConfig, i, mm, tmpColumnNames, tmpColumnTypes); + ReaderBindUtils::validateNumColumns(columnTypes.size(), tmpColumnTypes.size()); + ReaderBindUtils::validateColumnTypes(columnNames, columnTypes, tmpColumnTypes); + } +} + +void ParquetScanFunction::bindColumns(const common::ReaderConfig& readerConfig, uint32_t fileIdx, + storage::MemoryManager* mm, std::vector& columnNames, + std::vector>& columnTypes) { + auto reader = ParquetReader(readerConfig.filePaths[fileIdx], mm); + auto state = std::make_unique(); + reader.initializeScan(*state, std::vector{}); + for (auto i = 0u; i < reader.getNumColumns(); ++i) { + columnNames.push_back(reader.getColumnName(i)); + columnTypes.push_back(reader.getColumnType(i)->copy()); + } +} + } // namespace processor } // namespace kuzu diff --git a/src/processor/operator/persistent/reader/rdf/rdf_reader.cpp b/src/processor/operator/persistent/reader/rdf/rdf_reader.cpp index 5006318c30a..92f16b08f4a 100644 --- a/src/processor/operator/persistent/reader/rdf/rdf_reader.cpp +++ b/src/processor/operator/persistent/reader/rdf/rdf_reader.cpp @@ -293,10 +293,10 @@ void RdfScan::tableFunc(function::TableFunctionInput& input, common::DataChunk& std::unique_ptr RdfScan::bindFunc(main::ClientContext* /*context*/, function::TableFuncBindInput* input, catalog::CatalogContent* /*catalog*/) { - auto rdfScanBindData = reinterpret_cast(input); + auto scanInput = reinterpret_cast(input); return std::make_unique( - common::LogicalType::copy(rdfScanBindData->config.columnTypes), - rdfScanBindData->config.columnNames, rdfScanBindData->config, rdfScanBindData->mm); + common::LogicalType::copy(scanInput->expectedColumnTypes), scanInput->expectedColumnNames, + scanInput->mm, scanInput->config); } std::unique_ptr RdfScan::initSharedState( diff --git a/src/processor/operator/persistent/reader/reader_bind_utils.cpp b/src/processor/operator/persistent/reader/reader_bind_utils.cpp new file mode 100644 index 00000000000..1ffd22c9307 --- /dev/null +++ b/src/processor/operator/persistent/reader/reader_bind_utils.cpp @@ -0,0 +1,51 @@ +#include "processor/operator/persistent/reader/reader_bind_utils.h" + +#include "common/exception/binder.h" +#include "common/string_format.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace processor { + +void ReaderBindUtils::validateNumColumns(uint32_t expectedNumber, uint32_t detectedNumber) { + if (detectedNumber == 0) { + return; // Empty CSV. Continue processing. + } + if (expectedNumber != detectedNumber) { + throw common::BinderException(common::stringFormat( + "Number of columns mismatch. Expected {} but got {}.", expectedNumber, detectedNumber)); + } +} + +void ReaderBindUtils::validateColumnTypes(const std::vector& columnNames, + const std::vector>& expectedColumnTypes, + const std::vector>& detectedColumnTypes) { + KU_ASSERT(expectedColumnTypes.size() == detectedColumnTypes.size()); + for (auto i = 0; i < expectedColumnTypes.size(); ++i) { + if (*expectedColumnTypes[i] != *detectedColumnTypes[i]) { + throw common::BinderException(common::stringFormat( + "Column `{}` type mismatch. Expected {} but got {}.", columnNames[i], + expectedColumnTypes[i]->toString(), detectedColumnTypes[i]->toString())); + } + } +} + +void ReaderBindUtils::resolveColumns(const std::vector& expectedColumnNames, + const std::vector& detectedColumnNames, + std::vector& resultColumnNames, + const std::vector>& expectedColumnTypes, + const std::vector>& detectedColumnTypes, + std::vector>& resultColumnTypes) { + if (expectedColumnTypes.empty()) { + resultColumnNames = detectedColumnNames; + resultColumnTypes = LogicalType::copy(detectedColumnTypes); + } else { + validateNumColumns(expectedColumnTypes.size(), detectedColumnTypes.size()); + resultColumnNames = expectedColumnNames; + resultColumnTypes = LogicalType::copy(expectedColumnTypes); + } +} + +} // namespace processor +} // namespace kuzu diff --git a/test/test_files/tinysnb/load_from/load_from.test b/test/test_files/tinysnb/load_from/load_from.test index 0b2cc577451..6872fe16d62 100644 --- a/test/test_files/tinysnb/load_from/load_from.test +++ b/test/test_files/tinysnb/load_from/load_from.test @@ -61,7 +61,7 @@ Greg|0|1994-01-12 Hubert Blaine Wolfeschlegelsteinhausenbergerdorff|0|1994-01-12 -STATEMENT LOAD FROM "${KUZU_ROOT_DIRECTORY}/dataset/copy-test/rdf/taxonomy.ttl" RETURN COUNT(*) ---- error -Binder exception: Cannot sniff header of file type TURTLE +Binder exception: Cannot load from file type TURTLE. -STATEMENT LOAD WITH HEADERS (a INT64, b INT64) FROM "${KUZU_ROOT_DIRECTORY}/dataset/demo-db/parquet/user.parquet" RETURN *; ---- error Binder exception: Column `a` type mismatch. Expected INT64 but got STRING.