Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make table functions as part of the reading clause #1737

Merged
merged 1 commit into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ kU_CopyCSV
kU_CopyNPY
: COPY SP oC_SchemaName SP FROM SP '(' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ')' SP BY SP COLUMN ;

kU_Call
: CALL SP ( ( oC_SymbolicName SP? '=' SP? oC_Literal )
| ( oC_FunctionName SP? '(' oC_Literal? ')' ) );
kU_StandaloneCall
: CALL SP oC_SymbolicName SP? '=' SP? oC_Literal ;

CALL : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ( 'L' | 'l' ) ;

Expand Down Expand Up @@ -141,7 +140,7 @@ oC_Statement
| kU_DDL
| kU_CopyNPY
| kU_CopyCSV
| kU_Call ;
| kU_StandaloneCall ;

oC_Query
: oC_RegularQuery ;
Expand Down Expand Up @@ -185,8 +184,12 @@ oC_UpdatingClause
oC_ReadingClause
: oC_Match
| oC_Unwind
| kU_InQueryCall
;

kU_InQueryCall
: CALL SP oC_FunctionName SP? '(' oC_Literal? ')' ;

oC_Match
: ( OPTIONAL SP )? MATCH SP? oC_Pattern (SP? oC_Where)? ;

Expand Down
2 changes: 1 addition & 1 deletion src/binder/bind/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
add_library(
kuzu_binder_bind
OBJECT
bind_call.cpp
bind_standalone_call.cpp
bind_copy.cpp
bind_ddl.cpp
bind_graph_pattern.cpp
Expand Down
44 changes: 0 additions & 44 deletions src/binder/bind/bind_call.cpp

This file was deleted.

28 changes: 26 additions & 2 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "binder/binder.h"
#include "binder/call/bound_in_query_call.h"
#include "binder/expression/literal_expression.h"
#include "binder/query/reading_clause/bound_unwind_clause.h"
#include "parser/call/in_query_call.h"
#include "parser/query/reading_clause/unwind_clause.h"

using namespace kuzu::common;
Expand All @@ -16,13 +19,16 @@ std::unique_ptr<BoundReadingClause> Binder::bindReadingClause(const ReadingClaus
case ClauseType::UNWIND: {
return bindUnwindClause((UnwindClause&)readingClause);
}
case ClauseType::InQueryCall: {
return bindInQueryCall((InQueryCall&)readingClause);
}
default:
throw NotImplementedException("bindReadingClause().");
}
}

std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause& readingClause) {
auto& matchClause = (MatchClause&)readingClause;
auto& matchClause = reinterpret_cast<const MatchClause&>(readingClause);
auto [queryGraphCollection, propertyCollection] =
bindGraphPattern(matchClause.getPatternElements());
auto boundMatchClause =
Expand All @@ -42,7 +48,7 @@ std::unique_ptr<BoundReadingClause> Binder::bindMatchClause(const ReadingClause&
}

std::unique_ptr<BoundReadingClause> Binder::bindUnwindClause(const ReadingClause& readingClause) {
auto& unwindClause = (UnwindClause&)readingClause;
auto& unwindClause = reinterpret_cast<const UnwindClause&>(readingClause);
auto boundExpression = expressionBinder.bindExpression(*unwindClause.getExpression());
boundExpression =
ExpressionBinder::implicitCastIfNecessary(boundExpression, LogicalTypeID::VAR_LIST);
Expand All @@ -51,5 +57,23 @@ std::unique_ptr<BoundReadingClause> Binder::bindUnwindClause(const ReadingClause
return make_unique<BoundUnwindClause>(std::move(boundExpression), std::move(aliasExpression));
}

std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause& readingClause) {
auto& callStatement = reinterpret_cast<const parser::InQueryCall&>(readingClause);
auto tableFunctionDefinition =
catalog.getBuiltInTableOperation()->mathTableOperation(callStatement.getFuncName());
auto boundExpr = expressionBinder.bindLiteralExpression(*callStatement.getParameter());
auto inputValue = reinterpret_cast<LiteralExpression*>(boundExpr.get())->getValue();
auto bindData = tableFunctionDefinition->bindFunc(
function::TableFuncBindInput{std::vector<common::Value>{*inputValue}},
catalog.getReadOnlyVersion());
expression_vector outputExpressions;
for (auto i = 0u; i < bindData->returnColumnNames.size(); i++) {
outputExpressions.push_back(
createVariable(bindData->returnColumnNames[i], bindData->returnTypes[i]));
}
return std::make_unique<BoundInQueryCall>(
tableFunctionDefinition->tableFunc, std::move(bindData), std::move(outputExpressions));
}

} // namespace binder
} // namespace kuzu
22 changes: 22 additions & 0 deletions src/binder/bind/bind_standalone_call.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "binder/binder.h"
#include "binder/call/bound_standalone_call.h"
#include "parser/call/standalone_call.h"

namespace kuzu {
namespace binder {

std::unique_ptr<BoundStatement> Binder::bindStandaloneCall(const parser::Statement& statement) {
auto& callStatement = reinterpret_cast<const parser::StandaloneCall&>(statement);
auto option = main::DBConfig::getOptionByName(callStatement.getOptionName());
if (option == nullptr) {
throw common::BinderException{
"Invalid option name: " + callStatement.getOptionName() + "."};
}
auto optionValue = expressionBinder.bindLiteralExpression(*callStatement.getOptionValue());
// TODO(Ziyi): add casting rule for option value.
ExpressionBinder::validateExpectedDataType(*optionValue, option->parameterType);
return std::make_unique<BoundStandaloneCall>(*option, std::move(optionValue));
}

} // namespace binder
} // namespace kuzu
7 changes: 2 additions & 5 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,8 @@ std::unique_ptr<BoundStatement> Binder::bind(const Statement& statement) {
case StatementType::QUERY: {
return bindQuery((const RegularQuery&)statement);
}
case StatementType::CALL_CONFIG: {
return bindCallConfig(statement);
}
case StatementType::CALL_TABLE_FUNC: {
return bindCallTableFunc(statement);
case StatementType::StandaloneCall: {
return bindStandaloneCall(statement);
}
default:
assert(false);
Expand Down
10 changes: 5 additions & 5 deletions src/binder/bound_statement_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,8 @@ void BoundStatementVisitor::visit(const kuzu::binder::BoundStatement& statement)
case StatementType::COPY: {
visitCopy(statement);
} break;
case StatementType::CALL_CONFIG: {
visitCallConfig(statement);
} break;
case StatementType::CALL_TABLE_FUNC: {
visitCallTableFunc(statement);
case StatementType::StandaloneCall: {
visitStandaloneCall(statement);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visit");
Expand Down Expand Up @@ -80,6 +77,9 @@ void BoundStatementVisitor::visitReadingClause(const BoundReadingClause& reading
case common::ClauseType::UNWIND: {
visitUnwind(readingClause);
} break;
case common::ClauseType::InQueryCall: {
visitInQueryCall(readingClause);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visitReadingClause");
}
Expand Down
4 changes: 2 additions & 2 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ class Binder {
std::unique_ptr<BoundQueryPart> bindQueryPart(const parser::QueryPart& queryPart);

/*** bind call ***/
std::unique_ptr<BoundStatement> bindCallTableFunc(const parser::Statement& statement);
std::unique_ptr<BoundStatement> bindCallConfig(const parser::Statement& statement);
std::unique_ptr<BoundStatement> bindStandaloneCall(const parser::Statement& statement);

/*** bind reading clause ***/
std::unique_ptr<BoundReadingClause> bindReadingClause(
const parser::ReadingClause& readingClause);
std::unique_ptr<BoundReadingClause> bindMatchClause(const parser::ReadingClause& readingClause);
std::unique_ptr<BoundReadingClause> bindUnwindClause(
const parser::ReadingClause& readingClause);
std::unique_ptr<BoundReadingClause> bindInQueryCall(const parser::ReadingClause& readingClause);

/*** bind updating clause ***/
std::unique_ptr<BoundUpdatingClause> bindUpdatingClause(
Expand Down
4 changes: 2 additions & 2 deletions src/include/binder/bound_statement_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ class BoundStatementVisitor {
virtual void visitDropProperty(const BoundStatement& statement) {}
virtual void visitRenameProperty(const BoundStatement& statement) {}
virtual void visitCopy(const BoundStatement& statement) {}
virtual void visitCallConfig(const BoundStatement& statement) {}
virtual void visitCallTableFunc(const BoundStatement& statement) {}
virtual void visitStandaloneCall(const BoundStatement& statement) {}

void visitReadingClause(const BoundReadingClause& readingClause);
virtual void visitMatch(const BoundReadingClause& readingClause) {}
virtual void visitUnwind(const BoundReadingClause& readingClause) {}
virtual void visitInQueryCall(const BoundReadingClause& statement) {}
void visitUpdatingClause(const BoundUpdatingClause& updatingClause);
virtual void visitSet(const BoundUpdatingClause& updatingClause) {}
virtual void visitDelete(const BoundUpdatingClause& updatingClause) {}
Expand Down
27 changes: 0 additions & 27 deletions src/include/binder/call/bound_call_table_func.h

This file was deleted.

34 changes: 34 additions & 0 deletions src/include/binder/call/bound_in_query_call.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

#include "binder/expression/expression.h"
#include "binder/query/reading_clause/bound_reading_clause.h"
#include "function/table_operations.h"

namespace kuzu {
namespace binder {

class BoundInQueryCall : public BoundReadingClause {
public:
BoundInQueryCall(function::table_func_t tableFunc,
std::unique_ptr<function::TableFuncBindData> bindData, expression_vector outputExpressions)
: BoundReadingClause{common::ClauseType::InQueryCall}, tableFunc{std::move(tableFunc)},
bindData{std::move(bindData)}, outputExpressions{std::move(outputExpressions)} {}

inline function::table_func_t getTableFunc() const { return tableFunc; }

inline function::TableFuncBindData* getBindData() const { return bindData.get(); }

inline expression_vector getOutputExpressions() const { return outputExpressions; }

inline std::unique_ptr<BoundReadingClause> copy() override {
return std::make_unique<BoundInQueryCall>(tableFunc, bindData->copy(), outputExpressions);
}

private:
function::table_func_t tableFunc;
std::unique_ptr<function::TableFuncBindData> bindData;
expression_vector outputExpressions;
};

} // namespace binder
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
namespace kuzu {
namespace binder {

class BoundCallConfig : public BoundStatement {
class BoundStandaloneCall : public BoundStatement {
public:
BoundCallConfig(main::ConfigurationOption option, std::shared_ptr<Expression> optionValue)
: BoundStatement{common::StatementType::CALL_CONFIG,
BoundStandaloneCall(main::ConfigurationOption option, std::shared_ptr<Expression> optionValue)
: BoundStatement{common::StatementType::StandaloneCall,
BoundStatementResult::createEmptyResult()},
option{option}, optionValue{optionValue} {}

Expand Down
3 changes: 2 additions & 1 deletion src/include/common/clause_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ enum class ClauseType : uint8_t {
CREATE = 2,
// reading clause
MATCH = 3,
UNWIND = 4
UNWIND = 4,
InQueryCall = 5,
};

} // namespace common
Expand Down
3 changes: 1 addition & 2 deletions src/include/common/statement_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ enum class StatementType : uint8_t {
DROP_PROPERTY = 7,
RENAME_PROPERTY = 8,
COPY = 20,
CALL_CONFIG = 21,
CALL_TABLE_FUNC = 22,
StandaloneCall = 21,
};

class StatementTypeUtils {
Expand Down
25 changes: 25 additions & 0 deletions src/include/parser/call/in_query_call.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "parser/expression/parsed_expression.h"
#include "parser/query/reading_clause/reading_clause.h"

namespace kuzu {
namespace parser {

class InQueryCall : public ReadingClause {
public:
InQueryCall(std::string optionName, std::unique_ptr<ParsedExpression> optionValue)
: ReadingClause{common::ClauseType::InQueryCall}, funcName{std::move(optionName)},
parameter{std::move(optionValue)} {}

inline std::string getFuncName() const { return funcName; }

inline ParsedExpression* getParameter() const { return parameter.get(); }

private:
std::string funcName;
std::unique_ptr<ParsedExpression> parameter;
};

} // namespace parser
} // namespace kuzu
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
namespace kuzu {
namespace parser {

class Call : public Statement {
class StandaloneCall : public Statement {
public:
explicit Call(common::StatementType statementType, std::string optionName,
std::unique_ptr<ParsedExpression> optionValue)
: Statement{statementType}, optionName{std::move(optionName)}, optionValue{std::move(
optionValue)} {}
explicit StandaloneCall(std::string optionName, std::unique_ptr<ParsedExpression> optionValue)
: Statement{common::StatementType::StandaloneCall}, optionName{std::move(optionName)},
optionValue{std::move(optionValue)} {}

inline std::string getOptionName() const { return optionName; }

Expand Down
4 changes: 3 additions & 1 deletion src/include/parser/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class Transformer {

std::unique_ptr<ReadingClause> transformUnwind(CypherParser::OC_UnwindContext& ctx);

std::unique_ptr<ReadingClause> transformInQueryCall(CypherParser::KU_InQueryCallContext& ctx);

std::unique_ptr<UpdatingClause> transformCreate(CypherParser::OC_CreateContext& ctx);

std::unique_ptr<UpdatingClause> transformSet(CypherParser::OC_SetContext& ctx);
Expand Down Expand Up @@ -248,7 +250,7 @@ class Transformer {

std::unique_ptr<Statement> transformCopyNPY(CypherParser::KU_CopyNPYContext& ctx);

std::unique_ptr<Statement> transformCall(CypherParser::KU_CallContext& ctx);
std::unique_ptr<Statement> transformStandaloneCall(CypherParser::KU_StandaloneCallContext& ctx);

std::vector<std::string> transformFilePaths(
std::vector<antlr4::tree::TerminalNode*> stringLiteral);
Expand Down
Loading
Loading