Skip to content

Commit

Permalink
Merge pull request #2571 from kuzudb/logical-type-cleanup
Browse files Browse the repository at this point in the history
Cleanup logical type construction
  • Loading branch information
benjaminwinger committed Dec 14, 2023
2 parents ac33484 + 3799d81 commit 2af6412
Show file tree
Hide file tree
Showing 45 changed files with 432 additions and 410 deletions.
78 changes: 31 additions & 47 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,12 @@ std::unique_ptr<QueryGraph> Binder::bindPatternElement(

static std::unique_ptr<LogicalType> getRecursiveRelLogicalType(
const LogicalType& nodeType, const LogicalType& relType) {
auto nodesType = std::make_unique<LogicalType>(
LogicalTypeID::VAR_LIST, std::make_unique<VarListTypeInfo>(nodeType.copy()));
auto relsType = std::make_unique<LogicalType>(
LogicalTypeID::VAR_LIST, std::make_unique<VarListTypeInfo>(relType.copy()));
std::vector<std::unique_ptr<StructField>> recursiveRelFields;
recursiveRelFields.push_back(
std::make_unique<StructField>(InternalKeyword::NODES, std::move(nodesType)));
recursiveRelFields.push_back(
std::make_unique<StructField>(InternalKeyword::RELS, std::move(relsType)));
return std::make_unique<LogicalType>(LogicalTypeID::RECURSIVE_REL,
auto nodesType = LogicalType::VAR_LIST(nodeType.copy());
auto relsType = LogicalType::VAR_LIST(relType.copy());
std::vector<StructField> recursiveRelFields;
recursiveRelFields.emplace_back(InternalKeyword::NODES, std::move(nodesType));
recursiveRelFields.emplace_back(InternalKeyword::RELS, std::move(relsType));
return LogicalType::RECURSIVE_REL(
std::make_unique<StructTypeInfo>(std::move(recursiveRelFields)));
}

Expand Down Expand Up @@ -119,9 +115,9 @@ std::shared_ptr<Expression> Binder::createPath(
}
}
auto nodeExtraInfo = std::make_unique<StructTypeInfo>(nodeFieldNames, nodeFieldTypes);
auto nodeType = std::make_unique<LogicalType>(LogicalTypeID::NODE, std::move(nodeExtraInfo));
auto nodeType = LogicalType::NODE(std::move(nodeExtraInfo));
auto relExtraInfo = std::make_unique<StructTypeInfo>(relFieldNames, relFieldTypes);
auto relType = std::make_unique<LogicalType>(LogicalTypeID::REL, std::move(relExtraInfo));
auto relType = LogicalType::REL(std::move(relExtraInfo));
auto uniqueName = getUniqueExpressionName(pathName);
return std::make_shared<PathExpression>(*getRecursiveRelLogicalType(*nodeType, *relType),
uniqueName, pathName, std::move(nodeType), std::move(relType), children);
Expand Down Expand Up @@ -244,23 +240,19 @@ std::shared_ptr<RelExpression> Binder::createNonRecursiveQueryRel(const std::str
std::move(dstNode), directionType, QueryRelType::NON_RECURSIVE);
queryRel->setAlias(parsedName);
bindQueryRelProperties(*queryRel);
std::vector<std::string> fieldNames;
std::vector<std::unique_ptr<LogicalType>> fieldTypes;
fieldNames.emplace_back(InternalKeyword::SRC);
fieldNames.emplace_back(InternalKeyword::DST);
fieldTypes.push_back(LogicalType::INTERNAL_ID());
fieldTypes.push_back(LogicalType::INTERNAL_ID());
std::vector<StructField> fields;
fields.emplace_back(InternalKeyword::SRC, LogicalType::INTERNAL_ID());
fields.emplace_back(InternalKeyword::DST, LogicalType::INTERNAL_ID());
// Bind internal expressions.
queryRel->setLabelExpression(expressionBinder.bindLabelFunction(*queryRel));
fieldNames.emplace_back(InternalKeyword::LABEL);
fieldTypes.push_back(queryRel->getLabelExpression()->getDataType().copy());
fields.emplace_back(
InternalKeyword::LABEL, queryRel->getLabelExpression()->getDataType().copy());
// Bind properties.
for (auto& expression : queryRel->getPropertyExpressions()) {
auto property = reinterpret_cast<PropertyExpression*>(expression.get());
fieldNames.push_back(property->getPropertyName());
fieldTypes.push_back(property->getDataType().copy());
fields.emplace_back(property->getPropertyName(), property->getDataType().copy());
}
auto extraInfo = std::make_unique<StructTypeInfo>(fieldNames, fieldTypes);
auto extraInfo = std::make_unique<StructTypeInfo>(std::move(fields));
RelType::setExtraTypeInfo(queryRel->getDataTypeReference(), std::move(extraInfo));
auto tableSchema = catalog.getTableSchema(clientContext->getTx(), tableIDs[0]);
if (tableSchema->getTableType() == TableType::RDF) {
Expand All @@ -286,16 +278,15 @@ std::shared_ptr<RelExpression> Binder::createNonRecursiveQueryRel(const std::str
return queryRel;
}

static void bindRecursiveRelProjectionList(const expression_vector& projectionList,
std::vector<std::string>& fieldNames, std::vector<std::unique_ptr<LogicalType>>& fieldTypes) {
static void bindRecursiveRelProjectionList(
const expression_vector& projectionList, std::vector<StructField>& fields) {
for (auto& expression : projectionList) {
if (expression->expressionType != common::ExpressionType::PROPERTY) {
throw BinderException(stringFormat(
"Unsupported projection item {} on recursive rel.", expression->toString()));
}
auto property = reinterpret_cast<PropertyExpression*>(expression.get());
fieldNames.push_back(property->getPropertyName());
fieldTypes.push_back(property->getDataType().copy());
fields.emplace_back(property->getPropertyName(), property->getDataType().copy());
}
}

Expand All @@ -317,12 +308,10 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
auto node = createQueryNode(recursivePatternInfo->nodeName,
std::vector<table_id_t>{nodeTableIDs.begin(), nodeTableIDs.end()});
scope->addExpression(node->toString(), node);
std::vector<std::string> nodeFieldNames;
std::vector<std::unique_ptr<LogicalType>> nodeFieldTypes;
nodeFieldNames.emplace_back(InternalKeyword::ID);
nodeFieldNames.emplace_back(InternalKeyword::LABEL);
nodeFieldTypes.push_back(node->getInternalID()->getDataType().copy());
nodeFieldTypes.push_back(node->getLabelExpression()->getDataType().copy());
std::vector<StructField> nodeFields;
nodeFields.emplace_back(InternalKeyword::ID, node->getInternalID()->getDataType().copy());
nodeFields.emplace_back(
InternalKeyword::LABEL, node->getLabelExpression()->getDataType().copy());
expression_vector nodeProjectionList;
if (!recursivePatternInfo->hasProjection) {
for (auto& expression : node->getPropertyExpressions()) {
Expand All @@ -333,8 +322,8 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
nodeProjectionList.push_back(expressionBinder.bindExpression(*expression));
}
}
bindRecursiveRelProjectionList(nodeProjectionList, nodeFieldNames, nodeFieldTypes);
auto nodeExtraInfo = std::make_unique<StructTypeInfo>(nodeFieldNames, nodeFieldTypes);
bindRecursiveRelProjectionList(nodeProjectionList, nodeFields);
auto nodeExtraInfo = std::make_unique<StructTypeInfo>(std::move(nodeFields));
node->getDataTypeReference().setExtraTypeInfo(std::move(nodeExtraInfo));
auto nodeCopy = createQueryNode(recursivePatternInfo->nodeName,
std::vector<table_id_t>{nodeTableIDs.begin(), nodeTableIDs.end()});
Expand All @@ -355,18 +344,13 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
relProjectionList.push_back(expressionBinder.bindExpression(*expression));
}
}
std::vector<std::string> relFieldNames;
std::vector<std::unique_ptr<LogicalType>> relFieldTypes;
relFieldNames.emplace_back(InternalKeyword::SRC);
relFieldNames.emplace_back(InternalKeyword::DST);
relFieldNames.emplace_back(InternalKeyword::LABEL);
relFieldNames.emplace_back(InternalKeyword::ID);
relFieldTypes.push_back(LogicalType::INTERNAL_ID());
relFieldTypes.push_back(LogicalType::INTERNAL_ID());
relFieldTypes.push_back(rel->getLabelExpression()->getDataType().copy());
relFieldTypes.push_back(LogicalType::INTERNAL_ID());
bindRecursiveRelProjectionList(relProjectionList, relFieldNames, relFieldTypes);
auto relExtraInfo = std::make_unique<StructTypeInfo>(relFieldNames, relFieldTypes);
std::vector<StructField> relFields;
relFields.emplace_back(InternalKeyword::SRC, LogicalType::INTERNAL_ID());
relFields.emplace_back(InternalKeyword::DST, LogicalType::INTERNAL_ID());
relFields.emplace_back(InternalKeyword::LABEL, rel->getLabelExpression()->getDataType().copy());
relFields.emplace_back(InternalKeyword::ID, LogicalType::INTERNAL_ID());
bindRecursiveRelProjectionList(relProjectionList, relFields);
auto relExtraInfo = std::make_unique<StructTypeInfo>(std::move(relFields));
rel->getDataTypeReference().setExtraTypeInfo(std::move(relExtraInfo));
// Bind predicates in {}, e.g. [e* {date=1999-01-01}]
std::shared_ptr<Expression> relPredicate;
Expand Down
2 changes: 1 addition & 1 deletion src/binder/bind_expression/bind_boolean_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindBooleanExpression(
function::scalar_select_func selectFunc;
function::VectorBooleanFunction::bindSelectFunction(
expressionType, childrenAfterCast, selectFunc);
auto bindData = std::make_unique<function::FunctionBindData>(LogicalType(LogicalTypeID::BOOL));
auto bindData = std::make_unique<function::FunctionBindData>(LogicalType::BOOL());
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(functionName, childrenAfterCast);
return make_shared<ScalarFunctionExpression>(functionName, expressionType, std::move(bindData),
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind_expression/bind_comparison_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
childrenAfterCast.push_back(
implicitCastIfNecessary(children[i], function->parameterTypeIDs[i]));
}
auto bindData =
std::make_unique<function::FunctionBindData>(LogicalType(function->returnTypeID));
auto bindData = std::make_unique<function::FunctionBindData>(
std::make_unique<LogicalType>(function->returnTypeID));
auto uniqueExpressionName =
ScalarFunctionExpression::getUniqueName(function->name, childrenAfterCast);
return make_shared<ScalarFunctionExpression>(functionName, expressionType, std::move(bindData),
Expand Down
32 changes: 13 additions & 19 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
if (function->bindFunc) {
bindData = function->bindFunc(childrenAfterCast, function);
} else {
bindData =
std::make_unique<function::FunctionBindData>(LogicalType(function->returnTypeID));
bindData = std::make_unique<function::FunctionBindData>(
std::make_unique<LogicalType>(function->returnTypeID));
}
}
auto uniqueExpressionName =
Expand Down Expand Up @@ -114,8 +114,8 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
if (function->bindFunc) {
bindData = function->bindFunc(children, function.get());
} else {
bindData =
std::make_unique<function::FunctionBindData>(LogicalType(function->returnTypeID));
bindData = std::make_unique<function::FunctionBindData>(
std::make_unique<LogicalType>(function->returnTypeID));
}
return make_shared<AggregateFunctionExpression>(functionName, std::move(bindData),
std::move(children), std::move(function), uniqueExpressionName);
Expand Down Expand Up @@ -193,8 +193,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindInternalIDExpression(
return bindNodeOrRelPropertyExpression(*expression, InternalKeyword::ID);
}
KU_ASSERT(expression->dataType.getPhysicalType() == PhysicalTypeID::STRUCT);
auto stringValue =
std::make_unique<Value>(LogicalType{LogicalTypeID::STRING}, InternalKeyword::ID);
auto stringValue = std::make_unique<Value>(LogicalType::STRING(), InternalKeyword::ID);
return bindScalarFunctionExpression(
expression_vector{expression, createLiteralExpression(std::move(stringValue))},
STRUCT_EXTRACT_FUNC_NAME);
Expand All @@ -208,21 +207,17 @@ static std::vector<std::unique_ptr<Value>> populateLabelValues(std::vector<table
labels.resize(maxTableID + 1);
for (auto i = 0u; i < labels.size(); ++i) {
if (tableIDsSet.contains(i)) {
labels[i] = std::make_unique<Value>(
LogicalType{LogicalTypeID::STRING}, catalog.getTableName(tx, i));
labels[i] = std::make_unique<Value>(LogicalType::STRING(), catalog.getTableName(tx, i));
} else {
// TODO(Xiyang/Guodong): change to null literal once we support null in LIST type.
labels[i] =
std::make_unique<Value>(LogicalType{LogicalTypeID::STRING}, std::string(""));
labels[i] = std::make_unique<Value>(LogicalType::STRING(), std::string(""));
}
}
return labels;
}

std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(const Expression& expression) {
auto varListTypeInfo = std::make_unique<VarListTypeInfo>(LogicalType::STRING());
auto listType =
std::make_unique<LogicalType>(LogicalTypeID::VAR_LIST, std::move(varListTypeInfo));
auto listType = LogicalType::VAR_LIST(LogicalType::STRING());
expression_vector children;
switch (expression.getDataType().getLogicalTypeID()) {
case LogicalTypeID::NODE: {
Expand All @@ -231,11 +226,11 @@ std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(const Expression
auto labelName = binder->catalog.getTableName(
binder->clientContext->getTx(), node.getSingleTableID());
return createLiteralExpression(
std::make_unique<Value>(LogicalType{LogicalTypeID::STRING}, labelName));
std::make_unique<Value>(LogicalType::STRING(), labelName));
}
auto nodeTableIDs = binder->catalog.getNodeTableIDs(binder->clientContext->getTx());
children.push_back(node.getInternalID());
auto labelsValue = std::make_unique<Value>(*listType,
auto labelsValue = std::make_unique<Value>(std::move(listType),
populateLabelValues(nodeTableIDs, binder->catalog, binder->clientContext->getTx()));
children.push_back(createLiteralExpression(std::move(labelsValue)));
} break;
Expand All @@ -245,20 +240,19 @@ std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(const Expression
auto labelName = binder->catalog.getTableName(
binder->clientContext->getTx(), rel.getSingleTableID());
return createLiteralExpression(
std::make_unique<Value>(LogicalType{LogicalTypeID::STRING}, labelName));
std::make_unique<Value>(LogicalType::STRING(), labelName));
}
auto relTableIDs = binder->catalog.getRelTableIDs(binder->clientContext->getTx());
children.push_back(rel.getInternalIDProperty());
auto labelsValue = std::make_unique<Value>(*listType,
auto labelsValue = std::make_unique<Value>(std::move(listType),
populateLabelValues(relTableIDs, binder->catalog, binder->clientContext->getTx()));
children.push_back(createLiteralExpression(std::move(labelsValue)));
} break;
default:
KU_UNREACHABLE;
}
auto execFunc = function::LabelFunction::execFunction;
auto bindData =
std::make_unique<function::FunctionBindData>(LogicalType(LogicalTypeID::STRING));
auto bindData = std::make_unique<function::FunctionBindData>(LogicalType::STRING());
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(LABEL_FUNC_NAME, children);
return std::make_shared<ScalarFunctionExpression>(LABEL_FUNC_NAME, ExpressionType::FUNCTION,
std::move(bindData), std::move(children), execFunc, nullptr, uniqueExpressionName);
Expand Down
2 changes: 1 addition & 1 deletion src/binder/bind_expression/bind_literal_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ std::shared_ptr<Expression> ExpressionBinder::createLiteralExpression(

std::shared_ptr<Expression> ExpressionBinder::createStringLiteralExpression(
const std::string& strVal) {
auto value = std::make_unique<Value>(LogicalType{LogicalTypeID::STRING}, strVal);
auto value = std::make_unique<Value>(LogicalType::STRING(), strVal);
return createLiteralExpression(std::move(value));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ std::shared_ptr<Expression> ExpressionBinder::bindNullOperatorExpression(
function::VectorNullFunction::bindExecFunction(expressionType, children, execFunc);
function::scalar_select_func selectFunc;
function::VectorNullFunction::bindSelectFunction(expressionType, children, selectFunc);
auto bindData = std::make_unique<function::FunctionBindData>(LogicalType(LogicalTypeID::BOOL));
auto bindData = std::make_unique<function::FunctionBindData>(LogicalType::BOOL());
auto uniqueExpressionName = ScalarFunctionExpression::getUniqueName(functionName, children);
return make_shared<ScalarFunctionExpression>(functionName, expressionType, std::move(bindData),
std::move(children), std::move(execFunc), std::move(selectFunc), uniqueExpressionName);
Expand Down
3 changes: 2 additions & 1 deletion src/binder/bind_expression/bind_subquery_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ std::shared_ptr<Expression> ExpressionBinder::bindSubqueryExpression(
// Bind projection
auto function = binder->catalog.getBuiltInFunctions()->matchAggregateFunction(
COUNT_STAR_FUNC_NAME, std::vector<LogicalType*>{}, false);
auto bindData = std::make_unique<FunctionBindData>(LogicalType(function->returnTypeID));
auto bindData =
std::make_unique<FunctionBindData>(std::make_unique<LogicalType>(function->returnTypeID));
auto countStarExpr = std::make_shared<AggregateFunctionExpression>(COUNT_STAR_FUNC_NAME,
std::move(bindData), expression_vector{}, function->clone(),
binder->getUniqueExpressionName(COUNT_STAR_FUNC_NAME));
Expand Down
2 changes: 1 addition & 1 deletion src/binder/bound_statement_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace binder {
std::unique_ptr<BoundStatementResult> BoundStatementResult::createSingleStringColumnResult(
const std::string& columnName) {
auto result = std::make_unique<BoundStatementResult>();
auto value = std::make_unique<Value>(LogicalType{LogicalTypeID::STRING}, columnName);
auto value = std::make_unique<Value>(LogicalType::STRING(), columnName);
auto stringColumn = std::make_shared<LiteralExpression>(std::move(value), columnName);
result->addColumn(stringColumn);
return result;
Expand Down
2 changes: 1 addition & 1 deletion src/binder/expression/function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ std::string ScalarFunctionExpression::toStringInternal() const {
result += ExpressionUtil::toString(children);
if (functionName == "CAST") {
result += ", ";
result += bindData->resultType.toString();
result += bindData->resultType->toString();
}
result += ")";
return result;
Expand Down
2 changes: 1 addition & 1 deletion src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ std::shared_ptr<Expression> ExpressionBinder::implicitCast(
if (CastFunction::hasImplicitCast(expression->dataType, targetType)) {
auto functionName = stringFormat("CAST_TO({})", targetType.toString());
auto children = expression_vector{expression};
auto bindData = std::make_unique<FunctionBindData>(targetType);
auto bindData = std::make_unique<FunctionBindData>(targetType.copy());
auto scalarFunction = CastFunction::bindCastFunction(
functionName, expression->dataType.getLogicalTypeID(), targetType.getLogicalTypeID());
auto uniqueName = ScalarFunctionExpression::getUniqueName(functionName, children);
Expand Down
11 changes: 10 additions & 1 deletion src/c_api/data_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@

using namespace kuzu::common;

namespace kuzu::common {
struct CAPIHelper {
static inline LogicalType* createLogicalType(
LogicalTypeID typeID, std::unique_ptr<ExtraTypeInfo> extraTypeInfo) {
return new LogicalType(typeID, std::move(extraTypeInfo));
}
};
} // namespace kuzu::common

kuzu_logical_type* kuzu_data_type_create(
kuzu_data_type_id id, kuzu_logical_type* child_type, uint64_t fixed_num_elements_in_list) {
auto* c_data_type = (kuzu_logical_type*)malloc(sizeof(kuzu_logical_type));
Expand All @@ -18,7 +27,7 @@ kuzu_logical_type* kuzu_data_type_create(
std::make_unique<FixedListTypeInfo>(
std::move(child_type_pty), fixed_num_elements_in_list) :
std::make_unique<VarListTypeInfo>(std::move(child_type_pty));
data_type = new LogicalType(logicalTypeID, std::move(extraTypeInfo));
data_type = CAPIHelper::createLogicalType(logicalTypeID, std::move(extraTypeInfo));
}
c_data_type->_data_type = data_type;
return c_data_type;
Expand Down
Loading

0 comments on commit 2af6412

Please sign in to comment.