Skip to content

Commit

Permalink
Fix hash node rel
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Oct 12, 2023
1 parent 96b3f7e commit 53acf45
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 92 deletions.
2 changes: 2 additions & 0 deletions src/binder/bind/bind_projection_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ std::unique_ptr<BoundProjectionBody> Binder::bindProjectionBody(
}
}
if (!groupByExpressions.empty()) {
// TODO(Xiyang): we can remove augment group by. But make sure we test sufficient including
// edge case and bug before release.
expression_vector augmentedGroupByExpressions = groupByExpressions;
for (auto& expression : groupByExpressions) {
if (ExpressionUtil::isNodeVariable(*expression)) {
Expand Down
55 changes: 41 additions & 14 deletions src/function/vector_hash_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,73 @@ void VectorHashFunction::computeHash(ValueVector* operand, ValueVector* result)
switch (operand->dataType.getPhysicalType()) {
case PhysicalTypeID::INTERNAL_ID: {
UnaryHashFunctionExecutor::execute<internalID_t, hash_t>(*operand, *result);
} break;
}
return;
case PhysicalTypeID::BOOL: {
UnaryHashFunctionExecutor::execute<bool, hash_t>(*operand, *result);
} break;
}
return;
case PhysicalTypeID::INT64: {
UnaryHashFunctionExecutor::execute<int64_t, hash_t>(*operand, *result);
} break;
}
return;
case PhysicalTypeID::INT32: {
UnaryHashFunctionExecutor::execute<int32_t, hash_t>(*operand, *result);
} break;
}
return;
case PhysicalTypeID::INT16: {
UnaryHashFunctionExecutor::execute<int16_t, hash_t>(*operand, *result);
} break;
}
return;

Check warning on line 33 in src/function/vector_hash_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_hash_functions.cpp#L33

Added line #L33 was not covered by tests
case PhysicalTypeID::INT8: {
UnaryHashFunctionExecutor::execute<int8_t, hash_t>(*operand, *result);
} break;
}
return;
case PhysicalTypeID::UINT64: {
UnaryHashFunctionExecutor::execute<uint64_t, hash_t>(*operand, *result);
} break;
}
return;
case PhysicalTypeID::UINT32: {
UnaryHashFunctionExecutor::execute<uint32_t, hash_t>(*operand, *result);
} break;
}
return;
case PhysicalTypeID::UINT16: {
UnaryHashFunctionExecutor::execute<uint16_t, hash_t>(*operand, *result);
} break;
}
return;
case PhysicalTypeID::UINT8: {
UnaryHashFunctionExecutor::execute<uint8_t, hash_t>(*operand, *result);
} break;
}
return;
case PhysicalTypeID::DOUBLE: {
UnaryHashFunctionExecutor::execute<double, hash_t>(*operand, *result);
} break;
}
return;
case PhysicalTypeID::FLOAT: {
UnaryHashFunctionExecutor::execute<float_t, hash_t>(*operand, *result);
} break;
}
return;
case PhysicalTypeID::STRING: {
UnaryHashFunctionExecutor::execute<ku_string_t, hash_t>(*operand, *result);
} break;
}
return;
case PhysicalTypeID::INTERVAL: {
UnaryHashFunctionExecutor::execute<interval_t, hash_t>(*operand, *result);
} break;
}
return;
case PhysicalTypeID::STRUCT: {
if (operand->dataType.getLogicalTypeID() == LogicalTypeID::NODE) {

Check warning on line 71 in src/function/vector_hash_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_hash_functions.cpp#L70-L71

Added lines #L70 - L71 were not covered by tests
assert(0 == common::StructType::getFieldIdx(&operand->dataType, InternalKeyword::ID));
auto idVector = StructVector::getFieldVector(operand, 0);
UnaryHashFunctionExecutor::execute<internalID_t, hash_t>(*idVector, *result);

Check warning on line 74 in src/function/vector_hash_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_hash_functions.cpp#L74

Added line #L74 was not covered by tests
return;
} else if (operand->dataType.getLogicalTypeID() == LogicalTypeID::REL) {

Check warning on line 76 in src/function/vector_hash_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_hash_functions.cpp#L76

Added line #L76 was not covered by tests
assert(3 == StructType::getFieldIdx(&operand->dataType, InternalKeyword::ID));
auto idVector = StructVector::getFieldVector(operand, 3);
UnaryHashFunctionExecutor::execute<internalID_t, hash_t>(*idVector, *result);

Check warning on line 79 in src/function/vector_hash_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_hash_functions.cpp#L79

Added line #L79 was not covered by tests
return;
}
}
default: {
throw RuntimeException(
"Cannot hash data type " +
Expand Down
12 changes: 2 additions & 10 deletions src/include/processor/operator/aggregate/aggregate_hash_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct HashSlot {
*
*/
class AggregateHashTable;
using compare_function_t = std::function<bool(const uint8_t*, const uint8_t*)>;
using compare_function_t = std::function<bool(common::ValueVector*, uint32_t, const uint8_t*)>;
using update_agg_function_t =
std::function<void(AggregateHashTable*, const std::vector<common::ValueVector*>&,
const std::vector<common::ValueVector*>&, std::unique_ptr<function::AggregateFunction>&,
Expand Down Expand Up @@ -181,16 +181,8 @@ class AggregateHashTable : public BaseHashTable {

void resizeHashTableIfNecessary(uint32_t maxNumDistinctHashKeys);

template<typename type>
static bool compareEntryWithKeys(const uint8_t* keyValue, const uint8_t* entry) {
uint8_t result;
kuzu::function::Equals::operation(*(type*)keyValue, *(type*)entry, result,
nullptr /* leftVector */, nullptr /* rightVector */);
return result != 0;
}

static void getCompareEntryWithKeysFunc(
common::PhysicalTypeID physicalType, compare_function_t& func);
const common::LogicalType& logicalType, compare_function_t& func);

void updateNullAggVectorState(const std::vector<common::ValueVector*>& flatKeyVectors,
const std::vector<common::ValueVector*>& unFlatKeyVectors,
Expand Down
77 changes: 39 additions & 38 deletions src/optimizer/agg_key_dependency_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,64 +26,65 @@ void AggKeyDependencyOptimizer::visitOperator(planner::LogicalOperator* op) {

void AggKeyDependencyOptimizer::visitAggregate(planner::LogicalOperator* op) {
auto agg = (LogicalAggregate*)op;
auto [keyExpressions, payloadExpressions] =
resolveKeysAndDependentKeys(agg->getKeyExpressions());
agg->setKeyExpressions(keyExpressions);
agg->setDependentKeyExpressions(payloadExpressions);
auto [keys, dependentKeys] = resolveKeysAndDependentKeys(agg->getKeyExpressions());
agg->setKeyExpressions(keys);
agg->setDependentKeyExpressions(dependentKeys);
}

void AggKeyDependencyOptimizer::visitDistinct(planner::LogicalOperator* op) {
auto distinct = (LogicalDistinct*)op;
auto [keyExpressions, payloadExpressions] =
resolveKeysAndDependentKeys(distinct->getKeyExpressions());
distinct->setKeyExpressions(keyExpressions);
distinct->setDependentKeyExpressions(payloadExpressions);
auto [keys, dependentKeys] = resolveKeysAndDependentKeys(distinct->getKeyExpressions());
distinct->setKeyExpressions(keys);
distinct->setDependentKeyExpressions(dependentKeys);
}

std::pair<binder::expression_vector, binder::expression_vector>
AggKeyDependencyOptimizer::resolveKeysAndDependentKeys(const binder::expression_vector& keys) {
AggKeyDependencyOptimizer::resolveKeysAndDependentKeys(const expression_vector& inputKeys) {
// Consider example RETURN a.ID, a.age, COUNT(*).
// We first collect a.ID into primaryKeys. Then collect "a" into primaryVarNames.
// Finally, we loop through all group by keys to put non-primary key properties under name "a"
// into dependentKeyExpressions.

// Collect primary keys from group keys.
std::vector<binder::PropertyExpression*> primaryKeys;
for (auto& expression : keys) {
if (expression->expressionType == PROPERTY) {
auto propertyExpression = (binder::PropertyExpression*)expression.get();
if (propertyExpression->isPrimaryKey() || propertyExpression->isInternalID()) {
primaryKeys.push_back(propertyExpression);
// Collect primary variables from keys.
std::unordered_set<std::string> primaryVarNames;
for (auto& key : inputKeys) {
if (key->expressionType == PROPERTY) {
auto property = (PropertyExpression*)key.get();
if (property->isPrimaryKey() || property->isInternalID()) {
primaryVarNames.insert(property->getVariableName());
}
}
}
// Collect variable names whose primary key is part of group keys.
std::unordered_set<std::string> primaryVarNames;
for (auto& primaryKey : primaryKeys) {
primaryVarNames.insert(primaryKey->getVariableName());
}
binder::expression_vector groupExpressions;
binder::expression_vector dependentExpressions;
for (auto& expression : keys) {
if (expression->expressionType == PROPERTY) {
auto propertyExpression = (binder::PropertyExpression*)expression.get();
if (propertyExpression->isPrimaryKey() ||
propertyExpression->isInternalID()) { // NOLINT(bugprone-branch-clone): Collapsing
// is a logical error.
groupExpressions.push_back(expression);
} else if (primaryVarNames.contains(propertyExpression->getVariableName())) {
dependentExpressions.push_back(expression);
// Resolve key dependency.
binder::expression_vector keys;
binder::expression_vector dependentKeys;
for (auto& key : inputKeys) {
if (key->expressionType == PROPERTY) {
auto property = (PropertyExpression*)key.get();
if (property->isPrimaryKey() ||
property->isInternalID()) { // NOLINT(bugprone-branch-clone): Collapsing
// is a logical error.
// Primary properties are always keys.
keys.push_back(key);
} else if (primaryVarNames.contains(property->getVariableName())) {
// Properties depend on any primary property are dependent keys.
// e.g. a.age depends on a._id
dependentKeys.push_back(key);
} else {
keys.push_back(key);
}
} else if (ExpressionUtil::isNodeVariable(*key) || ExpressionUtil::isRelVariable(*key)) {
if (primaryVarNames.contains(key->getUniqueName())) {
// e.g. a depends on a._id
dependentKeys.push_back(key);
} else {
groupExpressions.push_back(expression);
keys.push_back(key);

Check warning on line 81 in src/optimizer/agg_key_dependency_optimizer.cpp

View check run for this annotation

Codecov / codecov/patch

src/optimizer/agg_key_dependency_optimizer.cpp#L81

Added line #L81 was not covered by tests
}
} else if (ExpressionUtil::isNodeVariable(*expression) ||
ExpressionUtil::isRelVariable(*expression)) {
dependentExpressions.push_back(expression);
} else {
groupExpressions.push_back(expression);
keys.push_back(key);
}
}
return std::make_pair(std::move(groupExpressions), std::move(dependentExpressions));
return std::make_pair(std::move(keys), std::move(dependentKeys));
}

} // namespace optimizer
Expand Down
Loading

0 comments on commit 53acf45

Please sign in to comment.