Skip to content

Commit

Permalink
Implement scalar macro
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Jul 19, 2023
1 parent 1e9d4fd commit f3c0d5b
Show file tree
Hide file tree
Showing 61 changed files with 5,272 additions and 3,858 deletions.
14 changes: 13 additions & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ kU_StandaloneCall

CALL : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ( 'L' | 'l' ) ;

kU_CreateMacro
: CREATE SP MACRO SP oC_FunctionName SP? '(' SP? kU_PositionalArgs? SP? kU_DefaultArg? ( SP? ',' SP? kU_DefaultArg )* SP? ')' SP AS SP oC_Expression ;

kU_PositionalArgs
: oC_SymbolicName ( SP? ',' SP? oC_SymbolicName )* ;

kU_DefaultArg
: oC_SymbolicName SP? ':' '=' SP? oC_Literal ;

MACRO : ( 'M' | 'm' ) ( 'A' | 'a' ) ( 'C' | 'c' ) ( 'R' | 'r' ) ( 'O' | 'o' ) ;

kU_FilePaths
: '[' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ']'
| StringLiteral
Expand Down Expand Up @@ -140,7 +151,8 @@ oC_Statement
| kU_DDL
| kU_CopyNPY
| kU_CopyCSV
| kU_StandaloneCall ;
| kU_StandaloneCall
| kU_CreateMacro ;

oC_Query
: oC_RegularQuery ;
Expand Down
1 change: 1 addition & 0 deletions src/binder/bind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_library(
kuzu_binder_bind
OBJECT
bind_standalone_call.cpp
bind_create_macro.cpp
bind_copy.cpp
bind_ddl.cpp
bind_explain.cpp
Expand Down
27 changes: 27 additions & 0 deletions src/binder/bind/bind_create_macro.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "binder/binder.h"
#include "binder/macro/bound_create_macro.h"
#include "common/string_utils.h"
#include "parser/macro/create_macro.h"

namespace kuzu {
namespace binder {

std::unique_ptr<BoundStatement> Binder::bindCreateMacro(const parser::Statement& statement) {
auto& createMacro = reinterpret_cast<const parser::CreateMacro&>(statement);
auto macroName = createMacro.getMacroName();
if (catalog.getReadOnlyVersion()->containMacro(macroName)) {
throw common::BinderException{
common::StringUtils::string_format("Macro {} already exists.", macroName)};

Check warning on line 14 in src/binder/bind/bind_create_macro.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/bind/bind_create_macro.cpp#L13-L14

Added lines #L13 - L14 were not covered by tests
}
parser::default_macro_args defaultArgs;
for (auto& defaultArg : createMacro.getDefaultArgs()) {
defaultArgs.emplace_back(defaultArg.first, defaultArg.second->copy());
}
auto scalarMacro =
std::make_unique<function::ScalarMacroFunction>(createMacro.getMacroExpression()->copy(),
createMacro.getPositionalArgs(), std::move(defaultArgs));
return std::make_unique<BoundCreateMacro>(macroName, std::move(scalarMacro));
}

} // namespace binder
} // namespace kuzu
36 changes: 33 additions & 3 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "common/string_utils.h"
#include "function/schema/vector_label_functions.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/parsed_expression_visitor.h"

using namespace kuzu::common;
using namespace kuzu::parser;
Expand All @@ -22,12 +23,16 @@ std::shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(
return result;
}
auto functionType = binder->catalog.getFunctionType(functionName);
if (functionType == FUNCTION) {
switch (functionType) {
case common::FUNCTION:
return bindScalarFunctionExpression(parsedExpression, functionName);
} else {
assert(functionType == AGGREGATE_FUNCTION);
case common::AGGREGATE_FUNCTION:
return bindAggregateFunctionExpression(
parsedExpression, functionName, parsedFunctionExpression.getIsDistinct());
case common::MACRO:
return bindMacroExpression(parsedExpression, functionName);
default:
throw NotImplementedException{"ExpressionBinder::bindFunctionExpression"};

Check warning on line 35 in src/binder/bind_expression/bind_function_expression.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/bind_expression/bind_function_expression.cpp#L34-L35

Added lines #L34 - L35 were not covered by tests
}
}

Expand Down Expand Up @@ -104,6 +109,31 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
std::move(children), function->aggregateFunction->clone(), uniqueExpressionName);
}

std::shared_ptr<Expression> ExpressionBinder::bindMacroExpression(
const ParsedExpression& parsedExpression, const std::string& macroName) {
auto scalarMacroFunction = binder->catalog.getScalarMacroFunction(macroName);
auto macroExpr = scalarMacroFunction->expression->copy();
auto parameterVals = scalarMacroFunction->getDefaultParameterVals();
auto& parsedFuncExpr = reinterpret_cast<const ParsedFunctionExpression&>(parsedExpression);
auto positionalArgs = scalarMacroFunction->getPositionalArgs();
if (parsedFuncExpr.getNumChildren() > scalarMacroFunction->getNumArgs() ||
parsedFuncExpr.getNumChildren() < positionalArgs.size()) {
throw BinderException{"Invalid number of arguments for macro " + macroName + "."};
}
// Bind positional arguments.
for (auto i = 0u; i < positionalArgs.size(); i++) {
parameterVals[positionalArgs[i]] = parsedFuncExpr.getChild(i);
}
// Bind arguments with default values.
for (auto i = positionalArgs.size(); i < parsedFuncExpr.getNumChildren(); i++) {
auto parameterName =
scalarMacroFunction->getDefaultParameterName(i - positionalArgs.size());
parameterVals[parameterName] = parsedFuncExpr.getChild(i);
}
auto macroParameterReplacer = std::make_unique<MacroParameterReplacer>(parameterVals);
return bindExpression(*macroParameterReplacer->visit(std::move(macroExpr)));
}

std::shared_ptr<Expression> ExpressionBinder::staticEvaluate(
const std::string& functionName, const expression_vector& children) {
assert(children[0]->expressionType == common::LITERAL);
Expand Down
3 changes: 3 additions & 0 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ std::unique_ptr<BoundStatement> Binder::bind(const Statement& statement) {
case StatementType::EXPLAIN: {
return bindExplain(statement);
}
case StatementType::CREATE_MACRO: {
return bindCreateMacro(statement);
}
default:
assert(false);
}
Expand Down
3 changes: 3 additions & 0 deletions src/binder/bound_statement_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ void BoundStatementVisitor::visit(const kuzu::binder::BoundStatement& statement)
case StatementType::EXPLAIN: {
visitExplain(statement);
} break;
case StatementType::CREATE_MACRO: {
visitCreateMacro(statement);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visit");
}
Expand Down
54 changes: 40 additions & 14 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,15 @@ namespace catalog {

CatalogContent::CatalogContent() : nextTableID{0} {
logger = LoggerUtils::getLogger(LoggerConstants::LoggerEnum::CATALOG);
registerBuiltInFunctions();
}

CatalogContent::CatalogContent(const std::string& directory) {
logger = LoggerUtils::getLogger(LoggerConstants::LoggerEnum::CATALOG);
logger->info("Initializing catalog.");
readFromFile(directory, DBFileType::ORIGINAL);
logger->info("Initializing catalog done.");
registerBuiltInFunctions();
}

CatalogContent::CatalogContent(const CatalogContent& other) {
Expand All @@ -192,6 +194,7 @@ CatalogContent::CatalogContent(const CatalogContent& other) {
nodeTableNameToIDMap = other.nodeTableNameToIDMap;
relTableNameToIDMap = other.relTableNameToIDMap;
nextTableID = other.nextTableID;
registerBuiltInFunctions();
}

table_id_t CatalogContent::addNodeTableSchema(
Expand Down Expand Up @@ -325,6 +328,31 @@ void CatalogContent::readFromFile(const std::string& directory, DBFileType dbFil
SerDeser::deserializeValue<table_id_t>(nextTableID, fileInfo.get(), offset);
}

ExpressionType CatalogContent::getFunctionType(const std::string& name) const {
auto upperCaseName = StringUtils::getUpper(name);
if (builtInVectorFunctions->containsFunction(upperCaseName)) {
return FUNCTION;
} else if (builtInAggregateFunctions->containsFunction(upperCaseName)) {
return AGGREGATE_FUNCTION;
} else if (macros.contains(upperCaseName)) {
return MACRO;
} else {
throw CatalogException(name + " function does not exist.");
}
}

void CatalogContent::addVectorFunction(
std::string name, function::vector_function_definitions definitions) {
StringUtils::toUpper(name);
builtInVectorFunctions->addFunction(std::move(name), std::move(definitions));
}

void CatalogContent::addScalarMacroFunction(
std::string name, std::unique_ptr<function::ScalarMacroFunction> macro) {
StringUtils::toUpper(name);
macros.emplace(std::move(name), std::move(macro));
}

void CatalogContent::validateStorageVersion(storage_version_t savedStorageVersion) const {
auto storageVersion = StorageVersionInfo::getStorageVersion();
if (savedStorageVersion != storageVersion) {
Expand Down Expand Up @@ -355,18 +383,18 @@ void CatalogContent::writeMagicBytes(FileInfo* fileInfo, offset_t& offset) const
}
}

Catalog::Catalog() : wal{nullptr} {
catalogContentForReadOnlyTrx = std::make_unique<CatalogContent>();
void CatalogContent::registerBuiltInFunctions() {
builtInVectorFunctions = std::make_unique<function::BuiltInVectorFunctions>();
builtInAggregateFunctions = std::make_unique<function::BuiltInAggregateFunctions>();
builtInTableFunctions = std::make_unique<function::BuiltInTableFunctions>();
}

Catalog::Catalog() : wal{nullptr} {
catalogContentForReadOnlyTrx = std::make_unique<CatalogContent>();
}

Catalog::Catalog(WAL* wal) : wal{wal} {
catalogContentForReadOnlyTrx = std::make_unique<CatalogContent>(wal->getDirectory());
builtInVectorFunctions = std::make_unique<function::BuiltInVectorFunctions>();
builtInAggregateFunctions = std::make_unique<function::BuiltInAggregateFunctions>();
builtInTableFunctions = std::make_unique<function::BuiltInTableFunctions>();
}

void Catalog::prepareCommitOrRollback(TransactionAction action) {
Expand All @@ -386,13 +414,7 @@ void Catalog::checkpointInMemory() {
}

ExpressionType Catalog::getFunctionType(const std::string& name) const {
if (builtInVectorFunctions->containsFunction(name)) {
return FUNCTION;
} else if (builtInAggregateFunctions->containsFunction(name)) {
return AGGREGATE_FUNCTION;
} else {
throw CatalogException(name + " function does not exist.");
}
return catalogContentForReadOnlyTrx->getFunctionType(name);
}

table_id_t Catalog::addNodeTableSchema(
Expand Down Expand Up @@ -461,8 +483,12 @@ std::unordered_set<RelTableSchema*> Catalog::getAllRelTableSchemasContainBoundTa

void Catalog::addVectorFunction(
std::string name, function::vector_function_definitions definitions) {
common::StringUtils::toUpper(name);
builtInVectorFunctions->addFunction(std::move(name), std::move(definitions));
catalogContentForReadOnlyTrx->addVectorFunction(std::move(name), std::move(definitions));
}

void Catalog::addScalarMacroFunction(
std::string name, std::unique_ptr<function::ScalarMacroFunction> macro) {
catalogContentForReadOnlyTrx->addScalarMacroFunction(std::move(name), std::move(macro));
}

} // namespace catalog
Expand Down
1 change: 1 addition & 0 deletions src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ add_library(kuzu_function
built_in_table_functions.cpp
comparison_functions.cpp
find_function.cpp
scalar_macro_function.cpp
table_functions.cpp
vector_arithmetic_functions.cpp
vector_boolean_functions.cpp
Expand Down
24 changes: 24 additions & 0 deletions src/function/scalar_macro_function.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "function/scalar_macro_function.h"

namespace kuzu {
namespace function {

macro_parameter_value_map ScalarMacroFunction::getDefaultParameterVals() const {
macro_parameter_value_map defaultArgsToReturn;
for (auto& defaultArg : defaultArgs) {
defaultArgsToReturn.emplace(defaultArg.first, defaultArg.second.get());
}
return defaultArgsToReturn;
}

std::unique_ptr<ScalarMacroFunction> ScalarMacroFunction::copy() const {
parser::default_macro_args defaultArgsCopy;
for (auto& defaultArg : defaultArgs) {
defaultArgsCopy.emplace_back(defaultArg.first, defaultArg.second->copy());
}
return std::make_unique<ScalarMacroFunction>(
expression->copy(), positionalArgs, std::move(defaultArgsCopy));
}

} // namespace function
} // namespace kuzu
3 changes: 3 additions & 0 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ class Binder {
/*** bind call ***/
std::unique_ptr<BoundStatement> bindStandaloneCall(const parser::Statement& statement);

/*** bind create macro ***/
std::unique_ptr<BoundStatement> bindCreateMacro(const parser::Statement& statement);

/*** bind explain ***/
std::unique_ptr<BoundStatement> bindExplain(const parser::Statement& statement);

Expand Down
1 change: 1 addition & 0 deletions src/include/binder/bound_statement_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class BoundStatementVisitor {
virtual void visitCopy(const BoundStatement& statement) {}
virtual void visitStandaloneCall(const BoundStatement& statement) {}
virtual void visitExplain(const BoundStatement& statement);
virtual void visitCreateMacro(const BoundStatement& statement) {}

void visitReadingClause(const BoundReadingClause& readingClause);
virtual void visitMatch(const BoundReadingClause& readingClause) {}
Expand Down
2 changes: 2 additions & 0 deletions src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class ExpressionBinder {
std::shared_ptr<Expression> bindAggregateFunctionExpression(
const parser::ParsedExpression& parsedExpression, const std::string& functionName,
bool isDistinct);
std::shared_ptr<Expression> bindMacroExpression(
const parser::ParsedExpression& parsedExpression, const std::string& macroName);
std::shared_ptr<Expression> staticEvaluate(
const std::string& functionName, const expression_vector& children);

Expand Down
27 changes: 27 additions & 0 deletions src/include/binder/macro/bound_create_macro.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "binder/bound_statement.h"
#include "function/scalar_macro_function.h"

namespace kuzu {
namespace binder {

class BoundCreateMacro : public BoundStatement {
public:
explicit BoundCreateMacro(
std::string macroName, std::unique_ptr<function::ScalarMacroFunction> macro)
: BoundStatement{common::StatementType::CREATE_MACRO,
BoundStatementResult::createSingleStringColumnResult("result" /* columnName */)},
macroName{std::move(macroName)}, macro{std::move(macro)} {}

inline std::string getMacroName() const { return macroName; }

inline std::unique_ptr<function::ScalarMacroFunction> getMacro() const { return macro->copy(); }

private:
std::string macroName;
std::unique_ptr<function::ScalarMacroFunction> macro;
};

} // namespace binder
} // namespace kuzu
Loading

0 comments on commit f3c0d5b

Please sign in to comment.