Skip to content

Commit

Permalink
Fix issue-3004
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Mar 13, 2024
1 parent 7a3ff60 commit f830e7e
Show file tree
Hide file tree
Showing 27 changed files with 185 additions and 71 deletions.
4 changes: 2 additions & 2 deletions src/binder/bind/bind_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ std::unique_ptr<BoundStatement> Binder::bindCopyToClause(const Statement& statem
}
auto csvConfig =
CSVReaderConfig::construct(bindParsingOptions(copyToStatement.getParsingOptionsRef()));
return std::make_unique<BoundCopyTo>(boundFilePath, fileType, std::move(columnNames),
std::move(columnTypes), std::move(query), csvConfig.option.copy());
return std::make_unique<BoundCopyTo>(
boundFilePath, fileType, std::move(query), csvConfig.option.copy());
}

std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& statement) {
Expand Down
10 changes: 0 additions & 10 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ expression_vector Binder::bindProjectionExpressions(
result.push_back(expressionBinder.bindExpression(*expression));
}
}
resolveAnyDataTypeWithDefaultType(result);
validateProjectionColumnNamesAreUnique(result);
return result;
}
Expand All @@ -224,7 +223,6 @@ expression_vector Binder::bindOrderByExpressions(
}
boundOrderByExpressions.push_back(std::move(boundExpression));
}
resolveAnyDataTypeWithDefaultType(boundOrderByExpressions);
return boundOrderByExpressions;
}

Expand Down Expand Up @@ -264,13 +262,5 @@ void Binder::addExpressionsToScope(const expression_vector& projectionExpression
}
}

void Binder::resolveAnyDataTypeWithDefaultType(const expression_vector& expressions) {
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
ExpressionBinder::implicitCastIfNecessary(expression, LogicalTypeID::STRING);
}
}
}

} // namespace binder
} // namespace kuzu
4 changes: 4 additions & 0 deletions src/binder/bound_statement_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "binder/rewriter/match_clause_pattern_label_rewriter.h"
#include "binder/rewriter/with_clause_projection_rewriter.h"
#include "binder/visitor/default_type_solver.h"

namespace kuzu {
namespace binder {
Expand All @@ -13,6 +14,9 @@ void BoundStatementRewriter::rewrite(

auto matchClausePatternLabelRewriter = MatchClausePatternLabelRewriter(catalog);
matchClausePatternLabelRewriter.visit(boundStatement);

auto defaultTypeSolver = DefaultTypeSolver();
defaultTypeSolver.visit(boundStatement);
}

} // namespace binder
Expand Down
16 changes: 16 additions & 0 deletions src/binder/bound_statement_visitor.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "binder/bound_statement_visitor.h"

#include "binder/bound_explain.h"
#include "binder/copy/bound_copy_from.h"
#include "binder/copy/bound_copy_to.h"
#include "binder/query/bound_regular_query.h"
#include "common/cast.h"

Expand Down Expand Up @@ -68,6 +70,20 @@ void BoundStatementVisitor::visitUnsafe(BoundStatement& statement) {
}
}

void BoundStatementVisitor::visitCopyFrom(const BoundStatement& statement) {
auto& copyFrom = ku_dynamic_cast<const BoundStatement&, const BoundCopyFrom&>(statement);
if (copyFrom.getInfo()->source->type == ScanSourceType::QUERY) {
auto querySource = ku_dynamic_cast<BoundBaseScanSource*, BoundQueryScanSource*>(
copyFrom.getInfo()->source.get());
visit(*querySource->statement);
}
}

void BoundStatementVisitor::visitCopyTo(const BoundStatement& statement) {
auto& copyTo = ku_dynamic_cast<const BoundStatement&, const BoundCopyTo&>(statement);
visitRegularQuery(*copyTo.getRegularQuery());
}

void BoundStatementVisitor::visitRegularQuery(const BoundStatement& statement) {
auto& regularQuery =
ku_dynamic_cast<const BoundStatement&, const BoundRegularQuery&>(statement);
Expand Down
4 changes: 3 additions & 1 deletion src/binder/expression/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ add_library(
case_expression.cpp
expression.cpp
expression_util.cpp
function_expression.cpp)
function_expression.cpp
literal_expression.cpp
parameter_expression.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_binder_expression>
Expand Down
9 changes: 9 additions & 0 deletions src/binder/expression/expression.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
#include "binder/expression/expression.h"

#include "common/exception/binder.h"

using namespace kuzu::common;

namespace kuzu {
namespace binder {

void Expression::cast(const LogicalType&) {

Check warning on line 10 in src/binder/expression/expression.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/expression/expression.cpp#L10

Added line #L10 was not covered by tests
// LCOV_EXCL_START
throw BinderException(
stringFormat("Data type of expression {} should not be modified.", toString()));
// LCOV_EXCL_STOP
}

expression_vector Expression::splitOnAND() {
expression_vector result;
if (ExpressionType::AND == expressionType) {
Expand Down
23 changes: 23 additions & 0 deletions src/binder/expression/literal_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "binder/expression/literal_expression.h"

#include "common/exception/binder.h"

namespace kuzu {
using namespace common;

namespace binder {

void LiteralExpression::cast(const LogicalType& type) {
if (dataType.getLogicalTypeID() != LogicalTypeID::ANY) {
// LCOV_EXCL_START
throw BinderException(
stringFormat("Cannot change literal expression data type from {} to {}.",
dataType.toString(), type.toString()));
// LCOV_EXCL_STOP
}
dataType = type;
value->setDataType(type);
}

} // namespace binder
} // namespace kuzu
23 changes: 23 additions & 0 deletions src/binder/expression/parameter_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "binder/expression/parameter_expression.h"

#include "common/exception/binder.h"

namespace kuzu {
using namespace common;

namespace binder {

void ParameterExpression::cast(const LogicalType& type) {
if (dataType.getLogicalTypeID() != LogicalTypeID::ANY) {
// LCOV_EXCL_START
throw BinderException(
stringFormat("Cannot change parameter expression data type from {} to {}.",
dataType.toString(), type.toString()));
// LCOV_EXCL_STOP
}
dataType = type;
value->setDataType(type);
}

} // namespace binder
} // namespace kuzu
11 changes: 1 addition & 10 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(
return expression;
}
if (expression->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
resolveAnyDataType(*expression, targetType);
expression->cast(targetType);
return expression;
}
return implicitCast(expression, targetType);
Expand All @@ -127,15 +127,6 @@ std::shared_ptr<Expression> ExpressionBinder::implicitCast(
}
}

void ExpressionBinder::resolveAnyDataType(Expression& expression, const LogicalType& targetType) {
if (expression.expressionType == ExpressionType::PARAMETER) { // expression is parameter
((ParameterExpression&)expression).setDataType(targetType);
} else { // expression is null literal
KU_ASSERT(expression.expressionType == ExpressionType::LITERAL);
((LiteralExpression&)expression).setDataType(targetType);
}
}

void ExpressionBinder::validateExpectedDataType(
const Expression& expression, const std::vector<LogicalTypeID>& targets) {
auto dataType = expression.dataType;
Expand Down
1 change: 1 addition & 0 deletions src/binder/visitor/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(
kuzu_binder_visitor
OBJECT
default_type_solver.cpp
property_collector.cpp)

set(ALL_OBJECT_FILES
Expand Down
25 changes: 25 additions & 0 deletions src/binder/visitor/default_type_solver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "binder/visitor/default_type_solver.h"

using namespace kuzu::common;

namespace kuzu {
namespace binder {

static void resolveAnyType(Expression& expr) {
if (expr.getDataType().getLogicalTypeID() != LogicalTypeID::ANY) {
return;
}
expr.cast(*LogicalType::STRING());
}

void DefaultTypeSolver::visitProjectionBody(const BoundProjectionBody& projectionBody) {
for (auto& expr : projectionBody.getProjectionExpressions()) {
resolveAnyType(*expr);
}
for (auto& expr : projectionBody.getOrderByExpressions()) {
resolveAnyType(*expr);
}
}

} // namespace binder
} // namespace kuzu
7 changes: 7 additions & 0 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ namespace common {

ValueVector::ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager)
: dataType{std::move(dataType)} {
if (this->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
// LCOV_EXCL_START
// Alternatively we can assign
throw RuntimeException("Trying to a create a vector with ANY type. This should not happen. "
"Data type is expected to be resolved during binding.");
// LCOV_EXCL_STOP
}
numBytesPerValue = getDataTypeSize(this->dataType);
initializeValueBuffer();
nullMask = std::make_unique<NullMask>();
Expand Down
2 changes: 1 addition & 1 deletion src/function/vector_list_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ std::unique_ptr<FunctionBindData> ListCreationFunction::bindFunc(
auto& parameterType = argument->getDataTypeReference();
if (parameterType != childType) {
if (parameterType.getLogicalTypeID() == LogicalTypeID::ANY) {
binder::ExpressionBinder::resolveAnyDataType(*argument, childType);
argument->cast(childType);
} else {
throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg(
LIST_CREATION_FUNC_NAME, arguments[0]->getDataType(), argument->getDataType()));
Expand Down
3 changes: 1 addition & 2 deletions src/function/vector_struct_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ std::unique_ptr<FunctionBindData> StructPackFunctions::bindFunc(
std::vector<StructField> fields;
for (auto& argument : arguments) {
if (argument->getDataType().getLogicalTypeID() == LogicalTypeID::ANY) {
binder::ExpressionBinder::resolveAnyDataType(
*argument, LogicalType{LogicalTypeID::STRING});
argument->cast(*LogicalType::STRING());
}
fields.emplace_back(argument->getAlias(), argument->getDataType().copy());
}
Expand Down
3 changes: 1 addition & 2 deletions src/function/vector_union_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ std::unique_ptr<FunctionBindData> UnionValueFunction::bindFunc(
fields.emplace_back(
UnionType::TAG_FIELD_NAME, std::make_unique<LogicalType>(UnionType::TAG_FIELD_TYPE));
if (arguments[0]->getDataType().getLogicalTypeID() == common::LogicalTypeID::ANY) {
binder::ExpressionBinder::resolveAnyDataType(
*arguments[0], LogicalType(LogicalTypeID::STRING));
arguments[0]->cast(*LogicalType::STRING());
}
fields.emplace_back(arguments[0]->getAlias(), arguments[0]->getDataType().copy());
auto resultType = LogicalType::UNION(std::move(fields));
Expand Down
1 change: 0 additions & 1 deletion src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ class Binder {
uint64_t bindSkipLimitExpression(const parser::ParsedExpression& expression);

void addExpressionsToScope(const expression_vector& projectionExpressions);
void resolveAnyDataTypeWithDefaultType(const expression_vector& expressions);

/*** bind graph pattern ***/
BoundGraphPattern bindGraphPattern(const std::vector<parser::PatternElement>& graphPattern);
Expand Down
4 changes: 2 additions & 2 deletions src/include/binder/bound_statement_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class BoundStatementVisitor {
virtual void visitCreateTable(const BoundStatement&) {}
virtual void visitDropTable(const BoundStatement&) {}
virtual void visitAlter(const BoundStatement&) {}
virtual void visitCopyFrom(const BoundStatement&) {}
virtual void visitCopyTo(const BoundStatement&) {}
virtual void visitCopyFrom(const BoundStatement&);
virtual void visitCopyTo(const BoundStatement&);
virtual void visitExportDatabase(const BoundStatement&) {}
virtual void visitImportDatabase(const BoundStatement&) {}
virtual void visitStandaloneCall(const BoundStatement&) {}
Expand Down
18 changes: 6 additions & 12 deletions src/include/binder/copy/bound_copy_to.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include "binder/query/bound_regular_query.h"
#include "binder/bound_statement.h"
#include "common/copier_config/csv_reader_config.h"
#include "common/copier_config/reader_config.h"

Expand All @@ -10,27 +10,21 @@ namespace binder {
class BoundCopyTo : public BoundStatement {
public:
BoundCopyTo(std::string filePath, common::FileType fileType,
std::vector<std::string> columnNames, std::vector<common::LogicalType> columnTypes,
std::unique_ptr<BoundRegularQuery> regularQuery, common::CSVOption csvOption)
std::unique_ptr<BoundStatement> query, common::CSVOption csvOption)
: BoundStatement{common::StatementType::COPY_TO, BoundStatementResult::createEmptyResult()},
filePath{std::move(filePath)}, fileType{fileType}, columnNames{std::move(columnNames)},
columnTypes{std::move(columnTypes)},
regularQuery{std::move(regularQuery)}, csvOption{std::move(csvOption)} {}
filePath{std::move(filePath)}, fileType{fileType}, query{std::move(query)},
csvOption{std::move(csvOption)} {}

inline std::string getFilePath() const { return filePath; }
inline common::FileType getFileType() const { return fileType; }
inline std::vector<std::string> getColumnNames() const { return columnNames; }
inline const std::vector<common::LogicalType>& getColumnTypesRef() const { return columnTypes; }

inline const BoundRegularQuery* getRegularQuery() const { return regularQuery.get(); }
inline const BoundStatement* getRegularQuery() const { return query.get(); }
inline const common::CSVOption* getCopyOption() const { return &csvOption; }

private:
std::string filePath;
common::FileType fileType;
std::vector<std::string> columnNames;
std::vector<common::LogicalType> columnTypes;
std::unique_ptr<BoundRegularQuery> regularQuery;
std::unique_ptr<BoundStatement> query;
common::CSVOption csvOption;
};

Expand Down
11 changes: 6 additions & 5 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,17 @@ class Expression : public std::enable_shared_from_this<Expression> {
DELETE_COPY_DEFAULT_MOVE(Expression);
virtual ~Expression() = default;

inline void setAlias(const std::string& name) { alias = name; }
void setAlias(const std::string& name) { alias = name; }

inline void setUniqueName(const std::string& name) { uniqueName = name; }
inline std::string getUniqueName() const {
void setUniqueName(const std::string& name) { uniqueName = name; }
std::string getUniqueName() const {
KU_ASSERT(!uniqueName.empty());
return uniqueName;
}

inline common::LogicalType getDataType() const { return dataType; }
inline common::LogicalType& getDataTypeReference() { return dataType; }
virtual void cast(const common::LogicalType& type);
common::LogicalType getDataType() const { return dataType; }
common::LogicalType& getDataTypeReference() { return dataType; }

inline bool hasAlias() const { return !alias.empty(); }
inline std::string getAlias() const { return alias; }
Expand Down
14 changes: 5 additions & 9 deletions src/include/binder/expression/literal_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,21 @@
namespace kuzu {
namespace binder {

class LiteralExpression : public Expression {
class LiteralExpression final : public Expression {
public:
LiteralExpression(std::unique_ptr<common::Value> value, const std::string& uniqueName)
: Expression{common::ExpressionType::LITERAL, *value->getDataType(), uniqueName},
value{std::move(value)} {}

inline bool isNull() const { return value->isNull(); }
bool isNull() const { return value->isNull(); }

inline void setDataType(const common::LogicalType& targetType) {
KU_ASSERT(dataType.getLogicalTypeID() == common::LogicalTypeID::ANY && isNull());
dataType = targetType;
value->setDataType(targetType);
}
void cast(const common::LogicalType& type) override;

inline common::Value* getValue() const { return value.get(); }
common::Value* getValue() const { return value.get(); }

std::string toStringInternal() const final { return value->toString(); }

inline std::unique_ptr<Expression> copy() const final {
std::unique_ptr<Expression> copy() const final {

Check warning on line 23 in src/include/binder/expression/literal_expression.h

View check run for this annotation

Codecov / codecov/patch

src/include/binder/expression/literal_expression.h#L23

Added line #L23 was not covered by tests
return std::make_unique<LiteralExpression>(value->copy(), uniqueName);
}

Expand Down
Loading

0 comments on commit f830e7e

Please sign in to comment.