Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue-3004 #3036

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/binder/bind/bind_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ std::unique_ptr<BoundStatement> Binder::bindCopyToClause(const Statement& statem
}
auto csvConfig =
CSVReaderConfig::construct(bindParsingOptions(copyToStatement.getParsingOptionsRef()));
return std::make_unique<BoundCopyTo>(boundFilePath, fileType, std::move(columnNames),
std::move(columnTypes), std::move(query), csvConfig.option.copy());
return std::make_unique<BoundCopyTo>(
boundFilePath, fileType, std::move(query), csvConfig.option.copy());
}

std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& statement) {
Expand Down
10 changes: 0 additions & 10 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ expression_vector Binder::bindProjectionExpressions(
result.push_back(expressionBinder.bindExpression(*expression));
}
}
resolveAnyDataTypeWithDefaultType(result);
validateProjectionColumnNamesAreUnique(result);
return result;
}
Expand All @@ -224,7 +223,6 @@ expression_vector Binder::bindOrderByExpressions(
}
boundOrderByExpressions.push_back(std::move(boundExpression));
}
resolveAnyDataTypeWithDefaultType(boundOrderByExpressions);
return boundOrderByExpressions;
}

Expand Down Expand Up @@ -264,13 +262,5 @@ void Binder::addExpressionsToScope(const expression_vector& projectionExpression
}
}

void Binder::resolveAnyDataTypeWithDefaultType(const expression_vector& expressions) {
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
ExpressionBinder::implicitCastIfNecessary(expression, LogicalTypeID::STRING);
}
}
}

} // namespace binder
} // namespace kuzu
4 changes: 4 additions & 0 deletions src/binder/bound_statement_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "binder/rewriter/match_clause_pattern_label_rewriter.h"
#include "binder/rewriter/with_clause_projection_rewriter.h"
#include "binder/visitor/default_type_solver.h"

namespace kuzu {
namespace binder {
Expand All @@ -13,6 +14,9 @@ void BoundStatementRewriter::rewrite(

auto matchClausePatternLabelRewriter = MatchClausePatternLabelRewriter(catalog);
matchClausePatternLabelRewriter.visit(boundStatement);

auto defaultTypeSolver = DefaultTypeSolver();
defaultTypeSolver.visit(boundStatement);
}

} // namespace binder
Expand Down
16 changes: 16 additions & 0 deletions src/binder/bound_statement_visitor.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "binder/bound_statement_visitor.h"

#include "binder/bound_explain.h"
#include "binder/copy/bound_copy_from.h"
#include "binder/copy/bound_copy_to.h"
#include "binder/query/bound_regular_query.h"
#include "common/cast.h"

Expand Down Expand Up @@ -68,6 +70,20 @@ void BoundStatementVisitor::visitUnsafe(BoundStatement& statement) {
}
}

void BoundStatementVisitor::visitCopyFrom(const BoundStatement& statement) {
auto& copyFrom = ku_dynamic_cast<const BoundStatement&, const BoundCopyFrom&>(statement);
if (copyFrom.getInfo()->source->type == ScanSourceType::QUERY) {
auto querySource = ku_dynamic_cast<BoundBaseScanSource*, BoundQueryScanSource*>(
copyFrom.getInfo()->source.get());
visit(*querySource->statement);
}
}

void BoundStatementVisitor::visitCopyTo(const BoundStatement& statement) {
auto& copyTo = ku_dynamic_cast<const BoundStatement&, const BoundCopyTo&>(statement);
visitRegularQuery(*copyTo.getRegularQuery());
}

void BoundStatementVisitor::visitRegularQuery(const BoundStatement& statement) {
auto& regularQuery =
ku_dynamic_cast<const BoundStatement&, const BoundRegularQuery&>(statement);
Expand Down
4 changes: 3 additions & 1 deletion src/binder/expression/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ add_library(
case_expression.cpp
expression.cpp
expression_util.cpp
function_expression.cpp)
function_expression.cpp
literal_expression.cpp
parameter_expression.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_binder_expression>
Expand Down
9 changes: 9 additions & 0 deletions src/binder/expression/expression.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
#include "binder/expression/expression.h"

#include "common/exception/binder.h"

using namespace kuzu::common;

namespace kuzu {
namespace binder {

void Expression::cast(const LogicalType&) {

Check warning on line 10 in src/binder/expression/expression.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/expression/expression.cpp#L10

Added line #L10 was not covered by tests
// LCOV_EXCL_START
throw BinderException(
stringFormat("Data type of expression {} should not be modified.", toString()));
// LCOV_EXCL_STOP
}

expression_vector Expression::splitOnAND() {
expression_vector result;
if (ExpressionType::AND == expressionType) {
Expand Down
23 changes: 23 additions & 0 deletions src/binder/expression/literal_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "binder/expression/literal_expression.h"

#include "common/exception/binder.h"

namespace kuzu {
using namespace common;

namespace binder {

void LiteralExpression::cast(const LogicalType& type) {
if (dataType.getLogicalTypeID() != LogicalTypeID::ANY) {
// LCOV_EXCL_START
throw BinderException(
stringFormat("Cannot change literal expression data type from {} to {}.",
dataType.toString(), type.toString()));
// LCOV_EXCL_STOP
}
dataType = type;
value->setDataType(type);
}

} // namespace binder
} // namespace kuzu
23 changes: 23 additions & 0 deletions src/binder/expression/parameter_expression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "binder/expression/parameter_expression.h"

#include "common/exception/binder.h"

namespace kuzu {
using namespace common;

namespace binder {

void ParameterExpression::cast(const LogicalType& type) {
if (dataType.getLogicalTypeID() != LogicalTypeID::ANY) {
// LCOV_EXCL_START
throw BinderException(
stringFormat("Cannot change parameter expression data type from {} to {}.",
dataType.toString(), type.toString()));
// LCOV_EXCL_STOP
}
dataType = type;
value->setDataType(type);
}

} // namespace binder
} // namespace kuzu
13 changes: 1 addition & 12 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

#include "binder/binder.h"
#include "binder/expression/function_expression.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression/parameter_expression.h"
#include "binder/expression_visitor.h"
#include "common/exception/binder.h"
#include "common/exception/not_implemented.h"
Expand Down Expand Up @@ -104,7 +102,7 @@ std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(
return expression;
}
if (expression->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
resolveAnyDataType(*expression, targetType);
expression->cast(targetType);
return expression;
}
return implicitCast(expression, targetType);
Expand All @@ -127,15 +125,6 @@ std::shared_ptr<Expression> ExpressionBinder::implicitCast(
}
}

void ExpressionBinder::resolveAnyDataType(Expression& expression, const LogicalType& targetType) {
if (expression.expressionType == ExpressionType::PARAMETER) { // expression is parameter
((ParameterExpression&)expression).setDataType(targetType);
} else { // expression is null literal
KU_ASSERT(expression.expressionType == ExpressionType::LITERAL);
((LiteralExpression&)expression).setDataType(targetType);
}
}

void ExpressionBinder::validateExpectedDataType(
const Expression& expression, const std::vector<LogicalTypeID>& targets) {
auto dataType = expression.dataType;
Expand Down
1 change: 1 addition & 0 deletions src/binder/visitor/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_library(
kuzu_binder_visitor
OBJECT
default_type_solver.cpp
property_collector.cpp)

set(ALL_OBJECT_FILES
Expand Down
25 changes: 25 additions & 0 deletions src/binder/visitor/default_type_solver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "binder/visitor/default_type_solver.h"

using namespace kuzu::common;

namespace kuzu {
namespace binder {

static void resolveAnyType(Expression& expr) {
if (expr.getDataType().getLogicalTypeID() != LogicalTypeID::ANY) {
return;
}
expr.cast(*LogicalType::STRING());
}

void DefaultTypeSolver::visitProjectionBody(const BoundProjectionBody& projectionBody) {
for (auto& expr : projectionBody.getProjectionExpressions()) {
resolveAnyType(*expr);
}
for (auto& expr : projectionBody.getOrderByExpressions()) {
resolveAnyType(*expr);
}
}

} // namespace binder
} // namespace kuzu
7 changes: 7 additions & 0 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ namespace common {

ValueVector::ValueVector(LogicalType dataType, storage::MemoryManager* memoryManager)
: dataType{std::move(dataType)} {
if (this->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
// LCOV_EXCL_START
// Alternatively we can assign
throw RuntimeException("Trying to a create a vector with ANY type. This should not happen. "
"Data type is expected to be resolved during binding.");
// LCOV_EXCL_STOP
}
numBytesPerValue = getDataTypeSize(this->dataType);
initializeValueBuffer();
nullMask = std::make_unique<NullMask>();
Expand Down
3 changes: 1 addition & 2 deletions src/function/vector_list_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "function/list/vector_list_functions.h"

#include "binder/expression_binder.h"
#include "common/exception/binder.h"
#include "common/exception/runtime.h"
#include "function/list/functions/list_any_value_function.h"
Expand Down Expand Up @@ -76,7 +75,7 @@ std::unique_ptr<FunctionBindData> ListCreationFunction::bindFunc(
auto& parameterType = argument->getDataTypeReference();
if (parameterType != childType) {
if (parameterType.getLogicalTypeID() == LogicalTypeID::ANY) {
binder::ExpressionBinder::resolveAnyDataType(*argument, childType);
argument->cast(childType);
} else {
throw BinderException(getListFunctionIncompatibleChildrenTypeErrorMsg(
LIST_CREATION_FUNC_NAME, arguments[0]->getDataType(), argument->getDataType()));
Expand Down
3 changes: 1 addition & 2 deletions src/function/vector_struct_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ std::unique_ptr<FunctionBindData> StructPackFunctions::bindFunc(
std::vector<StructField> fields;
for (auto& argument : arguments) {
if (argument->getDataType().getLogicalTypeID() == LogicalTypeID::ANY) {
binder::ExpressionBinder::resolveAnyDataType(
*argument, LogicalType{LogicalTypeID::STRING});
argument->cast(*LogicalType::STRING());
}
fields.emplace_back(argument->getAlias(), argument->getDataType().copy());
}
Expand Down
4 changes: 1 addition & 3 deletions src/function/vector_union_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "function/union/vector_union_functions.h"

#include "binder/expression_binder.h"
#include "function/struct/vector_struct_functions.h"
#include "function/union/functions/union_tag.h"

Expand All @@ -25,8 +24,7 @@ std::unique_ptr<FunctionBindData> UnionValueFunction::bindFunc(
fields.emplace_back(
UnionType::TAG_FIELD_NAME, std::make_unique<LogicalType>(UnionType::TAG_FIELD_TYPE));
if (arguments[0]->getDataType().getLogicalTypeID() == common::LogicalTypeID::ANY) {
binder::ExpressionBinder::resolveAnyDataType(
*arguments[0], LogicalType(LogicalTypeID::STRING));
arguments[0]->cast(*LogicalType::STRING());
}
fields.emplace_back(arguments[0]->getAlias(), arguments[0]->getDataType().copy());
auto resultType = LogicalType::UNION(std::move(fields));
Expand Down
1 change: 0 additions & 1 deletion src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ class Binder {
uint64_t bindSkipLimitExpression(const parser::ParsedExpression& expression);

void addExpressionsToScope(const expression_vector& projectionExpressions);
void resolveAnyDataTypeWithDefaultType(const expression_vector& expressions);

/*** bind graph pattern ***/
BoundGraphPattern bindGraphPattern(const std::vector<parser::PatternElement>& graphPattern);
Expand Down
4 changes: 2 additions & 2 deletions src/include/binder/bound_statement_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class BoundStatementVisitor {
virtual void visitCreateTable(const BoundStatement&) {}
virtual void visitDropTable(const BoundStatement&) {}
virtual void visitAlter(const BoundStatement&) {}
virtual void visitCopyFrom(const BoundStatement&) {}
virtual void visitCopyTo(const BoundStatement&) {}
virtual void visitCopyFrom(const BoundStatement&);
virtual void visitCopyTo(const BoundStatement&);
virtual void visitExportDatabase(const BoundStatement&) {}
virtual void visitImportDatabase(const BoundStatement&) {}
virtual void visitStandaloneCall(const BoundStatement&) {}
Expand Down
18 changes: 6 additions & 12 deletions src/include/binder/copy/bound_copy_to.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include "binder/query/bound_regular_query.h"
#include "binder/bound_statement.h"
#include "common/copier_config/csv_reader_config.h"
#include "common/copier_config/reader_config.h"

Expand All @@ -10,27 +10,21 @@ namespace binder {
class BoundCopyTo : public BoundStatement {
public:
BoundCopyTo(std::string filePath, common::FileType fileType,
std::vector<std::string> columnNames, std::vector<common::LogicalType> columnTypes,
std::unique_ptr<BoundRegularQuery> regularQuery, common::CSVOption csvOption)
std::unique_ptr<BoundStatement> query, common::CSVOption csvOption)
: BoundStatement{common::StatementType::COPY_TO, BoundStatementResult::createEmptyResult()},
filePath{std::move(filePath)}, fileType{fileType}, columnNames{std::move(columnNames)},
columnTypes{std::move(columnTypes)},
regularQuery{std::move(regularQuery)}, csvOption{std::move(csvOption)} {}
filePath{std::move(filePath)}, fileType{fileType}, query{std::move(query)},
csvOption{std::move(csvOption)} {}

inline std::string getFilePath() const { return filePath; }
inline common::FileType getFileType() const { return fileType; }
inline std::vector<std::string> getColumnNames() const { return columnNames; }
inline const std::vector<common::LogicalType>& getColumnTypesRef() const { return columnTypes; }

inline const BoundRegularQuery* getRegularQuery() const { return regularQuery.get(); }
inline const BoundStatement* getRegularQuery() const { return query.get(); }
inline const common::CSVOption* getCopyOption() const { return &csvOption; }

private:
std::string filePath;
common::FileType fileType;
std::vector<std::string> columnNames;
std::vector<common::LogicalType> columnTypes;
std::unique_ptr<BoundRegularQuery> regularQuery;
std::unique_ptr<BoundStatement> query;
common::CSVOption csvOption;
};

Expand Down
11 changes: 6 additions & 5 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,17 @@ class Expression : public std::enable_shared_from_this<Expression> {
DELETE_COPY_DEFAULT_MOVE(Expression);
virtual ~Expression() = default;

inline void setAlias(const std::string& name) { alias = name; }
void setAlias(const std::string& name) { alias = name; }

inline void setUniqueName(const std::string& name) { uniqueName = name; }
inline std::string getUniqueName() const {
void setUniqueName(const std::string& name) { uniqueName = name; }
std::string getUniqueName() const {
KU_ASSERT(!uniqueName.empty());
return uniqueName;
}

inline common::LogicalType getDataType() const { return dataType; }
inline common::LogicalType& getDataTypeReference() { return dataType; }
virtual void cast(const common::LogicalType& type);
common::LogicalType getDataType() const { return dataType; }
common::LogicalType& getDataTypeReference() { return dataType; }

inline bool hasAlias() const { return !alias.empty(); }
inline std::string getAlias() const { return alias; }
Expand Down
Loading
Loading