Skip to content

Commit

Permalink
Add constant folding
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Sep 4, 2023
1 parent 3c2086b commit 57fef9a
Show file tree
Hide file tree
Showing 62 changed files with 417 additions and 225 deletions.
30 changes: 23 additions & 7 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,30 @@ expression_vector Binder::bindOrderByExpressions(

uint64_t Binder::bindSkipLimitExpression(const ParsedExpression& expression) {
auto boundExpression = expressionBinder.bindExpression(expression);
// We currently do not support the number of rows to skip/limit written as an expression (eg.
// SKIP 3 + 2 is not supported).
if (expression.getExpressionType() != LITERAL ||
((LiteralExpression&)(*boundExpression)).getDataType().getLogicalTypeID() !=
LogicalTypeID::INT64) {
throw BinderException("The number of rows to skip/limit must be a non-negative integer.");
auto errorMsg = "The number of rows to skip/limit must be a non-negative integer.";
if (!ExpressionVisitor::isConstant(*boundExpression)) {
throw BinderException(errorMsg);

Check warning on line 232 in src/binder/bind/bind_projection_clause.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/bind/bind_projection_clause.cpp#L232

Added line #L232 was not covered by tests
}
return ((LiteralExpression&)(*boundExpression)).value->getValue<int64_t>();
auto value = ((LiteralExpression&)(*boundExpression)).value.get();
int64_t num = 0;
// TODO: replace the following switch with value.cast()
switch (value->getDataType()->getLogicalTypeID()) {
case LogicalTypeID::INT64: {
num = value->getValue<int64_t>();
} break;
case LogicalTypeID::INT32: {
num = value->getValue<int32_t>();
} break;

Check warning on line 243 in src/binder/bind/bind_projection_clause.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/bind/bind_projection_clause.cpp#L242-L243

Added lines #L242 - L243 were not covered by tests
case LogicalTypeID::INT16: {
num = value->getValue<int16_t>();
} break;

Check warning on line 246 in src/binder/bind/bind_projection_clause.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/bind/bind_projection_clause.cpp#L245-L246

Added lines #L245 - L246 were not covered by tests
default:
throw BinderException(errorMsg);
}
if (num < 0) {
throw BinderException(errorMsg);
}
return num;
}

void Binder::addExpressionsToScope(const expression_vector& projectionExpressions) {
Expand Down
19 changes: 0 additions & 19 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
childrenTypes.push_back(child->dataType);
}
auto function = builtInFunctions->matchVectorFunction(functionName, childrenTypes);
if (builtInFunctions->canApplyStaticEvaluation(functionName, children)) {
return staticEvaluate(functionName, children);
}
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
auto targetType =
Expand Down Expand Up @@ -138,22 +135,6 @@ std::shared_ptr<Expression> ExpressionBinder::bindMacroExpression(
return bindExpression(*macroParameterReplacer->visit(std::move(macroExpr)));
}

std::shared_ptr<Expression> ExpressionBinder::staticEvaluate(
const std::string& functionName, const expression_vector& children) {
assert(children[0]->expressionType == LITERAL);
auto strVal = ((LiteralExpression*)children[0].get())->getValue()->getValue<std::string>();
std::unique_ptr<Value> value;
if (functionName == CAST_TO_DATE_FUNC_NAME) {
value = std::make_unique<Value>(Date::fromCString(strVal.c_str(), strVal.length()));
} else if (functionName == CAST_TO_TIMESTAMP_FUNC_NAME) {
value = std::make_unique<Value>(Timestamp::fromCString(strVal.c_str(), strVal.length()));
} else {
assert(functionName == CAST_TO_INTERVAL_FUNC_NAME);
value = std::make_unique<Value>(Interval::fromCString(strVal.c_str(), strVal.length()));
}
return createLiteralExpression(std::move(value));
}

// Function rewriting happens when we need to expose internal property access through function so
// that it becomes read-only or the function involves catalog information. Currently we write
// Before | After
Expand Down
1 change: 0 additions & 1 deletion src/binder/bind_expression/bind_property_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ std::shared_ptr<Expression> ExpressionBinder::bindPropertyExpression(
} else if (ExpressionUtil::isRelVariable(*child)) {
return bindRelPropertyExpression(*child, propertyName);
} else {
assert(child->expressionType == FUNCTION);
return bindStructPropertyExpression(child, propertyName);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/binder/expression/case_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namespace kuzu {
namespace binder {

std::string CaseExpression::toString() const {
std::string CaseExpression::toStringInternal() const {
std::string result = "CASE ";
for (auto& caseAlternative : caseAlternatives) {
result += "WHEN " + caseAlternative->whenExpression->toString() + " THEN " +
Expand Down
4 changes: 2 additions & 2 deletions src/binder/expression/function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ std::string ScalarFunctionExpression::getUniqueName(
return result;
}

std::string ScalarFunctionExpression::toString() const {
std::string ScalarFunctionExpression::toStringInternal() const {
auto result = functionName + "(";
result += ExpressionUtil::toString(children);
result += ")";
Expand All @@ -35,7 +35,7 @@ std::string AggregateFunctionExpression::getUniqueName(
return result;
}

std::string AggregateFunctionExpression::toString() const {
std::string AggregateFunctionExpression::toStringInternal() const {
auto result = functionName + "(";
if (isDistinct()) {
result += "DISTINCT ";
Expand Down
23 changes: 22 additions & 1 deletion src/binder/expression_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "binder/expression/literal_expression.h"
#include "binder/expression/parameter_expression.h"
#include "binder/expression_visitor.h"
#include "expression_evaluator/expression_evaluator_utils.h"
#include "function/cast/vector_cast_functions.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -47,9 +48,29 @@ std::shared_ptr<Expression> ExpressionBinder::bindExpression(
if (isExpressionAggregate(expression->expressionType)) {
validateAggregationExpressionIsNotNested(*expression);
}
if (ExpressionVisitor::needFold(*expression)) {
return foldExpression(expression);
}
return expression;
}

std::shared_ptr<Expression> ExpressionBinder::foldExpression(
std::shared_ptr<Expression> expression) {
auto value = evaluator::ExpressionEvaluatorUtils::evaluateConstantExpression(
expression, binder->memoryManager);
auto result = createLiteralExpression(std::move(value));
// Fold result should preserve the alias original expression. E.g.
// RETURN 2, 1 + 1 AS x
// Once folded, 1 + 1 will become 2 and have the same identifier as the first RETURN element.
// We preserve alias (x) to avoid such conflict.
if (expression->hasAlias()) {
result->setAlias(expression->getAlias());
} else {
result->setAlias(expression->toString());
}
return result;
}

std::shared_ptr<Expression> ExpressionBinder::implicitCastIfNecessary(
const std::shared_ptr<Expression>& expression, const LogicalType& targetType) {
if (targetType.getLogicalTypeID() == LogicalTypeID::ANY || expression->dataType == targetType) {
Expand Down Expand Up @@ -127,7 +148,7 @@ void ExpressionBinder::validateAggregationExpressionIsNotNested(const Expression
if (expression.getNumChildren() == 0) {
return;
}
if (ExpressionVisitor::hasAggregateExpression(*expression.getChild(0))) {
if (ExpressionVisitor::hasAggregate(*expression.getChild(0))) {
throw BinderException(
"Expression " + expression.toString() + " contains nested aggregation.");
}
Expand Down
20 changes: 18 additions & 2 deletions src/binder/expression_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,29 @@ expression_vector ExpressionChildrenCollector::collectRelChildren(const Expressi
return result;
}

bool ExpressionVisitor::hasExpression(
// isConstant requires all children to be constant.
bool ExpressionVisitor::isConstant(const Expression& expression) {
if (expression.expressionType == ExpressionType::AGGREGATE_FUNCTION) {
return false; // We don't have a framework to fold aggregated constant.
}
if (expression.getNumChildren() == 0) {
return expression.expressionType == ExpressionType::LITERAL;
}
for (auto& child : ExpressionChildrenCollector::collectChildren(expression)) {
if (!isConstant(*child)) {
return false;
}
}
return true;
}

bool ExpressionVisitor::satisfyAny(
const Expression& expression, const std::function<bool(const Expression&)>& condition) {
if (condition(expression)) {
return true;
}
for (auto& child : ExpressionChildrenCollector::collectChildren(expression)) {
if (hasExpression(*child, condition)) {
if (satisfyAny(*child, condition)) {
return true;
}
}
Expand Down
62 changes: 62 additions & 0 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ void ValueVector::copyFromVectorData(
}

void ValueVector::copyFromValue(uint64_t pos, const Value& value) {
if (value.isNull()) {
setNull(pos, true);
return;
}
setNull(pos, false);
auto dstValue = valueBuffer.get() + pos * numBytesPerValue;
switch (dataType.getPhysicalType()) {
case PhysicalTypeID::INT64: {
Expand Down Expand Up @@ -171,6 +176,63 @@ void ValueVector::copyFromValue(uint64_t pos, const Value& value) {
}
}

std::unique_ptr<Value> ValueVector::getAsValue(uint64_t pos) {
if (isNull(pos)) {
return Value::createNullValue(dataType).copy();
}
auto value = Value::createDefaultValue(dataType).copy();
switch (dataType.getPhysicalType()) {
case PhysicalTypeID::INT64: {
value->val.int64Val = getValue<int64_t>(pos);
} break;
case PhysicalTypeID::INT32: {
value->val.int32Val = getValue<int32_t>(pos);
} break;
case PhysicalTypeID::INT16: {
value->val.int16Val = getValue<int16_t>(pos);
} break;
case PhysicalTypeID::DOUBLE: {
value->val.doubleVal = getValue<double_t>(pos);
} break;
case PhysicalTypeID::FLOAT: {
value->val.floatVal = getValue<float_t>(pos);
} break;
case PhysicalTypeID::BOOL: {
value->val.booleanVal = getValue<bool>(pos);
} break;
case PhysicalTypeID::INTERVAL: {
value->val.intervalVal = getValue<interval_t>(pos);
} break;
case PhysicalTypeID::STRING: {
value->strVal = getValue<ku_string_t>(pos).getAsString();
} break;
case PhysicalTypeID::VAR_LIST: {
auto dataVector = ListVector::getDataVector(this);
auto listEntry = getValue<list_entry_t>(pos);
std::vector<std::unique_ptr<Value>> children;
children.reserve(listEntry.size);
for (auto i = 0u; i < listEntry.size; ++i) {
children.push_back(dataVector->getAsValue(listEntry.offset + i));
}
value->childrenSize = children.size();
value->children = std::move(children);
} break;
case PhysicalTypeID::STRUCT: {
auto& fieldVectors = StructVector::getFieldVectors(this);
std::vector<std::unique_ptr<Value>> children;
children.reserve(fieldVectors.size());
for (auto& fieldVector : fieldVectors) {
children.push_back(fieldVector->getAsValue(pos));
}
value->childrenSize = children.size();
value->children = std::move(children);
} break;
default:
throw NotImplementedException("ValueVector::getAsValue");

Check warning on line 231 in src/common/vector/value_vector.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/vector/value_vector.cpp#L230-L231

Added lines #L230 - L231 were not covered by tests
}
return value;
}

void ValueVector::resetAuxiliaryBuffer() {
switch (dataType.getPhysicalType()) {
case PhysicalTypeID::STRING: {
Expand Down
3 changes: 2 additions & 1 deletion src/expression_evaluator/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
add_library(kuzu_expression_evaluator
OBJECT
base_evaluator.cpp
case_evaluator.cpp
expression_evaluator.cpp
expression_evaluator_utils.cpp
function_evaluator.cpp
literal_evaluator.cpp
node_rel_evaluator.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "expression_evaluator/base_evaluator.h"
#include "expression_evaluator/expression_evaluator.h"

using namespace kuzu::common;

Expand Down
23 changes: 23 additions & 0 deletions src/expression_evaluator/expression_evaluator_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "expression_evaluator/expression_evaluator_utils.h"

#include "processor/expression_mapper.h"

using namespace kuzu::common;
using namespace kuzu::processor;

namespace kuzu {
namespace evaluator {

std::unique_ptr<Value> ExpressionEvaluatorUtils::evaluateConstantExpression(
const std::shared_ptr<binder::Expression>& expression, storage::MemoryManager* memoryManager) {
auto evaluator = ExpressionMapper::getConstantEvaluator(expression);
auto emptyResultSet = std::make_unique<ResultSet>(0);
evaluator->init(*emptyResultSet, memoryManager);
evaluator->evaluate();
auto selVector = evaluator->resultVector->state->selVector.get();
assert(selVector->selectedSize == 1);
return evaluator->resultVector->getAsValue(selVector->selectedPositions[0]);
}

} // namespace evaluator
} // namespace kuzu
23 changes: 12 additions & 11 deletions src/expression_evaluator/path_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ void PathExpressionEvaluator::init(
for (auto& fieldVector : StructVector::getFieldVectors(resultRelsDataVector)) {
resultRelsFieldVectors.push_back(fieldVector.get());
}
for (auto i = 0u; i < pathExpression->getNumChildren(); ++i) {
auto child = pathExpression->getChild(i).get();
auto pathExpression = (PathExpression*)expression.get();
for (auto i = 0u; i < expression->getNumChildren(); ++i) {
auto child = expression->getChild(i).get();
auto vectors = std::make_unique<InputVectors>();
vectors->input = children[i]->resultVector.get();
switch (child->dataType.getLogicalTypeID()) {
Expand Down Expand Up @@ -99,8 +100,8 @@ static inline uint32_t getCurrentPos(ValueVector* vector, uint32_t pos) {
void PathExpressionEvaluator::copyNodes(sel_t resultPos) {
auto listSize = 0u;
// Calculate list size.
for (auto i = 0; i < pathExpression->getNumChildren(); ++i) {
auto child = pathExpression->getChild(i).get();
for (auto i = 0; i < expression->getNumChildren(); ++i) {
auto child = expression->getChild(i).get();
switch (child->dataType.getLogicalTypeID()) {
case LogicalTypeID::NODE: {
listSize++;
Expand All @@ -119,8 +120,8 @@ void PathExpressionEvaluator::copyNodes(sel_t resultPos) {
resultNodesVector->setValue(resultPos, entry);
// Copy field vectors
offset_t resultDataPos = entry.offset;
for (auto i = 0; i < pathExpression->getNumChildren(); ++i) {
auto child = pathExpression->getChild(i).get();
for (auto i = 0; i < expression->getNumChildren(); ++i) {
auto child = expression->getChild(i).get();
auto vectors = inputVectorsPerChild[i].get();
auto inputPos = getCurrentPos(vectors->input, resultPos);
switch (child->dataType.getLogicalTypeID()) {
Expand All @@ -144,8 +145,8 @@ void PathExpressionEvaluator::copyNodes(sel_t resultPos) {
void PathExpressionEvaluator::copyRels(sel_t resultPos) {
auto listSize = 0u;
// Calculate list size.
for (auto i = 0; i < pathExpression->getNumChildren(); ++i) {
auto child = pathExpression->getChild(i).get();
for (auto i = 0; i < expression->getNumChildren(); ++i) {
auto child = expression->getChild(i).get();
switch (child->dataType.getLogicalTypeID()) {
case LogicalTypeID::REL: {
listSize++;
Expand All @@ -164,8 +165,8 @@ void PathExpressionEvaluator::copyRels(sel_t resultPos) {
resultRelsVector->setValue(resultPos, entry);
// Copy field vectors
offset_t resultDataPos = entry.offset;
for (auto i = 0; i < pathExpression->getNumChildren(); ++i) {
auto child = pathExpression->getChild(i).get();
for (auto i = 0; i < expression->getNumChildren(); ++i) {
auto child = expression->getChild(i).get();
auto vectors = inputVectorsPerChild[i].get();
auto inputPos = getCurrentPos(vectors->input, resultPos);
switch (child->dataType.getLogicalTypeID()) {
Expand Down Expand Up @@ -206,7 +207,7 @@ void PathExpressionEvaluator::copyFieldVectors(offset_t inputVectorPos,

void PathExpressionEvaluator::resolveResultVector(
const processor::ResultSet& resultSet, storage::MemoryManager* memoryManager) {
resultVector = std::make_shared<ValueVector>(pathExpression->getDataType(), memoryManager);
resultVector = std::make_shared<ValueVector>(expression->getDataType(), memoryManager);
std::vector<ExpressionEvaluator*> inputEvaluators;
inputEvaluators.reserve(children.size());
for (auto& child : children) {
Expand Down
11 changes: 0 additions & 11 deletions src/function/built_in_vector_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,6 @@ void BuiltInVectorFunctions::registerVectorFunctions() {
registerBlobFunctions();
}

bool BuiltInVectorFunctions::canApplyStaticEvaluation(
const std::string& functionName, const binder::expression_vector& children) {
if ((functionName == CAST_TO_DATE_FUNC_NAME || functionName == CAST_TO_TIMESTAMP_FUNC_NAME ||
functionName == CAST_TO_INTERVAL_FUNC_NAME) &&
children[0]->expressionType == LITERAL &&
children[0]->dataType.getLogicalTypeID() == LogicalTypeID::STRING) {
return true; // bind as literal
}
return false;
}

VectorFunctionDefinition* BuiltInVectorFunctions::matchVectorFunction(
const std::string& name, const std::vector<LogicalType>& inputTypes) {
auto& functionDefinitions = vectorFunctions.at(name);
Expand Down
Loading

0 comments on commit 57fef9a

Please sign in to comment.