Skip to content

Commit

Permalink
Make table functions as part of the reading clause
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Jun 29, 2023
1 parent 184d8df commit 5846c5e
Show file tree
Hide file tree
Showing 24 changed files with 2,892 additions and 2,722 deletions.
11 changes: 7 additions & 4 deletions src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ kU_CopyCSV
kU_CopyNPY
: COPY SP oC_SchemaName SP FROM SP '(' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ')' SP BY SP COLUMN ;

kU_Call
: CALL SP ( ( oC_SymbolicName SP? '=' SP? oC_Literal )
| ( oC_FunctionName SP? '(' oC_Literal? ')' ) );
kU_CallConfig
: CALL SP oC_SymbolicName SP? '=' SP? oC_Literal ;

CALL : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ( 'L' | 'l' ) ;

Expand Down Expand Up @@ -141,7 +140,7 @@ oC_Statement
| kU_DDL
| kU_CopyNPY
| kU_CopyCSV
| kU_Call ;
| kU_CallConfig ;

oC_Query
: oC_RegularQuery ;
Expand Down Expand Up @@ -185,8 +184,12 @@ oC_UpdatingClause
oC_ReadingClause
: oC_Match
| oC_Unwind
| oC_CallTableFunc
;

oC_CallTableFunc
: CALL SP oC_FunctionName SP? '(' oC_Literal? ')' ;

oC_Match
: ( OPTIONAL SP )? MATCH SP? oC_Pattern (SP? oC_Where)? ;

Expand Down
2 changes: 1 addition & 1 deletion src/binder/bind/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
add_library(
kuzu_binder_bind
OBJECT
bind_call.cpp
bind_call_config.cpp
bind_copy.cpp
bind_ddl.cpp
bind_graph_pattern.cpp
Expand Down
44 changes: 0 additions & 44 deletions src/binder/bind/bind_call.cpp

This file was deleted.

22 changes: 22 additions & 0 deletions src/binder/bind/bind_call_config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "binder/binder.h"
#include "binder/call/bound_call_config.h"
#include "parser/call/call_config.h"

namespace kuzu {
namespace binder {

std::unique_ptr<BoundStatement> Binder::bindCallConfig(const parser::Statement& statement) {
auto& callStatement = reinterpret_cast<const parser::CallConfig&>(statement);
auto option = main::DBConfig::getOptionByName(callStatement.getOptionName());
if (option == nullptr) {
throw common::BinderException{
"Invalid option name: " + callStatement.getOptionName() + "."};
}
auto optionValue = expressionBinder.bindLiteralExpression(*callStatement.getOptionValue());
// TODO(Ziyi): add casting rule for option value.
ExpressionBinder::validateExpectedDataType(*optionValue, option->parameterType);
return std::make_unique<BoundCallConfig>(*option, std::move(optionValue));
}

} // namespace binder
} // namespace kuzu
28 changes: 26 additions & 2 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "binder/binder.h"
#include "binder/call/bound_call_table_func.h"
#include "binder/expression/literal_expression.h"
#include "binder/query/reading_clause/bound_unwind_clause.h"
#include "parser/call/call_table_func.h"
#include "parser/query/reading_clause/unwind_clause.h"

using namespace kuzu::common;
Expand All @@ -16,13 +19,16 @@ std::unique_ptr<BoundReadingClause> Binder::bindReadingClause(const ReadingClaus
case ClauseType::UNWIND: {
return bindUnwindClause((UnwindClause&)readingClause);
}
case ClauseType::CALL_TABLE_FUNC: {
return bindCallTableFunc((CallTableFunc&)readingClause);
}
default:
throw NotImplementedException("bindReadingClause().");
}
}

std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause& readingClause) {
auto& matchClause = (MatchClause&)readingClause;
auto& matchClause = reinterpret_cast<const MatchClause&>(readingClause);
auto [queryGraphCollection, propertyCollection] =
bindGraphPattern(matchClause.getPatternElements());
auto boundMatchClause =
Expand All @@ -42,7 +48,7 @@ std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause&
}

std::unique_ptr<BoundReadingClause> Binder::bindUnwindClause(const ReadingClause& readingClause) {
auto& unwindClause = (UnwindClause&)readingClause;
auto& unwindClause = reinterpret_cast<const UnwindClause&>(readingClause);
auto boundExpression = expressionBinder.bindExpression(*unwindClause.getExpression());
boundExpression =
ExpressionBinder::implicitCastIfNecessary(boundExpression, LogicalTypeID::VAR_LIST);
Expand All @@ -51,5 +57,23 @@ std::unique_ptr<BoundReadingClause> Binder::bindUnwindClause(const ReadingClause
return make_unique<BoundUnwindClause>(std::move(boundExpression), std::move(aliasExpression));
}

std::unique_ptr<BoundReadingClause> Binder::bindCallTableFunc(const ReadingClause& readingClause) {
auto& callStatement = reinterpret_cast<const parser::CallTableFunc&>(readingClause);
auto tableFunctionDefinition =
catalog.getBuiltInTableOperation()->mathTableOperation(callStatement.getFuncName());
auto boundExpr = expressionBinder.bindLiteralExpression(*callStatement.getParameter());
auto inputValue = reinterpret_cast<LiteralExpression*>(boundExpr.get())->getValue();
auto bindData = tableFunctionDefinition->bindFunc(
function::TableFuncBindInput{std::vector<common::Value>{*inputValue}},
catalog.getReadOnlyVersion());
expression_vector outputExpressions;
for (auto i = 0u; i < bindData->returnColumnNames.size(); i++) {
outputExpressions.push_back(
createVariable(bindData->returnColumnNames[i], bindData->returnTypes[i]));
}
return std::make_unique<BoundCallTableFunc>(
tableFunctionDefinition->tableFunc, std::move(bindData), std::move(outputExpressions));
}

} // namespace binder
} // namespace kuzu
3 changes: 0 additions & 3 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ std::unique_ptr<BoundStatement> Binder::bind(const Statement& statement) {
case StatementType::CALL_CONFIG: {
return bindCallConfig(statement);
}
case StatementType::CALL_TABLE_FUNC: {
return bindCallTableFunc(statement);
}
default:
assert(false);
}
Expand Down
6 changes: 3 additions & 3 deletions src/binder/bound_statement_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ void BoundStatementVisitor::visit(const kuzu::binder::BoundStatement& statement)
case StatementType::CALL_CONFIG: {
visitCallConfig(statement);
} break;
case StatementType::CALL_TABLE_FUNC: {
visitCallTableFunc(statement);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visit");
}
Expand Down Expand Up @@ -80,6 +77,9 @@ void BoundStatementVisitor::visitReadingClause(const BoundReadingClause& reading
case common::ClauseType::UNWIND: {
visitUnwind(readingClause);
} break;
case common::ClauseType::CALL_TABLE_FUNC: {
visitCallTableFunc(readingClause);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visitReadingClause");
}
Expand Down
2 changes: 1 addition & 1 deletion src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ class Binder {
std::unique_ptr<BoundQueryPart> bindQueryPart(const parser::QueryPart& queryPart);

/*** bind call ***/
std::unique_ptr<BoundStatement> bindCallTableFunc(const parser::Statement& statement);
std::unique_ptr<BoundStatement> bindCallConfig(const parser::Statement& statement);

/*** bind reading clause ***/
Expand All @@ -118,6 +117,7 @@ class Binder {
std::unique_ptr<BoundReadingClause> bindMatchClause(const parser::ReadingClause& readingClause);
std::unique_ptr<BoundReadingClause> bindUnwindClause(
const parser::ReadingClause& readingClause);
std::unique_ptr<BoundReadingClause> bindCallTableFunc(const parser::ReadingClause& statement);

/*** bind updating clause ***/
std::unique_ptr<BoundUpdatingClause> bindUpdatingClause(
Expand Down
2 changes: 1 addition & 1 deletion src/include/binder/bound_statement_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ class BoundStatementVisitor {
virtual void visitRenameProperty(const BoundStatement& statement) {}
virtual void visitCopy(const BoundStatement& statement) {}
virtual void visitCallConfig(const BoundStatement& statement) {}
virtual void visitCallTableFunc(const BoundStatement& statement) {}

void visitReadingClause(const BoundReadingClause& readingClause);
virtual void visitMatch(const BoundReadingClause& readingClause) {}
virtual void visitUnwind(const BoundReadingClause& readingClause) {}
virtual void visitCallTableFunc(const BoundReadingClause& statement) {}
void visitUpdatingClause(const BoundUpdatingClause& updatingClause);
virtual void visitSet(const BoundUpdatingClause& updatingClause) {}
virtual void visitDelete(const BoundUpdatingClause& updatingClause) {}
Expand Down
17 changes: 12 additions & 5 deletions src/include/binder/call/bound_call_table_func.h
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
#pragma once

#include "binder/expression/expression.h"
#include "binder/query/reading_clause/bound_reading_clause.h"
#include "function/table_operations.h"

namespace kuzu {
namespace binder {

class BoundCallTableFunc : public BoundStatement {
class BoundCallTableFunc : public BoundReadingClause {
public:
BoundCallTableFunc(function::table_func_t tableFunc,
std::unique_ptr<function::TableFuncBindData> bindData,
std::unique_ptr<BoundStatementResult> statementResult)
: BoundStatement{common::StatementType::CALL_TABLE_FUNC, std::move(statementResult)},
tableFunc{std::move(tableFunc)}, bindData{std::move(bindData)} {}
std::unique_ptr<function::TableFuncBindData> bindData, expression_vector outputExpressions)
: BoundReadingClause{common::ClauseType::CALL_TABLE_FUNC}, tableFunc{std::move(tableFunc)},
bindData{std::move(bindData)}, outputExpressions{std::move(outputExpressions)} {}

inline function::table_func_t getTableFunc() const { return tableFunc; }

inline function::TableFuncBindData* getBindData() const { return bindData.get(); }

inline expression_vector getOutputExpressions() const { return outputExpressions; }

inline std::unique_ptr<BoundReadingClause> copy() override {
return std::make_unique<BoundCallTableFunc>(tableFunc, bindData->copy(), outputExpressions);
}

private:
function::table_func_t tableFunc;
std::unique_ptr<function::TableFuncBindData> bindData;
expression_vector outputExpressions;
};

} // namespace binder
Expand Down
3 changes: 2 additions & 1 deletion src/include/common/clause_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ enum class ClauseType : uint8_t {
CREATE = 2,
// reading clause
MATCH = 3,
UNWIND = 4
UNWIND = 4,
CALL_TABLE_FUNC = 5,
};

} // namespace common
Expand Down
1 change: 0 additions & 1 deletion src/include/common/statement_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ enum class StatementType : uint8_t {
RENAME_PROPERTY = 8,
COPY = 20,
CALL_CONFIG = 21,
CALL_TABLE_FUNC = 22,
};

class StatementTypeUtils {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
namespace kuzu {
namespace parser {

class Call : public Statement {
class CallConfig : public Statement {
public:
explicit Call(common::StatementType statementType, std::string optionName,
std::unique_ptr<ParsedExpression> optionValue)
: Statement{statementType}, optionName{std::move(optionName)}, optionValue{std::move(
optionValue)} {}
explicit CallConfig(std::string optionName, std::unique_ptr<ParsedExpression> optionValue)
: Statement{common::StatementType::CALL_CONFIG}, optionName{std::move(optionName)},
optionValue{std::move(optionValue)} {}

inline std::string getOptionName() const { return optionName; }

Expand Down
25 changes: 25 additions & 0 deletions src/include/parser/call/call_table_func.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "parser/expression/parsed_expression.h"
#include "parser/query/reading_clause/reading_clause.h"

namespace kuzu {
namespace parser {

class CallTableFunc : public ReadingClause {
public:
CallTableFunc(std::string optionName, std::unique_ptr<ParsedExpression> optionValue)
: ReadingClause{common::ClauseType::CALL_TABLE_FUNC}, funcName{std::move(optionName)},
parameter{std::move(optionValue)} {}

inline std::string getFuncName() const { return funcName; }

inline ParsedExpression* getParameter() const { return parameter.get(); }

private:
std::string funcName;
std::unique_ptr<ParsedExpression> parameter;
};

} // namespace parser
} // namespace kuzu
5 changes: 4 additions & 1 deletion src/include/parser/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class Transformer {

std::unique_ptr<ReadingClause> transformUnwind(CypherParser::OC_UnwindContext& ctx);

std::unique_ptr<ReadingClause> transformCallTableFunc(
CypherParser::OC_CallTableFuncContext& ctx);

std::unique_ptr<UpdatingClause> transformCreate(CypherParser::OC_CreateContext& ctx);

std::unique_ptr<UpdatingClause> transformSet(CypherParser::OC_SetContext& ctx);
Expand Down Expand Up @@ -248,7 +251,7 @@ class Transformer {

std::unique_ptr<Statement> transformCopyNPY(CypherParser::KU_CopyNPYContext& ctx);

std::unique_ptr<Statement> transformCall(CypherParser::KU_CallContext& ctx);
std::unique_ptr<Statement> transformCallConfig(CypherParser::KU_CallConfigContext& ctx);

std::vector<std::string> transformFilePaths(
std::vector<antlr4::tree::TerminalNode*> stringLiteral);
Expand Down
2 changes: 1 addition & 1 deletion src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class JoinOrderEnumerator {
inline void planCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan) {
appendCrossProduct(probePlan, buildPlan);
}
void appendCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan);

private:
std::vector<std::unique_ptr<LogicalPlan>> planCrossProduct(
Expand Down Expand Up @@ -109,7 +110,6 @@ class JoinOrderEnumerator {
void appendIntersect(const std::shared_ptr<Expression>& intersectNodeID,
binder::expression_vector& boundNodeIDs, LogicalPlan& probePlan,
std::vector<std::unique_ptr<LogicalPlan>>& buildPlans);
void appendCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan);

static binder::expression_vector getNewlyMatchedExpressions(const SubqueryGraph& prevSubgraph,
const SubqueryGraph& newSubgraph, const binder::expression_vector& expressions) {
Expand Down
2 changes: 0 additions & 2 deletions src/include/planner/planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ class Planner {
const catalog::Catalog& catalog, const BoundStatement& statement);

static std::unique_ptr<LogicalPlan> planCallConfig(const BoundStatement& statement);

static std::unique_ptr<LogicalPlan> planCallTableFunc(const BoundStatement& statement);
};

} // namespace planner
Expand Down
5 changes: 5 additions & 0 deletions src/include/planner/query_planner.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "binder/bound_statement.h"
#include "binder/call/bound_call_table_func.h"
#include "binder/expression/existential_subquery_expression.h"
#include "join_order_enumerator.h"
#include "planner/join_order/cardinality_estimator.h"
Expand Down Expand Up @@ -44,6 +45,8 @@ class QueryPlanner {
BoundReadingClause* boundReadingClause, std::vector<std::unique_ptr<LogicalPlan>>& plans);
void planUnwindClause(
BoundReadingClause* boundReadingClause, std::vector<std::unique_ptr<LogicalPlan>>& plans);
void planCallTableFunc(
BoundReadingClause* boundReadingClause, std::vector<std::unique_ptr<LogicalPlan>>& plans);

// CTE & subquery planning
void planOptionalMatch(const QueryGraphCollection& queryGraphCollection,
Expand All @@ -61,6 +64,8 @@ class QueryPlanner {

void appendUnwind(BoundUnwindClause& boundUnwindClause, LogicalPlan& plan);

void appendCallTableFunc(BoundCallTableFunc& boundCallTableFunc, LogicalPlan& plan);

void appendFlattens(const f_group_pos_set& groupsPos, LogicalPlan& plan);
void appendFlattenIfNecessary(f_group_pos groupPos, LogicalPlan& plan);

Expand Down
Loading

0 comments on commit 5846c5e

Please sign in to comment.