From 24c320423af167955b6ff00ebadac21bb3cf65db Mon Sep 17 00:00:00 2001 From: xiyang Date: Sun, 12 Nov 2023 21:57:52 +0800 Subject: [PATCH] Add parsed statement visitor --- src/binder/bound_statement_visitor.cpp | 13 +- src/binder/visitor/CMakeLists.txt | 3 +- .../visitor/statement_read_write_analyzer.cpp | 18 --- src/include/binder/bound_statement_visitor.h | 6 +- .../visitor/statement_read_write_analyzer.h | 29 ---- src/include/parser/parsed_statement_visitor.h | 54 ++++++++ .../visitor/statement_read_write_analyzer.h | 30 +++++ src/main/connection.cpp | 5 +- src/parser/CMakeLists.txt | 6 +- src/parser/parsed_statement_visitor.cpp | 126 ++++++++++++++++++ src/parser/visitor/CMakeLists.txt | 8 ++ .../visitor/statement_read_write_analyzer.cpp | 12 ++ 12 files changed, 247 insertions(+), 63 deletions(-) delete mode 100644 src/binder/visitor/statement_read_write_analyzer.cpp delete mode 100644 src/include/binder/visitor/statement_read_write_analyzer.h create mode 100644 src/include/parser/parsed_statement_visitor.h create mode 100644 src/include/parser/visitor/statement_read_write_analyzer.h create mode 100644 src/parser/parsed_statement_visitor.cpp create mode 100644 src/parser/visitor/CMakeLists.txt create mode 100644 src/parser/visitor/statement_read_write_analyzer.cpp diff --git a/src/binder/bound_statement_visitor.cpp b/src/binder/bound_statement_visitor.cpp index 34df8018ca..834435871e 100644 --- a/src/binder/bound_statement_visitor.cpp +++ b/src/binder/bound_statement_visitor.cpp @@ -1,7 +1,7 @@ #include "binder/bound_statement_visitor.h" #include "binder/bound_explain.h" -#include "common/exception/not_implemented.h" +#include "binder/query/bound_regular_query.h" using namespace kuzu::common; @@ -11,7 +11,7 @@ namespace binder { void BoundStatementVisitor::visit(const kuzu::binder::BoundStatement& statement) { switch (statement.getStatementType()) { case StatementType::QUERY: { - visitRegularQuery((BoundRegularQuery&)statement); + visitRegularQuery(statement); } break; case StatementType::CREATE_TABLE: { visitCreateTable(statement); @@ -44,11 +44,12 @@ void BoundStatementVisitor::visit(const kuzu::binder::BoundStatement& statement) visitTransaction(statement); } break; default: - throw NotImplementedException("BoundStatementVisitor::visit"); + KU_UNREACHABLE; } } -void BoundStatementVisitor::visitRegularQuery(const BoundRegularQuery& regularQuery) { +void BoundStatementVisitor::visitRegularQuery(const BoundStatement& statement) { + auto& regularQuery = reinterpret_cast(statement); for (auto i = 0u; i < regularQuery.getNumSingleQueries(); ++i) { visitSingleQuery(*regularQuery.getSingleQuery(i)); } @@ -94,7 +95,7 @@ void BoundStatementVisitor::visitReadingClause(const BoundReadingClause& reading visitLoadFrom(readingClause); } break; default: - throw NotImplementedException("BoundStatementVisitor::visitReadingClause"); + KU_UNREACHABLE; } } @@ -113,7 +114,7 @@ void BoundStatementVisitor::visitUpdatingClause(const BoundUpdatingClause& updat visitMerge(updatingClause); } break; default: - throw NotImplementedException("BoundStatementVisitor::visitUpdatingClause"); + KU_UNREACHABLE; } } diff --git a/src/binder/visitor/CMakeLists.txt b/src/binder/visitor/CMakeLists.txt index d352a04838..ddf93f29dd 100644 --- a/src/binder/visitor/CMakeLists.txt +++ b/src/binder/visitor/CMakeLists.txt @@ -1,8 +1,7 @@ add_library( kuzu_binder_visitor OBJECT - property_collector.cpp - statement_read_write_analyzer.cpp) + property_collector.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/binder/visitor/statement_read_write_analyzer.cpp b/src/binder/visitor/statement_read_write_analyzer.cpp deleted file mode 100644 index adc3cff9b2..0000000000 --- a/src/binder/visitor/statement_read_write_analyzer.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "binder/visitor/statement_read_write_analyzer.h" - -namespace kuzu { -namespace binder { - -bool StatementReadWriteAnalyzer::isReadOnly(const kuzu::binder::BoundStatement& statement) { - visit(statement); - return readOnly; -} - -void StatementReadWriteAnalyzer::visitQueryPart(const NormalizedQueryPart& queryPart) { - if (queryPart.hasUpdatingClause()) { - readOnly = false; - } -} - -} // namespace binder -} // namespace kuzu diff --git a/src/include/binder/bound_statement_visitor.h b/src/include/binder/bound_statement_visitor.h index 2af74fd2bb..c0f464dd7c 100644 --- a/src/include/binder/bound_statement_visitor.h +++ b/src/include/binder/bound_statement_visitor.h @@ -1,6 +1,6 @@ #pragma once -#include "binder/query/bound_regular_query.h" +#include "binder/query/normalized_single_query.h" #include "bound_statement.h" namespace kuzu { @@ -13,9 +13,7 @@ class BoundStatementVisitor { void visit(const BoundStatement& statement); - virtual void visitRegularQuery(const BoundRegularQuery& regularQuery); virtual void visitSingleQuery(const NormalizedSingleQuery& singleQuery); - virtual void visitQueryPart(const NormalizedQueryPart& queryPart); protected: virtual void visitCreateTable(const BoundStatement& statement) {} @@ -29,6 +27,8 @@ class BoundStatementVisitor { virtual void visitCreateMacro(const BoundStatement& statement) {} virtual void visitTransaction(const BoundStatement& statement) {} + virtual void visitRegularQuery(const BoundStatement& statement); + virtual void visitQueryPart(const NormalizedQueryPart& queryPart); void visitReadingClause(const BoundReadingClause& readingClause); virtual void visitMatch(const BoundReadingClause& readingClause) {} virtual void visitUnwind(const BoundReadingClause& readingClause) {} diff --git a/src/include/binder/visitor/statement_read_write_analyzer.h b/src/include/binder/visitor/statement_read_write_analyzer.h deleted file mode 100644 index a4cd0d96f5..0000000000 --- a/src/include/binder/visitor/statement_read_write_analyzer.h +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#include "binder/bound_statement_visitor.h" - -namespace kuzu { -namespace binder { - -class StatementReadWriteAnalyzer final : public BoundStatementVisitor { -public: - StatementReadWriteAnalyzer() : BoundStatementVisitor(), readOnly{true} {} - - bool isReadOnly(const BoundStatement& statement); - -private: - void visitCommentOn(const BoundStatement& /*statement*/) { readOnly = false; } - void visitCopyFrom(const BoundStatement& /*statement*/) { readOnly = false; } - void visitCreateMacro(const BoundStatement& /*statement*/) { readOnly = false; } - void visitCreateTable(const BoundStatement& /*statement*/) { readOnly = false; } - void visitDropTable(const BoundStatement& /*statement*/) { readOnly = false; } - void visitAlter(const BoundStatement& /*statement*/) { readOnly = false; } - void visitStandaloneCall(const BoundStatement& /*statement*/) { readOnly = false; } - void visitQueryPart(const NormalizedQueryPart& queryPart); - -private: - bool readOnly; -}; - -} // namespace binder -} // namespace kuzu diff --git a/src/include/parser/parsed_statement_visitor.h b/src/include/parser/parsed_statement_visitor.h new file mode 100644 index 0000000000..cf7f502f9c --- /dev/null +++ b/src/include/parser/parsed_statement_visitor.h @@ -0,0 +1,54 @@ +#pragma once + +#include "statement.h" + +namespace kuzu { +namespace parser { + +class SingleQuery; +class QueryPart; +class ReadingClause; +class UpdatingClause; +class WithClause; +class ReturnClause; + +class StatementVisitor { +public: + StatementVisitor() = default; + virtual ~StatementVisitor() = default; + + void visit(const Statement& statement); + +private: + // LCOV_EXCL_START + virtual void visitQuery(const Statement& statement); + virtual void visitSingleQuery(const SingleQuery* singleQuery); + virtual void visitQueryPart(const QueryPart* queryPart); + virtual void visitReadingClause(const ReadingClause* readingClause); + virtual void visitMatch(const ReadingClause* readingClause) {} + virtual void visitUnwind(const ReadingClause* readingClause) {} + virtual void visitInQueryCall(const ReadingClause* readingClause) {} + virtual void visitLoadFrom(const ReadingClause* readingClause) {} + virtual void visitUpdatingClause(const UpdatingClause* updatingClause); + virtual void visitSet(const UpdatingClause* updatingClause) {} + virtual void visitDelete(const UpdatingClause* updatingClause) {} + virtual void visitInsert(const UpdatingClause* updatingClause) {} + virtual void visitMerge(const UpdatingClause* updatingClause) {} + virtual void visitWithClause(const WithClause* withClause) {} + virtual void visitReturnClause(const ReturnClause* returnClause) {} + + virtual void visitCreateTable(const Statement& statement) {} + virtual void visitDropTable(const Statement& statement) {} + virtual void visitAlter(const Statement& statement) {} + virtual void visitCopyFrom(const Statement& statement) {} + virtual void visitCopyTo(const Statement& statement) {} + virtual void visitStandaloneCall(const Statement& statement) {} + virtual void visitExplain(const Statement& statement); + virtual void visitCreateMacro(const Statement& statement) {} + virtual void visitCommentOn(const Statement& statement) {} + virtual void visitTransaction(const Statement& statement) {} + // LCOV_EXCL_STOP +}; + +} // namespace parser +} // namespace kuzu diff --git a/src/include/parser/visitor/statement_read_write_analyzer.h b/src/include/parser/visitor/statement_read_write_analyzer.h new file mode 100644 index 0000000000..b34fb49e60 --- /dev/null +++ b/src/include/parser/visitor/statement_read_write_analyzer.h @@ -0,0 +1,30 @@ +#pragma once + +#include "parser/parsed_statement_visitor.h" + +namespace kuzu { +namespace parser { + +class StatementReadWriteAnalyzer final : public StatementVisitor { +public: + StatementReadWriteAnalyzer() : StatementVisitor{}, readOnly{true} {} + + bool isReadOnly(const Statement& statement); + +private: + inline void visitCreateTable(const Statement& /*statement*/) { readOnly = false; } + inline void visitDropTable(const Statement& /*statement*/) { readOnly = false; } + inline void visitAlter(const Statement& /*statement*/) { readOnly = false; } + inline void visitCopyFrom(const Statement& /*statement*/) { readOnly = false; } + inline void visitStandaloneCall(const Statement& /*statement*/) { readOnly = false; } + inline void visitCreateMacro(const Statement& /*statement*/) { readOnly = false; } + inline void visitCommentOn(const Statement& /*statement*/) { readOnly = false; } + + inline void visitUpdatingClause(const UpdatingClause* /*updatingClause*/) { readOnly = false; } + +private: + bool readOnly; +}; + +} // namespace parser +} // namespace kuzu diff --git a/src/main/connection.cpp b/src/main/connection.cpp index fbb8cb604d..257450bf36 100644 --- a/src/main/connection.cpp +++ b/src/main/connection.cpp @@ -1,10 +1,10 @@ #include "main/connection.h" #include "binder/binder.h" -#include "binder/visitor/statement_read_write_analyzer.h" #include "common/exception/connection.h" #include "main/database.h" #include "optimizer/optimizer.h" +#include "parser//visitor/statement_read_write_analyzer.h" #include "parser/parser.h" #include "planner/operator/logical_plan_util.h" #include "planner/planner.h" @@ -96,13 +96,12 @@ std::unique_ptr Connection::prepareNoLock( try { // parsing auto statement = Parser::parseQuery(query); + preparedStatement->readOnly = parser::StatementReadWriteAnalyzer().isReadOnly(*statement); // binding auto binder = Binder(*database->catalog, database->memoryManager.get(), database->storageManager.get(), clientContext.get()); auto boundStatement = binder.bind(*statement); preparedStatement->preparedSummary.statementType = boundStatement->getStatementType(); - preparedStatement->readOnly = - binder::StatementReadWriteAnalyzer().isReadOnly(*boundStatement); preparedStatement->parameterMap = binder.getParameterMap(); preparedStatement->statementResult = boundStatement->getStatementResult()->copy(); // planning diff --git a/src/parser/CMakeLists.txt b/src/parser/CMakeLists.txt index 0446366e61..a068d7460d 100644 --- a/src/parser/CMakeLists.txt +++ b/src/parser/CMakeLists.txt @@ -1,13 +1,15 @@ add_subdirectory(antlr_parser) add_subdirectory(expression) add_subdirectory(transform) +add_subdirectory(visitor) add_library(kuzu_parser OBJECT create_macro.cpp + parsed_expression_visitor.cpp parser.cpp - transformer.cpp - parsed_expression_visitor.cpp) + parsed_statement_visitor.cpp + transformer.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/parser/parsed_statement_visitor.cpp b/src/parser/parsed_statement_visitor.cpp new file mode 100644 index 0000000000..3b62946180 --- /dev/null +++ b/src/parser/parsed_statement_visitor.cpp @@ -0,0 +1,126 @@ +#include "parser/parsed_statement_visitor.h" + +#include "common/assert.h" +#include "parser/explain_statement.h" +#include "parser/query/regular_query.h" + +using namespace kuzu::common; + +namespace kuzu { +namespace parser { + +void StatementVisitor::visit(const Statement& statement) { + switch (statement.getStatementType()) { + case StatementType::QUERY: { + visitQuery(statement); + } break; + case StatementType::CREATE_TABLE: { + visitCreateTable(statement); + } break; + case StatementType::DROP_TABLE: { + visitDropTable(statement); + } break; + case StatementType::ALTER: { + visitAlter(statement); + } break; + case StatementType::COPY_FROM: { + visitCopyFrom(statement); + } break; + case StatementType::COPY_TO: { + visitCopyTo(statement); + } break; + case StatementType::STANDALONE_CALL: { + visitStandaloneCall(statement); + } break; + case StatementType::EXPLAIN: { + visitExplain(statement); + } break; + case StatementType::CREATE_MACRO: { + visitCreateMacro(statement); + } break; + case StatementType::COMMENT_ON: { + visitCommentOn(statement); + } break; + case StatementType::TRANSACTION: { + visitTransaction(statement); + } break; + default: + KU_UNREACHABLE; + } +} + +void StatementVisitor::visitExplain(const Statement& statement) { + auto& explainStatement = reinterpret_cast(statement); + visit(*explainStatement.getStatementToExplain()); +} + +void StatementVisitor::visitQuery(const Statement& statement) { + auto& regularQuery = reinterpret_cast(statement); + for (auto i = 0u; i < regularQuery.getNumSingleQueries(); ++i) { + visitSingleQuery(regularQuery.getSingleQuery(i)); + } +} + +void StatementVisitor::visitSingleQuery(const SingleQuery* singleQuery) { + for (auto i = 0u; i < singleQuery->getNumQueryParts(); ++i) { + visitQueryPart(singleQuery->getQueryPart(i)); + } + for (auto i = 0u; i < singleQuery->getNumReadingClauses(); ++i) { + visitReadingClause(singleQuery->getReadingClause(i)); + } + for (auto i = 0u; i < singleQuery->getNumUpdatingClauses(); ++i) { + visitUpdatingClause(singleQuery->getUpdatingClause(i)); + } + visitReturnClause(singleQuery->getReturnClause()); +} + +void StatementVisitor::visitQueryPart(const QueryPart* queryPart) { + for (auto i = 0u; i < queryPart->getNumReadingClauses(); ++i) { + visitReadingClause(queryPart->getReadingClause(i)); + } + for (auto i = 0u; i < queryPart->getNumUpdatingClauses(); ++i) { + visitUpdatingClause(queryPart->getUpdatingClause(i)); + } + visitWithClause(queryPart->getWithClause()); +} + +void StatementVisitor::visitReadingClause(const ReadingClause* readingClause) { + switch (readingClause->getClauseType()) { + case ClauseType::MATCH: { + visitMatch(readingClause); + } break; + case ClauseType::UNWIND: { + visitUnwind(readingClause); + } break; + case ClauseType::IN_QUERY_CALL: { + visitInQueryCall(readingClause); + } break; + case ClauseType::LOAD_FROM: { + visitLoadFrom(readingClause); + } break; + default: + KU_UNREACHABLE; + } +} + +void StatementVisitor::visitUpdatingClause(const UpdatingClause* updatingClause) { + switch (updatingClause->getClauseType()) { + case ClauseType::SET: { + visitSet(updatingClause); + } break; + case ClauseType::DELETE_: { + visitDelete(updatingClause); + } break; + case ClauseType::INSERT: { + visitInsert(updatingClause); + } break; + case ClauseType::MERGE: { + visitMerge(updatingClause); + } break; + default: + KU_UNREACHABLE; + } +} + +} // namespace parser +} // namespace kuzu diff --git a/src/parser/visitor/CMakeLists.txt b/src/parser/visitor/CMakeLists.txt new file mode 100644 index 0000000000..a490682266 --- /dev/null +++ b/src/parser/visitor/CMakeLists.txt @@ -0,0 +1,8 @@ +add_library( + kuzu_parser_visitor + OBJECT + statement_read_write_analyzer.cpp) + +set(ALL_OBJECT_FILES + ${ALL_OBJECT_FILES} $ + PARENT_SCOPE) diff --git a/src/parser/visitor/statement_read_write_analyzer.cpp b/src/parser/visitor/statement_read_write_analyzer.cpp new file mode 100644 index 0000000000..9b559ed830 --- /dev/null +++ b/src/parser/visitor/statement_read_write_analyzer.cpp @@ -0,0 +1,12 @@ +#include "parser/visitor/statement_read_write_analyzer.h" + +namespace kuzu { +namespace parser { + +bool StatementReadWriteAnalyzer::isReadOnly(const Statement& statement) { + visit(statement); + return readOnly; +} + +} // namespace parser +} // namespace kuzu