Skip to content

Commit

Permalink
Implement function catalog-entry
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Feb 19, 2024
1 parent 0d57b0b commit 67a814e
Show file tree
Hide file tree
Showing 42 changed files with 1,649 additions and 1,167 deletions.
10 changes: 5 additions & 5 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
for (auto& val : inputValues) {
inputTypes.push_back(*val.getDataType());
}
auto func = catalog.getBuiltInFunctions(clientContext->getTx())
->matchFunction(functionExpr->getFunctionName(), inputTypes);
auto func = BuiltInFunctionsUtils::matchFunction(
functionExpr->getFunctionName(), inputTypes, catalog.getFunctions(clientContext->getTx()));
auto tableFunc = ku_dynamic_cast<function::Function*, function::TableFunction*>(func);
auto bindInput = std::make_unique<function::TableFuncBindInput>();
bindInput->inputs = std::move(inputValues);
Expand Down Expand Up @@ -157,9 +157,9 @@ std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(const ReadingClause& re
auto objectExpr = expressionBinder.bindVariableExpression(objectName);
auto literalExpr =
ku_dynamic_cast<const Expression*, const LiteralExpression*>(objectExpr.get());
auto func = catalog.getBuiltInFunctions(clientContext->getTx())
->matchFunction(READ_PANDAS_FUNC_NAME,
std::vector<LogicalType>{objectExpr->getDataType()});
auto func = BuiltInFunctionsUtils::matchFunction(READ_PANDAS_FUNC_NAME,
std::vector<LogicalType>{objectExpr->getDataType()},
catalog.getFunctions(clientContext->getTx()));
scanFunction = ku_dynamic_cast<Function*, TableFunction*>(func);
bindInput = std::make_unique<function::TableFuncBindInput>();
bindInput->inputs.push_back(*literalExpr->getValue());
Expand Down
26 changes: 16 additions & 10 deletions src/binder/bind/copy/bind_copy_rdf_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace binder {

std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(const parser::Statement&,
std::unique_ptr<ReaderConfig> config, RDFGraphCatalogEntry* rdfGraphEntry) {
auto functions = catalog.getBuiltInFunctions(clientContext->getTx());
auto functions = catalog.getFunctions(clientContext->getTx());
auto offset = expressionBinder.createVariableExpression(
*LogicalType::INT64(), InternalKeyword::ROW_OFFSET);
auto r = expressionBinder.createVariableExpression(*LogicalType::STRING(), rdf::IRI);
Expand All @@ -36,15 +36,16 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(const parser::Statement&
Function* func;
// Bind file scan;
auto inMemory = RdfReaderConfig::construct(config->options).inMemory;
func = functions->matchFunction(READ_RDF_ALL_TRIPLE_FUNC_NAME);
func = BuiltInFunctionsUtils::matchFunction(READ_RDF_ALL_TRIPLE_FUNC_NAME, functions);
auto scanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto bindData =
scanFunc->bindFunc(clientContext, bindInput.get(), (Catalog*)&catalog, storageManager);
auto scanInfo = std::make_unique<BoundFileScanInfo>(
scanFunc, bindData->copy(), expression_vector{}, offset);
// Bind copy resource.
func = inMemory ? functions->matchFunction(IN_MEM_READ_RDF_RESOURCE_FUNC_NAME) :
functions->matchFunction(READ_RDF_RESOURCE_FUNC_NAME);
func = inMemory ?
BuiltInFunctionsUtils::matchFunction(IN_MEM_READ_RDF_RESOURCE_FUNC_NAME, functions) :
BuiltInFunctionsUtils::matchFunction(READ_RDF_RESOURCE_FUNC_NAME, functions);
auto rScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto rColumns = expression_vector{r};
auto rScanInfo = std::make_unique<BoundFileScanInfo>(
Expand All @@ -53,8 +54,9 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(const parser::Statement&
auto rSchema = catalog.getTableCatalogEntry(clientContext->getTx(), rTableID);
auto rCopyInfo = BoundCopyFromInfo(rSchema, std::move(rScanInfo), false, nullptr);
// Bind copy literal.
func = inMemory ? functions->matchFunction(IN_MEM_READ_RDF_LITERAL_FUNC_NAME) :
functions->matchFunction(READ_RDF_LITERAL_FUNC_NAME);
func = inMemory ?
BuiltInFunctionsUtils::matchFunction(IN_MEM_READ_RDF_LITERAL_FUNC_NAME, functions) :
BuiltInFunctionsUtils::matchFunction(READ_RDF_LITERAL_FUNC_NAME, functions);
auto lScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto lColumns = expression_vector{l, lang};
auto lScanInfo = std::make_unique<BoundFileScanInfo>(
Expand All @@ -63,8 +65,10 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(const parser::Statement&
auto lSchema = catalog.getTableCatalogEntry(clientContext->getTx(), lTableID);
auto lCopyInfo = BoundCopyFromInfo(lSchema, std::move(lScanInfo), true, nullptr);
// Bind copy resource triples
func = inMemory ? functions->matchFunction(IN_MEM_READ_RDF_RESOURCE_TRIPLE_FUNC_NAME) :
functions->matchFunction(READ_RDF_RESOURCE_TRIPLE_FUNC_NAME);
func = inMemory ?
BuiltInFunctionsUtils::matchFunction(
IN_MEM_READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, functions) :
BuiltInFunctionsUtils::matchFunction(READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, functions);
auto rrrScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto rrrColumns = expression_vector{s, p, o};
auto rrrScanInfo =
Expand All @@ -83,8 +87,10 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(const parser::Statement&
auto rrrCopyInfo =
BoundCopyFromInfo(rrrSchema, std::move(rrrScanInfo), false, std::move(rrrExtraInfo));
// Bind copy literal triples
func = inMemory ? functions->matchFunction(IN_MEM_READ_RDF_LITERAL_TRIPLE_FUNC_NAME) :
functions->matchFunction(READ_RDF_LITERAL_TRIPLE_FUNC_NAME);
func = inMemory ?
BuiltInFunctionsUtils::matchFunction(
IN_MEM_READ_RDF_LITERAL_TRIPLE_FUNC_NAME, functions) :
BuiltInFunctionsUtils::matchFunction(READ_RDF_LITERAL_TRIPLE_FUNC_NAME, functions);
auto rrlScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto rrlColumns = expression_vector{s, p, oOffset};
auto rrlScanInfo =
Expand Down
5 changes: 3 additions & 2 deletions src/binder/bind_expression/bind_comparison_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(

std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
ExpressionType expressionType, const expression_vector& children) {
auto builtInFunctions = binder->catalog.getBuiltInFunctions(binder->clientContext->getTx());
auto builtInFunctions = binder->catalog.getFunctions(binder->clientContext->getTx());
auto functionName = expressionTypeToString(expressionType);
std::vector<LogicalType> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
}
auto function = ku_dynamic_cast<function::Function*, function::ScalarFunction*>(
builtInFunctions->matchFunction(functionName, childrenTypes));
function::BuiltInFunctionsUtils::matchFunction(
functionName, childrenTypes, builtInFunctions));
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
childrenAfterCast.push_back(
Expand Down
10 changes: 5 additions & 5 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(

std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
const expression_vector& children, const std::string& functionName) {
auto builtInFunctions = binder->catalog.getBuiltInFunctions(binder->clientContext->getTx());
std::vector<LogicalType> childrenTypes;
for (auto& child : children) {
childrenTypes.push_back(child->dataType);
}
auto function = ku_dynamic_cast<function::Function*, function::ScalarFunction*>(
builtInFunctions->matchFunction(functionName, childrenTypes));
function::BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes,
binder->catalog.getFunctions(binder->clientContext->getTx())));
expression_vector childrenAfterCast;
std::unique_ptr<function::FunctionBindData> bindData;
if (functionName == CAST_FUNC_NAME) {
Expand Down Expand Up @@ -98,7 +98,6 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(

std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
const ParsedExpression& parsedExpression, const std::string& functionName, bool isDistinct) {
auto builtInFunctions = binder->catalog.getBuiltInFunctions(binder->clientContext->getTx());
std::vector<LogicalType> childrenTypes;
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
Expand All @@ -111,8 +110,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
childrenTypes.push_back(child->dataType);
children.push_back(std::move(child));
}
auto function =
builtInFunctions->matchAggregateFunction(functionName, childrenTypes, isDistinct)->clone();
auto function = function::BuiltInFunctionsUtils::matchAggregateFunction(functionName,
childrenTypes, isDistinct, binder->catalog.getFunctions(binder->clientContext->getTx()))
->clone();
if (function->paramRewriteFunc) {
function->paramRewriteFunc(children);
}
Expand Down
7 changes: 3 additions & 4 deletions src/binder/bind_expression/bind_subquery_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#include "binder/expression/subquery_expression.h"
#include "binder/expression_binder.h"
#include "common/types/value/value.h"
#include "main/client_context.h"
#include "parser/expression/parsed_subquery_expression.h"

using namespace kuzu::parser;
Expand Down Expand Up @@ -32,9 +31,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindSubqueryExpression(
std::move(boundGraphPattern.queryGraphCollection), uniqueName, std::move(rawName));
boundSubqueryExpr->setWhereExpression(boundGraphPattern.where);
// Bind projection
auto function =
binder->catalog.getBuiltInFunctions(binder->clientContext->getTx())
->matchAggregateFunction(COUNT_STAR_FUNC_NAME, std::vector<LogicalType>{}, false);
auto function = BuiltInFunctionsUtils::matchAggregateFunction(COUNT_STAR_FUNC_NAME,
std::vector<LogicalType>{}, false,
binder->catalog.getFunctions(binder->clientContext->getTx()));
auto bindData =
std::make_unique<FunctionBindData>(std::make_unique<LogicalType>(function->returnTypeID));
auto countStarExpr = std::make_shared<AggregateFunctionExpression>(COUNT_STAR_FUNC_NAME,
Expand Down
12 changes: 7 additions & 5 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,19 +209,21 @@ function::TableFunction* Binder::getScanFunction(FileType fileType, const Reader
auto stringType = LogicalType(LogicalTypeID::STRING);
std::vector<LogicalType> inputTypes;
inputTypes.push_back(stringType);
auto functions = catalog.getBuiltInFunctions(clientContext->getTx());
auto functions = catalog.getFunctions(clientContext->getTx());
switch (fileType) {
case FileType::PARQUET: {
func = functions->matchFunction(READ_PARQUET_FUNC_NAME, inputTypes);
func = function::BuiltInFunctionsUtils::matchFunction(
READ_PARQUET_FUNC_NAME, inputTypes, functions);
} break;
case FileType::NPY: {
func = functions->matchFunction(READ_NPY_FUNC_NAME, inputTypes);
func = function::BuiltInFunctionsUtils::matchFunction(
READ_NPY_FUNC_NAME, inputTypes, functions);
} break;
case FileType::CSV: {
auto csvConfig = CSVReaderConfig::construct(config.options);
func = functions->matchFunction(
func = function::BuiltInFunctionsUtils::matchFunction(
csvConfig.parallel ? READ_CSV_PARALLEL_FUNC_NAME : READ_CSV_SERIAL_FUNC_NAME,
inputTypes);
inputTypes, functions);
} break;
default:
KU_UNREACHABLE;
Expand Down
10 changes: 8 additions & 2 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ std::vector<TableCatalogEntry*> Catalog::getTableSchemas(
return result;
}

CatalogSet* Catalog::getFunctions(Transaction* tx) const {
return getVersion(tx)->functions.get();
}

void Catalog::prepareCommitOrRollback(TransactionAction action) {
if (hasUpdates()) {
wal->logCatalogRecord();
Expand Down Expand Up @@ -265,8 +269,10 @@ void Catalog::addScalarMacroFunction(

std::vector<std::string> Catalog::getMacroNames(transaction::Transaction* tx) const {
std::vector<std::string> macroNames;
for (auto& macro : getVersion(tx)->macros) {
macroNames.push_back(macro.first);
for (auto& [_, function] : getVersion(tx)->functions->getEntries()) {
if (function->getType() == CatalogEntryType::SCALAR_MACRO_ENTRY) {
macroNames.push_back(function->getName());
}
}
return macroNames;
}
Expand Down
53 changes: 31 additions & 22 deletions src/catalog/catalog_content.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "catalog/catalog_entry/rdf_graph_catalog_entry.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "catalog/catalog_entry/rel_table_catalog_entry.h"
#include "catalog/catalog_entry/scalar_function_catalog_entry.h"
#include "catalog/catalog_entry/scalar_macro_catalog_entry.h"
#include "common/cast.h"
#include "common/exception/catalog.h"
#include "common/exception/runtime.h"
Expand All @@ -15,7 +17,6 @@
#include "common/serializer/deserializer.h"
#include "common/serializer/serializer.h"
#include "common/string_format.h"
#include "common/string_utils.h"
#include "storage/storage_info.h"
#include "storage/storage_utils.h"

Expand All @@ -27,8 +28,9 @@ namespace kuzu {
namespace catalog {

CatalogContent::CatalogContent(common::VirtualFileSystem* vfs) : nextTableID{0}, vfs{vfs} {
registerBuiltInFunctions();
tables = std::make_unique<CatalogSet>();
functions = std::make_unique<CatalogSet>();
registerBuiltInFunctions();
}

CatalogContent::CatalogContent(const std::string& directory, VirtualFileSystem* vfs) : vfs{vfs} {
Expand Down Expand Up @@ -180,7 +182,7 @@ void CatalogContent::saveToFile(const std::string& directory, FileVersionType db
serializer.serializeValue(StorageVersionInfo::getStorageVersion());
tables->serialize(serializer);
serializer.serializeValue(nextTableID);
serializer.serializeUnorderedMap(macros);
functions->serialize(serializer);
}

void CatalogContent::readFromFile(const std::string& directory, FileVersionType dbFileType) {
Expand All @@ -197,47 +199,54 @@ void CatalogContent::readFromFile(const std::string& directory, FileVersionType
ku_dynamic_cast<CatalogEntry*, TableCatalogEntry*>(entry.get())->getTableID();
}
deserializer.deserializeValue(nextTableID);
deserializer.deserializeUnorderedMap(macros);
functions = CatalogSet::deserialize(deserializer);
}

ExpressionType CatalogContent::getFunctionType(const std::string& name) const {
auto normalizedName = StringUtils::getUpper(name);
if (macros.contains(normalizedName)) {
return ExpressionType::MACRO;
if (!functions->containsEntry(name)) {
throw CatalogException{common::stringFormat("function {} does not exist.", name)};
}
auto functionType = builtInFunctions->getFunctionType(name);
switch (functionType) {
case function::FunctionType::SCALAR:
auto functionEntry = functions->getEntry(name);
switch (functionEntry->getType()) {
case CatalogEntryType::SCALAR_MACRO_ENTRY:
return ExpressionType::MACRO;
case CatalogEntryType::SCALAR_FUNCTION_ENTRY:
return ExpressionType::FUNCTION;
case function::FunctionType::AGGREGATE:
case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY:
return ExpressionType::AGGREGATE_FUNCTION;
default:
KU_UNREACHABLE;
}
}

void CatalogContent::addFunction(std::string name, function::function_set definitions) {
StringUtils::toUpper(name);
builtInFunctions->addFunction(std::move(name), std::move(definitions));
if (functions->containsEntry(name)) {
throw CatalogException{common::stringFormat("function {} already exists.", name)};
}
functions->createEntry(
std::make_unique<ScalarFunctionCatalogEntry>(std::move(name), std::move(definitions)));
}

void CatalogContent::addScalarMacroFunction(
std::string name, std::unique_ptr<function::ScalarMacroFunction> macro) {
StringUtils::toUpper(name);
macros.emplace(std::move(name), std::move(macro));
functions->createEntry(
std::make_unique<ScalarMacroCatalogEntry>(std::move(name), std::move(macro)));
}

function::ScalarMacroFunction* CatalogContent::getScalarMacroFunction(
const std::string& name) const {
return ku_dynamic_cast<CatalogEntry*, ScalarMacroCatalogEntry*>(functions->getEntry(name))
->getMacroFunction();
}

std::unique_ptr<CatalogContent> CatalogContent::copy() const {
std::unordered_map<std::string, std::unique_ptr<function::ScalarMacroFunction>> macrosToCopy;
for (auto& macro : macros) {
macrosToCopy.emplace(macro.first, macro.second->copy());
}
return std::make_unique<CatalogContent>(tables->copy(), tableNameToIDMap, nextTableID,
builtInFunctions->copy(), std::move(macrosToCopy), vfs);
return std::make_unique<CatalogContent>(
tables->copy(), tableNameToIDMap, nextTableID, functions->copy(), vfs);
}

void CatalogContent::registerBuiltInFunctions() {
builtInFunctions = std::make_unique<function::BuiltInFunctions>();
function::BuiltInFunctionsUtils::createFunctions(functions.get());
}

bool CatalogContent::containsTable(const std::string& tableName) const {
Expand All @@ -264,7 +273,7 @@ CatalogEntry* CatalogContent::getTableCatalogEntry(table_id_t tableID) const {
return table.get();
}
}
KU_ASSERT(false);
KU_UNREACHABLE;
}

common::table_id_t CatalogContent::getTableID(const std::string& tableName) const {
Expand Down
7 changes: 6 additions & 1 deletion src/catalog/catalog_entry/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
add_library(kuzu_catalog_entry
OBJECT
aggregate_function_catalog_entry.cpp
catalog_entry.cpp
function_catalog_entry.cpp
table_catalog_entry.cpp
node_table_catalog_entry.cpp
rel_table_catalog_entry.cpp
rel_group_catalog_entry.cpp
rdf_graph_catalog_entry.cpp)
rdf_graph_catalog_entry.cpp
scalar_macro_catalog_entry.cpp
scalar_function_catalog_entry.cpp
table_function_catalog_entry.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_catalog_entry>
Expand Down
19 changes: 19 additions & 0 deletions src/catalog/catalog_entry/aggregate_function_catalog_entry.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "catalog/catalog_entry/aggregate_function_catalog_entry.h"

#include "common/utils.h"

namespace kuzu {
namespace catalog {

AggregateFunctionCatalogEntry::AggregateFunctionCatalogEntry(
std::string name, function::function_set functionSet)
: FunctionCatalogEntry{
CatalogEntryType::AGGREGATE_FUNCTION_ENTRY, std::move(name), std::move(functionSet)} {}

std::unique_ptr<CatalogEntry> AggregateFunctionCatalogEntry::copy() const {
return std::make_unique<AggregateFunctionCatalogEntry>(
getName(), common::copyVector(functionSet));
}

} // namespace catalog
} // namespace kuzu
8 changes: 6 additions & 2 deletions src/catalog/catalog_entry/catalog_entry.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "catalog/catalog_entry/catalog_entry.h"

#include "catalog/catalog_entry/scalar_macro_catalog_entry.h"
#include "catalog/catalog_entry/table_catalog_entry.h"

namespace kuzu {
Expand All @@ -20,9 +21,12 @@ std::unique_ptr<CatalogEntry> CatalogEntry::deserialize(common::Deserializer& de
case CatalogEntryType::NODE_TABLE_ENTRY:
case CatalogEntryType::REL_TABLE_ENTRY:
case CatalogEntryType::REL_GROUP_ENTRY:
case CatalogEntryType::RDF_GRAPH_ENTRY:
case CatalogEntryType::RDF_GRAPH_ENTRY: {
entry = TableCatalogEntry::deserialize(deserializer, type);
break;
} break;
case CatalogEntryType::SCALAR_MACRO_ENTRY: {
entry = ScalarMacroCatalogEntry::deserialize(deserializer);
} break;
default:
KU_UNREACHABLE;
}
Expand Down
Loading

0 comments on commit 67a814e

Please sign in to comment.