Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expression Constant Folding #1989

Merged
merged 1 commit into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,30 @@

uint64_t Binder::bindSkipLimitExpression(const ParsedExpression& expression) {
auto boundExpression = expressionBinder.bindExpression(expression);
// We currently do not support the number of rows to skip/limit written as an expression (eg.
// SKIP 3 + 2 is not supported).
if (expression.getExpressionType() != LITERAL ||
((LiteralExpression&)(*boundExpression)).getDataType().getLogicalTypeID() !=
LogicalTypeID::INT64) {
throw BinderException("The number of rows to skip/limit must be a non-negative integer.");
auto errorMsg = "The number of rows to skip/limit must be a non-negative integer.";
if (!ExpressionVisitor::isConstant(*boundExpression)) {
throw BinderException(errorMsg);

Check warning on line 232 in src/binder/bind/bind_projection_clause.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/bind/bind_projection_clause.cpp#L232

Added line #L232 was not covered by tests
}
return ((LiteralExpression&)(*boundExpression)).value->getValue<int64_t>();
auto value = ((LiteralExpression&)(*boundExpression)).value.get();
int64_t num = 0;
// TODO: replace the following switch with value.cast()
switch (value->getDataType()->getLogicalTypeID()) {
case LogicalTypeID::INT64: {
num = value->getValue<int64_t>();
} break;
case LogicalTypeID::INT32: {
num = value->getValue<int32_t>();
} break;

Check warning on line 243 in src/binder/bind/bind_projection_clause.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/bind/bind_projection_clause.cpp#L242-L243

Added lines #L242 - L243 were not covered by tests
case LogicalTypeID::INT16: {
num = value->getValue<int16_t>();
} break;

Check warning on line 246 in src/binder/bind/bind_projection_clause.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/bind/bind_projection_clause.cpp#L245-L246

Added lines #L245 - L246 were not covered by tests
default:
throw BinderException(errorMsg);
}
if (num < 0) {
throw BinderException(errorMsg);
}
return num;
}

void Binder::addExpressionsToScope(const expression_vector& projectionExpressions) {
Expand Down
19 changes: 0 additions & 19 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
childrenTypes.push_back(child->dataType);
}
auto function = builtInFunctions->matchVectorFunction(functionName, childrenTypes);
if (builtInFunctions->canApplyStaticEvaluation(functionName, children)) {
return staticEvaluate(functionName, children);
}
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
auto targetType =
Expand Down Expand Up @@ -138,22 +135,6 @@ std::shared_ptr<Expression> ExpressionBinder::bindMacroExpression(
return bindExpression(*macroParameterReplacer->visit(std::move(macroExpr)));
}

std::shared_ptr<Expression> ExpressionBinder::staticEvaluate(
const std::string& functionName, const expression_vector& children) {
assert(children[0]->expressionType == LITERAL);
auto strVal = ((LiteralExpression*)children[0].get())->getValue()->getValue<std::string>();
std::unique_ptr<Value> value;
if (functionName == CAST_TO_DATE_FUNC_NAME) {
value = std::make_unique<Value>(Date::fromCString(strVal.c_str(), strVal.length()));
} else if (functionName == CAST_TO_TIMESTAMP_FUNC_NAME) {
value = std::make_unique<Value>(Timestamp::fromCString(strVal.c_str(), strVal.length()));
} else {
assert(functionName == CAST_TO_INTERVAL_FUNC_NAME);
value = std::make_unique<Value>(Interval::fromCString(strVal.c_str(), strVal.length()));
}
return createLiteralExpression(std::move(value));
}

// Function rewriting happens when we need to expose internal property access through function so
// that it becomes read-only or the function involves catalog information. Currently we write
// Before | After
Expand Down
1 change: 0 additions & 1 deletion src/binder/bind_expression/bind_property_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ std::shared_ptr<Expression> ExpressionBinder::bindPropertyExpression(
} else if (ExpressionUtil::isRelVariable(*child)) {
return bindRelPropertyExpression(*child, propertyName);
} else {
assert(child->expressionType == FUNCTION);
return bindStructPropertyExpression(child, propertyName);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/binder/expression/case_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namespace kuzu {
namespace binder {

std::string CaseExpression::toString() const {
std::string CaseExpression::toStringInternal() const {
std::string result = "CASE ";
for (auto& caseAlternative : caseAlternatives) {
result += "WHEN " + caseAlternative->whenExpression->toString() + " THEN " +
Expand Down
4 changes: 2 additions & 2 deletions src/binder/expression/function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ std::string ScalarFunctionExpression::getUniqueName(
return result;
}

std::string ScalarFunctionExpression::toString() const {
std::string ScalarFunctionExpression::toStringInternal() const {
auto result = functionName + "(";
result += ExpressionUtil::toString(children);
result += ")";
Expand All @@ -35,7 +35,7 @@ std::string AggregateFunctionExpression::getUniqueName(
return result;
}

std::string AggregateFunctionExpression::toString() const {
std::string AggregateFunctionExpression::toStringInternal() const {
auto result = functionName + "(";
if (isDistinct()) {
result += "DISTINCT ";
Expand Down
23 changes: 22 additions & 1 deletion src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "binder/expression/literal_expression.h"
#include "binder/expression/parameter_expression.h"
#include "binder/expression_visitor.h"
#include "expression_evaluator/expression_evaluator_utils.h"
#include "function/cast/vector_cast_functions.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -47,9 +48,29 @@ std::shared_ptr<Expression> ExpressionBinder::bindExpression(
if (isExpressionAggregate(expression->expressionType)) {
validateAggregationExpressionIsNotNested(*expression);
}
if (ExpressionVisitor::needFold(*expression)) {
return foldExpression(expression);
}
return expression;
}

std::shared_ptr<Expression> ExpressionBinder::foldExpression(
std::shared_ptr<Expression> expression) {
auto value = evaluator::ExpressionEvaluatorUtils::evaluateConstantExpression(
expression, binder->memoryManager);
auto result = createLiteralExpression(std::move(value));
// Fold result should preserve the alias original expression. E.g.
// RETURN 2, 1 + 1 AS x
// Once folded, 1 + 1 will become 2 and have the same identifier as the first RETURN element.
// We preserve alias (x) to avoid such conflict.
if (expression->hasAlias()) {
result->setAlias(expression->getAlias());
} else {
result->setAlias(expression->toString());
}
return result;
}

std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, const LogicalType& targetType) {
if (targetType.getLogicalTypeID() == LogicalTypeID::ANY || expression->dataType == targetType) {
Expand Down Expand Up @@ -127,7 +148,7 @@ void ExpressionBinder::validateAggregationExpressionIsNotNested(const Expression
if (expression.getNumChildren() == 0) {
return;
}
if (ExpressionVisitor::hasAggregateExpression(*expression.getChild(0))) {
if (ExpressionVisitor::hasAggregate(*expression.getChild(0))) {
throw BinderException(
"Expression " + expression.toString() + " contains nested aggregation.");
}
Expand Down
20 changes: 18 additions & 2 deletions src/binder/expression_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,29 @@ expression_vector ExpressionChildrenCollector::collectRelChildren(const Expressi
return result;
}

bool ExpressionVisitor::hasExpression(
// isConstant requires all children to be constant.
bool ExpressionVisitor::isConstant(const Expression& expression) {
if (expression.expressionType == ExpressionType::AGGREGATE_FUNCTION) {
return false; // We don't have a framework to fold aggregated constant.
}
if (expression.getNumChildren() == 0) {
return expression.expressionType == ExpressionType::LITERAL;
}
for (auto& child : ExpressionChildrenCollector::collectChildren(expression)) {
if (!isConstant(*child)) {
return false;
}
}
return true;
}

bool ExpressionVisitor::satisfyAny(
const Expression& expression, const std::function<bool(const Expression&)>& condition) {
if (condition(expression)) {
return true;
}
for (auto& child : ExpressionChildrenCollector::collectChildren(expression)) {
if (hasExpression(*child, condition)) {
if (satisfyAny(*child, condition)) {
return true;
}
}
Expand Down
62 changes: 62 additions & 0 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@
}

void ValueVector::copyFromValue(uint64_t pos, const Value& value) {
if (value.isNull()) {
setNull(pos, true);
return;
}
setNull(pos, false);
auto dstValue = valueBuffer.get() + pos * numBytesPerValue;
switch (dataType.getPhysicalType()) {
case PhysicalTypeID::INT64: {
Expand Down Expand Up @@ -171,6 +176,63 @@
}
}

std::unique_ptr<Value> ValueVector::getAsValue(uint64_t pos) {
if (isNull(pos)) {
return Value::createNullValue(dataType).copy();
}
auto value = Value::createDefaultValue(dataType).copy();
switch (dataType.getPhysicalType()) {
case PhysicalTypeID::INT64: {
value->val.int64Val = getValue<int64_t>(pos);
} break;
case PhysicalTypeID::INT32: {
value->val.int32Val = getValue<int32_t>(pos);
} break;
case PhysicalTypeID::INT16: {
value->val.int16Val = getValue<int16_t>(pos);
} break;
case PhysicalTypeID::DOUBLE: {
value->val.doubleVal = getValue<double_t>(pos);
} break;
case PhysicalTypeID::FLOAT: {
value->val.floatVal = getValue<float_t>(pos);
} break;
case PhysicalTypeID::BOOL: {
value->val.booleanVal = getValue<bool>(pos);
} break;
case PhysicalTypeID::INTERVAL: {
value->val.intervalVal = getValue<interval_t>(pos);
} break;
case PhysicalTypeID::STRING: {
value->strVal = getValue<ku_string_t>(pos).getAsString();
} break;
case PhysicalTypeID::VAR_LIST: {
auto dataVector = ListVector::getDataVector(this);
auto listEntry = getValue<list_entry_t>(pos);
std::vector<std::unique_ptr<Value>> children;
children.reserve(listEntry.size);
for (auto i = 0u; i < listEntry.size; ++i) {
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
children.push_back(dataVector->getAsValue(listEntry.offset + i));
}
value->childrenSize = children.size();
value->children = std::move(children);
} break;
case PhysicalTypeID::STRUCT: {
auto& fieldVectors = StructVector::getFieldVectors(this);
std::vector<std::unique_ptr<Value>> children;
children.reserve(fieldVectors.size());
for (auto& fieldVector : fieldVectors) {
children.push_back(fieldVector->getAsValue(pos));
}
value->childrenSize = children.size();
value->children = std::move(children);
} break;
default:
throw NotImplementedException("ValueVector::getAsValue");

Check warning on line 231 in src/common/vector/value_vector.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/vector/value_vector.cpp#L230-L231

Added lines #L230 - L231 were not covered by tests
}
return value;
}

void ValueVector::resetAuxiliaryBuffer() {
switch (dataType.getPhysicalType()) {
case PhysicalTypeID::STRING: {
Expand Down
3 changes: 2 additions & 1 deletion src/expression_evaluator/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
add_library(kuzu_expression_evaluator
OBJECT
base_evaluator.cpp
case_evaluator.cpp
expression_evaluator.cpp
expression_evaluator_utils.cpp
function_evaluator.cpp
literal_evaluator.cpp
node_rel_evaluator.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "expression_evaluator/base_evaluator.h"
#include "expression_evaluator/expression_evaluator.h"

using namespace kuzu::common;

Expand Down
23 changes: 23 additions & 0 deletions src/expression_evaluator/expression_evaluator_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "expression_evaluator/expression_evaluator_utils.h"

#include "processor/expression_mapper.h"

using namespace kuzu::common;
using namespace kuzu::processor;

namespace kuzu {
namespace evaluator {

std::unique_ptr<Value> ExpressionEvaluatorUtils::evaluateConstantExpression(
const std::shared_ptr<binder::Expression>& expression, storage::MemoryManager* memoryManager) {
auto evaluator = ExpressionMapper::getConstantEvaluator(expression);
auto emptyResultSet = std::make_unique<ResultSet>(0);
evaluator->init(*emptyResultSet, memoryManager);
evaluator->evaluate();
auto selVector = evaluator->resultVector->state->selVector.get();
assert(selVector->selectedSize == 1);
return evaluator->resultVector->getAsValue(selVector->selectedPositions[0]);
}

} // namespace evaluator
} // namespace kuzu
23 changes: 12 additions & 11 deletions src/expression_evaluator/path_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ void PathExpressionEvaluator::init(
for (auto& fieldVector : StructVector::getFieldVectors(resultRelsDataVector)) {
resultRelsFieldVectors.push_back(fieldVector.get());
}
for (auto i = 0u; i < pathExpression->getNumChildren(); ++i) {
auto child = pathExpression->getChild(i).get();
auto pathExpression = (PathExpression*)expression.get();
for (auto i = 0u; i < expression->getNumChildren(); ++i) {
auto child = expression->getChild(i).get();
auto vectors = std::make_unique<InputVectors>();
vectors->input = children[i]->resultVector.get();
switch (child->dataType.getLogicalTypeID()) {
Expand Down Expand Up @@ -99,8 +100,8 @@ static inline uint32_t getCurrentPos(ValueVector* vector, uint32_t pos) {
void PathExpressionEvaluator::copyNodes(sel_t resultPos) {
auto listSize = 0u;
// Calculate list size.
for (auto i = 0; i < pathExpression->getNumChildren(); ++i) {
auto child = pathExpression->getChild(i).get();
for (auto i = 0; i < expression->getNumChildren(); ++i) {
auto child = expression->getChild(i).get();
switch (child->dataType.getLogicalTypeID()) {
case LogicalTypeID::NODE: {
listSize++;
Expand All @@ -119,8 +120,8 @@ void PathExpressionEvaluator::copyNodes(sel_t resultPos) {
resultNodesVector->setValue(resultPos, entry);
// Copy field vectors
offset_t resultDataPos = entry.offset;
for (auto i = 0; i < pathExpression->getNumChildren(); ++i) {
auto child = pathExpression->getChild(i).get();
for (auto i = 0; i < expression->getNumChildren(); ++i) {
auto child = expression->getChild(i).get();
auto vectors = inputVectorsPerChild[i].get();
auto inputPos = getCurrentPos(vectors->input, resultPos);
switch (child->dataType.getLogicalTypeID()) {
Expand All @@ -144,8 +145,8 @@ void PathExpressionEvaluator::copyNodes(sel_t resultPos) {
void PathExpressionEvaluator::copyRels(sel_t resultPos) {
auto listSize = 0u;
// Calculate list size.
for (auto i = 0; i < pathExpression->getNumChildren(); ++i) {
auto child = pathExpression->getChild(i).get();
for (auto i = 0; i < expression->getNumChildren(); ++i) {
auto child = expression->getChild(i).get();
switch (child->dataType.getLogicalTypeID()) {
case LogicalTypeID::REL: {
listSize++;
Expand All @@ -164,8 +165,8 @@ void PathExpressionEvaluator::copyRels(sel_t resultPos) {
resultRelsVector->setValue(resultPos, entry);
// Copy field vectors
offset_t resultDataPos = entry.offset;
for (auto i = 0; i < pathExpression->getNumChildren(); ++i) {
auto child = pathExpression->getChild(i).get();
for (auto i = 0; i < expression->getNumChildren(); ++i) {
auto child = expression->getChild(i).get();
auto vectors = inputVectorsPerChild[i].get();
auto inputPos = getCurrentPos(vectors->input, resultPos);
switch (child->dataType.getLogicalTypeID()) {
Expand Down Expand Up @@ -206,7 +207,7 @@ void PathExpressionEvaluator::copyFieldVectors(offset_t inputVectorPos,

void PathExpressionEvaluator::resolveResultVector(
const processor::ResultSet& resultSet, storage::MemoryManager* memoryManager) {
resultVector = std::make_shared<ValueVector>(pathExpression->getDataType(), memoryManager);
resultVector = std::make_shared<ValueVector>(expression->getDataType(), memoryManager);
std::vector<ExpressionEvaluator*> inputEvaluators;
inputEvaluators.reserve(children.size());
for (auto& child : children) {
Expand Down
11 changes: 0 additions & 11 deletions src/function/built_in_vector_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,6 @@ void BuiltInVectorFunctions::registerVectorFunctions() {
registerBlobFunctions();
}

bool BuiltInVectorFunctions::canApplyStaticEvaluation(
const std::string& functionName, const binder::expression_vector& children) {
if ((functionName == CAST_TO_DATE_FUNC_NAME || functionName == CAST_TO_TIMESTAMP_FUNC_NAME ||
functionName == CAST_TO_INTERVAL_FUNC_NAME) &&
children[0]->expressionType == LITERAL &&
children[0]->dataType.getLogicalTypeID() == LogicalTypeID::STRING) {
return true; // bind as literal
}
return false;
}

VectorFunctionDefinition* BuiltInVectorFunctions::matchVectorFunction(
const std::string& name, const std::vector<LogicalType>& inputTypes) {
auto& functionDefinitions = vectorFunctions.at(name);
Expand Down
Loading