Skip to content

Commit

Permalink
Merge pull request #1987 from kuzudb/fix-agg-rel-bug
Browse files Browse the repository at this point in the history
Fix aggregate rel bug
  • Loading branch information
andyfengHKU committed Sep 3, 2023
2 parents bc16021 + b061b9a commit 3c2086b
Show file tree
Hide file tree
Showing 24 changed files with 294 additions and 158 deletions.
5 changes: 4 additions & 1 deletion src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression_visitor.h"
#include "binder/query/return_with_clause/bound_return_clause.h"
Expand Down Expand Up @@ -30,6 +31,8 @@ static expression_vector rewriteProjectionInWithClause(const expression_vector&
result.push_back(node->getInternalIDProperty());
} else if (ExpressionUtil::isRelVariable(*expression)) {
auto rel = (RelExpression*)expression.get();
result.push_back(rel->getSrcNode()->getInternalIDProperty());
result.push_back(rel->getDstNode()->getInternalIDProperty());
result.push_back(rel->getInternalIDProperty());
} else if (ExpressionUtil::isRecursiveRelVariable(*expression)) {
auto rel = (RelExpression*)expression.get();
Expand All @@ -39,7 +42,7 @@ static expression_vector rewriteProjectionInWithClause(const expression_vector&
result.push_back(expression);
}
}
return result;
return ExpressionUtil::removeDuplication(result);
}

std::unique_ptr<BoundWithClause> Binder::bindWithClause(const WithClause& withClause) {
Expand Down
1 change: 1 addition & 0 deletions src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/query/updating_clause/bound_create_clause.h"
#include "binder/query/updating_clause/bound_delete_clause.h"
#include "binder/query/updating_clause/bound_merge_clause.h"
Expand Down
4 changes: 4 additions & 0 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/function_expression.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression_binder.h"
Expand Down Expand Up @@ -93,6 +94,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
children.push_back(std::move(child));
}
auto function = builtInFunctions->matchFunction(functionName, childrenTypes, isDistinct);
if (function->paramRewriteFunc) {
function->paramRewriteFunc(children);
}
auto uniqueExpressionName =
AggregateFunctionExpression::getUniqueName(function->name, children, function->isDistinct);
if (children.empty()) {
Expand Down
1 change: 1 addition & 0 deletions src/binder/bind_expression/bind_property_expression.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "binder/expression/expression_util.h"
#include "binder/expression/rel_expression.h"
#include "binder/expression_binder.h"
#include "common/string_utils.h"
Expand Down
1 change: 1 addition & 0 deletions src/binder/expression/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_library(
OBJECT
case_expression.cpp
expression.cpp
expression_util.cpp
function_expression.cpp)

set(ALL_OBJECT_FILES
Expand Down
81 changes: 0 additions & 81 deletions src/binder/expression/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,86 +19,5 @@ expression_vector Expression::splitOnAND() {
return result;
}

bool ExpressionUtil::isExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID) {
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() != dataTypeID) {
return false;
}
}
return true;
}

expression_vector ExpressionUtil::getExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID) {
expression_vector result;
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() == dataTypeID) {
result.push_back(expression);
}
}
return result;
}

uint32_t ExpressionUtil::find(Expression* target, expression_vector expressions) {
for (auto i = 0u; i < expressions.size(); ++i) {
if (target->getUniqueName() == expressions[i]->getUniqueName()) {
return i;
}
}
return UINT32_MAX;
}

std::string ExpressionUtil::toString(const expression_vector& expressions) {
if (expressions.empty()) {
return std::string{};
}
auto result = expressions[0]->toString();
for (auto i = 1u; i < expressions.size(); ++i) {
result += "," + expressions[i]->toString();
}
return result;
}

std::string ExpressionUtil::toString(const std::vector<expression_pair>& expressionPairs) {
if (expressionPairs.empty()) {
return std::string{};
}
auto result = toString(expressionPairs[0]);
for (auto i = 1u; i < expressionPairs.size(); ++i) {
result += "," + toString(expressionPairs[i]);
}
return result;
}

std::string ExpressionUtil::toString(const expression_pair& expressionPair) {
return expressionPair.first->toString() + "=" + expressionPair.second->toString();
}

expression_vector ExpressionUtil::excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude) {
expression_set excludeSet;
for (auto& expression : expressionsToExclude) {
excludeSet.insert(expression);
}
expression_vector result;
for (auto& expression : expressions) {
if (!excludeSet.contains(expression)) {
result.push_back(expression);
}
}
return result;
}

std::vector<std::unique_ptr<common::LogicalType>> ExpressionUtil::getDataTypes(
const kuzu::binder::expression_vector& expressions) {
std::vector<std::unique_ptr<common::LogicalType>> result;
result.reserve(expressions.size());
for (auto& expression : expressions) {
result.push_back(expression->getDataType().copy());
}
return result;
}

} // namespace binder
} // namespace kuzu
101 changes: 101 additions & 0 deletions src/binder/expression/expression_util.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#include "binder/expression/expression_util.h"

namespace kuzu {
namespace binder {

bool ExpressionUtil::isExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID) {
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() != dataTypeID) {
return false;
}
}
return true;
}

expression_vector ExpressionUtil::getExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID) {
expression_vector result;
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() == dataTypeID) {
result.push_back(expression);
}
}
return result;
}

uint32_t ExpressionUtil::find(Expression* target, expression_vector expressions) {
for (auto i = 0u; i < expressions.size(); ++i) {
if (target->getUniqueName() == expressions[i]->getUniqueName()) {
return i;
}
}
return UINT32_MAX;
}

std::string ExpressionUtil::toString(const expression_vector& expressions) {
if (expressions.empty()) {
return std::string{};
}
auto result = expressions[0]->toString();
for (auto i = 1u; i < expressions.size(); ++i) {
result += "," + expressions[i]->toString();
}
return result;
}

std::string ExpressionUtil::toString(const std::vector<expression_pair>& expressionPairs) {
if (expressionPairs.empty()) {
return std::string{};
}
auto result = toString(expressionPairs[0]);
for (auto i = 1u; i < expressionPairs.size(); ++i) {
result += "," + toString(expressionPairs[i]);
}
return result;
}

std::string ExpressionUtil::toString(const expression_pair& expressionPair) {
return expressionPair.first->toString() + "=" + expressionPair.second->toString();
}

expression_vector ExpressionUtil::excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude) {
expression_set excludeSet;
for (auto& expression : expressionsToExclude) {
excludeSet.insert(expression);
}
expression_vector result;
for (auto& expression : expressions) {
if (!excludeSet.contains(expression)) {
result.push_back(expression);
}
}
return result;
}

std::vector<std::unique_ptr<common::LogicalType>> ExpressionUtil::getDataTypes(
const kuzu::binder::expression_vector& expressions) {
std::vector<std::unique_ptr<common::LogicalType>> result;
result.reserve(expressions.size());
for (auto& expression : expressions) {
result.push_back(expression->getDataType().copy());
}
return result;
}

expression_vector ExpressionUtil::removeDuplication(const expression_vector& expressions) {
expression_vector result;
expression_set expressionSet;
for (auto& expression : expressions) {
if (expressionSet.contains(expression)) {
continue;
}
result.push_back(expression);
expressionSet.insert(expression);
}
return result;
}

} // namespace binder
} // namespace kuzu
2 changes: 2 additions & 0 deletions src/binder/expression/function_expression.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "binder/expression/function_expression.h"

#include "binder/expression/expression_util.h"

namespace kuzu {
namespace binder {

Expand Down
2 changes: 2 additions & 0 deletions src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
add_subdirectory(aggregate)

add_library(kuzu_function
OBJECT
aggregate_function.cpp
Expand Down
7 changes: 7 additions & 0 deletions src/function/aggregate/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
add_library(kuzu_function_aggregate
OBJECT
count.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_function_aggregate>
PARENT_SCOPE)
43 changes: 43 additions & 0 deletions src/function/aggregate/count.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "function/aggregate/count.h"

#include "binder/expression/expression_util.h"
#include "binder/expression/node_expression.h"
#include "binder/expression/rel_expression.h"

using namespace kuzu::common;
using namespace kuzu::storage;
using namespace kuzu::binder;

namespace kuzu {
namespace function {

void CountFunction::updateAll(
uint8_t* state_, ValueVector* input, uint64_t multiplicity, MemoryManager* memoryManager) {
auto state = reinterpret_cast<CountState*>(state_);
if (input->hasNoNullsGuarantee()) {
for (auto i = 0u; i < input->state->selVector->selectedSize; ++i) {
state->count += multiplicity;
}
} else {
for (auto i = 0u; i < input->state->selVector->selectedSize; ++i) {
auto pos = input->state->selVector->selectedPositions[i];
if (!input->isNull(pos)) {
state->count += multiplicity;
}
}
}
}

void CountFunction::paramRewriteFunc(binder::expression_vector& arguments) {
assert(arguments.size() == 1);
if (ExpressionUtil::isNodeVariable(*arguments[0])) {
auto node = (NodeExpression*)arguments[0].get();
arguments[0] = node->getInternalIDProperty();
} else if (ExpressionUtil::isRelVariable(*arguments[0])) {
auto rel = (RelExpression*)arguments[0].get();
arguments[0] = rel->getInternalIDProperty();
}
}

} // namespace function
} // namespace kuzu
12 changes: 3 additions & 9 deletions src/function/built_in_aggregate_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "function/aggregate/built_in_aggregate_functions.h"

#include "function/aggregate/collect.h"
#include "function/aggregate/count.h"

using namespace kuzu::common;

Expand All @@ -23,14 +24,6 @@ AggregateFunctionDefinition* BuiltInAggregateFunctions::matchFunction(
return candidateFunctions[0];
}

std::vector<std::string> BuiltInAggregateFunctions::getFunctionNames() {
std::vector<std::string> result;
for (auto& [functionName, definitions] : aggregateFunctions) {
result.push_back(functionName);
}
return result;
}

uint32_t BuiltInAggregateFunctions::getFunctionCost(const std::vector<LogicalType>& inputTypes,
bool isDistinct, AggregateFunctionDefinition* function) {
if (inputTypes.size() != function->parameterTypeIDs.size() ||
Expand Down Expand Up @@ -89,7 +82,8 @@ void BuiltInAggregateFunctions::registerCount() {
for (auto isDistinct : std::vector<bool>{true, false}) {
definitions.push_back(std::make_unique<AggregateFunctionDefinition>(COUNT_FUNC_NAME,
std::vector<LogicalTypeID>{type.getLogicalTypeID()}, LogicalTypeID::INT64,
AggregateFunctionUtil::getCountFunction(type, isDistinct), isDistinct));
AggregateFunctionUtil::getCountFunction(type, isDistinct), isDistinct,
nullptr /* bindFunc */, CountFunction::paramRewriteFunc));
}
}
aggregateFunctions.insert({COUNT_FUNC_NAME, std::move(definitions)});
Expand Down
35 changes: 0 additions & 35 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,40 +113,5 @@ struct ExpressionEquality {
}
};

struct ExpressionUtil {
static bool isExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID);
static expression_vector getExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID);

static uint32_t find(Expression* target, expression_vector expressions);

// Print as a1,a2,a3,...
static std::string toString(const expression_vector& expressions);
// Print as a1=a2, a3=a4,...
static std::string toString(const std::vector<expression_pair>& expressionPairs);
// Print as a1=a2
static std::string toString(const expression_pair& expressionPair);

static expression_vector excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude);

inline static bool isNodeVariable(const Expression& expression) {
return expression.expressionType == common::ExpressionType::VARIABLE &&
expression.dataType.getLogicalTypeID() == common::LogicalTypeID::NODE;
}
inline static bool isRelVariable(const Expression& expression) {
return expression.expressionType == common::ExpressionType::VARIABLE &&
expression.dataType.getLogicalTypeID() == common::LogicalTypeID::REL;
}
inline static bool isRecursiveRelVariable(const Expression& expression) {
return expression.expressionType == common::ExpressionType::VARIABLE &&
expression.dataType.getLogicalTypeID() == common::LogicalTypeID::RECURSIVE_REL;
}

static std::vector<std::unique_ptr<common::LogicalType>> getDataTypes(
const expression_vector& expressions);
};

} // namespace binder
} // namespace kuzu
Loading

0 comments on commit 3c2086b

Please sign in to comment.