Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement scalar macro #1836

Merged
merged 1 commit into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ kU_StandaloneCall

CALL : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ( 'L' | 'l' ) ;

kU_CreateMacro
: CREATE SP MACRO SP oC_FunctionName SP? '(' SP? kU_PositionalArgs? SP? kU_DefaultArg? ( SP? ',' SP? kU_DefaultArg )* SP? ')' SP AS SP oC_Expression ;

kU_PositionalArgs
: oC_SymbolicName ( SP? ',' SP? oC_SymbolicName )* ;

kU_DefaultArg
: oC_SymbolicName SP? ':' '=' SP? oC_Literal ;

MACRO : ( 'M' | 'm' ) ( 'A' | 'a' ) ( 'C' | 'c' ) ( 'R' | 'r' ) ( 'O' | 'o' ) ;

kU_FilePaths
: '[' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ']'
| StringLiteral
Expand Down Expand Up @@ -140,7 +151,8 @@ oC_Statement
| kU_DDL
| kU_CopyNPY
| kU_CopyCSV
| kU_StandaloneCall ;
| kU_StandaloneCall
| kU_CreateMacro ;

oC_Query
: oC_RegularQuery ;
Expand Down
1 change: 1 addition & 0 deletions src/binder/bind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_library(
kuzu_binder_bind
OBJECT
bind_standalone_call.cpp
bind_create_macro.cpp
bind_copy.cpp
bind_ddl.cpp
bind_explain.cpp
Expand Down
27 changes: 27 additions & 0 deletions src/binder/bind/bind_create_macro.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "binder/binder.h"
#include "binder/macro/bound_create_macro.h"
#include "common/string_utils.h"
#include "parser/macro/create_macro.h"

namespace kuzu {
namespace binder {

std::unique_ptr<BoundStatement> Binder::bindCreateMacro(const parser::Statement& statement) {
auto& createMacro = reinterpret_cast<const parser::CreateMacro&>(statement);
auto macroName = createMacro.getMacroName();
if (catalog.getReadOnlyVersion()->containMacro(macroName)) {
throw common::BinderException{
common::StringUtils::string_format("Macro {} already exists.", macroName)};

Check warning on line 14 in src/binder/bind/bind_create_macro.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/bind/bind_create_macro.cpp#L13-L14

Added lines #L13 - L14 were not covered by tests
}
parser::default_macro_args defaultArgs;
for (auto& defaultArg : createMacro.getDefaultArgs()) {
defaultArgs.emplace_back(defaultArg.first, defaultArg.second->copy());
}
auto scalarMacro =
std::make_unique<function::ScalarMacroFunction>(createMacro.getMacroExpression()->copy(),
createMacro.getPositionalArgs(), std::move(defaultArgs));
return std::make_unique<BoundCreateMacro>(macroName, std::move(scalarMacro));
}

} // namespace binder
} // namespace kuzu
36 changes: 33 additions & 3 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "common/string_utils.h"
#include "function/schema/vector_label_functions.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/parsed_expression_visitor.h"

using namespace kuzu::common;
using namespace kuzu::parser;
Expand All @@ -22,12 +23,16 @@
return result;
}
auto functionType = binder->catalog.getFunctionType(functionName);
if (functionType == FUNCTION) {
switch (functionType) {
case common::FUNCTION:
return bindScalarFunctionExpression(parsedExpression, functionName);
} else {
assert(functionType == AGGREGATE_FUNCTION);
case common::AGGREGATE_FUNCTION:
return bindAggregateFunctionExpression(
parsedExpression, functionName, parsedFunctionExpression.getIsDistinct());
case common::MACRO:
return bindMacroExpression(parsedExpression, functionName);
default:
throw NotImplementedException{"ExpressionBinder::bindFunctionExpression"};

Check warning on line 35 in src/binder/bind_expression/bind_function_expression.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/bind_expression/bind_function_expression.cpp#L34-L35

Added lines #L34 - L35 were not covered by tests
}
}

Expand Down Expand Up @@ -104,6 +109,31 @@
std::move(children), function->aggregateFunction->clone(), uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::bindMacroExpression(
const ParsedExpression& parsedExpression, const std::string& macroName) {
auto scalarMacroFunction = binder->catalog.getScalarMacroFunction(macroName);
auto macroExpr = scalarMacroFunction->expression->copy();
auto parameterVals = scalarMacroFunction->getDefaultParameterVals();
auto& parsedFuncExpr = reinterpret_cast<const ParsedFunctionExpression&>(parsedExpression);
auto positionalArgs = scalarMacroFunction->getPositionalArgs();
if (parsedFuncExpr.getNumChildren() > scalarMacroFunction->getNumArgs() ||
parsedFuncExpr.getNumChildren() < positionalArgs.size()) {
throw BinderException{"Invalid number of arguments for macro " + macroName + "."};
}
// Bind positional arguments.
for (auto i = 0u; i < positionalArgs.size(); i++) {
parameterVals[positionalArgs[i]] = parsedFuncExpr.getChild(i);
}
// Bind arguments with default values.
for (auto i = positionalArgs.size(); i < parsedFuncExpr.getNumChildren(); i++) {
auto parameterName =
scalarMacroFunction->getDefaultParameterName(i - positionalArgs.size());
parameterVals[parameterName] = parsedFuncExpr.getChild(i);
}
auto macroParameterReplacer = std::make_unique<MacroParameterReplacer>(parameterVals);
return bindExpression(*macroParameterReplacer->visit(std::move(macroExpr)));
}

std::shared_ptr<Expression> ExpressionBinder::staticEvaluate(
const std::string& functionName, const expression_vector& children) {
assert(children[0]->expressionType == common::LITERAL);
Expand Down
3 changes: 3 additions & 0 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ std::unique_ptr<BoundStatement> Binder::bind(const Statement& statement) {
case StatementType::EXPLAIN: {
return bindExplain(statement);
}
case StatementType::CREATE_MACRO: {
return bindCreateMacro(statement);
}
default:
assert(false);
}
Expand Down
3 changes: 3 additions & 0 deletions src/binder/bound_statement_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ void BoundStatementVisitor::visit(const kuzu::binder::BoundStatement& statement)
case StatementType::EXPLAIN: {
visitExplain(statement);
} break;
case StatementType::CREATE_MACRO: {
visitCreateMacro(statement);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visit");
}
Expand Down
54 changes: 40 additions & 14 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,15 @@ namespace catalog {

CatalogContent::CatalogContent() : nextTableID{0} {
logger = LoggerUtils::getLogger(LoggerConstants::LoggerEnum::CATALOG);
registerBuiltInFunctions();
}

CatalogContent::CatalogContent(const std::string& directory) {
logger = LoggerUtils::getLogger(LoggerConstants::LoggerEnum::CATALOG);
logger->info("Initializing catalog.");
readFromFile(directory, DBFileType::ORIGINAL);
logger->info("Initializing catalog done.");
registerBuiltInFunctions();
}

CatalogContent::CatalogContent(const CatalogContent& other) {
Expand All @@ -192,6 +194,7 @@ CatalogContent::CatalogContent(const CatalogContent& other) {
nodeTableNameToIDMap = other.nodeTableNameToIDMap;
relTableNameToIDMap = other.relTableNameToIDMap;
nextTableID = other.nextTableID;
registerBuiltInFunctions();
}

table_id_t CatalogContent::addNodeTableSchema(
Expand Down Expand Up @@ -325,6 +328,31 @@ void CatalogContent::readFromFile(const std::string& directory, DBFileType dbFil
SerDeser::deserializeValue<table_id_t>(nextTableID, fileInfo.get(), offset);
}

ExpressionType CatalogContent::getFunctionType(const std::string& name) const {
auto upperCaseName = StringUtils::getUpper(name);
if (builtInVectorFunctions->containsFunction(upperCaseName)) {
return FUNCTION;
} else if (builtInAggregateFunctions->containsFunction(upperCaseName)) {
return AGGREGATE_FUNCTION;
} else if (macros.contains(upperCaseName)) {
return MACRO;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably do a refactor of ExpressionType at some point. All these 3 things seem to be of FUNCTION EXPRESSION TYPE but 3 different FUNCTION TYPE. Open an issue for me maybe

Copy link
Collaborator Author

@acquamarin acquamarin Jul 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

open issue #1838

} else {
throw CatalogException(name + " function does not exist.");
}
}

void CatalogContent::addVectorFunction(
std::string name, function::vector_function_definitions definitions) {
StringUtils::toUpper(name);
builtInVectorFunctions->addFunction(std::move(name), std::move(definitions));
}

void CatalogContent::addScalarMacroFunction(
std::string name, std::unique_ptr<function::ScalarMacroFunction> macro) {
StringUtils::toUpper(name);
macros.emplace(std::move(name), std::move(macro));
}

void CatalogContent::validateStorageVersion(storage_version_t savedStorageVersion) const {
auto storageVersion = StorageVersionInfo::getStorageVersion();
if (savedStorageVersion != storageVersion) {
Expand Down Expand Up @@ -355,18 +383,18 @@ void CatalogContent::writeMagicBytes(FileInfo* fileInfo, offset_t& offset) const
}
}

Catalog::Catalog() : wal{nullptr} {
catalogContentForReadOnlyTrx = std::make_unique<CatalogContent>();
void CatalogContent::registerBuiltInFunctions() {
builtInVectorFunctions = std::make_unique<function::BuiltInVectorFunctions>();
builtInAggregateFunctions = std::make_unique<function::BuiltInAggregateFunctions>();
builtInTableFunctions = std::make_unique<function::BuiltInTableFunctions>();
}

Catalog::Catalog() : wal{nullptr} {
catalogContentForReadOnlyTrx = std::make_unique<CatalogContent>();
}

Catalog::Catalog(WAL* wal) : wal{wal} {
catalogContentForReadOnlyTrx = std::make_unique<CatalogContent>(wal->getDirectory());
builtInVectorFunctions = std::make_unique<function::BuiltInVectorFunctions>();
builtInAggregateFunctions = std::make_unique<function::BuiltInAggregateFunctions>();
builtInTableFunctions = std::make_unique<function::BuiltInTableFunctions>();
}

void Catalog::prepareCommitOrRollback(TransactionAction action) {
Expand All @@ -386,13 +414,7 @@ void Catalog::checkpointInMemory() {
}

ExpressionType Catalog::getFunctionType(const std::string& name) const {
if (builtInVectorFunctions->containsFunction(name)) {
return FUNCTION;
} else if (builtInAggregateFunctions->containsFunction(name)) {
return AGGREGATE_FUNCTION;
} else {
throw CatalogException(name + " function does not exist.");
}
return catalogContentForReadOnlyTrx->getFunctionType(name);
}

table_id_t Catalog::addNodeTableSchema(
Expand Down Expand Up @@ -461,8 +483,12 @@ std::unordered_set<RelTableSchema*> Catalog::getAllRelTableSchemasContainBoundTa

void Catalog::addVectorFunction(
std::string name, function::vector_function_definitions definitions) {
common::StringUtils::toUpper(name);
builtInVectorFunctions->addFunction(std::move(name), std::move(definitions));
catalogContentForReadOnlyTrx->addVectorFunction(std::move(name), std::move(definitions));
}

void Catalog::addScalarMacroFunction(
std::string name, std::unique_ptr<function::ScalarMacroFunction> macro) {
catalogContentForReadOnlyTrx->addScalarMacroFunction(std::move(name), std::move(macro));
}

} // namespace catalog
Expand Down
1 change: 1 addition & 0 deletions src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ add_library(kuzu_function
built_in_table_functions.cpp
comparison_functions.cpp
find_function.cpp
scalar_macro_function.cpp
table_functions.cpp
vector_arithmetic_functions.cpp
vector_boolean_functions.cpp
Expand Down
24 changes: 24 additions & 0 deletions src/function/scalar_macro_function.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "function/scalar_macro_function.h"

namespace kuzu {
namespace function {

macro_parameter_value_map ScalarMacroFunction::getDefaultParameterVals() const {
macro_parameter_value_map defaultArgsToReturn;
for (auto& defaultArg : defaultArgs) {
defaultArgsToReturn.emplace(defaultArg.first, defaultArg.second.get());
}
return defaultArgsToReturn;
}

std::unique_ptr<ScalarMacroFunction> ScalarMacroFunction::copy() const {
parser::default_macro_args defaultArgsCopy;
for (auto& defaultArg : defaultArgs) {
defaultArgsCopy.emplace_back(defaultArg.first, defaultArg.second->copy());
}
return std::make_unique<ScalarMacroFunction>(
expression->copy(), positionalArgs, std::move(defaultArgsCopy));
}

} // namespace function
} // namespace kuzu
3 changes: 3 additions & 0 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ class Binder {
/*** bind call ***/
std::unique_ptr<BoundStatement> bindStandaloneCall(const parser::Statement& statement);

/*** bind create macro ***/
std::unique_ptr<BoundStatement> bindCreateMacro(const parser::Statement& statement);

/*** bind explain ***/
std::unique_ptr<BoundStatement> bindExplain(const parser::Statement& statement);

Expand Down
1 change: 1 addition & 0 deletions src/include/binder/bound_statement_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class BoundStatementVisitor {
virtual void visitCopy(const BoundStatement& statement) {}
virtual void visitStandaloneCall(const BoundStatement& statement) {}
virtual void visitExplain(const BoundStatement& statement);
virtual void visitCreateMacro(const BoundStatement& statement) {}

void visitReadingClause(const BoundReadingClause& readingClause);
virtual void visitMatch(const BoundReadingClause& readingClause) {}
Expand Down
2 changes: 2 additions & 0 deletions src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class ExpressionBinder {
std::shared_ptr<Expression> bindAggregateFunctionExpression(
const parser::ParsedExpression& parsedExpression, const std::string& functionName,
bool isDistinct);
std::shared_ptr<Expression> bindMacroExpression(
const parser::ParsedExpression& parsedExpression, const std::string& macroName);
std::shared_ptr<Expression> staticEvaluate(
const std::string& functionName, const expression_vector& children);

Expand Down
27 changes: 27 additions & 0 deletions src/include/binder/macro/bound_create_macro.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "binder/bound_statement.h"
#include "function/scalar_macro_function.h"

namespace kuzu {
namespace binder {

class BoundCreateMacro : public BoundStatement {
public:
explicit BoundCreateMacro(
std::string macroName, std::unique_ptr<function::ScalarMacroFunction> macro)
: BoundStatement{common::StatementType::CREATE_MACRO,
BoundStatementResult::createSingleStringColumnResult("result" /* columnName */)},
macroName{std::move(macroName)}, macro{std::move(macro)} {}

inline std::string getMacroName() const { return macroName; }

inline std::unique_ptr<function::ScalarMacroFunction> getMacro() const { return macro->copy(); }

private:
std::string macroName;
std::unique_ptr<function::ScalarMacroFunction> macro;
};

} // namespace binder
} // namespace kuzu
Loading
Loading