Skip to content

Commit

Permalink
Refactor table function
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Nov 12, 2023
1 parent 89e1ab4 commit 4227a14
Show file tree
Hide file tree
Showing 32 changed files with 561 additions and 491 deletions.
185 changes: 70 additions & 115 deletions src/binder/bind/bind_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ std::unique_ptr<BoundStatement> Binder::bindCopyToClause(const Statement& statem
if (fileType != FileType::CSV && fileType != FileType::PARQUET) {
throw BinderException(ExceptionMessage::validateCopyToCSVParquetExtensionsException());
}
auto readerConfig = std::make_unique<ReaderConfig>(
fileType, std::vector<std::string>{boundFilePath}, columnNames, std::move(columnTypes));
return std::make_unique<BoundCopyTo>(std::move(readerConfig), std::move(query));
return std::make_unique<BoundCopyTo>(
boundFilePath, fileType, std::move(columnNames), std::move(columnTypes), std::move(query));
}

// As a temporary constraint, we require npy files loaded with COPY FROM BY COLUMN keyword.
Expand Down Expand Up @@ -93,15 +92,7 @@ std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& stat
auto fileType = bindFileType(filePaths);
auto readerConfig =
std::make_unique<ReaderConfig>(fileType, std::move(filePaths), std::move(csvReaderConfig));
auto inputType = std::make_unique<LogicalType>(LogicalTypeID::STRING);
std::vector<LogicalType*> inputTypes;
inputTypes.push_back(inputType.get());
auto scanFunction =
getScanFunction(fileType, std::move(inputTypes), readerConfig->csvReaderConfig->parallel);
std::vector<std::unique_ptr<Value>> inputValues;
inputValues.push_back(std::make_unique<Value>(Value::createValue(readerConfig->filePaths[0])));
auto tableFuncBindInput =
std::make_unique<function::ScanTableFuncBindInput>(*readerConfig, memoryManager);
auto scanFunction = getScanFunction(fileType, readerConfig->csvReaderConfig->parallel);
validateByColumnKeyword(readerConfig->fileType, copyStatement.byColumn());
if (readerConfig->fileType == FileType::NPY) {
validateCopyNpyNotForRelTables(tableSchema);
Expand Down Expand Up @@ -132,14 +123,21 @@ std::unique_ptr<BoundStatement> Binder::bindCopyNodeFrom(function::TableFunction
std::unique_ptr<common::ReaderConfig> readerConfig, TableSchema* tableSchema) {
// For table with SERIAL columns, we need to read in serial from files.
auto containsSerial = bindContainsSerial(tableSchema);
auto columns = bindExpectedNodeFileColumns(tableSchema, *readerConfig);
auto tableFuncBindInput =
std::make_unique<function::ScanTableFuncBindInput>(std::move(*readerConfig), memoryManager);
auto copyFuncBindData =
copyFunc->bindFunc(clientContext, tableFuncBindInput.get(), catalog.getReadOnlyVersion());
auto nodeID = createVariable(std::string(Property::INTERNAL_ID_NAME), LogicalTypeID::INT64);
std::vector<std::string> expectedColumnNames;
std::vector<std::unique_ptr<common::LogicalType>> expectedColumnTypes;
bindExpectedNodeColumns(tableSchema, expectedColumnNames, expectedColumnTypes);
auto bindInput = std::make_unique<function::ScanTableFuncBindInput>(memoryManager,
*readerConfig, std::move(expectedColumnNames), std::move(expectedColumnTypes));
auto bindData =
copyFunc->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion());
expression_vector columns;
for (auto i = 0u; i < bindData->columnTypes.size(); i++) {
columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i]));
}
auto offset = expressionBinder.createVariableExpression(
LogicalType(LogicalTypeID::INT64), common::InternalKeyword::ANONYMOUS);
auto boundFileScanInfo = std::make_unique<BoundFileScanInfo>(
copyFunc, copyFuncBindData->copy(), columns, std::move(nodeID), TableType::NODE);
copyFunc, std::move(bindData), columns, std::move(offset), TableType::NODE);
auto boundCopyFromInfo = std::make_unique<BoundCopyFromInfo>(tableSchema,
std::move(boundFileScanInfo), containsSerial, std::move(columns), nullptr /* extraInfo */);
return std::make_unique<BoundCopyFrom>(std::move(boundCopyFromInfo));
Expand All @@ -150,31 +148,38 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRelFrom(function::TableFunction*
// For table with SERIAL columns, we need to read in serial from files.
auto containsSerial = bindContainsSerial(tableSchema);
KU_ASSERT(containsSerial == false);
auto columnsToRead = bindExpectedRelFileColumns(tableSchema, *readerConfig);
auto tableFuncBindInput =
std::make_unique<function::ScanTableFuncBindInput>(std::move(*readerConfig), memoryManager);
auto copyFuncBindData =
copyFunc->bindFunc(clientContext, tableFuncBindInput.get(), catalog.getReadOnlyVersion());
auto relID = createVariable(std::string(Property::INTERNAL_ID_NAME), LogicalTypeID::INT64);
std::vector<std::string> expectedColumnNames;
std::vector<std::unique_ptr<common::LogicalType>> expectedColumnTypes;
bindExpectedRelColumns(tableSchema, expectedColumnNames, expectedColumnTypes);
auto bindInput = std::make_unique<function::ScanTableFuncBindInput>(memoryManager,
std::move(*readerConfig), std::move(expectedColumnNames), std::move(expectedColumnTypes));
auto bindData =
copyFunc->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion());
expression_vector columns;
for (auto i = 0u; i < bindData->columnTypes.size(); i++) {
columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i]));
}
auto offset = expressionBinder.createVariableExpression(
LogicalType(LogicalTypeID::INT64), common::InternalKeyword::ANONYMOUS);
auto boundFileScanInfo = std::make_unique<BoundFileScanInfo>(
copyFunc, copyFuncBindData->copy(), columnsToRead, relID->copy(), TableType::REL);
copyFunc, std::move(bindData), columns, offset, TableType::REL);
auto relTableSchema = reinterpret_cast<RelTableSchema*>(tableSchema);
auto srcTableSchema =
catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getSrcTableID());
auto dstTableSchema =
catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getDstTableID());
auto srcKey = columnsToRead[0];
auto dstKey = columnsToRead[1];
auto srcKey = columns[0];
auto dstKey = columns[1];
auto srcNodeID =
createVariable(std::string(Property::REL_BOUND_OFFSET_NAME), LogicalTypeID::INT64);
auto dstNodeID =
createVariable(std::string(Property::REL_NBR_OFFSET_NAME), LogicalTypeID::INT64);
auto extraCopyRelInfo = std::make_unique<ExtraBoundCopyRelInfo>(
srcTableSchema, dstTableSchema, srcNodeID, dstNodeID, srcKey, dstKey);
// Skip the first two columns.
expression_vector columnsToCopy{std::move(srcNodeID), std::move(dstNodeID), std::move(relID)};
for (auto i = NUM_COLUMNS_TO_SKIP_IN_REL_FILE; i < columnsToRead.size(); i++) {
columnsToCopy.push_back(std::move(columnsToRead[i]));
expression_vector columnsToCopy{std::move(srcNodeID), std::move(dstNodeID), std::move(offset)};
for (auto i = NUM_COLUMNS_TO_SKIP_IN_REL_FILE; i < columns.size(); i++) {
columnsToCopy.push_back(columns[i]);
}
auto boundCopyFromInfo =
std::make_unique<BoundCopyFromInfo>(tableSchema, std::move(boundFileScanInfo),
Expand All @@ -187,97 +192,47 @@ static bool skipPropertyInFile(const Property& property) {
TableSchema::isReservedPropertyName(property.getName());
}

expression_vector Binder::bindExpectedNodeFileColumns(
TableSchema* tableSchema, ReaderConfig& readerConfig) {
// Resolve expected columns.
std::vector<std::string> expectedColumnNames;
std::vector<std::unique_ptr<common::LogicalType>> expectedColumnTypes;
switch (readerConfig.fileType) {
case FileType::NPY:
case FileType::PARQUET:
case FileType::CSV: {
for (auto& property : tableSchema->properties) {
if (skipPropertyInFile(*property)) {
continue;
}
expectedColumnNames.push_back(property->getName());
expectedColumnTypes.push_back(property->getDataType()->copy());
void Binder::bindExpectedNodeColumns(catalog::TableSchema* tableSchema,
std::vector<std::string>& columnNames,
std::vector<std::unique_ptr<common::LogicalType>>& columnTypes) {
for (auto& property : tableSchema->properties) {
if (skipPropertyInFile(*property)) {
continue;
}
} break;
default: {
KU_UNREACHABLE;
}
}
// Detect columns from file.
std::vector<std::string> detectedColumnNames;
std::vector<std::unique_ptr<common::LogicalType>> detectedColumnTypes;
sniffFiles(readerConfig, detectedColumnNames, detectedColumnTypes);
// Validate.
validateNumColumns(expectedColumnTypes.size(), detectedColumnTypes.size());
if (readerConfig.fileType == common::FileType::PARQUET) {
// HACK(Ziyi): We should allow casting in Parquet reader.
validateColumnTypes(expectedColumnNames, expectedColumnTypes, detectedColumnTypes);
columnNames.push_back(property->getName());
columnTypes.push_back(property->getDataType()->copy());
}
return createColumnExpressions(readerConfig, expectedColumnNames, expectedColumnTypes);
}

expression_vector Binder::bindExpectedRelFileColumns(
TableSchema* tableSchema, ReaderConfig& readerConfig) {
void Binder::bindExpectedRelColumns(catalog::TableSchema* tableSchema,
std::vector<std::string>& columnNames,
std::vector<std::unique_ptr<common::LogicalType>>& columnTypes) {
auto relTableSchema = reinterpret_cast<RelTableSchema*>(tableSchema);
expression_vector columns;
switch (readerConfig.fileType) {
case FileType::CSV:
case FileType::PARQUET:
case FileType::NPY: {
auto srcColumnName = std::string(Property::REL_FROM_PROPERTY_NAME);
auto dstColumnName = std::string(Property::REL_TO_PROPERTY_NAME);
readerConfig.columnNames.push_back(srcColumnName);
readerConfig.columnNames.push_back(dstColumnName);
auto srcTable =
catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getSrcTableID());
KU_ASSERT(srcTable->tableType == TableType::NODE);
auto dstTable =
catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getDstTableID());
KU_ASSERT(dstTable->tableType == TableType::NODE);
auto srcPKColumnType =
reinterpret_cast<NodeTableSchema*>(srcTable)->getPrimaryKey()->getDataType()->copy();
if (srcPKColumnType->getLogicalTypeID() == LogicalTypeID::SERIAL) {
srcPKColumnType = std::make_unique<LogicalType>(LogicalTypeID::INT64);
}
auto dstPKColumnType =
reinterpret_cast<NodeTableSchema*>(dstTable)->getPrimaryKey()->getDataType()->copy();
if (dstPKColumnType->getLogicalTypeID() == LogicalTypeID::SERIAL) {
dstPKColumnType = std::make_unique<LogicalType>(LogicalTypeID::INT64);
}
columns.push_back(createVariable(srcColumnName, *srcPKColumnType));
columns.push_back(createVariable(dstColumnName, *dstPKColumnType));
readerConfig.columnTypes.push_back(std::move(srcPKColumnType));
readerConfig.columnTypes.push_back(std::move(dstPKColumnType));
for (auto& property : tableSchema->properties) {
if (skipPropertyInFile(*property)) {
continue;
}
readerConfig.columnNames.push_back(property->getName());
auto columnType = property->getDataType()->copy();
columns.push_back(createVariable(property->getName(), *columnType));
readerConfig.columnTypes.push_back(std::move(columnType));
auto srcTable = reinterpret_cast<NodeTableSchema*>(
catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getSrcTableID()));
auto dstTable = reinterpret_cast<NodeTableSchema*>(
catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getDstTableID()));
auto srcColumnName = std::string(Property::REL_FROM_PROPERTY_NAME);
auto dstColumnName = std::string(Property::REL_TO_PROPERTY_NAME);
columnNames.push_back(srcColumnName);
columnNames.push_back(dstColumnName);
auto srcPKColumnType = srcTable->getPrimaryKey()->getDataType()->copy();
if (srcPKColumnType->getLogicalTypeID() == LogicalTypeID::SERIAL) {
srcPKColumnType = std::make_unique<LogicalType>(LogicalTypeID::INT64);
}
auto dstPKColumnType = dstTable->getPrimaryKey()->getDataType()->copy();
if (dstPKColumnType->getLogicalTypeID() == LogicalTypeID::SERIAL) {
dstPKColumnType = std::make_unique<LogicalType>(LogicalTypeID::INT64);
}
columnTypes.push_back(std::move(srcPKColumnType));
columnTypes.push_back(std::move(dstPKColumnType));
for (auto& property : tableSchema->properties) {
if (skipPropertyInFile(*property)) {
continue;
}
} break;
default: {
KU_UNREACHABLE;
}
}
// Detect columns from file.
std::vector<std::string> detectedColumnNames;
std::vector<std::unique_ptr<common::LogicalType>> detectedColumnTypes;
sniffFiles(readerConfig, detectedColumnNames, detectedColumnTypes);
// Validate number of columns.
validateNumColumns(readerConfig.getNumColumns(), detectedColumnTypes.size());
if (readerConfig.fileType == common::FileType::PARQUET) {
validateColumnTypes(
readerConfig.columnNames, readerConfig.columnTypes, detectedColumnTypes);
columnNames.push_back(property->getName());
columnTypes.push_back(property->getDataType()->copy());
}
return columns;
}

} // namespace binder
Expand Down
Loading

0 comments on commit 4227a14

Please sign in to comment.