Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement function catalog-entry #2910

Merged
merged 1 commit into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
for (auto& val : inputValues) {
inputTypes.push_back(*val.getDataType());
}
auto func = catalog.getBuiltInFunctions(clientContext->getTx())
->matchFunction(functionExpr->getFunctionName(), inputTypes);
auto func = BuiltInFunctionsUtils::matchFunction(
functionExpr->getFunctionName(), inputTypes, catalog.getFunctions(clientContext->getTx()));
auto tableFunc = ku_dynamic_cast<function::Function*, function::TableFunction*>(func);
auto bindInput = std::make_unique<function::TableFuncBindInput>();
bindInput->inputs = std::move(inputValues);
Expand Down Expand Up @@ -157,9 +157,9 @@ std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(const ReadingClause& re
auto objectExpr = expressionBinder.bindVariableExpression(objectName);
auto literalExpr =
ku_dynamic_cast<const Expression*, const LiteralExpression*>(objectExpr.get());
auto func = catalog.getBuiltInFunctions(clientContext->getTx())
->matchFunction(READ_PANDAS_FUNC_NAME,
std::vector<LogicalType>{objectExpr->getDataType()});
auto func = BuiltInFunctionsUtils::matchFunction(READ_PANDAS_FUNC_NAME,
std::vector<LogicalType>{objectExpr->getDataType()},
catalog.getFunctions(clientContext->getTx()));
scanFunction = ku_dynamic_cast<Function*, TableFunction*>(func);
bindInput = std::make_unique<function::TableFuncBindInput>();
bindInput->inputs.push_back(*literalExpr->getValue());
Expand Down
26 changes: 16 additions & 10 deletions src/binder/bind/copy/bind_copy_rdf_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace binder {

std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(const parser::Statement&,
std::unique_ptr<ReaderConfig> config, RDFGraphCatalogEntry* rdfGraphEntry) {
auto functions = catalog.getBuiltInFunctions(clientContext->getTx());
auto functions = catalog.getFunctions(clientContext->getTx());
auto offset = expressionBinder.createVariableExpression(
*LogicalType::INT64(), InternalKeyword::ROW_OFFSET);
auto r = expressionBinder.createVariableExpression(*LogicalType::STRING(), rdf::IRI);
Expand All @@ -36,15 +36,16 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(const parser::Statement&
Function* func;
// Bind file scan;
auto inMemory = RdfReaderConfig::construct(config->options).inMemory;
func = functions->matchFunction(READ_RDF_ALL_TRIPLE_FUNC_NAME);
func = BuiltInFunctionsUtils::matchFunction(READ_RDF_ALL_TRIPLE_FUNC_NAME, functions);
auto scanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto bindData =
scanFunc->bindFunc(clientContext, bindInput.get(), (Catalog*)&catalog, storageManager);
auto scanInfo = std::make_unique<BoundFileScanInfo>(
scanFunc, bindData->copy(), expression_vector{}, offset);
// Bind copy resource.
func = inMemory ? functions->matchFunction(IN_MEM_READ_RDF_RESOURCE_FUNC_NAME) :
functions->matchFunction(READ_RDF_RESOURCE_FUNC_NAME);
func = inMemory ?
BuiltInFunctionsUtils::matchFunction(IN_MEM_READ_RDF_RESOURCE_FUNC_NAME, functions) :
BuiltInFunctionsUtils::matchFunction(READ_RDF_RESOURCE_FUNC_NAME, functions);
auto rScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto rColumns = expression_vector{r};
auto rScanInfo = std::make_unique<BoundFileScanInfo>(
Expand All @@ -53,8 +54,9 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(const parser::Statement&
auto rSchema = catalog.getTableCatalogEntry(clientContext->getTx(), rTableID);
auto rCopyInfo = BoundCopyFromInfo(rSchema, std::move(rScanInfo), false, nullptr);
// Bind copy literal.
func = inMemory ? functions->matchFunction(IN_MEM_READ_RDF_LITERAL_FUNC_NAME) :
functions->matchFunction(READ_RDF_LITERAL_FUNC_NAME);
func = inMemory ?
BuiltInFunctionsUtils::matchFunction(IN_MEM_READ_RDF_LITERAL_FUNC_NAME, functions) :
BuiltInFunctionsUtils::matchFunction(READ_RDF_LITERAL_FUNC_NAME, functions);
auto lScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto lColumns = expression_vector{l, lang};
auto lScanInfo = std::make_unique<BoundFileScanInfo>(
Expand All @@ -63,8 +65,10 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(const parser::Statement&
auto lSchema = catalog.getTableCatalogEntry(clientContext->getTx(), lTableID);
auto lCopyInfo = BoundCopyFromInfo(lSchema, std::move(lScanInfo), true, nullptr);
// Bind copy resource triples
func = inMemory ? functions->matchFunction(IN_MEM_READ_RDF_RESOURCE_TRIPLE_FUNC_NAME) :
functions->matchFunction(READ_RDF_RESOURCE_TRIPLE_FUNC_NAME);
func = inMemory ?
BuiltInFunctionsUtils::matchFunction(
IN_MEM_READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, functions) :
BuiltInFunctionsUtils::matchFunction(READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, functions);
auto rrrScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto rrrColumns = expression_vector{s, p, o};
auto rrrScanInfo =
Expand All @@ -83,8 +87,10 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(const parser::Statement&
auto rrrCopyInfo =
BoundCopyFromInfo(rrrSchema, std::move(rrrScanInfo), false, std::move(rrrExtraInfo));
// Bind copy literal triples
func = inMemory ? functions->matchFunction(IN_MEM_READ_RDF_LITERAL_TRIPLE_FUNC_NAME) :
functions->matchFunction(READ_RDF_LITERAL_TRIPLE_FUNC_NAME);
func = inMemory ?
BuiltInFunctionsUtils::matchFunction(
IN_MEM_READ_RDF_LITERAL_TRIPLE_FUNC_NAME, functions) :
BuiltInFunctionsUtils::matchFunction(READ_RDF_LITERAL_TRIPLE_FUNC_NAME, functions);
auto rrlScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto rrlColumns = expression_vector{s, p, oOffset};
auto rrlScanInfo =
Expand Down
5 changes: 3 additions & 2 deletions src/binder/bind_expression/bind_comparison_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(

std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
ExpressionType expressionType, const expression_vector& children) {
auto builtInFunctions = binder->catalog.getBuiltInFunctions(binder->clientContext->getTx());
auto builtInFunctions = binder->catalog.getFunctions(binder->clientContext->getTx());
auto functionName = expressionTypeToString(expressionType);
std::vector<LogicalType> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
}
auto function = ku_dynamic_cast<function::Function*, function::ScalarFunction*>(
builtInFunctions->matchFunction(functionName, childrenTypes));
function::BuiltInFunctionsUtils::matchFunction(
functionName, childrenTypes, builtInFunctions));
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
childrenAfterCast.push_back(
Expand Down
10 changes: 5 additions & 5 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(

std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const expression_vector& children, const std::string& functionName) {
auto builtInFunctions = binder->catalog.getBuiltInFunctions(binder->clientContext->getTx());
std::vector<LogicalType> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
}
auto function = ku_dynamic_cast<function::Function*, function::ScalarFunction*>(
builtInFunctions->matchFunction(functionName, childrenTypes));
function::BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes,
binder->catalog.getFunctions(binder->clientContext->getTx())));
expression_vector childrenAfterCast;
std::unique_ptr<function::FunctionBindData> bindData;
if (functionName == CAST_FUNC_NAME) {
Expand Down Expand Up @@ -98,7 +98,6 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(

std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
const ParsedExpression& parsedExpression, const std::string& functionName, bool isDistinct) {
auto builtInFunctions = binder->catalog.getBuiltInFunctions(binder->clientContext->getTx());
std::vector<LogicalType> childrenTypes;
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
Expand All @@ -111,8 +110,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
childrenTypes.push_back(child->dataType);
children.push_back(std::move(child));
}
auto function =
builtInFunctions->matchAggregateFunction(functionName, childrenTypes, isDistinct)->clone();
auto function = function::BuiltInFunctionsUtils::matchAggregateFunction(functionName,
childrenTypes, isDistinct, binder->catalog.getFunctions(binder->clientContext->getTx()))
->clone();
if (function->paramRewriteFunc) {
function->paramRewriteFunc(children);
}
Expand Down
7 changes: 3 additions & 4 deletions src/binder/bind_expression/bind_subquery_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "binder/expression/subquery_expression.h"
#include "binder/expression_binder.h"
#include "common/types/value/value.h"
#include "main/client_context.h"
#include "parser/expression/parsed_subquery_expression.h"

using namespace kuzu::parser;
Expand Down Expand Up @@ -32,9 +31,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindSubqueryExpression(
std::move(boundGraphPattern.queryGraphCollection), uniqueName, std::move(rawName));
boundSubqueryExpr->setWhereExpression(boundGraphPattern.where);
// Bind projection
auto function =
binder->catalog.getBuiltInFunctions(binder->clientContext->getTx())
->matchAggregateFunction(COUNT_STAR_FUNC_NAME, std::vector<LogicalType>{}, false);
auto function = BuiltInFunctionsUtils::matchAggregateFunction(COUNT_STAR_FUNC_NAME,
std::vector<LogicalType>{}, false,
binder->catalog.getFunctions(binder->clientContext->getTx()));
auto bindData =
std::make_unique<FunctionBindData>(std::make_unique<LogicalType>(function->returnTypeID));
auto countStarExpr = std::make_shared<AggregateFunctionExpression>(COUNT_STAR_FUNC_NAME,
Expand Down
12 changes: 7 additions & 5 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,19 +209,21 @@ function::TableFunction* Binder::getScanFunction(FileType fileType, const Reader
auto stringType = LogicalType(LogicalTypeID::STRING);
std::vector<LogicalType> inputTypes;
inputTypes.push_back(stringType);
auto functions = catalog.getBuiltInFunctions(clientContext->getTx());
auto functions = catalog.getFunctions(clientContext->getTx());
switch (fileType) {
case FileType::PARQUET: {
func = functions->matchFunction(READ_PARQUET_FUNC_NAME, inputTypes);
func = function::BuiltInFunctionsUtils::matchFunction(
READ_PARQUET_FUNC_NAME, inputTypes, functions);
} break;
case FileType::NPY: {
func = functions->matchFunction(READ_NPY_FUNC_NAME, inputTypes);
func = function::BuiltInFunctionsUtils::matchFunction(
READ_NPY_FUNC_NAME, inputTypes, functions);
} break;
case FileType::CSV: {
auto csvConfig = CSVReaderConfig::construct(config.options);
func = functions->matchFunction(
func = function::BuiltInFunctionsUtils::matchFunction(
csvConfig.parallel ? READ_CSV_PARALLEL_FUNC_NAME : READ_CSV_SERIAL_FUNC_NAME,
inputTypes);
inputTypes, functions);
} break;
default:
KU_UNREACHABLE;
Expand Down
10 changes: 8 additions & 2 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ std::vector<TableCatalogEntry*> Catalog::getTableSchemas(
return result;
}

CatalogSet* Catalog::getFunctions(Transaction* tx) const {
return getVersion(tx)->functions.get();
}

void Catalog::prepareCommitOrRollback(TransactionAction action) {
if (hasUpdates()) {
wal->logCatalogRecord();
Expand Down Expand Up @@ -265,8 +269,10 @@ void Catalog::addScalarMacroFunction(

std::vector<std::string> Catalog::getMacroNames(transaction::Transaction* tx) const {
std::vector<std::string> macroNames;
for (auto& macro : getVersion(tx)->macros) {
macroNames.push_back(macro.first);
for (auto& [_, function] : getVersion(tx)->functions->getEntries()) {
if (function->getType() == CatalogEntryType::SCALAR_MACRO_ENTRY) {
macroNames.push_back(function->getName());
}
}
return macroNames;
}
Expand Down
53 changes: 31 additions & 22 deletions src/catalog/catalog_content.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "catalog/catalog_entry/rdf_graph_catalog_entry.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "catalog/catalog_entry/rel_table_catalog_entry.h"
#include "catalog/catalog_entry/scalar_function_catalog_entry.h"
#include "catalog/catalog_entry/scalar_macro_catalog_entry.h"
#include "common/cast.h"
#include "common/exception/catalog.h"
#include "common/exception/runtime.h"
Expand All @@ -15,7 +17,6 @@
#include "common/serializer/deserializer.h"
#include "common/serializer/serializer.h"
#include "common/string_format.h"
#include "common/string_utils.h"
#include "storage/storage_info.h"
#include "storage/storage_utils.h"

Expand All @@ -27,8 +28,9 @@ namespace kuzu {
namespace catalog {

CatalogContent::CatalogContent(common::VirtualFileSystem* vfs) : nextTableID{0}, vfs{vfs} {
registerBuiltInFunctions();
tables = std::make_unique<CatalogSet>();
functions = std::make_unique<CatalogSet>();
registerBuiltInFunctions();
}

CatalogContent::CatalogContent(const std::string& directory, VirtualFileSystem* vfs) : vfs{vfs} {
Expand Down Expand Up @@ -180,7 +182,7 @@ void CatalogContent::saveToFile(const std::string& directory, FileVersionType db
serializer.serializeValue(StorageVersionInfo::getStorageVersion());
tables->serialize(serializer);
serializer.serializeValue(nextTableID);
serializer.serializeUnorderedMap(macros);
functions->serialize(serializer);
}

void CatalogContent::readFromFile(const std::string& directory, FileVersionType dbFileType) {
Expand All @@ -197,47 +199,54 @@ void CatalogContent::readFromFile(const std::string& directory, FileVersionType
ku_dynamic_cast<CatalogEntry*, TableCatalogEntry*>(entry.get())->getTableID();
}
deserializer.deserializeValue(nextTableID);
deserializer.deserializeUnorderedMap(macros);
functions = CatalogSet::deserialize(deserializer);
}

ExpressionType CatalogContent::getFunctionType(const std::string& name) const {
auto normalizedName = StringUtils::getUpper(name);
if (macros.contains(normalizedName)) {
return ExpressionType::MACRO;
if (!functions->containsEntry(name)) {
throw CatalogException{common::stringFormat("function {} does not exist.", name)};
}
auto functionType = builtInFunctions->getFunctionType(name);
switch (functionType) {
case function::FunctionType::SCALAR:
auto functionEntry = functions->getEntry(name);
switch (functionEntry->getType()) {
case CatalogEntryType::SCALAR_MACRO_ENTRY:
return ExpressionType::MACRO;
case CatalogEntryType::SCALAR_FUNCTION_ENTRY:
return ExpressionType::FUNCTION;
case function::FunctionType::AGGREGATE:
case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY:
return ExpressionType::AGGREGATE_FUNCTION;
default:
KU_UNREACHABLE;
}
}

void CatalogContent::addFunction(std::string name, function::function_set definitions) {
StringUtils::toUpper(name);
builtInFunctions->addFunction(std::move(name), std::move(definitions));
if (functions->containsEntry(name)) {
throw CatalogException{common::stringFormat("function {} already exists.", name)};
}
functions->createEntry(
std::make_unique<ScalarFunctionCatalogEntry>(std::move(name), std::move(definitions)));
}

void CatalogContent::addScalarMacroFunction(
std::string name, std::unique_ptr<function::ScalarMacroFunction> macro) {
StringUtils::toUpper(name);
macros.emplace(std::move(name), std::move(macro));
functions->createEntry(
std::make_unique<ScalarMacroCatalogEntry>(std::move(name), std::move(macro)));
}

function::ScalarMacroFunction* CatalogContent::getScalarMacroFunction(
const std::string& name) const {
return ku_dynamic_cast<CatalogEntry*, ScalarMacroCatalogEntry*>(functions->getEntry(name))
->getMacroFunction();
}

std::unique_ptr<CatalogContent> CatalogContent::copy() const {
std::unordered_map<std::string, std::unique_ptr<function::ScalarMacroFunction>> macrosToCopy;
for (auto& macro : macros) {
macrosToCopy.emplace(macro.first, macro.second->copy());
}
return std::make_unique<CatalogContent>(tables->copy(), tableNameToIDMap, nextTableID,
builtInFunctions->copy(), std::move(macrosToCopy), vfs);
return std::make_unique<CatalogContent>(
tables->copy(), tableNameToIDMap, nextTableID, functions->copy(), vfs);
}

void CatalogContent::registerBuiltInFunctions() {
builtInFunctions = std::make_unique<function::BuiltInFunctions>();
function::BuiltInFunctionsUtils::createFunctions(functions.get());
}

bool CatalogContent::containsTable(const std::string& tableName) const {
Expand All @@ -264,7 +273,7 @@ CatalogEntry* CatalogContent::getTableCatalogEntry(table_id_t tableID) const {
return table.get();
}
}
KU_ASSERT(false);
KU_UNREACHABLE;
}

common::table_id_t CatalogContent::getTableID(const std::string& tableName) const {
Expand Down
7 changes: 6 additions & 1 deletion src/catalog/catalog_entry/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
add_library(kuzu_catalog_entry
OBJECT
aggregate_function_catalog_entry.cpp
catalog_entry.cpp
function_catalog_entry.cpp
table_catalog_entry.cpp
node_table_catalog_entry.cpp
rel_table_catalog_entry.cpp
rel_group_catalog_entry.cpp
rdf_graph_catalog_entry.cpp)
rdf_graph_catalog_entry.cpp
scalar_macro_catalog_entry.cpp
scalar_function_catalog_entry.cpp
table_function_catalog_entry.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_catalog_entry>
Expand Down
19 changes: 19 additions & 0 deletions src/catalog/catalog_entry/aggregate_function_catalog_entry.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "catalog/catalog_entry/aggregate_function_catalog_entry.h"

#include "common/utils.h"

namespace kuzu {
namespace catalog {

AggregateFunctionCatalogEntry::AggregateFunctionCatalogEntry(
std::string name, function::function_set functionSet)
: FunctionCatalogEntry{
CatalogEntryType::AGGREGATE_FUNCTION_ENTRY, std::move(name), std::move(functionSet)} {}

std::unique_ptr<CatalogEntry> AggregateFunctionCatalogEntry::copy() const {
return std::make_unique<AggregateFunctionCatalogEntry>(
getName(), common::copyVector(functionSet));
}

} // namespace catalog
} // namespace kuzu
8 changes: 6 additions & 2 deletions src/catalog/catalog_entry/catalog_entry.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "catalog/catalog_entry/catalog_entry.h"

#include "catalog/catalog_entry/scalar_macro_catalog_entry.h"
#include "catalog/catalog_entry/table_catalog_entry.h"

namespace kuzu {
Expand All @@ -20,9 +21,12 @@ std::unique_ptr<CatalogEntry> CatalogEntry::deserialize(common::Deserializer& de
case CatalogEntryType::NODE_TABLE_ENTRY:
case CatalogEntryType::REL_TABLE_ENTRY:
case CatalogEntryType::REL_GROUP_ENTRY:
case CatalogEntryType::RDF_GRAPH_ENTRY:
case CatalogEntryType::RDF_GRAPH_ENTRY: {
entry = TableCatalogEntry::deserialize(deserializer, type);
break;
} break;
case CatalogEntryType::SCALAR_MACRO_ENTRY: {
entry = ScalarMacroCatalogEntry::deserialize(deserializer);
} break;
default:
KU_UNREACHABLE;
}
Expand Down
Loading
Loading