Skip to content

Commit

Permalink
Merge pull request #2407 from kuzudb/partial-column-copy
Browse files Browse the repository at this point in the history
Partial column copy
  • Loading branch information
andyfengHKU committed Nov 17, 2023
2 parents ef27e49 + 4fa1c3b commit cd61648
Show file tree
Hide file tree
Showing 37 changed files with 3,230 additions and 2,757 deletions.
5 changes: 4 additions & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ oC_Statement
| kU_Transaction ;

kU_CopyFrom
: COPY SP oC_SchemaName SP FROM SP kU_FilePaths ( SP? '(' SP? kU_ParsingOptions SP? ')' )? ;
: COPY SP oC_SchemaName ( ( SP? '(' SP? kU_ColumnNames SP? ')' SP? ) | SP ) FROM SP kU_FilePaths ( SP? '(' SP? kU_ParsingOptions SP? ')' )? ;

kU_ColumnNames
: oC_SchemaName ( SP? ',' SP? oC_SchemaName )* ;

kU_CopyFromByColumn
: COPY SP oC_SchemaName SP FROM SP '(' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ')' SP BY SP COLUMN ;
Expand Down
159 changes: 83 additions & 76 deletions src/binder/bind/bind_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "common/exception/binder.h"
#include "common/exception/message.h"
#include "common/string_format.h"
#include "function/table_functions.h"
#include "function/table_functions/bind_input.h"
#include "parser/copy.h"

Expand All @@ -19,8 +18,6 @@ using namespace kuzu::parser;
namespace kuzu {
namespace binder {

static constexpr uint64_t NUM_COLUMNS_TO_SKIP_IN_REL_FILE = 2;

std::unique_ptr<BoundStatement> Binder::bindCopyToClause(const Statement& statement) {
auto& copyToStatement = reinterpret_cast<const CopyTo&>(statement);
auto boundFilePath = copyToStatement.getFilePath();
Expand Down Expand Up @@ -63,19 +60,8 @@ static void validateCopyNpyNotForRelTables(TableSchema* schema) {
}
}

static bool bindContainsSerial(TableSchema* tableSchema) {
bool containsSerial = false;
for (auto& property : tableSchema->properties) {
if (property->getDataType()->getLogicalTypeID() == LogicalTypeID::SERIAL) {
containsSerial = true;
break;
}
}
return containsSerial;
}

std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& statement) {
auto& copyStatement = (CopyFrom&)statement;
auto& copyStatement = reinterpret_cast<const CopyFrom&>(statement);
auto catalogContent = catalog.getReadOnlyVersion();
auto tableName = copyStatement.getTableName();
validateTableExist(tableName);
Expand All @@ -96,23 +82,22 @@ std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& stat
auto fileType = bindFileType(filePaths);
auto readerConfig =
std::make_unique<ReaderConfig>(fileType, std::move(filePaths), std::move(csvReaderConfig));
auto scanFunction = getScanFunction(fileType, readerConfig->csvReaderConfig->parallel);
validateByColumnKeyword(readerConfig->fileType, copyStatement.byColumn());
if (readerConfig->fileType == FileType::NPY) {
validateCopyNpyNotForRelTables(tableSchema);
}
switch (tableSchema->tableType) {
case TableType::NODE:
if (readerConfig->fileType == FileType::TURTLE) {
return bindCopyRdfNodeFrom(scanFunction, std::move(readerConfig), tableSchema);
return bindCopyRdfNodeFrom(statement, std::move(readerConfig), tableSchema);
} else {
return bindCopyNodeFrom(scanFunction, std::move(readerConfig), tableSchema);
return bindCopyNodeFrom(statement, std::move(readerConfig), tableSchema);
}
case TableType::REL: {
if (readerConfig->fileType == FileType::TURTLE) {
return bindCopyRdfRelFrom(scanFunction, std::move(readerConfig), tableSchema);
return bindCopyRdfRelFrom(statement, std::move(readerConfig), tableSchema);
} else {
return bindCopyRelFrom(scanFunction, std::move(readerConfig), tableSchema);
return bindCopyRelFrom(statement, std::move(readerConfig), tableSchema);
}
}
// LCOV_EXCL_START
Expand All @@ -123,71 +108,67 @@ std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& stat
}
}

std::unique_ptr<BoundStatement> Binder::bindCopyNodeFrom(function::TableFunction* copyFunc,
std::unique_ptr<common::ReaderConfig> readerConfig, TableSchema* tableSchema) {
std::unique_ptr<BoundStatement> Binder::bindCopyNodeFrom(const Statement& statement,
std::unique_ptr<common::ReaderConfig> config, TableSchema* tableSchema) {
auto& copyStatement = reinterpret_cast<const CopyFrom&>(statement);
auto func = getScanFunction(config->fileType, config->csvReaderConfig->parallel);
// For table with SERIAL columns, we need to read in serial from files.
auto containsSerial = bindContainsSerial(tableSchema);
auto containsSerial = tableSchema->containsColumnType(LogicalType(LogicalTypeID::SERIAL));
std::vector<std::string> expectedColumnNames;
std::vector<std::unique_ptr<common::LogicalType>> expectedColumnTypes;
bindExpectedNodeColumns(tableSchema, expectedColumnNames, expectedColumnTypes);
auto bindInput = std::make_unique<function::ScanTableFuncBindInput>(memoryManager,
*readerConfig, std::move(expectedColumnNames), std::move(expectedColumnTypes));
auto bindData =
copyFunc->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion());
bindExpectedNodeColumns(
tableSchema, copyStatement.getColumnNames(), expectedColumnNames, expectedColumnTypes);
auto bindInput = std::make_unique<function::ScanTableFuncBindInput>(
memoryManager, *config, std::move(expectedColumnNames), std::move(expectedColumnTypes));
auto bindData = func->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion());
expression_vector columns;
for (auto i = 0u; i < bindData->columnTypes.size(); i++) {
columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i]));
}
auto offset = expressionBinder.createVariableExpression(
LogicalType(LogicalTypeID::INT64), common::InternalKeyword::ANONYMOUS);
auto boundFileScanInfo = std::make_unique<BoundFileScanInfo>(
copyFunc, std::move(bindData), columns, std::move(offset), TableType::NODE);
auto boundCopyFromInfo = std::make_unique<BoundCopyFromInfo>(tableSchema,
std::move(boundFileScanInfo), containsSerial, std::move(columns), nullptr /* extraInfo */);
LogicalType(LogicalTypeID::INT64), InternalKeyword::ANONYMOUS);
auto boundFileScanInfo =
std::make_unique<BoundFileScanInfo>(func, std::move(bindData), columns, std::move(offset));
auto boundCopyFromInfo = std::make_unique<BoundCopyFromInfo>(
tableSchema, std::move(boundFileScanInfo), containsSerial, nullptr /* extraInfo */);
return std::make_unique<BoundCopyFrom>(std::move(boundCopyFromInfo));
}

std::unique_ptr<BoundStatement> Binder::bindCopyRelFrom(function::TableFunction* copyFunc,
std::unique_ptr<common::ReaderConfig> readerConfig, TableSchema* tableSchema) {
std::unique_ptr<BoundStatement> Binder::bindCopyRelFrom(const parser::Statement& statement,
std::unique_ptr<common::ReaderConfig> config, TableSchema* tableSchema) {
auto& copyStatement = reinterpret_cast<const CopyFrom&>(statement);
auto func = getScanFunction(config->fileType, config->csvReaderConfig->parallel);
// For table with SERIAL columns, we need to read in serial from files.
auto containsSerial = bindContainsSerial(tableSchema);
auto containsSerial = tableSchema->containsColumnType(LogicalType(LogicalTypeID::SERIAL));
KU_ASSERT(containsSerial == false);
std::vector<std::string> expectedColumnNames;
std::vector<std::unique_ptr<common::LogicalType>> expectedColumnTypes;
bindExpectedRelColumns(tableSchema, expectedColumnNames, expectedColumnTypes);
bindExpectedRelColumns(
tableSchema, copyStatement.getColumnNames(), expectedColumnNames, expectedColumnTypes);
auto bindInput = std::make_unique<function::ScanTableFuncBindInput>(memoryManager,
std::move(*readerConfig), std::move(expectedColumnNames), std::move(expectedColumnTypes));
auto bindData =
copyFunc->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion());
std::move(*config), std::move(expectedColumnNames), std::move(expectedColumnTypes));
auto bindData = func->bindFunc(clientContext, bindInput.get(), catalog.getReadOnlyVersion());
expression_vector columns;
for (auto i = 0u; i < bindData->columnTypes.size(); i++) {
columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i]));
}
auto offset = expressionBinder.createVariableExpression(
LogicalType(LogicalTypeID::INT64), common::InternalKeyword::ANONYMOUS);
auto boundFileScanInfo = std::make_unique<BoundFileScanInfo>(
copyFunc, std::move(bindData), columns, offset, TableType::REL);
LogicalType(LogicalTypeID::INT64), std::string(InternalKeyword::ROW_OFFSET));
auto boundFileScanInfo =
std::make_unique<BoundFileScanInfo>(func, std::move(bindData), columns, offset);
auto relTableSchema = reinterpret_cast<RelTableSchema*>(tableSchema);
auto srcTableSchema =
catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getSrcTableID());
auto dstTableSchema =
catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getDstTableID());
auto srcKey = columns[0];
auto dstKey = columns[1];
auto srcNodeID =
createVariable(std::string(Property::REL_BOUND_OFFSET_NAME), LogicalTypeID::INT64);
auto dstNodeID =
createVariable(std::string(Property::REL_NBR_OFFSET_NAME), LogicalTypeID::INT64);
auto srcNodeID = createVariable(std::string(InternalKeyword::SRC_OFFSET), LogicalTypeID::INT64);
auto dstNodeID = createVariable(std::string(InternalKeyword::DST_OFFSET), LogicalTypeID::INT64);
auto extraCopyRelInfo = std::make_unique<ExtraBoundCopyRelInfo>(
srcTableSchema, dstTableSchema, srcNodeID, dstNodeID, srcKey, dstKey);
// Skip the first two columns.
expression_vector columnsToCopy{std::move(srcNodeID), std::move(dstNodeID), std::move(offset)};
for (auto i = NUM_COLUMNS_TO_SKIP_IN_REL_FILE; i < columns.size(); i++) {
columnsToCopy.push_back(columns[i]);
}
auto boundCopyFromInfo =
std::make_unique<BoundCopyFromInfo>(tableSchema, std::move(boundFileScanInfo),
containsSerial, std::move(columnsToCopy), std::move(extraCopyRelInfo));
auto boundCopyFromInfo = std::make_unique<BoundCopyFromInfo>(
tableSchema, std::move(boundFileScanInfo), containsSerial, std::move(extraCopyRelInfo));
return std::make_unique<BoundCopyFrom>(std::move(boundCopyFromInfo));
}

Expand All @@ -196,30 +177,62 @@ static bool skipPropertyInFile(const Property& property) {
TableSchema::isReservedPropertyName(property.getName());
}

void Binder::bindExpectedNodeColumns(catalog::TableSchema* tableSchema,
std::vector<std::string>& columnNames,
std::vector<std::unique_ptr<common::LogicalType>>& columnTypes) {
for (auto& property : tableSchema->properties) {
if (skipPropertyInFile(*property)) {
continue;
static void bindExpectedColumns(TableSchema* tableSchema,
const std::vector<std::string>& inputColumnNames, std::vector<std::string>& columnNames,
logical_types_t& columnTypes) {
if (!inputColumnNames.empty()) {
std::unordered_set<std::string> inputColumnNamesSet;
for (auto& columName : inputColumnNames) {
if (inputColumnNamesSet.contains(columName)) {
throw BinderException(
stringFormat("Detect duplicate column name {} during COPY.", columName));
}
inputColumnNamesSet.insert(columName);
}
// Search column data type for each input column.
for (auto& columnName : inputColumnNames) {
if (!tableSchema->containProperty(columnName)) {
throw BinderException(stringFormat(
"Table {} does not contain column {}.", tableSchema->tableName, columnName));
}
auto propertyID = tableSchema->getPropertyID(columnName);
auto property = tableSchema->getProperty(propertyID);
if (skipPropertyInFile(*property)) {
continue;
}
columnNames.push_back(columnName);
columnTypes.push_back(property->getDataType()->copy());
}
} else {
// No column specified. Fall back to schema columns.
for (auto& property : tableSchema->properties) {
if (skipPropertyInFile(*property)) {
continue;
}
columnNames.push_back(property->getName());
columnTypes.push_back(property->getDataType()->copy());
}
columnNames.push_back(property->getName());
columnTypes.push_back(property->getDataType()->copy());
}
}

void Binder::bindExpectedRelColumns(catalog::TableSchema* tableSchema,
std::vector<std::string>& columnNames,
void Binder::bindExpectedNodeColumns(catalog::TableSchema* tableSchema,
const std::vector<std::string>& inputColumnNames, std::vector<std::string>& columnNames,
std::vector<std::unique_ptr<common::LogicalType>>& columnTypes) {
KU_ASSERT(columnNames.empty() && columnTypes.empty());
bindExpectedColumns(tableSchema, inputColumnNames, columnNames, columnTypes);
}

void Binder::bindExpectedRelColumns(TableSchema* tableSchema,
const std::vector<std::string>& inputColumnNames, std::vector<std::string>& columnNames,
std::vector<std::unique_ptr<common::LogicalType>>& columnTypes) {
KU_ASSERT(columnNames.empty() && columnTypes.empty());
auto relTableSchema = reinterpret_cast<RelTableSchema*>(tableSchema);
auto srcTable = reinterpret_cast<NodeTableSchema*>(
catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getSrcTableID()));
auto dstTable = reinterpret_cast<NodeTableSchema*>(
catalog.getReadOnlyVersion()->getTableSchema(relTableSchema->getDstTableID()));
auto srcColumnName = std::string(Property::REL_FROM_PROPERTY_NAME);
auto dstColumnName = std::string(Property::REL_TO_PROPERTY_NAME);
columnNames.push_back(srcColumnName);
columnNames.push_back(dstColumnName);
columnNames.push_back("from");
columnNames.push_back("to");
auto srcPKColumnType = srcTable->getPrimaryKey()->getDataType()->copy();
if (srcPKColumnType->getLogicalTypeID() == LogicalTypeID::SERIAL) {
srcPKColumnType = LogicalType::INT64();
Expand All @@ -230,13 +243,7 @@ void Binder::bindExpectedRelColumns(catalog::TableSchema* tableSchema,
}
columnTypes.push_back(std::move(srcPKColumnType));
columnTypes.push_back(std::move(dstPKColumnType));
for (auto& property : tableSchema->properties) {
if (skipPropertyInFile(*property)) {
continue;
}
columnNames.push_back(property->getName());
columnTypes.push_back(property->getDataType()->copy());
}
bindExpectedColumns(tableSchema, inputColumnNames, columnNames, columnTypes);
}

} // namespace binder
Expand Down
8 changes: 4 additions & 4 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i]));
}
auto offset = expressionBinder.createVariableExpression(
LogicalType(LogicalTypeID::INT64), common::InternalKeyword::ANONYMOUS);
*LogicalType::INT64(), std::string(InternalKeyword::ROW_OFFSET));
auto boundInQueryCall = std::make_unique<BoundInQueryCall>(
std::move(tableFunction), std::move(bindData), std::move(columns), offset);
if (call.hasWherePredicate()) {
Expand Down Expand Up @@ -182,9 +182,9 @@ std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(
columns.push_back(createVariable(bindData->columnNames[i], *bindData->columnTypes[i]));
}
auto offset = expressionBinder.createVariableExpression(
LogicalType(LogicalTypeID::INT64), common::InternalKeyword::ANONYMOUS);
auto info = std::make_unique<BoundFileScanInfo>(scanFunction, std::move(bindData),
std::move(columns), std::move(offset), TableType::UNKNOWN);
LogicalType(LogicalTypeID::INT64), std::string(InternalKeyword::ROW_OFFSET));
auto info = std::make_unique<BoundFileScanInfo>(
scanFunction, std::move(bindData), std::move(columns), std::move(offset));
auto boundLoadFrom = std::make_unique<BoundLoadFrom>(std::move(info));
if (loadFrom.hasWherePredicate()) {
auto wherePredicate = expressionBinder.bindExpression(*loadFrom.getWherePredicate());
Expand Down
Loading

0 comments on commit cd61648

Please sign in to comment.