diff --git a/src/binder/bind/bind_projection_clause.cpp b/src/binder/bind/bind_projection_clause.cpp index 2c7461798e..8e3ab4faa9 100644 --- a/src/binder/bind/bind_projection_clause.cpp +++ b/src/binder/bind/bind_projection_clause.cpp @@ -227,14 +227,30 @@ expression_vector Binder::bindOrderByExpressions( uint64_t Binder::bindSkipLimitExpression(const ParsedExpression& expression) { auto boundExpression = expressionBinder.bindExpression(expression); - // We currently do not support the number of rows to skip/limit written as an expression (eg. - // SKIP 3 + 2 is not supported). - if (expression.getExpressionType() != LITERAL || - ((LiteralExpression&)(*boundExpression)).getDataType().getLogicalTypeID() != - LogicalTypeID::INT64) { - throw BinderException("The number of rows to skip/limit must be a non-negative integer."); + auto errorMsg = "The number of rows to skip/limit must be a non-negative integer."; + if (!ExpressionVisitor::isConstant(*boundExpression)) { + throw BinderException(errorMsg); } - return ((LiteralExpression&)(*boundExpression)).value->getValue(); + auto value = ((LiteralExpression&)(*boundExpression)).value.get(); + int64_t num = 0; + // TODO: replace the following switch with value.cast() + switch (value->getDataType()->getLogicalTypeID()) { + case LogicalTypeID::INT64: { + num = value->getValue(); + } break; + case LogicalTypeID::INT32: { + num = value->getValue(); + } break; + case LogicalTypeID::INT16: { + num = value->getValue(); + } break; + default: + throw BinderException(errorMsg); + } + if (num < 0) { + throw BinderException(errorMsg); + } + return num; } void Binder::addExpressionsToScope(const expression_vector& projectionExpressions) { diff --git a/src/binder/bind_expression/bind_function_expression.cpp b/src/binder/bind_expression/bind_function_expression.cpp index 9ea1d3784a..4ea44eac17 100644 --- a/src/binder/bind_expression/bind_function_expression.cpp +++ b/src/binder/bind_expression/bind_function_expression.cpp @@ -55,9 +55,6 @@ std::shared_ptr ExpressionBinder::bindScalarFunctionExpression( childrenTypes.push_back(child->dataType); } auto function = builtInFunctions->matchVectorFunction(functionName, childrenTypes); - if (builtInFunctions->canApplyStaticEvaluation(functionName, children)) { - return staticEvaluate(functionName, children); - } expression_vector childrenAfterCast; for (auto i = 0u; i < children.size(); ++i) { auto targetType = @@ -138,22 +135,6 @@ std::shared_ptr ExpressionBinder::bindMacroExpression( return bindExpression(*macroParameterReplacer->visit(std::move(macroExpr))); } -std::shared_ptr ExpressionBinder::staticEvaluate( - const std::string& functionName, const expression_vector& children) { - assert(children[0]->expressionType == LITERAL); - auto strVal = ((LiteralExpression*)children[0].get())->getValue()->getValue(); - std::unique_ptr value; - if (functionName == CAST_TO_DATE_FUNC_NAME) { - value = std::make_unique(Date::fromCString(strVal.c_str(), strVal.length())); - } else if (functionName == CAST_TO_TIMESTAMP_FUNC_NAME) { - value = std::make_unique(Timestamp::fromCString(strVal.c_str(), strVal.length())); - } else { - assert(functionName == CAST_TO_INTERVAL_FUNC_NAME); - value = std::make_unique(Interval::fromCString(strVal.c_str(), strVal.length())); - } - return createLiteralExpression(std::move(value)); -} - // Function rewriting happens when we need to expose internal property access through function so // that it becomes read-only or the function involves catalog information. Currently we write // Before | After diff --git a/src/binder/bind_expression/bind_property_expression.cpp b/src/binder/bind_expression/bind_property_expression.cpp index a3d9b05dbe..77b26668d7 100644 --- a/src/binder/bind_expression/bind_property_expression.cpp +++ b/src/binder/bind_expression/bind_property_expression.cpp @@ -80,7 +80,6 @@ std::shared_ptr ExpressionBinder::bindPropertyExpression( } else if (ExpressionUtil::isRelVariable(*child)) { return bindRelPropertyExpression(*child, propertyName); } else { - assert(child->expressionType == FUNCTION); return bindStructPropertyExpression(child, propertyName); } } diff --git a/src/binder/expression/case_expression.cpp b/src/binder/expression/case_expression.cpp index 8ec0eb280c..771f38eac7 100644 --- a/src/binder/expression/case_expression.cpp +++ b/src/binder/expression/case_expression.cpp @@ -3,7 +3,7 @@ namespace kuzu { namespace binder { -std::string CaseExpression::toString() const { +std::string CaseExpression::toStringInternal() const { std::string result = "CASE "; for (auto& caseAlternative : caseAlternatives) { result += "WHEN " + caseAlternative->whenExpression->toString() + " THEN " + diff --git a/src/binder/expression/function_expression.cpp b/src/binder/expression/function_expression.cpp index aa2d7bcb42..f3f4dd21f5 100644 --- a/src/binder/expression/function_expression.cpp +++ b/src/binder/expression/function_expression.cpp @@ -15,7 +15,7 @@ std::string ScalarFunctionExpression::getUniqueName( return result; } -std::string ScalarFunctionExpression::toString() const { +std::string ScalarFunctionExpression::toStringInternal() const { auto result = functionName + "("; result += ExpressionUtil::toString(children); result += ")"; @@ -35,7 +35,7 @@ std::string AggregateFunctionExpression::getUniqueName( return result; } -std::string AggregateFunctionExpression::toString() const { +std::string AggregateFunctionExpression::toStringInternal() const { auto result = functionName + "("; if (isDistinct()) { result += "DISTINCT "; diff --git a/src/binder/expression_binder.cpp b/src/binder/expression_binder.cpp index da99e9e4c4..c4bde52fbc 100644 --- a/src/binder/expression_binder.cpp +++ b/src/binder/expression_binder.cpp @@ -5,6 +5,7 @@ #include "binder/expression/literal_expression.h" #include "binder/expression/parameter_expression.h" #include "binder/expression_visitor.h" +#include "expression_evaluator/expression_evaluator_utils.h" #include "function/cast/vector_cast_functions.h" using namespace kuzu::common; @@ -47,9 +48,29 @@ std::shared_ptr ExpressionBinder::bindExpression( if (isExpressionAggregate(expression->expressionType)) { validateAggregationExpressionIsNotNested(*expression); } + if (ExpressionVisitor::needFold(*expression)) { + return foldExpression(expression); + } return expression; } +std::shared_ptr ExpressionBinder::foldExpression( + std::shared_ptr expression) { + auto value = evaluator::ExpressionEvaluatorUtils::evaluateConstantExpression( + expression, binder->memoryManager); + auto result = createLiteralExpression(std::move(value)); + // Fold result should preserve the alias original expression. E.g. + // RETURN 2, 1 + 1 AS x + // Once folded, 1 + 1 will become 2 and have the same identifier as the first RETURN element. + // We preserve alias (x) to avoid such conflict. + if (expression->hasAlias()) { + result->setAlias(expression->getAlias()); + } else { + result->setAlias(expression->toString()); + } + return result; +} + std::shared_ptr ExpressionBinder::implicitCastIfNecessary( const std::shared_ptr& expression, const LogicalType& targetType) { if (targetType.getLogicalTypeID() == LogicalTypeID::ANY || expression->dataType == targetType) { @@ -127,7 +148,7 @@ void ExpressionBinder::validateAggregationExpressionIsNotNested(const Expression if (expression.getNumChildren() == 0) { return; } - if (ExpressionVisitor::hasAggregateExpression(*expression.getChild(0))) { + if (ExpressionVisitor::hasAggregate(*expression.getChild(0))) { throw BinderException( "Expression " + expression.toString() + " contains nested aggregation."); } diff --git a/src/binder/expression_visitor.cpp b/src/binder/expression_visitor.cpp index 6e8c7b54c8..6d91c38026 100644 --- a/src/binder/expression_visitor.cpp +++ b/src/binder/expression_visitor.cpp @@ -84,13 +84,29 @@ expression_vector ExpressionChildrenCollector::collectRelChildren(const Expressi return result; } -bool ExpressionVisitor::hasExpression( +// isConstant requires all children to be constant. +bool ExpressionVisitor::isConstant(const Expression& expression) { + if (expression.expressionType == ExpressionType::AGGREGATE_FUNCTION) { + return false; // We don't have a framework to fold aggregated constant. + } + if (expression.getNumChildren() == 0) { + return expression.expressionType == ExpressionType::LITERAL; + } + for (auto& child : ExpressionChildrenCollector::collectChildren(expression)) { + if (!isConstant(*child)) { + return false; + } + } + return true; +} + +bool ExpressionVisitor::satisfyAny( const Expression& expression, const std::function& condition) { if (condition(expression)) { return true; } for (auto& child : ExpressionChildrenCollector::collectChildren(expression)) { - if (hasExpression(*child, condition)) { + if (satisfyAny(*child, condition)) { return true; } } diff --git a/src/common/vector/value_vector.cpp b/src/common/vector/value_vector.cpp index 3f2ebdb39e..cd716bb2cb 100644 --- a/src/common/vector/value_vector.cpp +++ b/src/common/vector/value_vector.cpp @@ -124,6 +124,11 @@ void ValueVector::copyFromVectorData( } void ValueVector::copyFromValue(uint64_t pos, const Value& value) { + if (value.isNull()) { + setNull(pos, true); + return; + } + setNull(pos, false); auto dstValue = valueBuffer.get() + pos * numBytesPerValue; switch (dataType.getPhysicalType()) { case PhysicalTypeID::INT64: { @@ -171,6 +176,63 @@ void ValueVector::copyFromValue(uint64_t pos, const Value& value) { } } +std::unique_ptr ValueVector::getAsValue(uint64_t pos) { + if (isNull(pos)) { + return Value::createNullValue(dataType).copy(); + } + auto value = Value::createDefaultValue(dataType).copy(); + switch (dataType.getPhysicalType()) { + case PhysicalTypeID::INT64: { + value->val.int64Val = getValue(pos); + } break; + case PhysicalTypeID::INT32: { + value->val.int32Val = getValue(pos); + } break; + case PhysicalTypeID::INT16: { + value->val.int16Val = getValue(pos); + } break; + case PhysicalTypeID::DOUBLE: { + value->val.doubleVal = getValue(pos); + } break; + case PhysicalTypeID::FLOAT: { + value->val.floatVal = getValue(pos); + } break; + case PhysicalTypeID::BOOL: { + value->val.booleanVal = getValue(pos); + } break; + case PhysicalTypeID::INTERVAL: { + value->val.intervalVal = getValue(pos); + } break; + case PhysicalTypeID::STRING: { + value->strVal = getValue(pos).getAsString(); + } break; + case PhysicalTypeID::VAR_LIST: { + auto dataVector = ListVector::getDataVector(this); + auto listEntry = getValue(pos); + std::vector> children; + children.reserve(listEntry.size); + for (auto i = 0u; i < listEntry.size; ++i) { + children.push_back(dataVector->getAsValue(listEntry.offset + i)); + } + value->childrenSize = children.size(); + value->children = std::move(children); + } break; + case PhysicalTypeID::STRUCT: { + auto& fieldVectors = StructVector::getFieldVectors(this); + std::vector> children; + children.reserve(fieldVectors.size()); + for (auto& fieldVector : fieldVectors) { + children.push_back(fieldVector->getAsValue(pos)); + } + value->childrenSize = children.size(); + value->children = std::move(children); + } break; + default: + throw NotImplementedException("ValueVector::getAsValue"); + } + return value; +} + void ValueVector::resetAuxiliaryBuffer() { switch (dataType.getPhysicalType()) { case PhysicalTypeID::STRING: { diff --git a/src/expression_evaluator/CMakeLists.txt b/src/expression_evaluator/CMakeLists.txt index 41af093e08..b7873dbbfc 100644 --- a/src/expression_evaluator/CMakeLists.txt +++ b/src/expression_evaluator/CMakeLists.txt @@ -1,7 +1,8 @@ add_library(kuzu_expression_evaluator OBJECT - base_evaluator.cpp case_evaluator.cpp + expression_evaluator.cpp + expression_evaluator_utils.cpp function_evaluator.cpp literal_evaluator.cpp node_rel_evaluator.cpp diff --git a/src/expression_evaluator/base_evaluator.cpp b/src/expression_evaluator/expression_evaluator.cpp similarity index 94% rename from src/expression_evaluator/base_evaluator.cpp rename to src/expression_evaluator/expression_evaluator.cpp index cf5a14a433..2f205be1f7 100644 --- a/src/expression_evaluator/base_evaluator.cpp +++ b/src/expression_evaluator/expression_evaluator.cpp @@ -1,4 +1,4 @@ -#include "expression_evaluator/base_evaluator.h" +#include "expression_evaluator/expression_evaluator.h" using namespace kuzu::common; diff --git a/src/expression_evaluator/expression_evaluator_utils.cpp b/src/expression_evaluator/expression_evaluator_utils.cpp new file mode 100644 index 0000000000..2273b7f28e --- /dev/null +++ b/src/expression_evaluator/expression_evaluator_utils.cpp @@ -0,0 +1,23 @@ +#include "expression_evaluator/expression_evaluator_utils.h" + +#include "processor/expression_mapper.h" + +using namespace kuzu::common; +using namespace kuzu::processor; + +namespace kuzu { +namespace evaluator { + +std::unique_ptr ExpressionEvaluatorUtils::evaluateConstantExpression( + const std::shared_ptr& expression, storage::MemoryManager* memoryManager) { + auto evaluator = ExpressionMapper::getConstantEvaluator(expression); + auto emptyResultSet = std::make_unique(0); + evaluator->init(*emptyResultSet, memoryManager); + evaluator->evaluate(); + auto selVector = evaluator->resultVector->state->selVector.get(); + assert(selVector->selectedSize == 1); + return evaluator->resultVector->getAsValue(selVector->selectedPositions[0]); +} + +} // namespace evaluator +} // namespace kuzu diff --git a/src/expression_evaluator/path_evaluator.cpp b/src/expression_evaluator/path_evaluator.cpp index fa75548072..a223954974 100644 --- a/src/expression_evaluator/path_evaluator.cpp +++ b/src/expression_evaluator/path_evaluator.cpp @@ -40,8 +40,9 @@ void PathExpressionEvaluator::init( for (auto& fieldVector : StructVector::getFieldVectors(resultRelsDataVector)) { resultRelsFieldVectors.push_back(fieldVector.get()); } - for (auto i = 0u; i < pathExpression->getNumChildren(); ++i) { - auto child = pathExpression->getChild(i).get(); + auto pathExpression = (PathExpression*)expression.get(); + for (auto i = 0u; i < expression->getNumChildren(); ++i) { + auto child = expression->getChild(i).get(); auto vectors = std::make_unique(); vectors->input = children[i]->resultVector.get(); switch (child->dataType.getLogicalTypeID()) { @@ -99,8 +100,8 @@ static inline uint32_t getCurrentPos(ValueVector* vector, uint32_t pos) { void PathExpressionEvaluator::copyNodes(sel_t resultPos) { auto listSize = 0u; // Calculate list size. - for (auto i = 0; i < pathExpression->getNumChildren(); ++i) { - auto child = pathExpression->getChild(i).get(); + for (auto i = 0; i < expression->getNumChildren(); ++i) { + auto child = expression->getChild(i).get(); switch (child->dataType.getLogicalTypeID()) { case LogicalTypeID::NODE: { listSize++; @@ -119,8 +120,8 @@ void PathExpressionEvaluator::copyNodes(sel_t resultPos) { resultNodesVector->setValue(resultPos, entry); // Copy field vectors offset_t resultDataPos = entry.offset; - for (auto i = 0; i < pathExpression->getNumChildren(); ++i) { - auto child = pathExpression->getChild(i).get(); + for (auto i = 0; i < expression->getNumChildren(); ++i) { + auto child = expression->getChild(i).get(); auto vectors = inputVectorsPerChild[i].get(); auto inputPos = getCurrentPos(vectors->input, resultPos); switch (child->dataType.getLogicalTypeID()) { @@ -144,8 +145,8 @@ void PathExpressionEvaluator::copyNodes(sel_t resultPos) { void PathExpressionEvaluator::copyRels(sel_t resultPos) { auto listSize = 0u; // Calculate list size. - for (auto i = 0; i < pathExpression->getNumChildren(); ++i) { - auto child = pathExpression->getChild(i).get(); + for (auto i = 0; i < expression->getNumChildren(); ++i) { + auto child = expression->getChild(i).get(); switch (child->dataType.getLogicalTypeID()) { case LogicalTypeID::REL: { listSize++; @@ -164,8 +165,8 @@ void PathExpressionEvaluator::copyRels(sel_t resultPos) { resultRelsVector->setValue(resultPos, entry); // Copy field vectors offset_t resultDataPos = entry.offset; - for (auto i = 0; i < pathExpression->getNumChildren(); ++i) { - auto child = pathExpression->getChild(i).get(); + for (auto i = 0; i < expression->getNumChildren(); ++i) { + auto child = expression->getChild(i).get(); auto vectors = inputVectorsPerChild[i].get(); auto inputPos = getCurrentPos(vectors->input, resultPos); switch (child->dataType.getLogicalTypeID()) { @@ -206,7 +207,7 @@ void PathExpressionEvaluator::copyFieldVectors(offset_t inputVectorPos, void PathExpressionEvaluator::resolveResultVector( const processor::ResultSet& resultSet, storage::MemoryManager* memoryManager) { - resultVector = std::make_shared(pathExpression->getDataType(), memoryManager); + resultVector = std::make_shared(expression->getDataType(), memoryManager); std::vector inputEvaluators; inputEvaluators.reserve(children.size()); for (auto& child : children) { diff --git a/src/function/built_in_vector_functions.cpp b/src/function/built_in_vector_functions.cpp index acaffdbf65..d33f96ecf1 100644 --- a/src/function/built_in_vector_functions.cpp +++ b/src/function/built_in_vector_functions.cpp @@ -38,17 +38,6 @@ void BuiltInVectorFunctions::registerVectorFunctions() { registerBlobFunctions(); } -bool BuiltInVectorFunctions::canApplyStaticEvaluation( - const std::string& functionName, const binder::expression_vector& children) { - if ((functionName == CAST_TO_DATE_FUNC_NAME || functionName == CAST_TO_TIMESTAMP_FUNC_NAME || - functionName == CAST_TO_INTERVAL_FUNC_NAME) && - children[0]->expressionType == LITERAL && - children[0]->dataType.getLogicalTypeID() == LogicalTypeID::STRING) { - return true; // bind as literal - } - return false; -} - VectorFunctionDefinition* BuiltInVectorFunctions::matchVectorFunction( const std::string& name, const std::vector& inputTypes) { auto& functionDefinitions = vectorFunctions.at(name); diff --git a/src/include/binder/binder.h b/src/include/binder/binder.h index 56f3d38957..c4aa4f6bec 100644 --- a/src/include/binder/binder.h +++ b/src/include/binder/binder.h @@ -57,9 +57,11 @@ class Binder { friend class ExpressionBinder; public: - explicit Binder(const catalog::Catalog& catalog, main::ClientContext* clientContext) - : catalog{catalog}, lastExpressionId{0}, scope{std::make_unique()}, - expressionBinder{this}, clientContext{clientContext} {} + explicit Binder(const catalog::Catalog& catalog, storage::MemoryManager* memoryManager, + main::ClientContext* clientContext) + : catalog{catalog}, memoryManager{memoryManager}, lastExpressionId{0}, + scope{std::make_unique()}, expressionBinder{this}, clientContext{ + clientContext} {} std::unique_ptr bind(const parser::Statement& statement); @@ -241,6 +243,7 @@ class Binder { private: const catalog::Catalog& catalog; + storage::MemoryManager* memoryManager; uint32_t lastExpressionId; std::unique_ptr scope; ExpressionBinder expressionBinder; diff --git a/src/include/binder/expression/case_expression.h b/src/include/binder/expression/case_expression.h index f9c83ddded..01deaafb2d 100644 --- a/src/include/binder/expression/case_expression.h +++ b/src/include/binder/expression/case_expression.h @@ -32,7 +32,7 @@ class CaseExpression : public Expression { inline std::shared_ptr getElseExpression() const { return elseExpression; } - std::string toString() const override; + std::string toStringInternal() const final; private: std::vector> caseAlternatives; diff --git a/src/include/binder/expression/existential_subquery_expression.h b/src/include/binder/expression/existential_subquery_expression.h index ae13edbc9b..54cb59ebc6 100644 --- a/src/include/binder/expression/existential_subquery_expression.h +++ b/src/include/binder/expression/existential_subquery_expression.h @@ -27,7 +27,7 @@ class ExistentialSubqueryExpression : public Expression { return hasWhereExpression() ? whereExpression->splitOnAND() : expression_vector{}; } - std::string toString() const override { return rawName; } + std::string toStringInternal() const final { return rawName; } private: std::unique_ptr queryGraphCollection; diff --git a/src/include/binder/expression/expression.h b/src/include/binder/expression/expression.h index 293f490072..90dfd0ab02 100644 --- a/src/include/binder/expression/expression.h +++ b/src/include/binder/expression/expression.h @@ -33,29 +33,24 @@ class Expression : public std::enable_shared_from_this { expression_vector children, std::string uniqueName) : expressionType{expressionType}, dataType{std::move(dataType)}, uniqueName{std::move(uniqueName)}, children{std::move(children)} {} - // Create binary expression. Expression(common::ExpressionType expressionType, common::LogicalType dataType, const std::shared_ptr& left, const std::shared_ptr& right, std::string uniqueName) : Expression{expressionType, std::move(dataType), expression_vector{left, right}, std::move(uniqueName)} {} - // Create unary expression. Expression(common::ExpressionType expressionType, common::LogicalType dataType, const std::shared_ptr& child, std::string uniqueName) : Expression{expressionType, std::move(dataType), expression_vector{child}, std::move(uniqueName)} {} - // Create leaf expression Expression( common::ExpressionType expressionType, common::LogicalType dataType, std::string uniqueName) : Expression{ expressionType, std::move(dataType), expression_vector{}, std::move(uniqueName)} {} - virtual ~Expression() = default; -public: inline void setAlias(const std::string& name) { alias = name; } inline std::string getUniqueName() const { @@ -67,14 +62,13 @@ class Expression : public std::enable_shared_from_this { inline common::LogicalType& getDataTypeReference() { return dataType; } inline bool hasAlias() const { return !alias.empty(); } - inline std::string getAlias() const { return alias; } inline uint32_t getNumChildren() const { return children.size(); } - inline std::shared_ptr getChild(common::vector_idx_t idx) const { return children[idx]; } + inline expression_vector getChildren() const { return children; } inline void setChild(common::vector_idx_t idx, std::shared_ptr child) { children[idx] = std::move(child); } @@ -83,12 +77,15 @@ class Expression : public std::enable_shared_from_this { inline bool operator==(const Expression& rhs) const { return uniqueName == rhs.uniqueName; } - virtual std::string toString() const = 0; + std::string toString() const { return hasAlias() ? alias : toStringInternal(); } virtual std::unique_ptr copy() const { throw common::InternalException("Unimplemented expression copy()."); } +protected: + virtual std::string toStringInternal() const = 0; + public: common::ExpressionType expressionType; common::LogicalType dataType; diff --git a/src/include/binder/expression/function_expression.h b/src/include/binder/expression/function_expression.h index f715e559e0..d6a6aa35bb 100644 --- a/src/include/binder/expression/function_expression.h +++ b/src/include/binder/expression/function_expression.h @@ -31,7 +31,7 @@ class FunctionExpression : public Expression { inline std::string getFunctionName() const { return functionName; } inline function::FunctionBindData* getBindData() const { return bindData.get(); } - std::string toString() const override = 0; + std::string toStringInternal() const override = 0; protected: std::string functionName; @@ -59,7 +59,7 @@ class ScalarFunctionExpression : public FunctionExpression { static std::string getUniqueName(const std::string& functionName, expression_vector& children); - std::string toString() const override; + std::string toStringInternal() const final; public: function::scalar_exec_func execFunc; @@ -89,7 +89,7 @@ class AggregateFunctionExpression : public FunctionExpression { inline bool isDistinct() const { return aggregateFunction->isFunctionDistinct(); } - std::string toString() const override; + std::string toStringInternal() const final; public: std::unique_ptr aggregateFunction; diff --git a/src/include/binder/expression/literal_expression.h b/src/include/binder/expression/literal_expression.h index ce51964fdf..2e4b3dae37 100644 --- a/src/include/binder/expression/literal_expression.h +++ b/src/include/binder/expression/literal_expression.h @@ -21,7 +21,7 @@ class LiteralExpression : public Expression { inline common::Value* getValue() const { return value.get(); } - std::string toString() const override { return value->toString(); } + std::string toStringInternal() const final { return value->toString(); } public: std::unique_ptr value; diff --git a/src/include/binder/expression/node_rel_expression.h b/src/include/binder/expression/node_rel_expression.h index 745e73ebfe..8c7d11ce4f 100644 --- a/src/include/binder/expression/node_rel_expression.h +++ b/src/include/binder/expression/node_rel_expression.h @@ -59,7 +59,7 @@ class NodeOrRelExpression : public Expression { } inline std::shared_ptr getLabelExpression() const { return labelExpression; } - std::string toString() const override { return variableName; } + inline std::string toStringInternal() const final { return variableName; } protected: std::string variableName; diff --git a/src/include/binder/expression/parameter_expression.h b/src/include/binder/expression/parameter_expression.h index 6ad577d6aa..f1377cf8e9 100644 --- a/src/include/binder/expression/parameter_expression.h +++ b/src/include/binder/expression/parameter_expression.h @@ -22,7 +22,7 @@ class ParameterExpression : public Expression { inline std::shared_ptr getLiteral() const { return value; } - std::string toString() const override { return "$" + parameterName; } + inline std::string toStringInternal() const final { return "$" + parameterName; } private: inline static std::string createUniqueName(const std::string& input) { return "$" + input; } diff --git a/src/include/binder/expression/path_expression.h b/src/include/binder/expression/path_expression.h index 0d60b2c89d..cf972af430 100644 --- a/src/include/binder/expression/path_expression.h +++ b/src/include/binder/expression/path_expression.h @@ -18,7 +18,7 @@ class PathExpression : public Expression { inline std::shared_ptr getNode() const { return node; } inline std::shared_ptr getRel() const { return rel; } - inline std::string toString() const override { return variableName; } + inline std::string toStringInternal() const final { return variableName; } private: std::string variableName; diff --git a/src/include/binder/expression/property_expression.h b/src/include/binder/expression/property_expression.h index 7e96cb08cb..d2a0ce8194 100644 --- a/src/include/binder/expression/property_expression.h +++ b/src/include/binder/expression/property_expression.h @@ -44,7 +44,9 @@ class PropertyExpression : public Expression { return make_unique(*this); } - inline std::string toString() const override { return rawVariableName + "." + propertyName; } + inline std::string toStringInternal() const final { + return rawVariableName + "." + propertyName; + } private: bool isPrimaryKey_ = false; diff --git a/src/include/binder/expression/variable_expression.h b/src/include/binder/expression/variable_expression.h index c1d89bf3aa..a69bc9f110 100644 --- a/src/include/binder/expression/variable_expression.h +++ b/src/include/binder/expression/variable_expression.h @@ -12,7 +12,7 @@ class VariableExpression : public Expression { : Expression{common::VARIABLE, dataType, std::move(uniqueName)}, variableName{std::move( variableName)} {} - std::string toString() const override { return variableName; } + inline std::string toStringInternal() const final { return variableName; } private: std::string variableName; diff --git a/src/include/binder/expression_binder.h b/src/include/binder/expression_binder.h index 1805d6fb7d..5eae97d5f9 100644 --- a/src/include/binder/expression_binder.h +++ b/src/include/binder/expression_binder.h @@ -22,6 +22,9 @@ class ExpressionBinder { static void resolveAnyDataType(Expression& expression, const common::LogicalType& targetType); private: + // TODO(Xiyang): move to an expression rewriter + std::shared_ptr foldExpression(std::shared_ptr expression); + // Boolean expressions. std::shared_ptr bindBooleanExpression( const parser::ParsedExpression& parsedExpression); @@ -67,8 +70,6 @@ class ExpressionBinder { bool isDistinct); std::shared_ptr bindMacroExpression( const parser::ParsedExpression& parsedExpression, const std::string& macroName); - std::shared_ptr staticEvaluate( - const std::string& functionName, const expression_vector& children); std::shared_ptr rewriteFunctionExpression( const parser::ParsedExpression& parsedExpression, const std::string& functionName); diff --git a/src/include/binder/expression_visitor.h b/src/include/binder/expression_visitor.h index 7ac518ddd5..7829d2f706 100644 --- a/src/include/binder/expression_visitor.h +++ b/src/include/binder/expression_visitor.h @@ -21,20 +21,28 @@ class ExpressionChildrenCollector { class ExpressionVisitor { public: - static bool hasAggregateExpression(const Expression& expression) { - return hasExpression(expression, [&](const Expression& expression) { + static inline bool hasAggregate(const Expression& expression) { + return satisfyAny(expression, [&](const Expression& expression) { return common::isExpressionAggregate(expression.expressionType); }); } - static bool hasSubqueryExpression(const Expression& expression) { - return hasExpression(expression, [&](const Expression& expression) { + static inline bool hasSubquery(const Expression& expression) { + return satisfyAny(expression, [&](const Expression& expression) { return common::isExpressionSubquery(expression.expressionType); }); } + static inline bool needFold(const Expression& expression) { + return expression.expressionType == common::ExpressionType::LITERAL ? + false : + isConstant(expression); + } + + static bool isConstant(const Expression& expression); + private: - static bool hasExpression( + static bool satisfyAny( const Expression& expression, const std::function& condition); }; diff --git a/src/include/common/types/value.h b/src/include/common/types/value.h index 9d05735c1e..613de6c153 100644 --- a/src/include/common/types/value.h +++ b/src/include/common/types/value.h @@ -17,6 +17,7 @@ class FileInfo; class NestedVal; class RecursiveRelVal; class ArrowRowBatch; +class ValueVector; class Value { friend class NodeVal; @@ -24,6 +25,7 @@ class Value { friend class NestedVal; friend class RecursiveRelVal; friend class ArrowRowBatch; + friend class ValueVector; public: /** diff --git a/src/include/common/vector/value_vector.h b/src/include/common/vector/value_vector.h index 9258524846..9476688f7e 100644 --- a/src/include/common/vector/value_vector.h +++ b/src/include/common/vector/value_vector.h @@ -69,6 +69,7 @@ class ValueVector { void copyFromVectorData(uint64_t dstPos, const ValueVector* srcVector, uint64_t srcPos); void copyFromValue(uint64_t pos, const Value& value); + std::unique_ptr getAsValue(uint64_t pos); inline uint8_t* getData() const { return valueBuffer.get(); } diff --git a/src/include/expression_evaluator/case_evaluator.h b/src/include/expression_evaluator/case_evaluator.h index 76cbe62cd6..c192b17815 100644 --- a/src/include/expression_evaluator/case_evaluator.h +++ b/src/include/expression_evaluator/case_evaluator.h @@ -2,8 +2,8 @@ #include -#include "base_evaluator.h" #include "common/exception.h" +#include "expression_evaluator.h" namespace kuzu { namespace evaluator { diff --git a/src/include/expression_evaluator/base_evaluator.h b/src/include/expression_evaluator/expression_evaluator.h similarity index 99% rename from src/include/expression_evaluator/base_evaluator.h rename to src/include/expression_evaluator/expression_evaluator.h index a1374e5b20..680ca483ea 100644 --- a/src/include/expression_evaluator/base_evaluator.h +++ b/src/include/expression_evaluator/expression_evaluator.h @@ -13,10 +13,8 @@ class ExpressionEvaluator { ExpressionEvaluator() = default; // Leaf evaluators (reference or literal) explicit ExpressionEvaluator(bool isResultFlat) : isResultFlat_{isResultFlat} {} - explicit ExpressionEvaluator(std::vector> children) : children{std::move(children)} {} - virtual ~ExpressionEvaluator() = default; inline bool isResultFlat() const { return isResultFlat_; } diff --git a/src/include/expression_evaluator/expression_evaluator_utils.h b/src/include/expression_evaluator/expression_evaluator_utils.h new file mode 100644 index 0000000000..b2ab5bd242 --- /dev/null +++ b/src/include/expression_evaluator/expression_evaluator_utils.h @@ -0,0 +1,14 @@ +#include "binder/expression/expression.h" +#include "expression_evaluator.h" + +namespace kuzu { +namespace evaluator { + +struct ExpressionEvaluatorUtils { + static std::unique_ptr evaluateConstantExpression( + const std::shared_ptr& expression, + storage::MemoryManager* memoryManager); +}; + +} // namespace evaluator +} // namespace kuzu diff --git a/src/include/expression_evaluator/function_evaluator.h b/src/include/expression_evaluator/function_evaluator.h index 475383a683..d10fc8b36f 100644 --- a/src/include/expression_evaluator/function_evaluator.h +++ b/src/include/expression_evaluator/function_evaluator.h @@ -1,6 +1,6 @@ #pragma once -#include "base_evaluator.h" +#include "expression_evaluator.h" #include "function/vector_functions.h" namespace kuzu { diff --git a/src/include/expression_evaluator/literal_evaluator.h b/src/include/expression_evaluator/literal_evaluator.h index 0213cfccb4..bbc1770110 100644 --- a/src/include/expression_evaluator/literal_evaluator.h +++ b/src/include/expression_evaluator/literal_evaluator.h @@ -1,6 +1,6 @@ #pragma once -#include "base_evaluator.h" +#include "expression_evaluator.h" namespace kuzu { namespace evaluator { diff --git a/src/include/expression_evaluator/node_rel_evaluator.h b/src/include/expression_evaluator/node_rel_evaluator.h index a82eaaafb7..f4db6ed2a5 100644 --- a/src/include/expression_evaluator/node_rel_evaluator.h +++ b/src/include/expression_evaluator/node_rel_evaluator.h @@ -1,7 +1,7 @@ #pragma once -#include "base_evaluator.h" #include "binder/expression/expression.h" +#include "expression_evaluator.h" namespace kuzu { namespace evaluator { diff --git a/src/include/expression_evaluator/path_evaluator.h b/src/include/expression_evaluator/path_evaluator.h index fa7cf86d10..3a24a0f0a0 100644 --- a/src/include/expression_evaluator/path_evaluator.h +++ b/src/include/expression_evaluator/path_evaluator.h @@ -1,16 +1,15 @@ #pragma once -#include "base_evaluator.h" -#include "binder/expression/path_expression.h" +#include "expression_evaluator.h" namespace kuzu { namespace evaluator { class PathExpressionEvaluator : public ExpressionEvaluator { public: - PathExpressionEvaluator(std::shared_ptr pathExpression, + PathExpressionEvaluator(std::shared_ptr expression, std::vector> children) - : ExpressionEvaluator{std::move(children)}, pathExpression{std::move(pathExpression)} {} + : ExpressionEvaluator{std::move(children)}, expression{std::move(expression)} {} void init(const processor::ResultSet& resultSet, storage::MemoryManager* memoryManager) final; @@ -25,7 +24,7 @@ class PathExpressionEvaluator : public ExpressionEvaluator { for (auto& child : children) { clonedChildren.push_back(child->clone()); } - return make_unique(pathExpression, std::move(clonedChildren)); + return make_unique(expression, std::move(clonedChildren)); } private: @@ -57,7 +56,7 @@ class PathExpressionEvaluator : public ExpressionEvaluator { const std::vector& resultFieldVectors); private: - std::shared_ptr pathExpression; + std::shared_ptr expression; std::vector> inputVectorsPerChild; common::ValueVector* resultNodesVector; // LIST[NODE] common::ValueVector* resultRelsVector; // LIST[REL] diff --git a/src/include/expression_evaluator/reference_evaluator.h b/src/include/expression_evaluator/reference_evaluator.h index a6e50ede62..96938ae7bd 100644 --- a/src/include/expression_evaluator/reference_evaluator.h +++ b/src/include/expression_evaluator/reference_evaluator.h @@ -1,6 +1,6 @@ #pragma once -#include "base_evaluator.h" +#include "expression_evaluator.h" namespace kuzu { namespace evaluator { diff --git a/src/include/function/built_in_vector_functions.h b/src/include/function/built_in_vector_functions.h index a9947c939d..94432315c3 100644 --- a/src/include/function/built_in_vector_functions.h +++ b/src/include/function/built_in_vector_functions.h @@ -14,13 +14,6 @@ class BuiltInVectorFunctions { return vectorFunctions.contains(functionName); } - /** - * Certain function can be evaluated statically and thus avoid runtime execution. - * E.g. date("2021-01-01") can be evaluated as date literal statically. - */ - bool canApplyStaticEvaluation( - const std::string& functionName, const binder::expression_vector& children); - VectorFunctionDefinition* matchVectorFunction( const std::string& name, const std::vector& inputTypes); diff --git a/src/include/processor/expression_mapper.h b/src/include/processor/expression_mapper.h index eba7d3b3da..8258b4a504 100644 --- a/src/include/processor/expression_mapper.h +++ b/src/include/processor/expression_mapper.h @@ -1,7 +1,7 @@ #pragma once #include "binder/expression/expression.h" -#include "expression_evaluator/base_evaluator.h" +#include "expression_evaluator/expression_evaluator.h" #include "processor/execution_context.h" #include "processor/result/result_set.h" #include "processor/result/result_set_descriptor.h" @@ -9,38 +9,40 @@ namespace kuzu { namespace processor { -class PlanMapper; - class ExpressionMapper { - public: - std::unique_ptr mapExpression( - const std::shared_ptr& expression, const planner::Schema& schema); + static std::unique_ptr getEvaluator( + const std::shared_ptr& expression, const planner::Schema* schema); + static std::unique_ptr getConstantEvaluator( + const std::shared_ptr& expression); private: - std::unique_ptr mapLiteralExpression( - const std::shared_ptr& expression); + static std::unique_ptr getLiteralEvaluator( + const binder::Expression& expression); - std::unique_ptr mapParameterExpression( - const std::shared_ptr& expression); + static std::unique_ptr getParameterEvaluator( + const binder::Expression& expression); + + static std::unique_ptr getReferenceEvaluator( + std::shared_ptr expression, const planner::Schema* schema); - std::unique_ptr mapReferenceExpression( - const std::shared_ptr& expression, const planner::Schema& schema); + static std::unique_ptr getCaseEvaluator( + std::shared_ptr expression, const planner::Schema* schema); - std::unique_ptr mapCaseExpression( - const std::shared_ptr& expression, const planner::Schema& schema); + static std::unique_ptr getFunctionEvaluator( + std::shared_ptr expression, const planner::Schema* schema); - std::unique_ptr mapFunctionExpression( - const std::shared_ptr& expression, const planner::Schema& schema); + static std::unique_ptr getNodeEvaluator( + std::shared_ptr expression, const planner::Schema* schema); - std::unique_ptr mapNodeExpression( - const std::shared_ptr& expression, const planner::Schema& schema); + static std::unique_ptr getRelEvaluator( + std::shared_ptr expression, const planner::Schema* schema); - std::unique_ptr mapRelExpression( - const std::shared_ptr& expression, const planner::Schema& schema); + static std::unique_ptr getPathEvaluator( + std::shared_ptr expression, const planner::Schema* schema); - std::unique_ptr mapPathExpression( - const std::shared_ptr& expression, const planner::Schema& schema); + static std::vector> getEvaluators( + const binder::expression_vector& expressions, const planner::Schema* schema); }; } // namespace processor diff --git a/src/include/processor/operator/ddl/add_property.h b/src/include/processor/operator/ddl/add_property.h index 54f72067a4..d65e2c7549 100644 --- a/src/include/processor/operator/ddl/add_property.h +++ b/src/include/processor/operator/ddl/add_property.h @@ -1,7 +1,7 @@ #pragma once #include "ddl.h" -#include "expression_evaluator/base_evaluator.h" +#include "expression_evaluator/expression_evaluator.h" #include "storage/storage_manager.h" namespace kuzu { diff --git a/src/include/processor/operator/filter.h b/src/include/processor/operator/filter.h index 67d1146062..424b7c029a 100644 --- a/src/include/processor/operator/filter.h +++ b/src/include/processor/operator/filter.h @@ -1,6 +1,6 @@ #pragma once -#include "expression_evaluator/base_evaluator.h" +#include "expression_evaluator/expression_evaluator.h" #include "processor/operator/filtering_operator.h" #include "processor/operator/physical_operator.h" diff --git a/src/include/processor/operator/index_scan.h b/src/include/processor/operator/index_scan.h index ef310c909d..d1290b61e0 100644 --- a/src/include/processor/operator/index_scan.h +++ b/src/include/processor/operator/index_scan.h @@ -1,6 +1,6 @@ #pragma once -#include "expression_evaluator/base_evaluator.h" +#include "expression_evaluator/expression_evaluator.h" #include "physical_operator.h" #include "processor/operator/filtering_operator.h" #include "storage/index/hash_index.h" diff --git a/src/include/processor/operator/persistent/insert_executor.h b/src/include/processor/operator/persistent/insert_executor.h index 956a92b156..be83b4538b 100644 --- a/src/include/processor/operator/persistent/insert_executor.h +++ b/src/include/processor/operator/persistent/insert_executor.h @@ -1,6 +1,6 @@ #pragma once -#include "expression_evaluator/base_evaluator.h" +#include "expression_evaluator/expression_evaluator.h" #include "processor/execution_context.h" #include "storage/store/node_table.h" diff --git a/src/include/processor/operator/persistent/set_executor.h b/src/include/processor/operator/persistent/set_executor.h index d6790d0363..801479acec 100644 --- a/src/include/processor/operator/persistent/set_executor.h +++ b/src/include/processor/operator/persistent/set_executor.h @@ -1,6 +1,6 @@ #pragma once -#include "expression_evaluator/base_evaluator.h" +#include "expression_evaluator/expression_evaluator.h" #include "processor/execution_context.h" #include "processor/result/result_set.h" #include "storage/store/node_table.h" diff --git a/src/include/processor/operator/projection.h b/src/include/processor/operator/projection.h index b1feafc419..ba7907d908 100644 --- a/src/include/processor/operator/projection.h +++ b/src/include/processor/operator/projection.h @@ -2,7 +2,7 @@ #include -#include "expression_evaluator/base_evaluator.h" +#include "expression_evaluator/expression_evaluator.h" #include "processor/operator/physical_operator.h" namespace kuzu { diff --git a/src/include/processor/operator/unwind.h b/src/include/processor/operator/unwind.h index f8f3eee411..948ab6ba3b 100644 --- a/src/include/processor/operator/unwind.h +++ b/src/include/processor/operator/unwind.h @@ -1,7 +1,7 @@ #pragma once #include "binder/expression/expression.h" -#include "expression_evaluator/base_evaluator.h" +#include "expression_evaluator/expression_evaluator.h" #include "processor/operator/physical_operator.h" #include "processor/result/result_set.h" diff --git a/src/main/connection.cpp b/src/main/connection.cpp index e7b41f1d0f..67107a2a05 100644 --- a/src/main/connection.cpp +++ b/src/main/connection.cpp @@ -161,7 +161,8 @@ std::unique_ptr Connection::prepareNoLock( // parsing auto statement = Parser::parseQuery(query); // binding - auto binder = Binder(*database->catalog, clientContext.get()); + auto binder = + Binder(*database->catalog, database->memoryManager.get(), clientContext.get()); auto boundStatement = binder.bind(*statement); preparedStatement->preparedSummary.statementType = boundStatement->getStatementType(); preparedStatement->readOnly = diff --git a/src/planner/plan/plan_subquery.cpp b/src/planner/plan/plan_subquery.cpp index 1568d7d07a..c061e3cd6f 100644 --- a/src/planner/plan/plan_subquery.cpp +++ b/src/planner/plan/plan_subquery.cpp @@ -121,7 +121,7 @@ void QueryPlanner::planExistsSubquery( void QueryPlanner::planSubqueryIfNecessary( const std::shared_ptr& expression, LogicalPlan& plan) { - if (ExpressionVisitor::hasSubqueryExpression(*expression)) { + if (ExpressionVisitor::hasSubquery(*expression)) { auto expressionCollector = std::make_unique(); for (auto& expr : expressionCollector->collectTopLevelSubqueryExpressions(expression)) { planExistsSubquery(expr, plan); diff --git a/src/processor/map/expression_mapper.cpp b/src/processor/map/expression_mapper.cpp index 4b2482300a..60cab1b5b0 100644 --- a/src/processor/map/expression_mapper.cpp +++ b/src/processor/map/expression_mapper.cpp @@ -6,6 +6,8 @@ #include "binder/expression/parameter_expression.h" #include "binder/expression/path_expression.h" #include "binder/expression/rel_expression.h" +#include "binder/expression_visitor.h" +#include "common/string_utils.h" #include "expression_evaluator/case_evaluator.h" #include "expression_evaluator/function_evaluator.h" #include "expression_evaluator/literal_evaluator.h" @@ -22,108 +24,160 @@ using namespace kuzu::planner; namespace kuzu { namespace processor { -std::unique_ptr ExpressionMapper::mapExpression( - const std::shared_ptr& expression, const Schema& schema) { +static bool canEvaluateAsFunction(ExpressionType expressionType) { + switch (expressionType) { + case ExpressionType::OR: + case ExpressionType::XOR: + case ExpressionType::AND: + case ExpressionType::NOT: + case ExpressionType::EQUALS: + case ExpressionType::NOT_EQUALS: + case ExpressionType::GREATER_THAN: + case ExpressionType::GREATER_THAN_EQUALS: + case ExpressionType::LESS_THAN: + case ExpressionType::LESS_THAN_EQUALS: + case ExpressionType::IS_NULL: + case ExpressionType::IS_NOT_NULL: + case ExpressionType::FUNCTION: + return true; + default: + return false; + } +} + +std::unique_ptr ExpressionMapper::getEvaluator( + const std::shared_ptr& expression, const Schema* schema) { + if (schema == nullptr) { + return getConstantEvaluator(expression); + } auto expressionType = expression->expressionType; - if (schema.isExpressionInScope(*expression)) { - return mapReferenceExpression(expression, schema); + if (schema->isExpressionInScope(*expression)) { + return getReferenceEvaluator(expression, schema); } else if (isExpressionLiteral(expressionType)) { - return mapLiteralExpression(expression); + return getLiteralEvaluator(*expression); } else if (ExpressionUtil::isNodeVariable(*expression)) { - return mapNodeExpression(expression, schema); + return getNodeEvaluator(expression, schema); } else if (ExpressionUtil::isRelVariable(*expression)) { - return mapRelExpression(expression, schema); + return getRelEvaluator(expression, schema); } else if (expressionType == ExpressionType::PATH) { - return mapPathExpression(expression, schema); + return getPathEvaluator(expression, schema); } else if (expressionType == ExpressionType::PARAMETER) { - return mapParameterExpression(expression); + return getParameterEvaluator(*expression); } else if (CASE_ELSE == expressionType) { - return mapCaseExpression(expression, schema); + return getCaseEvaluator(expression, schema); + } else if (canEvaluateAsFunction(expressionType)) { + return getFunctionEvaluator(expression, schema); } else { - return mapFunctionExpression(expression, schema); + throw NotImplementedException(StringUtils::string_format( + "Cannot evaluate expression with type {}.", expressionTypeToString(expressionType))); } } -std::unique_ptr ExpressionMapper::mapLiteralExpression( +std::unique_ptr ExpressionMapper::getConstantEvaluator( const std::shared_ptr& expression) { - auto& literalExpression = (LiteralExpression&)*expression; + assert(ExpressionVisitor::isConstant(*expression)); + auto expressionType = expression->expressionType; + if (isExpressionLiteral(expressionType)) { + return getLiteralEvaluator(*expression); + } else if (CASE_ELSE == expressionType) { + return getCaseEvaluator(expression, nullptr); + } else if (canEvaluateAsFunction(expressionType)) { + return getFunctionEvaluator(expression, nullptr); + } else { + throw NotImplementedException(StringUtils::string_format( + "Cannot evaluate expression with type {}.", expressionTypeToString(expressionType))); + } +} + +std::unique_ptr ExpressionMapper::getLiteralEvaluator( + const Expression& expression) { + auto& literalExpression = (LiteralExpression&)expression; return std::make_unique( std::make_shared(*literalExpression.getValue())); } -std::unique_ptr ExpressionMapper::mapParameterExpression( - const std::shared_ptr& expression) { - auto& parameterExpression = (ParameterExpression&)*expression; +std::unique_ptr ExpressionMapper::getParameterEvaluator( + const Expression& expression) { + auto& parameterExpression = (ParameterExpression&)expression; assert(parameterExpression.getLiteral() != nullptr); return std::make_unique(parameterExpression.getLiteral()); } -std::unique_ptr ExpressionMapper::mapReferenceExpression( - const std::shared_ptr& expression, const Schema& schema) { - auto vectorPos = DataPos(schema.getExpressionPos(*expression)); - auto expressionGroup = schema.getGroup(expression->getUniqueName()); +std::unique_ptr ExpressionMapper::getReferenceEvaluator( + std::shared_ptr expression, const Schema* schema) { + assert(schema != nullptr); + auto vectorPos = DataPos(schema->getExpressionPos(*expression)); + auto expressionGroup = schema->getGroup(expression->getUniqueName()); return std::make_unique(vectorPos, expressionGroup->isFlat()); } -std::unique_ptr ExpressionMapper::mapCaseExpression( - const std::shared_ptr& expression, const Schema& schema) { - auto& caseExpression = (CaseExpression&)*expression; +std::unique_ptr ExpressionMapper::getCaseEvaluator( + std::shared_ptr expression, const Schema* schema) { + auto caseExpression = reinterpret_cast(expression.get()); std::vector> alternativeEvaluators; - for (auto i = 0u; i < caseExpression.getNumCaseAlternatives(); ++i) { - auto alternative = caseExpression.getCaseAlternative(i); - auto whenEvaluator = mapExpression(alternative->whenExpression, schema); - auto thenEvaluator = mapExpression(alternative->thenExpression, schema); + for (auto i = 0u; i < caseExpression->getNumCaseAlternatives(); ++i) { + auto alternative = caseExpression->getCaseAlternative(i); + auto whenEvaluator = getEvaluator(alternative->whenExpression, schema); + auto thenEvaluator = getEvaluator(alternative->thenExpression, schema); alternativeEvaluators.push_back(std::make_unique( std::move(whenEvaluator), std::move(thenEvaluator))); } - auto elseEvaluator = mapExpression(caseExpression.getElseExpression(), schema); + auto elseEvaluator = getEvaluator(caseExpression->getElseExpression(), schema); return std::make_unique( - expression, std::move(alternativeEvaluators), std::move(elseEvaluator)); + std::move(expression), std::move(alternativeEvaluators), std::move(elseEvaluator)); } -std::unique_ptr ExpressionMapper::mapFunctionExpression( - const std::shared_ptr& expression, const Schema& schema) { - std::vector> children; - for (auto i = 0u; i < expression->getNumChildren(); ++i) { - children.push_back(mapExpression(expression->getChild(i), schema)); - } - return std::make_unique(expression, std::move(children)); +std::unique_ptr ExpressionMapper::getFunctionEvaluator( + std::shared_ptr expression, const Schema* schema) { + auto childrenEvaluators = getEvaluators(expression->getChildren(), schema); + return std::make_unique( + std::move(expression), std::move(childrenEvaluators)); } -std::unique_ptr ExpressionMapper::mapNodeExpression( - const std::shared_ptr& expression, const planner::Schema& schema) { +std::unique_ptr ExpressionMapper::getNodeEvaluator( + std::shared_ptr expression, const Schema* schema) { auto node = (NodeExpression*)expression.get(); - std::vector> children; - children.push_back(mapExpression(node->getInternalIDProperty(), schema)); - children.push_back(mapExpression(node->getLabelExpression(), schema)); + expression_vector children; + children.push_back(node->getInternalIDProperty()); + children.push_back(node->getLabelExpression()); for (auto& property : node->getPropertyExpressions()) { - children.push_back(mapExpression(property->copy(), schema)); + children.push_back(property->copy()); } - return std::make_unique(expression, std::move(children)); + auto childrenEvaluators = getEvaluators(children, schema); + return std::make_unique( + std::move(expression), std::move(childrenEvaluators)); } -std::unique_ptr ExpressionMapper::mapRelExpression( - const std::shared_ptr& expression, const planner::Schema& schema) { +std::unique_ptr ExpressionMapper::getRelEvaluator( + std::shared_ptr expression, const Schema* schema) { auto rel = (RelExpression*)expression.get(); - std::vector> children; - children.push_back(mapExpression(rel->getSrcNode()->getInternalIDProperty(), schema)); - children.push_back(mapExpression(rel->getDstNode()->getInternalIDProperty(), schema)); - children.push_back(mapExpression(rel->getLabelExpression(), schema)); + expression_vector children; + children.push_back(rel->getSrcNode()->getInternalIDProperty()); + children.push_back(rel->getDstNode()->getInternalIDProperty()); + children.push_back(rel->getLabelExpression()); for (auto& property : rel->getPropertyExpressions()) { - children.push_back(mapExpression(property->copy(), schema)); + children.push_back(property->copy()); } - return std::make_unique(expression, std::move(children)); + auto childrenEvaluators = getEvaluators(children, schema); + return std::make_unique( + std::move(expression), std::move(childrenEvaluators)); +} + +std::unique_ptr ExpressionMapper::getPathEvaluator( + std::shared_ptr expression, const Schema* schema) { + auto childrenEvaluators = getEvaluators(expression->getChildren(), schema); + return std::make_unique( + std::move(expression), std::move(childrenEvaluators)); } -std::unique_ptr ExpressionMapper::mapPathExpression( - const std::shared_ptr& expression, const planner::Schema& schema) { - auto pathExpression = std::static_pointer_cast(expression); - std::vector> children; - children.reserve(pathExpression->getNumChildren()); - for (auto i = 0u; i < pathExpression->getNumChildren(); ++i) { - children.push_back(mapExpression(pathExpression->getChild(i), schema)); +std::vector> ExpressionMapper::getEvaluators( + const binder::expression_vector& expressions, const Schema* schema) { + std::vector> evaluators; + evaluators.reserve(expressions.size()); + for (auto& expression : expressions) { + evaluators.push_back(getEvaluator(expression, schema)); } - return std::make_unique(pathExpression, std::move(children)); + return evaluators; } } // namespace processor diff --git a/src/processor/map/map_create.cpp b/src/processor/map/map_create.cpp index d8216ee69c..ec175db8ff 100644 --- a/src/processor/map/map_create.cpp +++ b/src/processor/map/map_create.cpp @@ -44,7 +44,7 @@ std::unique_ptr PlanMapper::getNodeInsertExecutor( for (auto i = 0u; i < info->setItems.size(); ++i) { auto& [lhs, rhs] = info->setItems[i]; auto propertyExpression = (binder::PropertyExpression*)lhs.get(); - evaluators.push_back(expressionMapper.mapExpression(rhs, inSchema)); + evaluators.push_back(ExpressionMapper::getEvaluator(rhs, &inSchema)); propertyIDToVectorIdx.insert({propertyExpression->getPropertyID(nodeTableID), i}); } return std::make_unique(table, std::move(relTablesToInit), nodeIDPos, @@ -78,7 +78,7 @@ std::unique_ptr PlanMapper::getRelInsertExecutor(storage::Rel auto lhsVectorPositions = populateLhsVectorPositions(info->setItems, outSchema); std::vector> evaluators; for (auto& [lhs, rhs] : info->setItems) { - evaluators.push_back(expressionMapper.mapExpression(rhs, inSchema)); + evaluators.push_back(ExpressionMapper::getEvaluator(rhs, &inSchema)); } return std::make_unique(relsStore->getRelsStatistics(), table, srcNodePos, dstNodePos, std::move(lhsVectorPositions), std::move(evaluators)); diff --git a/src/processor/map/map_ddl.cpp b/src/processor/map/map_ddl.cpp index 512e192a1f..0279dd7ab4 100644 --- a/src/processor/map/map_ddl.cpp +++ b/src/processor/map/map_ddl.cpp @@ -67,7 +67,7 @@ std::unique_ptr PlanMapper::mapRenameTable(LogicalOperator* lo std::unique_ptr PlanMapper::mapAddProperty(LogicalOperator* logicalOperator) { auto addProperty = (LogicalAddProperty*)logicalOperator; auto expressionEvaluator = - expressionMapper.mapExpression(addProperty->getDefaultValue(), *addProperty->getSchema()); + ExpressionMapper::getEvaluator(addProperty->getDefaultValue(), addProperty->getSchema()); auto tableSchema = catalog->getReadOnlyVersion()->getTableSchema(addProperty->getTableID()); switch (tableSchema->getTableType()) { case catalog::TableType::NODE: diff --git a/src/processor/map/map_dummy_scan.cpp b/src/processor/map/map_dummy_scan.cpp index 6295200438..2f491b8067 100644 --- a/src/processor/map/map_dummy_scan.cpp +++ b/src/processor/map/map_dummy_scan.cpp @@ -19,7 +19,7 @@ std::unique_ptr PlanMapper::mapDummyScan(LogicalOperator* logi tableSchema->appendColumn( std::make_unique(false, 0 /* all expressions are in the same datachunk */, LogicalTypeUtils::getRowLayoutSize(expression->dataType))); - auto expressionEvaluator = expressionMapper.mapExpression(expression, *inSchema); + auto expressionEvaluator = ExpressionMapper::getEvaluator(expression, inSchema.get()); // expression can be evaluated statically and does not require an actual resultset to init expressionEvaluator->init(ResultSet(0) /* dummy resultset */, memoryManager); expressionEvaluator->evaluate(); diff --git a/src/processor/map/map_filter.cpp b/src/processor/map/map_filter.cpp index 39250e3268..d43c855bb9 100644 --- a/src/processor/map/map_filter.cpp +++ b/src/processor/map/map_filter.cpp @@ -11,7 +11,7 @@ std::unique_ptr PlanMapper::mapFilter(LogicalOperator* logical auto& logicalFilter = (const LogicalFilter&)*logicalOperator; auto inSchema = logicalFilter.getChild(0)->getSchema(); auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); - auto physicalRootExpr = expressionMapper.mapExpression(logicalFilter.getPredicate(), *inSchema); + auto physicalRootExpr = ExpressionMapper::getEvaluator(logicalFilter.getPredicate(), inSchema); return make_unique(std::move(physicalRootExpr), logicalFilter.getGroupPosToSelect(), std::move(prevOperator), getOperatorID(), logicalFilter.getExpressionsForPrinting()); } diff --git a/src/processor/map/map_projection.cpp b/src/processor/map/map_projection.cpp index 7f7f565632..16c4dbd989 100644 --- a/src/processor/map/map_projection.cpp +++ b/src/processor/map/map_projection.cpp @@ -15,7 +15,7 @@ std::unique_ptr PlanMapper::mapProjection(LogicalOperator* log std::vector> expressionEvaluators; std::vector expressionsOutputPos; for (auto& expression : logicalProjection.getExpressionsToProject()) { - expressionEvaluators.push_back(expressionMapper.mapExpression(expression, *inSchema)); + expressionEvaluators.push_back(ExpressionMapper::getEvaluator(expression, inSchema)); expressionsOutputPos.emplace_back(outSchema->getExpressionPos(*expression)); } return make_unique(std::move(expressionEvaluators), std::move(expressionsOutputPos), diff --git a/src/processor/map/map_scan_node.cpp b/src/processor/map/map_scan_node.cpp index 7e09cd85f3..90ab0a1370 100644 --- a/src/processor/map/map_scan_node.cpp +++ b/src/processor/map/map_scan_node.cpp @@ -31,7 +31,7 @@ std::unique_ptr PlanMapper::mapIndexScanNode(LogicalOperator* auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); auto nodeTable = storageManager.getNodesStore().getNodeTable(node->getSingleTableID()); auto evaluator = - expressionMapper.mapExpression(logicalIndexScan->getIndexExpression(), *inSchema); + ExpressionMapper::getEvaluator(logicalIndexScan->getIndexExpression(), inSchema); auto outDataPos = DataPos(outSchema->getExpressionPos(*node->getInternalIDProperty())); return make_unique(nodeTable->getTableID(), nodeTable->getPKIndex(), std::move(evaluator), outDataPos, std::move(prevOperator), getOperatorID(), diff --git a/src/processor/map/map_set.cpp b/src/processor/map/map_set.cpp index e28381abd0..2a7501c4ac 100644 --- a/src/processor/map/map_set.cpp +++ b/src/processor/map/map_set.cpp @@ -20,7 +20,7 @@ std::unique_ptr PlanMapper::getNodeSetExecutor(storage::NodesSt if (inSchema.isExpressionInScope(*property)) { propertyPos = DataPos(inSchema.getExpressionPos(*property)); } - auto evaluator = expressionMapper.mapExpression(info->setItem.second, inSchema); + auto evaluator = ExpressionMapper::getEvaluator(info->setItem.second, &inSchema); if (node->isMultiLabeled()) { std::unordered_map tableIDToSetInfo; for (auto tableID : node->getTableIDs()) { @@ -68,7 +68,7 @@ std::unique_ptr PlanMapper::getRelSetExecutor(storage::RelsStore if (inSchema.isExpressionInScope(*property)) { propertyPos = DataPos(inSchema.getExpressionPos(*property)); } - auto evaluator = expressionMapper.mapExpression(info->setItem.second, inSchema); + auto evaluator = ExpressionMapper::getEvaluator(info->setItem.second, &inSchema); if (rel->isMultiLabeled()) { std::unordered_map> tableIDToTableAndPropertyID; diff --git a/src/processor/map/map_unwind.cpp b/src/processor/map/map_unwind.cpp index 06e2b6888c..bf7e534689 100644 --- a/src/processor/map/map_unwind.cpp +++ b/src/processor/map/map_unwind.cpp @@ -15,7 +15,7 @@ std::unique_ptr PlanMapper::mapUnwind(LogicalOperator* logical auto inSchema = unwind->getChild(0)->getSchema(); auto prevOperator = mapOperator(logicalOperator->getChild(0).get()); auto dataPos = DataPos(outSchema->getExpressionPos(*unwind->getAliasExpression())); - auto expressionEvaluator = expressionMapper.mapExpression(unwind->getExpression(), *inSchema); + auto expressionEvaluator = ExpressionMapper::getEvaluator(unwind->getExpression(), inSchema); return std::make_unique(*VarListType::getChildType(&unwind->getExpression()->dataType), dataPos, std::move(expressionEvaluator), std::move(prevOperator), getOperatorID(), unwind->getExpressionsForPrinting()); diff --git a/test/graph_test/graph_test.cpp b/test/graph_test/graph_test.cpp index fd80863a59..81ea1b548f 100644 --- a/test/graph_test/graph_test.cpp +++ b/test/graph_test/graph_test.cpp @@ -72,7 +72,8 @@ void BaseGraphTest::validateQueryBestPlanJoinOrder( auto catalog = getCatalog(*database); auto statement = parser::Parser::parseQuery(query); auto parsedQuery = (parser::RegularQuery*)statement.get(); - auto boundQuery = Binder(*catalog, conn->clientContext.get()).bind(*parsedQuery); + auto boundQuery = Binder(*catalog, database->memoryManager.get(), conn->clientContext.get()) + .bind(*parsedQuery); auto plan = Planner::getBestPlan(*catalog, getStorageManager(*database)->getNodesStore().getNodesStatisticsAndDeletedIDs(), getStorageManager(*database)->getRelsStore().getRelsStatistics(), *boundQuery); diff --git a/test/test_files/exceptions/binder/binder_error.test b/test/test_files/exceptions/binder/binder_error.test index a5e0a067e9..b508080321 100644 --- a/test/test_files/exceptions/binder/binder_error.test +++ b/test/test_files/exceptions/binder/binder_error.test @@ -38,12 +38,12 @@ Binder exception: Expression in WITH must be aliased (use AS). -LOG BindToDifferentVariableType1 -STATEMENT MATCH (a:person)-[e1:knows]->(b:person) WITH e1 AS a MATCH (a) RETURN * ---- error -Binder exception: e1 has data type REL. (NODE) was expected. +Binder exception: a has data type REL. (NODE) was expected. -LOG BindToDifferentVariableType2 -STATEMENT MATCH (a:person)-[e1:knows]->(b:person) WITH a.age + 1 AS a MATCH (a) RETURN * ---- error -Binder exception: +(a.age,1) has data type INT64. (NODE) was expected. +Binder exception: a has data type INT64. (NODE) was expected. -LOG BindEmptyStar -STATEMENT RETURN * diff --git a/test/test_files/tck/match/match1.test b/test/test_files/tck/match/match1.test index 3358f34e8c..3cceb7057f 100644 --- a/test/test_files/tck/match/match1.test +++ b/test/test_files/tck/match/match1.test @@ -327,22 +327,22 @@ Binder exception: r has data type RECURSIVE_REL. (NODE) was expected. ---- ok -STATEMENT WITH true AS n MATCH (n) RETURN n; ---- error -Binder exception: True has data type BOOL. (NODE) was expected. +Binder exception: n has data type BOOL. (NODE) was expected. -STATEMENT WITH 123 AS n MATCH (n) RETURN n; ---- error -Binder exception: 123 has data type INT64. (NODE) was expected. +Binder exception: n has data type INT64. (NODE) was expected. -STATEMENT WITH 123.4 AS n MATCH (n) RETURN n; ---- error -Binder exception: 123.400000 has data type DOUBLE. (NODE) was expected. +Binder exception: n has data type DOUBLE. (NODE) was expected. -STATEMENT WITH 'foo' AS n MATCH (n) RETURN n; ---- error -Binder exception: foo has data type STRING. (NODE) was expected. +Binder exception: n has data type STRING. (NODE) was expected. -STATEMENT WITH [10] AS n MATCH (n) RETURN n; ---- error -Binder exception: LIST_CREATION(10) has data type VAR_LIST. (NODE) was expected. +Binder exception: n has data type VAR_LIST. (NODE) was expected. -STATEMENT WITH {x: 1} AS n MATCH (n) RETURN n; ---- error -Binder exception: STRUCT_PACK(1) has data type STRUCT. (NODE) was expected. +Binder exception: n has data type STRUCT. (NODE) was expected. -STATEMENT WITH {x: [1]} AS n MATCH (n) RETURN n; ---- error -Binder exception: STRUCT_PACK(LIST_CREATION(1)) has data type STRUCT. (NODE) was expected. +Binder exception: n has data type STRUCT. (NODE) was expected. diff --git a/test/test_files/tck/match/match2.test b/test/test_files/tck/match/match2.test index b2a75450a6..4c23ec9cd4 100644 --- a/test/test_files/tck/match/match2.test +++ b/test/test_files/tck/match/match2.test @@ -369,22 +369,22 @@ Binder exception: Bind relationship r to relationship with same name is not supp ---- ok -STATEMENT WITH true AS r MATCH ()-[r]-() RETURN r; ---- error -Binder exception: True has data type BOOL. (REL) was expected. +Binder exception: r has data type BOOL. (REL) was expected. -STATEMENT WITH 123 AS r MATCH ()-[r]-() RETURN r; ---- error -Binder exception: 123 has data type INT64. (REL) was expected. +Binder exception: r has data type INT64. (REL) was expected. -STATEMENT WITH 123.4 AS r MATCH ()-[r]-() RETURN r; ---- error -Binder exception: 123.400000 has data type DOUBLE. (REL) was expected. +Binder exception: r has data type DOUBLE. (REL) was expected. -STATEMENT WITH 'foo' AS r MATCH ()-[r]-() RETURN r; ---- error -Binder exception: foo has data type STRING. (REL) was expected. +Binder exception: r has data type STRING. (REL) was expected. -STATEMENT WITH [10] AS r MATCH ()-[r]-() RETURN r; ---- error -Binder exception: LIST_CREATION(10) has data type VAR_LIST. (REL) was expected. +Binder exception: r has data type VAR_LIST. (REL) was expected. -STATEMENT WITH {x: 1} AS r MATCH ()-[r]-() RETURN r; ---- error -Binder exception: STRUCT_PACK(1) has data type STRUCT. (REL) was expected. +Binder exception: r has data type STRUCT. (REL) was expected. -STATEMENT WITH {x: [1]} AS r MATCH ()-[r]-() RETURN r; ---- error -Binder exception: STRUCT_PACK(LIST_CREATION(1)) has data type STRUCT. (REL) was expected. +Binder exception: r has data type STRUCT. (REL) was expected. diff --git a/test/test_files/tinysnb/function/union.test b/test/test_files/tinysnb/function/union.test index c1bdff3b10..2376f376c5 100644 --- a/test/test_files/tinysnb/function/union.test +++ b/test/test_files/tinysnb/function/union.test @@ -11,7 +11,7 @@ 36 -LOG UnionValueOnExpr --STATEMENT MATCH (p:person)-[:knows]->(p1:person) return union_value(age := p.age), union_value(age := p1.age) +-STATEMENT MATCH (p:person)-[:knows]->(p1:person) return union_value(age := p.age) AS u1, union_value(age := p1.age) AS u2 ---- 14 35|30 35|45 @@ -74,7 +74,7 @@ age|PERSON_id 100 -LOG UnionExtractOnFlatUnflatExpr --STATEMENT MATCH (p:person)-[:knows]->(p1:person) return union_extract(union_value(age := p.age), 'age'), union_extract(union_value(age := p1.age), 'age') +-STATEMENT MATCH (p:person)-[:knows]->(p1:person) return union_extract(union_value(age := p.age), 'age') AS a, union_extract(union_value(age := p1.age), 'age') AS b ---- 14 35|30 35|45 diff --git a/test/test_files/tinysnb/projection/skip_limit.test b/test/test_files/tinysnb/projection/skip_limit.test index 453dde56bf..7102700791 100644 --- a/test/test_files/tinysnb/projection/skip_limit.test +++ b/test/test_files/tinysnb/projection/skip_limit.test @@ -34,6 +34,13 @@ Farooq Greg Hubert Blaine Wolfeschlegelsteinhausenbergerdorff +-LOG BasicLimitTest3 +-STATEMENT MATCH (a:person) RETURN a.fName ORDER BY a.fName LIMIT 1 + 1 +-CHECK_ORDER +---- 2 +Alice +Bob + -LOG BasicSkipLimitTest -STATEMENT MATCH (a:person) RETURN a.fName SKIP 1 LIMIT 2 ---- 2