Skip to content

Commit

Permalink
Remove unused code (#3335)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Apr 22, 2024
1 parent 8231ed3 commit 9594367
Show file tree
Hide file tree
Showing 14 changed files with 86 additions and 148 deletions.
2 changes: 1 addition & 1 deletion src/binder/bind/bind_ddl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ std::unique_ptr<BoundStatement> Binder::bindAddProperty(const Statement& stateme
if (dataType->getLogicalTypeID() == LogicalTypeID::SERIAL) {
throw BinderException("Serial property in node table must be the primary key.");
}
auto defaultVal = ExpressionBinder::implicitCastIfNecessary(
auto defaultVal = expressionBinder.implicitCastIfNecessary(
expressionBinder.bindExpression(*extraInfo->defaultValue), *dataType);
auto boundExtraInfo =
std::make_unique<BoundExtraAddPropertyInfo>(propertyName, *dataType, std::move(defaultVal));
Expand Down
6 changes: 3 additions & 3 deletions src/binder/bind/bind_graph_pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ std::shared_ptr<RelExpression> Binder::bindQueryRel(const RelPattern& relPattern
auto boundLhs =
expressionBinder.bindNodeOrRelPropertyExpression(*queryRel, propertyName);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = ExpressionBinder::implicitCastIfNecessary(boundRhs, boundLhs->dataType);
boundRhs = expressionBinder.implicitCastIfNecessary(boundRhs, boundLhs->dataType);
queryRel->addPropertyDataExpr(propertyName, std::move(boundRhs));
}
}
Expand Down Expand Up @@ -390,7 +390,7 @@ std::shared_ptr<RelExpression> Binder::createRecursiveQueryRel(const parser::Rel
for (auto& [propertyName, rhs] : relPattern.getPropertyKeyVals()) {
auto boundLhs = expressionBinder.bindNodeOrRelPropertyExpression(*rel, propertyName);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = ExpressionBinder::implicitCastIfNecessary(boundRhs, boundLhs->dataType);
boundRhs = expressionBinder.implicitCastIfNecessary(boundRhs, boundLhs->dataType);
auto predicate = expressionBinder.createEqualityComparisonExpression(boundLhs, boundRhs);
relPredicate = expressionBinder.combineBooleanExpressions(ExpressionType::AND, relPredicate,
predicate);
Expand Down Expand Up @@ -534,7 +534,7 @@ std::shared_ptr<NodeExpression> Binder::bindQueryNode(const NodePattern& nodePat
for (auto& [propertyName, rhs] : nodePattern.getPropertyKeyVals()) {
auto boundLhs = expressionBinder.bindNodeOrRelPropertyExpression(*queryNode, propertyName);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = ExpressionBinder::implicitCastIfNecessary(boundRhs, boundLhs->dataType);
boundRhs = expressionBinder.implicitCastIfNecessary(boundRhs, boundLhs->dataType);
queryNode->addPropertyDataExpr(propertyName, std::move(boundRhs));
}
queryGraph.addQueryNode(queryNode);
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ expression_vector Binder::bindInsertColumnDataExprs(
} else {
rhs = expressionBinder.createNullLiteralExpression();
}
rhs = ExpressionBinder::implicitCastIfNecessary(rhs, *property.getDataType());
rhs = expressionBinder.implicitCastIfNecessary(rhs, *property.getDataType());
result.push_back(std::move(rhs));
}
return result;
Expand Down Expand Up @@ -309,7 +309,7 @@ BoundSetPropertyInfo Binder::bindSetPropertyInfo(parser::ParsedExpression* lhs,
expression_pair Binder::bindSetItem(parser::ParsedExpression* lhs, parser::ParsedExpression* rhs) {
auto boundLhs = expressionBinder.bindExpression(*lhs);
auto boundRhs = expressionBinder.bindExpression(*rhs);
boundRhs = ExpressionBinder::implicitCastIfNecessary(boundRhs, boundLhs->dataType);
boundRhs = expressionBinder.implicitCastIfNecessary(boundRhs, boundLhs->dataType);
return make_pair(std::move(boundLhs), std::move(boundRhs));
}

Expand Down
2 changes: 1 addition & 1 deletion src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ std::unique_ptr<BoundStatement> Binder::bind(const Statement& statement) {

std::shared_ptr<Expression> Binder::bindWhereExpression(const ParsedExpression& parsedExpression) {
auto whereExpression = expressionBinder.bindExpression(parsedExpression);
ExpressionBinder::implicitCastIfNecessary(whereExpression, LogicalTypeID::BOOL);
expressionBinder.implicitCastIfNecessary(whereExpression, LogicalTypeID::BOOL);
return whereExpression;
}

Expand Down
41 changes: 4 additions & 37 deletions src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "binder/expression_binder.h"

#include "binder/binder.h"
#include "binder/expression/function_expression.h"
#include "binder/expression_visitor.h"
#include "common/exception/binder.h"
#include "common/exception/not_implemented.h"
Expand All @@ -16,6 +15,8 @@ using namespace kuzu::function;
namespace kuzu {
namespace binder {

static void validateAggregationExpressionIsNotNested(const Expression& expression);

std::shared_ptr<Expression> ExpressionBinder::bindExpression(
const parser::ParsedExpression& parsedExpression) {
std::shared_ptr<Expression> expression;
Expand Down Expand Up @@ -112,15 +113,7 @@ std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(
std::shared_ptr<Expression> ExpressionBinder::implicitCast(
const std::shared_ptr<Expression>& expression, const LogicalType& targetType) {
if (CastFunction::hasImplicitCast(expression->dataType, targetType)) {
auto functionName = stringFormat("CAST_TO({})", targetType.toString());
auto children = expression_vector{expression};
auto bindData = std::make_unique<FunctionBindData>(targetType.copy());
auto scalarFunction = CastFunction::bindCastFunction(functionName,
expression->dataType.getLogicalTypeID(), targetType.getLogicalTypeID());
auto uniqueName = ScalarFunctionExpression::getUniqueName(functionName, children);
return std::make_shared<ScalarFunctionExpression>(functionName, ExpressionType::FUNCTION,
std::move(bindData), std::move(children), scalarFunction->execFunc,
nullptr /* selectFunc */, std::move(uniqueName));
return forceCast(expression, targetType);
} else {
throw BinderException(unsupportedImplicitCastException(*expression, targetType.toString()));
}
Expand All @@ -135,33 +128,7 @@ std::shared_ptr<Expression> ExpressionBinder::forceCast(
return bindScalarFunctionExpression(children, functionName);
}

void ExpressionBinder::validateExpectedDataType(const Expression& expression,
const std::vector<LogicalTypeID>& targets) {
auto dataType = expression.dataType;
auto targetsSet = std::unordered_set<LogicalTypeID>{targets.begin(), targets.end()};
if (!targetsSet.contains(dataType.getLogicalTypeID())) {
throw BinderException(stringFormat("{} has data type {} but {} was expected.",
expression.toString(), LogicalTypeUtils::toString(dataType.getLogicalTypeID()),
LogicalTypeUtils::toString(targets)));
}
}

void ExpressionBinder::validateDataType(const Expression& expr, const LogicalType& expectedType) {
if (expr.getDataType() != expectedType) {
throw BinderException(stringFormat("{} has data type {} but {} was expected.",
expr.toString(), expr.getDataType().toString(), expectedType.toString()));
}
}

void ExpressionBinder::validateDataType(const Expression& expr, LogicalTypeID expectedTypeID) {
if (expr.getDataType().getLogicalTypeID() != expectedTypeID) {
throw BinderException(
stringFormat("{} has data type {} but {} was expected.", expr.toString(),
expr.getDataType().toString(), LogicalTypeUtils::toString(expectedTypeID)));
}
}

void ExpressionBinder::validateAggregationExpressionIsNotNested(const Expression& expression) {
void validateAggregationExpressionIsNotNested(const Expression& expression) {
if (expression.getNumChildren() == 0) {
return;
}
Expand Down
52 changes: 27 additions & 25 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,33 @@ void StructTypeInfo::serializeInternal(Serializer& serializer) const {
serializer.serializeVector(fields);
}

static std::string getIncompleteTypeErrMsg(LogicalTypeID id) {
return "Trying to create nested type " + LogicalTypeUtils::toString(id) +
" without child information.";
}

LogicalType::LogicalType(LogicalTypeID typeID) : typeID{typeID}, extraTypeInfo{nullptr} {
physicalType = getPhysicalType(typeID);
// Complex types should not use this constructor as they need extra type information
KU_ASSERT(physicalType != PhysicalTypeID::LIST);
KU_ASSERT(physicalType != PhysicalTypeID::ARRAY);
// Node/Rel types are exempted due to some complex code in bind_graph_pattern.cpp
KU_ASSERT(physicalType != PhysicalTypeID::STRUCT || typeID == LogicalTypeID::NODE ||
typeID == LogicalTypeID::REL || typeID == LogicalTypeID::RECURSIVE_REL);
// LCOV_EXCL_START
switch (physicalType) {
case PhysicalTypeID::LIST:
case PhysicalTypeID::ARRAY:
throw BinderException(getIncompleteTypeErrMsg(typeID));
case PhysicalTypeID::STRUCT: {
switch (typeID) {
// Node/Rel types are exempted due to some complex code in bind_graph_pattern.cpp
case LogicalTypeID::NODE:
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL:
return;
default:
throw BinderException(getIncompleteTypeErrMsg(typeID));
}
}
default:
return;
}
// LCOV_EXCL_STOP
}

LogicalType::LogicalType(LogicalTypeID typeID, std::unique_ptr<ExtraTypeInfo> extraTypeInfo)
Expand Down Expand Up @@ -416,6 +435,8 @@ LogicalType LogicalType::fromString(const std::string& str) {
dataType = *parseMapType(trimmedStr);
} else if (upperDataTypeString.starts_with("UNION")) {
dataType = *parseUnionType(trimmedStr);
} else if (upperDataTypeString == "RDF_VARIANT") {
dataType = *LogicalType::RDF_VARIANT();
} else {
dataType.typeID = strToLogicalTypeID(upperDataTypeString);
}
Expand Down Expand Up @@ -1394,25 +1415,6 @@ bool LogicalTypeUtils::tryGetMaxLogicalType(const LogicalType& left, const Logic
return true;
}

LogicalTypeID LogicalTypeUtils::getMaxLogicalTypeID(const LogicalTypeID& left,
const LogicalTypeID& right) {
LogicalTypeID result;
if (!tryGetMaxLogicalTypeID(left, right, result)) {
throw common::BinderException(stringFormat("Cannot combine logical types {} and {}",
toString(left), toString(right)));
}
return result;
}

LogicalType LogicalTypeUtils::getMaxLogicalType(const LogicalType& left, const LogicalType& right) {
LogicalType result;
if (!tryGetMaxLogicalType(left, right, result)) {
throw common::BinderException(stringFormat("Cannot combine logical types {} and {}",
left.toString(), right.toString()));
}
return result;
}

bool LogicalTypeUtils::tryGetMaxLogicalType(const std::vector<LogicalType>& types,
LogicalType& result) {
LogicalType combinedType(LogicalTypeID::ANY);
Expand Down
15 changes: 3 additions & 12 deletions src/include/binder/expression_binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,27 +104,18 @@ class ExpressionBinder {
const parser::ParsedExpression& parsedExpression);

/****** cast *****/
static std::shared_ptr<Expression> implicitCastIfNecessary(
std::shared_ptr<Expression> implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, common::LogicalTypeID targetTypeID);
static std::shared_ptr<Expression> implicitCastIfNecessary(
std::shared_ptr<Expression> implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, const common::LogicalType& targetType);
// Use implicitCast to cast to types you have obtained through known implicit casting rules.
// Use forceCast to cast to types you have obtained through other means, for example,
// through a maxLogicalType function
static std::shared_ptr<Expression> implicitCast(const std::shared_ptr<Expression>& expression,
std::shared_ptr<Expression> implicitCast(const std::shared_ptr<Expression>& expression,
const common::LogicalType& targetType);
std::shared_ptr<Expression> forceCast(const std::shared_ptr<Expression>& expression,
const common::LogicalType& targetType);

/****** validation *****/
// E.g. SUM(SUM(a.age)) is not allowed
static void validateAggregationExpressionIsNotNested(const Expression& expression);

void validateExpectedDataType(const Expression& expression,
const std::vector<common::LogicalTypeID>& targets);
void validateDataType(const Expression& expr, const common::LogicalType& expectedType);
void validateDataType(const Expression& expr, common::LogicalTypeID expectedTypeID);

private:
Binder* binder;
main::ClientContext* context;
Expand Down
12 changes: 0 additions & 12 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,19 +596,7 @@ struct LogicalTypeUtils {
LogicalTypeID& result);
static bool tryGetMaxLogicalType(const LogicalType& left, const LogicalType& right,
LogicalType& result);
static LogicalTypeID getMaxLogicalTypeID(const LogicalTypeID& left, const LogicalTypeID& right);
static LogicalType getMaxLogicalType(const LogicalType& left, const LogicalType& right);
static bool tryGetMaxLogicalType(const std::vector<LogicalType>& types, LogicalType& result);

private:
static LogicalTypeID dataTypeIDFromString(const std::string& trimmedStr);
static std::vector<std::string> parseStructFields(const std::string& structTypeStr);
static std::unique_ptr<LogicalType> parseListType(const std::string& trimmedStr);
static std::unique_ptr<LogicalType> parseArrayType(const std::string& trimmedStr);
static std::vector<StructField> parseStructTypeInfo(const std::string& structTypeStr);
static std::unique_ptr<LogicalType> parseStructType(const std::string& trimmedStr);
static std::unique_ptr<LogicalType> parseMapType(const std::string& trimmedStr);
static std::unique_ptr<LogicalType> parseUnionType(const std::string& trimmedStr);
};

enum class FileVersionType : uint8_t { ORIGINAL = 0, WAL_VERSION = 1 };
Expand Down
24 changes: 12 additions & 12 deletions src/include/function/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,23 @@ using aggr_combine_function_t =
using aggr_finalize_function_t = std::function<void(uint8_t* state)>;

struct AggregateFunction final : public BaseScalarFunction {
bool isDistinct;
aggr_initialize_function_t initializeFunc;
aggr_update_all_function_t updateAllFunc;
aggr_update_pos_function_t updatePosFunc;
aggr_combine_function_t combineFunc;
aggr_finalize_function_t finalizeFunc;
std::unique_ptr<AggregateState> initialNullAggregateState;
// Rewrite aggregate on NODE/REL, e.g. COUNT(a) -> COUNT(a._id)
param_rewrite_function_t paramRewriteFunc;

AggregateFunction(std::string name, std::vector<common::LogicalTypeID> parameterTypeIDs,
common::LogicalTypeID returnTypeID, aggr_initialize_function_t initializeFunc,
aggr_update_all_function_t updateAllFunc, aggr_update_pos_function_t updatePosFunc,
aggr_combine_function_t combineFunc, aggr_finalize_function_t finalizeFunc, bool isDistinct,
scalar_bind_func bindFunc = nullptr, param_rewrite_function_t paramRewriteFunc = nullptr)
: BaseScalarFunction{FunctionType::AGGREGATE, std::move(name), std::move(parameterTypeIDs),
returnTypeID, std::move(bindFunc)},
: BaseScalarFunction{std::move(name), std::move(parameterTypeIDs), returnTypeID,
std::move(bindFunc)},
isDistinct{isDistinct}, initializeFunc{std::move(initializeFunc)},
updateAllFunc{std::move(updateAllFunc)}, updatePosFunc{std::move(updatePosFunc)},
combineFunc{std::move(combineFunc)}, finalizeFunc{std::move(finalizeFunc)},
Expand Down Expand Up @@ -91,16 +101,6 @@ struct AggregateFunction final : public BaseScalarFunction {
initializeFunc, updateAllFunc, updatePosFunc, combineFunc, finalizeFunc, isDistinct,
bindFunc, paramRewriteFunc);
}

bool isDistinct;
aggr_initialize_function_t initializeFunc;
aggr_update_all_function_t updateAllFunc;
aggr_update_pos_function_t updatePosFunc;
aggr_combine_function_t combineFunc;
aggr_finalize_function_t finalizeFunc;
std::unique_ptr<AggregateState> initialNullAggregateState;
// Rewrite aggregate on NODE/REL, e.g. COUNT(a) -> COUNT(a._id)
param_rewrite_function_t paramRewriteFunc;
};

class AggregateFunctionUtil {
Expand Down
47 changes: 19 additions & 28 deletions src/include/function/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ using function_set = std::vector<std::unique_ptr<Function>>;
using scalar_bind_func = std::function<std::unique_ptr<FunctionBindData>(
const binder::expression_vector&, Function* definition)>;

enum class FunctionType : uint8_t {
UNKNOWN = 0,
SCALAR = 1,
REWRITE = 2,
AGGREGATE = 3,
TABLE = 4
};

struct Function {
Function() : type{FunctionType::UNKNOWN}, isVarLength{false} {};
Function(FunctionType type, std::string name,
std::vector<common::LogicalTypeID> parameterTypeIDs)
: type{type}, name{std::move(name)}, parameterTypeIDs{std::move(parameterTypeIDs)},
isVarLength{false} {}
std::string name;
std::vector<common::LogicalTypeID> parameterTypeIDs;
// Currently we only one variable-length function which is list creation. The expectation is
// that all parameters must have the same type as parameterTypes[0].
bool isVarLength;

Function() : isVarLength{false} {};
Function(std::string name, std::vector<common::LogicalTypeID> parameterTypeIDs)
: name{std::move(name)}, parameterTypeIDs{std::move(parameterTypeIDs)}, isVarLength{false} {
}
Function(const Function& other)
: name{other.name}, parameterTypeIDs{other.parameterTypeIDs},
isVarLength{other.isVarLength} {}

virtual ~Function() = default;

Expand All @@ -44,35 +44,26 @@ struct Function {

virtual std::unique_ptr<Function> copy() const = 0;

// TODO(Ziyi): Move to catalog entry once we have implemented the catalog entry.
FunctionType type;
std::string name;
std::vector<common::LogicalTypeID> parameterTypeIDs;
// Currently we only one variable-length function which is list creation. The expectation is
// that all parameters must have the same type as parameterTypes[0].
bool isVarLength;

template<class TARGET>
const TARGET* constPtrCast() const {
return common::ku_dynamic_cast<const Function*, const TARGET*>(this);
}
};

struct BaseScalarFunction : public Function {
BaseScalarFunction(FunctionType type, std::string name,
std::vector<common::LogicalTypeID> parameterTypeIDs, common::LogicalTypeID returnTypeID,
scalar_bind_func bindFunc)
: Function{type, std::move(name), std::move(parameterTypeIDs)}, returnTypeID{returnTypeID},
common::LogicalTypeID returnTypeID;
scalar_bind_func bindFunc;

BaseScalarFunction(std::string name, std::vector<common::LogicalTypeID> parameterTypeIDs,
common::LogicalTypeID returnTypeID, scalar_bind_func bindFunc)
: Function{std::move(name), std::move(parameterTypeIDs)}, returnTypeID{returnTypeID},
bindFunc{std::move(bindFunc)} {}

std::string signatureToString() const override {
auto result = Function::signatureToString();
result += " -> " + common::LogicalTypeUtils::toString(returnTypeID);
return result;
}

common::LogicalTypeID returnTypeID;
scalar_bind_func bindFunc;
};

} // namespace function
Expand Down
8 changes: 4 additions & 4 deletions src/include/function/rewrite_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@ using rewrite_func_rewrite_t = std::function<std::shared_ptr<binder::Expression>
// We write for the following functions
// ID(n) -> n._id
struct RewriteFunction final : public Function {
rewrite_func_rewrite_t rewriteFunc;

RewriteFunction(std::string name, std::vector<common::LogicalTypeID> parameterTypeIDs,
rewrite_func_rewrite_t rewriteFunc)
: Function{FunctionType::REWRITE, name, std::move(parameterTypeIDs)},
rewriteFunc{rewriteFunc} {}
: Function{name, std::move(parameterTypeIDs)}, rewriteFunc{rewriteFunc} {}
RewriteFunction(const RewriteFunction& other)
: Function{other}, rewriteFunc{other.rewriteFunc} {}

std::unique_ptr<Function> copy() const override {
return std::make_unique<RewriteFunction>(*this);
}

rewrite_func_rewrite_t rewriteFunc;
};

} // namespace function
Expand Down
Loading

0 comments on commit 9594367

Please sign in to comment.