Skip to content

Commit

Permalink
Merge pull request #2889 from kuzudb/multiple_queries
Browse files Browse the repository at this point in the history
add support for multiple query statements
  • Loading branch information
hououou committed Feb 25, 2024
2 parents 1ce0728 + 373fb2e commit 0ba33b8
Show file tree
Hide file tree
Showing 17 changed files with 2,941 additions and 2,690 deletions.
5 changes: 4 additions & 1 deletion scripts/antlr4/Cypher.g4.copy
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ grammar Cypher;
virtual void notifyNonBinaryComparison(antlr4::Token* startToken) {};
}

ku_Statements
: oC_Cypher ( SP? ';' SP? oC_Cypher )* SP? EOF ;

oC_Cypher
: SP ? oC_AnyCypherOption? SP? ( oC_Statement ) ( SP? ';' )? SP? EOF ;
: oC_AnyCypherOption? SP? ( oC_Statement ) ( SP? ';' )?;

oC_Statement
: oC_Query
Expand Down
5 changes: 4 additions & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ grammar Cypher;
virtual void notifyNonBinaryComparison(antlr4::Token* startToken) {};
}

ku_Statements
: oC_Cypher ( SP? ';' SP? oC_Cypher )* SP? EOF ;

oC_Cypher
: SP ? oC_AnyCypherOption? SP? ( oC_Statement ) ( SP? ';' )? SP? EOF ;
: oC_AnyCypherOption? SP? ( oC_Statement ) ( SP? ';' )?;

oC_Statement
: oC_Query
Expand Down
3 changes: 2 additions & 1 deletion src/binder/bind/bind_export_database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ ExportedTableData Binder::extractExportData(std::string selQuery, std::string ta
auto parsedStatement = Parser::parseQuery(selQuery);
ExportedTableData exportedTableData;
exportedTableData.tableName = tableName;
KU_ASSERT(parsedStatement.size() == 1);
auto parsedQuery =
ku_dynamic_cast<const Statement*, const RegularQuery*>(parsedStatement.get());
ku_dynamic_cast<const Statement*, const RegularQuery*>(parsedStatement[0].get());
auto query = bindQuery(*parsedQuery);
auto columns = query->getStatementResult()->getColumns();
for (auto& column : columns) {
Expand Down
10 changes: 8 additions & 2 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "client_context.h"
#include "database.h"
#include "function/udf_function.h"
#include "parser/statement.h"
#include "prepared_statement.h"
#include "query_result.h"

Expand Down Expand Up @@ -142,11 +143,16 @@ class Connection {
inline ClientContext* getClientContext() { return clientContext.get(); };

private:
std::unique_ptr<QueryResult> query(std::string_view query, std::string_view encodedJoin);
std::unique_ptr<QueryResult> query(
std::string_view query, std::string_view encodedJoin, bool enumerateAllPlans = true);

std::unique_ptr<QueryResult> queryResultWithError(std::string_view errMsg);

std::unique_ptr<PreparedStatement> prepareNoLock(std::string_view query,
std::unique_ptr<PreparedStatement> preparedStatementWithError(std::string_view errMsg);

std::vector<std::unique_ptr<parser::Statement>> parseQuery(std::string_view query);

std::unique_ptr<PreparedStatement> prepareNoLock(parser::Statement* parsedStatement,
bool enumerateAllPlans = false, std::string_view joinOrder = std::string_view());

template<typename T, typename... Args>
Expand Down
24 changes: 23 additions & 1 deletion src/include/main/query_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ struct DataTypeInfo {
*/
class QueryResult {
friend class Connection;
class QueryResultIterator {
private:
QueryResult* currentResult;

public:
explicit QueryResultIterator(QueryResult* startResult) : currentResult(startResult) {}

void operator++() {
if (currentResult) {
currentResult = currentResult->nextQueryResult.get();
}
}

bool isEnd() { return currentResult == nullptr; }

QueryResult* getCurrentResult() { return currentResult; }
};

public:
/**
Expand Down Expand Up @@ -79,11 +96,16 @@ class QueryResult {
* @return whether there are more tuples to read.
*/
KUZU_API bool hasNext() const;
std::unique_ptr<QueryResult> nextQueryResult;

std::string toSingleQueryString();
/**
* @return next flat tuple in the query result.
*/
KUZU_API std::shared_ptr<processor::FlatTuple> getNext();

/**
* @return string of query result.
*/
KUZU_API std::string toString();

/**
Expand Down
3 changes: 2 additions & 1 deletion src/include/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <memory>
#include <string_view>
#include <vector>

#include "statement.h"

Expand All @@ -11,7 +12,7 @@ namespace parser {
class Parser {

public:
static std::unique_ptr<Statement> parseQuery(std::string_view query);
static std::vector<std::unique_ptr<Statement>> parseQuery(std::string_view query);
};

} // namespace parser
Expand Down
6 changes: 3 additions & 3 deletions src/include/parser/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ struct ParsedCaseAlternative;

class Transformer {
public:
explicit Transformer(CypherParser::OC_CypherContext& root) : root{root} {}
explicit Transformer(CypherParser::Ku_StatementsContext& root) : root{root} {}

std::unique_ptr<Statement> transform();
std::vector<std::unique_ptr<Statement>> transform();

private:
std::unique_ptr<Statement> transformStatement(CypherParser::OC_StatementContext& ctx);
Expand Down Expand Up @@ -217,7 +217,7 @@ class Transformer {
std::unique_ptr<Statement> transformCommentOn(CypherParser::KU_CommentOnContext& ctx);

private:
CypherParser::OC_CypherContext& root;
CypherParser::Ku_StatementsContext& root;
};

} // namespace parser
Expand Down
83 changes: 63 additions & 20 deletions src/main/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,45 +44,79 @@ uint64_t Connection::getMaxNumThreadForExec() {
}

std::unique_ptr<PreparedStatement> Connection::prepare(std::string_view query) {
auto preparedStatement = std::unique_ptr<PreparedStatement>();
std::unique_lock<std::mutex> lck{mtx};
return prepareNoLock(query);
auto parsedStatements = std::vector<std::unique_ptr<Statement>>();
try {
parsedStatements = parseQuery(query);
} catch (std::exception& exception) { return preparedStatementWithError(exception.what()); }
if (parsedStatements.size() > 1) {
return preparedStatementWithError(
"Connection Exception: We do not support prepare multiple statements.");
}
if (parsedStatements.empty()) {
return preparedStatementWithError("Connection Exception: Query is empty.");
}
return prepareNoLock(parsedStatements[0].get());
}

std::unique_ptr<QueryResult> Connection::query(std::string_view query) {
lock_t lck{mtx};
auto preparedStatement = prepareNoLock(query);
return executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get());
std::unique_ptr<QueryResult> Connection::query(std::string_view queryStatement) {
return query(queryStatement, std::string_view() /*encodedJoin*/, false /*enumerateAllPlans */);
}

std::unique_ptr<QueryResult> Connection::query(
std::string_view query, std::string_view encodedJoin) {
std::string_view query, std::string_view encodedJoin, bool enumerateAllPlans) {
lock_t lck{mtx};
auto preparedStatement = prepareNoLock(query, true /* enumerate all plans */, encodedJoin);
return executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get());
// parsing
auto parsedStatements = std::vector<std::unique_ptr<Statement>>();
try {
parsedStatements = parseQuery(query);
} catch (std::exception& exception) { return queryResultWithError(exception.what()); }
if (parsedStatements.empty()) {
return queryResultWithError("Connection Exception: Query is empty.");
}
std::unique_ptr<QueryResult> queryResult;
QueryResult* lastResult = nullptr;
for (auto& statement : parsedStatements) {
auto preparedStatement = prepareNoLock(
statement.get(), enumerateAllPlans /* enumerate all plans */, encodedJoin);
auto currentQueryResult = executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get());
if (!lastResult) {
// first result of the query
queryResult = std::move(currentQueryResult);
lastResult = queryResult.get();
} else {
lastResult->nextQueryResult = std::move(currentQueryResult);
lastResult = lastResult->nextQueryResult.get();
}
}
return queryResult;
}

std::unique_ptr<QueryResult> Connection::queryResultWithError(std::string_view errMsg) {
auto queryResult = std::make_unique<QueryResult>();
queryResult->success = false;
queryResult->errMsg = errMsg;
queryResult->nextQueryResult = nullptr;
return queryResult;
}

std::unique_ptr<PreparedStatement> Connection::preparedStatementWithError(std::string_view errMsg) {
auto preparedStatement = std::make_unique<PreparedStatement>();
preparedStatement->success = false;
preparedStatement->errMsg = errMsg;
return preparedStatement;
}

std::unique_ptr<PreparedStatement> Connection::prepareNoLock(
std::string_view query, bool enumerateAllPlans, std::string_view encodedJoin) {
Statement* parsedStatement, bool enumerateAllPlans, std::string_view encodedJoin) {
auto preparedStatement = std::make_unique<PreparedStatement>();
if (query.empty()) {
preparedStatement->success = false;
preparedStatement->errMsg = "Connection Exception: Query is empty.";
return preparedStatement;
}
auto compilingTimer = TimeMetric(true /* enable */);
compilingTimer.start();
std::unique_ptr<Statement> statement;
try {
statement = Parser::parseQuery(query);
preparedStatement->preparedSummary.statementType = statement->getStatementType();
preparedStatement->readOnly = parser::StatementReadWriteAnalyzer().isReadOnly(*statement);
preparedStatement->preparedSummary.statementType = parsedStatement->getStatementType();
preparedStatement->readOnly =
parser::StatementReadWriteAnalyzer().isReadOnly(*parsedStatement);
if (database->systemConfig.readOnly && !preparedStatement->isReadOnly()) {
throw ConnectionException("Cannot execute write operations in a read-only database!");
}
Expand All @@ -97,7 +131,7 @@ std::unique_ptr<PreparedStatement> Connection::prepareNoLock(
std::unique_ptr<LogicalPlan> logicalPlan;
try {
// parsing
if (statement->getStatementType() != StatementType::TRANSACTION) {
if (parsedStatement->getStatementType() != StatementType::TRANSACTION) {
auto txContext = clientContext->transactionContext.get();
if (txContext->isAutoTransaction()) {
txContext->beginAutoTransaction(preparedStatement->readOnly);
Expand All @@ -114,7 +148,7 @@ std::unique_ptr<PreparedStatement> Connection::prepareNoLock(
auto binder = Binder(*database->catalog, database->memoryManager.get(),
database->storageManager.get(), database->vfs.get(), clientContext.get(),
database->extensionOptions.get());
auto boundStatement = binder.bind(*statement);
auto boundStatement = binder.bind(*parsedStatement);
preparedStatement->parameterMap = binder.getParameterMap();
preparedStatement->statementResult =
std::make_unique<BoundStatementResult>(boundStatement->getStatementResult()->copy());
Expand Down Expand Up @@ -155,6 +189,15 @@ std::unique_ptr<PreparedStatement> Connection::prepareNoLock(
return preparedStatement;
}

std::vector<std::unique_ptr<Statement>> Connection::parseQuery(std::string_view query) {
std::vector<std::unique_ptr<Statement>> statements;
if (query.empty()) {
return statements;
}
statements = Parser::parseQuery(query);
return statements;
}

void Connection::interrupt() {
clientContext->interrupt();
}
Expand Down
11 changes: 11 additions & 0 deletions src/main/query_result.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ QueryResult::QueryResult() = default;
QueryResult::QueryResult(const PreparedSummary& preparedSummary) {
querySummary = std::make_unique<QuerySummary>();
querySummary->setPreparedSummary(preparedSummary);
nextQueryResult = nullptr;
}

QueryResult::~QueryResult() = default;
Expand Down Expand Up @@ -162,6 +163,16 @@ std::shared_ptr<FlatTuple> QueryResult::getNext() {
}

std::string QueryResult::toString() {
std::string result;
QueryResultIterator it(this);
while (!it.isEnd()) {
result += it.getCurrentResult()->toSingleQueryString() + "\n";
++it;
}
return result;
}

std::string QueryResult::toSingleQueryString() {
std::string result;
if (isSuccess()) {
// print header
Expand Down
4 changes: 2 additions & 2 deletions src/parser/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using namespace antlr4;
namespace kuzu {
namespace parser {

std::unique_ptr<Statement> Parser::parseQuery(std::string_view query) {
std::vector<std::unique_ptr<Statement>> Parser::parseQuery(std::string_view query) {
auto inputStream = ANTLRInputStream(query);
auto parserErrorListener = ParserErrorListener();

Expand All @@ -31,7 +31,7 @@ std::unique_ptr<Statement> Parser::parseQuery(std::string_view query) {
kuzuCypherParser.addErrorListener(&parserErrorListener);
kuzuCypherParser.setErrorHandler(std::make_shared<ParserErrorStrategy>());

Transformer transformer(*kuzuCypherParser.oC_Cypher());
Transformer transformer(*kuzuCypherParser.ku_Statements());
return transformer.transform();
}

Expand Down
22 changes: 14 additions & 8 deletions src/parser/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,21 @@ using namespace kuzu::common;
namespace kuzu {
namespace parser {

std::unique_ptr<Statement> Transformer::transform() {
auto statement = transformStatement(*root.oC_Statement());
if (root.oC_AnyCypherOption()) {
auto cypherOption = root.oC_AnyCypherOption();
auto explainType =
cypherOption->oC_Explain() ? ExplainType::PHYSICAL_PLAN : ExplainType::PROFILE;
return std::make_unique<ExplainStatement>(std::move(statement), explainType);
std::vector<std::unique_ptr<Statement>> Transformer::transform() {
std::vector<std::unique_ptr<Statement>> statements;
for (auto& oc_Statement : root.oC_Cypher()) {
auto statement = transformStatement(*oc_Statement->oC_Statement());
if (oc_Statement->oC_AnyCypherOption()) {
auto cypherOption = oc_Statement->oC_AnyCypherOption();
auto explainType =
cypherOption->oC_Explain() ? ExplainType::PHYSICAL_PLAN : ExplainType::PROFILE;
statements.push_back(
std::make_unique<ExplainStatement>(std::move(statement), explainType));
continue;
}
statements.push_back(std::move(statement));
}
return statement;
return statements;
}

std::unique_ptr<Statement> Transformer::transformStatement(CypherParser::OC_StatementContext& ctx) {
Expand Down
3 changes: 2 additions & 1 deletion test/graph_test/graph_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ void PrivateGraphTest::validateQueryBestPlanJoinOrder(
std::string query, std::string expectedJoinOrder) {
auto catalog = getCatalog(*database);
auto statement = parser::Parser::parseQuery(query);
auto parsedQuery = (parser::RegularQuery*)statement.get();
ASSERT_EQ(statement.size(), 1);
auto parsedQuery = (parser::RegularQuery*)statement[0].get();
auto boundQuery =
Binder(*catalog, database->memoryManager.get(), database->storageManager.get(),
database->vfs.get(), conn->clientContext.get(), database->extensionOptions.get())
Expand Down
Loading

0 comments on commit 0ba33b8

Please sign in to comment.