Skip to content

Commit

Permalink
solve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hououou committed Mar 26, 2024
1 parent c08c334 commit b3407f2
Show file tree
Hide file tree
Showing 13 changed files with 116 additions and 108 deletions.
1 change: 1 addition & 0 deletions src/common/types/value/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace kuzu {
namespace common {

void Value::setDataType(const LogicalType& dataType_) {
KU_ASSERT(dataType->getLogicalTypeID() == LogicalTypeID::ANY);
dataType = dataType_.copy();
}

Expand Down
3 changes: 1 addition & 2 deletions src/include/binder/expression/parameter_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ class ParameterExpression : public Expression {
public:
explicit ParameterExpression(
const std::string& parameterName, std::shared_ptr<common::Value> value)
: Expression{common::ExpressionType::PARAMETER,
common::LogicalType(value->getDataType()->getLogicalTypeID()),
: Expression{common::ExpressionType::PARAMETER, common::LogicalType(*value->getDataType()),
createUniqueName(parameterName)},
parameterName(parameterName), value{std::move(value)} {}

Expand Down
16 changes: 8 additions & 8 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ class ClientContext {
std::unique_ptr<QueryResult> query(std::string_view queryStatement);
void runQuery(std::string query);

void prepareInternal(PreparedStatement& preparedStatement, bool enumerateAllPlans = false,
std::string_view encodedJoin = "",
std::optional<std::unordered_map<std::string, std::shared_ptr<common::Value>>> inputParams =
std::nullopt);
// TODO(Jiamin): should remove after supporting ddl in manual tx
std::unique_ptr<PreparedStatement> prepareTest(std::string_view query);
// only use for test framework
std::vector<std::shared_ptr<parser::Statement>> parseQuery(std::string_view query);

private:
std::unique_ptr<QueryResult> query(
Expand All @@ -119,11 +119,11 @@ class ClientContext {

std::unique_ptr<PreparedStatement> preparedStatementWithError(std::string_view errMsg);

std::vector<std::shared_ptr<parser::Statement>> parseQuery(std::string_view query);

std::unique_ptr<PreparedStatement> prepareNoLock(
std::shared_ptr<parser::Statement> parsedStatement, bool enumerateAllPlans = false,
std::string_view joinOrder = std::string_view());
std::string_view joinOrder = std::string_view(), bool requireNewTx = true,
std::optional<std::unordered_map<std::string, std::shared_ptr<common::Value>>> inputParams =
std::nullopt);

template<typename T, typename... Args>
std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,
Expand All @@ -139,7 +139,7 @@ class ClientContext {
const std::unordered_map<std::string, std::unique_ptr<common::Value>>& inputParams);

std::unique_ptr<QueryResult> executeAndAutoCommitIfNecessaryNoLock(
PreparedStatement* preparedStatement, uint32_t planIdx = 0u);
PreparedStatement* preparedStatement, uint32_t planIdx = 0u, bool requiredNexTx = true);

void addScalarFunction(std::string name, function::function_set definitions);

Expand Down
2 changes: 0 additions & 2 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,6 @@ class Connection {

std::unique_ptr<PreparedStatement> preparedStatementWithError(std::string_view errMsg);

std::vector<std::shared_ptr<parser::Statement>> parseQuery(std::string_view query);

std::unique_ptr<PreparedStatement> prepareNoLock(
std::shared_ptr<parser::Statement> parsedStatement, bool enumerateAllPlans = false,
std::string_view joinOrder = std::string_view());
Expand Down
1 change: 0 additions & 1 deletion src/include/main/prepared_statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class PreparedStatement {
private:
bool success = true;
bool readOnly = false;
bool requiredNewTx = true;
std::string errMsg;
PreparedSummary preparedSummary;
std::unordered_map<std::string, std::shared_ptr<common::Value>> parameterMap;
Expand Down
9 changes: 9 additions & 0 deletions src/include/parser/statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ class Statement {

inline common::StatementType getStatementType() const { return statementType; }

inline bool requireTx() {
switch (statementType) {
case common::StatementType::TRANSACTION:
return false;
default:
return true;
}
}

private:
common::StatementType statementType;
};
Expand Down
171 changes: 86 additions & 85 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ std::unique_ptr<PreparedStatement> ClientContext::prepare(std::string_view query
std::unique_lock<std::mutex> lck{mtx};
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = parseQuery(query);
parsedStatements = Parser::parseQuery(query);
} catch (std::exception& exception) { return preparedStatementWithError(exception.what()); }
if (parsedStatements.size() > 1) {
return preparedStatementWithError(
Expand All @@ -205,6 +205,24 @@ std::unique_ptr<PreparedStatement> ClientContext::prepare(std::string_view query
return prepareNoLock(parsedStatements[0]);
}

std::unique_ptr<PreparedStatement> ClientContext::prepareTest(std::string_view query) {
auto preparedStatement = std::unique_ptr<PreparedStatement>();
std::unique_lock<std::mutex> lck{mtx};
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = Parser::parseQuery(query);
} catch (std::exception& exception) { return preparedStatementWithError(exception.what()); }
if (parsedStatements.size() > 1) {
return preparedStatementWithError(
"Connection Exception: We do not support prepare multiple statements.");
}
if (parsedStatements.empty()) {
return preparedStatementWithError("Connection Exception: Query is empty.");
}
return prepareNoLock(
parsedStatements[0], false /* enumerate all plans */, "", false /*requireNewTx*/);
}

std::unique_ptr<QueryResult> ClientContext::query(std::string_view queryStatement) {
return query(queryStatement, std::string_view() /*encodedJoin*/, false /*enumerateAllPlans */);
}
Expand All @@ -215,17 +233,18 @@ std::unique_ptr<QueryResult> ClientContext::query(
// parsing
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = parseQuery(query);
parsedStatements = Parser::parseQuery(query);
} catch (std::exception& exception) { return queryResultWithError(exception.what()); }
if (parsedStatements.empty()) {
return queryResultWithError("Connection Exception: Query is empty.");
}
std::unique_ptr<QueryResult> queryResult;
QueryResult* lastResult = nullptr;
for (auto& statement : parsedStatements) {
auto preparedStatement =
prepareNoLock(statement, enumerateAllPlans /* enumerate all plans */, encodedJoin);
auto currentQueryResult = executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get());
auto preparedStatement = prepareNoLock(statement,
enumerateAllPlans /* enumerate all plans */, encodedJoin, false /*requireNewTx*/);
auto currentQueryResult = executeAndAutoCommitIfNecessaryNoLock(
preparedStatement.get(), 0u, false /*requiredNexTx*/);
if (!lastResult) {
// first result of the query
queryResult = std::move(currentQueryResult);
Expand Down Expand Up @@ -254,65 +273,10 @@ std::unique_ptr<PreparedStatement> ClientContext::preparedStatementWithError(
return preparedStatement;
}

void ClientContext::prepareInternal(PreparedStatement& preparedStatement, bool enumerateAllPlans,
std::string_view encodedJoin,
std::optional<std::unordered_map<std::string, std::shared_ptr<common::Value>>> inputParams) {
auto parsedStatement = preparedStatement.parsedStatement;
// parsing
if (preparedStatement.requiredNewTx) {
if (parsedStatement->getStatementType() != StatementType::TRANSACTION) {
if (transactionContext->isAutoTransaction()) {
transactionContext->beginAutoTransaction(preparedStatement.readOnly);
} else {
transactionContext->validateManualTransaction(
preparedStatement.allowActiveTransaction(), preparedStatement.readOnly);
}
if (!this->getTx()->isReadOnly()) {
database->catalog->initCatalogContentForWriteTrxIfNecessary();
database->storageManager->initStatistics();
}
}
}
// binding
auto binder = Binder(this);
if (inputParams) {
binder.setInputParameters(*inputParams);
}
auto boundStatement = binder.bind(*parsedStatement);
preparedStatement.parameterMap = binder.getParameterMap();
preparedStatement.statementResult =
std::make_unique<BoundStatementResult>(boundStatement->getStatementResult()->copy());
// planning
auto planner = Planner(this);
std::vector<std::unique_ptr<LogicalPlan>> plans;
if (enumerateAllPlans) {
plans = planner.getAllPlans(*boundStatement);
} else {
plans.push_back(planner.getBestPlan(*boundStatement));
}
// optimizing
for (auto& plan : plans) {
optimizer::Optimizer::optimize(plan.get(), this);
}
if (!encodedJoin.empty()) {
std::unique_ptr<LogicalPlan> match;
for (auto& plan : plans) {
if (LogicalPlanUtil::encodeJoin(*plan) == encodedJoin) {
match = std::move(plan);
}
}
if (match == nullptr) {
throw ConnectionException(stringFormat("Cannot find a plan matching {}", encodedJoin));
}
preparedStatement.logicalPlans.push_back(std::move(match));
} else {
preparedStatement.logicalPlans = std::move(plans);
}
}

std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
std::shared_ptr<Statement> parsedStatement, bool enumerateAllPlans,
std::string_view encodedJoin) {
std::string_view encodedJoin, bool requireNewTx,
std::optional<std::unordered_map<std::string, std::shared_ptr<common::Value>>> inputParams) {
auto preparedStatement = std::make_unique<PreparedStatement>();
auto compilingTimer = TimeMetric(true /* enable */);
compilingTimer.start();
Expand All @@ -330,11 +294,60 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
preparedStatement->preparedSummary.compilingTime = compilingTimer.getElapsedTimeMS();
return preparedStatement;
}
std::unique_ptr<ExecutionContext> executionContext;
std::unique_ptr<LogicalPlan> logicalPlan;
try {
// parsing
preparedStatement->parsedStatement = parsedStatement;
prepareInternal(*preparedStatement, enumerateAllPlans, encodedJoin);
if (parsedStatement->requireTx()) {
if (transactionContext->isAutoTransaction()) {
transactionContext->beginAutoTransaction(preparedStatement->readOnly);
} else {
transactionContext->validateManualTransaction(
preparedStatement->allowActiveTransaction(), preparedStatement->readOnly);
}
if (!this->getTx()->isReadOnly()) {
database->catalog->initCatalogContentForWriteTrxIfNecessary();
database->storageManager->initStatistics();
}
}
// binding
auto binder = Binder(this);
if (inputParams) {
binder.setInputParameters(*inputParams);
}
auto boundStatement = binder.bind(*parsedStatement);
preparedStatement->parameterMap = binder.getParameterMap();
preparedStatement->statementResult =
std::make_unique<BoundStatementResult>(boundStatement->getStatementResult()->copy());
// planning
auto planner = Planner(this);
std::vector<std::unique_ptr<LogicalPlan>> plans;
if (enumerateAllPlans) {
plans = planner.getAllPlans(*boundStatement);
} else {
plans.push_back(planner.getBestPlan(*boundStatement));
}
// optimizing
for (auto& plan : plans) {
optimizer::Optimizer::optimize(plan.get(), this);
}
if (!encodedJoin.empty()) {
std::unique_ptr<LogicalPlan> match;
for (auto& plan : plans) {
if (LogicalPlanUtil::encodeJoin(*plan) == encodedJoin) {
match = std::move(plan);
}
}
if (match == nullptr) {
throw ConnectionException(
stringFormat("Cannot find a plan matching {}", encodedJoin));
}
preparedStatement->logicalPlans.push_back(std::move(match));
} else {
preparedStatement->logicalPlans = std::move(plans);
}
if (transactionContext->isAutoTransaction() && requireNewTx) {
this->transactionContext->commit();
}
} catch (std::exception& exception) {
preparedStatement->success = false;
preparedStatement->errMsg = exception.what();
Expand Down Expand Up @@ -370,17 +383,9 @@ std::unique_ptr<QueryResult> ClientContext::executeWithParams(PreparedStatement*
}
// rebind
KU_ASSERT(preparedStatement->parsedStatement != nullptr);
auto& parameterMap = preparedStatement->parameterMap;
preparedStatement->requiredNewTx = transactionContext->hasActiveTransaction() ? false : true;
try {
prepareInternal(*preparedStatement, false, "", parameterMap);
} catch (std::exception& exception) {
preparedStatement->success = false;
preparedStatement->errMsg = exception.what();
this->transactionContext->rollback();
return queryResultWithError(exception.what());
}
return executeAndAutoCommitIfNecessaryNoLock(preparedStatement);
auto rebindPreparedStatement = prepareNoLock(
preparedStatement->parsedStatement, false, "", false, preparedStatement->parameterMap);
return executeAndAutoCommitIfNecessaryNoLock(rebindPreparedStatement.get(), 0u, false);
}

void ClientContext::bindParametersNoLock(PreparedStatement* preparedStatement,
Expand All @@ -391,9 +396,6 @@ void ClientContext::bindParametersNoLock(PreparedStatement* preparedStatement,
throw Exception("Parameter " + name + " not found.");
}
auto expectParam = parameterMap.at(name);
if (value->getDataType()->getLogicalTypeID() == LogicalTypeID::ANY) {
value->setDataType(*expectParam->getDataType());
}
// The much more natural `parameterMap.at(name) = std::move(v)` fails.
// The reason is that other parts of the code rely on the existing Value object to be
// modified in-place, not replaced in this map.
Expand All @@ -402,12 +404,11 @@ void ClientContext::bindParametersNoLock(PreparedStatement* preparedStatement,
}

std::unique_ptr<QueryResult> ClientContext::executeAndAutoCommitIfNecessaryNoLock(
PreparedStatement* preparedStatement, uint32_t planIdx) {
PreparedStatement* preparedStatement, uint32_t planIdx, bool requiredNexTx) {
if (!preparedStatement->isSuccess()) {
return queryResultWithError(preparedStatement->errMsg);
}
if (preparedStatement->preparedSummary.statementType != common::StatementType::TRANSACTION &&
this->getTx() == nullptr) {
if (preparedStatement->parsedStatement->requireTx() && requiredNexTx && getTx() == nullptr) {
this->transactionContext->beginAutoTransaction(preparedStatement->isReadOnly());
if (!preparedStatement->readOnly) {
database->catalog->initCatalogContentForWriteTrxIfNecessary();
Expand Down Expand Up @@ -491,16 +492,16 @@ void ClientContext::runQuery(std::string query) {
}
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = parseQuery(query);
parsedStatements = Parser::parseQuery(query);
} catch (std::exception& exception) { throw ConnectionException(exception.what()); }
if (parsedStatements.empty()) {
throw ConnectionException("Connection Exception: Query is empty.");
}
try {
for (auto& statement : parsedStatements) {
auto preparedStatement = prepareNoLock(statement);
auto preparedStatement = prepareNoLock(statement, false, "", false);
auto currentQueryResult =
executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get());
executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get(), 0u, false);
if (!currentQueryResult->isSuccess()) {
throw ConnectionException(currentQueryResult->errMsg);
}
Expand Down
4 changes: 0 additions & 4 deletions src/main/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ std::unique_ptr<PreparedStatement> Connection::prepareNoLock(
return clientContext->prepareNoLock(parsedStatement, enumerateAllPlans, encodedJoin);
}

std::vector<std::shared_ptr<Statement>> Connection::parseQuery(std::string_view query) {
return clientContext->parseQuery(query);
}

void Connection::interrupt() {
clientContext->interrupt();
}
Expand Down
4 changes: 2 additions & 2 deletions test/copy/e2e_copy_transaction_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {

void copyNodeCSVCommitAndRecoveryTest(TransactionTestType transactionTestType) {
conn->query(createPersonTableCMD);
auto preparedStatement = conn->prepare(copyPersonTableCMD);
auto preparedStatement = conn->getClientContext()->prepareTest(copyPersonTableCMD);
if (!preparedStatement->success) {
ASSERT_TRUE(false) << preparedStatement->errMsg;
}
Expand Down Expand Up @@ -138,7 +138,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest {
conn->query(createPersonTableCMD);
conn->query(copyPersonTableCMD);
conn->query(createKnowsTableCMD);
auto preparedStatement = conn->prepare(copyKnowsTableCMD);
auto preparedStatement = conn->getClientContext()->prepareTest(copyKnowsTableCMD);
auto mapper = PlanMapper(conn->getClientContext());
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
Expand Down
2 changes: 1 addition & 1 deletion test/ddl/e2e_ddl_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class TinySnbDDLTest : public DBTest {
}

void executeQueryWithoutCommit(std::string query) {
auto preparedStatement = conn->prepare(query);
auto preparedStatement = conn->getClientContext()->prepareTest(query);
auto mapper = PlanMapper(conn->getClientContext());
auto physicalPlan =
mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(),
Expand Down
4 changes: 2 additions & 2 deletions test/main/prepare_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ TEST_F(ApiTest, issueTest2) {
" return r.name;");
auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("a_id"), 1),
std::make_pair(std::string("b_id"), 3), std::make_pair(std::string("c_id"), 2),
std::make_pair(std::string("my_param"), "null"));
std::make_pair(std::string("my_param"), "friend"));
ASSERT_TRUE(result->hasNext());
checkTuple(result->getNext().get(), "null\n");
checkTuple(result->getNext().get(), "friend\n");
ASSERT_FALSE(result->hasNext());
}

Expand Down
Loading

0 comments on commit b3407f2

Please sign in to comment.