Skip to content

Commit

Permalink
Refactor table functions
Browse files Browse the repository at this point in the history
  • Loading branch information
manh9203 committed Apr 1, 2024
1 parent 6b1d45a commit f1a6a54
Show file tree
Hide file tree
Showing 33 changed files with 206 additions and 299 deletions.
23 changes: 10 additions & 13 deletions src/binder/bind/copy/bind_copy_rdf_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "function/table/bind_input.h"
#include "main/client_context.h"
#include "parser/copy.h"
#include "processor/operator/persistent/reader/rdf/rdf_scan.h"

using namespace kuzu::binder;
using namespace kuzu::catalog;
Expand Down Expand Up @@ -50,13 +51,12 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(
Function* func;
// Bind file scan;
auto inMemory = RdfReaderConfig::construct(config->options).inMemory;
func = BuiltInFunctionsUtils::matchFunction(READ_RDF_ALL_TRIPLE_FUNC_NAME, functions);
func = BuiltInFunctionsUtils::matchFunction(RdfAllTripleScan::name, functions);
auto scanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto bindData = scanFunc->bindFunc(clientContext, bindInput.get());
// Bind copy resource.
func = inMemory ?
BuiltInFunctionsUtils::matchFunction(IN_MEM_READ_RDF_RESOURCE_FUNC_NAME, functions) :
BuiltInFunctionsUtils::matchFunction(READ_RDF_RESOURCE_FUNC_NAME, functions);
func = inMemory ? BuiltInFunctionsUtils::matchFunction(RdfResourceInMemScan::name, functions) :
BuiltInFunctionsUtils::matchFunction(RdfResourceScan::name, functions);
auto rScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto rColumns = expression_vector{r};
auto rFileScanInfo = BoundFileScanInfo(*rScanFunc, bindData->copy(), std::move(rColumns));
Expand All @@ -65,9 +65,8 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(
auto rEntry = catalog->getTableCatalogEntry(clientContext->getTx(), rTableID);
auto rCopyInfo = BoundCopyFromInfo(rEntry, std::move(rSource), offset, nullptr /* extraInfo */);
// Bind copy literal.
func = inMemory ?
BuiltInFunctionsUtils::matchFunction(IN_MEM_READ_RDF_LITERAL_FUNC_NAME, functions) :
BuiltInFunctionsUtils::matchFunction(READ_RDF_LITERAL_FUNC_NAME, functions);
func = inMemory ? BuiltInFunctionsUtils::matchFunction(RdfLiteralInMemScan::name, functions) :
BuiltInFunctionsUtils::matchFunction(RdfLiteralScan::name, functions);
auto lScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto lColumns = expression_vector{l, lang};
auto lFileScanInfo = BoundFileScanInfo(*lScanFunc, bindData->copy(), std::move(lColumns));
Expand All @@ -77,9 +76,8 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(
auto lCopyInfo = BoundCopyFromInfo(lEntry, std::move(lSource), offset, nullptr /* extraInfo */);
// Bind copy resource triples
func = inMemory ?
BuiltInFunctionsUtils::matchFunction(
IN_MEM_READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, functions) :
BuiltInFunctionsUtils::matchFunction(READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, functions);
BuiltInFunctionsUtils::matchFunction(RdfResourceTripleInMemScan::name, functions) :
BuiltInFunctionsUtils::matchFunction(RdfResourceTripleScan::name, functions);
auto rrrScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto rrrColumns = expression_vector{s, p, o};
auto rrrFileScanInfo = BoundFileScanInfo(*rrrScanFunc, bindData->copy(), rrrColumns);
Expand All @@ -99,9 +97,8 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(
BoundCopyFromInfo(rrrEntry, std::move(rrrSource), offset, std::move(rrrExtraInfo));
// Bind copy literal triples
func = inMemory ?
BuiltInFunctionsUtils::matchFunction(
IN_MEM_READ_RDF_LITERAL_TRIPLE_FUNC_NAME, functions) :
BuiltInFunctionsUtils::matchFunction(READ_RDF_LITERAL_TRIPLE_FUNC_NAME, functions);
BuiltInFunctionsUtils::matchFunction(RdfLiteralTripleInMemScan::name, functions) :
BuiltInFunctionsUtils::matchFunction(RdfLiteralTripleScan::name, functions);
auto rrlScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto rrlColumns = expression_vector{s, p, oOffset};
auto rrlFileScanInfo = BoundFileScanInfo(*rrlScanFunc, bindData->copy(), rrlColumns);
Expand Down
13 changes: 9 additions & 4 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@
#include "common/string_utils.h"
#include "function/built_in_function_utils.h"
#include "function/table_functions.h"
#include "processor/operator/persistent/reader/csv/parallel_csv_reader.h"
#include "processor/operator/persistent/reader/csv/serial_csv_reader.h"
#include "processor/operator/persistent/reader/npy/npy_reader.h"
#include "processor/operator/persistent/reader/parquet/parquet_reader.h"

using namespace kuzu::catalog;
using namespace kuzu::common;
using namespace kuzu::parser;
using namespace kuzu::processor;

namespace kuzu {
namespace binder {
Expand Down Expand Up @@ -201,17 +206,17 @@ function::TableFunction Binder::getScanFunction(FileType fileType, const ReaderC
switch (fileType) {
case FileType::PARQUET: {
func = function::BuiltInFunctionsUtils::matchFunction(
READ_PARQUET_FUNC_NAME, inputTypes, functions);
ParquetScanFunction::name, inputTypes, functions);
} break;
case FileType::NPY: {
func = function::BuiltInFunctionsUtils::matchFunction(
READ_NPY_FUNC_NAME, inputTypes, functions);
NpyScanFunction::name, inputTypes, functions);
} break;
case FileType::CSV: {
auto csvConfig = CSVReaderConfig::construct(config.options);
func = function::BuiltInFunctionsUtils::matchFunction(
csvConfig.parallel ? READ_CSV_PARALLEL_FUNC_NAME : READ_CSV_SERIAL_FUNC_NAME,
inputTypes, functions);
csvConfig.parallel ? ParallelCSVScan::name : SerialCSVScan::name, inputTypes,
functions);
} break;
default:
KU_UNREACHABLE;
Expand Down
3 changes: 1 addition & 2 deletions src/catalog/catalog_entry/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ add_library(kuzu_catalog_entry
rel_table_catalog_entry.cpp
rel_group_catalog_entry.cpp
rdf_graph_catalog_entry.cpp
scalar_macro_catalog_entry.cpp
table_function_catalog_entry.cpp)
scalar_macro_catalog_entry.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_catalog_entry>
Expand Down
18 changes: 0 additions & 18 deletions src/catalog/catalog_entry/table_function_catalog_entry.cpp

This file was deleted.

71 changes: 7 additions & 64 deletions src/function/built_in_function_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
#include "function/built_in_function_utils.h"

#include "catalog/catalog_entry/table_function_catalog_entry.h"
#include "catalog/catalog_entry/function_catalog_entry.h"
#include "catalog/catalog_set.h"
#include "common/exception/binder.h"
#include "common/exception/catalog.h"
#include "function/aggregate_function.h"
#include "function/arithmetic/vector_arithmetic_functions.h"
#include "function/function_collection.h"
#include "function/scalar_function.h"
#include "function/table/call_functions.h"
#include "processor/operator/persistent/reader/csv/parallel_csv_reader.h"
#include "processor/operator/persistent/reader/csv/serial_csv_reader.h"
#include "processor/operator/persistent/reader/npy/npy_reader.h"
#include "processor/operator/persistent/reader/parquet/parquet_reader.h"
#include "processor/operator/persistent/reader/rdf/rdf_scan.h"
#include "processor/operator/table_scan/ftable_scan_function.h"

using namespace kuzu::common;
using namespace kuzu::catalog;
Expand All @@ -31,9 +24,12 @@ static void validateNonEmptyCandidateFunctions(std::vector<Function*>& candidate
function::function_set& set);

void BuiltInFunctionsUtils::createFunctions(CatalogSet* catalogSet) {
registerTableFunctions(catalogSet);

registerFunctions(catalogSet);
auto functions = FunctionCollection::getFunctions();
for (auto i = 0u; functions[i].name != nullptr; ++i) {
auto functionSet = functions[i].getFunctionSetFunc();
catalogSet->createEntry(std::make_unique<FunctionCatalogEntry>(
functions[i].catalogEntryType, functions[i].name, std::move(functionSet)));
}
}

Function* BuiltInFunctionsUtils::matchFunction(const std::string& name, CatalogSet* catalogSet) {
Expand Down Expand Up @@ -489,59 +485,6 @@ void BuiltInFunctionsUtils::validateSpecialCases(std::vector<Function*>& candida
}
}

void BuiltInFunctionsUtils::registerTableFunctions(CatalogSet* catalogSet) {
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
CURRENT_SETTING_FUNC_NAME, CurrentSettingFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
DB_VERSION_FUNC_NAME, DBVersionFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
SHOW_TABLES_FUNC_NAME, ShowTablesFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
TABLE_INFO_FUNC_NAME, TableInfoFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
SHOW_CONNECTION_FUNC_NAME, ShowConnectionFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
STORAGE_INFO_FUNC_NAME, StorageInfoFunction::getFunctionSet()));
// Read functions
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
READ_PARQUET_FUNC_NAME, ParquetScanFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
READ_NPY_FUNC_NAME, NpyScanFunction::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
READ_CSV_SERIAL_FUNC_NAME, SerialCSVScan::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
READ_CSV_PARALLEL_FUNC_NAME, ParallelCSVScan::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
READ_RDF_RESOURCE_FUNC_NAME, RdfResourceScan::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
READ_RDF_LITERAL_FUNC_NAME, RdfLiteralScan::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, RdfResourceTripleScan::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
READ_RDF_LITERAL_TRIPLE_FUNC_NAME, RdfLiteralTripleScan::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
READ_RDF_ALL_TRIPLE_FUNC_NAME, RdfAllTripleScan::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
IN_MEM_READ_RDF_RESOURCE_FUNC_NAME, RdfResourceInMemScan::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
IN_MEM_READ_RDF_LITERAL_FUNC_NAME, RdfLiteralInMemScan::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
IN_MEM_READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, RdfResourceTripleInMemScan::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
IN_MEM_READ_RDF_LITERAL_TRIPLE_FUNC_NAME, RdfLiteralTripleInMemScan::getFunctionSet()));
catalogSet->createEntry(std::make_unique<TableFunctionCatalogEntry>(
READ_FTABLE_FUNC_NAME, FTableScan::getFunctionSet()));
}

void BuiltInFunctionsUtils::registerFunctions(catalog::CatalogSet* catalogSet) {
auto functions = FunctionCollection::getFunctions();
for (auto i = 0u; functions[i].name != nullptr; ++i) {
auto functionSet = functions[i].getFunctionSetFunc();
catalogSet->createEntry(std::make_unique<FunctionCatalogEntry>(
functions[i].catalogEntryType, functions[i].name, std::move(functionSet)));
}
}

static std::string getFunctionMatchFailureMsg(const std::string name,
const std::vector<LogicalType>& inputTypes, const std::string& supportedInputs,
bool isDistinct = false) {
Expand Down
25 changes: 25 additions & 0 deletions src/function/function_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,18 @@
#include "function/schema/vector_node_rel_functions.h"
#include "function/string/vector_string_functions.h"
#include "function/struct/vector_struct_functions.h"
#include "function/table/call_functions.h"
#include "function/timestamp/vector_timestamp_functions.h"
#include "function/union/vector_union_functions.h"
#include "function/uuid/vector_uuid_functions.h"
#include "processor/operator/persistent/reader/csv/parallel_csv_reader.h"
#include "processor/operator/persistent/reader/csv/serial_csv_reader.h"
#include "processor/operator/persistent/reader/npy/npy_reader.h"
#include "processor/operator/persistent/reader/parquet/parquet_reader.h"
#include "processor/operator/persistent/reader/rdf/rdf_scan.h"
#include "processor/operator/table_scan/ftable_scan_function.h"

using namespace kuzu::processor;

namespace kuzu {
namespace function {
Expand All @@ -32,6 +41,8 @@ namespace function {
{ _PARAM::getFunctionSet, _PARAM::name, CatalogEntryType::REWRITE_FUNCTION_ENTRY }
#define AGGREGATE_FUNCTION(_PARAM) \
{ _PARAM::getFunctionSet, _PARAM::name, CatalogEntryType::AGGREGATE_FUNCTION_ENTRY }
#define TABLE_FUNCTION(_PARAM) \
{ _PARAM::getFunctionSet, _PARAM::name, CatalogEntryType::TABLE_FUNCTION_ENTRY }
#define FINAL_FUNCTION \
{ nullptr, nullptr, CatalogEntryType::SCALAR_FUNCTION_ENTRY }

Expand Down Expand Up @@ -173,6 +184,20 @@ FunctionCollection* FunctionCollection::getFunctions() {
AGGREGATE_FUNCTION(AggregateMinFunction), AGGREGATE_FUNCTION(AggregateMaxFunction),
AGGREGATE_FUNCTION(CollectFunction),

// Table functions
TABLE_FUNCTION(CurrentSettingFunction), TABLE_FUNCTION(DBVersionFunction),
TABLE_FUNCTION(ShowTablesFunction), TABLE_FUNCTION(TableInfoFunction),
TABLE_FUNCTION(ShowConnectionFunction), TABLE_FUNCTION(StorageInfoFunction),

// Read functions
TABLE_FUNCTION(ParquetScanFunction), TABLE_FUNCTION(NpyScanFunction),
TABLE_FUNCTION(SerialCSVScan), TABLE_FUNCTION(ParallelCSVScan),
TABLE_FUNCTION(RdfResourceScan), TABLE_FUNCTION(RdfLiteralScan),
TABLE_FUNCTION(RdfResourceTripleScan), TABLE_FUNCTION(RdfLiteralTripleScan),
TABLE_FUNCTION(RdfAllTripleScan), TABLE_FUNCTION(RdfResourceInMemScan),
TABLE_FUNCTION(RdfLiteralInMemScan), TABLE_FUNCTION(RdfResourceTripleInMemScan),
TABLE_FUNCTION(RdfLiteralTripleInMemScan), TABLE_FUNCTION(FTableScan),

// End of array
FINAL_FUNCTION};

Expand Down
5 changes: 2 additions & 3 deletions src/function/table/call/current_setting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ static std::unique_ptr<TableFuncBindData> bindFunc(

function_set CurrentSettingFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(std::make_unique<TableFunction>(CURRENT_SETTING_FUNC_NAME, tableFunc,
bindFunc, initSharedState, initEmptyLocalState,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
functionSet.push_back(std::make_unique<TableFunction>(name, tableFunc, bindFunc,
initSharedState, initEmptyLocalState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}

Expand Down
2 changes: 1 addition & 1 deletion src/function/table/call/db_version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext*, TableFuncBind

function_set DBVersionFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(std::make_unique<TableFunction>(DB_VERSION_FUNC_NAME, tableFunc, bindFunc,
functionSet.push_back(std::make_unique<TableFunction>(name, tableFunc, bindFunc,
initSharedState, initEmptyLocalState, std::vector<LogicalTypeID>{}));
return functionSet;
}
Expand Down
4 changes: 2 additions & 2 deletions src/function/table/call/show_tables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ static std::unique_ptr<TableFuncBindData> bindFunc(

function_set ShowTablesFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(std::make_unique<TableFunction>(SHOW_TABLES_FUNC_NAME, tableFunc,
bindFunc, initSharedState, initEmptyLocalState, std::vector<LogicalTypeID>{}));
functionSet.push_back(std::make_unique<TableFunction>(name, tableFunc, bindFunc,
initSharedState, initEmptyLocalState, std::vector<LogicalTypeID>{}));
return functionSet;
}

Expand Down
6 changes: 3 additions & 3 deletions src/function/table/call/storage_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ static std::unique_ptr<TableFuncBindData> bindFunc(

function_set StorageInfoFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(std::make_unique<TableFunction>(STORAGE_INFO_FUNC_NAME, tableFunc,
bindFunc, initStorageInfoSharedState, initLocalState,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
functionSet.push_back(
std::make_unique<TableFunction>(name, tableFunc, bindFunc, initStorageInfoSharedState,
initLocalState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}

Expand Down
2 changes: 1 addition & 1 deletion src/function/table/call/table_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ static std::unique_ptr<TableFuncBindData> bindFunc(

function_set TableInfoFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(std::make_unique<TableFunction>(TABLE_INFO_FUNC_NAME, tableFunc, bindFunc,
functionSet.push_back(std::make_unique<TableFunction>(name, tableFunc, bindFunc,
initSharedState, initEmptyLocalState, std::vector<LogicalTypeID>{LogicalTypeID::STRING}));
return functionSet;
}
Expand Down
23 changes: 0 additions & 23 deletions src/include/catalog/catalog_entry/table_function_catalog_entry.h

This file was deleted.

Loading

0 comments on commit f1a6a54

Please sign in to comment.