Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distinct aggregate over node and relationships #3236

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,6 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto child = bindExpression(*parsedExpression.getChild(i));
auto childTypeID = child->dataType.getLogicalTypeID();
if (isDistinct &&
(childTypeID == LogicalTypeID::NODE || childTypeID == LogicalTypeID::REL)) {
throw BinderException{"DISTINCT is not supported for NODE or REL type."};
}
childrenTypes.push_back(child->dataType);
children.push_back(std::move(child));
}
Expand Down
2 changes: 1 addition & 1 deletion src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ ValueVector::ValueVector(LogicalType dataType, storage::MemoryManager* memoryMan
: dataType{std::move(dataType)} {
if (this->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
// LCOV_EXCL_START
// Alternatively we can assign
// Alternatively we can assign a default type here but I don't think it's a good practice.
throw RuntimeException("Trying to a create a vector with ANY type. This should not happen. "
"Data type is expected to be resolved during binding.");
// LCOV_EXCL_STOP
Expand Down
22 changes: 15 additions & 7 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <unordered_set>

#include "common/assert.h"
#include "common/cast.h"
#include "common/copy_constructors.h"
#include "common/enums/expression_type.h"
#include "common/exception/internal.h"
Expand Down Expand Up @@ -65,28 +66,35 @@ class Expression : public std::enable_shared_from_this<Expression> {
common::LogicalType getDataType() const { return dataType; }
common::LogicalType& getDataTypeReference() { return dataType; }

inline bool hasAlias() const { return !alias.empty(); }
inline std::string getAlias() const { return alias; }
bool hasAlias() const { return !alias.empty(); }
std::string getAlias() const { return alias; }

inline uint32_t getNumChildren() const { return children.size(); }
inline std::shared_ptr<Expression> getChild(common::vector_idx_t idx) const {
uint32_t getNumChildren() const { return children.size(); }
std::shared_ptr<Expression> getChild(common::idx_t idx) const {
KU_ASSERT(idx < children.size());
return children[idx];
}
inline expression_vector getChildren() const { return children; }
inline void setChild(common::vector_idx_t idx, std::shared_ptr<Expression> child) {
expression_vector getChildren() const { return children; }
void setChild(common::idx_t idx, std::shared_ptr<Expression> child) {
KU_ASSERT(idx < children.size());
children[idx] = std::move(child);
}

expression_vector splitOnAND();

inline bool operator==(const Expression& rhs) const { return uniqueName == rhs.uniqueName; }
bool operator==(const Expression& rhs) const { return uniqueName == rhs.uniqueName; }

std::string toString() const { return hasAlias() ? alias : toStringInternal(); }

virtual std::unique_ptr<Expression> copy() const {
throw common::InternalException("Unimplemented expression copy().");
}

template<class TARGET>
const TARGET* constPtrCast() const {
return common::ku_dynamic_cast<const Expression*, const TARGET*>(this);
}

protected:
virtual std::string toStringInternal() const = 0;

Expand Down
26 changes: 11 additions & 15 deletions src/include/function/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace kuzu {
namespace function {

struct AggregateState {
virtual inline uint32_t getStateSize() const = 0;
virtual uint32_t getStateSize() const = 0;
virtual void moveResultToVector(common::ValueVector* outputVector, uint64_t pos) = 0;
virtual ~AggregateState() = default;

Expand Down Expand Up @@ -52,45 +52,41 @@ struct AggregateFunction final : public BaseScalarFunction {
std::move(combineFunc), std::move(finalizeFunc), isDistinct, nullptr /* bindFunc */,
std::move(paramRewriteFunc)} {}

inline uint32_t getAggregateStateSize() const {
return initialNullAggregateState->getStateSize();
}
uint32_t getAggregateStateSize() const { return initialNullAggregateState->getStateSize(); }

// NOLINTNEXTLINE(readability-make-member-function-const): Returns a non-const pointer.
inline AggregateState* getInitialNullAggregateState() {
return initialNullAggregateState.get();
}
AggregateState* getInitialNullAggregateState() { return initialNullAggregateState.get(); }

inline std::unique_ptr<AggregateState> createInitialNullAggregateState() const {
std::unique_ptr<AggregateState> createInitialNullAggregateState() const {
return initializeFunc();
}

inline void updateAllState(uint8_t* state, common::ValueVector* input, uint64_t multiplicity,
void updateAllState(uint8_t* state, common::ValueVector* input, uint64_t multiplicity,
storage::MemoryManager* memoryManager) const {
return updateAllFunc(state, input, multiplicity, memoryManager);
}

inline void updatePosState(uint8_t* state, common::ValueVector* input, uint64_t multiplicity,
void updatePosState(uint8_t* state, common::ValueVector* input, uint64_t multiplicity,
uint32_t pos, storage::MemoryManager* memoryManager) const {
return updatePosFunc(state, input, multiplicity, pos, memoryManager);
}

inline void combineState(uint8_t* state, uint8_t* otherState,
void combineState(uint8_t* state, uint8_t* otherState,
storage::MemoryManager* memoryManager) const {
return combineFunc(state, otherState, memoryManager);
}

inline void finalizeState(uint8_t* state) const { return finalizeFunc(state); }
void finalizeState(uint8_t* state) const { return finalizeFunc(state); }

inline bool isFunctionDistinct() const { return isDistinct; }
bool isFunctionDistinct() const { return isDistinct; }

inline std::unique_ptr<Function> copy() const override {
std::unique_ptr<Function> copy() const override {
return std::make_unique<AggregateFunction>(name, parameterTypeIDs, returnTypeID,
initializeFunc, updateAllFunc, updatePosFunc, combineFunc, finalizeFunc, isDistinct,
bindFunc, paramRewriteFunc);
}

inline std::unique_ptr<AggregateFunction> clone() const {
std::unique_ptr<AggregateFunction> clone() const {
return std::make_unique<AggregateFunction>(name, parameterTypeIDs, returnTypeID,
initializeFunc, updateAllFunc, updatePosFunc, combineFunc, finalizeFunc, isDistinct,
bindFunc, paramRewriteFunc);
Expand Down
57 changes: 23 additions & 34 deletions src/include/planner/operator/logical_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,14 @@ namespace planner {

class LogicalAggregate : public LogicalOperator {
public:
LogicalAggregate(binder::expression_vector keyExpressions,
binder::expression_vector aggregateExpressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)},
keyExpressions{std::move(keyExpressions)},
aggregateExpressions{std::move(aggregateExpressions)} {}
LogicalAggregate(binder::expression_vector keyExpressions,
binder::expression_vector dependentKeyExpressions,
binder::expression_vector aggregateExpressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)},
keyExpressions{std::move(keyExpressions)},
dependentKeyExpressions{std::move(dependentKeyExpressions)},
aggregateExpressions{std::move(aggregateExpressions)} {}
LogicalAggregate(binder::expression_vector keys, binder::expression_vector aggregates,
std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)}, keys{std::move(keys)},
aggregates{std::move(aggregates)} {}
LogicalAggregate(binder::expression_vector keys, binder::expression_vector dependentKeys,
binder::expression_vector aggregates, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)}, keys{std::move(keys)},
dependentKeys{std::move(dependentKeys)}, aggregates{std::move(aggregates)} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;
Expand All @@ -28,42 +24,35 @@ class LogicalAggregate : public LogicalOperator {

std::string getExpressionsForPrinting() const override;

inline bool hasKeyExpressions() const { return !keyExpressions.empty(); }
inline binder::expression_vector getKeyExpressions() const { return keyExpressions; }
inline void setKeyExpressions(binder::expression_vector expressions) {
keyExpressions = std::move(expressions);
bool hasKeys() const { return !keys.empty(); }
binder::expression_vector getKeys() const { return keys; }
void setKeys(binder::expression_vector expressions) { keys = std::move(expressions); }
binder::expression_vector getDependentKeys() const { return dependentKeys; }
void setDependentKeys(binder::expression_vector expressions) {
dependentKeys = std::move(expressions);
}
inline binder::expression_vector getDependentKeyExpressions() const {
return dependentKeyExpressions;
}
inline void setDependentKeyExpressions(binder::expression_vector expressions) {
dependentKeyExpressions = std::move(expressions);
}
inline binder::expression_vector getAllKeyExpressions() const {
binder::expression_vector getAllKeys() const {
binder::expression_vector result;
result.insert(result.end(), keyExpressions.begin(), keyExpressions.end());
result.insert(result.end(), dependentKeyExpressions.begin(), dependentKeyExpressions.end());
result.insert(result.end(), keys.begin(), keys.end());
result.insert(result.end(), dependentKeys.begin(), dependentKeys.end());
return result;
}
inline binder::expression_vector getAggregateExpressions() const {
return aggregateExpressions;
}
binder::expression_vector getAggregates() const { return aggregates; }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAggregate>(keyExpressions, dependentKeyExpressions,
aggregateExpressions, children[0]->copy());
std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAggregate>(keys, dependentKeys, aggregates, children[0]->copy());
}

private:
bool hasDistinctAggregate();
void insertAllExpressionsToGroupAndScope(f_group_pos groupPos);

private:
binder::expression_vector keyExpressions;
binder::expression_vector keys;
// A dependentKeyExpression depend on a keyExpression (e.g. a.age depends on a.ID) and will not
// be treated as a hash key during hash aggregation.
binder::expression_vector dependentKeyExpressions;
binder::expression_vector aggregateExpressions;
binder::expression_vector dependentKeys;
binder::expression_vector aggregates;
};

} // namespace planner
Expand Down
4 changes: 2 additions & 2 deletions src/include/planner/operator/logical_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ class LogicalOperator {
static logical_op_vector_t copy(const logical_op_vector_t& ops);

template<class TARGET>
TARGET* ptrCast() {
return common::ku_dynamic_cast<LogicalOperator*, TARGET*>(this);
const TARGET* constPtrCast() const {
return common::ku_dynamic_cast<const LogicalOperator*, const TARGET*>(this);
}

protected:
Expand Down
40 changes: 18 additions & 22 deletions src/include/processor/operator/aggregate/aggregate_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,20 @@ using update_agg_function_t =

class AggregateHashTable : public BaseHashTable {
public:
// Used by distinct aggregate hash table only.
AggregateHashTable(storage::MemoryManager& memoryManager,
const common::logical_type_vec_t& keysDataTypes,
const std::vector<std::unique_ptr<function::AggregateFunction>>& aggregateFunctions,
uint64_t numEntriesToAllocate, std::unique_ptr<FactorizedTableSchema> tableSchema)
: AggregateHashTable(memoryManager, keysDataTypes, std::vector<common::LogicalType>(),
aggregateFunctions, numEntriesToAllocate, std::move(tableSchema)) {}
const std::vector<common::LogicalType>& keyTypes,
const std::vector<common::LogicalType>& payloadTypes, uint64_t numEntriesToAllocate,
std::unique_ptr<FactorizedTableSchema> tableSchema)
: AggregateHashTable(memoryManager, keyTypes, payloadTypes,
std::vector<std::unique_ptr<function::AggregateFunction>>{} /* empty aggregates */,
std::vector<common::LogicalType>{} /* empty distinct agg key*/, numEntriesToAllocate,
std::move(tableSchema)) {}

AggregateHashTable(storage::MemoryManager& memoryManager,
std::vector<common::LogicalType> keysDataTypes,
std::vector<common::LogicalType> payloadsDataTypes,
std::vector<common::LogicalType> keyTypes, std::vector<common::LogicalType> payloadTypes,
const std::vector<std::unique_ptr<function::AggregateFunction>>& aggregateFunctions,
uint64_t numEntriesToAllocate, std::unique_ptr<FactorizedTableSchema> tableSchema);
const std::vector<common::LogicalType>& distinctAggKeyTypes, uint64_t numEntriesToAllocate,
std::unique_ptr<FactorizedTableSchema> tableSchema);

uint8_t* getEntry(uint64_t idx) { return factorizedTable->getTuple(idx); }

Expand All @@ -62,8 +63,7 @@ class AggregateHashTable : public BaseHashTable {

void append(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
common::DataChunkState* leadingState,
const std::vector<std::unique_ptr<AggregateInput>>& aggregateInputs,
common::DataChunkState* leadingState, const std::vector<AggregateInput>& aggregateInputs,
uint64_t resultSetMultiplicity) {
append(flatKeyVectors, unFlatKeyVectors, std::vector<common::ValueVector*>(), leadingState,
aggregateInputs, resultSetMultiplicity);
Expand All @@ -73,8 +73,7 @@ class AggregateHashTable : public BaseHashTable {
void append(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
const std::vector<common::ValueVector*>& dependentKeyVectors,
common::DataChunkState* leadingState,
const std::vector<std::unique_ptr<AggregateInput>>& aggregateInputs,
common::DataChunkState* leadingState, const std::vector<AggregateInput>& aggregateInputs,
uint64_t resultSetMultiplicity);

bool isAggregateValueDistinctForGroupByKeys(
Expand Down Expand Up @@ -152,8 +151,7 @@ class AggregateHashTable : public BaseHashTable {

void updateAggStates(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
const std::vector<std::unique_ptr<AggregateInput>>& aggregateInputs,
uint64_t resultSetMultiplicity);
const std::vector<AggregateInput>& aggregateInputs, uint64_t resultSetMultiplicity);

// ! This function will only be used by distinct aggregate, which assumes that all keyVectors
// are flat.
Expand Down Expand Up @@ -217,7 +215,7 @@ class AggregateHashTable : public BaseHashTable {
std::unique_ptr<HashSlot*[]> hashSlotsToUpdateAggState;

private:
std::vector<common::LogicalType> dependentKeyDataTypes;
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
std::vector<common::LogicalType> payloadTypes;
std::vector<std::unique_ptr<function::AggregateFunction>> aggregateFunctions;

//! special handling of distinct aggregate
Expand All @@ -233,13 +231,11 @@ class AggregateHashTable : public BaseHashTable {
std::unique_ptr<uint64_t[]> tmpSlotIdxes;
};

class AggregateHashTableUtils {

public:
static std::vector<std::unique_ptr<AggregateHashTable>> createDistinctHashTables(
struct AggregateHashTableUtils {
static std::unique_ptr<AggregateHashTable> createDistinctHashTable(
storage::MemoryManager& memoryManager,
const std::vector<common::LogicalType>& groupByKeyDataTypes,
const std::vector<std::unique_ptr<function::AggregateFunction>>& aggregateFunctions);
const std::vector<common::LogicalType>& groupByKeyTypes,
const common::LogicalType& distinctKeyType);
};

} // namespace processor
Expand Down
30 changes: 20 additions & 10 deletions src/include/processor/operator/aggregate/aggregate_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,33 @@
namespace kuzu {
namespace processor {

struct AggregateInputInfo {
DataPos aggregateVectorPos;
struct AggregateInfo {
DataPos aggVectorPos;
std::vector<data_chunk_pos_t> multiplicityChunksPos;
common::LogicalType distinctAggKeyType;

AggregateInputInfo(const DataPos& vectorPos,
std::vector<data_chunk_pos_t> multiplicityChunksPos)
: aggregateVectorPos{vectorPos}, multiplicityChunksPos{std::move(multiplicityChunksPos)} {}
AggregateInputInfo(const AggregateInputInfo& other)
: AggregateInputInfo(other.aggregateVectorPos, other.multiplicityChunksPos) {}
inline std::unique_ptr<AggregateInputInfo> copy() {
return std::make_unique<AggregateInputInfo>(*this);
}
AggregateInfo(const DataPos& aggVectorPos, std::vector<data_chunk_pos_t> multiplicityChunksPos,
common::LogicalType distinctAggKeyType)
: aggVectorPos{aggVectorPos}, multiplicityChunksPos{std::move(multiplicityChunksPos)},
distinctAggKeyType{std::move(distinctAggKeyType)} {}
EXPLICIT_COPY_DEFAULT_MOVE(AggregateInfo);

private:
AggregateInfo(const AggregateInfo& other)
: aggVectorPos{other.aggVectorPos}, multiplicityChunksPos{other.multiplicityChunksPos},
distinctAggKeyType{other.distinctAggKeyType} {}
};

struct AggregateInput {
common::ValueVector* aggregateVector;
std::vector<common::DataChunk*> multiplicityChunks;

AggregateInput() = default;
EXPLICIT_COPY_DEFAULT_MOVE(AggregateInput);

private:
AggregateInput(const AggregateInput& other)
: aggregateVector{other.aggregateVector}, multiplicityChunks{other.multiplicityChunks} {}
};

} // namespace processor
Expand Down
Loading
Loading