diff --git a/src/include/planner/planner.h b/src/include/planner/planner.h index 1a7c133db5..0036ac1e93 100644 --- a/src/include/planner/planner.h +++ b/src/include/planner/planner.h @@ -38,6 +38,16 @@ class Planner { static std::unique_ptr planExplain(const catalog::Catalog& catalog, const storage::NodesStatisticsAndDeletedIDs& nodesStatistics, const storage::RelsStatistics& relsStatistics, const BoundStatement& statement); + + static std::vector> getAllQueryPlans( + const catalog::Catalog& catalog, + const storage::NodesStatisticsAndDeletedIDs& nodesStatistics, + const storage::RelsStatistics& relsStatistics, const BoundStatement& statement); + + static std::vector> getAllExplainPlans( + const catalog::Catalog& catalog, + const storage::NodesStatisticsAndDeletedIDs& nodesStatistics, + const storage::RelsStatistics& relsStatistics, const BoundStatement& statement); }; } // namespace planner diff --git a/src/planner/planner.cpp b/src/planner/planner.cpp index 41f0c25c5f..1ddfa3b701 100644 --- a/src/planner/planner.cpp +++ b/src/planner/planner.cpp @@ -18,7 +18,6 @@ #include "planner/logical_plan/logical_operator/logical_drop_property.h" #include "planner/logical_plan/logical_operator/logical_drop_table.h" #include "planner/logical_plan/logical_operator/logical_explain.h" -#include "planner/logical_plan/logical_operator/logical_in_query_call.h" #include "planner/logical_plan/logical_operator/logical_rename_property.h" #include "planner/logical_plan/logical_operator/logical_rename_table.h" #include "planner/logical_plan/logical_operator/logical_standalone_call.h" @@ -78,16 +77,16 @@ std::unique_ptr Planner::getBestPlan(const Catalog& catalog, std::vector> Planner::getAllPlans(const Catalog& catalog, const NodesStatisticsAndDeletedIDs& nodesStatistics, const RelsStatistics& relsStatistics, const BoundStatement& statement) { - // We enumerate all plans for our testing framework. This API should only be used for QUERY - // but not DDL or COPY. - assert(statement.getStatementType() == StatementType::QUERY); - auto planner = QueryPlanner(catalog, nodesStatistics, relsStatistics); - std::vector> plans; - for (auto& plan : planner.getAllPlans(statement)) { - // Avoid sharing operator across plans. - plans.push_back(plan->deepCopy()); + // We enumerate all plans for our testing framework. This API should only be used for QUERY, + // EXPLAIN, but not DDL or COPY. + switch (statement.getStatementType()) { + case StatementType::QUERY: + return getAllQueryPlans(catalog, nodesStatistics, relsStatistics, statement); + case StatementType::EXPLAIN: + return getAllExplainPlans(catalog, nodesStatistics, relsStatistics, statement); + default: + throw NotImplementedException("Planner::getAllPlans"); } - return plans; } std::unique_ptr Planner::planCreateNodeTable(const BoundStatement& statement) { @@ -207,5 +206,33 @@ std::unique_ptr Planner::planExplain(const Catalog& catalog, return plan; } +std::vector> Planner::getAllQueryPlans(const catalog::Catalog& catalog, + const storage::NodesStatisticsAndDeletedIDs& nodesStatistics, + const storage::RelsStatistics& relsStatistics, const BoundStatement& statement) { + auto planner = QueryPlanner(catalog, nodesStatistics, relsStatistics); + std::vector> plans; + for (auto& plan : planner.getAllPlans(statement)) { + // Avoid sharing operator across plans. + plans.push_back(plan->deepCopy()); + } + return plans; +} + +std::vector> Planner::getAllExplainPlans( + const catalog::Catalog& catalog, const storage::NodesStatisticsAndDeletedIDs& nodesStatistics, + const storage::RelsStatistics& relsStatistics, const BoundStatement& statement) { + auto& explainStatement = reinterpret_cast(statement); + auto statementToExplain = explainStatement.getStatementToExplain(); + auto plans = getAllPlans(catalog, nodesStatistics, relsStatistics, *statementToExplain); + for (auto& plan : plans) { + auto logicalExplain = make_shared(plan->getLastOperator(), + statement.getStatementResult()->getSingleExpressionToCollect(), + explainStatement.getExplainType(), + explainStatement.getStatementToExplain()->getStatementResult()->getColumns()); + plan->setLastOperator(std::move(logicalExplain)); + } + return plans; +} + } // namespace planner } // namespace kuzu diff --git a/test/test_files/tinysnb/explain/explain.test b/test/test_files/tinysnb/explain/explain.test index e8ce06b07e..88e71d84a7 100644 --- a/test/test_files/tinysnb/explain/explain.test +++ b/test/test_files/tinysnb/explain/explain.test @@ -21,6 +21,7 @@ -LOG ExplainQuery -STATEMENT EXPLAIN MATCH (p:npytable) RETURN p.id +-ENUMERATE ---- ok -LOG ProfileDDL @@ -37,4 +38,5 @@ -LOG ProfileQuery -STATEMENT Profile MATCH (p:npytable) RETURN p.id +-ENUMERATE ---- ok diff --git a/tools/benchmark/benchmark.cpp b/tools/benchmark/benchmark.cpp index fcad0d42ad..e8af3240f8 100644 --- a/tools/benchmark/benchmark.cpp +++ b/tools/benchmark/benchmark.cpp @@ -22,8 +22,7 @@ void Benchmark::loadBenchmark(const std::string& benchmarkPath) { auto queryConfigs = testing::TestHelper::parseTestFile(benchmarkPath); assert(queryConfigs.size() == 1); auto queryConfig = queryConfigs[0].get(); - query = config.enableProfile ? "PROFILE " : ""; - query += queryConfig->query; + query = queryConfig->query; name = queryConfig->name; expectedOutput = queryConfig->expectedTuples; encodedJoin = queryConfig->encodedJoin; @@ -33,6 +32,10 @@ std::unique_ptr Benchmark::run() const { return conn->query(query, encodedJoin); } +std::unique_ptr Benchmark::runWithProfile() const { + return conn->query("PROFILE " + query, encodedJoin); +} + void Benchmark::logQueryInfo( std::ofstream& log, uint32_t runNum, std::vector& actualOutput) const { log << "Run Num: " << runNum << std::endl; diff --git a/tools/benchmark/benchmark_runner.cpp b/tools/benchmark/benchmark_runner.cpp index a73fdd4984..73f7bf3375 100644 --- a/tools/benchmark/benchmark_runner.cpp +++ b/tools/benchmark/benchmark_runner.cpp @@ -1,6 +1,7 @@ #include "benchmark_runner.h" #include +#include #include "spdlog/spdlog.h" @@ -62,6 +63,7 @@ void BenchmarkRunner::runBenchmark(Benchmark* benchmark) const { spdlog::info("Warm up"); benchmark->run(); } + profileQueryIfEnabled(benchmark); std::vector runTimes(config->numRuns); for (auto i = 0u; i < config->numRuns; ++i) { auto queryResult = benchmark->run(); @@ -74,5 +76,16 @@ void BenchmarkRunner::runBenchmark(Benchmark* benchmark) const { &runTimes[0], config->numRuns, config->numRuns /* numRunsToAverage */))); } +void BenchmarkRunner::profileQueryIfEnabled(Benchmark* benchmark) const { + if (config->enableProfile && !config->outputPath.empty()) { + auto profileInfo = benchmark->runWithProfile(); + std::ofstream profileFile( + config->outputPath + "/" + benchmark->name + "_profile.txt", std::ios_base::app); + profileFile << profileInfo->getNext()->toString() << std::endl; + profileFile.flush(); + profileFile.close(); + } +} + } // namespace benchmark } // namespace kuzu diff --git a/tools/benchmark/include/benchmark.h b/tools/benchmark/include/benchmark.h index da03de6196..e45bdd7c0d 100644 --- a/tools/benchmark/include/benchmark.h +++ b/tools/benchmark/include/benchmark.h @@ -15,6 +15,7 @@ class Benchmark { Benchmark(const std::string& benchmarkPath, main::Database* database, BenchmarkConfig& config); std::unique_ptr run() const; + std::unique_ptr runWithProfile() const; void log(uint32_t runNum, main::QueryResult& queryResult) const; private: diff --git a/tools/benchmark/include/benchmark_runner.h b/tools/benchmark/include/benchmark_runner.h index 8b30659b8f..ab409a3308 100644 --- a/tools/benchmark/include/benchmark_runner.h +++ b/tools/benchmark/include/benchmark_runner.h @@ -22,6 +22,8 @@ class BenchmarkRunner { void runBenchmark(Benchmark* benchmark) const; + void profileQueryIfEnabled(Benchmark* benchmark) const; + public: std::unique_ptr config; std::unique_ptr database;