Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Mar 21, 2023
1 parent d9d279a commit 3bc81d9
Show file tree
Hide file tree
Showing 15 changed files with 98 additions and 57 deletions.
3 changes: 2 additions & 1 deletion src/include/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ struct LoggerConstants {
};

struct EnumeratorKnobs {
static constexpr double PREDICATE_SELECTIVITY = 0.1;
static constexpr double NON_EQUALITY_PREDICATE_SELECTIVITY = 0.1;
static constexpr double EQUALITY_PREDICATE_SELECTIVITY = 0.01;
static constexpr double FLAT_PROBE_PENALTY = 10;
};

Expand Down
1 change: 0 additions & 1 deletion src/include/main/kuzu_fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ namespace testing {
class ApiTest;
class BaseGraphTest;
class TestHelper;
class TestHelper;
} // namespace testing

namespace benchmark {
Expand Down
4 changes: 0 additions & 4 deletions src/include/planner/join_order_enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ class JoinOrderEnumerator {
void planFiltersForHashJoin(binder::expression_vector& predicates, LogicalPlan& plan);

void appendScanNode(std::shared_ptr<NodeExpression>& node, LogicalPlan& plan);
void appendIndexScanNode(std::shared_ptr<NodeExpression>& node,
std::shared_ptr<Expression> indexExpression, LogicalPlan& plan);

bool needExtendToNewGroup(
RelExpression& rel, NodeExpression& boundNode, common::RelDirection direction);
Expand All @@ -120,8 +118,6 @@ class JoinOrderEnumerator {
std::vector<std::unique_ptr<LogicalPlan>>& buildPlans);
static void appendCrossProduct(LogicalPlan& probePlan, LogicalPlan& buildPlan);

binder::expression_vector getPropertiesForVariable(
Expression& expression, Expression& variable);
uint64_t getExtensionRate(
const RelExpression& rel, const NodeExpression& boundNode, common::RelDirection direction);

Expand Down
2 changes: 1 addition & 1 deletion src/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ add_library(kuzu_optimizer
OBJECT
asp_optimizer.cpp
factorization_rewriter.cpp
index_nested_loop_join_optimizer.cpp
filter_push_down_optimizer.cpp
logical_operator_collector.cpp
logical_operator_visitor.cpp
optimizer.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "optimizer/index_nested_loop_join_optimizer.h"
#include "optimizer/filter_push_down_optimizer.h"

#include "binder/expression/property_expression.h"
#include "planner/logical_plan/logical_operator/logical_expressions_scan.h"
Expand Down Expand Up @@ -96,14 +96,6 @@ std::shared_ptr<LogicalOperator> FilterPushDownOptimizer::visitCrossProductRepla
node, primaryKeyEqualityComparison->getChild(1), op->getChild(0));
indexScan->computeFlatSchema();
// Append right branch (except for node table scan) to left branch
// auto rightOp = op->getChild(1);
// while (rightOp->getNumChildren() != 0) {
// if (rightOp->getChild(0)->getOperatorType() == LogicalOperatorType::SCAN_NODE) {
// rightOp->setChild(0, std::move(indexScan));
// break;
// }
// rightOp = rightOp->getChild(0);
// }
trivialSubPlan[trivialSubPlan.size() - 2]->setChild(0, std::move(indexScan));
for (auto i = 0; i < trivialSubPlan.size() - 1; ++i) {
trivialSubPlan[i]->computeFlatSchema();
Expand Down
8 changes: 1 addition & 7 deletions src/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#include "optimizer/asp_optimizer.h"
#include "optimizer/factorization_rewriter.h"
#include "optimizer/index_nested_loop_join_optimizer.h"
#include "optimizer/filter_push_down_optimizer.h"
#include "optimizer/projection_push_down_optimizer.h"
#include "optimizer/remove_factorization_rewriter.h"
#include "optimizer/remove_unnecessary_join_optimizer.h"
Expand All @@ -18,22 +18,16 @@ void Optimizer::optimize(planner::LogicalPlan* plan) {
auto removeUnnecessaryJoinOptimizer = RemoveUnnecessaryJoinOptimizer();
removeUnnecessaryJoinOptimizer.rewrite(plan);

auto a = plan->toString();

auto filterPushDownOptimizer = FilterPushDownOptimizer();
filterPushDownOptimizer.rewrite(plan);

auto b = plan->toString();

// ASP optimizer should be applied after optimizers that manipulate hash join.
auto aspOptimizer = ASPOptimizer();
aspOptimizer.rewrite(plan);

auto projectionPushDownOptimizer = ProjectionPushDownOptimizer();
projectionPushDownOptimizer.rewrite(plan);

auto c = plan->toString();

auto factorizationRewriter = FactorizationRewriter();
factorizationRewriter.rewrite(plan);
}
Expand Down
33 changes: 1 addition & 32 deletions src/planner/join_order_enumerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,19 +431,6 @@ void JoinOrderEnumerator::appendScanNode(std::shared_ptr<NodeExpression>& node,
plan.setLastOperator(std::move(scan));
}

void JoinOrderEnumerator::appendIndexScanNode(std::shared_ptr<NodeExpression>& node,
std::shared_ptr<Expression> indexExpression, LogicalPlan& plan) {
assert(plan.isEmpty());
QueryPlanner::appendExpressionsScan(expression_vector{indexExpression}, plan);
auto scan =
make_shared<LogicalIndexScanNode>(node, std::move(indexExpression), plan.getLastOperator());
scan->computeFactorizedSchema();
// update cardinality
auto group = scan->getSchema()->getGroup(node->getInternalIDPropertyName());
group->setMultiplier(1);
plan.setLastOperator(std::move(scan));
}

// When extend might increase cardinality (i.e. n * m), we extend to a new factorization group.
bool JoinOrderEnumerator::needExtendToNewGroup(
RelExpression& rel, NodeExpression& boundNode, RelDirection direction) {
Expand Down Expand Up @@ -510,7 +497,7 @@ void JoinOrderEnumerator::appendHashJoin(const expression_vector& joinNodeIDs, J
probePlan.increaseCost(probePlan.getCardinality() + buildPlan.getCardinality());
if (!groupsPosToFlattenOnProbeSide.empty()) {
probePlan.multiplyCardinality(
buildPlan.getCardinality() * EnumeratorKnobs::PREDICATE_SELECTIVITY);
buildPlan.getCardinality() * EnumeratorKnobs::EQUALITY_PREDICATE_SELECTIVITY);
probePlan.multiplyCost(EnumeratorKnobs::FLAT_PROBE_PENALTY);
}
probePlan.setLastOperator(std::move(hashJoin));
Expand Down Expand Up @@ -562,24 +549,6 @@ void JoinOrderEnumerator::appendCrossProduct(LogicalPlan& probePlan, LogicalPlan
probePlan.setLastOperator(std::move(crossProduct));
}

expression_vector JoinOrderEnumerator::getPropertiesForVariable(
Expression& expression, Expression& variable) {
expression_vector result;
std::unordered_set<std::string> matchedPropertyNames; // remove duplication
for (auto& expr : expression.getSubPropertyExpressions()) {
auto propertyExpression = (PropertyExpression*)expr.get();
if (propertyExpression->getVariableName() != variable.getUniqueName()) {
continue;
}
if (matchedPropertyNames.contains(propertyExpression->getUniqueName())) {
continue;
}
matchedPropertyNames.insert(propertyExpression->getUniqueName());
result.push_back(expr);
}
return result;
}

uint64_t JoinOrderEnumerator::getExtensionRate(
const RelExpression& rel, const NodeExpression& boundNode, RelDirection direction) {
double numBoundNodes = 0;
Expand Down
15 changes: 13 additions & 2 deletions src/planner/query_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,16 +317,27 @@ void QueryPlanner::appendFilters(
}
}

static bool isPrimaryKey(const Expression& expression) {
if (expression.expressionType != common::ExpressionType::PROPERTY) {
return false;
}
return ((PropertyExpression&)expression).isPrimaryKey();
}

void QueryPlanner::appendFilter(const std::shared_ptr<Expression>& predicate, LogicalPlan& plan) {
planSubqueryIfNecessary(predicate, plan);
auto filter = make_shared<LogicalFilter>(predicate, plan.getLastOperator());
QueryPlanner::appendFlattens(filter->getGroupsPosToFlatten(), plan);
filter->setChild(0, plan.getLastOperator());
filter->computeFactorizedSchema();
if (predicate->expressionType == common::EQUALS) {
// TODO
if (isPrimaryKey(*predicate->getChild(0)) || isPrimaryKey(*predicate->getChild(1))) {
plan.setCardinality(1);
} else {
plan.multiplyCardinality(EnumeratorKnobs::EQUALITY_PREDICATE_SELECTIVITY);
}
} else {
plan.multiplyCardinality(EnumeratorKnobs::PREDICATE_SELECTIVITY);
plan.multiplyCardinality(EnumeratorKnobs::NON_EQUALITY_PREDICATE_SELECTIVITY);
}
plan.setLastOperator(std::move(filter));
}
Expand Down
2 changes: 2 additions & 0 deletions src/planner/subplans_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ void SubPlansTable::PlanSet::addPlan(std::unique_ptr<LogicalPlan> plan) {
auto currentPlan = plans[idx].get();
if (currentPlan->getCost() > plan->getCost()) {
plans[idx] = std::move(plan);
schemaToPlanIdx.erase(schema);
schemaToPlanIdx.insert({schema, idx});
}
}
}
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_subdirectory(common)
add_subdirectory(copy)
add_subdirectory(demo_db)
add_subdirectory(main)
add_subdirectory(optimizer)
add_subdirectory(parser)
add_subdirectory(processor)
add_subdirectory(runner)
Expand Down
3 changes: 3 additions & 0 deletions test/include/test_helper/test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class TestHelper {
return KUZU_ROOT_DIRECTORY + std::string("/") + path;
}

static std::unique_ptr<planner::LogicalPlan> getLogicalPlan(
const std::string& query, Connection& conn);

private:
static void initializeConnection(TestQueryConfig* config, Connection& conn);
static bool testQuery(TestQueryConfig* config, Connection& conn);
Expand Down
1 change: 1 addition & 0 deletions test/optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_kuzu_test(optimizer_test optimizer_test.cpp)
67 changes: 67 additions & 0 deletions test/optimizer/optimizer_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "graph_test/graph_test.h"

namespace kuzu {
namespace testing {

class OptimizerTest : public DBTest {
public:
std::string getInputDir() override {
return TestHelper::appendKuzuRootPath("dataset/tinysnb/");
}

std::shared_ptr<planner::LogicalOperator> getRoot(const std::string& query) {
return TestHelper::getLogicalPlan(query, *conn)->getLastOperator();
}
};

TEST_F(OptimizerTest, FilterPushDownTest) {
auto op = getRoot("MATCH (a:person) WHERE a.ID < 0 AND a.fName='Alice' RETURN a.gender;");
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::PROJECTION);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::SCAN_NODE_PROPERTY);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::FILTER);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::SCAN_NODE_PROPERTY);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::FILTER);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::SCAN_NODE_PROPERTY);
}

TEST_F(OptimizerTest, IndexScanTest) {
auto op = getRoot("MATCH (a:person) WHERE a.ID = 0 AND a.fName='Alice' RETURN a.gender;");
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::PROJECTION);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::SCAN_NODE_PROPERTY);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::FILTER);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::SCAN_NODE_PROPERTY);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::INDEX_SCAN_NODE);
}

TEST_F(OptimizerTest, RemoveUnnecessaryJoinTest) {
auto op = getRoot("MATCH (a:person)-[e:knows]->(b:person) RETURN e.date;");
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::PROJECTION);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::EXTEND);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::FLATTEN);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::SCAN_NODE);
}

TEST_F(OptimizerTest, ProjectionPushDownJoinTest) {
auto op = getRoot(
"MATCH (a:person)-[e:knows]->(b:person) WHERE a.age > 0 AND b.age>0 RETURN a.ID, b.ID;");
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::PROJECTION);
op = op->getChild(0);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::HASH_JOIN);
op = op->getChild(1);
ASSERT_EQ(op->getOperatorType(), planner::LogicalOperatorType::PROJECTION);
}

} // namespace testing
} // namespace kuzu
5 changes: 5 additions & 0 deletions test/test_helper/test_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,10 @@ bool TestHelper::testQuery(TestQueryConfig* config, Connection& conn) {
return numPassedPlans == numPlans;
}

std::unique_ptr<planner::LogicalPlan> TestHelper::getLogicalPlan(
const std::string& query, kuzu::main::Connection& conn) {
return std::move(conn.prepare(query)->logicalPlans[0]);
}

} // namespace testing
} // namespace kuzu

0 comments on commit 3bc81d9

Please sign in to comment.