Skip to content

Commit

Permalink
Split bind_copy and bind_read into multiple files, remove CALL read_p…
Browse files Browse the repository at this point in the history
…andas
  • Loading branch information
andyfengHKU authored and ray6080 committed Apr 11, 2024
1 parent b9112f4 commit d8879be
Show file tree
Hide file tree
Showing 38 changed files with 535 additions and 467 deletions.
2 changes: 1 addition & 1 deletion src/binder/bind/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
add_subdirectory(copy)
add_subdirectory(ddl)
add_subdirectory(read)

add_library(
kuzu_binder_bind
OBJECT
bind_attach_database.cpp
bind_comment_on.cpp
bind_copy.cpp
bind_create_macro.cpp
bind_ddl.cpp
bind_detach_database.cpp
Expand Down
2 changes: 1 addition & 1 deletion src/binder/bind/bind_ddl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ std::unique_ptr<BoundStatement> Binder::bindRenameTable(const Statement& stateme
static void validatePropertyExist(TableCatalogEntry* tableEntry, const std::string& propertyName) {
if (!tableEntry->containProperty(propertyName)) {
throw BinderException(
tableEntry->getName() + " table doesn't have property " + propertyName + ".");
tableEntry->getName() + " table does not have property " + propertyName + ".");
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/binder/bind/bind_file_scan.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "binder/binder.h"
#include "binder/bound_scan_source.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/literal_expression.h"
#include "common/exception/binder.h"
#include "common/exception/copy.h"
Expand Down Expand Up @@ -102,7 +103,7 @@ std::unique_ptr<BoundBaseScanSource> Binder::bindScanSource(BaseScanSource* sour
columns.size(), expectedColumnNames.size()));
}
for (auto i = 0u; i < columns.size(); ++i) {
expressionBinder.validateDataType(*columns[i], expectedColumnTypes[i]);
ExpressionUtil::validateDataType(*columns[i], expectedColumnTypes[i]);
columns[i]->setAlias(expectedColumnNames[i]);
}
return std::make_unique<BoundQueryScanSource>(std::move(boundStatement));
Expand Down
2 changes: 1 addition & 1 deletion src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ std::shared_ptr<RelExpression> Binder::bindQueryRel(const RelPattern& relPattern
auto expectedDataType = QueryRelTypeUtils::isRecursive(relPattern.getRelType()) ?
LogicalTypeID::RECURSIVE_REL :
LogicalTypeID::REL;
ExpressionBinder::validateExpectedDataType(*prevVariable, expectedDataType);
ExpressionUtil::validateDataType(*prevVariable, expectedDataType);
throw BinderException("Bind relationship " + parsedName +
" to relationship with same name is not supported.");
}
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind/bind_query.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/query/return_with_clause/bound_return_clause.h"
#include "binder/query/return_with_clause/bound_with_clause.h"
#include "common/exception/binder.h"
Expand All @@ -23,8 +24,7 @@ void validateUnionColumnsOfTheSameType(
// Check whether the dataTypes in union expressions are exactly the same in each single
// query.
for (auto j = 0u; j < columns.size(); j++) {
ExpressionBinder::validateExpectedDataType(*otherColumns[j],
columns[j]->dataType.getLogicalTypeID());
ExpressionUtil::validateDataType(*otherColumns[j], columns[j]->getDataType());
}
}
}
Expand Down
243 changes: 0 additions & 243 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
@@ -1,31 +1,7 @@
#include "binder/binder.h"
#include "binder/expression/literal_expression.h"
#include "binder/query/reading_clause/bound_in_query_call.h"
#include "binder/query/reading_clause/bound_load_from.h"
#include "binder/query/reading_clause/bound_match_clause.h"
#include "binder/query/reading_clause/bound_unwind_clause.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/table_catalog_entry.h"
#include "common/exception/binder.h"
#include "common/exception/message.h"
#include "common/string_format.h"
#include "common/string_utils.h"
#include "function/built_in_function_utils.h"
#include "function/table/bind_input.h"
#include "main/attached_database.h"
#include "main/database.h"
#include "main/database_manager.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/expression/parsed_variable_expression.h"
#include "parser/query/reading_clause/in_query_call_clause.h"
#include "parser/query/reading_clause/load_from.h"
#include "parser/query/reading_clause/match_clause.h"
#include "parser/query/reading_clause/unwind_clause.h"

using namespace kuzu::common;
using namespace kuzu::parser;
using namespace kuzu::catalog;
using namespace kuzu::function;

namespace kuzu {
namespace binder {
Expand All @@ -49,224 +25,5 @@ std::unique_ptr<BoundReadingClause> Binder::bindReadingClause(const ReadingClaus
}
}

std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause& readingClause) {
auto& matchClause = ku_dynamic_cast<const ReadingClause&, const MatchClause&>(readingClause);
auto boundGraphPattern = bindGraphPattern(matchClause.getPatternElementsRef());
if (matchClause.hasWherePredicate()) {
boundGraphPattern.where = bindWhereExpression(*matchClause.getWherePredicate());
}
rewriteMatchPattern(boundGraphPattern);
auto boundMatch = std::make_unique<BoundMatchClause>(
std::move(boundGraphPattern.queryGraphCollection), matchClause.getMatchClauseType());
boundMatch->setPredicate(boundGraphPattern.where);
return boundMatch;
}

void Binder::rewriteMatchPattern(BoundGraphPattern& boundGraphPattern) {
// Rewrite self loop edge
// e.g. rewrite (a)-[e]->(a) as [a]-[e]->(b) WHERE id(a) = id(b)
expression_vector selfLoopEdgePredicates;
auto& graphCollection = boundGraphPattern.queryGraphCollection;
for (auto i = 0u; i < graphCollection.getNumQueryGraphs(); ++i) {
auto queryGraph = graphCollection.getQueryGraphUnsafe(i);
for (auto& queryRel : queryGraph->getQueryRels()) {
if (!queryRel->isSelfLoop()) {
continue;
}
auto src = queryRel->getSrcNode();
auto dst = queryRel->getDstNode();
auto newDst = createQueryNode(dst->getVariableName(), dst->getTableIDs());
queryGraph->addQueryNode(newDst);
queryRel->setDstNode(newDst);
auto predicate = expressionBinder.createEqualityComparisonExpression(
src->getInternalID(), newDst->getInternalID());
selfLoopEdgePredicates.push_back(std::move(predicate));
}
}
auto where = boundGraphPattern.where;
for (auto& predicate : selfLoopEdgePredicates) {
where = expressionBinder.combineBooleanExpressions(ExpressionType::AND, predicate, where);
}
// Rewrite key value pairs in MATCH clause as predicate
for (auto i = 0u; i < graphCollection.getNumQueryGraphs(); ++i) {
auto queryGraph = graphCollection.getQueryGraphUnsafe(i);
for (auto& pattern : queryGraph->getAllPatterns()) {
for (auto& [propertyName, rhs] : pattern->getPropertyDataExprRef()) {
auto propertyExpr =
expressionBinder.bindNodeOrRelPropertyExpression(*pattern, propertyName);
auto predicate =
expressionBinder.createEqualityComparisonExpression(propertyExpr, rhs);
where = expressionBinder.combineBooleanExpressions(ExpressionType::AND, predicate,
where);
}
}
}
boundGraphPattern.where = std::move(where);
}

std::unique_ptr<BoundReadingClause> Binder::bindUnwindClause(const ReadingClause& readingClause) {
auto& unwindClause = ku_dynamic_cast<const ReadingClause&, const UnwindClause&>(readingClause);
auto boundExpression = expressionBinder.bindExpression(*unwindClause.getExpression());
ExpressionBinder::validateDataType(*boundExpression, LogicalTypeID::LIST);
auto aliasName = unwindClause.getAlias();
auto alias = createVariable(aliasName, *ListType::getChildType(&boundExpression->dataType));
std::shared_ptr<Expression> idExpr = nullptr;
if (scope.hasMemorizedTableIDs(boundExpression->getAlias())) {
auto tableIDs = scope.getMemorizedTableIDs(boundExpression->getAlias());
auto node = createQueryNode(aliasName, tableIDs);
idExpr = node->getInternalID();
scope.addNodeReplacement(node);
}
return make_unique<BoundUnwindClause>(std::move(boundExpression), std::move(alias),
std::move(idExpr));
}

std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause& readingClause) {
auto& call = ku_dynamic_cast<const ReadingClause&, const InQueryCallClause&>(readingClause);
auto expr = call.getFunctionExpression();
auto functionExpr =
ku_dynamic_cast<const ParsedExpression*, const ParsedFunctionExpression*>(expr);
// Bind parameters
std::unique_ptr<ScanReplacementData> replacementData;
expression_vector params;
for (auto i = 0u; i < functionExpr->getNumChildren(); i++) {
auto child = functionExpr->getChild(i);
try {
params.push_back(expressionBinder.bindExpression(*child));
} catch (BinderException& exception) {
if (child->getExpressionType() != ExpressionType::VARIABLE) {
throw BinderException(exception.what()); // Cannot replace. Rethrow.
}
// Try replacement.
auto varExpr = ku_dynamic_cast<ParsedExpression*, ParsedVariableExpression*>(child);
auto var = varExpr->getVariableName();
replacementData = clientContext->tryReplace(var);
if (replacementData == nullptr) { // Replacement fail.
throw BinderException(ExceptionMessage::variableNotInScope(var));
}
}
}
TableFunction tableFunction;
std::unique_ptr<TableFuncBindData> bindData;
if (replacementData) {
tableFunction = replacementData->func;
bindData = tableFunction.bindFunc(clientContext, &replacementData->bindInput);
} else {
std::vector<Value> inputValues;
std::vector<LogicalType> inputTypes;
for (auto& param : params) {
if (param->expressionType != ExpressionType::LITERAL) {
throw BinderException{
stringFormat("Cannot evaluate {} as a literal.", param->toString())};
}
auto literalExpr =
ku_dynamic_cast<const Expression*, const LiteralExpression*>(param.get());
inputTypes.push_back(literalExpr->getDataType());
inputValues.push_back(*literalExpr->getValue());
}
auto functions = clientContext->getCatalog()->getFunctions(clientContext->getTx());
auto func = BuiltInFunctionsUtils::matchFunction(functionExpr->getFunctionName(),
inputTypes, functions);
tableFunction = *ku_dynamic_cast<function::Function*, function::TableFunction*>(func);
auto bindInput = function::TableFuncBindInput();
bindInput.inputs = std::move(inputValues);
bindData = tableFunction.bindFunc(clientContext, &bindInput);
}
expression_vector columns;
for (auto i = 0u; i < bindData->columnTypes.size(); i++) {
columns.push_back(createVariable(bindData->columnNames[i], bindData->columnTypes[i]));
}
auto offset = expressionBinder.createVariableExpression(*LogicalType::INT64(),
std::string(InternalKeyword::ROW_OFFSET));
auto boundInQueryCall = std::make_unique<BoundInQueryCall>(tableFunction, std::move(bindData),
std::move(columns), offset);
if (call.hasWherePredicate()) {
auto wherePredicate = expressionBinder.bindExpression(*call.getWherePredicate());
boundInQueryCall->setPredicate(std::move(wherePredicate));
}
return boundInQueryCall;
}

std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(const ReadingClause& readingClause) {
auto& loadFrom = ku_dynamic_cast<const ReadingClause&, const LoadFrom&>(readingClause);
TableFunction scanFunction;
std::unique_ptr<TableFuncBindData> bindData;
auto source = loadFrom.getSource();
switch (source->type) {
case ScanSourceType::OBJECT: {
auto objectSource = ku_dynamic_cast<BaseScanSource*, ObjectScanSource*>(source);
auto objectName = objectSource->objectName;
if (objectName.find("_") == std::string::npos) {
// Bind table
auto replacementData = clientContext->tryReplace(objectName);
if (replacementData == nullptr) {
throw BinderException(ExceptionMessage::variableNotInScope(objectName));
}
scanFunction = replacementData->func;
bindData = scanFunction.bindFunc(clientContext, &replacementData->bindInput);
} else {
auto dbName = common::StringUtils::split(objectName, "_")[0];
auto attachedDB =
clientContext->getDatabase()->getDatabaseManagerUnsafe()->getAttachedDatabase(
dbName);
if (attachedDB == nullptr) {
throw BinderException{
common::stringFormat("No database named {} has been attached.", dbName)};
}
auto tableName = common::StringUtils::split(objectName, "_")[1];
auto tableID = attachedDB->getCatalogContent()->getTableID(tableName);
auto tableCatalogEntry = ku_dynamic_cast<CatalogEntry*, TableCatalogEntry*>(
attachedDB->getCatalogContent()->getTableCatalogEntry(tableID));
scanFunction = tableCatalogEntry->getScanFunction();
auto bindInput = function::TableFuncBindInput();
bindData = scanFunction.bindFunc(clientContext, &bindInput);
}
} break;
case ScanSourceType::FILE: {
auto fileSource = ku_dynamic_cast<BaseScanSource*, FileScanSource*>(source);
auto filePaths = bindFilePaths(fileSource->filePaths);
auto fileType = bindFileType(filePaths);
auto readerConfig = std::make_unique<ReaderConfig>(fileType, std::move(filePaths));
readerConfig->options = bindParsingOptions(loadFrom.getParsingOptionsRef());
if (readerConfig->getNumFiles() > 1) {
throw BinderException("Load from multiple files is not supported.");
}
switch (fileType) {
case common::FileType::CSV:
case common::FileType::PARQUET:
case common::FileType::NPY:
break;
default:
throw BinderException(
stringFormat("Cannot load from file type {}.", FileTypeUtils::toString(fileType)));
}
// Bind columns from input.
std::vector<std::string> expectedColumnNames;
std::vector<LogicalType> expectedColumnTypes;
for (auto& [name, type] : loadFrom.getColumnNameDataTypesRef()) {
expectedColumnNames.push_back(name);
expectedColumnTypes.push_back(*bindDataType(type));
}
scanFunction = getScanFunction(readerConfig->fileType, *readerConfig);
auto bindInput = ScanTableFuncBindInput(readerConfig->copy(),
std::move(expectedColumnNames), std::move(expectedColumnTypes), clientContext);
bindData = scanFunction.bindFunc(clientContext, &bindInput);
} break;
default:
throw BinderException(stringFormat("LOAD FROM subquery is not supported."));
}
expression_vector columns;
for (auto i = 0u; i < bindData->columnTypes.size(); i++) {
columns.push_back(createVariable(bindData->columnNames[i], bindData->columnTypes[i]));
}
auto info = BoundFileScanInfo(scanFunction, std::move(bindData), std::move(columns));
auto boundLoadFrom = std::make_unique<BoundLoadFrom>(std::move(info));
if (loadFrom.hasWherePredicate()) {
auto wherePredicate = expressionBinder.bindExpression(*loadFrom.getWherePredicate());
boundLoadFrom->setPredicate(std::move(wherePredicate));
}
return boundLoadFrom;
}

} // namespace binder
} // namespace kuzu
3 changes: 2 additions & 1 deletion src/binder/bind/bind_standalone_call.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "binder/binder.h"
#include "binder/bound_standalone_call.h"
#include "binder/expression/expression_util.h"
#include "common/exception/binder.h"
#include "extension/extension.h"
#include "main/db_config.h"
Expand All @@ -23,7 +24,7 @@ std::unique_ptr<BoundStatement> Binder::bindStandaloneCall(const parser::Stateme
}
auto optionValue = expressionBinder.bindLiteralExpression(*callStatement.getOptionValue());
// TODO(Ziyi): add casting rule for option value.
ExpressionBinder::validateExpectedDataType(*optionValue, option->parameterType);
ExpressionUtil::validateDataType(*optionValue, option->parameterType);
return std::make_unique<BoundStandaloneCall>(option, std::move(optionValue));
}

Expand Down
4 changes: 3 additions & 1 deletion src/binder/bind/copy/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
add_library(kuzu_binder_bind_copy
OBJECT
bind_copy_rdf_graph.cpp)
bind_copy_rdf_graph.cpp
bind_copy_to.cpp
bind_copy_from.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_binder_bind_copy>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "binder/binder.h"
#include "binder/copy/bound_copy_from.h"
#include "binder/copy/bound_copy_to.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rdf_graph_catalog_entry.h"
Expand All @@ -21,33 +20,6 @@ using namespace kuzu::function;
namespace kuzu {
namespace binder {

std::unique_ptr<BoundStatement> Binder::bindCopyToClause(const Statement& statement) {
auto& copyToStatement = ku_dynamic_cast<const Statement&, const CopyTo&>(statement);
auto boundFilePath = copyToStatement.getFilePath();
auto fileType = bindFileType(boundFilePath);
std::vector<std::string> columnNames;
std::vector<LogicalType> columnTypes;
auto parsedQuery =
ku_dynamic_cast<const Statement*, const RegularQuery*>(copyToStatement.getStatement());
auto query = bindQuery(*parsedQuery);
auto columns = query->getStatementResult()->getColumns();
for (auto& column : columns) {
auto columnName = column->hasAlias() ? column->getAlias() : column->toString();
columnNames.push_back(columnName);
columnTypes.push_back(column->getDataType());
}
if (fileType != FileType::CSV && fileType != FileType::PARQUET) {
throw BinderException("COPY TO currently only supports csv and parquet files.");
}
if (fileType != FileType::CSV && copyToStatement.getParsingOptionsRef().size() != 0) {
throw BinderException{"Only copy to csv can have options."};
}
auto csvConfig =
CSVReaderConfig::construct(bindParsingOptions(copyToStatement.getParsingOptionsRef()));
return std::make_unique<BoundCopyTo>(boundFilePath, fileType, std::move(query),
csvConfig.option.copy());
}

std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& statement) {
auto& copyStatement = ku_dynamic_cast<const Statement&, const CopyFrom&>(statement);
auto tableName = copyStatement.getTableName();
Expand Down
Loading

0 comments on commit d8879be

Please sign in to comment.