Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hash node rel #2192

Merged
merged 1 commit into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ std::unique_ptr<BoundProjectionBody> Binder::bindProjectionBody(
}
}
if (!groupByExpressions.empty()) {
// TODO(Xiyang): we can remove augment group by. But make sure we test sufficient including
// edge case and bug before release.
expression_vector augmentedGroupByExpressions = groupByExpressions;
for (auto& expression : groupByExpressions) {
if (ExpressionUtil::isNodeVariable(*expression)) {
Expand Down
13 changes: 13 additions & 0 deletions src/function/vector_hash_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ void VectorHashFunction::computeHash(ValueVector* operand, ValueVector* result)
case PhysicalTypeID::INTERVAL: {
UnaryHashFunctionExecutor::execute<interval_t, hash_t>(*operand, *result);
} break;
case PhysicalTypeID::STRUCT: {
if (operand->dataType.getLogicalTypeID() == LogicalTypeID::NODE) {
assert(0 == common::StructType::getFieldIdx(&operand->dataType, InternalKeyword::ID));
UnaryHashFunctionExecutor::execute<internalID_t, hash_t>(
*StructVector::getFieldVector(operand, 0), *result);
break;
} else if (operand->dataType.getLogicalTypeID() == LogicalTypeID::REL) {
assert(3 == StructType::getFieldIdx(&operand->dataType, InternalKeyword::ID));
UnaryHashFunctionExecutor::execute<internalID_t, hash_t>(
*StructVector::getFieldVector(operand, 3), *result);
break;
}
}
default: {
throw RuntimeException(
"Cannot hash data type " +
Expand Down
12 changes: 2 additions & 10 deletions src/include/processor/operator/aggregate/aggregate_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct HashSlot {
*
*/
class AggregateHashTable;
using compare_function_t = std::function<bool(const uint8_t*, const uint8_t*)>;
using compare_function_t = std::function<bool(common::ValueVector*, uint32_t, const uint8_t*)>;
using update_agg_function_t =
std::function<void(AggregateHashTable*, const std::vector<common::ValueVector*>&,
const std::vector<common::ValueVector*>&, std::unique_ptr<function::AggregateFunction>&,
Expand Down Expand Up @@ -181,16 +181,8 @@ class AggregateHashTable : public BaseHashTable {

void resizeHashTableIfNecessary(uint32_t maxNumDistinctHashKeys);

template<typename type>
static bool compareEntryWithKeys(const uint8_t* keyValue, const uint8_t* entry) {
uint8_t result;
kuzu::function::Equals::operation(*(type*)keyValue, *(type*)entry, result,
nullptr /* leftVector */, nullptr /* rightVector */);
return result != 0;
}

static void getCompareEntryWithKeysFunc(
common::PhysicalTypeID physicalType, compare_function_t& func);
const common::LogicalType& logicalType, compare_function_t& func);

void updateNullAggVectorState(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
Expand Down
77 changes: 39 additions & 38 deletions src/optimizer/agg_key_dependency_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,64 +26,65 @@ void AggKeyDependencyOptimizer::visitOperator(planner::LogicalOperator* op) {

void AggKeyDependencyOptimizer::visitAggregate(planner::LogicalOperator* op) {
auto agg = (LogicalAggregate*)op;
auto [keyExpressions, payloadExpressions] =
resolveKeysAndDependentKeys(agg->getKeyExpressions());
agg->setKeyExpressions(keyExpressions);
agg->setDependentKeyExpressions(payloadExpressions);
auto [keys, dependentKeys] = resolveKeysAndDependentKeys(agg->getKeyExpressions());
agg->setKeyExpressions(keys);
agg->setDependentKeyExpressions(dependentKeys);
}

void AggKeyDependencyOptimizer::visitDistinct(planner::LogicalOperator* op) {
auto distinct = (LogicalDistinct*)op;
auto [keyExpressions, payloadExpressions] =
resolveKeysAndDependentKeys(distinct->getKeyExpressions());
distinct->setKeyExpressions(keyExpressions);
distinct->setDependentKeyExpressions(payloadExpressions);
auto [keys, dependentKeys] = resolveKeysAndDependentKeys(distinct->getKeyExpressions());
distinct->setKeyExpressions(keys);
distinct->setDependentKeyExpressions(dependentKeys);
}

std::pair<binder::expression_vector, binder::expression_vector>
AggKeyDependencyOptimizer::resolveKeysAndDependentKeys(const binder::expression_vector& keys) {
AggKeyDependencyOptimizer::resolveKeysAndDependentKeys(const expression_vector& inputKeys) {
// Consider example RETURN a.ID, a.age, COUNT(*).
// We first collect a.ID into primaryKeys. Then collect "a" into primaryVarNames.
// Finally, we loop through all group by keys to put non-primary key properties under name "a"
// into dependentKeyExpressions.

// Collect primary keys from group keys.
std::vector<binder::PropertyExpression*> primaryKeys;
for (auto& expression : keys) {
if (expression->expressionType == PROPERTY) {
auto propertyExpression = (binder::PropertyExpression*)expression.get();
if (propertyExpression->isPrimaryKey() || propertyExpression->isInternalID()) {
primaryKeys.push_back(propertyExpression);
// Collect primary variables from keys.
std::unordered_set<std::string> primaryVarNames;
for (auto& key : inputKeys) {
if (key->expressionType == PROPERTY) {
auto property = (PropertyExpression*)key.get();
if (property->isPrimaryKey() || property->isInternalID()) {
primaryVarNames.insert(property->getVariableName());
}
}
}
// Collect variable names whose primary key is part of group keys.
std::unordered_set<std::string> primaryVarNames;
for (auto& primaryKey : primaryKeys) {
primaryVarNames.insert(primaryKey->getVariableName());
}
binder::expression_vector groupExpressions;
binder::expression_vector dependentExpressions;
for (auto& expression : keys) {
if (expression->expressionType == PROPERTY) {
auto propertyExpression = (binder::PropertyExpression*)expression.get();
if (propertyExpression->isPrimaryKey() ||
propertyExpression->isInternalID()) { // NOLINT(bugprone-branch-clone): Collapsing
// is a logical error.
groupExpressions.push_back(expression);
} else if (primaryVarNames.contains(propertyExpression->getVariableName())) {
dependentExpressions.push_back(expression);
// Resolve key dependency.
binder::expression_vector keys;
binder::expression_vector dependentKeys;
for (auto& key : inputKeys) {
if (key->expressionType == PROPERTY) {
auto property = (PropertyExpression*)key.get();
if (property->isPrimaryKey() ||
property->isInternalID()) { // NOLINT(bugprone-branch-clone): Collapsing
// is a logical error.
// Primary properties are always keys.
keys.push_back(key);
} else if (primaryVarNames.contains(property->getVariableName())) {
// Properties depend on any primary property are dependent keys.
// e.g. a.age depends on a._id
dependentKeys.push_back(key);
} else {
keys.push_back(key);
}
} else if (ExpressionUtil::isNodeVariable(*key) || ExpressionUtil::isRelVariable(*key)) {
if (primaryVarNames.contains(key->getUniqueName())) {
// e.g. a depends on a._id
dependentKeys.push_back(key);
} else {
groupExpressions.push_back(expression);
keys.push_back(key);
}
} else if (ExpressionUtil::isNodeVariable(*expression) ||
ExpressionUtil::isRelVariable(*expression)) {
dependentExpressions.push_back(expression);
} else {
groupExpressions.push_back(expression);
keys.push_back(key);
}
}
return std::make_pair(std::move(groupExpressions), std::move(dependentExpressions));
return std::make_pair(std::move(keys), std::move(dependentKeys));
}

} // namespace optimizer
Expand Down
91 changes: 63 additions & 28 deletions src/processor/operator/aggregate/aggregate_hash_table.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "processor/operator/aggregate/aggregate_hash_table.h"

#include "common/null_buffer.h"
#include "common/utils.h"
#include "function/aggregate/base_count.h"
#include "function/hash/vector_hash_functions.h"
Expand Down Expand Up @@ -138,7 +139,7 @@
for (auto& dataType : keyDataTypes) {
auto size = LogicalTypeUtils::getRowLayoutSize(dataType);
tableSchema->appendColumn(std::make_unique<ColumnSchema>(isUnflat, dataChunkPos, size));
getCompareEntryWithKeysFunc(dataType.getPhysicalType(), compareFuncs[colIdx]);
getCompareEntryWithKeysFunc(dataType, compareFuncs[colIdx]);
numBytesForKeys += size;
colIdx++;
}
Expand Down Expand Up @@ -466,7 +467,6 @@
auto keyVector = keyVectors[i];
assert(keyVector->state->isFlat());
auto pos = keyVector->state->selVector->selectedPositions[0];
auto keyValue = keyVector->getData() + pos * keyVector->getNumBytesPerValue();
auto isKeyVectorNull = keyVector->isNull(pos);
auto isEntryKeyNull = factorizedTable->isNonOverflowColNull(
entry + factorizedTable->getTableSchema()->getNullMapOffset(), i);
Expand All @@ -478,7 +478,7 @@
return false;
}
if (!compareFuncs[i](
keyValue, entry + factorizedTable->getTableSchema()->getColOffset(i))) {
keyVector, pos, entry + factorizedTable->getTableSchema()->getColOffset(i))) {
return false;
}
}
Expand All @@ -494,8 +494,8 @@
if (factorizedTable->hasNoNullGuarantee(colIdx)) {
for (auto i = 0u; i < numMayMatches; i++) {
auto idx = mayMatchIdxes[i];
if (compareFuncs[colIdx](vector->getData() + idx * vector->getNumBytesPerValue(),
hashSlotsToUpdateAggState[idx]->entry + colOffset)) {
if (compareFuncs[colIdx](
vector, idx, hashSlotsToUpdateAggState[idx]->entry + colOffset)) {
mayMatchIdxes[mayMatchIdx++] = idx;
} else {
noMatchIdxes[numNoMatches++] = idx;
Expand All @@ -504,7 +504,6 @@
} else {
for (auto i = 0u; i < numMayMatches; i++) {
auto idx = mayMatchIdxes[i];
auto value = vector->getData() + idx * vector->getNumBytesPerValue();
auto isEntryKeyNull = factorizedTable->isNonOverflowColNull(
hashSlotsToUpdateAggState[idx]->entry +
factorizedTable->getTableSchema()->getNullMapOffset(),
Expand All @@ -514,7 +513,7 @@
continue;
}
if (compareFuncs[colIdx](
value, hashSlotsToUpdateAggState[idx]->entry + colOffset)) {
vector, idx, hashSlotsToUpdateAggState[idx]->entry + colOffset)) {

Check warning on line 516 in src/processor/operator/aggregate/aggregate_hash_table.cpp

View check run for this annotation

Codecov / codecov/patch

src/processor/operator/aggregate/aggregate_hash_table.cpp#L516

Added line #L516 was not covered by tests
mayMatchIdxes[mayMatchIdx++] = idx;
} else {
noMatchIdxes[numNoMatches++] = idx;
Expand All @@ -524,7 +523,6 @@
} else {
for (auto i = 0u; i < numMayMatches; i++) {
auto idx = mayMatchIdxes[i];
auto value = vector->getData() + idx * vector->getNumBytesPerValue();
auto isKeyVectorNull = vector->isNull(idx);
auto isEntryKeyNull = factorizedTable->isNonOverflowColNull(
hashSlotsToUpdateAggState[idx]->entry +
Expand All @@ -538,7 +536,8 @@
continue;
}

if (compareFuncs[colIdx](value, hashSlotsToUpdateAggState[idx]->entry + colOffset)) {
if (compareFuncs[colIdx](
vector, idx, hashSlotsToUpdateAggState[idx]->entry + colOffset)) {

Check warning on line 540 in src/processor/operator/aggregate/aggregate_hash_table.cpp

View check run for this annotation

Codecov / codecov/patch

src/processor/operator/aggregate/aggregate_hash_table.cpp#L539-L540

Added lines #L539 - L540 were not covered by tests
mayMatchIdxes[mayMatchIdx++] = idx;
} else {
noMatchIdxes[numNoMatches++] = idx;
Expand All @@ -555,7 +554,6 @@
uint64_t mayMatchIdx = 0;
auto pos = vector->state->selVector->selectedPositions[0];
auto isVectorNull = vector->isNull(pos);
auto value = vector->getData() + pos * vector->getNumBytesPerValue();
for (auto i = 0u; i < numMayMatches; i++) {
auto idx = mayMatchIdxes[i];
auto isEntryKeyNull = factorizedTable->isNonOverflowColNull(
Expand All @@ -569,7 +567,7 @@
noMatchIdxes[numNoMatches++] = idx;
continue;
}
if (compareFuncs[colIdx](value, hashSlotsToUpdateAggState[idx]->entry + colOffset)) {
if (compareFuncs[colIdx](vector, pos, hashSlotsToUpdateAggState[idx]->entry + colOffset)) {
mayMatchIdxes[mayMatchIdx++] = idx;
} else {
noMatchIdxes[numNoMatches++] = idx;
Expand Down Expand Up @@ -632,68 +630,105 @@
}
}

template<typename T>
static bool compareEntry(common::ValueVector* vector, uint32_t vectorPos, const uint8_t* entry) {
uint8_t result;
auto key = vector->getData() + vectorPos * vector->getNumBytesPerValue();
kuzu::function::Equals::operation(
*(T*)key, *(T*)entry, result, nullptr /* leftVector */, nullptr /* rightVector */);
return result != 0;
}

static bool compareNodeEntry(
common::ValueVector* vector, uint32_t vectorPos, const uint8_t* entry) {
assert(0 == common::StructType::getFieldIdx(&vector->dataType, common::InternalKeyword::ID));
auto idVector = common::StructVector::getFieldVector(vector, 0).get();
return compareEntry<common::internalID_t>(idVector, vectorPos,
entry + common::NullBuffer::getNumBytesForNullValues(
common::StructType::getNumFields(&vector->dataType)));
}

static bool compareRelEntry(common::ValueVector* vector, uint32_t vectorPos, const uint8_t* entry) {
assert(3 == common::StructType::getFieldIdx(&vector->dataType, common::InternalKeyword::ID));
auto idVector = common::StructVector::getFieldVector(vector, 3).get();
return compareEntry<common::internalID_t>(idVector, vectorPos,
entry + sizeof(common::internalID_t) * 2 + sizeof(common::ku_string_t) +
common::NullBuffer::getNumBytesForNullValues(
common::StructType::getNumFields(&vector->dataType)));
}

void AggregateHashTable::getCompareEntryWithKeysFunc(
PhysicalTypeID physicalType, compare_function_t& func) {
switch (physicalType) {
const LogicalType& logicalType, compare_function_t& func) {
switch (logicalType.getPhysicalType()) {
case PhysicalTypeID::INTERNAL_ID: {
func = compareEntryWithKeys<nodeID_t>;
func = compareEntry<nodeID_t>;
return;
}
case PhysicalTypeID::BOOL: {
func = compareEntryWithKeys<bool>;
func = compareEntry<bool>;
return;
}
case PhysicalTypeID::INT64: {
func = compareEntryWithKeys<int64_t>;
func = compareEntry<int64_t>;
return;
}
case PhysicalTypeID::INT32: {
func = compareEntryWithKeys<int32_t>;
func = compareEntry<int32_t>;
return;
}
case PhysicalTypeID::INT16: {
func = compareEntryWithKeys<int16_t>;
func = compareEntry<int16_t>;

Check warning on line 680 in src/processor/operator/aggregate/aggregate_hash_table.cpp

View check run for this annotation

Codecov / codecov/patch

src/processor/operator/aggregate/aggregate_hash_table.cpp#L680

Added line #L680 was not covered by tests
return;
}
case PhysicalTypeID::INT8: {
func = compareEntryWithKeys<int8_t>;
func = compareEntry<int8_t>;
return;
}
case PhysicalTypeID::UINT64: {
func = compareEntryWithKeys<uint64_t>;
func = compareEntry<uint64_t>;
return;
}
case PhysicalTypeID::UINT32: {
func = compareEntryWithKeys<uint32_t>;
func = compareEntry<uint32_t>;
return;
}
case PhysicalTypeID::UINT16: {
func = compareEntryWithKeys<uint16_t>;
func = compareEntry<uint16_t>;
return;
}
case PhysicalTypeID::UINT8: {
func = compareEntryWithKeys<uint8_t>;
func = compareEntry<uint8_t>;
return;
}
case PhysicalTypeID::DOUBLE: {
func = compareEntryWithKeys<double_t>;
func = compareEntry<double_t>;

Check warning on line 704 in src/processor/operator/aggregate/aggregate_hash_table.cpp

View check run for this annotation

Codecov / codecov/patch

src/processor/operator/aggregate/aggregate_hash_table.cpp#L704

Added line #L704 was not covered by tests
return;
}
case PhysicalTypeID::FLOAT: {
func = compareEntryWithKeys<float_t>;
func = compareEntry<float_t>;
return;
}
case PhysicalTypeID::STRING: {
func = compareEntryWithKeys<ku_string_t>;
func = compareEntry<ku_string_t>;
return;
}
case PhysicalTypeID::INTERVAL: {
func = compareEntryWithKeys<interval_t>;
func = compareEntry<interval_t>;

Check warning on line 716 in src/processor/operator/aggregate/aggregate_hash_table.cpp

View check run for this annotation

Codecov / codecov/patch

src/processor/operator/aggregate/aggregate_hash_table.cpp#L716

Added line #L716 was not covered by tests
return;
}
case PhysicalTypeID::STRUCT: {
if (logicalType.getLogicalTypeID() == LogicalTypeID::NODE) {
func = compareNodeEntry;
return;
} else if (logicalType.getLogicalTypeID() == LogicalTypeID::REL) {
func = compareRelEntry;
return;
}
}
default: {
throw RuntimeException(
"Cannot compare data type " + PhysicalTypeUtils::physicalTypeToString(physicalType));
"Cannot compare data type " +
PhysicalTypeUtils::physicalTypeToString(logicalType.getPhysicalType()));

Check warning on line 731 in src/processor/operator/aggregate/aggregate_hash_table.cpp

View check run for this annotation

Codecov / codecov/patch

src/processor/operator/aggregate/aggregate_hash_table.cpp#L730-L731

Added lines #L730 - L731 were not covered by tests
}
}
}
Expand Down
Loading