Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare Statement Improvement #3140

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ class Binder {

std::unique_ptr<BoundStatement> bind(const parser::Statement& statement);

void setInputParameters(
std::unordered_map<std::string, std::shared_ptr<common::Value>> parameters) {
expressionBinder.parameterMap = parameters;
}

inline std::unordered_map<std::string, std::shared_ptr<common::Value>> getParameterMap() {
return expressionBinder.parameterMap;
}
Expand Down
4 changes: 2 additions & 2 deletions src/include/binder/expression/parameter_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ class ParameterExpression : public Expression {
public:
explicit ParameterExpression(
const std::string& parameterName, std::shared_ptr<common::Value> value)
: Expression{common::ExpressionType::PARAMETER,
common::LogicalType(common::LogicalTypeID::ANY), createUniqueName(parameterName)},
: Expression{common::ExpressionType::PARAMETER, common::LogicalType(*value->getDataType()),
createUniqueName(parameterName)},
parameterName(parameterName), value{std::move(value)} {}

void cast(const common::LogicalType& type) override;
Expand Down
20 changes: 15 additions & 5 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ class ClientContext {
std::unique_ptr<QueryResult> query(std::string_view queryStatement);
void runQuery(std::string query);

// TODO(Jiamin): should remove after supporting ddl in manual tx
std::unique_ptr<PreparedStatement> prepareTest(std::string_view query);
// only use for test framework
std::vector<std::shared_ptr<parser::Statement>> parseQuery(std::string_view query);

private:
std::unique_ptr<QueryResult> query(
std::string_view query, std::string_view encodedJoin, bool enumerateAllPlans = true);
Expand All @@ -114,10 +119,15 @@ class ClientContext {

std::unique_ptr<PreparedStatement> preparedStatementWithError(std::string_view errMsg);

std::vector<std::unique_ptr<parser::Statement>> parseQuery(std::string_view query);

std::unique_ptr<PreparedStatement> prepareNoLock(parser::Statement* parsedStatement,
bool enumerateAllPlans = false, std::string_view joinOrder = std::string_view());
// when we do prepare, we will start a transaction for the query
// when we execute after prepare in a same context, we set requireNewTx to false and will not
// commit the transaction in prepare when we only prepare a query statement, we set requireNewTx
// to true and will commit the transaction in prepare
std::unique_ptr<PreparedStatement> prepareNoLock(
std::shared_ptr<parser::Statement> parsedStatement, bool enumerateAllPlans = false,
std::string_view joinOrder = std::string_view(), bool requireNewTx = true,
std::optional<std::unordered_map<std::string, std::shared_ptr<common::Value>>> inputParams =
std::nullopt);

template<typename T, typename... Args>
std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,
Expand All @@ -133,7 +143,7 @@ class ClientContext {
const std::unordered_map<std::string, std::unique_ptr<common::Value>>& inputParams);

std::unique_ptr<QueryResult> executeAndAutoCommitIfNecessaryNoLock(
PreparedStatement* preparedStatement, uint32_t planIdx = 0u);
PreparedStatement* preparedStatement, uint32_t planIdx = 0u, bool requiredNexTx = true);

void addScalarFunction(std::string name, function::function_set definitions);

Expand Down
7 changes: 3 additions & 4 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,9 @@ class Connection {

std::unique_ptr<PreparedStatement> preparedStatementWithError(std::string_view errMsg);

std::vector<std::unique_ptr<parser::Statement>> parseQuery(std::string_view query);

std::unique_ptr<PreparedStatement> prepareNoLock(parser::Statement* parsedStatement,
bool enumerateAllPlans = false, std::string_view joinOrder = std::string_view());
std::unique_ptr<PreparedStatement> prepareNoLock(
std::shared_ptr<parser::Statement> parsedStatement, bool enumerateAllPlans = false,
std::string_view joinOrder = std::string_view());

template<typename T, typename... Args>
std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,
Expand Down
2 changes: 2 additions & 0 deletions src/include/main/prepared_statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "common/api.h"
#include "kuzu_fwd.h"
#include "parser/statement.h"
#include "query_summary.h"

namespace kuzu {
Expand Down Expand Up @@ -62,6 +63,7 @@ class PreparedStatement {
std::unordered_map<std::string, std::shared_ptr<common::Value>> parameterMap;
std::unique_ptr<binder::BoundStatementResult> statementResult;
std::vector<std::unique_ptr<planner::LogicalPlan>> logicalPlans;
std::shared_ptr<parser::Statement> parsedStatement;
};

} // namespace main
Expand Down
2 changes: 1 addition & 1 deletion src/include/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace parser {
class Parser {

public:
static std::vector<std::unique_ptr<Statement>> parseQuery(std::string_view query);
static std::vector<std::shared_ptr<Statement>> parseQuery(std::string_view query);
};

} // namespace parser
Expand Down
9 changes: 9 additions & 0 deletions src/include/parser/statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ class Statement {

inline common::StatementType getStatementType() const { return statementType; }

inline bool requireTx() {
switch (statementType) {
case common::StatementType::TRANSACTION:
return false;
default:
return true;
}
}

private:
common::StatementType statementType;
};
Expand Down
2 changes: 1 addition & 1 deletion src/include/parser/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Transformer {
public:
explicit Transformer(CypherParser::Ku_StatementsContext& root) : root{root} {}

std::vector<std::unique_ptr<Statement>> transform();
std::vector<std::shared_ptr<Statement>> transform();

private:
std::unique_ptr<Statement> transformStatement(CypherParser::OC_StatementContext& ctx);
Expand Down
97 changes: 54 additions & 43 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,28 @@ std::string ClientContext::getEnvVariable(const std::string& name) {
}

std::unique_ptr<PreparedStatement> ClientContext::prepare(std::string_view query) {
auto preparedStatement = std::unique_ptr<PreparedStatement>();
if (query.empty()) {
return preparedStatementWithError("Connection Exception: Query is empty.");
}
std::unique_lock<std::mutex> lck{mtx};
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = Parser::parseQuery(query);
} catch (std::exception& exception) { return preparedStatementWithError(exception.what()); }
if (parsedStatements.size() > 1) {
return preparedStatementWithError(
"Connection Exception: We do not support prepare multiple statements.");
}
return prepareNoLock(parsedStatements[0]);
}

std::unique_ptr<PreparedStatement> ClientContext::prepareTest(std::string_view query) {
auto preparedStatement = std::unique_ptr<PreparedStatement>();
std::unique_lock<std::mutex> lck{mtx};
auto parsedStatements = std::vector<std::unique_ptr<Statement>>();
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = parseQuery(query);
parsedStatements = Parser::parseQuery(query);
} catch (std::exception& exception) { return preparedStatementWithError(exception.what()); }
if (parsedStatements.size() > 1) {
return preparedStatementWithError(
Expand All @@ -203,7 +220,8 @@ std::unique_ptr<PreparedStatement> ClientContext::prepare(std::string_view query
if (parsedStatements.empty()) {
return preparedStatementWithError("Connection Exception: Query is empty.");
}
return prepareNoLock(parsedStatements[0].get());
return prepareNoLock(
parsedStatements[0], false /* enumerate all plans */, "", false /*requireNewTx*/);
}

std::unique_ptr<QueryResult> ClientContext::query(std::string_view queryStatement) {
Expand All @@ -213,20 +231,20 @@ std::unique_ptr<QueryResult> ClientContext::query(std::string_view queryStatemen
std::unique_ptr<QueryResult> ClientContext::query(
std::string_view query, std::string_view encodedJoin, bool enumerateAllPlans) {
lock_t lck{mtx};
// parsing
auto parsedStatements = std::vector<std::unique_ptr<Statement>>();
try {
parsedStatements = parseQuery(query);
} catch (std::exception& exception) { return queryResultWithError(exception.what()); }
if (parsedStatements.empty()) {
if (query.empty()) {
return queryResultWithError("Connection Exception: Query is empty.");
}
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = Parser::parseQuery(query);
} catch (std::exception& exception) { return queryResultWithError(exception.what()); }
std::unique_ptr<QueryResult> queryResult;
QueryResult* lastResult = nullptr;
for (auto& statement : parsedStatements) {
auto preparedStatement = prepareNoLock(
statement.get(), enumerateAllPlans /* enumerate all plans */, encodedJoin);
auto currentQueryResult = executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get());
auto preparedStatement = prepareNoLock(statement,
enumerateAllPlans /* enumerate all plans */, encodedJoin, false /*requireNewTx*/);
auto currentQueryResult = executeAndAutoCommitIfNecessaryNoLock(
preparedStatement.get(), 0u, false /*requiredNexTx*/);
if (!lastResult) {
// first result of the query
queryResult = std::move(currentQueryResult);
Expand Down Expand Up @@ -256,7 +274,9 @@ std::unique_ptr<PreparedStatement> ClientContext::preparedStatementWithError(
}

std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
Statement* parsedStatement, bool enumerateAllPlans, std::string_view encodedJoin) {
std::shared_ptr<Statement> parsedStatement, bool enumerateAllPlans,
std::string_view encodedJoin, bool requireNewTx,
std::optional<std::unordered_map<std::string, std::shared_ptr<common::Value>>> inputParams) {
auto preparedStatement = std::make_unique<PreparedStatement>();
auto compilingTimer = TimeMetric(true /* enable */);
compilingTimer.start();
Expand All @@ -267,18 +287,8 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
if (database->systemConfig.readOnly && !preparedStatement->isReadOnly()) {
throw ConnectionException("Cannot execute write operations in a read-only database!");
}
} catch (std::exception& exception) {
preparedStatement->success = false;
preparedStatement->errMsg = exception.what();
compilingTimer.stop();
preparedStatement->preparedSummary.compilingTime = compilingTimer.getElapsedTimeMS();
return preparedStatement;
}
std::unique_ptr<ExecutionContext> executionContext;
std::unique_ptr<LogicalPlan> logicalPlan;
try {
// parsing
if (parsedStatement->getStatementType() != StatementType::TRANSACTION) {
preparedStatement->parsedStatement = parsedStatement;
if (parsedStatement->requireTx()) {
if (transactionContext->isAutoTransaction()) {
transactionContext->beginAutoTransaction(preparedStatement->readOnly);
} else {
Expand All @@ -292,6 +302,9 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
}
// binding
auto binder = Binder(this);
if (inputParams) {
binder.setInputParameters(*inputParams);
}
auto boundStatement = binder.bind(*parsedStatement);
preparedStatement->parameterMap = binder.getParameterMap();
preparedStatement->statementResult =
Expand Down Expand Up @@ -323,6 +336,9 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
} else {
preparedStatement->logicalPlans = std::move(plans);
}
if (transactionContext->isAutoTransaction() && requireNewTx) {
this->transactionContext->commit();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we have a similar check when rollback?

}
} catch (std::exception& exception) {
preparedStatement->success = false;
preparedStatement->errMsg = exception.what();
Expand All @@ -333,8 +349,8 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
return preparedStatement;
}

std::vector<std::unique_ptr<Statement>> ClientContext::parseQuery(std::string_view query) {
std::vector<std::unique_ptr<Statement>> statements;
std::vector<std::shared_ptr<Statement>> ClientContext::parseQuery(std::string_view query) {
std::vector<std::shared_ptr<Statement>> statements;
if (query.empty()) {
return statements;
}
Expand All @@ -356,7 +372,11 @@ std::unique_ptr<QueryResult> ClientContext::executeWithParams(PreparedStatement*
std::string errMsg = exception.what();
return queryResultWithError(errMsg);
}
return executeAndAutoCommitIfNecessaryNoLock(preparedStatement);
// rebind
KU_ASSERT(preparedStatement->parsedStatement != nullptr);
auto rebindPreparedStatement = prepareNoLock(
preparedStatement->parsedStatement, false, "", false, preparedStatement->parameterMap);
return executeAndAutoCommitIfNecessaryNoLock(rebindPreparedStatement.get(), 0u, false);
}

void ClientContext::bindParametersNoLock(PreparedStatement* preparedStatement,
Expand All @@ -367,14 +387,6 @@ void ClientContext::bindParametersNoLock(PreparedStatement* preparedStatement,
throw Exception("Parameter " + name + " not found.");
}
auto expectParam = parameterMap.at(name);
if (value->getDataType()->getLogicalTypeID() == LogicalTypeID::ANY) {
value->setDataType(*expectParam->getDataType());
}
if (*expectParam->getDataType() != *value->getDataType()) {
throw Exception("Parameter " + name + " has data type " +
value->getDataType()->toString() + " but expects " +
expectParam->getDataType()->toString() + ".");
}
// The much more natural `parameterMap.at(name) = std::move(v)` fails.
// The reason is that other parts of the code rely on the existing Value object to be
// modified in-place, not replaced in this map.
Expand All @@ -383,12 +395,11 @@ void ClientContext::bindParametersNoLock(PreparedStatement* preparedStatement,
}

std::unique_ptr<QueryResult> ClientContext::executeAndAutoCommitIfNecessaryNoLock(
PreparedStatement* preparedStatement, uint32_t planIdx) {
PreparedStatement* preparedStatement, uint32_t planIdx, bool requiredNexTx) {
if (!preparedStatement->isSuccess()) {
return queryResultWithError(preparedStatement->errMsg);
}
if (preparedStatement->preparedSummary.statementType != common::StatementType::TRANSACTION &&
this->getTx() == nullptr) {
if (preparedStatement->parsedStatement->requireTx() && requiredNexTx && getTx() == nullptr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can u branch out from this branch and try commit after prepare and we open a new transaction during execute. In this design, we should no longer need to pass requiredNexTx

this->transactionContext->beginAutoTransaction(preparedStatement->isReadOnly());
if (!preparedStatement->readOnly) {
database->catalog->initCatalogContentForWriteTrxIfNecessary();
Expand Down Expand Up @@ -470,18 +481,18 @@ void ClientContext::runQuery(std::string query) {
if (transactionContext->hasActiveTransaction()) {
transactionContext->commit();
}
auto parsedStatements = std::vector<std::unique_ptr<Statement>>();
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = parseQuery(query);
parsedStatements = Parser::parseQuery(query);
} catch (std::exception& exception) { throw ConnectionException(exception.what()); }
if (parsedStatements.empty()) {
throw ConnectionException("Connection Exception: Query is empty.");
}
try {
for (auto& statement : parsedStatements) {
auto preparedStatement = prepareNoLock(statement.get());
auto preparedStatement = prepareNoLock(statement, false, "", false);
auto currentQueryResult =
executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get());
executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get(), 0u, false);
if (!currentQueryResult->isSuccess()) {
throw ConnectionException(currentQueryResult->errMsg);
}
Expand Down
7 changes: 2 additions & 5 deletions src/main/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,11 @@ std::unique_ptr<PreparedStatement> Connection::preparedStatementWithError(std::s
}

std::unique_ptr<PreparedStatement> Connection::prepareNoLock(
Statement* parsedStatement, bool enumerateAllPlans, std::string_view encodedJoin) {
std::shared_ptr<Statement> parsedStatement, bool enumerateAllPlans,
hououou marked this conversation as resolved.
Show resolved Hide resolved
std::string_view encodedJoin) {
return clientContext->prepareNoLock(parsedStatement, enumerateAllPlans, encodedJoin);
}

std::vector<std::unique_ptr<Statement>> Connection::parseQuery(std::string_view query) {
return clientContext->parseQuery(query);
}

void Connection::interrupt() {
clientContext->interrupt();
}
Expand Down
2 changes: 1 addition & 1 deletion src/parser/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using namespace antlr4;
namespace kuzu {
namespace parser {

std::vector<std::unique_ptr<Statement>> Parser::parseQuery(std::string_view query) {
std::vector<std::shared_ptr<Statement>> Parser::parseQuery(std::string_view query) {
auto inputStream = ANTLRInputStream(query);
auto parserErrorListener = ParserErrorListener();

Expand Down
4 changes: 2 additions & 2 deletions src/parser/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ using namespace kuzu::common;
namespace kuzu {
namespace parser {

std::vector<std::unique_ptr<Statement>> Transformer::transform() {
std::vector<std::unique_ptr<Statement>> statements;
std::vector<std::shared_ptr<Statement>> Transformer::transform() {
std::vector<std::shared_ptr<Statement>> statements;
for (auto& oc_Statement : root.oC_Cypher()) {
auto statement = transformStatement(*oc_Statement->oC_Statement());
if (oc_Statement->oC_AnyCypherOption()) {
Expand Down
4 changes: 2 additions & 2 deletions test/copy/e2e_copy_transaction_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {

void copyNodeCSVCommitAndRecoveryTest(TransactionTestType transactionTestType) {
conn->query(createPersonTableCMD);
auto preparedStatement = conn->prepare(copyPersonTableCMD);
auto preparedStatement = conn->getClientContext()->prepareTest(copyPersonTableCMD);
if (!preparedStatement->success) {
ASSERT_TRUE(false) << preparedStatement->errMsg;
}
Expand Down Expand Up @@ -130,7 +130,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
conn->query(createPersonTableCMD);
conn->query(copyPersonTableCMD);
conn->query(createKnowsTableCMD);
auto preparedStatement = conn->prepare(copyKnowsTableCMD);
auto preparedStatement = conn->getClientContext()->prepareTest(copyKnowsTableCMD);
auto mapper = PlanMapper(conn->getClientContext());
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
Expand Down
2 changes: 1 addition & 1 deletion test/ddl/e2e_ddl_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class TinySnbDDLTest : public DBTest {
}

void executeQueryWithoutCommit(std::string query) {
auto preparedStatement = conn->prepare(query);
auto preparedStatement = conn->getClientContext()->prepareTest(query);
auto mapper = PlanMapper(conn->getClientContext());
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
Expand Down
Loading
Loading