Skip to content

Commit

Permalink
Resolve default any type (#3374)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Apr 25, 2024
1 parent 7d632f3 commit 6d763fe
Show file tree
Hide file tree
Showing 15 changed files with 321 additions and 83 deletions.
8 changes: 4 additions & 4 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ std::shared_ptr<RelExpression> Binder::createNonRecursiveQueryRel(const std::str
fields.emplace_back(property->getPropertyName(), property->getDataType().copy());
}
auto extraInfo = std::make_unique<StructTypeInfo>(std::move(fields));
RelType::setExtraTypeInfo(queryRel->getDataTypeReference(), std::move(extraInfo));
queryRel->setExtraTypeInfo(std::move(extraInfo));
return queryRel;
}

Expand Down Expand Up @@ -357,7 +357,7 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
}
bindRecursiveRelProjectionList(nodeProjectionList, nodeFields);
auto nodeExtraInfo = std::make_unique<StructTypeInfo>(std::move(nodeFields));
node->getDataTypeReference().setExtraTypeInfo(std::move(nodeExtraInfo));
node->setExtraTypeInfo(std::move(nodeExtraInfo));
auto nodeCopy = createQueryNode(recursivePatternInfo->nodeName,
std::vector<table_id_t>{nodeTableIDs.begin(), nodeTableIDs.end()});
// Bind intermediate rel
Expand All @@ -384,7 +384,7 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
relFields.emplace_back(InternalKeyword::ID, LogicalType::INTERNAL_ID());
bindRecursiveRelProjectionList(relProjectionList, relFields);
auto relExtraInfo = std::make_unique<StructTypeInfo>(std::move(relFields));
rel->getDataTypeReference().setExtraTypeInfo(std::move(relExtraInfo));
rel->setExtraTypeInfo(std::move(relExtraInfo));
// Bind predicates in {}, e.g. [e* {date=1999-01-01}]
std::shared_ptr<Expression> relPredicate;
for (auto& [propertyName, rhs] : relPattern.getPropertyKeyVals()) {
Expand Down Expand Up @@ -569,7 +569,7 @@ std::shared_ptr<NodeExpression> Binder::createQueryNode(const std::string& parse
fieldTypes.emplace_back(property->dataType.copy());
}
auto extraInfo = std::make_unique<StructTypeInfo>(fieldNames, fieldTypes);
NodeType::setExtraTypeInfo(queryNode->getDataTypeReference(), std::move(extraInfo));
queryNode->setExtraTypeInfo(std::move(extraInfo));
return queryNode;
}

Expand Down
15 changes: 13 additions & 2 deletions src/binder/bind/read/bind_unwind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,23 @@ using namespace kuzu::common;
namespace kuzu {
namespace binder {

// E.g. UNWIND $1. We cannot validate $1 has data type LIST until we see the actual parameter.
static bool skipDataTypeValidation(const Expression& expr) {
return expr.expressionType == ExpressionType::PARAMETER &&
expr.getDataType().getLogicalTypeID() == LogicalTypeID::ANY;
}

std::unique_ptr<BoundReadingClause> Binder::bindUnwindClause(const ReadingClause& readingClause) {
auto& unwindClause = readingClause.constCast<UnwindClause>();
auto boundExpression = expressionBinder.bindExpression(*unwindClause.getExpression());
ExpressionUtil::validateDataType(*boundExpression, LogicalTypeID::LIST);
auto aliasName = unwindClause.getAlias();
auto alias = createVariable(aliasName, *ListType::getChildType(&boundExpression->dataType));
std::shared_ptr<Expression> alias;
if (!skipDataTypeValidation(*boundExpression)) {
ExpressionUtil::validateDataType(*boundExpression, LogicalTypeID::LIST);
alias = createVariable(aliasName, *ListType::getChildType(&boundExpression->dataType));
} else {
alias = createVariable(aliasName, *LogicalType::ANY());
}
std::shared_ptr<Expression> idExpr = nullptr;
if (scope.hasMemorizedTableIDs(boundExpression->getAlias())) {
auto tableIDs = scope.getMemorizedTableIDs(boundExpression->getAlias());
Expand Down
3 changes: 2 additions & 1 deletion src/binder/expression/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ add_library(
expression_util.cpp
function_expression.cpp
literal_expression.cpp
parameter_expression.cpp)
parameter_expression.cpp
variable_expression.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_binder_expression>
Expand Down
2 changes: 1 addition & 1 deletion src/binder/expression/parameter_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using namespace common;
namespace binder {

void ParameterExpression::cast(const LogicalType& type) {
if (dataType.getLogicalTypeID() != LogicalTypeID::ANY) {
if (!dataType.containsAny()) {
// LCOV_EXCL_START
throw BinderException(
stringFormat("Cannot change parameter expression data type from {} to {}.",
Expand Down
22 changes: 22 additions & 0 deletions src/binder/expression/variable_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "binder/expression/variable_expression.h"

#include "common/exception/binder.h"

using namespace kuzu::common;

namespace kuzu {
namespace binder {

void VariableExpression::cast(const LogicalType& type) {
if (!dataType.containsAny()) {
// LCOV_EXCL_START
throw BinderException(
stringFormat("Cannot change variable expression data type from {} to {}.",
dataType.toString(), type.toString()));
// LCOV_EXCL_STOP
}
dataType = type;
}

} // namespace binder
} // namespace kuzu
42 changes: 40 additions & 2 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,50 @@ static std::string unsupportedImplicitCastException(const Expression& expression
expression.toString(), expression.dataType.toString(), targetTypeStr);
}

static bool compatible(const LogicalType& type, const LogicalType& target) {
if (type.getLogicalTypeID() == LogicalTypeID::ANY) {
return true;
}
if (type.getLogicalTypeID() != target.getLogicalTypeID()) {
return false;
}
switch (type.getLogicalTypeID()) {
case LogicalTypeID::LIST: {
return compatible(*ListType::getChildType(&type), *ListType::getChildType(&target));
}
case LogicalTypeID::ARRAY: {
return compatible(*ArrayType::getChildType(&type), *ArrayType::getChildType(&target));
}
case LogicalTypeID::STRUCT: {
if (StructType::getNumFields(&type) != StructType::getNumFields(&target)) {
return false;
}
for (auto i = 0u; i < StructType::getNumFields(&type); ++i) {
if (!compatible(*StructType::getField(&type, i)->getType(),
*StructType::getField(&target, i)->getType())) {
return false;
}
}
return true;
}
case LogicalTypeID::RDF_VARIANT:
case LogicalTypeID::UNION:
case LogicalTypeID::MAP:
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL:
return false;
default:
return true;
}
}

std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, const LogicalType& targetType) {
if (targetType.getLogicalTypeID() == LogicalTypeID::ANY || expression->dataType == targetType) {
if (expression->dataType == targetType || targetType.containsAny()) { // No need to cast.
return expression;
}
if (expression->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
if (compatible(expression->getDataType(), targetType)) {
expression->cast(targetType);
return expression;
}
Expand Down
24 changes: 24 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ uint32_t PhysicalTypeUtils::getFixedTypeSize(PhysicalTypeID physicalType) {
}
}

bool ListTypeInfo::containsAny() const {
return childType->containsAny();
}

bool ListTypeInfo::operator==(const ExtraTypeInfo& other) const {
auto otherListTypeInfo = ku_dynamic_cast<const ExtraTypeInfo*, const ListTypeInfo*>(&other);
if (otherListTypeInfo) {
Expand Down Expand Up @@ -147,6 +151,10 @@ void ArrayTypeInfo::serializeInternal(Serializer& serializer) const {
serializer.serializeValue(numElements);
}

bool StructField::containsAny() const {
return type->containsAny();
}

bool StructField::operator==(const StructField& other) const {
return *type == *other.type;
}
Expand Down Expand Up @@ -238,6 +246,15 @@ std::vector<const StructField*> StructTypeInfo::getStructFields() const {
return structFields;
}

bool StructTypeInfo::containsAny() const {
for (auto& field : fields) {
if (field.containsAny()) {
return true;
}
}
return false;
}

bool StructTypeInfo::operator==(const ExtraTypeInfo& other) const {
auto otherStructTypeInfo = ku_dynamic_cast<const ExtraTypeInfo*, const StructTypeInfo*>(&other);
if (otherStructTypeInfo) {
Expand Down Expand Up @@ -314,6 +331,13 @@ LogicalType::LogicalType(const LogicalType& other) {
}
}

bool LogicalType::containsAny() const {
if (extraTypeInfo != nullptr) {
return extraTypeInfo->containsAny();
}
return typeID == LogicalTypeID::ANY;
}

LogicalType& LogicalType::operator=(const LogicalType& other) {
// Reuse the copy constructor and move assignment operator.
LogicalType copy(other);
Expand Down
2 changes: 1 addition & 1 deletion src/common/types/value/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ bool Value::operator==(const Value& rhs) const {
}

void Value::setDataType(const LogicalType& dataType_) {
KU_ASSERT(dataType->getLogicalTypeID() == LogicalTypeID::ANY);
KU_ASSERT(dataType->containsAny());
dataType = dataType_.copy();
}

Expand Down
3 changes: 2 additions & 1 deletion src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ class Expression : public std::enable_shared_from_this<Expression> {
}

virtual void cast(const common::LogicalType& type);
// NOTE: Avoid using the following unsafe getter. It is meant for resolving ANY data type only.
common::LogicalType& getDataTypeUnsafe() { return dataType; }
common::LogicalType getDataType() const { return dataType; }
common::LogicalType& getDataTypeReference() { return dataType; }

bool hasAlias() const { return !alias.empty(); }
std::string getAlias() const { return alias; }
Expand Down
49 changes: 26 additions & 23 deletions src/include/binder/expression/node_rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ class NodeOrRelExpression : public Expression {
variableName(std::move(variableName)), tableIDs{std::move(tableIDs)} {}
~NodeOrRelExpression() override = default;

inline std::string getVariableName() const { return variableName; }

inline void setTableIDs(common::table_id_vector_t tableIDs) {
this->tableIDs = std::move(tableIDs);
// Note: ideally I would try to remove this function. But for now, we have to create type
// after expression.
void setExtraTypeInfo(std::unique_ptr<common::ExtraTypeInfo> info) {
dataType.setExtraTypeInfo(std::move(info));
}
inline void addTableIDs(const common::table_id_vector_t& tableIDsToAdd) {

std::string getVariableName() const { return variableName; }

void setTableIDs(common::table_id_vector_t tableIDs_) { tableIDs = std::move(tableIDs_); }
void addTableIDs(const common::table_id_vector_t& tableIDsToAdd) {
auto tableIDsSet = getTableIDsSet();
for (auto tableID : tableIDsToAdd) {
if (!tableIDsSet.contains(tableID)) {
Expand All @@ -29,63 +33,62 @@ class NodeOrRelExpression : public Expression {
}
}

inline bool isMultiLabeled() const { return tableIDs.size() > 1; }
inline uint32_t getNumTableIDs() const { return tableIDs.size(); }
inline std::vector<common::table_id_t> getTableIDs() const { return tableIDs; }
inline std::unordered_set<common::table_id_t> getTableIDsSet() const {
bool isMultiLabeled() const { return tableIDs.size() > 1; }
uint32_t getNumTableIDs() const { return tableIDs.size(); }
std::vector<common::table_id_t> getTableIDs() const { return tableIDs; }
std::unordered_set<common::table_id_t> getTableIDsSet() const {
return {tableIDs.begin(), tableIDs.end()};
}
inline common::table_id_t getSingleTableID() const {
common::table_id_t getSingleTableID() const {
KU_ASSERT(tableIDs.size() == 1);
return tableIDs[0];
}

inline void addPropertyExpression(const std::string& propertyName,
void addPropertyExpression(const std::string& propertyName,
std::unique_ptr<Expression> property) {
KU_ASSERT(!propertyNameToIdx.contains(propertyName));
propertyNameToIdx.insert({propertyName, propertyExprs.size()});
propertyExprs.push_back(std::move(property));
}
inline bool hasPropertyExpression(const std::string& propertyName) const {
bool hasPropertyExpression(const std::string& propertyName) const {
return propertyNameToIdx.contains(propertyName);
}
inline std::shared_ptr<Expression> getPropertyExpression(
const std::string& propertyName) const {
std::shared_ptr<Expression> getPropertyExpression(const std::string& propertyName) const {
KU_ASSERT(propertyNameToIdx.contains(propertyName));
return propertyExprs[propertyNameToIdx.at(propertyName)]->copy();
}
inline const std::vector<std::unique_ptr<Expression>>& getPropertyExprsRef() const {
const std::vector<std::unique_ptr<Expression>>& getPropertyExprsRef() const {
return propertyExprs;
}
inline expression_vector getPropertyExprs() const {
expression_vector getPropertyExprs() const {
expression_vector result;
for (auto& expr : propertyExprs) {
result.push_back(expr->copy());
}
return result;
}

inline void setLabelExpression(std::shared_ptr<Expression> expression) {
void setLabelExpression(std::shared_ptr<Expression> expression) {
labelExpression = std::move(expression);
}
inline std::shared_ptr<Expression> getLabelExpression() const { return labelExpression; }
std::shared_ptr<Expression> getLabelExpression() const { return labelExpression; }

inline void addPropertyDataExpr(std::string propertyName, std::shared_ptr<Expression> expr) {
void addPropertyDataExpr(std::string propertyName, std::shared_ptr<Expression> expr) {
propertyDataExprs.insert({propertyName, expr});
}
inline const std::unordered_map<std::string, std::shared_ptr<Expression>>&
const std::unordered_map<std::string, std::shared_ptr<Expression>>&
getPropertyDataExprRef() const {
return propertyDataExprs;
}
inline bool hasPropertyDataExpr(const std::string& propertyName) const {
bool hasPropertyDataExpr(const std::string& propertyName) const {
return propertyDataExprs.contains(propertyName);
}
inline std::shared_ptr<Expression> getPropertyDataExpr(const std::string& propertyName) const {
std::shared_ptr<Expression> getPropertyDataExpr(const std::string& propertyName) const {
KU_ASSERT(propertyDataExprs.contains(propertyName));
return propertyDataExprs.at(propertyName);
}

inline std::string toStringInternal() const final { return variableName; }
std::string toStringInternal() const final { return variableName; }

protected:
std::string variableName;
Expand Down
8 changes: 5 additions & 3 deletions src/include/binder/expression/variable_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ class VariableExpression : public Expression {
: Expression{common::ExpressionType::VARIABLE, std::move(dataType), std::move(uniqueName)},
variableName{std::move(variableName)} {}

inline std::string getVariableName() const { return variableName; }
std::string getVariableName() const { return variableName; }

inline std::string toStringInternal() const final { return variableName; }
void cast(const common::LogicalType& type) override;

inline std::unique_ptr<Expression> copy() const final {
std::string toStringInternal() const final { return variableName; }

std::unique_ptr<Expression> copy() const final {
return std::make_unique<VariableExpression>(*dataType.copy(), uniqueName, variableName);
}

Expand Down
Loading

0 comments on commit 6d763fe

Please sign in to comment.