Skip to content

Commit

Permalink
Implement table function framework
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Jun 28, 2023
1 parent faada3f commit aab7589
Show file tree
Hide file tree
Showing 49 changed files with 5,555 additions and 2,325 deletions.
3 changes: 2 additions & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ kU_CopyNPY
: COPY SP oC_SchemaName SP FROM SP '(' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ')' SP BY SP COLUMN ;

kU_Call
: CALL SP oC_SymbolicName SP? '=' SP? oC_Literal ;
: CALL SP ( ( oC_SymbolicName SP? '=' SP? oC_Literal )
| ( oC_FunctionName SP? '(' oC_Literal? ')' ) );

CALL : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ( 'L' | 'l' ) ;

Expand Down
30 changes: 25 additions & 5 deletions src/binder/bind/bind_call.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
#include "binder/binder.h"
#include "binder/call/bound_call.h"
#include "common/string_utils.h"
#include "binder/call/bound_call_config.h"
#include "binder/call/bound_call_table_func.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression/variable_expression.h"
#include "parser/call/call.h"

namespace kuzu {
namespace binder {

std::unique_ptr<BoundStatement> Binder::bindCallClause(const parser::Statement& statement) {
std::unique_ptr<BoundStatement> Binder::bindCallTableFunc(const parser::Statement& statement) {
auto& callStatement = reinterpret_cast<const parser::Call&>(statement);
auto tableFunctionDefinition =
catalog.getBuiltInTableOperation()->mathTableOperation(callStatement.getOptionName());
auto boundExpr = expressionBinder.bindLiteralExpression(*callStatement.getOptionValue());
auto inputValue = reinterpret_cast<LiteralExpression*>(boundExpr.get())->getValue();
auto bindData = tableFunctionDefinition->bindFunc(
function::TableFuncBindInput{std::vector<common::Value>{*inputValue}},
catalog.getReadOnlyVersion());
auto statementResult = std::make_unique<BoundStatementResult>();
for (auto i = 0u; i < bindData->returnColumnNames.size(); i++) {
auto expr = std::make_shared<VariableExpression>(bindData->returnTypes[i],
bindData->returnColumnNames[i], bindData->returnColumnNames[i]);
statementResult->addColumn(expr, std::vector<std::shared_ptr<Expression>>{expr});
}
return std::make_unique<BoundCallTableFunc>(
tableFunctionDefinition->tableFunc, std::move(bindData), std::move(statementResult));
}

std::unique_ptr<BoundStatement> Binder::bindCallConfig(const parser::Statement& statement) {
auto& callStatement = reinterpret_cast<const parser::Call&>(statement);
auto option = main::DBConfig::getOptionByName(callStatement.getOptionName());
if (option == nullptr) {
Expand All @@ -16,8 +37,7 @@ std::unique_ptr<BoundStatement> Binder::bindCallClause(const parser::Statement&
auto optionValue = expressionBinder.bindLiteralExpression(*callStatement.getOptionValue());
// TODO(Ziyi): add casting rule for option value.
ExpressionBinder::validateExpectedDataType(*optionValue, option->parameterType);
auto boundCall = std::make_unique<BoundCall>(*option, std::move(optionValue));
return boundCall;
return std::make_unique<BoundCallConfig>(*option, std::move(optionValue));
}

} // namespace binder
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind_expression/bind_comparison_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(

std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
common::ExpressionType expressionType, const expression_vector& children) {
auto builtInFunctions = binder->catalog.getBuiltInScalarFunctions();
auto builtInFunctions = binder->catalog.getBuiltInVectorOperation();
auto functionName = expressionTypeToString(expressionType);
std::vector<common::LogicalType> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
}
auto function = builtInFunctions->matchFunction(functionName, childrenTypes);
auto function = builtInFunctions->matchVectorOperation(functionName, childrenTypes);
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
childrenAfterCast.push_back(
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(

std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const expression_vector& children, const std::string& functionName) {
auto builtInFunctions = binder->catalog.getBuiltInScalarFunctions();
auto builtInFunctions = binder->catalog.getBuiltInVectorOperation();
std::vector<LogicalType> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
}
auto function = builtInFunctions->matchFunction(functionName, childrenTypes);
auto function = builtInFunctions->matchVectorOperation(functionName, childrenTypes);
if (builtInFunctions->canApplyStaticEvaluation(functionName, children)) {
return staticEvaluate(functionName, children);
}
Expand Down
7 changes: 5 additions & 2 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ std::unique_ptr<BoundStatement> Binder::bind(const Statement& statement) {
case StatementType::QUERY: {
return bindQuery((const RegularQuery&)statement);
}
case StatementType::CALL: {
return bindCallClause(statement);
case StatementType::CALL_CONFIG: {
return bindCallConfig(statement);
}
case StatementType::CALL_TABLE_FUNC: {
return bindCallTableFunc(statement);
}
default:
assert(false);
Expand Down
7 changes: 5 additions & 2 deletions src/binder/bound_statement_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ void BoundStatementVisitor::visit(const kuzu::binder::BoundStatement& statement)
case StatementType::COPY: {
visitCopy(statement);
} break;
case StatementType::CALL: {
visitCall(statement);
case StatementType::CALL_CONFIG: {
visitCallConfig(statement);
} break;
case StatementType::CALL_TABLE_FUNC: {
visitCallTableFunc(statement);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visit");
Expand Down
2 changes: 2 additions & 0 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,14 @@ Catalog::Catalog() : wal{nullptr} {
catalogContentForReadOnlyTrx = std::make_unique<CatalogContent>();
builtInVectorOperations = std::make_unique<function::BuiltInVectorOperations>();
builtInAggregateFunctions = std::make_unique<function::BuiltInAggregateFunctions>();
builtInTableOperations = std::make_unique<function::BuiltInTableOperations>();
}

Catalog::Catalog(WAL* wal) : wal{wal} {
catalogContentForReadOnlyTrx = std::make_unique<CatalogContent>(wal->getDirectory());
builtInVectorOperations = std::make_unique<function::BuiltInVectorOperations>();
builtInAggregateFunctions = std::make_unique<function::BuiltInAggregateFunctions>();
builtInTableOperations = std::make_unique<function::BuiltInTableOperations>();
}

void Catalog::prepareCommitOrRollback(TransactionAction action) {
Expand Down
2 changes: 2 additions & 0 deletions src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ add_library(kuzu_function
base_lower_upper_operation.cpp
built_in_aggregate_functions.cpp
built_in_vector_operations.cpp
built_in_table_operations.cpp
find_operation.cpp
table_operations.cpp
vector_arithmetic_operations.cpp
vector_boolean_operations.cpp
vector_cast_operations.cpp
Expand Down
24 changes: 24 additions & 0 deletions src/function/built_in_table_operations.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "function/built_in_table_operations.h"

#include "common/expression_type.h"
#include "common/string_utils.h"

namespace kuzu {
namespace function {

void BuiltInTableOperations::registerTableOperations() {
tableOperations.insert({common::TABLE_INFO_FUNC_NAME, TableInfoOperation::getDefinitions()});
}

TableOperationDefinition* BuiltInTableOperations::mathTableOperation(const std::string& name) {
auto upperName = name;
common::StringUtils::toUpper(upperName);
if (!tableOperations.contains(upperName)) {
throw common::BinderException{
"Cannot match a built-in function for given function " + name + "."};

Check warning on line 18 in src/function/built_in_table_operations.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/built_in_table_operations.cpp#L17-L18

Added lines #L17 - L18 were not covered by tests
}
return tableOperations.at(upperName).get();
}

} // namespace function
} // namespace kuzu
2 changes: 1 addition & 1 deletion src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ bool BuiltInVectorOperations::canApplyStaticEvaluation(
return false;
}

VectorOperationDefinition* BuiltInVectorOperations::matchFunction(
VectorOperationDefinition* BuiltInVectorOperations::matchVectorOperation(
const std::string& name, const std::vector<LogicalType>& inputTypes) {
auto& functionDefinitions = vectorOperations.at(name);
bool isOverload = functionDefinitions.size() > 1;
Expand Down
56 changes: 56 additions & 0 deletions src/function/table_operations.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include "function/table_operations.h"

#include "catalog/catalog.h"

namespace kuzu {
namespace function {

std::unique_ptr<TableInfoBindData> TableInfoOperation::bindFunc(
kuzu::function::TableFuncBindInput input, catalog::CatalogContent* catalog) {
std::vector<std::string> returnColumnNames;
std::vector<common::LogicalType> returnTypes;
auto tableName = input.inputs[0].getValue<std::string>();
auto tableID = catalog->getTableID(tableName);
auto schema = catalog->getTableSchema(tableID);
returnColumnNames.emplace_back("property id");
returnTypes.emplace_back(common::LogicalTypeID::INT64);
returnColumnNames.emplace_back("name");
returnTypes.emplace_back(common::LogicalTypeID::STRING);
returnColumnNames.emplace_back("type");
returnTypes.emplace_back(common::LogicalTypeID::STRING);
if (schema->isNodeTable) {
returnColumnNames.emplace_back("pk");
returnTypes.emplace_back(common::LogicalTypeID::BOOL);
}
return std::make_unique<TableInfoBindData>(
schema, std::move(returnTypes), std::move(returnColumnNames), schema->getNumProperties());
}

void TableInfoOperation::tableFunc(std::pair<common::offset_t, common::offset_t> morsel,
function::TableFuncBindData* bindData, std::vector<common::ValueVector*> outputVectors) {
auto tableSchema = reinterpret_cast<function::TableInfoBindData*>(bindData)->tableSchema;
auto numPropertiesToOutput = morsel.second - morsel.first;
auto outVectorPos = 0;
for (auto i = 0u; i < numPropertiesToOutput; i++) {
auto property = tableSchema->properties[morsel.first + i];
if (!tableSchema->isNodeTable && property.name == common::InternalKeyword::ID) {
continue;
}
outputVectors[0]->setValue(outVectorPos, (int64_t)property.propertyID);
outputVectors[1]->setValue(outVectorPos, property.name);
outputVectors[2]->setValue(
outVectorPos, common::LogicalTypeUtils::dataTypeToString(property.dataType));
if (tableSchema->isNodeTable &&
reinterpret_cast<catalog::NodeTableSchema*>(tableSchema)->primaryKeyPropertyID ==
property.propertyID) {
outputVectors[3]->setValue(outVectorPos, true /* isPrimaryKey */);
}
outVectorPos++;
}
for (auto& outputVector : outputVectors) {
outputVector->state->selVector->selectedSize = outVectorPos;
}
}

} // namespace function
} // namespace kuzu
3 changes: 2 additions & 1 deletion src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ class Binder {
std::unique_ptr<BoundQueryPart> bindQueryPart(const parser::QueryPart& queryPart);

/*** bind call ***/
std::unique_ptr<BoundStatement> bindCallClause(const parser::Statement& statement);
std::unique_ptr<BoundStatement> bindCallTableFunc(const parser::Statement& statement);
std::unique_ptr<BoundStatement> bindCallConfig(const parser::Statement& statement);

/*** bind reading clause ***/
std::unique_ptr<BoundReadingClause> bindReadingClause(
Expand Down
3 changes: 2 additions & 1 deletion src/include/binder/bound_statement_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class BoundStatementVisitor {
virtual void visitDropProperty(const BoundStatement& statement) {}
virtual void visitRenameProperty(const BoundStatement& statement) {}
virtual void visitCopy(const BoundStatement& statement) {}
virtual void visitCall(const BoundStatement& statement) {}
virtual void visitCallConfig(const BoundStatement& statement) {}
virtual void visitCallTableFunc(const BoundStatement& statement) {}

void visitReadingClause(const BoundReadingClause& readingClause);
virtual void visitMatch(const BoundReadingClause& readingClause) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
namespace kuzu {
namespace binder {

class BoundCall : public BoundStatement {
class BoundCallConfig : public BoundStatement {
public:
BoundCall(main::ConfigurationOption option, std::shared_ptr<Expression> optionValue)
: BoundStatement{common::StatementType::CALL, BoundStatementResult::createEmptyResult()},
BoundCallConfig(main::ConfigurationOption option, std::shared_ptr<Expression> optionValue)
: BoundStatement{common::StatementType::CALL_CONFIG,
BoundStatementResult::createEmptyResult()},
option{option}, optionValue{optionValue} {}

inline main::ConfigurationOption getOption() const { return option; }
Expand Down
27 changes: 27 additions & 0 deletions src/include/binder/call/bound_call_table_func.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "binder/expression/expression.h"
#include "function/table_operations.h"

namespace kuzu {
namespace binder {

class BoundCallTableFunc : public BoundStatement {
public:
BoundCallTableFunc(function::table_func_t tableFunc,
std::unique_ptr<function::TableFuncBindData> bindData,
std::unique_ptr<BoundStatementResult> statementResult)
: BoundStatement{common::StatementType::CALL_TABLE_FUNC, std::move(statementResult)},
tableFunc{std::move(tableFunc)}, bindData{std::move(bindData)} {}

inline function::table_func_t getTableFunc() const { return tableFunc; }

inline function::TableFuncBindData* getBindData() const { return bindData.get(); }

private:
function::table_func_t tableFunc;
std::unique_ptr<function::TableFuncBindData> bindData;
};

} // namespace binder
} // namespace kuzu
7 changes: 6 additions & 1 deletion src/include/catalog/catalog.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "common/ser_deser.h"
#include "common/utils.h"
#include "function/aggregate/built_in_aggregate_functions.h"
#include "function/built_in_table_operations.h"
#include "function/built_in_vector_operations.h"
#include "storage/storage_info.h"
#include "storage/wal/wal.h"
Expand Down Expand Up @@ -164,12 +165,15 @@ class Catalog {
inline CatalogContent* getReadOnlyVersion() const { return catalogContentForReadOnlyTrx.get(); }
inline CatalogContent* getWriteVersion() const { return catalogContentForWriteTrx.get(); }

inline function::BuiltInVectorOperations* getBuiltInScalarFunctions() const {
inline function::BuiltInVectorOperations* getBuiltInVectorOperation() const {
return builtInVectorOperations.get();
}
inline function::BuiltInAggregateFunctions* getBuiltInAggregateFunction() const {
return builtInAggregateFunctions.get();
}
inline function::BuiltInTableOperations* getBuiltInTableOperation() const {
return builtInTableOperations.get();
}

void prepareCommitOrRollback(transaction::TransactionAction action);
void checkpointInMemory();
Expand Down Expand Up @@ -217,6 +221,7 @@ class Catalog {
protected:
std::unique_ptr<function::BuiltInVectorOperations> builtInVectorOperations;
std::unique_ptr<function::BuiltInAggregateFunctions> builtInAggregateFunctions;
std::unique_ptr<function::BuiltInTableOperations> builtInTableOperations;
std::unique_ptr<CatalogContent> catalogContentForReadOnlyTrx;
std::unique_ptr<CatalogContent> catalogContentForWriteTrx;
storage::WAL* wal;
Expand Down
3 changes: 3 additions & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ const std::string OCTET_LENGTH_FUNC_NAME = "OCTET_LENGTH";
const std::string ENCODE_FUNC_NAME = "ENCODE";
const std::string DECODE_FUNC_NAME = "DECODE";

// TABLE functions
const std::string TABLE_INFO_FUNC_NAME = "TABLE_INFO";

enum ExpressionType : uint8_t {

// Boolean Connection Expressions
Expand Down
3 changes: 2 additions & 1 deletion src/include/common/statement_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ enum class StatementType : uint8_t {
DROP_PROPERTY = 7,
RENAME_PROPERTY = 8,
COPY = 20,
CALL = 21,
CALL_CONFIG = 21,
CALL_TABLE_FUNC = 22,
};

class StatementTypeUtils {
Expand Down
24 changes: 24 additions & 0 deletions src/include/function/built_in_table_operations.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include "table_operations.h"

namespace kuzu {
namespace function {

class BuiltInTableOperations {

public:
BuiltInTableOperations() { registerTableOperations(); }

TableOperationDefinition* mathTableOperation(const std::string& name);

private:
void registerTableOperations();

private:
// TODO(Ziyi): Refactor vectorOperation/tableOperation to inherit from the same base class.
std::unordered_map<std::string, std::unique_ptr<TableOperationDefinition>> tableOperations;
};

} // namespace function
} // namespace kuzu
3 changes: 2 additions & 1 deletion src/include/function/built_in_vector_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class BuiltInVectorOperations {
bool canApplyStaticEvaluation(
const std::string& functionName, const binder::expression_vector& children);

VectorOperationDefinition* matchFunction(
VectorOperationDefinition* matchVectorOperation(
const std::string& name, const std::vector<common::LogicalType>& inputTypes);

std::vector<std::string> getFunctionNames();
Expand Down Expand Up @@ -78,6 +78,7 @@ class BuiltInVectorOperations {
void registerNodeRelOperations();

private:
// TODO(Ziyi): Refactor vectorOperation/tableOperation to inherit from the same base class.
std::unordered_map<std::string, vector_operation_definitions> vectorOperations;
};

Expand Down
Loading

0 comments on commit aab7589

Please sign in to comment.