Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix aggregate rel bug #1987

Merged
merged 1 commit into from
Sep 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression_visitor.h"
#include "binder/query/return_with_clause/bound_return_clause.h"
Expand Down Expand Up @@ -30,6 +31,8 @@ static expression_vector rewriteProjectionInWithClause(const expression_vector&
result.push_back(node->getInternalIDProperty());
} else if (ExpressionUtil::isRelVariable(*expression)) {
auto rel = (RelExpression*)expression.get();
result.push_back(rel->getSrcNode()->getInternalIDProperty());
result.push_back(rel->getDstNode()->getInternalIDProperty());
result.push_back(rel->getInternalIDProperty());
} else if (ExpressionUtil::isRecursiveRelVariable(*expression)) {
auto rel = (RelExpression*)expression.get();
Expand All @@ -39,7 +42,7 @@ static expression_vector rewriteProjectionInWithClause(const expression_vector&
result.push_back(expression);
}
}
return result;
return ExpressionUtil::removeDuplication(result);
}

std::unique_ptr<BoundWithClause> Binder::bindWithClause(const WithClause& withClause) {
Expand Down
1 change: 1 addition & 0 deletions src/binder/bind/bind_updating_clause.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/query/updating_clause/bound_create_clause.h"
#include "binder/query/updating_clause/bound_delete_clause.h"
#include "binder/query/updating_clause/bound_merge_clause.h"
Expand Down
4 changes: 4 additions & 0 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "binder/binder.h"
#include "binder/expression/expression_util.h"
#include "binder/expression/function_expression.h"
#include "binder/expression/literal_expression.h"
#include "binder/expression_binder.h"
Expand Down Expand Up @@ -93,6 +94,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
children.push_back(std::move(child));
}
auto function = builtInFunctions->matchFunction(functionName, childrenTypes, isDistinct);
if (function->paramRewriteFunc) {
function->paramRewriteFunc(children);
}
auto uniqueExpressionName =
AggregateFunctionExpression::getUniqueName(function->name, children, function->isDistinct);
if (children.empty()) {
Expand Down
1 change: 1 addition & 0 deletions src/binder/bind_expression/bind_property_expression.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "binder/expression/expression_util.h"
#include "binder/expression/rel_expression.h"
#include "binder/expression_binder.h"
#include "common/string_utils.h"
Expand Down
1 change: 1 addition & 0 deletions src/binder/expression/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_library(
OBJECT
case_expression.cpp
expression.cpp
expression_util.cpp
function_expression.cpp)

set(ALL_OBJECT_FILES
Expand Down
81 changes: 0 additions & 81 deletions src/binder/expression/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,86 +19,5 @@ expression_vector Expression::splitOnAND() {
return result;
}

bool ExpressionUtil::isExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID) {
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() != dataTypeID) {
return false;
}
}
return true;
}

expression_vector ExpressionUtil::getExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID) {
expression_vector result;
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() == dataTypeID) {
result.push_back(expression);
}
}
return result;
}

uint32_t ExpressionUtil::find(Expression* target, expression_vector expressions) {
for (auto i = 0u; i < expressions.size(); ++i) {
if (target->getUniqueName() == expressions[i]->getUniqueName()) {
return i;
}
}
return UINT32_MAX;
}

std::string ExpressionUtil::toString(const expression_vector& expressions) {
if (expressions.empty()) {
return std::string{};
}
auto result = expressions[0]->toString();
for (auto i = 1u; i < expressions.size(); ++i) {
result += "," + expressions[i]->toString();
}
return result;
}

std::string ExpressionUtil::toString(const std::vector<expression_pair>& expressionPairs) {
if (expressionPairs.empty()) {
return std::string{};
}
auto result = toString(expressionPairs[0]);
for (auto i = 1u; i < expressionPairs.size(); ++i) {
result += "," + toString(expressionPairs[i]);
}
return result;
}

std::string ExpressionUtil::toString(const expression_pair& expressionPair) {
return expressionPair.first->toString() + "=" + expressionPair.second->toString();
}

expression_vector ExpressionUtil::excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude) {
expression_set excludeSet;
for (auto& expression : expressionsToExclude) {
excludeSet.insert(expression);
}
expression_vector result;
for (auto& expression : expressions) {
if (!excludeSet.contains(expression)) {
result.push_back(expression);
}
}
return result;
}

std::vector<std::unique_ptr<common::LogicalType>> ExpressionUtil::getDataTypes(
const kuzu::binder::expression_vector& expressions) {
std::vector<std::unique_ptr<common::LogicalType>> result;
result.reserve(expressions.size());
for (auto& expression : expressions) {
result.push_back(expression->getDataType().copy());
}
return result;
}

} // namespace binder
} // namespace kuzu
101 changes: 101 additions & 0 deletions src/binder/expression/expression_util.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#include "binder/expression/expression_util.h"

namespace kuzu {
namespace binder {

bool ExpressionUtil::isExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID) {
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() != dataTypeID) {
return false;
}
}
return true;
}

expression_vector ExpressionUtil::getExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID) {
expression_vector result;
for (auto& expression : expressions) {
if (expression->dataType.getLogicalTypeID() == dataTypeID) {
result.push_back(expression);
}
}
return result;
}

Check warning on line 25 in src/binder/expression/expression_util.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/expression/expression_util.cpp#L25

Added line #L25 was not covered by tests

uint32_t ExpressionUtil::find(Expression* target, expression_vector expressions) {
for (auto i = 0u; i < expressions.size(); ++i) {
if (target->getUniqueName() == expressions[i]->getUniqueName()) {
return i;

Check warning on line 30 in src/binder/expression/expression_util.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/expression/expression_util.cpp#L27-L30

Added lines #L27 - L30 were not covered by tests
}
}
return UINT32_MAX;
}

std::string ExpressionUtil::toString(const expression_vector& expressions) {
if (expressions.empty()) {
return std::string{};
}
auto result = expressions[0]->toString();
for (auto i = 1u; i < expressions.size(); ++i) {
result += "," + expressions[i]->toString();
}
return result;
}

std::string ExpressionUtil::toString(const std::vector<expression_pair>& expressionPairs) {
if (expressionPairs.empty()) {
return std::string{};
}
auto result = toString(expressionPairs[0]);
for (auto i = 1u; i < expressionPairs.size(); ++i) {
result += "," + toString(expressionPairs[i]);
}
return result;
}

std::string ExpressionUtil::toString(const expression_pair& expressionPair) {
return expressionPair.first->toString() + "=" + expressionPair.second->toString();
}

expression_vector ExpressionUtil::excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude) {
expression_set excludeSet;
for (auto& expression : expressionsToExclude) {
excludeSet.insert(expression);
}
expression_vector result;
for (auto& expression : expressions) {
if (!excludeSet.contains(expression)) {
result.push_back(expression);
}
}
return result;
}

Check warning on line 75 in src/binder/expression/expression_util.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/expression/expression_util.cpp#L75

Added line #L75 was not covered by tests

std::vector<std::unique_ptr<common::LogicalType>> ExpressionUtil::getDataTypes(
const kuzu::binder::expression_vector& expressions) {
std::vector<std::unique_ptr<common::LogicalType>> result;
result.reserve(expressions.size());
for (auto& expression : expressions) {
result.push_back(expression->getDataType().copy());
}
return result;
}

Check warning on line 85 in src/binder/expression/expression_util.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/expression/expression_util.cpp#L85

Added line #L85 was not covered by tests

expression_vector ExpressionUtil::removeDuplication(const expression_vector& expressions) {
expression_vector result;
expression_set expressionSet;
for (auto& expression : expressions) {
if (expressionSet.contains(expression)) {
continue;
}
result.push_back(expression);
expressionSet.insert(expression);
}
return result;
}

Check warning on line 98 in src/binder/expression/expression_util.cpp

View check run for this annotation

Codecov / codecov/patch

src/binder/expression/expression_util.cpp#L98

Added line #L98 was not covered by tests

} // namespace binder
} // namespace kuzu
2 changes: 2 additions & 0 deletions src/binder/expression/function_expression.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "binder/expression/function_expression.h"

#include "binder/expression/expression_util.h"

namespace kuzu {
namespace binder {

Expand Down
2 changes: 2 additions & 0 deletions src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
add_subdirectory(aggregate)

add_library(kuzu_function
OBJECT
aggregate_function.cpp
Expand Down
7 changes: 7 additions & 0 deletions src/function/aggregate/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
add_library(kuzu_function_aggregate
OBJECT
count.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_function_aggregate>
PARENT_SCOPE)
43 changes: 43 additions & 0 deletions src/function/aggregate/count.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "function/aggregate/count.h"

#include "binder/expression/expression_util.h"
#include "binder/expression/node_expression.h"
#include "binder/expression/rel_expression.h"

using namespace kuzu::common;
using namespace kuzu::storage;
using namespace kuzu::binder;

namespace kuzu {
namespace function {

void CountFunction::updateAll(
uint8_t* state_, ValueVector* input, uint64_t multiplicity, MemoryManager* memoryManager) {
auto state = reinterpret_cast<CountState*>(state_);
if (input->hasNoNullsGuarantee()) {
for (auto i = 0u; i < input->state->selVector->selectedSize; ++i) {
state->count += multiplicity;
}
} else {
for (auto i = 0u; i < input->state->selVector->selectedSize; ++i) {
auto pos = input->state->selVector->selectedPositions[i];
if (!input->isNull(pos)) {
state->count += multiplicity;

Check warning on line 25 in src/function/aggregate/count.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/aggregate/count.cpp#L22-L25

Added lines #L22 - L25 were not covered by tests
}
}
}
}

void CountFunction::paramRewriteFunc(binder::expression_vector& arguments) {
assert(arguments.size() == 1);
if (ExpressionUtil::isNodeVariable(*arguments[0])) {
auto node = (NodeExpression*)arguments[0].get();
arguments[0] = node->getInternalIDProperty();
} else if (ExpressionUtil::isRelVariable(*arguments[0])) {
auto rel = (RelExpression*)arguments[0].get();
arguments[0] = rel->getInternalIDProperty();
}
}

} // namespace function
} // namespace kuzu
12 changes: 3 additions & 9 deletions src/function/built_in_aggregate_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "function/aggregate/built_in_aggregate_functions.h"

#include "function/aggregate/collect.h"
#include "function/aggregate/count.h"

using namespace kuzu::common;

Expand All @@ -23,14 +24,6 @@ AggregateFunctionDefinition* BuiltInAggregateFunctions::matchFunction(
return candidateFunctions[0];
}

std::vector<std::string> BuiltInAggregateFunctions::getFunctionNames() {
std::vector<std::string> result;
for (auto& [functionName, definitions] : aggregateFunctions) {
result.push_back(functionName);
}
return result;
}

uint32_t BuiltInAggregateFunctions::getFunctionCost(const std::vector<LogicalType>& inputTypes,
bool isDistinct, AggregateFunctionDefinition* function) {
if (inputTypes.size() != function->parameterTypeIDs.size() ||
Expand Down Expand Up @@ -89,7 +82,8 @@ void BuiltInAggregateFunctions::registerCount() {
for (auto isDistinct : std::vector<bool>{true, false}) {
definitions.push_back(std::make_unique<AggregateFunctionDefinition>(COUNT_FUNC_NAME,
std::vector<LogicalTypeID>{type.getLogicalTypeID()}, LogicalTypeID::INT64,
AggregateFunctionUtil::getCountFunction(type, isDistinct), isDistinct));
AggregateFunctionUtil::getCountFunction(type, isDistinct), isDistinct,
nullptr /* bindFunc */, CountFunction::paramRewriteFunc));
}
}
aggregateFunctions.insert({COUNT_FUNC_NAME, std::move(definitions)});
Expand Down
35 changes: 0 additions & 35 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,40 +113,5 @@ struct ExpressionEquality {
}
};

struct ExpressionUtil {
static bool isExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID);
static expression_vector getExpressionsWithDataType(
const expression_vector& expressions, common::LogicalTypeID dataTypeID);

static uint32_t find(Expression* target, expression_vector expressions);

// Print as a1,a2,a3,...
static std::string toString(const expression_vector& expressions);
// Print as a1=a2, a3=a4,...
static std::string toString(const std::vector<expression_pair>& expressionPairs);
// Print as a1=a2
static std::string toString(const expression_pair& expressionPair);

static expression_vector excludeExpressions(
const expression_vector& expressions, const expression_vector& expressionsToExclude);

inline static bool isNodeVariable(const Expression& expression) {
return expression.expressionType == common::ExpressionType::VARIABLE &&
expression.dataType.getLogicalTypeID() == common::LogicalTypeID::NODE;
}
inline static bool isRelVariable(const Expression& expression) {
return expression.expressionType == common::ExpressionType::VARIABLE &&
expression.dataType.getLogicalTypeID() == common::LogicalTypeID::REL;
}
inline static bool isRecursiveRelVariable(const Expression& expression) {
return expression.expressionType == common::ExpressionType::VARIABLE &&
expression.dataType.getLogicalTypeID() == common::LogicalTypeID::RECURSIVE_REL;
}

static std::vector<std::unique_ptr<common::LogicalType>> getDataTypes(
const expression_vector& expressions);
};

} // namespace binder
} // namespace kuzu
Loading