Skip to content

Commit

Permalink
Fix issue 2469, 2986 & 3185 (#3284)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Apr 16, 2024
1 parent 12e3b1b commit 53b687f
Show file tree
Hide file tree
Showing 24 changed files with 235 additions and 126 deletions.
13 changes: 10 additions & 3 deletions src/binder/bind/read/bind_in_query_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
#include "binder/expression/literal_expression.h"
#include "binder/query/reading_clause/bound_in_query_call.h"
#include "catalog/catalog.h"
#include "common/exception/binder.h"
#include "function/built_in_function_utils.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/query/reading_clause/in_query_call_clause.h"

using namespace kuzu::common;
using namespace kuzu::parser;
using namespace kuzu::function;
using namespace kuzu::catalog;

namespace kuzu {
namespace binder {
Expand All @@ -18,6 +20,7 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
auto& call = readingClause.constCast<InQueryCallClause>();
auto expr = call.getFunctionExpression();
auto functionExpr = expr->constPtrCast<ParsedFunctionExpression>();
auto functionName = functionExpr->getFunctionName();
expression_vector params;
for (auto i = 0u; i < functionExpr->getNumChildren(); i++) {
auto child = functionExpr->getChild(i);
Expand All @@ -32,10 +35,14 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
inputTypes.push_back(literalExpr->getDataType());
inputValues.push_back(*literalExpr->getValue());
}
auto functions = clientContext->getCatalog()->getFunctions(clientContext->getTx());
auto catalogSet = clientContext->getCatalog()->getFunctions(clientContext->getTx());
auto functionEntry = BuiltInFunctionsUtils::getFunctionCatalogEntry(functionName, catalogSet);
if (functionEntry->getType() != CatalogEntryType::TABLE_FUNCTION_ENTRY) {
throw BinderException(stringFormat("{} is not a table function.", functionName));
}
auto func = BuiltInFunctionsUtils::matchFunction(functionExpr->getFunctionName(), inputTypes,
functions);
tableFunction = *ku_dynamic_cast<function::Function*, function::TableFunction*>(func);
functionEntry);
tableFunction = *func->constPtrCast<TableFunction>();
auto bindInput = function::TableFuncBindInput();
bindInput.inputs = std::move(inputValues);
auto bindData = tableFunction.bindFunc(clientContext, &bindInput);
Expand Down
11 changes: 6 additions & 5 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ namespace kuzu {
namespace binder {

std::shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(const ParsedExpression& expr) {
auto& funcExpr =
ku_dynamic_cast<const ParsedExpression&, const ParsedFunctionExpression&>(expr);
auto functionName = funcExpr.getNormalizedFunctionName();
auto funcExpr = expr.constPtrCast<ParsedFunctionExpression>();
auto functionName = funcExpr->getNormalizedFunctionName();
auto result = rewriteFunctionExpression(expr, functionName);
if (result != nullptr) {
return result;
Expand All @@ -41,11 +40,13 @@ std::shared_ptr<Expression> ExpressionBinder::bindFunctionExpression(const Parse
case CatalogEntryType::REWRITE_FUNCTION_ENTRY:
return bindRewriteFunctionExpression(expr);
case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY:
return bindAggregateFunctionExpression(expr, functionName, funcExpr.getIsDistinct());
return bindAggregateFunctionExpression(expr, functionName, funcExpr->getIsDistinct());
case CatalogEntryType::SCALAR_MACRO_ENTRY:
return bindMacroExpression(expr, functionName);
default:
KU_UNREACHABLE;
throw BinderException(
stringFormat("{} is a {}. Scalar function, aggregate function or macro was expected. ",
functionName, CatalogEntryTypeUtils::toString(entry->getType())));
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ std::shared_ptr<Expression> Binder::createVariable(const std::string& name,
}

std::unique_ptr<LogicalType> Binder::bindDataType(const std::string& dataType) {
auto boundType = LogicalTypeUtils::dataTypeFromString(dataType);
auto boundType = LogicalType::fromString(dataType);
if (boundType.getLogicalTypeID() == LogicalTypeID::ARRAY) {
auto numElementsInArray = ArrayType::getNumElements(&boundType);
if (numElementsInArray == 0) {
Expand Down
10 changes: 6 additions & 4 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,17 @@ void Catalog::logAlterTableToWAL(const BoundAlterInfo& info) {
}
}

void Catalog::addFunction(std::string name, function::function_set functionSet) {
void Catalog::addFunction(CatalogEntryType entryType, std::string name,
function::function_set functionSet) {
initCatalogContentForWriteTrxIfNecessary();
KU_ASSERT(readWriteVersion != nullptr);
setToUpdated();
readWriteVersion->addFunction(std::move(name), std::move(functionSet));
readWriteVersion->addFunction(entryType, std::move(name), std::move(functionSet));
}

void Catalog::addBuiltInFunction(std::string name, function::function_set functionSet) {
readOnlyVersion->addFunction(std::move(name), std::move(functionSet));
void Catalog::addBuiltInFunction(CatalogEntryType entryType, std::string name,
function::function_set functionSet) {
readOnlyVersion->addFunction(entryType, std::move(name), std::move(functionSet));
}

CatalogSet* Catalog::getFunctions(Transaction* tx) const {
Expand Down
7 changes: 4 additions & 3 deletions src/catalog/catalog_content.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,13 @@ void CatalogContent::readFromFile(const std::string& directory, FileVersionType
functions = CatalogSet::deserialize(deserializer);
}

void CatalogContent::addFunction(std::string name, function::function_set definitions) {
void CatalogContent::addFunction(CatalogEntryType entryType, std::string name,
function::function_set definitions) {
if (functions->containsEntry(name)) {
throw CatalogException{stringFormat("function {} already exists.", name)};
}
functions->createEntry(std::make_unique<FunctionCatalogEntry>(
CatalogEntryType::SCALAR_FUNCTION_ENTRY, std::move(name), std::move(definitions)));
functions->createEntry(
std::make_unique<FunctionCatalogEntry>(entryType, std::move(name), std::move(definitions)));
}

function::ScalarMacroFunction* CatalogContent::getScalarMacroFunction(
Expand Down
1 change: 1 addition & 0 deletions src/catalog/catalog_entry/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(kuzu_catalog_entry
OBJECT
catalog_entry.cpp
catalog_entry_type.cpp
function_catalog_entry.cpp
table_catalog_entry.cpp
node_table_catalog_entry.cpp
Expand Down
36 changes: 36 additions & 0 deletions src/catalog/catalog_entry/catalog_entry_type.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "catalog/catalog_entry/catalog_entry_type.h"

#include "common/assert.h"

namespace kuzu {
namespace catalog {

std::string CatalogEntryTypeUtils::toString(CatalogEntryType type) {
switch (type) {
case CatalogEntryType::NODE_TABLE_ENTRY:
return "NODE_TABLE_ENTRY";
case CatalogEntryType::REL_TABLE_ENTRY:
return "REL_TABLE_ENTRY";
case CatalogEntryType::REL_GROUP_ENTRY:
return "REL_GROUP_ENTRY";
case CatalogEntryType::RDF_GRAPH_ENTRY:
return "RDF_GRAPH_ENTRY";
case CatalogEntryType::SCALAR_MACRO_ENTRY:
return "SCALAR_MACRO_ENTRY";
case CatalogEntryType::AGGREGATE_FUNCTION_ENTRY:
return "AGGREGATE_FUNCTION_ENTRY";
case CatalogEntryType::SCALAR_FUNCTION_ENTRY:
return "SCALAR_FUNCTION_ENTRY";
case CatalogEntryType::REWRITE_FUNCTION_ENTRY:
return "REWRITE_FUNCTION_ENTRY";
case CatalogEntryType::TABLE_FUNCTION_ENTRY:
return "TABLE_FUNCTION_ENTRY";
case CatalogEntryType::FOREIGN_TABLE_ENTRY:
return "FOREIGN_TABLE_ENTRY";
default:
KU_UNREACHABLE;
}
}

} // namespace catalog
} // namespace kuzu
131 changes: 70 additions & 61 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,36 @@ std::string LogicalType::toString() const {
}
}

static LogicalTypeID strToLogicalTypeID(const std::string& trimmedStr);
static std::vector<std::string> parseStructFields(const std::string& structTypeStr);
static std::unique_ptr<LogicalType> parseListType(const std::string& trimmedStr);
static std::unique_ptr<LogicalType> parseArrayType(const std::string& trimmedStr);
static std::vector<StructField> parseStructTypeInfo(const std::string& structTypeStr);
static std::unique_ptr<LogicalType> parseStructType(const std::string& trimmedStr);
static std::unique_ptr<LogicalType> parseMapType(const std::string& trimmedStr);
static std::unique_ptr<LogicalType> parseUnionType(const std::string& trimmedStr);

LogicalType LogicalType::fromString(const std::string& str) {
LogicalType dataType;
auto trimmedStr = StringUtils::ltrim(StringUtils::rtrim(str));
auto upperDataTypeString = StringUtils::getUpper(trimmedStr);
if (upperDataTypeString.ends_with("[]")) {
dataType = *parseListType(trimmedStr);
} else if (upperDataTypeString.ends_with("]")) {
dataType = *parseArrayType(trimmedStr);
} else if (upperDataTypeString.starts_with("STRUCT")) {
dataType = *parseStructType(trimmedStr);
} else if (upperDataTypeString.starts_with("MAP")) {
dataType = *parseMapType(trimmedStr);
} else if (upperDataTypeString.starts_with("UNION")) {
dataType = *parseUnionType(trimmedStr);
} else {
dataType.typeID = strToLogicalTypeID(upperDataTypeString);
}
dataType.physicalType = LogicalType::getPhysicalType(dataType.typeID);
return dataType;
}

void LogicalType::serialize(Serializer& serializer) const {
serializer.serializeValue(typeID);
serializer.serializeValue(physicalType);
Expand Down Expand Up @@ -544,79 +574,58 @@ PhysicalTypeID LogicalType::getPhysicalType(LogicalTypeID typeID) {
}
}

LogicalType LogicalTypeUtils::dataTypeFromString(const std::string& dataTypeString) {
LogicalType dataType;
auto trimmedStr = StringUtils::ltrim(StringUtils::rtrim(dataTypeString));
auto upperDataTypeString = StringUtils::getUpper(trimmedStr);
if (upperDataTypeString.ends_with("[]")) {
dataType = *parseListType(trimmedStr);
} else if (upperDataTypeString.ends_with("]")) {
dataType = *parseArrayType(trimmedStr);
} else if (upperDataTypeString.starts_with("STRUCT")) {
dataType = *parseStructType(trimmedStr);
} else if (upperDataTypeString.starts_with("MAP")) {
dataType = *parseMapType(trimmedStr);
} else if (upperDataTypeString.starts_with("UNION")) {
dataType = *parseUnionType(trimmedStr);
} else {
dataType.typeID = dataTypeIDFromString(upperDataTypeString);
}
dataType.physicalType = LogicalType::getPhysicalType(dataType.typeID);
return dataType;
}

LogicalTypeID LogicalTypeUtils::dataTypeIDFromString(const std::string& dataTypeIDString) {
auto upperDataTypeIDString = StringUtils::getUpper(dataTypeIDString);
if ("INTERNAL_ID" == upperDataTypeIDString) {
LogicalTypeID strToLogicalTypeID(const std::string& str) {
auto upperStr = StringUtils::getUpper(str);
if ("INTERNAL_ID" == upperStr) {
return LogicalTypeID::INTERNAL_ID;
} else if ("INT64" == upperDataTypeIDString) {
} else if ("INT64" == upperStr) {
return LogicalTypeID::INT64;
} else if ("INT32" == upperDataTypeIDString || "INT" == upperDataTypeIDString) {
} else if ("INT32" == upperStr || "INT" == upperStr) {
return LogicalTypeID::INT32;
} else if ("INT16" == upperDataTypeIDString) {
} else if ("INT16" == upperStr) {
return LogicalTypeID::INT16;
} else if ("INT8" == upperDataTypeIDString) {
} else if ("INT8" == upperStr) {
return LogicalTypeID::INT8;
} else if ("UINT64" == upperDataTypeIDString) {
} else if ("UINT64" == upperStr) {
return LogicalTypeID::UINT64;
} else if ("UINT32" == upperDataTypeIDString) {
} else if ("UINT32" == upperStr) {
return LogicalTypeID::UINT32;
} else if ("UINT16" == upperDataTypeIDString) {
} else if ("UINT16" == upperStr) {
return LogicalTypeID::UINT16;
} else if ("UINT8" == upperDataTypeIDString) {
} else if ("UINT8" == upperStr) {
return LogicalTypeID::UINT8;
} else if ("INT128" == upperDataTypeIDString) {
} else if ("INT128" == upperStr) {
return LogicalTypeID::INT128;
} else if ("DOUBLE" == upperDataTypeIDString) {
} else if ("DOUBLE" == upperStr || "FLOAT8" == upperStr) {
return LogicalTypeID::DOUBLE;
} else if ("FLOAT" == upperDataTypeIDString) {
} else if ("FLOAT" == upperStr || "FLOAT4" == upperStr || "REAL" == upperStr) {
return LogicalTypeID::FLOAT;
} else if ("BOOLEAN" == upperDataTypeIDString || "BOOL" == upperDataTypeIDString) {
} else if ("BOOLEAN" == upperStr || "BOOL" == upperStr) {
return LogicalTypeID::BOOL;
} else if ("BYTEA" == upperDataTypeIDString || "BLOB" == upperDataTypeIDString) {
} else if ("BYTEA" == upperStr || "BLOB" == upperStr) {
return LogicalTypeID::BLOB;
} else if ("UUID" == upperDataTypeIDString) {
} else if ("UUID" == upperStr) {
return LogicalTypeID::UUID;
} else if ("STRING" == upperDataTypeIDString) {
} else if ("STRING" == upperStr) {
return LogicalTypeID::STRING;
} else if ("DATE" == upperDataTypeIDString) {
} else if ("DATE" == upperStr) {
return LogicalTypeID::DATE;
} else if ("TIMESTAMP" == upperDataTypeIDString) {
} else if ("TIMESTAMP" == upperStr) {
return LogicalTypeID::TIMESTAMP;
} else if ("TIMESTAMP_NS" == upperDataTypeIDString) {
} else if ("TIMESTAMP_NS" == upperStr) {
return LogicalTypeID::TIMESTAMP_NS;
} else if ("TIMESTAMP_MS" == upperDataTypeIDString) {
} else if ("TIMESTAMP_MS" == upperStr) {
return LogicalTypeID::TIMESTAMP_MS;
} else if ("TIMESTAMP_SEC" == upperDataTypeIDString || "TIMESTAMP_S" == upperDataTypeIDString) {
} else if ("TIMESTAMP_SEC" == upperStr || "TIMESTAMP_S" == upperStr) {
return LogicalTypeID::TIMESTAMP_SEC;
} else if ("TIMESTAMP_TZ" == upperDataTypeIDString) {
} else if ("TIMESTAMP_TZ" == upperStr) {
return LogicalTypeID::TIMESTAMP_TZ;
} else if ("INTERVAL" == upperDataTypeIDString) {
} else if ("INTERVAL" == upperStr || "DURATION" == upperStr) {
return LogicalTypeID::INTERVAL;
} else if ("SERIAL" == upperDataTypeIDString) {
} else if ("SERIAL" == upperStr) {
return LogicalTypeID::SERIAL;
} else {
throw NotImplementedException("Cannot parse dataTypeID: " + dataTypeIDString);
throw NotImplementedException("Cannot parse dataTypeID: " + str);
}
}

Expand Down Expand Up @@ -822,7 +831,7 @@ std::vector<LogicalTypeID> LogicalTypeUtils::getAllValidLogicTypes() {
LogicalTypeID::RDF_VARIANT};
}

std::vector<std::string> LogicalTypeUtils::parseStructFields(const std::string& structTypeStr) {
std::vector<std::string> parseStructFields(const std::string& structTypeStr) {
std::vector<std::string> structFieldsStr;
auto startPos = 0u;
auto curPos = 0u;
Expand Down Expand Up @@ -853,22 +862,22 @@ std::vector<std::string> LogicalTypeUtils::parseStructFields(const std::string&
return structFieldsStr;
}

std::unique_ptr<LogicalType> LogicalTypeUtils::parseListType(const std::string& trimmedStr) {
return LogicalType::LIST(dataTypeFromString(trimmedStr.substr(0, trimmedStr.size() - 2)));
std::unique_ptr<LogicalType> parseListType(const std::string& trimmedStr) {
return LogicalType::LIST(LogicalType::fromString(trimmedStr.substr(0, trimmedStr.size() - 2)));
}

std::unique_ptr<LogicalType> LogicalTypeUtils::parseArrayType(const std::string& trimmedStr) {
std::unique_ptr<LogicalType> parseArrayType(const std::string& trimmedStr) {
auto leftBracketPos = trimmedStr.find_last_of('[');
auto rightBracketPos = trimmedStr.find_last_of(']');
auto childType =
std::make_unique<LogicalType>(dataTypeFromString(trimmedStr.substr(0, leftBracketPos)));
auto childType = std::make_unique<LogicalType>(
LogicalType::fromString(trimmedStr.substr(0, leftBracketPos)));
auto numElements = std::strtoll(
trimmedStr.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1).c_str(),
nullptr, 0 /* base */);
return LogicalType::ARRAY(std::move(childType), numElements);
}

std::vector<StructField> LogicalTypeUtils::parseStructTypeInfo(const std::string& structTypeStr) {
std::vector<StructField> parseStructTypeInfo(const std::string& structTypeStr) {
auto leftBracketPos = structTypeStr.find('(');
auto rightBracketPos = structTypeStr.find_last_of(')');
if (leftBracketPos == std::string::npos || rightBracketPos == std::string::npos) {
Expand All @@ -884,28 +893,28 @@ std::vector<StructField> LogicalTypeUtils::parseStructTypeInfo(const std::string
auto fieldName = structFieldStr.substr(0, pos);
auto fieldTypeString = structFieldStr.substr(pos + 1);
structFields.emplace_back(fieldName,
std::make_unique<LogicalType>(dataTypeFromString(fieldTypeString)));
std::make_unique<LogicalType>(LogicalType::fromString(fieldTypeString)));
}
return structFields;
}

std::unique_ptr<LogicalType> LogicalTypeUtils::parseStructType(const std::string& trimmedStr) {
std::unique_ptr<LogicalType> parseStructType(const std::string& trimmedStr) {
return LogicalType::STRUCT(parseStructTypeInfo(trimmedStr));
}

std::unique_ptr<LogicalType> LogicalTypeUtils::parseMapType(const std::string& trimmedStr) {
std::unique_ptr<LogicalType> parseMapType(const std::string& trimmedStr) {
auto leftBracketPos = trimmedStr.find('(');
auto rightBracketPos = trimmedStr.find_last_of(')');
if (leftBracketPos == std::string::npos || rightBracketPos == std::string::npos) {
throw Exception("Cannot parse map type: " + trimmedStr);
}
auto mapTypeStr = trimmedStr.substr(leftBracketPos + 1, rightBracketPos - leftBracketPos - 1);
auto keyValueTypes = StringUtils::splitComma(mapTypeStr);
return LogicalType::MAP(dataTypeFromString(keyValueTypes[0]),
dataTypeFromString(keyValueTypes[1]));
return LogicalType::MAP(LogicalType::fromString(keyValueTypes[0]),
LogicalType::fromString(keyValueTypes[1]));
}

std::unique_ptr<LogicalType> LogicalTypeUtils::parseUnionType(const std::string& trimmedStr) {
std::unique_ptr<LogicalType> parseUnionType(const std::string& trimmedStr) {
auto unionFields = parseStructTypeInfo(trimmedStr);
auto unionTagField = StructField(UnionType::TAG_FIELD_NAME,
std::make_unique<LogicalType>(UnionType::TAG_FIELD_TYPE));
Expand Down
Loading

0 comments on commit 53b687f

Please sign in to comment.