Skip to content

Commit

Permalink
Merge pull request #2192 from kuzudb/fix-hash-node-rel
Browse files Browse the repository at this point in the history
Fix hash node rel
  • Loading branch information
andyfengHKU committed Oct 13, 2023
2 parents 2d1acf4 + 959229d commit 72dcb79
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 76 deletions.
2 changes: 2 additions & 0 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ std::unique_ptr<BoundProjectionBody> Binder::bindProjectionBody(
}
}
if (!groupByExpressions.empty()) {
// TODO(Xiyang): we can remove augment group by. But make sure we test sufficient including
// edge case and bug before release.
expression_vector augmentedGroupByExpressions = groupByExpressions;
for (auto& expression : groupByExpressions) {
if (ExpressionUtil::isNodeVariable(*expression)) {
Expand Down
13 changes: 13 additions & 0 deletions src/function/vector_hash_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ void VectorHashFunction::computeHash(ValueVector* operand, ValueVector* result)
case PhysicalTypeID::INTERVAL: {
UnaryHashFunctionExecutor::execute<interval_t, hash_t>(*operand, *result);
} break;
case PhysicalTypeID::STRUCT: {
if (operand->dataType.getLogicalTypeID() == LogicalTypeID::NODE) {
assert(0 == common::StructType::getFieldIdx(&operand->dataType, InternalKeyword::ID));
UnaryHashFunctionExecutor::execute<internalID_t, hash_t>(
*StructVector::getFieldVector(operand, 0), *result);
break;
} else if (operand->dataType.getLogicalTypeID() == LogicalTypeID::REL) {
assert(3 == StructType::getFieldIdx(&operand->dataType, InternalKeyword::ID));
UnaryHashFunctionExecutor::execute<internalID_t, hash_t>(
*StructVector::getFieldVector(operand, 3), *result);
break;
}
}
default: {
throw RuntimeException(
"Cannot hash data type " +
Expand Down
12 changes: 2 additions & 10 deletions src/include/processor/operator/aggregate/aggregate_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct HashSlot {
*
*/
class AggregateHashTable;
using compare_function_t = std::function<bool(const uint8_t*, const uint8_t*)>;
using compare_function_t = std::function<bool(common::ValueVector*, uint32_t, const uint8_t*)>;
using update_agg_function_t =
std::function<void(AggregateHashTable*, const std::vector<common::ValueVector*>&,
const std::vector<common::ValueVector*>&, std::unique_ptr<function::AggregateFunction>&,
Expand Down Expand Up @@ -181,16 +181,8 @@ class AggregateHashTable : public BaseHashTable {

void resizeHashTableIfNecessary(uint32_t maxNumDistinctHashKeys);

template<typename type>
static bool compareEntryWithKeys(const uint8_t* keyValue, const uint8_t* entry) {
uint8_t result;
kuzu::function::Equals::operation(*(type*)keyValue, *(type*)entry, result,
nullptr /* leftVector */, nullptr /* rightVector */);
return result != 0;
}

static void getCompareEntryWithKeysFunc(
common::PhysicalTypeID physicalType, compare_function_t& func);
const common::LogicalType& logicalType, compare_function_t& func);

void updateNullAggVectorState(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
Expand Down
77 changes: 39 additions & 38 deletions src/optimizer/agg_key_dependency_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,64 +26,65 @@ void AggKeyDependencyOptimizer::visitOperator(planner::LogicalOperator* op) {

void AggKeyDependencyOptimizer::visitAggregate(planner::LogicalOperator* op) {
auto agg = (LogicalAggregate*)op;
auto [keyExpressions, payloadExpressions] =
resolveKeysAndDependentKeys(agg->getKeyExpressions());
agg->setKeyExpressions(keyExpressions);
agg->setDependentKeyExpressions(payloadExpressions);
auto [keys, dependentKeys] = resolveKeysAndDependentKeys(agg->getKeyExpressions());
agg->setKeyExpressions(keys);
agg->setDependentKeyExpressions(dependentKeys);
}

void AggKeyDependencyOptimizer::visitDistinct(planner::LogicalOperator* op) {
auto distinct = (LogicalDistinct*)op;
auto [keyExpressions, payloadExpressions] =
resolveKeysAndDependentKeys(distinct->getKeyExpressions());
distinct->setKeyExpressions(keyExpressions);
distinct->setDependentKeyExpressions(payloadExpressions);
auto [keys, dependentKeys] = resolveKeysAndDependentKeys(distinct->getKeyExpressions());
distinct->setKeyExpressions(keys);
distinct->setDependentKeyExpressions(dependentKeys);
}

std::pair<binder::expression_vector, binder::expression_vector>
AggKeyDependencyOptimizer::resolveKeysAndDependentKeys(const binder::expression_vector& keys) {
AggKeyDependencyOptimizer::resolveKeysAndDependentKeys(const expression_vector& inputKeys) {
// Consider example RETURN a.ID, a.age, COUNT(*).
// We first collect a.ID into primaryKeys. Then collect "a" into primaryVarNames.
// Finally, we loop through all group by keys to put non-primary key properties under name "a"
// into dependentKeyExpressions.

// Collect primary keys from group keys.
std::vector<binder::PropertyExpression*> primaryKeys;
for (auto& expression : keys) {
if (expression->expressionType == PROPERTY) {
auto propertyExpression = (binder::PropertyExpression*)expression.get();
if (propertyExpression->isPrimaryKey() || propertyExpression->isInternalID()) {
primaryKeys.push_back(propertyExpression);
// Collect primary variables from keys.
std::unordered_set<std::string> primaryVarNames;
for (auto& key : inputKeys) {
if (key->expressionType == PROPERTY) {
auto property = (PropertyExpression*)key.get();
if (property->isPrimaryKey() || property->isInternalID()) {
primaryVarNames.insert(property->getVariableName());
}
}
}
// Collect variable names whose primary key is part of group keys.
std::unordered_set<std::string> primaryVarNames;
for (auto& primaryKey : primaryKeys) {
primaryVarNames.insert(primaryKey->getVariableName());
}
binder::expression_vector groupExpressions;
binder::expression_vector dependentExpressions;
for (auto& expression : keys) {
if (expression->expressionType == PROPERTY) {
auto propertyExpression = (binder::PropertyExpression*)expression.get();
if (propertyExpression->isPrimaryKey() ||
propertyExpression->isInternalID()) { // NOLINT(bugprone-branch-clone): Collapsing
// is a logical error.
groupExpressions.push_back(expression);
} else if (primaryVarNames.contains(propertyExpression->getVariableName())) {
dependentExpressions.push_back(expression);
// Resolve key dependency.
binder::expression_vector keys;
binder::expression_vector dependentKeys;
for (auto& key : inputKeys) {
if (key->expressionType == PROPERTY) {
auto property = (PropertyExpression*)key.get();
if (property->isPrimaryKey() ||
property->isInternalID()) { // NOLINT(bugprone-branch-clone): Collapsing
// is a logical error.
// Primary properties are always keys.
keys.push_back(key);
} else if (primaryVarNames.contains(property->getVariableName())) {
// Properties depend on any primary property are dependent keys.
// e.g. a.age depends on a._id
dependentKeys.push_back(key);
} else {
keys.push_back(key);
}
} else if (ExpressionUtil::isNodeVariable(*key) || ExpressionUtil::isRelVariable(*key)) {
if (primaryVarNames.contains(key->getUniqueName())) {
// e.g. a depends on a._id
dependentKeys.push_back(key);
} else {
groupExpressions.push_back(expression);
keys.push_back(key);
}
} else if (ExpressionUtil::isNodeVariable(*expression) ||
ExpressionUtil::isRelVariable(*expression)) {
dependentExpressions.push_back(expression);
} else {
groupExpressions.push_back(expression);
keys.push_back(key);
}
}
return std::make_pair(std::move(groupExpressions), std::move(dependentExpressions));
return std::make_pair(std::move(keys), std::move(dependentKeys));
}

} // namespace optimizer
Expand Down
91 changes: 63 additions & 28 deletions src/processor/operator/aggregate/aggregate_hash_table.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "processor/operator/aggregate/aggregate_hash_table.h"

#include "common/null_buffer.h"
#include "common/utils.h"
#include "function/aggregate/base_count.h"
#include "function/hash/vector_hash_functions.h"
Expand Down Expand Up @@ -138,7 +139,7 @@ void AggregateHashTable::initializeFT(
for (auto& dataType : keyDataTypes) {
auto size = LogicalTypeUtils::getRowLayoutSize(dataType);
tableSchema->appendColumn(std::make_unique<ColumnSchema>(isUnflat, dataChunkPos, size));
getCompareEntryWithKeysFunc(dataType.getPhysicalType(), compareFuncs[colIdx]);
getCompareEntryWithKeysFunc(dataType, compareFuncs[colIdx]);
numBytesForKeys += size;
colIdx++;
}
Expand Down Expand Up @@ -466,7 +467,6 @@ bool AggregateHashTable::matchFlatGroupByKeys(
auto keyVector = keyVectors[i];
assert(keyVector->state->isFlat());
auto pos = keyVector->state->selVector->selectedPositions[0];
auto keyValue = keyVector->getData() + pos * keyVector->getNumBytesPerValue();
auto isKeyVectorNull = keyVector->isNull(pos);
auto isEntryKeyNull = factorizedTable->isNonOverflowColNull(
entry + factorizedTable->getTableSchema()->getNullMapOffset(), i);
Expand All @@ -478,7 +478,7 @@ bool AggregateHashTable::matchFlatGroupByKeys(
return false;
}
if (!compareFuncs[i](
keyValue, entry + factorizedTable->getTableSchema()->getColOffset(i))) {
keyVector, pos, entry + factorizedTable->getTableSchema()->getColOffset(i))) {
return false;
}
}
Expand All @@ -494,8 +494,8 @@ uint64_t AggregateHashTable::matchUnFlatVecWithFTColumn(
if (factorizedTable->hasNoNullGuarantee(colIdx)) {
for (auto i = 0u; i < numMayMatches; i++) {
auto idx = mayMatchIdxes[i];
if (compareFuncs[colIdx](vector->getData() + idx * vector->getNumBytesPerValue(),
hashSlotsToUpdateAggState[idx]->entry + colOffset)) {
if (compareFuncs[colIdx](
vector, idx, hashSlotsToUpdateAggState[idx]->entry + colOffset)) {
mayMatchIdxes[mayMatchIdx++] = idx;
} else {
noMatchIdxes[numNoMatches++] = idx;
Expand All @@ -504,7 +504,6 @@ uint64_t AggregateHashTable::matchUnFlatVecWithFTColumn(
} else {
for (auto i = 0u; i < numMayMatches; i++) {
auto idx = mayMatchIdxes[i];
auto value = vector->getData() + idx * vector->getNumBytesPerValue();
auto isEntryKeyNull = factorizedTable->isNonOverflowColNull(
hashSlotsToUpdateAggState[idx]->entry +
factorizedTable->getTableSchema()->getNullMapOffset(),
Expand All @@ -514,7 +513,7 @@ uint64_t AggregateHashTable::matchUnFlatVecWithFTColumn(
continue;
}
if (compareFuncs[colIdx](
value, hashSlotsToUpdateAggState[idx]->entry + colOffset)) {
vector, idx, hashSlotsToUpdateAggState[idx]->entry + colOffset)) {
mayMatchIdxes[mayMatchIdx++] = idx;
} else {
noMatchIdxes[numNoMatches++] = idx;
Expand All @@ -524,7 +523,6 @@ uint64_t AggregateHashTable::matchUnFlatVecWithFTColumn(
} else {
for (auto i = 0u; i < numMayMatches; i++) {
auto idx = mayMatchIdxes[i];
auto value = vector->getData() + idx * vector->getNumBytesPerValue();
auto isKeyVectorNull = vector->isNull(idx);
auto isEntryKeyNull = factorizedTable->isNonOverflowColNull(
hashSlotsToUpdateAggState[idx]->entry +
Expand All @@ -538,7 +536,8 @@ uint64_t AggregateHashTable::matchUnFlatVecWithFTColumn(
continue;
}

if (compareFuncs[colIdx](value, hashSlotsToUpdateAggState[idx]->entry + colOffset)) {
if (compareFuncs[colIdx](
vector, idx, hashSlotsToUpdateAggState[idx]->entry + colOffset)) {
mayMatchIdxes[mayMatchIdx++] = idx;
} else {
noMatchIdxes[numNoMatches++] = idx;
Expand All @@ -555,7 +554,6 @@ uint64_t AggregateHashTable::matchFlatVecWithFTColumn(
uint64_t mayMatchIdx = 0;
auto pos = vector->state->selVector->selectedPositions[0];
auto isVectorNull = vector->isNull(pos);
auto value = vector->getData() + pos * vector->getNumBytesPerValue();
for (auto i = 0u; i < numMayMatches; i++) {
auto idx = mayMatchIdxes[i];
auto isEntryKeyNull = factorizedTable->isNonOverflowColNull(
Expand All @@ -569,7 +567,7 @@ uint64_t AggregateHashTable::matchFlatVecWithFTColumn(
noMatchIdxes[numNoMatches++] = idx;
continue;
}
if (compareFuncs[colIdx](value, hashSlotsToUpdateAggState[idx]->entry + colOffset)) {
if (compareFuncs[colIdx](vector, pos, hashSlotsToUpdateAggState[idx]->entry + colOffset)) {
mayMatchIdxes[mayMatchIdx++] = idx;
} else {
noMatchIdxes[numNoMatches++] = idx;
Expand Down Expand Up @@ -632,68 +630,105 @@ void AggregateHashTable::resizeHashTableIfNecessary(uint32_t maxNumDistinctHashK
}
}

template<typename T>
static bool compareEntry(common::ValueVector* vector, uint32_t vectorPos, const uint8_t* entry) {
uint8_t result;
auto key = vector->getData() + vectorPos * vector->getNumBytesPerValue();
kuzu::function::Equals::operation(
*(T*)key, *(T*)entry, result, nullptr /* leftVector */, nullptr /* rightVector */);
return result != 0;
}

static bool compareNodeEntry(
common::ValueVector* vector, uint32_t vectorPos, const uint8_t* entry) {
assert(0 == common::StructType::getFieldIdx(&vector->dataType, common::InternalKeyword::ID));
auto idVector = common::StructVector::getFieldVector(vector, 0).get();
return compareEntry<common::internalID_t>(idVector, vectorPos,
entry + common::NullBuffer::getNumBytesForNullValues(
common::StructType::getNumFields(&vector->dataType)));
}

static bool compareRelEntry(common::ValueVector* vector, uint32_t vectorPos, const uint8_t* entry) {
assert(3 == common::StructType::getFieldIdx(&vector->dataType, common::InternalKeyword::ID));
auto idVector = common::StructVector::getFieldVector(vector, 3).get();
return compareEntry<common::internalID_t>(idVector, vectorPos,
entry + sizeof(common::internalID_t) * 2 + sizeof(common::ku_string_t) +
common::NullBuffer::getNumBytesForNullValues(
common::StructType::getNumFields(&vector->dataType)));
}

void AggregateHashTable::getCompareEntryWithKeysFunc(
PhysicalTypeID physicalType, compare_function_t& func) {
switch (physicalType) {
const LogicalType& logicalType, compare_function_t& func) {
switch (logicalType.getPhysicalType()) {
case PhysicalTypeID::INTERNAL_ID: {
func = compareEntryWithKeys<nodeID_t>;
func = compareEntry<nodeID_t>;
return;
}
case PhysicalTypeID::BOOL: {
func = compareEntryWithKeys<bool>;
func = compareEntry<bool>;
return;
}
case PhysicalTypeID::INT64: {
func = compareEntryWithKeys<int64_t>;
func = compareEntry<int64_t>;
return;
}
case PhysicalTypeID::INT32: {
func = compareEntryWithKeys<int32_t>;
func = compareEntry<int32_t>;
return;
}
case PhysicalTypeID::INT16: {
func = compareEntryWithKeys<int16_t>;
func = compareEntry<int16_t>;
return;
}
case PhysicalTypeID::INT8: {
func = compareEntryWithKeys<int8_t>;
func = compareEntry<int8_t>;
return;
}
case PhysicalTypeID::UINT64: {
func = compareEntryWithKeys<uint64_t>;
func = compareEntry<uint64_t>;
return;
}
case PhysicalTypeID::UINT32: {
func = compareEntryWithKeys<uint32_t>;
func = compareEntry<uint32_t>;
return;
}
case PhysicalTypeID::UINT16: {
func = compareEntryWithKeys<uint16_t>;
func = compareEntry<uint16_t>;
return;
}
case PhysicalTypeID::UINT8: {
func = compareEntryWithKeys<uint8_t>;
func = compareEntry<uint8_t>;
return;
}
case PhysicalTypeID::DOUBLE: {
func = compareEntryWithKeys<double_t>;
func = compareEntry<double_t>;
return;
}
case PhysicalTypeID::FLOAT: {
func = compareEntryWithKeys<float_t>;
func = compareEntry<float_t>;
return;
}
case PhysicalTypeID::STRING: {
func = compareEntryWithKeys<ku_string_t>;
func = compareEntry<ku_string_t>;
return;
}
case PhysicalTypeID::INTERVAL: {
func = compareEntryWithKeys<interval_t>;
func = compareEntry<interval_t>;
return;
}
case PhysicalTypeID::STRUCT: {
if (logicalType.getLogicalTypeID() == LogicalTypeID::NODE) {
func = compareNodeEntry;
return;
} else if (logicalType.getLogicalTypeID() == LogicalTypeID::REL) {
func = compareRelEntry;
return;
}
}
default: {
throw RuntimeException(
"Cannot compare data type " + PhysicalTypeUtils::physicalTypeToString(physicalType));
"Cannot compare data type " +
PhysicalTypeUtils::physicalTypeToString(logicalType.getPhysicalType()));
}
}
}
Expand Down
Loading

0 comments on commit 72dcb79

Please sign in to comment.