Skip to content

Commit

Permalink
Add distinct aggregate over node and relationships
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Apr 9, 2024
1 parent 3e626f3 commit b71b0a7
Show file tree
Hide file tree
Showing 28 changed files with 418 additions and 350 deletions.
5 changes: 0 additions & 5 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,6 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
expression_vector children;
for (auto i = 0u; i < parsedExpression.getNumChildren(); ++i) {
auto child = bindExpression(*parsedExpression.getChild(i));
auto childTypeID = child->dataType.getLogicalTypeID();
if (isDistinct &&
(childTypeID == LogicalTypeID::NODE || childTypeID == LogicalTypeID::REL)) {
throw BinderException{"DISTINCT is not supported for NODE or REL type."};
}
childrenTypes.push_back(child->dataType);
children.push_back(std::move(child));
}
Expand Down
2 changes: 1 addition & 1 deletion src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ ValueVector::ValueVector(LogicalType dataType, storage::MemoryManager* memoryMan
: dataType{std::move(dataType)} {
if (this->dataType.getLogicalTypeID() == LogicalTypeID::ANY) {
// LCOV_EXCL_START
// Alternatively we can assign
// Alternatively we can assign a default type here but I don't think it's a good practice.
throw RuntimeException("Trying to a create a vector with ANY type. This should not happen. "
"Data type is expected to be resolved during binding.");
// LCOV_EXCL_STOP
Expand Down
22 changes: 15 additions & 7 deletions src/include/binder/expression/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <unordered_set>

#include "common/assert.h"
#include "common/cast.h"
#include "common/copy_constructors.h"
#include "common/enums/expression_type.h"
#include "common/exception/internal.h"
Expand Down Expand Up @@ -65,28 +66,35 @@ class Expression : public std::enable_shared_from_this<Expression> {
common::LogicalType getDataType() const { return dataType; }
common::LogicalType& getDataTypeReference() { return dataType; }

inline bool hasAlias() const { return !alias.empty(); }
inline std::string getAlias() const { return alias; }
bool hasAlias() const { return !alias.empty(); }
std::string getAlias() const { return alias; }

inline uint32_t getNumChildren() const { return children.size(); }
inline std::shared_ptr<Expression> getChild(common::vector_idx_t idx) const {
uint32_t getNumChildren() const { return children.size(); }
std::shared_ptr<Expression> getChild(common::idx_t idx) const {
KU_ASSERT(idx < children.size());
return children[idx];
}
inline expression_vector getChildren() const { return children; }
inline void setChild(common::vector_idx_t idx, std::shared_ptr<Expression> child) {
expression_vector getChildren() const { return children; }
void setChild(common::idx_t idx, std::shared_ptr<Expression> child) {
KU_ASSERT(idx < children.size());
children[idx] = std::move(child);
}

expression_vector splitOnAND();

inline bool operator==(const Expression& rhs) const { return uniqueName == rhs.uniqueName; }
bool operator==(const Expression& rhs) const { return uniqueName == rhs.uniqueName; }

std::string toString() const { return hasAlias() ? alias : toStringInternal(); }

virtual std::unique_ptr<Expression> copy() const {
throw common::InternalException("Unimplemented expression copy().");
}

template<class TARGET>
const TARGET* constPtrCast() const {
return common::ku_dynamic_cast<const Expression*, const TARGET*>(this);
}

protected:
virtual std::string toStringInternal() const = 0;

Expand Down
26 changes: 11 additions & 15 deletions src/include/function/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace kuzu {
namespace function {

struct AggregateState {
virtual inline uint32_t getStateSize() const = 0;
virtual uint32_t getStateSize() const = 0;
virtual void moveResultToVector(common::ValueVector* outputVector, uint64_t pos) = 0;
virtual ~AggregateState() = default;

Expand Down Expand Up @@ -52,45 +52,41 @@ struct AggregateFunction final : public BaseScalarFunction {
std::move(combineFunc), std::move(finalizeFunc), isDistinct, nullptr /* bindFunc */,
std::move(paramRewriteFunc)} {}

inline uint32_t getAggregateStateSize() const {
return initialNullAggregateState->getStateSize();
}
uint32_t getAggregateStateSize() const { return initialNullAggregateState->getStateSize(); }

// NOLINTNEXTLINE(readability-make-member-function-const): Returns a non-const pointer.
inline AggregateState* getInitialNullAggregateState() {
return initialNullAggregateState.get();
}
AggregateState* getInitialNullAggregateState() { return initialNullAggregateState.get(); }

inline std::unique_ptr<AggregateState> createInitialNullAggregateState() const {
std::unique_ptr<AggregateState> createInitialNullAggregateState() const {
return initializeFunc();
}

inline void updateAllState(uint8_t* state, common::ValueVector* input, uint64_t multiplicity,
void updateAllState(uint8_t* state, common::ValueVector* input, uint64_t multiplicity,
storage::MemoryManager* memoryManager) const {
return updateAllFunc(state, input, multiplicity, memoryManager);
}

inline void updatePosState(uint8_t* state, common::ValueVector* input, uint64_t multiplicity,
void updatePosState(uint8_t* state, common::ValueVector* input, uint64_t multiplicity,
uint32_t pos, storage::MemoryManager* memoryManager) const {
return updatePosFunc(state, input, multiplicity, pos, memoryManager);
}

inline void combineState(uint8_t* state, uint8_t* otherState,
void combineState(uint8_t* state, uint8_t* otherState,
storage::MemoryManager* memoryManager) const {
return combineFunc(state, otherState, memoryManager);
}

inline void finalizeState(uint8_t* state) const { return finalizeFunc(state); }
void finalizeState(uint8_t* state) const { return finalizeFunc(state); }

inline bool isFunctionDistinct() const { return isDistinct; }
bool isFunctionDistinct() const { return isDistinct; }

inline std::unique_ptr<Function> copy() const override {
std::unique_ptr<Function> copy() const override {
return std::make_unique<AggregateFunction>(name, parameterTypeIDs, returnTypeID,
initializeFunc, updateAllFunc, updatePosFunc, combineFunc, finalizeFunc, isDistinct,
bindFunc, paramRewriteFunc);
}

inline std::unique_ptr<AggregateFunction> clone() const {
std::unique_ptr<AggregateFunction> clone() const {
return std::make_unique<AggregateFunction>(name, parameterTypeIDs, returnTypeID,
initializeFunc, updateAllFunc, updatePosFunc, combineFunc, finalizeFunc, isDistinct,
bindFunc, paramRewriteFunc);
Expand Down
57 changes: 23 additions & 34 deletions src/include/planner/operator/logical_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,14 @@ namespace planner {

class LogicalAggregate : public LogicalOperator {
public:
LogicalAggregate(binder::expression_vector keyExpressions,
binder::expression_vector aggregateExpressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)},
keyExpressions{std::move(keyExpressions)},
aggregateExpressions{std::move(aggregateExpressions)} {}
LogicalAggregate(binder::expression_vector keyExpressions,
binder::expression_vector dependentKeyExpressions,
binder::expression_vector aggregateExpressions, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)},
keyExpressions{std::move(keyExpressions)},
dependentKeyExpressions{std::move(dependentKeyExpressions)},
aggregateExpressions{std::move(aggregateExpressions)} {}
LogicalAggregate(binder::expression_vector keys, binder::expression_vector aggregates,
std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)}, keys{std::move(keys)},
aggregates{std::move(aggregates)} {}
LogicalAggregate(binder::expression_vector keys, binder::expression_vector dependentKeys,
binder::expression_vector aggregates, std::shared_ptr<LogicalOperator> child)
: LogicalOperator{LogicalOperatorType::AGGREGATE, std::move(child)}, keys{std::move(keys)},
dependentKeys{std::move(dependentKeys)}, aggregates{std::move(aggregates)} {}

void computeFactorizedSchema() override;
void computeFlatSchema() override;
Expand All @@ -28,42 +24,35 @@ class LogicalAggregate : public LogicalOperator {

std::string getExpressionsForPrinting() const override;

inline bool hasKeyExpressions() const { return !keyExpressions.empty(); }
inline binder::expression_vector getKeyExpressions() const { return keyExpressions; }
inline void setKeyExpressions(binder::expression_vector expressions) {
keyExpressions = std::move(expressions);
bool hasKeys() const { return !keys.empty(); }
binder::expression_vector getKeys() const { return keys; }
void setKeys(binder::expression_vector expressions) { keys = std::move(expressions); }
binder::expression_vector getDependentKeys() const { return dependentKeys; }
void setDependentKeys(binder::expression_vector expressions) {
dependentKeys = std::move(expressions);
}
inline binder::expression_vector getDependentKeyExpressions() const {
return dependentKeyExpressions;
}
inline void setDependentKeyExpressions(binder::expression_vector expressions) {
dependentKeyExpressions = std::move(expressions);
}
inline binder::expression_vector getAllKeyExpressions() const {
binder::expression_vector getAllKeys() const {
binder::expression_vector result;
result.insert(result.end(), keyExpressions.begin(), keyExpressions.end());
result.insert(result.end(), dependentKeyExpressions.begin(), dependentKeyExpressions.end());
result.insert(result.end(), keys.begin(), keys.end());
result.insert(result.end(), dependentKeys.begin(), dependentKeys.end());
return result;
}
inline binder::expression_vector getAggregateExpressions() const {
return aggregateExpressions;
}
binder::expression_vector getAggregates() const { return aggregates; }

inline std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAggregate>(keyExpressions, dependentKeyExpressions,
aggregateExpressions, children[0]->copy());
std::unique_ptr<LogicalOperator> copy() override {
return make_unique<LogicalAggregate>(keys, dependentKeys, aggregates, children[0]->copy());
}

private:
bool hasDistinctAggregate();
void insertAllExpressionsToGroupAndScope(f_group_pos groupPos);

private:
binder::expression_vector keyExpressions;
binder::expression_vector keys;
// A dependentKeyExpression depend on a keyExpression (e.g. a.age depends on a.ID) and will not
// be treated as a hash key during hash aggregation.
binder::expression_vector dependentKeyExpressions;
binder::expression_vector aggregateExpressions;
binder::expression_vector dependentKeys;
binder::expression_vector aggregates;
};

} // namespace planner
Expand Down
4 changes: 2 additions & 2 deletions src/include/planner/operator/logical_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ class LogicalOperator {
static logical_op_vector_t copy(const logical_op_vector_t& ops);

template<class TARGET>
TARGET* ptrCast() {
return common::ku_dynamic_cast<LogicalOperator*, TARGET*>(this);
const TARGET* constPtrCast() const {
return common::ku_dynamic_cast<const LogicalOperator*, const TARGET*>(this);
}

protected:
Expand Down
40 changes: 18 additions & 22 deletions src/include/processor/operator/aggregate/aggregate_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,20 @@ using update_agg_function_t =

class AggregateHashTable : public BaseHashTable {
public:
// Used by distinct aggregate hash table only.
AggregateHashTable(storage::MemoryManager& memoryManager,
const common::logical_type_vec_t& keysDataTypes,
const std::vector<std::unique_ptr<function::AggregateFunction>>& aggregateFunctions,
uint64_t numEntriesToAllocate, std::unique_ptr<FactorizedTableSchema> tableSchema)
: AggregateHashTable(memoryManager, keysDataTypes, std::vector<common::LogicalType>(),
aggregateFunctions, numEntriesToAllocate, std::move(tableSchema)) {}
const std::vector<common::LogicalType>& keyTypes,
const std::vector<common::LogicalType>& payloadTypes, uint64_t numEntriesToAllocate,
std::unique_ptr<FactorizedTableSchema> tableSchema)
: AggregateHashTable(memoryManager, keyTypes, payloadTypes,
std::vector<std::unique_ptr<function::AggregateFunction>>{} /* empty aggregates */,
std::vector<common::LogicalType>{} /* empty distinct agg key*/, numEntriesToAllocate,
std::move(tableSchema)) {}

AggregateHashTable(storage::MemoryManager& memoryManager,
std::vector<common::LogicalType> keysDataTypes,
std::vector<common::LogicalType> payloadsDataTypes,
std::vector<common::LogicalType> keyTypes, std::vector<common::LogicalType> payloadTypes,
const std::vector<std::unique_ptr<function::AggregateFunction>>& aggregateFunctions,
uint64_t numEntriesToAllocate, std::unique_ptr<FactorizedTableSchema> tableSchema);
const std::vector<common::LogicalType>& distinctAggKeyTypes, uint64_t numEntriesToAllocate,
std::unique_ptr<FactorizedTableSchema> tableSchema);

uint8_t* getEntry(uint64_t idx) { return factorizedTable->getTuple(idx); }

Expand All @@ -62,8 +63,7 @@ class AggregateHashTable : public BaseHashTable {

void append(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
common::DataChunkState* leadingState,
const std::vector<std::unique_ptr<AggregateInput>>& aggregateInputs,
common::DataChunkState* leadingState, const std::vector<AggregateInput>& aggregateInputs,
uint64_t resultSetMultiplicity) {
append(flatKeyVectors, unFlatKeyVectors, std::vector<common::ValueVector*>(), leadingState,
aggregateInputs, resultSetMultiplicity);
Expand All @@ -73,8 +73,7 @@ class AggregateHashTable : public BaseHashTable {
void append(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
const std::vector<common::ValueVector*>& dependentKeyVectors,
common::DataChunkState* leadingState,
const std::vector<std::unique_ptr<AggregateInput>>& aggregateInputs,
common::DataChunkState* leadingState, const std::vector<AggregateInput>& aggregateInputs,
uint64_t resultSetMultiplicity);

bool isAggregateValueDistinctForGroupByKeys(
Expand Down Expand Up @@ -152,8 +151,7 @@ class AggregateHashTable : public BaseHashTable {

void updateAggStates(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
const std::vector<std::unique_ptr<AggregateInput>>& aggregateInputs,
uint64_t resultSetMultiplicity);
const std::vector<AggregateInput>& aggregateInputs, uint64_t resultSetMultiplicity);

// ! This function will only be used by distinct aggregate, which assumes that all keyVectors
// are flat.
Expand Down Expand Up @@ -217,7 +215,7 @@ class AggregateHashTable : public BaseHashTable {
std::unique_ptr<HashSlot*[]> hashSlotsToUpdateAggState;

private:
std::vector<common::LogicalType> dependentKeyDataTypes;
std::vector<common::LogicalType> payloadTypes;
std::vector<std::unique_ptr<function::AggregateFunction>> aggregateFunctions;

//! special handling of distinct aggregate
Expand All @@ -233,13 +231,11 @@ class AggregateHashTable : public BaseHashTable {
std::unique_ptr<uint64_t[]> tmpSlotIdxes;
};

class AggregateHashTableUtils {

public:
static std::vector<std::unique_ptr<AggregateHashTable>> createDistinctHashTables(
struct AggregateHashTableUtils {
static std::unique_ptr<AggregateHashTable> createDistinctHashTable(
storage::MemoryManager& memoryManager,
const std::vector<common::LogicalType>& groupByKeyDataTypes,
const std::vector<std::unique_ptr<function::AggregateFunction>>& aggregateFunctions);
const std::vector<common::LogicalType>& groupByKeyTypes,
const common::LogicalType& distinctKeyType);
};

} // namespace processor
Expand Down
30 changes: 20 additions & 10 deletions src/include/processor/operator/aggregate/aggregate_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,33 @@
namespace kuzu {
namespace processor {

struct AggregateInputInfo {
DataPos aggregateVectorPos;
struct AggregateInfo {
DataPos aggVectorPos;
std::vector<data_chunk_pos_t> multiplicityChunksPos;
common::LogicalType distinctAggKeyType;

AggregateInputInfo(const DataPos& vectorPos,
std::vector<data_chunk_pos_t> multiplicityChunksPos)
: aggregateVectorPos{vectorPos}, multiplicityChunksPos{std::move(multiplicityChunksPos)} {}
AggregateInputInfo(const AggregateInputInfo& other)
: AggregateInputInfo(other.aggregateVectorPos, other.multiplicityChunksPos) {}
inline std::unique_ptr<AggregateInputInfo> copy() {
return std::make_unique<AggregateInputInfo>(*this);
}
AggregateInfo(const DataPos& aggVectorPos, std::vector<data_chunk_pos_t> multiplicityChunksPos,
common::LogicalType distinctAggKeyType)
: aggVectorPos{aggVectorPos}, multiplicityChunksPos{std::move(multiplicityChunksPos)},
distinctAggKeyType{std::move(distinctAggKeyType)} {}
EXPLICIT_COPY_DEFAULT_MOVE(AggregateInfo);

private:
AggregateInfo(const AggregateInfo& other)
: aggVectorPos{other.aggVectorPos}, multiplicityChunksPos{other.multiplicityChunksPos},
distinctAggKeyType{other.distinctAggKeyType} {}
};

struct AggregateInput {
common::ValueVector* aggregateVector;
std::vector<common::DataChunk*> multiplicityChunks;

AggregateInput() = default;
EXPLICIT_COPY_DEFAULT_MOVE(AggregateInput);

private:
AggregateInput(const AggregateInput& other)
: aggregateVector{other.aggregateVector}, multiplicityChunks{other.multiplicityChunks} {}
};

} // namespace processor
Expand Down
Loading

0 comments on commit b71b0a7

Please sign in to comment.