diff --git a/src/include/binder/binder.h b/src/include/binder/binder.h index 8cf55fab538..85373b29248 100644 --- a/src/include/binder/binder.h +++ b/src/include/binder/binder.h @@ -87,6 +87,11 @@ class Binder { std::unique_ptr bind(const parser::Statement& statement); + void setInputParameters( + std::unordered_map> parameters) { + expressionBinder.parameterMap = parameters; + } + inline std::unordered_map> getParameterMap() { return expressionBinder.parameterMap; } diff --git a/src/include/binder/expression/parameter_expression.h b/src/include/binder/expression/parameter_expression.h index 12598a63d45..1d8c53f81bb 100644 --- a/src/include/binder/expression/parameter_expression.h +++ b/src/include/binder/expression/parameter_expression.h @@ -10,8 +10,8 @@ class ParameterExpression : public Expression { public: explicit ParameterExpression( const std::string& parameterName, std::shared_ptr value) - : Expression{common::ExpressionType::PARAMETER, - common::LogicalType(common::LogicalTypeID::ANY), createUniqueName(parameterName)}, + : Expression{common::ExpressionType::PARAMETER, common::LogicalType(*value->getDataType()), + createUniqueName(parameterName)}, parameterName(parameterName), value{std::move(value)} {} void cast(const common::LogicalType& type) override; diff --git a/src/include/main/client_context.h b/src/include/main/client_context.h index 5cdb5c891d7..bb44b290dec 100644 --- a/src/include/main/client_context.h +++ b/src/include/main/client_context.h @@ -106,6 +106,11 @@ class ClientContext { std::unique_ptr query(std::string_view queryStatement); void runQuery(std::string query); + // TODO(Jiamin): should remove after supporting ddl in manual tx + std::unique_ptr prepareTest(std::string_view query); + // only use for test framework + std::vector> parseQuery(std::string_view query); + private: std::unique_ptr query( std::string_view query, std::string_view encodedJoin, bool enumerateAllPlans = true); @@ -114,10 +119,11 @@ class ClientContext { std::unique_ptr preparedStatementWithError(std::string_view errMsg); - std::vector> parseQuery(std::string_view query); - - std::unique_ptr prepareNoLock(parser::Statement* parsedStatement, - bool enumerateAllPlans = false, std::string_view joinOrder = std::string_view()); + std::unique_ptr prepareNoLock( + std::shared_ptr parsedStatement, bool enumerateAllPlans = false, + std::string_view joinOrder = std::string_view(), bool requireNewTx = true, + std::optional>> inputParams = + std::nullopt); template std::unique_ptr executeWithParams(PreparedStatement* preparedStatement, @@ -133,7 +139,7 @@ class ClientContext { const std::unordered_map>& inputParams); std::unique_ptr executeAndAutoCommitIfNecessaryNoLock( - PreparedStatement* preparedStatement, uint32_t planIdx = 0u); + PreparedStatement* preparedStatement, uint32_t planIdx = 0u, bool requiredNexTx = true); void addScalarFunction(std::string name, function::function_set definitions); diff --git a/src/include/main/connection.h b/src/include/main/connection.h index 87d1afcf5bd..d0d5c477183 100644 --- a/src/include/main/connection.h +++ b/src/include/main/connection.h @@ -145,10 +145,9 @@ class Connection { std::unique_ptr preparedStatementWithError(std::string_view errMsg); - std::vector> parseQuery(std::string_view query); - - std::unique_ptr prepareNoLock(parser::Statement* parsedStatement, - bool enumerateAllPlans = false, std::string_view joinOrder = std::string_view()); + std::unique_ptr prepareNoLock( + std::shared_ptr parsedStatement, bool enumerateAllPlans = false, + std::string_view joinOrder = std::string_view()); template std::unique_ptr executeWithParams(PreparedStatement* preparedStatement, diff --git a/src/include/main/prepared_statement.h b/src/include/main/prepared_statement.h index 22d4c9f952e..20ad0974b39 100644 --- a/src/include/main/prepared_statement.h +++ b/src/include/main/prepared_statement.h @@ -7,6 +7,7 @@ #include "common/api.h" #include "kuzu_fwd.h" +#include "parser/statement.h" #include "query_summary.h" namespace kuzu { @@ -62,6 +63,7 @@ class PreparedStatement { std::unordered_map> parameterMap; std::unique_ptr statementResult; std::vector> logicalPlans; + std::shared_ptr parsedStatement; }; } // namespace main diff --git a/src/include/parser/parser.h b/src/include/parser/parser.h index 5ca9ef9ba8b..a6bed351df1 100644 --- a/src/include/parser/parser.h +++ b/src/include/parser/parser.h @@ -12,7 +12,7 @@ namespace parser { class Parser { public: - static std::vector> parseQuery(std::string_view query); + static std::vector> parseQuery(std::string_view query); }; } // namespace parser diff --git a/src/include/parser/statement.h b/src/include/parser/statement.h index 3cb56749f57..45e3d1bb633 100644 --- a/src/include/parser/statement.h +++ b/src/include/parser/statement.h @@ -13,6 +13,15 @@ class Statement { inline common::StatementType getStatementType() const { return statementType; } + inline bool requireTx() { + switch (statementType) { + case common::StatementType::TRANSACTION: + return false; + default: + return true; + } + } + private: common::StatementType statementType; }; diff --git a/src/include/parser/transformer.h b/src/include/parser/transformer.h index ce5ccd92522..392af7a358f 100644 --- a/src/include/parser/transformer.h +++ b/src/include/parser/transformer.h @@ -31,7 +31,7 @@ class Transformer { public: explicit Transformer(CypherParser::Ku_StatementsContext& root) : root{root} {} - std::vector> transform(); + std::vector> transform(); private: std::unique_ptr transformStatement(CypherParser::OC_StatementContext& ctx); diff --git a/src/main/client_context.cpp b/src/main/client_context.cpp index 6a809f69cd5..5317430ba64 100644 --- a/src/main/client_context.cpp +++ b/src/main/client_context.cpp @@ -192,9 +192,9 @@ std::string ClientContext::getEnvVariable(const std::string& name) { std::unique_ptr ClientContext::prepare(std::string_view query) { auto preparedStatement = std::unique_ptr(); std::unique_lock lck{mtx}; - auto parsedStatements = std::vector>(); + auto parsedStatements = std::vector>(); try { - parsedStatements = parseQuery(query); + parsedStatements = Parser::parseQuery(query); } catch (std::exception& exception) { return preparedStatementWithError(exception.what()); } if (parsedStatements.size() > 1) { return preparedStatementWithError( @@ -203,7 +203,25 @@ std::unique_ptr ClientContext::prepare(std::string_view query if (parsedStatements.empty()) { return preparedStatementWithError("Connection Exception: Query is empty."); } - return prepareNoLock(parsedStatements[0].get()); + return prepareNoLock(parsedStatements[0]); +} + +std::unique_ptr ClientContext::prepareTest(std::string_view query) { + auto preparedStatement = std::unique_ptr(); + std::unique_lock lck{mtx}; + auto parsedStatements = std::vector>(); + try { + parsedStatements = Parser::parseQuery(query); + } catch (std::exception& exception) { return preparedStatementWithError(exception.what()); } + if (parsedStatements.size() > 1) { + return preparedStatementWithError( + "Connection Exception: We do not support prepare multiple statements."); + } + if (parsedStatements.empty()) { + return preparedStatementWithError("Connection Exception: Query is empty."); + } + return prepareNoLock( + parsedStatements[0], false /* enumerate all plans */, "", false /*requireNewTx*/); } std::unique_ptr ClientContext::query(std::string_view queryStatement) { @@ -214,9 +232,9 @@ std::unique_ptr ClientContext::query( std::string_view query, std::string_view encodedJoin, bool enumerateAllPlans) { lock_t lck{mtx}; // parsing - auto parsedStatements = std::vector>(); + auto parsedStatements = std::vector>(); try { - parsedStatements = parseQuery(query); + parsedStatements = Parser::parseQuery(query); } catch (std::exception& exception) { return queryResultWithError(exception.what()); } if (parsedStatements.empty()) { return queryResultWithError("Connection Exception: Query is empty."); @@ -224,9 +242,10 @@ std::unique_ptr ClientContext::query( std::unique_ptr queryResult; QueryResult* lastResult = nullptr; for (auto& statement : parsedStatements) { - auto preparedStatement = prepareNoLock( - statement.get(), enumerateAllPlans /* enumerate all plans */, encodedJoin); - auto currentQueryResult = executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get()); + auto preparedStatement = prepareNoLock(statement, + enumerateAllPlans /* enumerate all plans */, encodedJoin, false /*requireNewTx*/); + auto currentQueryResult = executeAndAutoCommitIfNecessaryNoLock( + preparedStatement.get(), 0u, false /*requiredNexTx*/); if (!lastResult) { // first result of the query queryResult = std::move(currentQueryResult); @@ -256,7 +275,9 @@ std::unique_ptr ClientContext::preparedStatementWithError( } std::unique_ptr ClientContext::prepareNoLock( - Statement* parsedStatement, bool enumerateAllPlans, std::string_view encodedJoin) { + std::shared_ptr parsedStatement, bool enumerateAllPlans, + std::string_view encodedJoin, bool requireNewTx, + std::optional>> inputParams) { auto preparedStatement = std::make_unique(); auto compilingTimer = TimeMetric(true /* enable */); compilingTimer.start(); @@ -274,11 +295,10 @@ std::unique_ptr ClientContext::prepareNoLock( preparedStatement->preparedSummary.compilingTime = compilingTimer.getElapsedTimeMS(); return preparedStatement; } - std::unique_ptr executionContext; - std::unique_ptr logicalPlan; try { // parsing - if (parsedStatement->getStatementType() != StatementType::TRANSACTION) { + preparedStatement->parsedStatement = parsedStatement; + if (parsedStatement->requireTx()) { if (transactionContext->isAutoTransaction()) { transactionContext->beginAutoTransaction(preparedStatement->readOnly); } else { @@ -292,6 +312,9 @@ std::unique_ptr ClientContext::prepareNoLock( } // binding auto binder = Binder(this); + if (inputParams) { + binder.setInputParameters(*inputParams); + } auto boundStatement = binder.bind(*parsedStatement); preparedStatement->parameterMap = binder.getParameterMap(); preparedStatement->statementResult = @@ -323,6 +346,9 @@ std::unique_ptr ClientContext::prepareNoLock( } else { preparedStatement->logicalPlans = std::move(plans); } + if (transactionContext->isAutoTransaction() && requireNewTx) { + this->transactionContext->commit(); + } } catch (std::exception& exception) { preparedStatement->success = false; preparedStatement->errMsg = exception.what(); @@ -333,8 +359,8 @@ std::unique_ptr ClientContext::prepareNoLock( return preparedStatement; } -std::vector> ClientContext::parseQuery(std::string_view query) { - std::vector> statements; +std::vector> ClientContext::parseQuery(std::string_view query) { + std::vector> statements; if (query.empty()) { return statements; } @@ -356,7 +382,11 @@ std::unique_ptr ClientContext::executeWithParams(PreparedStatement* std::string errMsg = exception.what(); return queryResultWithError(errMsg); } - return executeAndAutoCommitIfNecessaryNoLock(preparedStatement); + // rebind + KU_ASSERT(preparedStatement->parsedStatement != nullptr); + auto rebindPreparedStatement = prepareNoLock( + preparedStatement->parsedStatement, false, "", false, preparedStatement->parameterMap); + return executeAndAutoCommitIfNecessaryNoLock(rebindPreparedStatement.get(), 0u, false); } void ClientContext::bindParametersNoLock(PreparedStatement* preparedStatement, @@ -367,14 +397,6 @@ void ClientContext::bindParametersNoLock(PreparedStatement* preparedStatement, throw Exception("Parameter " + name + " not found."); } auto expectParam = parameterMap.at(name); - if (value->getDataType()->getLogicalTypeID() == LogicalTypeID::ANY) { - value->setDataType(*expectParam->getDataType()); - } - if (*expectParam->getDataType() != *value->getDataType()) { - throw Exception("Parameter " + name + " has data type " + - value->getDataType()->toString() + " but expects " + - expectParam->getDataType()->toString() + "."); - } // The much more natural `parameterMap.at(name) = std::move(v)` fails. // The reason is that other parts of the code rely on the existing Value object to be // modified in-place, not replaced in this map. @@ -383,12 +405,11 @@ void ClientContext::bindParametersNoLock(PreparedStatement* preparedStatement, } std::unique_ptr ClientContext::executeAndAutoCommitIfNecessaryNoLock( - PreparedStatement* preparedStatement, uint32_t planIdx) { + PreparedStatement* preparedStatement, uint32_t planIdx, bool requiredNexTx) { if (!preparedStatement->isSuccess()) { return queryResultWithError(preparedStatement->errMsg); } - if (preparedStatement->preparedSummary.statementType != common::StatementType::TRANSACTION && - this->getTx() == nullptr) { + if (preparedStatement->parsedStatement->requireTx() && requiredNexTx && getTx() == nullptr) { this->transactionContext->beginAutoTransaction(preparedStatement->isReadOnly()); if (!preparedStatement->readOnly) { database->catalog->initCatalogContentForWriteTrxIfNecessary(); @@ -470,18 +491,18 @@ void ClientContext::runQuery(std::string query) { if (transactionContext->hasActiveTransaction()) { transactionContext->commit(); } - auto parsedStatements = std::vector>(); + auto parsedStatements = std::vector>(); try { - parsedStatements = parseQuery(query); + parsedStatements = Parser::parseQuery(query); } catch (std::exception& exception) { throw ConnectionException(exception.what()); } if (parsedStatements.empty()) { throw ConnectionException("Connection Exception: Query is empty."); } try { for (auto& statement : parsedStatements) { - auto preparedStatement = prepareNoLock(statement.get()); + auto preparedStatement = prepareNoLock(statement, false, "", false); auto currentQueryResult = - executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get()); + executeAndAutoCommitIfNecessaryNoLock(preparedStatement.get(), 0u, false); if (!currentQueryResult->isSuccess()) { throw ConnectionException(currentQueryResult->errMsg); } diff --git a/src/main/connection.cpp b/src/main/connection.cpp index 320a363cc58..0c9c4674670 100644 --- a/src/main/connection.cpp +++ b/src/main/connection.cpp @@ -53,14 +53,11 @@ std::unique_ptr Connection::preparedStatementWithError(std::s } std::unique_ptr Connection::prepareNoLock( - Statement* parsedStatement, bool enumerateAllPlans, std::string_view encodedJoin) { + std::shared_ptr parsedStatement, bool enumerateAllPlans, + std::string_view encodedJoin) { return clientContext->prepareNoLock(parsedStatement, enumerateAllPlans, encodedJoin); } -std::vector> Connection::parseQuery(std::string_view query) { - return clientContext->parseQuery(query); -} - void Connection::interrupt() { clientContext->interrupt(); } diff --git a/src/parser/parser.cpp b/src/parser/parser.cpp index d0fab07a1f7..ed9f50773e6 100644 --- a/src/parser/parser.cpp +++ b/src/parser/parser.cpp @@ -16,7 +16,7 @@ using namespace antlr4; namespace kuzu { namespace parser { -std::vector> Parser::parseQuery(std::string_view query) { +std::vector> Parser::parseQuery(std::string_view query) { auto inputStream = ANTLRInputStream(query); auto parserErrorListener = ParserErrorListener(); diff --git a/src/parser/transformer.cpp b/src/parser/transformer.cpp index 2d3d63a3648..dafcfcad7bd 100644 --- a/src/parser/transformer.cpp +++ b/src/parser/transformer.cpp @@ -10,8 +10,8 @@ using namespace kuzu::common; namespace kuzu { namespace parser { -std::vector> Transformer::transform() { - std::vector> statements; +std::vector> Transformer::transform() { + std::vector> statements; for (auto& oc_Statement : root.oC_Cypher()) { auto statement = transformStatement(*oc_Statement->oC_Statement()); if (oc_Statement->oC_AnyCypherOption()) { diff --git a/test/copy/e2e_copy_transaction_test.cpp b/test/copy/e2e_copy_transaction_test.cpp index 14451a01f10..34a286ffd93 100644 --- a/test/copy/e2e_copy_transaction_test.cpp +++ b/test/copy/e2e_copy_transaction_test.cpp @@ -66,7 +66,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest { void copyNodeCSVCommitAndRecoveryTest(TransactionTestType transactionTestType) { conn->query(createPersonTableCMD); - auto preparedStatement = conn->prepare(copyPersonTableCMD); + auto preparedStatement = conn->getClientContext()->prepareTest(copyPersonTableCMD); if (!preparedStatement->success) { ASSERT_TRUE(false) << preparedStatement->errMsg; } @@ -130,7 +130,7 @@ class TinySnbCopyCSVTransactionTest : public EmptyDBTest { conn->query(createPersonTableCMD); conn->query(copyPersonTableCMD); conn->query(createKnowsTableCMD); - auto preparedStatement = conn->prepare(copyKnowsTableCMD); + auto preparedStatement = conn->getClientContext()->prepareTest(copyKnowsTableCMD); auto mapper = PlanMapper(conn->getClientContext()); auto physicalPlan = mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(), diff --git a/test/ddl/e2e_ddl_test.cpp b/test/ddl/e2e_ddl_test.cpp index 02b21b5cde5..5cfce9ddd67 100644 --- a/test/ddl/e2e_ddl_test.cpp +++ b/test/ddl/e2e_ddl_test.cpp @@ -169,7 +169,7 @@ class TinySnbDDLTest : public DBTest { } void executeQueryWithoutCommit(std::string query) { - auto preparedStatement = conn->prepare(query); + auto preparedStatement = conn->getClientContext()->prepareTest(query); auto mapper = PlanMapper(conn->getClientContext()); auto physicalPlan = mapper.mapLogicalPlanToPhysical(preparedStatement->logicalPlans[0].get(), diff --git a/test/main/prepare_test.cpp b/test/main/prepare_test.cpp index d393233934e..eb9bf945928 100644 --- a/test/main/prepare_test.cpp +++ b/test/main/prepare_test.cpp @@ -7,6 +7,48 @@ static void checkTuple(kuzu::processor::FlatTuple* tuple, const std::string& gro ASSERT_STREQ(tuple->toString().c_str(), groundTruth.c_str()); } +TEST_F(ApiTest, issueTest1) { + conn->query("CREATE NODE TABLE T(id SERIAL, name STRING, PRIMARY KEY(id));"); + conn->query("CREATE (t:T {name: \"foo\"});"); + auto preparedStatement = conn->prepare("MATCH (t:T {id: $p}) return t.name;"); + auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("p"), 0)); + ASSERT_TRUE(result->hasNext()); + checkTuple(result->getNext().get(), "foo\n"); + ASSERT_FALSE(result->hasNext()); +} + +TEST_F(ApiTest, issueTest2) { + conn->query("CREATE NODE TABLE NodeOne(id INT64, name STRING, PRIMARY KEY(id));"); + conn->query("CREATE NODE TABLE NodeTwo(id INT64, name STRING, PRIMARY KEY(id));"); + conn->query("CREATE Rel TABLE RelA(from NodeOne to NodeOne);"); + conn->query("CREATE Rel TABLE RelB(from NodeTwo to NodeOne, name String);"); + conn->query("CREATE (t: NodeOne {id:1, name: \"Alice\"});"); + conn->query("CREATE (t: NodeOne {id:2, name: \"Jack\"});"); + conn->query("CREATE (t: NodeTwo {id:3, name: \"Bob\"});"); + auto preparedStatement = conn->prepare("MATCH (a:NodeOne { id: $a_id })," + "(b:NodeTwo { id: $b_id })," + "(c: NodeOne{ id: $c_id } )" + " MERGE" + " (a)-[:RelA]->(c)," + " (b)-[r:RelB { name: $my_param }]->(c)" + " return r.name;"); + auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("a_id"), 1), + std::make_pair(std::string("b_id"), 3), std::make_pair(std::string("c_id"), 2), + std::make_pair(std::string("my_param"), "friend")); + ASSERT_TRUE(result->hasNext()); + checkTuple(result->getNext().get(), "friend\n"); + ASSERT_FALSE(result->hasNext()); +} + +TEST_F(ApiTest, issueTest) { + auto preparedStatement = conn->prepare("RETURN $1 + 1;"); + auto result = + conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), (int8_t)1)); + ASSERT_TRUE(result->hasNext()); + checkTuple(result->getNext().get(), "2\n"); + ASSERT_FALSE(result->hasNext()); +} + TEST_F(ApiTest, MultiParamsPrepare) { auto preparedStatement = conn->prepare( "MATCH (a:person) WHERE a.fName STARTS WITH $n OR a.fName CONTAINS $xx RETURN COUNT(*)"); @@ -96,9 +138,8 @@ TEST_F(ApiTest, PrepareDefaultParam) { ASSERT_FALSE(result->hasNext()); preparedStatement = conn->prepare("RETURN size($1)"); result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), 1)); - ASSERT_FALSE(result->isSuccess()); - ASSERT_STREQ( - result->getErrorMessage().c_str(), "Parameter 1 has data type INT32 but expects STRING."); + ASSERT_TRUE(result->hasNext()); + checkTuple(result->getNext().get(), "1\n"); } TEST_F(ApiTest, PrepareDefaultListParam) { @@ -109,8 +150,8 @@ TEST_F(ApiTest, PrepareDefaultListParam) { checkTuple(result->getNext().get(), "[1,1]\n"); result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "as")); ASSERT_FALSE(result->isSuccess()); - ASSERT_STREQ( - result->getErrorMessage().c_str(), "Parameter 1 has data type STRING but expects INT64."); + ASSERT_STREQ(result->getErrorMessage().c_str(), + "Binder exception: Cannot bind LIST_CREATION with parameter type INT64 and STRING."); preparedStatement = conn->prepare("RETURN [$1]"); result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), "as")); ASSERT_TRUE(result->hasNext()); @@ -127,9 +168,9 @@ TEST_F(ApiTest, PrepareDefaultStructParam) { ASSERT_TRUE(result->hasNext()); checkTuple(result->getNext().get(), "{a: 10}\n"); result = conn->execute(preparedStatement.get(), std::make_pair(std::string("1"), 1)); - ASSERT_FALSE(result->isSuccess()); - ASSERT_STREQ( - result->getErrorMessage().c_str(), "Parameter 1 has data type INT32 but expects STRING."); + ASSERT_TRUE(result->isSuccess()); + ASSERT_TRUE(result->hasNext()); + checkTuple(result->getNext().get(), "{a: 1}\n"); } TEST_F(ApiTest, PrepareDefaultMapParam) { @@ -170,9 +211,8 @@ TEST_F(ApiTest, ParamTypeError) { conn->prepare("MATCH (a:person) WHERE a.fName STARTS WITH $n RETURN COUNT(*)"); auto result = conn->execute(preparedStatement.get(), std::make_pair(std::string("n"), (int64_t)36)); - ASSERT_FALSE(result->isSuccess()); - ASSERT_STREQ( - "Parameter n has data type INT64 but expects STRING.", result->getErrorMessage().c_str()); + ASSERT_TRUE(result->hasNext()); + checkTuple(result->getNext().get(), "0\n"); } TEST_F(ApiTest, MultipleExecutionOfPreparedStatement) { diff --git a/test/test_runner/test_runner.cpp b/test/test_runner/test_runner.cpp index 0f110434235..a9d47a02ae5 100644 --- a/test/test_runner/test_runner.cpp +++ b/test/test_runner/test_runner.cpp @@ -46,9 +46,9 @@ bool TestRunner::testStatement( replaceEnv(statement->query, "AWS_S3_ACCESS_KEY_ID"); replaceEnv(statement->query, "AWS_S3_SECRET_ACCESS_KEY"); replaceEnv(statement->query, "RUN_ID"); - auto parsedStatements = std::vector>(); + auto parsedStatements = std::vector>(); try { - parsedStatements = conn.parseQuery(statement->query); + parsedStatements = conn.getClientContext()->parseQuery(statement->query); } catch (std::exception& exception) { auto errorPreparedStatement = conn.preparedStatementWithError(exception.what()); return checkLogicalPlan(errorPreparedStatement, statement, conn, 0); @@ -63,9 +63,9 @@ bool TestRunner::testStatement( } auto parsedStatement = std::move(parsedStatements[0]); if (statement->encodedJoin.empty()) { - preparedStatement = conn.prepareNoLock(parsedStatement.get(), statement->enumerate); + preparedStatement = conn.prepareNoLock(parsedStatement, statement->enumerate); } else { - preparedStatement = conn.prepareNoLock(parsedStatement.get(), true, statement->encodedJoin); + preparedStatement = conn.prepareNoLock(parsedStatement, true, statement->encodedJoin); } // Check for wrong statements if (!statement->expectedError && !statement->expectedErrorRegex && diff --git a/tools/python_api/test/test_parameter.py b/tools/python_api/test/test_parameter.py index f389297eb33..d13d68febcc 100644 --- a/tools/python_api/test/test_parameter.py +++ b/tools/python_api/test/test_parameter.py @@ -156,6 +156,27 @@ def test_param_error3(conn_db_readonly: ConnDB) -> None: with pytest.raises(RuntimeError, match="Parameters must be a dict"): conn.execute("MATCH (a:person) WHERE a.registerTime = $1 RETURN COUNT(*);", [("asd", 1, 1)]) +def test_param(conn_db_readwrite: ConnDB) -> None: + conn, db = conn_db_readwrite + conn.execute("CREATE NODE TABLE NodeOne(id INT64, name STRING, PRIMARY KEY(id));"); + conn.execute("CREATE NODE TABLE NodeTwo(id INT64, name STRING, PRIMARY KEY(id));"); + conn.execute("CREATE Rel TABLE RelA(from NodeOne to NodeOne);"); + conn.execute("CREATE Rel TABLE RelB(from NodeTwo to NodeOne, id int64, name String);"); + conn.execute("CREATE (t: NodeOne {id:1, name: \"Alice\"});"); + conn.execute("CREATE (t: NodeOne {id:2, name: \"Jack\"});"); + conn.execute("CREATE (t: NodeTwo {id:3, name: \"Bob\"});"); + result = conn.execute( + "MATCH (a:NodeOne { id: $a_id })," + "(b:NodeTwo { id: $b_id })," + "(c: NodeOne{ id: $c_id } )" + " MERGE" + " (a)-[:RelA]->(c)," + " (b)-[r:RelB { id: 2, name: $my_param }]->(c)" + " return r.*;", {"a_id": 1, "b_id": 3, "c_id": 2, "my_param": None} + ) + assert result.has_next() + assert result.get_next() == [2, None] + result.close() def test_param_error4(conn_db_readonly: ConnDB) -> None: conn, db = conn_db_readonly