Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement table function framework #1731

Merged
merged 1 commit into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ kU_CopyNPY
: COPY SP oC_SchemaName SP FROM SP '(' SP? StringLiteral ( SP? ',' SP? StringLiteral )* ')' SP BY SP COLUMN ;

kU_Call
: CALL SP oC_SymbolicName SP? '=' SP? oC_Literal ;
: CALL SP ( ( oC_SymbolicName SP? '=' SP? oC_Literal )
acquamarin marked this conversation as resolved.
Show resolved Hide resolved
| ( oC_FunctionName SP? '(' oC_Literal? ')' ) );

CALL : ( 'C' | 'c' ) ( 'A' | 'a' ) ( 'L' | 'l' ) ( 'L' | 'l' ) ;

Expand Down
30 changes: 25 additions & 5 deletions src/binder/bind/bind_call.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
#include "binder/binder.h"
#include "binder/call/bound_call.h"
#include "common/string_utils.h"
#include "binder/call/bound_call_config.h"
#include "binder/call/bound_call_table_func.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression/variable_expression.h"
#include "parser/call/call.h"

namespace kuzu {
namespace binder {

std::unique_ptr<BoundStatement> Binder::bindCallClause(const parser::Statement& statement) {
std::unique_ptr<BoundStatement> Binder::bindCallTableFunc(const parser::Statement& statement) {
auto& callStatement = reinterpret_cast<const parser::Call&>(statement);
auto tableFunctionDefinition =
catalog.getBuiltInTableOperation()->mathTableOperation(callStatement.getOptionName());
auto boundExpr = expressionBinder.bindLiteralExpression(*callStatement.getOptionValue());
auto inputValue = reinterpret_cast<LiteralExpression*>(boundExpr.get())->getValue();
auto bindData = tableFunctionDefinition->bindFunc(
function::TableFuncBindInput{std::vector<common::Value>{*inputValue}},
catalog.getReadOnlyVersion());
auto statementResult = std::make_unique<BoundStatementResult>();
for (auto i = 0u; i < bindData->returnColumnNames.size(); i++) {
auto expr = std::make_shared<VariableExpression>(bindData->returnTypes[i],
bindData->returnColumnNames[i], bindData->returnColumnNames[i]);
statementResult->addColumn(expr, expression_vector{expr});
}
return std::make_unique<BoundCallTableFunc>(
tableFunctionDefinition->tableFunc, std::move(bindData), std::move(statementResult));
}

std::unique_ptr<BoundStatement> Binder::bindCallConfig(const parser::Statement& statement) {
auto& callStatement = reinterpret_cast<const parser::Call&>(statement);
auto option = main::DBConfig::getOptionByName(callStatement.getOptionName());
if (option == nullptr) {
Expand All @@ -16,8 +37,7 @@ std::unique_ptr<BoundStatement> Binder::bindCallClause(const parser::Statement&
auto optionValue = expressionBinder.bindLiteralExpression(*callStatement.getOptionValue());
// TODO(Ziyi): add casting rule for option value.
ExpressionBinder::validateExpectedDataType(*optionValue, option->parameterType);
auto boundCall = std::make_unique<BoundCall>(*option, std::move(optionValue));
return boundCall;
return std::make_unique<BoundCallConfig>(*option, std::move(optionValue));
}

} // namespace binder
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind_expression/bind_comparison_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(

std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
common::ExpressionType expressionType, const expression_vector& children) {
auto builtInFunctions = binder->catalog.getBuiltInScalarFunctions();
auto builtInFunctions = binder->catalog.getBuiltInVectorOperation();
auto functionName = expressionTypeToString(expressionType);
std::vector<common::LogicalType> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
}
auto function = builtInFunctions->matchFunction(functionName, childrenTypes);
auto function = builtInFunctions->matchVectorOperation(functionName, childrenTypes);
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
childrenAfterCast.push_back(
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(

std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const expression_vector& children, const std::string& functionName) {
auto builtInFunctions = binder->catalog.getBuiltInScalarFunctions();
auto builtInFunctions = binder->catalog.getBuiltInVectorOperation();
std::vector<LogicalType> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
}
auto function = builtInFunctions->matchFunction(functionName, childrenTypes);
auto function = builtInFunctions->matchVectorOperation(functionName, childrenTypes);
if (builtInFunctions->canApplyStaticEvaluation(functionName, children)) {
return staticEvaluate(functionName, children);
}
Expand Down
7 changes: 5 additions & 2 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ std::unique_ptr<BoundStatement> Binder::bind(const Statement& statement) {
case StatementType::QUERY: {
return bindQuery((const RegularQuery&)statement);
}
case StatementType::CALL: {
return bindCallClause(statement);
case StatementType::CALL_CONFIG: {
return bindCallConfig(statement);
}
case StatementType::CALL_TABLE_FUNC: {
return bindCallTableFunc(statement);
}
default:
assert(false);
Expand Down
7 changes: 5 additions & 2 deletions src/binder/bound_statement_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ void BoundStatementVisitor::visit(const kuzu::binder::BoundStatement& statement)
case StatementType::COPY: {
visitCopy(statement);
} break;
case StatementType::CALL: {
visitCall(statement);
case StatementType::CALL_CONFIG: {
visitCallConfig(statement);
} break;
case StatementType::CALL_TABLE_FUNC: {
visitCallTableFunc(statement);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visit");
Expand Down
2 changes: 2 additions & 0 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,14 @@ Catalog::Catalog() : wal{nullptr} {
catalogContentForReadOnlyTrx = std::make_unique<CatalogContent>();
builtInVectorOperations = std::make_unique<function::BuiltInVectorOperations>();
builtInAggregateFunctions = std::make_unique<function::BuiltInAggregateFunctions>();
builtInTableOperations = std::make_unique<function::BuiltInTableOperations>();
}

Catalog::Catalog(WAL* wal) : wal{wal} {
catalogContentForReadOnlyTrx = std::make_unique<CatalogContent>(wal->getDirectory());
builtInVectorOperations = std::make_unique<function::BuiltInVectorOperations>();
builtInAggregateFunctions = std::make_unique<function::BuiltInAggregateFunctions>();
builtInTableOperations = std::make_unique<function::BuiltInTableOperations>();
}

void Catalog::prepareCommitOrRollback(TransactionAction action) {
Expand Down
2 changes: 2 additions & 0 deletions src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ add_library(kuzu_function
base_lower_upper_operation.cpp
built_in_aggregate_functions.cpp
built_in_vector_operations.cpp
built_in_table_operations.cpp
find_operation.cpp
table_operations.cpp
vector_arithmetic_operations.cpp
vector_boolean_operations.cpp
vector_cast_operations.cpp
Expand Down
24 changes: 24 additions & 0 deletions src/function/built_in_table_operations.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "function/built_in_table_operations.h"

#include "common/expression_type.h"
#include "common/string_utils.h"

namespace kuzu {
namespace function {

void BuiltInTableOperations::registerTableOperations() {
tableOperations.insert({common::TABLE_INFO_FUNC_NAME, TableInfoOperation::getDefinitions()});
}

TableOperationDefinition* BuiltInTableOperations::mathTableOperation(const std::string& name) {
auto upperName = name;
common::StringUtils::toUpper(upperName);
if (!tableOperations.contains(upperName)) {
throw common::BinderException{
"Cannot match a built-in function for given function " + name + "."};

Check warning on line 18 in src/function/built_in_table_operations.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/built_in_table_operations.cpp#L17-L18

Added lines #L17 - L18 were not covered by tests
}
return tableOperations.at(upperName).get();
}

} // namespace function
} // namespace kuzu
2 changes: 1 addition & 1 deletion src/function/built_in_vector_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ bool BuiltInVectorOperations::canApplyStaticEvaluation(
return false;
}

VectorOperationDefinition* BuiltInVectorOperations::matchFunction(
VectorOperationDefinition* BuiltInVectorOperations::matchVectorOperation(
acquamarin marked this conversation as resolved.
Show resolved Hide resolved
const std::string& name, const std::vector<LogicalType>& inputTypes) {
auto& functionDefinitions = vectorOperations.at(name);
bool isOverload = functionDefinitions.size() > 1;
Expand Down
56 changes: 56 additions & 0 deletions src/function/table_operations.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#include "function/table_operations.h"

#include "catalog/catalog.h"

namespace kuzu {
namespace function {

std::unique_ptr<TableInfoBindData> TableInfoOperation::bindFunc(
kuzu::function::TableFuncBindInput input, catalog::CatalogContent* catalog) {
std::vector<std::string> returnColumnNames;
std::vector<common::LogicalType> returnTypes;
auto tableName = input.inputs[0].getValue<std::string>();
auto tableID = catalog->getTableID(tableName);
auto schema = catalog->getTableSchema(tableID);
returnColumnNames.emplace_back("property id");
returnTypes.emplace_back(common::LogicalTypeID::INT64);
returnColumnNames.emplace_back("name");
returnTypes.emplace_back(common::LogicalTypeID::STRING);
returnColumnNames.emplace_back("type");
returnTypes.emplace_back(common::LogicalTypeID::STRING);
if (schema->isNodeTable) {
returnColumnNames.emplace_back("primary key");
returnTypes.emplace_back(common::LogicalTypeID::BOOL);
}
return std::make_unique<TableInfoBindData>(
schema, std::move(returnTypes), std::move(returnColumnNames), schema->getNumProperties());
}

void TableInfoOperation::tableFunc(std::pair<common::offset_t, common::offset_t> morsel,
function::TableFuncBindData* bindData, std::vector<common::ValueVector*> outputVectors) {
auto tableSchema = reinterpret_cast<function::TableInfoBindData*>(bindData)->tableSchema;
auto numPropertiesToOutput = morsel.second - morsel.first;
auto outVectorPos = 0;
for (auto i = 0u; i < numPropertiesToOutput; i++) {
auto property = tableSchema->properties[morsel.first + i];
if (!tableSchema->isNodeTable && property.name == common::InternalKeyword::ID) {
continue;
}
outputVectors[0]->setValue(outVectorPos, (int64_t)property.propertyID);
outputVectors[1]->setValue(outVectorPos, property.name);
outputVectors[2]->setValue(
outVectorPos, common::LogicalTypeUtils::dataTypeToString(property.dataType));
if (tableSchema->isNodeTable &&
reinterpret_cast<catalog::NodeTableSchema*>(tableSchema)->primaryKeyPropertyID ==
property.propertyID) {
outputVectors[3]->setValue(outVectorPos, true /* isPrimaryKey */);
}
outVectorPos++;
}
for (auto& outputVector : outputVectors) {
outputVector->state->selVector->selectedSize = outVectorPos;
}
}

} // namespace function
} // namespace kuzu
3 changes: 2 additions & 1 deletion src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ class Binder {
std::unique_ptr<BoundQueryPart> bindQueryPart(const parser::QueryPart& queryPart);

/*** bind call ***/
std::unique_ptr<BoundStatement> bindCallClause(const parser::Statement& statement);
std::unique_ptr<BoundStatement> bindCallTableFunc(const parser::Statement& statement);
std::unique_ptr<BoundStatement> bindCallConfig(const parser::Statement& statement);

/*** bind reading clause ***/
std::unique_ptr<BoundReadingClause> bindReadingClause(
Expand Down
3 changes: 2 additions & 1 deletion src/include/binder/bound_statement_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class BoundStatementVisitor {
virtual void visitDropProperty(const BoundStatement& statement) {}
virtual void visitRenameProperty(const BoundStatement& statement) {}
virtual void visitCopy(const BoundStatement& statement) {}
virtual void visitCall(const BoundStatement& statement) {}
virtual void visitCallConfig(const BoundStatement& statement) {}
virtual void visitCallTableFunc(const BoundStatement& statement) {}

void visitReadingClause(const BoundReadingClause& readingClause);
virtual void visitMatch(const BoundReadingClause& readingClause) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
namespace kuzu {
namespace binder {

class BoundCall : public BoundStatement {
class BoundCallConfig : public BoundStatement {
public:
BoundCall(main::ConfigurationOption option, std::shared_ptr<Expression> optionValue)
: BoundStatement{common::StatementType::CALL, BoundStatementResult::createEmptyResult()},
BoundCallConfig(main::ConfigurationOption option, std::shared_ptr<Expression> optionValue)
: BoundStatement{common::StatementType::CALL_CONFIG,
BoundStatementResult::createEmptyResult()},
option{option}, optionValue{optionValue} {}

inline main::ConfigurationOption getOption() const { return option; }
Expand Down
27 changes: 27 additions & 0 deletions src/include/binder/call/bound_call_table_func.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "binder/expression/expression.h"
#include "function/table_operations.h"

namespace kuzu {
namespace binder {

class BoundCallTableFunc : public BoundStatement {
public:
BoundCallTableFunc(function::table_func_t tableFunc,
std::unique_ptr<function::TableFuncBindData> bindData,
std::unique_ptr<BoundStatementResult> statementResult)
: BoundStatement{common::StatementType::CALL_TABLE_FUNC, std::move(statementResult)},
tableFunc{std::move(tableFunc)}, bindData{std::move(bindData)} {}

inline function::table_func_t getTableFunc() const { return tableFunc; }

inline function::TableFuncBindData* getBindData() const { return bindData.get(); }

private:
function::table_func_t tableFunc;
std::unique_ptr<function::TableFuncBindData> bindData;
};

} // namespace binder
} // namespace kuzu
7 changes: 6 additions & 1 deletion src/include/catalog/catalog.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "common/ser_deser.h"
#include "common/utils.h"
#include "function/aggregate/built_in_aggregate_functions.h"
#include "function/built_in_table_operations.h"
#include "function/built_in_vector_operations.h"
#include "storage/storage_info.h"
#include "storage/wal/wal.h"
Expand Down Expand Up @@ -164,12 +165,15 @@ class Catalog {
inline CatalogContent* getReadOnlyVersion() const { return catalogContentForReadOnlyTrx.get(); }
inline CatalogContent* getWriteVersion() const { return catalogContentForWriteTrx.get(); }

inline function::BuiltInVectorOperations* getBuiltInScalarFunctions() const {
inline function::BuiltInVectorOperations* getBuiltInVectorOperation() const {
return builtInVectorOperations.get();
}
inline function::BuiltInAggregateFunctions* getBuiltInAggregateFunction() const {
return builtInAggregateFunctions.get();
}
inline function::BuiltInTableOperations* getBuiltInTableOperation() const {
return builtInTableOperations.get();
}

void prepareCommitOrRollback(transaction::TransactionAction action);
void checkpointInMemory();
Expand Down Expand Up @@ -217,6 +221,7 @@ class Catalog {
protected:
std::unique_ptr<function::BuiltInVectorOperations> builtInVectorOperations;
std::unique_ptr<function::BuiltInAggregateFunctions> builtInAggregateFunctions;
std::unique_ptr<function::BuiltInTableOperations> builtInTableOperations;
std::unique_ptr<CatalogContent> catalogContentForReadOnlyTrx;
std::unique_ptr<CatalogContent> catalogContentForWriteTrx;
storage::WAL* wal;
Expand Down
3 changes: 3 additions & 0 deletions src/include/common/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ const std::string OCTET_LENGTH_FUNC_NAME = "OCTET_LENGTH";
const std::string ENCODE_FUNC_NAME = "ENCODE";
const std::string DECODE_FUNC_NAME = "DECODE";

// TABLE functions
const std::string TABLE_INFO_FUNC_NAME = "TABLE_INFO";

enum ExpressionType : uint8_t {

// Boolean Connection Expressions
Expand Down
3 changes: 2 additions & 1 deletion src/include/common/statement_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ enum class StatementType : uint8_t {
DROP_PROPERTY = 7,
RENAME_PROPERTY = 8,
COPY = 20,
CALL = 21,
CALL_CONFIG = 21,
CALL_TABLE_FUNC = 22,
};

class StatementTypeUtils {
Expand Down
24 changes: 24 additions & 0 deletions src/include/function/built_in_table_operations.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include "table_operations.h"

namespace kuzu {
namespace function {

class BuiltInTableOperations {

public:
BuiltInTableOperations() { registerTableOperations(); }

TableOperationDefinition* mathTableOperation(const std::string& name);

private:
void registerTableOperations();

private:
// TODO(Ziyi): Refactor vectorOperation/tableOperation to inherit from the same base class.
std::unordered_map<std::string, std::unique_ptr<TableOperationDefinition>> tableOperations;
};

} // namespace function
} // namespace kuzu
3 changes: 2 additions & 1 deletion src/include/function/built_in_vector_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class BuiltInVectorOperations {
bool canApplyStaticEvaluation(
const std::string& functionName, const binder::expression_vector& children);

VectorOperationDefinition* matchFunction(
VectorOperationDefinition* matchVectorOperation(
const std::string& name, const std::vector<common::LogicalType>& inputTypes);

std::vector<std::string> getFunctionNames();
Expand Down Expand Up @@ -78,6 +78,7 @@ class BuiltInVectorOperations {
void registerNodeRelOperations();

private:
// TODO(Ziyi): Refactor vectorOperation/tableOperation to inherit from the same base class.
std::unordered_map<std::string, vector_operation_definitions> vectorOperations;
};

Expand Down
Loading
Loading