Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve default any type #3374

Merged
merged 1 commit into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ std::shared_ptr<RelExpression> Binder::createNonRecursiveQueryRel(const std::str
fields.emplace_back(property->getPropertyName(), property->getDataType().copy());
}
auto extraInfo = std::make_unique<StructTypeInfo>(std::move(fields));
RelType::setExtraTypeInfo(queryRel->getDataTypeReference(), std::move(extraInfo));
queryRel->setExtraTypeInfo(std::move(extraInfo));
return queryRel;
}

Expand Down Expand Up @@ -357,7 +357,7 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
}
bindRecursiveRelProjectionList(nodeProjectionList, nodeFields);
auto nodeExtraInfo = std::make_unique<StructTypeInfo>(std::move(nodeFields));
node->getDataTypeReference().setExtraTypeInfo(std::move(nodeExtraInfo));
node->setExtraTypeInfo(std::move(nodeExtraInfo));
auto nodeCopy = createQueryNode(recursivePatternInfo->nodeName,
std::vector<table_id_t>{nodeTableIDs.begin(), nodeTableIDs.end()});
// Bind intermediate rel
Expand All @@ -384,7 +384,7 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
relFields.emplace_back(InternalKeyword::ID, LogicalType::INTERNAL_ID());
bindRecursiveRelProjectionList(relProjectionList, relFields);
auto relExtraInfo = std::make_unique<StructTypeInfo>(std::move(relFields));
rel->getDataTypeReference().setExtraTypeInfo(std::move(relExtraInfo));
rel->setExtraTypeInfo(std::move(relExtraInfo));
// Bind predicates in {}, e.g. [e* {date=1999-01-01}]
std::shared_ptr<Expression> relPredicate;
for (auto& [propertyName, rhs] : relPattern.getPropertyKeyVals()) {
Expand Down Expand Up @@ -569,7 +569,7 @@ std::shared_ptr<NodeExpression> Binder::createQueryNode(const std::string& parse
fieldTypes.emplace_back(property->dataType.copy());
}
auto extraInfo = std::make_unique<StructTypeInfo>(fieldNames, fieldTypes);
NodeType::setExtraTypeInfo(queryNode->getDataTypeReference(), std::move(extraInfo));
queryNode->setExtraTypeInfo(std::move(extraInfo));
return queryNode;
}

Expand Down
15 changes: 13 additions & 2 deletions src/binder/bind/read/bind_unwind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,23 @@ using namespace kuzu::common;
namespace kuzu {
namespace binder {

// E.g. UNWIND $1. We cannot validate $1 has data type LIST until we see the actual parameter.
static bool skipDataTypeValidation(const Expression& expr) {
return expr.expressionType == ExpressionType::PARAMETER &&
expr.getDataType().getLogicalTypeID() == LogicalTypeID::ANY;
}

std::unique_ptr<BoundReadingClause> Binder::bindUnwindClause(const ReadingClause& readingClause) {
auto& unwindClause = readingClause.constCast<UnwindClause>();
auto boundExpression = expressionBinder.bindExpression(*unwindClause.getExpression());
ExpressionUtil::validateDataType(*boundExpression, LogicalTypeID::LIST);
auto aliasName = unwindClause.getAlias();
auto alias = createVariable(aliasName, *ListType::getChildType(&boundExpression->dataType));
std::shared_ptr<Expression> alias;
if (!skipDataTypeValidation(*boundExpression)) {
ExpressionUtil::validateDataType(*boundExpression, LogicalTypeID::LIST);
alias = createVariable(aliasName, *ListType::getChildType(&boundExpression->dataType));
} else {
alias = createVariable(aliasName, *LogicalType::ANY());
}
std::shared_ptr<Expression> idExpr = nullptr;
if (scope.hasMemorizedTableIDs(boundExpression->getAlias())) {
auto tableIDs = scope.getMemorizedTableIDs(boundExpression->getAlias());
Expand Down
3 changes: 2 additions & 1 deletion src/binder/expression/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ add_library(
expression_util.cpp
function_expression.cpp
literal_expression.cpp
parameter_expression.cpp)
parameter_expression.cpp
variable_expression.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_binder_expression>
Expand Down
2 changes: 1 addition & 1 deletion src/binder/expression/parameter_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using namespace common;
namespace binder {

void ParameterExpression::cast(const LogicalType& type) {
if (dataType.getLogicalTypeID() != LogicalTypeID::ANY) {
if (!dataType.containsAny()) {
// LCOV_EXCL_START
throw BinderException(
stringFormat("Cannot change parameter expression data type from {} to {}.",
Expand Down
22 changes: 22 additions & 0 deletions src/binder/expression/variable_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "binder/expression/variable_expression.h"

#include "common/exception/binder.h"

using namespace kuzu::common;

namespace kuzu {
namespace binder {

void VariableExpression::cast(const LogicalType& type) {
if (!dataType.containsAny()) {
// LCOV_EXCL_START
throw BinderException(
stringFormat("Cannot change variable expression data type from {} to {}.",
dataType.toString(), type.toString()));
// LCOV_EXCL_STOP
}
dataType = type;
}

} // namespace binder
} // namespace kuzu
42 changes: 40 additions & 2 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,50 @@ static std::string unsupportedImplicitCastException(const Expression& expression
expression.toString(), expression.dataType.toString(), targetTypeStr);
}

static bool compatible(const LogicalType& type, const LogicalType& target) {
if (type.getLogicalTypeID() == LogicalTypeID::ANY) {
return true;
}
if (type.getLogicalTypeID() != target.getLogicalTypeID()) {
return false;
}
switch (type.getLogicalTypeID()) {
case LogicalTypeID::LIST: {
return compatible(*ListType::getChildType(&type), *ListType::getChildType(&target));
}
case LogicalTypeID::ARRAY: {
return compatible(*ArrayType::getChildType(&type), *ArrayType::getChildType(&target));
}
case LogicalTypeID::STRUCT: {
if (StructType::getNumFields(&type) != StructType::getNumFields(&target)) {
return false;
}
for (auto i = 0u; i < StructType::getNumFields(&type); ++i) {
if (!compatible(*StructType::getField(&type, i)->getType(),
*StructType::getField(&target, i)->getType())) {
return false;
}
}
return true;
}
case LogicalTypeID::RDF_VARIANT:
case LogicalTypeID::UNION:
case LogicalTypeID::MAP:
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL:
return false;
default:
return true;
}
}

std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, const LogicalType& targetType) {
if (targetType.getLogicalTypeID() == LogicalTypeID::ANY || expression->dataType == targetType) {
if (expression->dataType == targetType || targetType.containsAny()) { // No need to cast.
return expression;
}
if (expression->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
if (compatible(expression->getDataType(), targetType)) {
expression->cast(targetType);
return expression;
}
Expand Down
24 changes: 24 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ uint32_t PhysicalTypeUtils::getFixedTypeSize(PhysicalTypeID physicalType) {
}
}

bool ListTypeInfo::containsAny() const {
return childType->containsAny();
}

bool ListTypeInfo::operator==(const ExtraTypeInfo& other) const {
auto otherListTypeInfo = ku_dynamic_cast<const ExtraTypeInfo*, const ListTypeInfo*>(&other);
if (otherListTypeInfo) {
Expand Down Expand Up @@ -147,6 +151,10 @@ void ArrayTypeInfo::serializeInternal(Serializer& serializer) const {
serializer.serializeValue(numElements);
}

bool StructField::containsAny() const {
return type->containsAny();
}

bool StructField::operator==(const StructField& other) const {
return *type == *other.type;
}
Expand Down Expand Up @@ -238,6 +246,15 @@ std::vector<const StructField*> StructTypeInfo::getStructFields() const {
return structFields;
}

bool StructTypeInfo::containsAny() const {
for (auto& field : fields) {
if (field.containsAny()) {
return true;
}
}
return false;
}

bool StructTypeInfo::operator==(const ExtraTypeInfo& other) const {
auto otherStructTypeInfo = ku_dynamic_cast<const ExtraTypeInfo*, const StructTypeInfo*>(&other);
if (otherStructTypeInfo) {
Expand Down Expand Up @@ -314,6 +331,13 @@ LogicalType::LogicalType(const LogicalType& other) {
}
}

bool LogicalType::containsAny() const {
if (extraTypeInfo != nullptr) {
return extraTypeInfo->containsAny();
}
return typeID == LogicalTypeID::ANY;
}

LogicalType& LogicalType::operator=(const LogicalType& other) {
// Reuse the copy constructor and move assignment operator.
LogicalType copy(other);
Expand Down
2 changes: 1 addition & 1 deletion src/common/types/value/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ bool Value::operator==(const Value& rhs) const {
}

void Value::setDataType(const LogicalType& dataType_) {
KU_ASSERT(dataType->getLogicalTypeID() == LogicalTypeID::ANY);
KU_ASSERT(dataType->containsAny());
dataType = dataType_.copy();
}

Expand Down
3 changes: 2 additions & 1 deletion src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ class Expression : public std::enable_shared_from_this<Expression> {
}

virtual void cast(const common::LogicalType& type);
// NOTE: Avoid using the following unsafe getter. It is meant for resolving ANY data type only.
common::LogicalType& getDataTypeUnsafe() { return dataType; }
common::LogicalType getDataType() const { return dataType; }
common::LogicalType& getDataTypeReference() { return dataType; }

bool hasAlias() const { return !alias.empty(); }
std::string getAlias() const { return alias; }
Expand Down
49 changes: 26 additions & 23 deletions src/include/binder/expression/node_rel_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ class NodeOrRelExpression : public Expression {
variableName(std::move(variableName)), tableIDs{std::move(tableIDs)} {}
~NodeOrRelExpression() override = default;

inline std::string getVariableName() const { return variableName; }

inline void setTableIDs(common::table_id_vector_t tableIDs) {
this->tableIDs = std::move(tableIDs);
// Note: ideally I would try to remove this function. But for now, we have to create type
// after expression.
void setExtraTypeInfo(std::unique_ptr<common::ExtraTypeInfo> info) {
dataType.setExtraTypeInfo(std::move(info));
}
inline void addTableIDs(const common::table_id_vector_t& tableIDsToAdd) {

std::string getVariableName() const { return variableName; }

void setTableIDs(common::table_id_vector_t tableIDs_) { tableIDs = std::move(tableIDs_); }
void addTableIDs(const common::table_id_vector_t& tableIDsToAdd) {
auto tableIDsSet = getTableIDsSet();
for (auto tableID : tableIDsToAdd) {
if (!tableIDsSet.contains(tableID)) {
Expand All @@ -29,63 +33,62 @@ class NodeOrRelExpression : public Expression {
}
}

inline bool isMultiLabeled() const { return tableIDs.size() > 1; }
inline uint32_t getNumTableIDs() const { return tableIDs.size(); }
inline std::vector<common::table_id_t> getTableIDs() const { return tableIDs; }
inline std::unordered_set<common::table_id_t> getTableIDsSet() const {
bool isMultiLabeled() const { return tableIDs.size() > 1; }
uint32_t getNumTableIDs() const { return tableIDs.size(); }
std::vector<common::table_id_t> getTableIDs() const { return tableIDs; }
std::unordered_set<common::table_id_t> getTableIDsSet() const {
return {tableIDs.begin(), tableIDs.end()};
}
inline common::table_id_t getSingleTableID() const {
common::table_id_t getSingleTableID() const {
KU_ASSERT(tableIDs.size() == 1);
return tableIDs[0];
}

inline void addPropertyExpression(const std::string& propertyName,
void addPropertyExpression(const std::string& propertyName,
std::unique_ptr<Expression> property) {
KU_ASSERT(!propertyNameToIdx.contains(propertyName));
propertyNameToIdx.insert({propertyName, propertyExprs.size()});
propertyExprs.push_back(std::move(property));
}
inline bool hasPropertyExpression(const std::string& propertyName) const {
bool hasPropertyExpression(const std::string& propertyName) const {
return propertyNameToIdx.contains(propertyName);
}
inline std::shared_ptr<Expression> getPropertyExpression(
const std::string& propertyName) const {
std::shared_ptr<Expression> getPropertyExpression(const std::string& propertyName) const {
KU_ASSERT(propertyNameToIdx.contains(propertyName));
return propertyExprs[propertyNameToIdx.at(propertyName)]->copy();
}
inline const std::vector<std::unique_ptr<Expression>>& getPropertyExprsRef() const {
const std::vector<std::unique_ptr<Expression>>& getPropertyExprsRef() const {
return propertyExprs;
}
inline expression_vector getPropertyExprs() const {
expression_vector getPropertyExprs() const {
expression_vector result;
for (auto& expr : propertyExprs) {
result.push_back(expr->copy());
}
return result;
}

inline void setLabelExpression(std::shared_ptr<Expression> expression) {
void setLabelExpression(std::shared_ptr<Expression> expression) {
labelExpression = std::move(expression);
}
inline std::shared_ptr<Expression> getLabelExpression() const { return labelExpression; }
std::shared_ptr<Expression> getLabelExpression() const { return labelExpression; }

inline void addPropertyDataExpr(std::string propertyName, std::shared_ptr<Expression> expr) {
void addPropertyDataExpr(std::string propertyName, std::shared_ptr<Expression> expr) {
propertyDataExprs.insert({propertyName, expr});
}
inline const std::unordered_map<std::string, std::shared_ptr<Expression>>&
const std::unordered_map<std::string, std::shared_ptr<Expression>>&
getPropertyDataExprRef() const {
return propertyDataExprs;
}
inline bool hasPropertyDataExpr(const std::string& propertyName) const {
bool hasPropertyDataExpr(const std::string& propertyName) const {
return propertyDataExprs.contains(propertyName);
}
inline std::shared_ptr<Expression> getPropertyDataExpr(const std::string& propertyName) const {
std::shared_ptr<Expression> getPropertyDataExpr(const std::string& propertyName) const {
KU_ASSERT(propertyDataExprs.contains(propertyName));
return propertyDataExprs.at(propertyName);
}

inline std::string toStringInternal() const final { return variableName; }
std::string toStringInternal() const final { return variableName; }

protected:
std::string variableName;
Expand Down
8 changes: 5 additions & 3 deletions src/include/binder/expression/variable_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ class VariableExpression : public Expression {
: Expression{common::ExpressionType::VARIABLE, std::move(dataType), std::move(uniqueName)},
variableName{std::move(variableName)} {}

inline std::string getVariableName() const { return variableName; }
std::string getVariableName() const { return variableName; }

inline std::string toStringInternal() const final { return variableName; }
void cast(const common::LogicalType& type) override;

inline std::unique_ptr<Expression> copy() const final {
std::string toStringInternal() const final { return variableName; }

std::unique_ptr<Expression> copy() const final {
return std::make_unique<VariableExpression>(*dataType.copy(), uniqueName, variableName);
}

Expand Down
Loading
Loading