Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parsed statement visitor #2396

Merged
merged 1 commit into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/binder/bound_statement_visitor.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "binder/bound_statement_visitor.h"

#include "binder/bound_explain.h"
#include "common/exception/not_implemented.h"
#include "binder/query/bound_regular_query.h"

using namespace kuzu::common;

Expand All @@ -11,7 +11,7 @@ namespace binder {
void BoundStatementVisitor::visit(const kuzu::binder::BoundStatement& statement) {
switch (statement.getStatementType()) {
case StatementType::QUERY: {
visitRegularQuery((BoundRegularQuery&)statement);
visitRegularQuery(statement);
} break;
case StatementType::CREATE_TABLE: {
visitCreateTable(statement);
Expand Down Expand Up @@ -44,11 +44,12 @@ void BoundStatementVisitor::visit(const kuzu::binder::BoundStatement& statement)
visitTransaction(statement);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visit");
KU_UNREACHABLE;
}
}

void BoundStatementVisitor::visitRegularQuery(const BoundRegularQuery& regularQuery) {
void BoundStatementVisitor::visitRegularQuery(const BoundStatement& statement) {
auto& regularQuery = reinterpret_cast<const BoundRegularQuery&>(statement);
for (auto i = 0u; i < regularQuery.getNumSingleQueries(); ++i) {
visitSingleQuery(*regularQuery.getSingleQuery(i));
}
Expand Down Expand Up @@ -94,7 +95,7 @@ void BoundStatementVisitor::visitReadingClause(const BoundReadingClause& reading
visitLoadFrom(readingClause);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visitReadingClause");
KU_UNREACHABLE;
}
}

Expand All @@ -113,7 +114,7 @@ void BoundStatementVisitor::visitUpdatingClause(const BoundUpdatingClause& updat
visitMerge(updatingClause);
} break;
default:
throw NotImplementedException("BoundStatementVisitor::visitUpdatingClause");
KU_UNREACHABLE;
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/binder/visitor/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
add_library(
kuzu_binder_visitor
OBJECT
property_collector.cpp
statement_read_write_analyzer.cpp)
property_collector.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_binder_visitor>
Expand Down
18 changes: 0 additions & 18 deletions src/binder/visitor/statement_read_write_analyzer.cpp

This file was deleted.

6 changes: 3 additions & 3 deletions src/include/binder/bound_statement_visitor.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include "binder/query/bound_regular_query.h"
#include "binder/query/normalized_single_query.h"
#include "bound_statement.h"

namespace kuzu {
Expand All @@ -13,9 +13,7 @@ class BoundStatementVisitor {

void visit(const BoundStatement& statement);

virtual void visitRegularQuery(const BoundRegularQuery& regularQuery);
virtual void visitSingleQuery(const NormalizedSingleQuery& singleQuery);
virtual void visitQueryPart(const NormalizedQueryPart& queryPart);

protected:
virtual void visitCreateTable(const BoundStatement& statement) {}
Expand All @@ -29,6 +27,8 @@ class BoundStatementVisitor {
virtual void visitCreateMacro(const BoundStatement& statement) {}
virtual void visitTransaction(const BoundStatement& statement) {}

virtual void visitRegularQuery(const BoundStatement& statement);
virtual void visitQueryPart(const NormalizedQueryPart& queryPart);
void visitReadingClause(const BoundReadingClause& readingClause);
virtual void visitMatch(const BoundReadingClause& readingClause) {}
virtual void visitUnwind(const BoundReadingClause& readingClause) {}
Expand Down
29 changes: 0 additions & 29 deletions src/include/binder/visitor/statement_read_write_analyzer.h

This file was deleted.

54 changes: 54 additions & 0 deletions src/include/parser/parsed_statement_visitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#pragma once

#include "statement.h"

namespace kuzu {
namespace parser {

class SingleQuery;
class QueryPart;
class ReadingClause;
class UpdatingClause;
class WithClause;
class ReturnClause;

class StatementVisitor {
public:
StatementVisitor() = default;
virtual ~StatementVisitor() = default;

Check warning on line 18 in src/include/parser/parsed_statement_visitor.h

View check run for this annotation

Codecov / codecov/patch

src/include/parser/parsed_statement_visitor.h#L18

Added line #L18 was not covered by tests

void visit(const Statement& statement);

private:
// LCOV_EXCL_START
virtual void visitQuery(const Statement& statement);
virtual void visitSingleQuery(const SingleQuery* singleQuery);
virtual void visitQueryPart(const QueryPart* queryPart);
virtual void visitReadingClause(const ReadingClause* readingClause);
virtual void visitMatch(const ReadingClause* readingClause) {}
virtual void visitUnwind(const ReadingClause* readingClause) {}
virtual void visitInQueryCall(const ReadingClause* readingClause) {}
virtual void visitLoadFrom(const ReadingClause* readingClause) {}
virtual void visitUpdatingClause(const UpdatingClause* updatingClause);
virtual void visitSet(const UpdatingClause* updatingClause) {}
virtual void visitDelete(const UpdatingClause* updatingClause) {}
virtual void visitInsert(const UpdatingClause* updatingClause) {}
virtual void visitMerge(const UpdatingClause* updatingClause) {}
virtual void visitWithClause(const WithClause* withClause) {}
virtual void visitReturnClause(const ReturnClause* returnClause) {}

virtual void visitCreateTable(const Statement& statement) {}
virtual void visitDropTable(const Statement& statement) {}
virtual void visitAlter(const Statement& statement) {}
virtual void visitCopyFrom(const Statement& statement) {}
virtual void visitCopyTo(const Statement& statement) {}
virtual void visitStandaloneCall(const Statement& statement) {}
virtual void visitExplain(const Statement& statement);
virtual void visitCreateMacro(const Statement& statement) {}
virtual void visitCommentOn(const Statement& statement) {}
virtual void visitTransaction(const Statement& statement) {}
// LCOV_EXCL_STOP
};

} // namespace parser
} // namespace kuzu
30 changes: 30 additions & 0 deletions src/include/parser/visitor/statement_read_write_analyzer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include "parser/parsed_statement_visitor.h"

namespace kuzu {
namespace parser {

class StatementReadWriteAnalyzer final : public StatementVisitor {
public:
StatementReadWriteAnalyzer() : StatementVisitor{}, readOnly{true} {}

bool isReadOnly(const Statement& statement);

private:
inline void visitCreateTable(const Statement& /*statement*/) { readOnly = false; }
inline void visitDropTable(const Statement& /*statement*/) { readOnly = false; }
inline void visitAlter(const Statement& /*statement*/) { readOnly = false; }
inline void visitCopyFrom(const Statement& /*statement*/) { readOnly = false; }
inline void visitStandaloneCall(const Statement& /*statement*/) { readOnly = false; }
inline void visitCreateMacro(const Statement& /*statement*/) { readOnly = false; }
inline void visitCommentOn(const Statement& /*statement*/) { readOnly = false; }

inline void visitUpdatingClause(const UpdatingClause* /*updatingClause*/) { readOnly = false; }

private:
bool readOnly;
};

} // namespace parser
} // namespace kuzu
5 changes: 2 additions & 3 deletions src/main/connection.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include "main/connection.h"

#include "binder/binder.h"
#include "binder/visitor/statement_read_write_analyzer.h"
#include "common/exception/connection.h"
#include "main/database.h"
#include "optimizer/optimizer.h"
#include "parser//visitor/statement_read_write_analyzer.h"
#include "parser/parser.h"
#include "planner/operator/logical_plan_util.h"
#include "planner/planner.h"
Expand Down Expand Up @@ -96,13 +96,12 @@ std::unique_ptr<PreparedStatement> Connection::prepareNoLock(
try {
// parsing
auto statement = Parser::parseQuery(query);
preparedStatement->readOnly = parser::StatementReadWriteAnalyzer().isReadOnly(*statement);
// binding
auto binder = Binder(*database->catalog, database->memoryManager.get(),
database->storageManager.get(), clientContext.get());
auto boundStatement = binder.bind(*statement);
preparedStatement->preparedSummary.statementType = boundStatement->getStatementType();
preparedStatement->readOnly =
binder::StatementReadWriteAnalyzer().isReadOnly(*boundStatement);
preparedStatement->parameterMap = binder.getParameterMap();
preparedStatement->statementResult = boundStatement->getStatementResult()->copy();
// planning
Expand Down
6 changes: 4 additions & 2 deletions src/parser/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
add_subdirectory(antlr_parser)
add_subdirectory(expression)
add_subdirectory(transform)
add_subdirectory(visitor)

add_library(kuzu_parser
OBJECT
create_macro.cpp
parsed_expression_visitor.cpp
parser.cpp
transformer.cpp
parsed_expression_visitor.cpp)
parsed_statement_visitor.cpp
transformer.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_parser>
Expand Down
126 changes: 126 additions & 0 deletions src/parser/parsed_statement_visitor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include "parser/parsed_statement_visitor.h"

#include "common/assert.h"
#include "parser/explain_statement.h"
#include "parser/query/regular_query.h"

using namespace kuzu::common;

namespace kuzu {
namespace parser {

void StatementVisitor::visit(const Statement& statement) {
switch (statement.getStatementType()) {
case StatementType::QUERY: {
visitQuery(statement);
} break;
case StatementType::CREATE_TABLE: {
visitCreateTable(statement);
} break;
case StatementType::DROP_TABLE: {
visitDropTable(statement);
} break;
case StatementType::ALTER: {
visitAlter(statement);
} break;
case StatementType::COPY_FROM: {
visitCopyFrom(statement);
} break;
case StatementType::COPY_TO: {
visitCopyTo(statement);
} break;
case StatementType::STANDALONE_CALL: {
visitStandaloneCall(statement);
} break;
case StatementType::EXPLAIN: {
visitExplain(statement);
} break;
case StatementType::CREATE_MACRO: {
visitCreateMacro(statement);
} break;
case StatementType::COMMENT_ON: {
visitCommentOn(statement);
} break;
case StatementType::TRANSACTION: {
visitTransaction(statement);
} break;
default:

Check warning on line 47 in src/parser/parsed_statement_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/parser/parsed_statement_visitor.cpp#L47

Added line #L47 was not covered by tests
KU_UNREACHABLE;
}
}

void StatementVisitor::visitExplain(const Statement& statement) {
auto& explainStatement = reinterpret_cast<const ExplainStatement&>(statement);
visit(*explainStatement.getStatementToExplain());
}

void StatementVisitor::visitQuery(const Statement& statement) {
auto& regularQuery = reinterpret_cast<const RegularQuery&>(statement);
for (auto i = 0u; i < regularQuery.getNumSingleQueries(); ++i) {
visitSingleQuery(regularQuery.getSingleQuery(i));
}
}

void StatementVisitor::visitSingleQuery(const SingleQuery* singleQuery) {
for (auto i = 0u; i < singleQuery->getNumQueryParts(); ++i) {
visitQueryPart(singleQuery->getQueryPart(i));
}
for (auto i = 0u; i < singleQuery->getNumReadingClauses(); ++i) {
visitReadingClause(singleQuery->getReadingClause(i));
}
for (auto i = 0u; i < singleQuery->getNumUpdatingClauses(); ++i) {
visitUpdatingClause(singleQuery->getUpdatingClause(i));
}
visitReturnClause(singleQuery->getReturnClause());
}

void StatementVisitor::visitQueryPart(const QueryPart* queryPart) {
for (auto i = 0u; i < queryPart->getNumReadingClauses(); ++i) {
visitReadingClause(queryPart->getReadingClause(i));
}
for (auto i = 0u; i < queryPart->getNumUpdatingClauses(); ++i) {
visitUpdatingClause(queryPart->getUpdatingClause(i));
}
visitWithClause(queryPart->getWithClause());
}

void StatementVisitor::visitReadingClause(const ReadingClause* readingClause) {
switch (readingClause->getClauseType()) {
case ClauseType::MATCH: {
visitMatch(readingClause);
} break;
case ClauseType::UNWIND: {
visitUnwind(readingClause);
} break;
case ClauseType::IN_QUERY_CALL: {
visitInQueryCall(readingClause);
} break;
case ClauseType::LOAD_FROM: {
visitLoadFrom(readingClause);
} break;
default:

Check warning on line 101 in src/parser/parsed_statement_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/parser/parsed_statement_visitor.cpp#L101

Added line #L101 was not covered by tests
KU_UNREACHABLE;
}
}

void StatementVisitor::visitUpdatingClause(const UpdatingClause* updatingClause) {
switch (updatingClause->getClauseType()) {
case ClauseType::SET: {
visitSet(updatingClause);
} break;
case ClauseType::DELETE_: {
visitDelete(updatingClause);
} break;
case ClauseType::INSERT: {
visitInsert(updatingClause);
} break;
case ClauseType::MERGE: {
visitMerge(updatingClause);
} break;
default:

Check warning on line 120 in src/parser/parsed_statement_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/parser/parsed_statement_visitor.cpp#L106-L120

Added lines #L106 - L120 were not covered by tests
KU_UNREACHABLE;
}
}

Check warning on line 123 in src/parser/parsed_statement_visitor.cpp

View check run for this annotation

Codecov / codecov/patch

src/parser/parsed_statement_visitor.cpp#L123

Added line #L123 was not covered by tests

} // namespace parser
} // namespace kuzu
Loading