Skip to content

Commit

Permalink
Prepare Statement Improvement (#3140)
Browse files Browse the repository at this point in the history
Prepare Statement Improvement
  • Loading branch information
hououou committed Mar 28, 2024
1 parent c747899 commit 015bf23
Show file tree
Hide file tree
Showing 19 changed files with 234 additions and 155 deletions.
5 changes: 5 additions & 0 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ class Binder {

std::unique_ptr<BoundStatement> bind(const parser::Statement& statement);

void setInputParameters(
std::unordered_map<std::string, std::shared_ptr<common::Value>> parameters) {
expressionBinder.parameterMap = parameters;
}

inline std::unordered_map<std::string, std::shared_ptr<common::Value>> getParameterMap() {
return expressionBinder.parameterMap;
}
Expand Down
4 changes: 2 additions & 2 deletions src/include/binder/expression/parameter_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ class ParameterExpression : public Expression {
public:
explicit ParameterExpression(
const std::string& parameterName, std::shared_ptr<common::Value> value)
: Expression{common::ExpressionType::PARAMETER,
common::LogicalType(common::LogicalTypeID::ANY), createUniqueName(parameterName)},
: Expression{common::ExpressionType::PARAMETER, common::LogicalType(*value->getDataType()),
createUniqueName(parameterName)},
parameterName(parameterName), value{std::move(value)} {}

void cast(const common::LogicalType& type) override;
Expand Down
20 changes: 15 additions & 5 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ class ClientContext {
std::unique_ptr<QueryResult> query(std::string_view queryStatement);
void runQuery(std::string query);

// TODO(Jiamin): should remove after supporting ddl in manual tx
std::unique_ptr<PreparedStatement> prepareTest(std::string_view query);
// only use for test framework
std::vector<std::shared_ptr<parser::Statement>> parseQuery(std::string_view query);

private:
std::unique_ptr<QueryResult> query(
std::string_view query, std::string_view encodedJoin, bool enumerateAllPlans = true);
Expand All @@ -114,10 +119,15 @@ class ClientContext {

std::unique_ptr<PreparedStatement> preparedStatementWithError(std::string_view errMsg);

std::vector<std::unique_ptr<parser::Statement>> parseQuery(std::string_view query);

std::unique_ptr<PreparedStatement> prepareNoLock(parser::Statement* parsedStatement,
bool enumerateAllPlans = false, std::string_view joinOrder = std::string_view());
// when we do prepare, we will start a transaction for the query
// when we execute after prepare in a same context, we set requireNewTx to false and will not
// commit the transaction in prepare when we only prepare a query statement, we set requireNewTx
// to true and will commit the transaction in prepare
std::unique_ptr<PreparedStatement> prepareNoLock(
std::shared_ptr<parser::Statement> parsedStatement, bool enumerateAllPlans = false,
std::string_view joinOrder = std::string_view(), bool requireNewTx = true,
std::optional<std::unordered_map<std::string, std::shared_ptr<common::Value>>> inputParams =
std::nullopt);

template<typename T, typename... Args>
std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,
Expand All @@ -133,7 +143,7 @@ class ClientContext {
const std::unordered_map<std::string, std::unique_ptr<common::Value>>& inputParams);

std::unique_ptr<QueryResult> executeAndAutoCommitIfNecessaryNoLock(
PreparedStatement* preparedStatement, uint32_t planIdx = 0u);
PreparedStatement* preparedStatement, uint32_t planIdx = 0u, bool requiredNexTx = true);

void addScalarFunction(std::string name, function::function_set definitions);

Expand Down
7 changes: 3 additions & 4 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,9 @@ class Connection {

std::unique_ptr<PreparedStatement> preparedStatementWithError(std::string_view errMsg);

std::vector<std::unique_ptr<parser::Statement>> parseQuery(std::string_view query);

std::unique_ptr<PreparedStatement> prepareNoLock(parser::Statement* parsedStatement,
bool enumerateAllPlans = false, std::string_view joinOrder = std::string_view());
std::unique_ptr<PreparedStatement> prepareNoLock(
std::shared_ptr<parser::Statement> parsedStatement, bool enumerateAllPlans = false,
std::string_view joinOrder = std::string_view());

template<typename T, typename... Args>
std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,
Expand Down
2 changes: 2 additions & 0 deletions src/include/main/prepared_statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "common/api.h"
#include "kuzu_fwd.h"
#include "parser/statement.h"
#include "query_summary.h"

namespace kuzu {
Expand Down Expand Up @@ -62,6 +63,7 @@ class PreparedStatement {
std::unordered_map<std::string, std::shared_ptr<common::Value>> parameterMap;
std::unique_ptr<binder::BoundStatementResult> statementResult;
std::vector<std::unique_ptr<planner::LogicalPlan>> logicalPlans;
std::shared_ptr<parser::Statement> parsedStatement;
};

} // namespace main
Expand Down
2 changes: 1 addition & 1 deletion src/include/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace parser {
class Parser {

public:
static std::vector<std::unique_ptr<Statement>> parseQuery(std::string_view query);
static std::vector<std::shared_ptr<Statement>> parseQuery(std::string_view query);
};

} // namespace parser
Expand Down
9 changes: 9 additions & 0 deletions src/include/parser/statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ class Statement {

inline common::StatementType getStatementType() const { return statementType; }

inline bool requireTx() {
switch (statementType) {
case common::StatementType::TRANSACTION:
return false;
default:
return true;
}
}

private:
common::StatementType statementType;
};
Expand Down
2 changes: 1 addition & 1 deletion src/include/parser/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Transformer {
public:
explicit Transformer(CypherParser::Ku_StatementsContext& root) : root{root} {}

std::vector<std::unique_ptr<Statement>> transform();
std::vector<std::shared_ptr<Statement>> transform();

private:
std::unique_ptr<Statement> transformStatement(CypherParser::OC_StatementContext& ctx);
Expand Down
97 changes: 54 additions & 43 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,28 @@ std::string ClientContext::getEnvVariable(const std::string& name) {
}

std::unique_ptr<PreparedStatement> ClientContext::prepare(std::string_view query) {
auto preparedStatement = std::unique_ptr<PreparedStatement>();
if (query.empty()) {
return preparedStatementWithError("Connection Exception: Query is empty.");
}
std::unique_lock<std::mutex> lck{mtx};
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = Parser::parseQuery(query);
} catch (std::exception& exception) { return preparedStatementWithError(exception.what()); }
if (parsedStatements.size() > 1) {
return preparedStatementWithError(
"Connection Exception: We do not support prepare multiple statements.");
}
return prepareNoLock(parsedStatements[0]);
}

std::unique_ptr<PreparedStatement> ClientContext::prepareTest(std::string_view query) {
auto preparedStatement = std::unique_ptr<PreparedStatement>();
std::unique_lock<std::mutex> lck{mtx};
auto parsedStatements = std::vector<std::unique_ptr<Statement>>();
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = parseQuery(query);
parsedStatements = Parser::parseQuery(query);
} catch (std::exception& exception) { return preparedStatementWithError(exception.what()); }
if (parsedStatements.size() > 1) {
return preparedStatementWithError(
Expand All @@ -203,7 +220,8 @@ std::unique_ptr<PreparedStatement> ClientContext::prepare(std::string_view query
if (parsedStatements.empty()) {
return preparedStatementWithError("Connection Exception: Query is empty.");
}
return prepareNoLock(parsedStatements[0].get());
return prepareNoLock(
parsedStatements[0], false /* enumerate all plans */, "", false /*requireNewTx*/);
}

std::unique_ptr<QueryResult> ClientContext::query(std::string_view queryStatement) {
Expand All @@ -213,20 +231,20 @@ std::unique_ptr<QueryResult> ClientContext::query(std::string_view queryStatemen
std::unique_ptr<QueryResult> ClientContext::query(
std::string_view query, std::string_view encodedJoin, bool enumerateAllPlans) {
lock_t lck{mtx};
// parsing
auto parsedStatements = std::vector<std::unique_ptr<Statement>>();
try {
parsedStatements = parseQuery(query);
} catch (std::exception& exception) { return queryResultWithError(exception.what()); }
if (parsedStatements.empty()) {
if (query.empty()) {
return queryResultWithError("Connection Exception: Query is empty.");
}
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = Parser::parseQuery(query);
} catch (std::exception& exception) { return queryResultWithError(exception.what()); }
std::unique_ptr<QueryResult> queryResult;
QueryResult* lastResult = nullptr;
for (auto& statement : parsedStatements) {
auto preparedStatement = prepareNoLock(
statement.get(), enumerateAllPlans /* enumerate all plans */, encodedJoin);
auto currentQueryResult = executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get());
auto preparedStatement = prepareNoLock(statement,
enumerateAllPlans /* enumerate all plans */, encodedJoin, false /*requireNewTx*/);
auto currentQueryResult = executeAndAutoCommitIfNecessaryNoLock(
preparedStatement.get(), 0u, false /*requiredNexTx*/);
if (!lastResult) {
// first result of the query
queryResult = std::move(currentQueryResult);
Expand Down Expand Up @@ -256,7 +274,9 @@ std::unique_ptr<PreparedStatement> ClientContext::preparedStatementWithError(
}

std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
Statement* parsedStatement, bool enumerateAllPlans, std::string_view encodedJoin) {
std::shared_ptr<Statement> parsedStatement, bool enumerateAllPlans,
std::string_view encodedJoin, bool requireNewTx,
std::optional<std::unordered_map<std::string, std::shared_ptr<common::Value>>> inputParams) {
auto preparedStatement = std::make_unique<PreparedStatement>();
auto compilingTimer = TimeMetric(true /* enable */);
compilingTimer.start();
Expand All @@ -267,18 +287,8 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
if (database->systemConfig.readOnly && !preparedStatement->isReadOnly()) {
throw ConnectionException("Cannot execute write operations in a read-only database!");
}
} catch (std::exception& exception) {
preparedStatement->success = false;
preparedStatement->errMsg = exception.what();
compilingTimer.stop();
preparedStatement->preparedSummary.compilingTime = compilingTimer.getElapsedTimeMS();
return preparedStatement;
}
std::unique_ptr<ExecutionContext> executionContext;
std::unique_ptr<LogicalPlan> logicalPlan;
try {
// parsing
if (parsedStatement->getStatementType() != StatementType::TRANSACTION) {
preparedStatement->parsedStatement = parsedStatement;
if (parsedStatement->requireTx()) {
if (transactionContext->isAutoTransaction()) {
transactionContext->beginAutoTransaction(preparedStatement->readOnly);
} else {
Expand All @@ -292,6 +302,9 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
}
// binding
auto binder = Binder(this);
if (inputParams) {
binder.setInputParameters(*inputParams);
}
auto boundStatement = binder.bind(*parsedStatement);
preparedStatement->parameterMap = binder.getParameterMap();
preparedStatement->statementResult =
Expand Down Expand Up @@ -323,6 +336,9 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
} else {
preparedStatement->logicalPlans = std::move(plans);
}
if (transactionContext->isAutoTransaction() && requireNewTx) {
this->transactionContext->commit();
}
} catch (std::exception& exception) {
preparedStatement->success = false;
preparedStatement->errMsg = exception.what();
Expand All @@ -333,8 +349,8 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
return preparedStatement;
}

std::vector<std::unique_ptr<Statement>> ClientContext::parseQuery(std::string_view query) {
std::vector<std::unique_ptr<Statement>> statements;
std::vector<std::shared_ptr<Statement>> ClientContext::parseQuery(std::string_view query) {
std::vector<std::shared_ptr<Statement>> statements;
if (query.empty()) {
return statements;
}
Expand All @@ -356,7 +372,11 @@ std::unique_ptr<QueryResult> ClientContext::executeWithParams(PreparedStatement*
std::string errMsg = exception.what();
return queryResultWithError(errMsg);
}
return executeAndAutoCommitIfNecessaryNoLock(preparedStatement);
// rebind
KU_ASSERT(preparedStatement->parsedStatement != nullptr);
auto rebindPreparedStatement = prepareNoLock(
preparedStatement->parsedStatement, false, "", false, preparedStatement->parameterMap);
return executeAndAutoCommitIfNecessaryNoLock(rebindPreparedStatement.get(), 0u, false);
}

void ClientContext::bindParametersNoLock(PreparedStatement* preparedStatement,
Expand All @@ -367,14 +387,6 @@ void ClientContext::bindParametersNoLock(PreparedStatement* preparedStatement,
throw Exception("Parameter " + name + " not found.");
}
auto expectParam = parameterMap.at(name);
if (value->getDataType()->getLogicalTypeID() == LogicalTypeID::ANY) {
value->setDataType(*expectParam->getDataType());
}
if (*expectParam->getDataType() != *value->getDataType()) {
throw Exception("Parameter " + name + " has data type " +
value->getDataType()->toString() + " but expects " +
expectParam->getDataType()->toString() + ".");
}
// The much more natural `parameterMap.at(name) = std::move(v)` fails.
// The reason is that other parts of the code rely on the existing Value object to be
// modified in-place, not replaced in this map.
Expand All @@ -383,12 +395,11 @@ void ClientContext::bindParametersNoLock(PreparedStatement* preparedStatement,
}

std::unique_ptr<QueryResult> ClientContext::executeAndAutoCommitIfNecessaryNoLock(
PreparedStatement* preparedStatement, uint32_t planIdx) {
PreparedStatement* preparedStatement, uint32_t planIdx, bool requiredNexTx) {
if (!preparedStatement->isSuccess()) {
return queryResultWithError(preparedStatement->errMsg);
}
if (preparedStatement->preparedSummary.statementType != common::StatementType::TRANSACTION &&
this->getTx() == nullptr) {
if (preparedStatement->parsedStatement->requireTx() && requiredNexTx && getTx() == nullptr) {
this->transactionContext->beginAutoTransaction(preparedStatement->isReadOnly());
if (!preparedStatement->readOnly) {
database->catalog->initCatalogContentForWriteTrxIfNecessary();
Expand Down Expand Up @@ -470,18 +481,18 @@ void ClientContext::runQuery(std::string query) {
if (transactionContext->hasActiveTransaction()) {
transactionContext->commit();
}
auto parsedStatements = std::vector<std::unique_ptr<Statement>>();
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = parseQuery(query);
parsedStatements = Parser::parseQuery(query);
} catch (std::exception& exception) { throw ConnectionException(exception.what()); }
if (parsedStatements.empty()) {
throw ConnectionException("Connection Exception: Query is empty.");
}
try {
for (auto& statement : parsedStatements) {
auto preparedStatement = prepareNoLock(statement.get());
auto preparedStatement = prepareNoLock(statement, false, "", false);
auto currentQueryResult =
executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get());
executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get(), 0u, false);
if (!currentQueryResult->isSuccess()) {
throw ConnectionException(currentQueryResult->errMsg);
}
Expand Down
7 changes: 2 additions & 5 deletions src/main/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,11 @@ std::unique_ptr<PreparedStatement> Connection::preparedStatementWithError(std::s
}

std::unique_ptr<PreparedStatement> Connection::prepareNoLock(
Statement* parsedStatement, bool enumerateAllPlans, std::string_view encodedJoin) {
std::shared_ptr<Statement> parsedStatement, bool enumerateAllPlans,
std::string_view encodedJoin) {
return clientContext->prepareNoLock(parsedStatement, enumerateAllPlans, encodedJoin);
}

std::vector<std::unique_ptr<Statement>> Connection::parseQuery(std::string_view query) {
return clientContext->parseQuery(query);
}

void Connection::interrupt() {
clientContext->interrupt();
}
Expand Down
2 changes: 1 addition & 1 deletion src/parser/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using namespace antlr4;
namespace kuzu {
namespace parser {

std::vector<std::unique_ptr<Statement>> Parser::parseQuery(std::string_view query) {
std::vector<std::shared_ptr<Statement>> Parser::parseQuery(std::string_view query) {
auto inputStream = ANTLRInputStream(query);
auto parserErrorListener = ParserErrorListener();

Expand Down
4 changes: 2 additions & 2 deletions src/parser/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ using namespace kuzu::common;
namespace kuzu {
namespace parser {

std::vector<std::unique_ptr<Statement>> Transformer::transform() {
std::vector<std::unique_ptr<Statement>> statements;
std::vector<std::shared_ptr<Statement>> Transformer::transform() {
std::vector<std::shared_ptr<Statement>> statements;
for (auto& oc_Statement : root.oC_Cypher()) {
auto statement = transformStatement(*oc_Statement->oC_Statement());
if (oc_Statement->oC_AnyCypherOption()) {
Expand Down
4 changes: 2 additions & 2 deletions test/copy/e2e_copy_transaction_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {

void copyNodeCSVCommitAndRecoveryTest(TransactionTestType transactionTestType) {
conn->query(createPersonTableCMD);
auto preparedStatement = conn->prepare(copyPersonTableCMD);
auto preparedStatement = conn->getClientContext()->prepareTest(copyPersonTableCMD);
if (!preparedStatement->success) {
ASSERT_TRUE(false) << preparedStatement->errMsg;
}
Expand Down Expand Up @@ -130,7 +130,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
conn->query(createPersonTableCMD);
conn->query(copyPersonTableCMD);
conn->query(createKnowsTableCMD);
auto preparedStatement = conn->prepare(copyKnowsTableCMD);
auto preparedStatement = conn->getClientContext()->prepareTest(copyKnowsTableCMD);
auto mapper = PlanMapper(conn->getClientContext());
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
Expand Down
2 changes: 1 addition & 1 deletion test/ddl/e2e_ddl_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class TinySnbDDLTest : public DBTest {
}

void executeQueryWithoutCommit(std::string query) {
auto preparedStatement = conn->prepare(query);
auto preparedStatement = conn->getClientContext()->prepareTest(query);
auto mapper = PlanMapper(conn->getClientContext());
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
Expand Down
Loading

0 comments on commit 015bf23

Please sign in to comment.