Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue 2469, 2986 & 3185 #3284

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/binder/bind/read/bind_in_query_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
#include "binder/expression/literal_expression.h"
#include "binder/query/reading_clause/bound_in_query_call.h"
#include "catalog/catalog.h"
#include "common/exception/binder.h"
#include "function/built_in_function_utils.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/query/reading_clause/in_query_call_clause.h"

using namespace kuzu::common;
using namespace kuzu::parser;
using namespace kuzu::function;
using namespace kuzu::catalog;

namespace kuzu {
namespace binder {
Expand All @@ -18,6 +20,7 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
auto& call = readingClause.constCast<InQueryCallClause>();
auto expr = call.getFunctionExpression();
auto functionExpr = expr->constPtrCast<ParsedFunctionExpression>();
auto functionName = functionExpr->getFunctionName();
expression_vector params;
for (auto i = 0u; i < functionExpr->getNumChildren(); i++) {
auto child = functionExpr->getChild(i);
Expand All @@ -32,10 +35,14 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
inputTypes.push_back(literalExpr->getDataType());
inputValues.push_back(*literalExpr->getValue());
}
auto functions = clientContext->getCatalog()->getFunctions(clientContext->getTx());
auto catalogSet = clientContext->getCatalog()->getFunctions(clientContext->getTx());
auto functionEntry = BuiltInFunctionsUtils::getFunctionCatalogEntry(functionName, catalogSet);
if (functionEntry->getType() != CatalogEntryType::TABLE_FUNCTION_ENTRY) {
throw BinderException(stringFormat("{} is not a table function.", functionName));
}
auto func = BuiltInFunctionsUtils::matchFunction(functionExpr->getFunctionName(), inputTypes,
functions);
tableFunction = *ku_dynamic_cast<function::Function*, function::TableFunction*>(func);
functionEntry);
tableFunction = *func->constPtrCast<TableFunction>();
auto bindInput = function::TableFuncBindInput();
bindInput.inputs = std::move(inputValues);
auto bindData = tableFunction.bindFunc(clientContext, &bindInput);
Expand Down
11 changes: 6 additions & 5 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ namespace kuzu {
namespace binder {

std::shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(const ParsedExpression& expr) {
auto& funcExpr =
ku_dynamic_cast<const ParsedExpression&, const ParsedFunctionExpression&>(expr);
auto functionName = funcExpr.getNormalizedFunctionName();
auto funcExpr = expr.constPtrCast<ParsedFunctionExpression>();
auto functionName = funcExpr->getNormalizedFunctionName();
auto result = rewriteFunctionExpression(expr, functionName);
if (result != nullptr) {
return result;
Expand All @@ -41,11 +40,13 @@ std::shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(const Parse
case CatalogEntryType::REWRITE_FUNCTION_ENTRY:
return bindRewriteFunctionExpression(expr);
case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY:
return bindAggregateFunctionExpression(expr, functionName, funcExpr.getIsDistinct());
return bindAggregateFunctionExpression(expr, functionName, funcExpr->getIsDistinct());
case CatalogEntryType::SCALAR_MACRO_ENTRY:
return bindMacroExpression(expr, functionName);
default:
KU_UNREACHABLE;
throw BinderException(
stringFormat("{} is a {}. Scalar function, aggregate function or macro was expected. ",
functionName, CatalogEntryTypeUtils::toString(entry->getType())));
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ std::shared_ptr<Expression> Binder::createVariable(const std::string& name,
}

std::unique_ptr<LogicalType> Binder::bindDataType(const std::string& dataType) {
auto boundType = LogicalTypeUtils::dataTypeFromString(dataType);
auto boundType = LogicalType::fromString(dataType);
if (boundType.getLogicalTypeID() == LogicalTypeID::ARRAY) {
auto numElementsInArray = ArrayType::getNumElements(&boundType);
if (numElementsInArray == 0) {
Expand Down
10 changes: 6 additions & 4 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,17 @@ void Catalog::logAlterTableToWAL(const BoundAlterInfo& info) {
}
}

void Catalog::addFunction(std::string name, function::function_set functionSet) {
void Catalog::addFunction(CatalogEntryType entryType, std::string name,
function::function_set functionSet) {
initCatalogContentForWriteTrxIfNecessary();
KU_ASSERT(readWriteVersion != nullptr);
setToUpdated();
readWriteVersion->addFunction(std::move(name), std::move(functionSet));
readWriteVersion->addFunction(entryType, std::move(name), std::move(functionSet));
}

void Catalog::addBuiltInFunction(std::string name, function::function_set functionSet) {
readOnlyVersion->addFunction(std::move(name), std::move(functionSet));
void Catalog::addBuiltInFunction(CatalogEntryType entryType, std::string name,
function::function_set functionSet) {
readOnlyVersion->addFunction(entryType, std::move(name), std::move(functionSet));
}

CatalogSet* Catalog::getFunctions(Transaction* tx) const {
Expand Down
7 changes: 4 additions & 3 deletions src/catalog/catalog_content.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,13 @@ void CatalogContent::readFromFile(const std::string& directory, FileVersionType
functions = CatalogSet::deserialize(deserializer);
}

void CatalogContent::addFunction(std::string name, function::function_set definitions) {
void CatalogContent::addFunction(CatalogEntryType entryType, std::string name,
function::function_set definitions) {
if (functions->containsEntry(name)) {
throw CatalogException{stringFormat("function {} already exists.", name)};
}
functions->createEntry(std::make_unique<FunctionCatalogEntry>(
CatalogEntryType::SCALAR_FUNCTION_ENTRY, std::move(name), std::move(definitions)));
functions->createEntry(
std::make_unique<FunctionCatalogEntry>(entryType, std::move(name), std::move(definitions)));
}

function::ScalarMacroFunction* CatalogContent::getScalarMacroFunction(
Expand Down
1 change: 1 addition & 0 deletions src/catalog/catalog_entry/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(kuzu_catalog_entry
OBJECT
catalog_entry.cpp
catalog_entry_type.cpp
function_catalog_entry.cpp
table_catalog_entry.cpp
node_table_catalog_entry.cpp
Expand Down
36 changes: 36 additions & 0 deletions src/catalog/catalog_entry/catalog_entry_type.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "catalog/catalog_entry/catalog_entry_type.h"

#include "common/assert.h"

namespace kuzu {
namespace catalog {

std::string CatalogEntryTypeUtils::toString(CatalogEntryType type) {
switch (type) {
case CatalogEntryType::NODE_TABLE_ENTRY:
return "NODE_TABLE_ENTRY";
case CatalogEntryType::REL_TABLE_ENTRY:
return "REL_TABLE_ENTRY";
case CatalogEntryType::REL_GROUP_ENTRY:
return "REL_GROUP_ENTRY";
case CatalogEntryType::RDF_GRAPH_ENTRY:
return "RDF_GRAPH_ENTRY";
case CatalogEntryType::SCALAR_MACRO_ENTRY:
return "SCALAR_MACRO_ENTRY";
case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY:
return "AGGREGATE_FUNCTION_ENTRY";
case CatalogEntryType::SCALAR_FUNCTION_ENTRY:
return "SCALAR_FUNCTION_ENTRY";
case CatalogEntryType::REWRITE_FUNCTION_ENTRY:
return "REWRITE_FUNCTION_ENTRY";
case CatalogEntryType::TABLE_FUNCTION_ENTRY:
return "TABLE_FUNCTION_ENTRY";
case CatalogEntryType::FOREIGN_TABLE_ENTRY:
return "FOREIGN_TABLE_ENTRY";
default:
KU_UNREACHABLE;
}
}

} // namespace catalog
} // namespace kuzu
131 changes: 70 additions & 61 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,36 @@ std::string LogicalType::toString() const {
}
}

static LogicalTypeID strToLogicalTypeID(const std::string& trimmedStr);
static std::vector<std::string> parseStructFields(const std::string& structTypeStr);
static std::unique_ptr<LogicalType> parseListType(const std::string& trimmedStr);
static std::unique_ptr<LogicalType> parseArrayType(const std::string& trimmedStr);
static std::vector<StructField> parseStructTypeInfo(const std::string& structTypeStr);
static std::unique_ptr<LogicalType> parseStructType(const std::string& trimmedStr);
static std::unique_ptr<LogicalType> parseMapType(const std::string& trimmedStr);
static std::unique_ptr<LogicalType> parseUnionType(const std::string& trimmedStr);

LogicalType LogicalType::fromString(const std::string& str) {
LogicalType dataType;
auto trimmedStr = StringUtils::ltrim(StringUtils::rtrim(str));
auto upperDataTypeString = StringUtils::getUpper(trimmedStr);
if (upperDataTypeString.ends_with("[]")) {
dataType = *parseListType(trimmedStr);
} else if (upperDataTypeString.ends_with("]")) {
dataType = *parseArrayType(trimmedStr);
} else if (upperDataTypeString.starts_with("STRUCT")) {
dataType = *parseStructType(trimmedStr);
} else if (upperDataTypeString.starts_with("MAP")) {
dataType = *parseMapType(trimmedStr);
} else if (upperDataTypeString.starts_with("UNION")) {
dataType = *parseUnionType(trimmedStr);
} else {
dataType.typeID = strToLogicalTypeID(upperDataTypeString);
}
dataType.physicalType = LogicalType::getPhysicalType(dataType.typeID);
return dataType;
}

void LogicalType::serialize(Serializer& serializer) const {
serializer.serializeValue(typeID);
serializer.serializeValue(physicalType);
Expand Down Expand Up @@ -544,79 +574,58 @@ PhysicalTypeID LogicalType::getPhysicalType(LogicalTypeID typeID) {
}
}

LogicalType LogicalTypeUtils::dataTypeFromString(const std::string& dataTypeString) {
LogicalType dataType;
auto trimmedStr = StringUtils::ltrim(StringUtils::rtrim(dataTypeString));
auto upperDataTypeString = StringUtils::getUpper(trimmedStr);
if (upperDataTypeString.ends_with("[]")) {
dataType = *parseListType(trimmedStr);
} else if (upperDataTypeString.ends_with("]")) {
dataType = *parseArrayType(trimmedStr);
} else if (upperDataTypeString.starts_with("STRUCT")) {
dataType = *parseStructType(trimmedStr);
} else if (upperDataTypeString.starts_with("MAP")) {
dataType = *parseMapType(trimmedStr);
} else if (upperDataTypeString.starts_with("UNION")) {
dataType = *parseUnionType(trimmedStr);
} else {
dataType.typeID = dataTypeIDFromString(upperDataTypeString);
}
dataType.physicalType = LogicalType::getPhysicalType(dataType.typeID);
return dataType;
}

LogicalTypeID LogicalTypeUtils::dataTypeIDFromString(const std::string& dataTypeIDString) {
auto upperDataTypeIDString = StringUtils::getUpper(dataTypeIDString);
if ("INTERNAL_ID" == upperDataTypeIDString) {
LogicalTypeID strToLogicalTypeID(const std::string& str) {
auto upperStr = StringUtils::getUpper(str);
if ("INTERNAL_ID" == upperStr) {
return LogicalTypeID::INTERNAL_ID;
} else if ("INT64" == upperDataTypeIDString) {
} else if ("INT64" == upperStr) {
return LogicalTypeID::INT64;
} else if ("INT32" == upperDataTypeIDString || "INT" == upperDataTypeIDString) {
} else if ("INT32" == upperStr || "INT" == upperStr) {
return LogicalTypeID::INT32;
} else if ("INT16" == upperDataTypeIDString) {
} else if ("INT16" == upperStr) {
return LogicalTypeID::INT16;
} else if ("INT8" == upperDataTypeIDString) {
} else if ("INT8" == upperStr) {
return LogicalTypeID::INT8;
} else if ("UINT64" == upperDataTypeIDString) {
} else if ("UINT64" == upperStr) {
return LogicalTypeID::UINT64;
} else if ("UINT32" == upperDataTypeIDString) {
} else if ("UINT32" == upperStr) {
return LogicalTypeID::UINT32;
} else if ("UINT16" == upperDataTypeIDString) {
} else if ("UINT16" == upperStr) {
return LogicalTypeID::UINT16;
} else if ("UINT8" == upperDataTypeIDString) {
} else if ("UINT8" == upperStr) {
return LogicalTypeID::UINT8;
} else if ("INT128" == upperDataTypeIDString) {
} else if ("INT128" == upperStr) {
return LogicalTypeID::INT128;
} else if ("DOUBLE" == upperDataTypeIDString) {
} else if ("DOUBLE" == upperStr || "FLOAT8" == upperStr) {
return LogicalTypeID::DOUBLE;
} else if ("FLOAT" == upperDataTypeIDString) {
} else if ("FLOAT" == upperStr || "FLOAT4" == upperStr || "REAL" == upperStr) {
return LogicalTypeID::FLOAT;
} else if ("BOOLEAN" == upperDataTypeIDString || "BOOL" == upperDataTypeIDString) {
} else if ("BOOLEAN" == upperStr || "BOOL" == upperStr) {
return LogicalTypeID::BOOL;
} else if ("BYTEA" == upperDataTypeIDString || "BLOB" == upperDataTypeIDString) {
} else if ("BYTEA" == upperStr || "BLOB" == upperStr) {
return LogicalTypeID::BLOB;
} else if ("UUID" == upperDataTypeIDString) {
} else if ("UUID" == upperStr) {
return LogicalTypeID::UUID;
} else if ("STRING" == upperDataTypeIDString) {
} else if ("STRING" == upperStr) {
return LogicalTypeID::STRING;
} else if ("DATE" == upperDataTypeIDString) {
} else if ("DATE" == upperStr) {
return LogicalTypeID::DATE;
} else if ("TIMESTAMP" == upperDataTypeIDString) {
} else if ("TIMESTAMP" == upperStr) {
return LogicalTypeID::TIMESTAMP;
} else if ("TIMESTAMP_NS" == upperDataTypeIDString) {
} else if ("TIMESTAMP_NS" == upperStr) {
return LogicalTypeID::TIMESTAMP_NS;
} else if ("TIMESTAMP_MS" == upperDataTypeIDString) {
} else if ("TIMESTAMP_MS" == upperStr) {
return LogicalTypeID::TIMESTAMP_MS;
} else if ("TIMESTAMP_SEC" == upperDataTypeIDString || "TIMESTAMP_S" == upperDataTypeIDString) {
} else if ("TIMESTAMP_SEC" == upperStr || "TIMESTAMP_S" == upperStr) {
return LogicalTypeID::TIMESTAMP_SEC;
} else if ("TIMESTAMP_TZ" == upperDataTypeIDString) {
} else if ("TIMESTAMP_TZ" == upperStr) {
return LogicalTypeID::TIMESTAMP_TZ;
} else if ("INTERVAL" == upperDataTypeIDString) {
} else if ("INTERVAL" == upperStr || "DURATION" == upperStr) {
return LogicalTypeID::INTERVAL;
} else if ("SERIAL" == upperDataTypeIDString) {
} else if ("SERIAL" == upperStr) {
return LogicalTypeID::SERIAL;
} else {
throw NotImplementedException("Cannot parse dataTypeID: " + dataTypeIDString);
throw NotImplementedException("Cannot parse dataTypeID: " + str);
}
}

Expand Down Expand Up @@ -822,7 +831,7 @@ std::vector<LogicalTypeID> LogicalTypeUtils::getAllValidLogicTypes() {
LogicalTypeID::RDF_VARIANT};
}

std::vector<std::string> LogicalTypeUtils::parseStructFields(const std::string& structTypeStr) {
std::vector<std::string> parseStructFields(const std::string& structTypeStr) {
std::vector<std::string> structFieldsStr;
auto startPos = 0u;
auto curPos = 0u;
Expand Down Expand Up @@ -853,22 +862,22 @@ std::vector<std::string> LogicalTypeUtils::parseStructFields(const std::string&
return structFieldsStr;
}

std::unique_ptr<LogicalType> LogicalTypeUtils::parseListType(const std::string& trimmedStr) {
return LogicalType::LIST(dataTypeFromString(trimmedStr.substr(0, trimmedStr.size() - 2)));
std::unique_ptr<LogicalType> parseListType(const std::string& trimmedStr) {
return LogicalType::LIST(LogicalType::fromString(trimmedStr.substr(0, trimmedStr.size() - 2)));
}

std::unique_ptr<LogicalType> LogicalTypeUtils::parseArrayType(const std::string& trimmedStr) {
std::unique_ptr<LogicalType> parseArrayType(const std::string& trimmedStr) {
auto leftBracketPos = trimmedStr.find_last_of('[');
auto rightBracketPos = trimmedStr.find_last_of(']');
auto childType =
std::make_unique<LogicalType>(dataTypeFromString(trimmedStr.substr(0, leftBracketPos)));
auto childType = std::make_unique<LogicalType>(
LogicalType::fromString(trimmedStr.substr(0, leftBracketPos)));
auto numElements = std::strtoll(
trimmedStr.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1).c_str(),
nullptr, 0 /* base */);
return LogicalType::ARRAY(std::move(childType), numElements);
}

std::vector<StructField> LogicalTypeUtils::parseStructTypeInfo(const std::string& structTypeStr) {
std::vector<StructField> parseStructTypeInfo(const std::string& structTypeStr) {
auto leftBracketPos = structTypeStr.find('(');
auto rightBracketPos = structTypeStr.find_last_of(')');
if (leftBracketPos == std::string::npos || rightBracketPos == std::string::npos) {
Expand All @@ -884,28 +893,28 @@ std::vector<StructField> LogicalTypeUtils::parseStructTypeInfo(const std::string
auto fieldName = structFieldStr.substr(0, pos);
auto fieldTypeString = structFieldStr.substr(pos + 1);
structFields.emplace_back(fieldName,
std::make_unique<LogicalType>(dataTypeFromString(fieldTypeString)));
std::make_unique<LogicalType>(LogicalType::fromString(fieldTypeString)));
}
return structFields;
}

std::unique_ptr<LogicalType> LogicalTypeUtils::parseStructType(const std::string& trimmedStr) {
std::unique_ptr<LogicalType> parseStructType(const std::string& trimmedStr) {
return LogicalType::STRUCT(parseStructTypeInfo(trimmedStr));
}

std::unique_ptr<LogicalType> LogicalTypeUtils::parseMapType(const std::string& trimmedStr) {
std::unique_ptr<LogicalType> parseMapType(const std::string& trimmedStr) {
auto leftBracketPos = trimmedStr.find('(');
auto rightBracketPos = trimmedStr.find_last_of(')');
if (leftBracketPos == std::string::npos || rightBracketPos == std::string::npos) {
throw Exception("Cannot parse map type: " + trimmedStr);
}
auto mapTypeStr = trimmedStr.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1);
auto keyValueTypes = StringUtils::splitComma(mapTypeStr);
return LogicalType::MAP(dataTypeFromString(keyValueTypes[0]),
dataTypeFromString(keyValueTypes[1]));
return LogicalType::MAP(LogicalType::fromString(keyValueTypes[0]),
LogicalType::fromString(keyValueTypes[1]));
}

std::unique_ptr<LogicalType> LogicalTypeUtils::parseUnionType(const std::string& trimmedStr) {
std::unique_ptr<LogicalType> parseUnionType(const std::string& trimmedStr) {
auto unionFields = parseStructTypeInfo(trimmedStr);
auto unionTagField = StructField(UnionType::TAG_FIELD_NAME,
std::make_unique<LogicalType>(UnionType::TAG_FIELD_TYPE));
Expand Down
Loading
Loading