Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binder copy read rework #3251

Merged
merged 2 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/binder/bind/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
add_subdirectory(copy)
add_subdirectory(ddl)
add_subdirectory(read)

add_library(
kuzu_binder_bind
OBJECT
bind_attach_database.cpp
bind_comment_on.cpp
bind_copy.cpp
bind_create_macro.cpp
bind_ddl.cpp
bind_detach_database.cpp
Expand Down
2 changes: 1 addition & 1 deletion src/binder/bind/bind_ddl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ std::unique_ptr<BoundStatement> Binder::bindRenameTable(const Statement& stateme
static void validatePropertyExist(TableCatalogEntry* tableEntry, const std::string& propertyName) {
if (!tableEntry->containProperty(propertyName)) {
throw BinderException(
tableEntry->getName() + " table doesn't have property " + propertyName + ".");
tableEntry->getName() + " table does not have property " + propertyName + ".");
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/binder/bind/bind_file_scan.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "binder/binder.h"
#include "binder/bound_scan_source.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/literal_expression.h"
#include "common/exception/binder.h"
#include "common/exception/copy.h"
Expand Down Expand Up @@ -102,7 +103,7 @@ std::unique_ptr<BoundBaseScanSource> Binder::bindScanSource(BaseScanSource* sour
columns.size(), expectedColumnNames.size()));
}
for (auto i = 0u; i < columns.size(); ++i) {
expressionBinder.validateDataType(*columns[i], expectedColumnTypes[i]);
ExpressionUtil::validateDataType(*columns[i], expectedColumnTypes[i]);
columns[i]->setAlias(expectedColumnNames[i]);
}
return std::make_unique<BoundQueryScanSource>(std::move(boundStatement));
Expand Down
2 changes: 1 addition & 1 deletion src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ std::shared_ptr<RelExpression> Binder::bindQueryRel(const RelPattern& relPattern
auto expectedDataType = QueryRelTypeUtils::isRecursive(relPattern.getRelType()) ?
LogicalTypeID::RECURSIVE_REL :
LogicalTypeID::REL;
ExpressionBinder::validateExpectedDataType(*prevVariable, expectedDataType);
ExpressionUtil::validateDataType(*prevVariable, expectedDataType);
throw BinderException("Bind relationship " + parsedName +
" to relationship with same name is not supported.");
}
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind/bind_query.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/query/return_with_clause/bound_return_clause.h"
#include "binder/query/return_with_clause/bound_with_clause.h"
#include "common/exception/binder.h"
Expand All @@ -23,8 +24,7 @@ void validateUnionColumnsOfTheSameType(
// Check whether the dataTypes in union expressions are exactly the same in each single
// query.
for (auto j = 0u; j < columns.size(); j++) {
ExpressionBinder::validateExpectedDataType(*otherColumns[j],
columns[j]->dataType.getLogicalTypeID());
ExpressionUtil::validateDataType(*otherColumns[j], columns[j]->getDataType());
}
}
}
Expand Down
243 changes: 0 additions & 243 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
@@ -1,31 +1,7 @@
#include "binder/binder.h"
#include "binder/expression/literal_expression.h"
#include "binder/query/reading_clause/bound_in_query_call.h"
#include "binder/query/reading_clause/bound_load_from.h"
#include "binder/query/reading_clause/bound_match_clause.h"
#include "binder/query/reading_clause/bound_unwind_clause.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/table_catalog_entry.h"
#include "common/exception/binder.h"
#include "common/exception/message.h"
#include "common/string_format.h"
#include "common/string_utils.h"
#include "function/built_in_function_utils.h"
#include "function/table/bind_input.h"
#include "main/attached_database.h"
#include "main/database.h"
#include "main/database_manager.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/expression/parsed_variable_expression.h"
#include "parser/query/reading_clause/in_query_call_clause.h"
#include "parser/query/reading_clause/load_from.h"
#include "parser/query/reading_clause/match_clause.h"
#include "parser/query/reading_clause/unwind_clause.h"

using namespace kuzu::common;
using namespace kuzu::parser;
using namespace kuzu::catalog;
using namespace kuzu::function;

namespace kuzu {
namespace binder {
Expand All @@ -49,224 +25,5 @@ std::unique_ptr<BoundReadingClause> Binder::bindReadingClause(const ReadingClaus
}
}

std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause& readingClause) {
auto& matchClause = ku_dynamic_cast<const ReadingClause&, const MatchClause&>(readingClause);
auto boundGraphPattern = bindGraphPattern(matchClause.getPatternElementsRef());
if (matchClause.hasWherePredicate()) {
boundGraphPattern.where = bindWhereExpression(*matchClause.getWherePredicate());
}
rewriteMatchPattern(boundGraphPattern);
auto boundMatch = std::make_unique<BoundMatchClause>(
std::move(boundGraphPattern.queryGraphCollection), matchClause.getMatchClauseType());
boundMatch->setPredicate(boundGraphPattern.where);
return boundMatch;
}

void Binder::rewriteMatchPattern(BoundGraphPattern& boundGraphPattern) {
// Rewrite self loop edge
// e.g. rewrite (a)-[e]->(a) as [a]-[e]->(b) WHERE id(a) = id(b)
expression_vector selfLoopEdgePredicates;
auto& graphCollection = boundGraphPattern.queryGraphCollection;
for (auto i = 0u; i < graphCollection.getNumQueryGraphs(); ++i) {
auto queryGraph = graphCollection.getQueryGraphUnsafe(i);
for (auto& queryRel : queryGraph->getQueryRels()) {
if (!queryRel->isSelfLoop()) {
continue;
}
auto src = queryRel->getSrcNode();
auto dst = queryRel->getDstNode();
auto newDst = createQueryNode(dst->getVariableName(), dst->getTableIDs());
queryGraph->addQueryNode(newDst);
queryRel->setDstNode(newDst);
auto predicate = expressionBinder.createEqualityComparisonExpression(
src->getInternalID(), newDst->getInternalID());
selfLoopEdgePredicates.push_back(std::move(predicate));
}
}
auto where = boundGraphPattern.where;
for (auto& predicate : selfLoopEdgePredicates) {
where = expressionBinder.combineBooleanExpressions(ExpressionType::AND, predicate, where);
}
// Rewrite key value pairs in MATCH clause as predicate
for (auto i = 0u; i < graphCollection.getNumQueryGraphs(); ++i) {
auto queryGraph = graphCollection.getQueryGraphUnsafe(i);
for (auto& pattern : queryGraph->getAllPatterns()) {
for (auto& [propertyName, rhs] : pattern->getPropertyDataExprRef()) {
auto propertyExpr =
expressionBinder.bindNodeOrRelPropertyExpression(*pattern, propertyName);
auto predicate =
expressionBinder.createEqualityComparisonExpression(propertyExpr, rhs);
where = expressionBinder.combineBooleanExpressions(ExpressionType::AND, predicate,
where);
}
}
}
boundGraphPattern.where = std::move(where);
}

std::unique_ptr<BoundReadingClause> Binder::bindUnwindClause(const ReadingClause& readingClause) {
auto& unwindClause = ku_dynamic_cast<const ReadingClause&, const UnwindClause&>(readingClause);
auto boundExpression = expressionBinder.bindExpression(*unwindClause.getExpression());
ExpressionBinder::validateDataType(*boundExpression, LogicalTypeID::LIST);
auto aliasName = unwindClause.getAlias();
auto alias = createVariable(aliasName, *ListType::getChildType(&boundExpression->dataType));
std::shared_ptr<Expression> idExpr = nullptr;
if (scope.hasMemorizedTableIDs(boundExpression->getAlias())) {
auto tableIDs = scope.getMemorizedTableIDs(boundExpression->getAlias());
auto node = createQueryNode(aliasName, tableIDs);
idExpr = node->getInternalID();
scope.addNodeReplacement(node);
}
return make_unique<BoundUnwindClause>(std::move(boundExpression), std::move(alias),
std::move(idExpr));
}

std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause& readingClause) {
auto& call = ku_dynamic_cast<const ReadingClause&, const InQueryCallClause&>(readingClause);
auto expr = call.getFunctionExpression();
auto functionExpr =
ku_dynamic_cast<const ParsedExpression*, const ParsedFunctionExpression*>(expr);
// Bind parameters
std::unique_ptr<ScanReplacementData> replacementData;
expression_vector params;
for (auto i = 0u; i < functionExpr->getNumChildren(); i++) {
auto child = functionExpr->getChild(i);
try {
params.push_back(expressionBinder.bindExpression(*child));
} catch (BinderException& exception) {
if (child->getExpressionType() != ExpressionType::VARIABLE) {
throw BinderException(exception.what()); // Cannot replace. Rethrow.
}
// Try replacement.
auto varExpr = ku_dynamic_cast<ParsedExpression*, ParsedVariableExpression*>(child);
auto var = varExpr->getVariableName();
replacementData = clientContext->tryReplace(var);
if (replacementData == nullptr) { // Replacement fail.
throw BinderException(ExceptionMessage::variableNotInScope(var));
}
}
}
TableFunction tableFunction;
std::unique_ptr<TableFuncBindData> bindData;
if (replacementData) {
tableFunction = replacementData->func;
bindData = tableFunction.bindFunc(clientContext, &replacementData->bindInput);
} else {
std::vector<Value> inputValues;
std::vector<LogicalType> inputTypes;
for (auto& param : params) {
if (param->expressionType != ExpressionType::LITERAL) {
throw BinderException{
stringFormat("Cannot evaluate {} as a literal.", param->toString())};
}
auto literalExpr =
ku_dynamic_cast<const Expression*, const LiteralExpression*>(param.get());
inputTypes.push_back(literalExpr->getDataType());
inputValues.push_back(*literalExpr->getValue());
}
auto functions = clientContext->getCatalog()->getFunctions(clientContext->getTx());
auto func = BuiltInFunctionsUtils::matchFunction(functionExpr->getFunctionName(),
inputTypes, functions);
tableFunction = *ku_dynamic_cast<function::Function*, function::TableFunction*>(func);
auto bindInput = function::TableFuncBindInput();
bindInput.inputs = std::move(inputValues);
bindData = tableFunction.bindFunc(clientContext, &bindInput);
}
expression_vector columns;
for (auto i = 0u; i < bindData->columnTypes.size(); i++) {
columns.push_back(createVariable(bindData->columnNames[i], bindData->columnTypes[i]));
}
auto offset = expressionBinder.createVariableExpression(*LogicalType::INT64(),
std::string(InternalKeyword::ROW_OFFSET));
auto boundInQueryCall = std::make_unique<BoundInQueryCall>(tableFunction, std::move(bindData),
std::move(columns), offset);
if (call.hasWherePredicate()) {
auto wherePredicate = expressionBinder.bindExpression(*call.getWherePredicate());
boundInQueryCall->setPredicate(std::move(wherePredicate));
}
return boundInQueryCall;
}

std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(const ReadingClause& readingClause) {
auto& loadFrom = ku_dynamic_cast<const ReadingClause&, const LoadFrom&>(readingClause);
TableFunction scanFunction;
std::unique_ptr<TableFuncBindData> bindData;
auto source = loadFrom.getSource();
switch (source->type) {
case ScanSourceType::OBJECT: {
auto objectSource = ku_dynamic_cast<BaseScanSource*, ObjectScanSource*>(source);
auto objectName = objectSource->objectName;
if (objectName.find("_") == std::string::npos) {
// Bind table
auto replacementData = clientContext->tryReplace(objectName);
if (replacementData == nullptr) {
throw BinderException(ExceptionMessage::variableNotInScope(objectName));
}
scanFunction = replacementData->func;
bindData = scanFunction.bindFunc(clientContext, &replacementData->bindInput);
} else {
auto dbName = common::StringUtils::split(objectName, "_")[0];
auto attachedDB =
clientContext->getDatabase()->getDatabaseManagerUnsafe()->getAttachedDatabase(
dbName);
if (attachedDB == nullptr) {
throw BinderException{
common::stringFormat("No database named {} has been attached.", dbName)};
}
auto tableName = common::StringUtils::split(objectName, "_")[1];
auto tableID = attachedDB->getCatalogContent()->getTableID(tableName);
auto tableCatalogEntry = ku_dynamic_cast<CatalogEntry*, TableCatalogEntry*>(
attachedDB->getCatalogContent()->getTableCatalogEntry(tableID));
scanFunction = tableCatalogEntry->getScanFunction();
auto bindInput = function::TableFuncBindInput();
bindData = scanFunction.bindFunc(clientContext, &bindInput);
}
} break;
case ScanSourceType::FILE: {
auto fileSource = ku_dynamic_cast<BaseScanSource*, FileScanSource*>(source);
auto filePaths = bindFilePaths(fileSource->filePaths);
auto fileType = bindFileType(filePaths);
auto readerConfig = std::make_unique<ReaderConfig>(fileType, std::move(filePaths));
readerConfig->options = bindParsingOptions(loadFrom.getParsingOptionsRef());
if (readerConfig->getNumFiles() > 1) {
throw BinderException("Load from multiple files is not supported.");
}
switch (fileType) {
case common::FileType::CSV:
case common::FileType::PARQUET:
case common::FileType::NPY:
break;
default:
throw BinderException(
stringFormat("Cannot load from file type {}.", FileTypeUtils::toString(fileType)));
}
// Bind columns from input.
std::vector<std::string> expectedColumnNames;
std::vector<LogicalType> expectedColumnTypes;
for (auto& [name, type] : loadFrom.getColumnNameDataTypesRef()) {
expectedColumnNames.push_back(name);
expectedColumnTypes.push_back(*bindDataType(type));
}
scanFunction = getScanFunction(readerConfig->fileType, *readerConfig);
auto bindInput = ScanTableFuncBindInput(readerConfig->copy(),
std::move(expectedColumnNames), std::move(expectedColumnTypes), clientContext);
bindData = scanFunction.bindFunc(clientContext, &bindInput);
} break;
default:
throw BinderException(stringFormat("LOAD FROM subquery is not supported."));
}
expression_vector columns;
for (auto i = 0u; i < bindData->columnTypes.size(); i++) {
columns.push_back(createVariable(bindData->columnNames[i], bindData->columnTypes[i]));
}
auto info = BoundFileScanInfo(scanFunction, std::move(bindData), std::move(columns));
auto boundLoadFrom = std::make_unique<BoundLoadFrom>(std::move(info));
if (loadFrom.hasWherePredicate()) {
auto wherePredicate = expressionBinder.bindExpression(*loadFrom.getWherePredicate());
boundLoadFrom->setPredicate(std::move(wherePredicate));
}
return boundLoadFrom;
}

} // namespace binder
} // namespace kuzu
3 changes: 2 additions & 1 deletion src/binder/bind/bind_standalone_call.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "binder/binder.h"
#include "binder/bound_standalone_call.h"
#include "binder/expression/expression_util.h"
#include "common/exception/binder.h"
#include "extension/extension.h"
#include "main/db_config.h"
Expand All @@ -23,7 +24,7 @@ std::unique_ptr<BoundStatement> Binder::bindStandaloneCall(const parser::Stateme
}
auto optionValue = expressionBinder.bindLiteralExpression(*callStatement.getOptionValue());
// TODO(Ziyi): add casting rule for option value.
ExpressionBinder::validateExpectedDataType(*optionValue, option->parameterType);
ExpressionUtil::validateDataType(*optionValue, option->parameterType);
return std::make_unique<BoundStandaloneCall>(option, std::move(optionValue));
}

Expand Down
4 changes: 3 additions & 1 deletion src/binder/bind/copy/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
add_library(kuzu_binder_bind_copy
OBJECT
bind_copy_rdf_graph.cpp)
bind_copy_rdf_graph.cpp
bind_copy_to.cpp
bind_copy_from.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_binder_bind_copy>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "binder/binder.h"
#include "binder/copy/bound_copy_from.h"
#include "binder/copy/bound_copy_to.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rdf_graph_catalog_entry.h"
Expand All @@ -21,33 +20,6 @@ using namespace kuzu::function;
namespace kuzu {
namespace binder {

std::unique_ptr<BoundStatement> Binder::bindCopyToClause(const Statement& statement) {
auto& copyToStatement = ku_dynamic_cast<const Statement&, const CopyTo&>(statement);
auto boundFilePath = copyToStatement.getFilePath();
auto fileType = bindFileType(boundFilePath);
std::vector<std::string> columnNames;
std::vector<LogicalType> columnTypes;
auto parsedQuery =
ku_dynamic_cast<const Statement*, const RegularQuery*>(copyToStatement.getStatement());
auto query = bindQuery(*parsedQuery);
auto columns = query->getStatementResult()->getColumns();
for (auto& column : columns) {
auto columnName = column->hasAlias() ? column->getAlias() : column->toString();
columnNames.push_back(columnName);
columnTypes.push_back(column->getDataType());
}
if (fileType != FileType::CSV && fileType != FileType::PARQUET) {
throw BinderException("COPY TO currently only supports csv and parquet files.");
}
if (fileType != FileType::CSV && copyToStatement.getParsingOptionsRef().size() != 0) {
throw BinderException{"Only copy to csv can have options."};
}
auto csvConfig =
CSVReaderConfig::construct(bindParsingOptions(copyToStatement.getParsingOptionsRef()));
return std::make_unique<BoundCopyTo>(boundFilePath, fileType, std::move(query),
csvConfig.option.copy());
}

std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& statement) {
auto& copyStatement = ku_dynamic_cast<const Statement&, const CopyFrom&>(statement);
auto tableName = copyStatement.getTableName();
Expand Down
Loading
Loading