Skip to content

Commit

Permalink
Merge pull request #2210 from kuzudb/issue-2139
Browse files Browse the repository at this point in the history
Validate file header for LOAD and COPY
  • Loading branch information
ray6080 committed Oct 14, 2023
2 parents 1420ead + 80e9905 commit 0e3f995
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 103 deletions.
56 changes: 37 additions & 19 deletions src/binder/bind/bind_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,6 @@ static void validateByColumnKeyword(FileType fileType, bool byColumn) {
}
}

static void validateCopyNpyFilesMatchSchema(uint32_t numFiles, TableSchema* schema) {
if (schema->properties.size() != numFiles) {
throw BinderException(StringUtils::string_format(
"Number of npy files is not equal to number of properties in table {}.",
schema->tableName));
}
}

static void validateCopyNpyNotForRelTables(TableSchema* schema) {
if (schema->tableType == TableType::REL) {
throw BinderException(
Expand Down Expand Up @@ -99,7 +91,6 @@ std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& stat
std::make_unique<ReaderConfig>(fileType, std::move(filePaths), std::move(csvReaderConfig));
validateByColumnKeyword(readerConfig->fileType, copyStatement.byColumn());
if (readerConfig->fileType == FileType::NPY) {
validateCopyNpyFilesMatchSchema(readerConfig->getNumFiles(), tableSchema);
validateCopyNpyNotForRelTables(tableSchema);
}
switch (tableSchema->tableType) {
Expand Down Expand Up @@ -194,16 +185,16 @@ static bool skipPropertyInFile(const Property& property) {

expression_vector Binder::bindExpectedNodeFileColumns(
TableSchema* tableSchema, ReaderConfig& readerConfig) {
expression_vector columns;
// Resolve expected columns.
std::vector<std::string> expectedColumnNames;
std::vector<std::unique_ptr<common::LogicalType>> expectedColumnTypes;
switch (readerConfig.fileType) {
case FileType::TURTLE: {
auto stringType = LogicalType{LogicalTypeID::STRING};
auto columnNames = std::vector<std::string>{
expectedColumnNames = {
std::string(RDF_SUBJECT), std::string(RDF_PREDICATE), std::string(RDF_OBJECT)};
for (auto& columnName : columnNames) {
readerConfig.columnNames.push_back(columnName);
readerConfig.columnTypes.push_back(stringType.copy());
columns.push_back(createVariable(columnName, stringType));
for (auto _ : expectedColumnNames) {
expectedColumnTypes.push_back(stringType.copy());
}
} break;
case FileType::NPY:
Expand All @@ -213,16 +204,29 @@ expression_vector Binder::bindExpectedNodeFileColumns(
if (skipPropertyInFile(*property)) {
continue;
}
readerConfig.columnNames.push_back(property->getName());
readerConfig.columnTypes.push_back(property->getDataType()->copy());
columns.push_back(createVariable(property->getName(), *property->getDataType()));
expectedColumnNames.push_back(property->getName());
expectedColumnTypes.push_back(property->getDataType()->copy());
}
} break;
default: {
throw NotImplementedException{"Binder::bindCopyNodeColumns"};
}
}
return columns;
if (readerConfig.fileType == common::FileType::TURTLE) {
// Nothing to validate for turtle
return createColumnExpressions(readerConfig, expectedColumnNames, expectedColumnTypes);
}
// Detect columns from file.
std::vector<std::string> detectedColumnNames;
std::vector<std::unique_ptr<common::LogicalType>> detectedColumnTypes;
sniffFiles(readerConfig, detectedColumnNames, detectedColumnTypes);
// Validate.
validateNumColumns(expectedColumnTypes.size(), detectedColumnTypes.size());
if (readerConfig.fileType == common::FileType::PARQUET) {
// HACK(Ziyi): We should allow casting in Parquet reader.
validateColumnTypes(expectedColumnNames, expectedColumnTypes, detectedColumnTypes);
}
return createColumnExpressions(readerConfig, expectedColumnNames, expectedColumnTypes);
}

expression_vector Binder::bindExpectedRelFileColumns(
Expand Down Expand Up @@ -273,6 +277,20 @@ expression_vector Binder::bindExpectedRelFileColumns(
throw NotImplementedException{"Binder::bindCopyRelColumns"};
}
}
if (readerConfig.fileType == common::FileType::TURTLE) {
// Nothing to validate for turtle
return columns;
}
// Detect columns from file.
std::vector<std::string> detectedColumnNames;
std::vector<std::unique_ptr<common::LogicalType>> detectedColumnTypes;
sniffFiles(readerConfig, detectedColumnNames, detectedColumnTypes);
// Validate number of columns.
validateNumColumns(readerConfig.getNumColumns(), detectedColumnTypes.size());
if (readerConfig.fileType == common::FileType::PARQUET) {
validateColumnTypes(
readerConfig.columnNames, readerConfig.columnTypes, detectedColumnTypes);
}
return columns;
}

Expand Down
171 changes: 112 additions & 59 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,86 +137,139 @@ std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(
if (readerConfig->getNumFiles() > 1) {
throw BinderException("Load from multiple files is not supported.");
}
std::vector<std::string> inputColumnNames;
std::vector<std::unique_ptr<common::LogicalType>> inputColumnTypes;
// Bind columns from input.
std::vector<std::string> expectedColumnNames;
std::vector<std::unique_ptr<common::LogicalType>> expectedColumnTypes;
for (auto& [name, type] : loadFrom.getColumnNameDataTypesRef()) {
inputColumnNames.push_back(name);
inputColumnTypes.push_back(bindDataType(type));
expectedColumnNames.push_back(name);
expectedColumnTypes.push_back(bindDataType(type));
}
// Detect columns from file.
std::vector<std::string> detectedColumnNames;
std::vector<std::unique_ptr<common::LogicalType>> detectedColumnTypes;
switch (fileType) {
sniffFiles(*readerConfig, detectedColumnNames, detectedColumnTypes);
// Validate and resolve columns to use.
expression_vector columns;
if (expectedColumnTypes.empty()) { // Input is empty. Use detected columns.
columns = createColumnExpressions(*readerConfig, detectedColumnNames, detectedColumnTypes);
} else {
validateNumColumns(expectedColumnTypes.size(), detectedColumnTypes.size());
if (fileType == common::FileType::PARQUET) {
validateColumnTypes(expectedColumnNames, expectedColumnTypes, detectedColumnTypes);
}
columns = createColumnExpressions(*readerConfig, expectedColumnNames, expectedColumnTypes);
}
auto info = std::make_unique<BoundFileScanInfo>(
std::move(readerConfig), std::move(columns), nullptr /* offset */, TableType::UNKNOWN);
auto boundLoadFrom = std::make_unique<BoundLoadFrom>(std::move(info));
if (loadFrom.hasWherePredicate()) {
auto wherePredicate = expressionBinder.bindExpression(*loadFrom.getWherePredicate());
boundLoadFrom->setWherePredicate(std::move(wherePredicate));
}
return boundLoadFrom;
}

expression_vector Binder::createColumnExpressions(common::ReaderConfig& readerConfig,
const std::vector<std::string>& columnNames,
const std::vector<std::unique_ptr<common::LogicalType>>& columnTypes) {
expression_vector columns;
for (auto i = 0u; i < columnTypes.size(); ++i) {
auto columnName = columnNames[i];
auto columnType = columnTypes[i].get();
readerConfig.columnNames.push_back(columnName);
readerConfig.columnTypes.push_back(columnType->copy());
columns.push_back(createVariable(columnName, *columnType));
}
return columns;
}

void Binder::validateColumnTypes(const std::vector<std::string>& columnNames,
const std::vector<std::unique_ptr<LogicalType>>& expectedColumnTypes,
const std::vector<std::unique_ptr<LogicalType>>& detectedColumnTypes) {
assert(expectedColumnTypes.size() == detectedColumnTypes.size());
for (auto i = 0; i < expectedColumnTypes.size(); ++i) {
if (*expectedColumnTypes[i] != *detectedColumnTypes[i]) {
throw BinderException(
StringUtils::string_format("Column `{}` type mismatch. Expected {} but got {}.",
columnNames[i], LogicalTypeUtils::dataTypeToString(*expectedColumnTypes[i]),
LogicalTypeUtils::dataTypeToString(*detectedColumnTypes[i])));
}
}
}

void Binder::validateNumColumns(uint32_t expectedNumber, uint32_t detectedNumber) {
if (detectedNumber == 0) {
return; // Empty CSV. Continue processing.
}
if (expectedNumber != detectedNumber) {
throw BinderException(StringUtils::string_format(
"Number of columns mismatch. Expected {} but got {}.", expectedNumber, detectedNumber));
}
}

void Binder::sniffFiles(const common::ReaderConfig& readerConfig,
std::vector<std::string>& columnNames,
std::vector<std::unique_ptr<common::LogicalType>>& columnTypes) {
assert(readerConfig.getNumFiles() > 0);
sniffFile(readerConfig, 0, columnNames, columnTypes);
for (auto i = 1; i < readerConfig.getNumFiles(); ++i) {
std::vector<std::string> tmpColumnNames;
std::vector<std::unique_ptr<LogicalType>> tmpColumnTypes;
sniffFile(readerConfig, i, tmpColumnNames, tmpColumnTypes);
switch (readerConfig.fileType) {
case FileType::CSV: {
validateNumColumns(columnTypes.size(), tmpColumnTypes.size());
}
case FileType::PARQUET: {
validateNumColumns(columnTypes.size(), tmpColumnTypes.size());
validateColumnTypes(columnNames, columnTypes, tmpColumnTypes);
} break;
case FileType::NPY: {
validateNumColumns(1, tmpColumnTypes.size());
columnNames.push_back(tmpColumnNames[0]);
columnTypes.push_back(tmpColumnTypes[0]->copy());
} break;
case FileType::TURTLE:
break;
default:
// LCOV_EXCL_START
throw NotImplementedException("Binder::sniffFiles");
// LCOV_EXCL_END
}
}
}

void Binder::sniffFile(const common::ReaderConfig& readerConfig, uint32_t fileIdx,
std::vector<std::string>& columnNames, std::vector<std::unique_ptr<LogicalType>>& columnTypes) {
switch (readerConfig.fileType) {
case FileType::CSV: {
auto csvReader = SerialCSVReader(readerConfig->filePaths[0], *readerConfig);
auto csvReader = SerialCSVReader(readerConfig.filePaths[fileIdx], readerConfig);
auto sniffedColumns = csvReader.sniffCSV();
for (auto& [name, type] : sniffedColumns) {
detectedColumnNames.push_back(name);
detectedColumnTypes.push_back(type.copy());
columnNames.push_back(name);
columnTypes.push_back(type.copy());
}
} break;
case FileType::PARQUET: {
auto reader = ParquetReader(readerConfig->filePaths[0], memoryManager);
auto reader = ParquetReader(readerConfig.filePaths[fileIdx], memoryManager);
auto state = std::make_unique<processor::ParquetReaderScanState>();
reader.initializeScan(*state, std::vector<uint64_t>{});
for (auto i = 0u; i < reader.getNumColumns(); ++i) {
detectedColumnNames.push_back(reader.getColumnName(i));
detectedColumnTypes.push_back(reader.getColumnType(i)->copy());
columnNames.push_back(reader.getColumnName(i));
columnTypes.push_back(reader.getColumnType(i)->copy());
}
} break;
case FileType::NPY: {
auto reader = NpyReader(readerConfig->filePaths[0]);
auto columnName = std::string("column0");
auto reader = NpyReader(readerConfig.filePaths[0]);
auto columnName = std::string("column" + std::to_string(fileIdx));
auto columnType = bindFixedListType(reader.getShape(), reader.getType());
detectedColumnNames.push_back(columnName);
detectedColumnTypes.push_back(columnType->copy());
columnNames.push_back(columnName);
columnTypes.push_back(columnType->copy());
} break;
default:
throw BinderException(StringUtils::string_format(
"Load from {} file is not supported.", FileTypeUtils::toString(fileType)));
"Cannot sniff header of file type {}", FileTypeUtils::toString(readerConfig.fileType)));
}
expression_vector columns;
if (inputColumnTypes.empty()) {
for (auto i = 0u; i < detectedColumnTypes.size(); ++i) {
auto columnName = detectedColumnNames[i];
auto columnType = detectedColumnTypes[i].get();
readerConfig->columnNames.push_back(columnName);
readerConfig->columnTypes.push_back(columnType->copy());
columns.push_back(createVariable(columnName, *columnType));
}
} else {
if (inputColumnTypes.size() != detectedColumnTypes.size()) {
throw BinderException(
StringUtils::string_format("Number of columns mismatch. Detect {} but expect {}.",
detectedColumnTypes.size(), inputColumnTypes.size()));
}
if (fileType == common::FileType::PARQUET) {
for (auto i = 0u; i < inputColumnTypes.size(); ++i) {
auto inputType = inputColumnTypes[i].get();
auto detectType = detectedColumnTypes[i].get();
if (*inputType != *detectType) {
throw BinderException(StringUtils::string_format(
"Column {} data type mismatch. Detect {} but expect {}.",
inputColumnNames[i], LogicalTypeUtils::dataTypeToString(*detectType),
LogicalTypeUtils::dataTypeToString(*inputType)));
}
}
}
for (auto i = 0u; i < inputColumnTypes.size(); ++i) {
auto columnName = inputColumnNames[i];
auto columnType = inputColumnTypes[i].get();
readerConfig->columnNames.push_back(columnName);
readerConfig->columnTypes.push_back(columnType->copy());
columns.push_back(createVariable(columnName, *columnType));
}
}
auto info = std::make_unique<BoundFileScanInfo>(
std::move(readerConfig), std::move(columns), nullptr, TableType::UNKNOWN);
auto boundLoadFrom = std::make_unique<BoundLoadFrom>(std::move(info));
if (loadFrom.hasWherePredicate()) {
auto wherePredicate = expressionBinder.bindExpression(*loadFrom.getWherePredicate());
boundLoadFrom->setWherePredicate(std::move(wherePredicate));
}
return boundLoadFrom;
}

} // namespace binder
Expand Down
12 changes: 12 additions & 0 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ class Binder {
const parser::ReadingClause& readingClause);
std::unique_ptr<BoundReadingClause> bindInQueryCall(const parser::ReadingClause& readingClause);
std::unique_ptr<BoundReadingClause> bindLoadFrom(const parser::ReadingClause& readingClause);
expression_vector createColumnExpressions(common::ReaderConfig& readerConfig,
const std::vector<std::string>& columnNames,
const std::vector<std::unique_ptr<common::LogicalType>>& columnTypes);
void sniffFiles(const common::ReaderConfig& readerConfig, std::vector<std::string>& columnNames,
std::vector<std::unique_ptr<common::LogicalType>>& columnTypes);
void sniffFile(const common::ReaderConfig& readerConfig, uint32_t fileIdx,
std::vector<std::string>& columnNames,
std::vector<std::unique_ptr<common::LogicalType>>& columnTypes);
static void validateNumColumns(uint32_t expectedNumber, uint32_t detectedNumber);
static void validateColumnTypes(const std::vector<std::string>& expectedColumnNames,
const std::vector<std::unique_ptr<common::LogicalType>>& expectedColumnTypes,
const std::vector<std::unique_ptr<common::LogicalType>>& detectedColumnTypes);

/*** bind updating clause ***/
// TODO(Guodong/Xiyang): Is update clause an accurate name? How about (data)modificationClause?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ class BaseCSVReader {
void addValue(Driver&, uint64_t rowNum, common::column_id_t columnIdx, std::string_view strVal,
std::vector<uint64_t>& escapePositions);

template<typename Driver>
bool addRow(Driver&, uint64_t rowNum, common::column_id_t column_count);

//! Read BOM and header.
void handleFirstBlock();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ struct SniffCSVNameAndTypeDriver {
};

struct SniffCSVColumnCountDriver {
bool emptyRow = true;
uint64_t numColumns = 0;

bool done(uint64_t rowNum);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,6 @@ void BaseCSVReader::addValue(Driver& driver, uint64_t rowNum, column_id_t column
}
}

template<typename Driver>
bool BaseCSVReader::addRow(Driver& driver, uint64_t rowNum, column_id_t column) {
return driver.addRow(rowNum, column);
}

void BaseCSVReader::handleFirstBlock() {
readBOM();
if (csvReaderConfig.hasHeader) {
Expand Down Expand Up @@ -308,7 +303,7 @@ add_row : {
std::string_view(buffer.get() + start, position - start - hasQuotes), escapePositions);
column++;

rowNum += addRow(driver, rowNum, column);
rowNum += driver.addRow(rowNum, column);

column = 0;
position++;
Expand Down Expand Up @@ -423,7 +418,7 @@ add_row : {
column++;
}
if (column > 0) {
rowNum += addRow(driver, rowNum, column);
rowNum += driver.addRow(rowNum, column);
}
return rowNum;
}
Expand Down
14 changes: 12 additions & 2 deletions src/processor/operator/persistent/reader/csv/driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,14 +612,24 @@ bool SniffCSVNameAndTypeDriver::addRow(uint64_t, common::column_id_t) {
}

bool SniffCSVColumnCountDriver::done(uint64_t) {
return true;
return !emptyRow;
}

void SniffCSVColumnCountDriver::addValue(uint64_t, common::column_id_t, std::string_view value) {
void SniffCSVColumnCountDriver::addValue(
uint64_t, common::column_id_t columnIdx, std::string_view value) {
if (value != "" || columnIdx > 0) {
emptyRow = false;
}
numColumns++;
}

bool SniffCSVColumnCountDriver::addRow(uint64_t, common::column_id_t) {
if (emptyRow) {
// If this is the last row, we just return zero: we don't know how many columns there are
// supposed to be.
numColumns = 0;
return false;
}
return true;
}

Expand Down
Loading

0 comments on commit 0e3f995

Please sign in to comment.