Skip to content

Commit

Permalink
draft
Browse files Browse the repository at this point in the history
  • Loading branch information
hououou committed Mar 26, 2024
1 parent 3a6bd7e commit c08c334
Show file tree
Hide file tree
Showing 15 changed files with 188 additions and 90 deletions.
1 change: 0 additions & 1 deletion src/common/types/value/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ namespace kuzu {
namespace common {

void Value::setDataType(const LogicalType& dataType_) {
KU_ASSERT(dataType->getLogicalTypeID() == LogicalTypeID::ANY);
dataType = dataType_.copy();
}

Expand Down
5 changes: 5 additions & 0 deletions src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ class Binder {

std::unique_ptr<BoundStatement> bind(const parser::Statement& statement);

void setInputParameters(
std::unordered_map<std::string, std::shared_ptr<common::Value>> parameters) {
expressionBinder.parameterMap = parameters;
}

inline std::unordered_map<std::string, std::shared_ptr<common::Value>> getParameterMap() {
return expressionBinder.parameterMap;
}
Expand Down
3 changes: 2 additions & 1 deletion src/include/binder/expression/parameter_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class ParameterExpression : public Expression {
explicit ParameterExpression(
const std::string& parameterName, std::shared_ptr<common::Value> value)
: Expression{common::ExpressionType::PARAMETER,
common::LogicalType(common::LogicalTypeID::ANY), createUniqueName(parameterName)},
common::LogicalType(value->getDataType()->getLogicalTypeID()),
createUniqueName(parameterName)},
parameterName(parameterName), value{std::move(value)} {}

void cast(const common::LogicalType& type) override;
Expand Down
12 changes: 9 additions & 3 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ class ClientContext {
std::unique_ptr<QueryResult> query(std::string_view queryStatement);
void runQuery(std::string query);

void prepareInternal(PreparedStatement& preparedStatement, bool enumerateAllPlans = false,
std::string_view encodedJoin = "",
std::optional<std::unordered_map<std::string, std::shared_ptr<common::Value>>> inputParams =
std::nullopt);

private:
std::unique_ptr<QueryResult> query(
std::string_view query, std::string_view encodedJoin, bool enumerateAllPlans = true);
Expand All @@ -114,10 +119,11 @@ class ClientContext {

std::unique_ptr<PreparedStatement> preparedStatementWithError(std::string_view errMsg);

std::vector<std::unique_ptr<parser::Statement>> parseQuery(std::string_view query);
std::vector<std::shared_ptr<parser::Statement>> parseQuery(std::string_view query);

std::unique_ptr<PreparedStatement> prepareNoLock(parser::Statement* parsedStatement,
bool enumerateAllPlans = false, std::string_view joinOrder = std::string_view());
std::unique_ptr<PreparedStatement> prepareNoLock(
std::shared_ptr<parser::Statement> parsedStatement, bool enumerateAllPlans = false,
std::string_view joinOrder = std::string_view());

template<typename T, typename... Args>
std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,
Expand Down
7 changes: 4 additions & 3 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,11 @@ class Connection {

std::unique_ptr<PreparedStatement> preparedStatementWithError(std::string_view errMsg);

std::vector<std::unique_ptr<parser::Statement>> parseQuery(std::string_view query);
std::vector<std::shared_ptr<parser::Statement>> parseQuery(std::string_view query);

std::unique_ptr<PreparedStatement> prepareNoLock(parser::Statement* parsedStatement,
bool enumerateAllPlans = false, std::string_view joinOrder = std::string_view());
std::unique_ptr<PreparedStatement> prepareNoLock(
std::shared_ptr<parser::Statement> parsedStatement, bool enumerateAllPlans = false,
std::string_view joinOrder = std::string_view());

template<typename T, typename... Args>
std::unique_ptr<QueryResult> executeWithParams(PreparedStatement* preparedStatement,
Expand Down
3 changes: 3 additions & 0 deletions src/include/main/prepared_statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "common/api.h"
#include "kuzu_fwd.h"
#include "parser/statement.h"
#include "query_summary.h"

namespace kuzu {
Expand Down Expand Up @@ -57,11 +58,13 @@ class PreparedStatement {
private:
bool success = true;
bool readOnly = false;
bool requiredNewTx = true;
std::string errMsg;
PreparedSummary preparedSummary;
std::unordered_map<std::string, std::shared_ptr<common::Value>> parameterMap;
std::unique_ptr<binder::BoundStatementResult> statementResult;
std::vector<std::unique_ptr<planner::LogicalPlan>> logicalPlans;
std::shared_ptr<parser::Statement> parsedStatement;
};

} // namespace main
Expand Down
2 changes: 1 addition & 1 deletion src/include/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace parser {
class Parser {

public:
static std::vector<std::unique_ptr<Statement>> parseQuery(std::string_view query);
static std::vector<std::shared_ptr<Statement>> parseQuery(std::string_view query);
};

} // namespace parser
Expand Down
2 changes: 1 addition & 1 deletion src/include/parser/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Transformer {
public:
explicit Transformer(CypherParser::Ku_StatementsContext& root) : root{root} {}

std::vector<std::unique_ptr<Statement>> transform();
std::vector<std::shared_ptr<Statement>> transform();

private:
std::unique_ptr<Statement> transformStatement(CypherParser::OC_StatementContext& ctx);
Expand Down
142 changes: 81 additions & 61 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ std::string ClientContext::getEnvVariable(const std::string& name) {
std::unique_ptr<PreparedStatement> ClientContext::prepare(std::string_view query) {
auto preparedStatement = std::unique_ptr<PreparedStatement>();
std::unique_lock<std::mutex> lck{mtx};
auto parsedStatements = std::vector<std::unique_ptr<Statement>>();
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = parseQuery(query);
} catch (std::exception& exception) { return preparedStatementWithError(exception.what()); }
Expand All @@ -202,7 +202,7 @@ std::unique_ptr<PreparedStatement> ClientContext::prepare(std::string_view query
if (parsedStatements.empty()) {
return preparedStatementWithError("Connection Exception: Query is empty.");
}
return prepareNoLock(parsedStatements[0].get());
return prepareNoLock(parsedStatements[0]);
}

std::unique_ptr<QueryResult> ClientContext::query(std::string_view queryStatement) {
Expand All @@ -213,7 +213,7 @@ std::unique_ptr<QueryResult> ClientContext::query(
std::string_view query, std::string_view encodedJoin, bool enumerateAllPlans) {
lock_t lck{mtx};
// parsing
auto parsedStatements = std::vector<std::unique_ptr<Statement>>();
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = parseQuery(query);
} catch (std::exception& exception) { return queryResultWithError(exception.what()); }
Expand All @@ -223,8 +223,8 @@ std::unique_ptr<QueryResult> ClientContext::query(
std::unique_ptr<QueryResult> queryResult;
QueryResult* lastResult = nullptr;
for (auto& statement : parsedStatements) {
auto preparedStatement = prepareNoLock(
statement.get(), enumerateAllPlans /* enumerate all plans */, encodedJoin);
auto preparedStatement =
prepareNoLock(statement, enumerateAllPlans /* enumerate all plans */, encodedJoin);
auto currentQueryResult = executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get());
if (!lastResult) {
// first result of the query
Expand Down Expand Up @@ -254,8 +254,65 @@ std::unique_ptr<PreparedStatement> ClientContext::preparedStatementWithError(
return preparedStatement;
}

void ClientContext::prepareInternal(PreparedStatement& preparedStatement, bool enumerateAllPlans,
std::string_view encodedJoin,
std::optional<std::unordered_map<std::string, std::shared_ptr<common::Value>>> inputParams) {
auto parsedStatement = preparedStatement.parsedStatement;
// parsing
if (preparedStatement.requiredNewTx) {
if (parsedStatement->getStatementType() != StatementType::TRANSACTION) {
if (transactionContext->isAutoTransaction()) {
transactionContext->beginAutoTransaction(preparedStatement.readOnly);
} else {
transactionContext->validateManualTransaction(
preparedStatement.allowActiveTransaction(), preparedStatement.readOnly);
}
if (!this->getTx()->isReadOnly()) {
database->catalog->initCatalogContentForWriteTrxIfNecessary();
database->storageManager->initStatistics();
}
}
}
// binding
auto binder = Binder(this);
if (inputParams) {
binder.setInputParameters(*inputParams);
}
auto boundStatement = binder.bind(*parsedStatement);
preparedStatement.parameterMap = binder.getParameterMap();
preparedStatement.statementResult =
std::make_unique<BoundStatementResult>(boundStatement->getStatementResult()->copy());
// planning
auto planner = Planner(this);
std::vector<std::unique_ptr<LogicalPlan>> plans;
if (enumerateAllPlans) {
plans = planner.getAllPlans(*boundStatement);
} else {
plans.push_back(planner.getBestPlan(*boundStatement));
}
// optimizing
for (auto& plan : plans) {
optimizer::Optimizer::optimize(plan.get(), this);
}
if (!encodedJoin.empty()) {
std::unique_ptr<LogicalPlan> match;
for (auto& plan : plans) {
if (LogicalPlanUtil::encodeJoin(*plan) == encodedJoin) {
match = std::move(plan);
}
}
if (match == nullptr) {
throw ConnectionException(stringFormat("Cannot find a plan matching {}", encodedJoin));
}
preparedStatement.logicalPlans.push_back(std::move(match));
} else {
preparedStatement.logicalPlans = std::move(plans);
}
}

std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
Statement* parsedStatement, bool enumerateAllPlans, std::string_view encodedJoin) {
std::shared_ptr<Statement> parsedStatement, bool enumerateAllPlans,
std::string_view encodedJoin) {
auto preparedStatement = std::make_unique<PreparedStatement>();
auto compilingTimer = TimeMetric(true /* enable */);
compilingTimer.start();
Expand All @@ -276,52 +333,8 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
std::unique_ptr<ExecutionContext> executionContext;
std::unique_ptr<LogicalPlan> logicalPlan;
try {
// parsing
if (parsedStatement->getStatementType() != StatementType::TRANSACTION) {
if (transactionContext->isAutoTransaction()) {
transactionContext->beginAutoTransaction(preparedStatement->readOnly);
} else {
transactionContext->validateManualTransaction(
preparedStatement->allowActiveTransaction(), preparedStatement->readOnly);
}
if (!this->getTx()->isReadOnly()) {
database->catalog->initCatalogContentForWriteTrxIfNecessary();
database->storageManager->initStatistics();
}
}
// binding
auto binder = Binder(this);
auto boundStatement = binder.bind(*parsedStatement);
preparedStatement->parameterMap = binder.getParameterMap();
preparedStatement->statementResult =
std::make_unique<BoundStatementResult>(boundStatement->getStatementResult()->copy());
// planning
auto planner = Planner(this);
std::vector<std::unique_ptr<LogicalPlan>> plans;
if (enumerateAllPlans) {
plans = planner.getAllPlans(*boundStatement);
} else {
plans.push_back(planner.getBestPlan(*boundStatement));
}
// optimizing
for (auto& plan : plans) {
optimizer::Optimizer::optimize(plan.get(), this);
}
if (!encodedJoin.empty()) {
std::unique_ptr<LogicalPlan> match;
for (auto& plan : plans) {
if (LogicalPlanUtil::encodeJoin(*plan) == encodedJoin) {
match = std::move(plan);
}
}
if (match == nullptr) {
throw ConnectionException(
stringFormat("Cannot find a plan matching {}", encodedJoin));
}
preparedStatement->logicalPlans.push_back(std::move(match));
} else {
preparedStatement->logicalPlans = std::move(plans);
}
preparedStatement->parsedStatement = parsedStatement;
prepareInternal(*preparedStatement, enumerateAllPlans, encodedJoin);
} catch (std::exception& exception) {
preparedStatement->success = false;
preparedStatement->errMsg = exception.what();
Expand All @@ -332,8 +345,8 @@ std::unique_ptr<PreparedStatement> ClientContext::prepareNoLock(
return preparedStatement;
}

std::vector<std::unique_ptr<Statement>> ClientContext::parseQuery(std::string_view query) {
std::vector<std::unique_ptr<Statement>> statements;
std::vector<std::shared_ptr<Statement>> ClientContext::parseQuery(std::string_view query) {
std::vector<std::shared_ptr<Statement>> statements;
if (query.empty()) {
return statements;
}
Expand All @@ -355,6 +368,18 @@ std::unique_ptr<QueryResult> ClientContext::executeWithParams(PreparedStatement*
std::string errMsg = exception.what();
return queryResultWithError(errMsg);
}
// rebind
KU_ASSERT(preparedStatement->parsedStatement != nullptr);
auto& parameterMap = preparedStatement->parameterMap;
preparedStatement->requiredNewTx = transactionContext->hasActiveTransaction() ? false : true;
try {
prepareInternal(*preparedStatement, false, "", parameterMap);
} catch (std::exception& exception) {
preparedStatement->success = false;
preparedStatement->errMsg = exception.what();
this->transactionContext->rollback();
return queryResultWithError(exception.what());
}
return executeAndAutoCommitIfNecessaryNoLock(preparedStatement);
}

Expand All @@ -369,11 +394,6 @@ void ClientContext::bindParametersNoLock(PreparedStatement* preparedStatement,
if (value->getDataType()->getLogicalTypeID() == LogicalTypeID::ANY) {
value->setDataType(*expectParam->getDataType());
}
if (*expectParam->getDataType() != *value->getDataType()) {
throw Exception("Parameter " + name + " has data type " +
value->getDataType()->toString() + " but expects " +
expectParam->getDataType()->toString() + ".");
}
// The much more natural `parameterMap.at(name) = std::move(v)` fails.
// The reason is that other parts of the code rely on the existing Value object to be
// modified in-place, not replaced in this map.
Expand Down Expand Up @@ -469,7 +489,7 @@ void ClientContext::runQuery(std::string query) {
if (transactionContext->hasActiveTransaction()) {
transactionContext->commit();
}
auto parsedStatements = std::vector<std::unique_ptr<Statement>>();
auto parsedStatements = std::vector<std::shared_ptr<Statement>>();
try {
parsedStatements = parseQuery(query);
} catch (std::exception& exception) { throw ConnectionException(exception.what()); }
Expand All @@ -478,7 +498,7 @@ void ClientContext::runQuery(std::string query) {
}
try {
for (auto& statement : parsedStatements) {
auto preparedStatement = prepareNoLock(statement.get());
auto preparedStatement = prepareNoLock(statement);
auto currentQueryResult =
executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get());
if (!currentQueryResult->isSuccess()) {
Expand Down
5 changes: 3 additions & 2 deletions src/main/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,12 @@ std::unique_ptr<PreparedStatement> Connection::preparedStatementWithError(std::s
}

std::unique_ptr<PreparedStatement> Connection::prepareNoLock(
Statement* parsedStatement, bool enumerateAllPlans, std::string_view encodedJoin) {
std::shared_ptr<Statement> parsedStatement, bool enumerateAllPlans,
std::string_view encodedJoin) {
return clientContext->prepareNoLock(parsedStatement, enumerateAllPlans, encodedJoin);
}

std::vector<std::unique_ptr<Statement>> Connection::parseQuery(std::string_view query) {
std::vector<std::shared_ptr<Statement>> Connection::parseQuery(std::string_view query) {
return clientContext->parseQuery(query);
}

Expand Down
2 changes: 1 addition & 1 deletion src/parser/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using namespace antlr4;
namespace kuzu {
namespace parser {

std::vector<std::unique_ptr<Statement>> Parser::parseQuery(std::string_view query) {
std::vector<std::shared_ptr<Statement>> Parser::parseQuery(std::string_view query) {
auto inputStream = ANTLRInputStream(query);
auto parserErrorListener = ParserErrorListener();

Expand Down
4 changes: 2 additions & 2 deletions src/parser/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ using namespace kuzu::common;
namespace kuzu {
namespace parser {

std::vector<std::unique_ptr<Statement>> Transformer::transform() {
std::vector<std::unique_ptr<Statement>> statements;
std::vector<std::shared_ptr<Statement>> Transformer::transform() {
std::vector<std::shared_ptr<Statement>> statements;
for (auto& oc_Statement : root.oC_Cypher()) {
auto statement = transformStatement(*oc_Statement->oC_Statement());
if (oc_Statement->oC_AnyCypherOption()) {
Expand Down
Loading

0 comments on commit c08c334

Please sign in to comment.