From 79bab4af15111c8a65fbea950b86dfc6571bae40 Mon Sep 17 00:00:00 2001 From: Benjamin Winger Date: Tue, 6 Jun 2023 15:29:42 -0400 Subject: [PATCH 1/3] Fix tense --- src/include/main/connection.h | 2 +- src/main/connection.cpp | 2 +- test/main/prepare_test.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/include/main/connection.h b/src/include/main/connection.h index 941d51a89e..93976f52c6 100644 --- a/src/include/main/connection.h +++ b/src/include/main/connection.h @@ -129,7 +129,7 @@ class Connection { KUZU_API std::string getRelPropertyNames(const std::string& relTableName); /** - * @brief interrupts all queries currently executed within this connection. + * @brief interrupts all queries currently executing within this connection. */ KUZU_API void interrupt(); diff --git a/src/main/connection.cpp b/src/main/connection.cpp index 044b9be0c3..2738243ef9 100644 --- a/src/main/connection.cpp +++ b/src/main/connection.cpp @@ -316,7 +316,7 @@ void Connection::bindParametersNoLock(PreparedStatement* preparedStatement, if (expectParam->dataType != value->getDataType()) { throw Exception("Parameter " + name + " has data type " + LogicalTypeUtils::dataTypeToString(value->getDataType()) + - " but expect " + + " but expects " + LogicalTypeUtils::dataTypeToString(expectParam->dataType) + "."); } parameterMap.at(name)->copyValueFrom(*value); diff --git a/test/main/prepare_test.cpp b/test/main/prepare_test.cpp index 2b834631d6..0648c635ca 100644 --- a/test/main/prepare_test.cpp +++ b/test/main/prepare_test.cpp @@ -128,7 +128,7 @@ TEST_F(ApiTest, ParamTypeError) { conn->execute(preparedStatement.get(), std::make_pair(std::string("n"), (int64_t)36)); ASSERT_FALSE(result->isSuccess()); ASSERT_STREQ( - "Parameter n has data type INT64 but expect STRING.", result->getErrorMessage().c_str()); + "Parameter n has data type INT64 but expects STRING.", result->getErrorMessage().c_str()); } TEST_F(ApiTest, MultipleExecutionOfPreparedStatement) { From 5cde4202229d48100c40b1cc31f95d855e88aa77 Mon Sep 17 00:00:00 2001 From: Benjamin Winger Date: Thu, 15 Jun 2023 17:18:52 -0400 Subject: [PATCH 2/3] Made functions const which didn't need to be mutable The rust API needs to call them from a const context. --- src/common/types/value.cpp | 2 +- src/include/common/types/value.h | 2 +- src/include/main/query_result.h | 10 +++++----- src/include/processor/result/flat_tuple.h | 4 ++-- src/main/query_result.cpp | 10 +++++----- src/processor/result/flat_tuple.cpp | 4 ++-- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/common/types/value.cpp b/src/common/types/value.cpp index 460f7dd27a..c0687aa4ec 100644 --- a/src/common/types/value.cpp +++ b/src/common/types/value.cpp @@ -548,7 +548,7 @@ nodeID_t RelVal::getDstNodeID() const { return dstNodeIDVal->getValue(); } -std::string RelVal::getLabelName() { +std::string RelVal::getLabelName() const { return labelVal->getValue(); } diff --git a/src/include/common/types/value.h b/src/include/common/types/value.h index 104c427249..c1206bf1b6 100644 --- a/src/include/common/types/value.h +++ b/src/include/common/types/value.h @@ -329,7 +329,7 @@ class RelVal { /** * @return the name of the RelVal. */ - KUZU_API std::string getLabelName(); + KUZU_API std::string getLabelName() const; /** * @return the value of the RelVal in string format. */ diff --git a/src/include/main/query_result.h b/src/include/main/query_result.h index 8164cd76a1..1bb633dae2 100644 --- a/src/include/main/query_result.h +++ b/src/include/main/query_result.h @@ -54,15 +54,15 @@ class QueryResult { /** * @return name of each column in query result. */ - KUZU_API std::vector getColumnNames(); + KUZU_API std::vector getColumnNames() const; /** * @return dataType of each column in query result. */ - KUZU_API std::vector getColumnDataTypes(); + KUZU_API std::vector getColumnDataTypes() const; /** * @return num of tuples in query result. */ - KUZU_API uint64_t getNumTuples(); + KUZU_API uint64_t getNumTuples() const; /** * @return query summary which stores the execution time, compiling time, plan and query * options. @@ -73,7 +73,7 @@ class QueryResult { /** * @return whether there are more tuples to read. */ - KUZU_API bool hasNext(); + KUZU_API bool hasNext() const; /** * @return next flat tuple in the query result. */ @@ -101,7 +101,7 @@ class QueryResult { const std::vector>& columns, const std::vector>>& expressionToCollectPerColumn); - void validateQuerySucceed(); + void validateQuerySucceed() const; private: // execution status diff --git a/src/include/processor/result/flat_tuple.h b/src/include/processor/result/flat_tuple.h index a90ac5280a..90da155150 100644 --- a/src/include/processor/result/flat_tuple.h +++ b/src/include/processor/result/flat_tuple.h @@ -16,13 +16,13 @@ class FlatTuple { /** * @return number of values in the FlatTuple. */ - KUZU_API uint32_t len(); + KUZU_API uint32_t len() const; /** * @param idx value index to get. * @return the value stored at idx. */ - KUZU_API common::Value* getValue(uint32_t idx); + KUZU_API common::Value* getValue(uint32_t idx) const; std::string toString(); diff --git a/src/main/query_result.cpp b/src/main/query_result.cpp index 1a0ec8c509..8e1b8d88e9 100644 --- a/src/main/query_result.cpp +++ b/src/main/query_result.cpp @@ -57,15 +57,15 @@ size_t QueryResult::getNumColumns() const { return columnDataTypes.size(); } -std::vector QueryResult::getColumnNames() { +std::vector QueryResult::getColumnNames() const { return columnNames; } -std::vector QueryResult::getColumnDataTypes() { +std::vector QueryResult::getColumnDataTypes() const { return columnDataTypes; } -uint64_t QueryResult::getNumTuples() { +uint64_t QueryResult::getNumTuples() const { return querySummary->getIsExplain() ? 0 : factorizedTable->getTotalNumFlatTuples(); } @@ -183,7 +183,7 @@ void QueryResult::initResultTableAndIterator( iterator = std::make_unique(*factorizedTable, std::move(valuesToCollect)); } -bool QueryResult::hasNext() { +bool QueryResult::hasNext() const { validateQuerySucceed(); assert(querySummary->getIsExplain() == false); return iterator->hasNextFlatTuple(); @@ -281,7 +281,7 @@ void QueryResult::writeToCSV( file.close(); } -void QueryResult::validateQuerySucceed() { +void QueryResult::validateQuerySucceed() const { if (!success) { throw Exception(errMsg); } diff --git a/src/processor/result/flat_tuple.cpp b/src/processor/result/flat_tuple.cpp index 74ffe85125..15604c82c4 100644 --- a/src/processor/result/flat_tuple.cpp +++ b/src/processor/result/flat_tuple.cpp @@ -13,11 +13,11 @@ void FlatTuple::addValue(std::unique_ptr value) { values.push_back(std::move(value)); } -uint32_t FlatTuple::len() { +uint32_t FlatTuple::len() const { return values.size(); } -common::Value* FlatTuple::getValue(uint32_t idx) { +common::Value* FlatTuple::getValue(uint32_t idx) const { if (idx >= len()) { throw common::RuntimeException(common::StringUtils::string_format( "ValIdx is out of range. Number of values in flatTuple: {}, valIdx: {}.", len(), idx)); From 983c4dc1586a90ebe2a51b4934bcd7a5361ddc24 Mon Sep 17 00:00:00 2001 From: Benjamin Winger Date: Tue, 9 May 2023 13:42:22 -0400 Subject: [PATCH 3/3] Implemented Rust API --- .github/workflows/ci-workflow.yml | 13 + Makefile | 10 + examples/rust/.gitignore | 1 + examples/rust/Cargo.lock | 268 +++++++++ examples/rust/Cargo.toml | 7 + examples/rust/src/main.rs | 24 + src/include/c_api/kuzu.h | 4 +- src/include/main/query_summary.h | 4 +- tools/rust_api/.gitignore | 2 + tools/rust_api/Cargo.toml | 29 + tools/rust_api/build.rs | 93 +++ tools/rust_api/include/kuzu_rs.h | 177 ++++++ tools/rust_api/kuzu-src | 1 + tools/rust_api/src/CMakeLists.txt | 14 + tools/rust_api/src/connection.rs | 490 ++++++++++++++++ tools/rust_api/src/database.rs | 89 +++ tools/rust_api/src/error.rs | 44 ++ tools/rust_api/src/ffi.rs | 331 +++++++++++ tools/rust_api/src/kuzu_rs.cpp | 266 +++++++++ tools/rust_api/src/lib.rs | 43 ++ tools/rust_api/src/logical_type.rs | 160 ++++++ tools/rust_api/src/query_result.rs | 220 ++++++++ tools/rust_api/src/value.rs | 878 +++++++++++++++++++++++++++++ 23 files changed, 3164 insertions(+), 4 deletions(-) create mode 100644 examples/rust/.gitignore create mode 100644 examples/rust/Cargo.lock create mode 100644 examples/rust/Cargo.toml create mode 100644 examples/rust/src/main.rs create mode 100644 tools/rust_api/.gitignore create mode 100644 tools/rust_api/Cargo.toml create mode 100644 tools/rust_api/build.rs create mode 100644 tools/rust_api/include/kuzu_rs.h create mode 120000 tools/rust_api/kuzu-src create mode 100644 tools/rust_api/src/CMakeLists.txt create mode 100644 tools/rust_api/src/connection.rs create mode 100644 tools/rust_api/src/database.rs create mode 100644 tools/rust_api/src/error.rs create mode 100644 tools/rust_api/src/ffi.rs create mode 100644 tools/rust_api/src/kuzu_rs.cpp create mode 100644 tools/rust_api/src/lib.rs create mode 100644 tools/rust_api/src/logical_type.rs create mode 100644 tools/rust_api/src/query_result.rs create mode 100644 tools/rust_api/src/value.rs diff --git a/.github/workflows/ci-workflow.yml b/.github/workflows/ci-workflow.yml index d7f33790af..bbfd1a360f 100644 --- a/.github/workflows/ci-workflow.yml +++ b/.github/workflows/ci-workflow.yml @@ -34,6 +34,9 @@ jobs: - name: Node.js test run: CC=gcc CXX=g++ make nodejstest NUM_THREADS=32 + - name: Rust test + run: CC=gcc CXX=g++ make rusttest NUM_THREADS=32 + - name: Generate coverage report run: | lcov -c -d ./ --no-external -o cover.info &&\ @@ -131,6 +134,16 @@ jobs: - name: Check test format run: python3 scripts/run-clang-format.py --clang-format-executable /usr/bin/clang-format-11 -r test/ + rustfmt-check: + name: rustfmt check + runs-on: kuzu-self-hosted-testing + steps: + - uses: actions/checkout@v3 + + - name: Check api format + working-directory: tools/rust_api + run: cargo fmt --check + benchmark: name: benchmark needs: [gcc-build-test, clang-build-test] diff --git a/Makefile b/Makefile index fb9247c8d3..9c385e1147 100644 --- a/Makefile +++ b/Makefile @@ -107,6 +107,16 @@ nodejstest: arrow cd $(ROOT_DIR)/tools/nodejs_api/ && \ npm test +rusttest: +ifeq ($(OS),Windows_NT) + cd $(ROOT_DIR)/tools/rust_api && \ + set KUZU_TESTING=1 && \ + cargo test -- --test-threads=1 +else + cd $(ROOT_DIR)/tools/rust_api && \ + KUZU_TESTING=1 cargo test -- --test-threads=1 +endif + clean-python-api: ifeq ($(OS),Windows_NT) if exist tools\python_api\build rmdir /s /q tools\python_api\build diff --git a/examples/rust/.gitignore b/examples/rust/.gitignore new file mode 100644 index 0000000000..ea8c4bf7f3 --- /dev/null +++ b/examples/rust/.gitignore @@ -0,0 +1 @@ +/target diff --git a/examples/rust/Cargo.lock b/examples/rust/Cargo.lock new file mode 100644 index 0000000000..c88f7cd396 --- /dev/null +++ b/examples/rust/Cargo.lock @@ -0,0 +1,268 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "cc" +version = "1.0.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" + +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + +[[package]] +name = "cxx" +version = "1.0.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e88abab2f5abbe4c56e8f1fb431b784d710b709888f35755a160e62e33fe38e8" +dependencies = [ + "cc", + "cxxbridge-flags", + "cxxbridge-macro", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c0c11acd0e63bae27dcd2afced407063312771212b7a823b4fd72d633be30fb" +dependencies = [ + "cc", + "codespan-reporting", + "once_cell", + "proc-macro2", + "quote", + "scratch", + "syn 2.0.18", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d3816ed957c008ccd4728485511e3d9aaf7db419aa321e3d2c5a2f3411e36c8" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26acccf6f445af85ea056362561a24ef56cdc15fcc685f03aec50b9c702cb6d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.18", +] + +[[package]] +name = "hermit-abi" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" +dependencies = [ + "libc", +] + +[[package]] +name = "kuzu" +version = "0.0.4" +dependencies = [ + "cxx", + "cxx-build", + "num-derive", + "num-traits", + "num_cpus", + "time", +] + +[[package]] +name = "kuzu-rust-example" +version = "0.1.0" +dependencies = [ + "kuzu", +] + +[[package]] +name = "libc" +version = "0.2.146" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" + +[[package]] +name = "link-cplusplus" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecd207c9c713c34f95a097a5b029ac2ce6010530c7b49d7fea24d977dede04f5" +dependencies = [ + "cc", +] + +[[package]] +name = "num-derive" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "876a53fff98e03a936a674b29568b0e605f06b29372c2489ff4de23f1949743d" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "num-traits" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num_cpus" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + +[[package]] +name = "proc-macro2" +version = "1.0.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dec2b086b7a862cf4de201096214fa870344cf922b2b30c167badb3af3195406" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "scratch" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1792db035ce95be60c3f8853017b3999209281c24e2ba5bc8e59bf97a0c590c1" + +[[package]] +name = "serde" +version = "1.0.164" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32d41677bcbe24c20c52e7c70b0d8db04134c5d1066bf98662e2871ad200ea3e" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "termcolor" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "time" +version = "0.3.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea9e1b3cf1243ae005d9e74085d4d542f3125458f3a81af210d901dcd7411efd" +dependencies = [ + "serde", + "time-core", +] + +[[package]] +name = "time-core" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" + +[[package]] +name = "unicode-ident" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" + +[[package]] +name = "unicode-width" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +dependencies = [ + "winapi", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" diff --git a/examples/rust/Cargo.toml b/examples/rust/Cargo.toml new file mode 100644 index 0000000000..b99893d1c7 --- /dev/null +++ b/examples/rust/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "kuzu-rust-example" +version = "0.1.0" +edition = "2021" + +[dependencies] +kuzu = {path="../../tools/rust_api"} diff --git a/examples/rust/src/main.rs b/examples/rust/src/main.rs new file mode 100644 index 0000000000..cc1ac82ee7 --- /dev/null +++ b/examples/rust/src/main.rs @@ -0,0 +1,24 @@ +use kuzu::{Connection, Database, Error}; + +fn main() -> Result<(), Error> { + let db = Database::new( + std::env::args() + .nth(1) + .expect("The first CLI argument should be the database path"), + 0, + )?; + let connection = Connection::new(&db)?; + + // Create schema. + connection.query("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY(name));")?; + // Create nodes. + connection.query("CREATE (:Person {name: 'Alice', age: 25});")?; + connection.query("CREATE (:Person {name: 'Bob', age: 30});")?; + + // Execute a simple query. + let mut result = connection.query("MATCH (a:Person) RETURN a.name AS NAME, a.age AS AGE;")?; + + // Print query result. + println!("{}", result.display()); + Ok(()) +} diff --git a/src/include/c_api/kuzu.h b/src/include/c_api/kuzu.h index 46c8faff03..70c756eead 100644 --- a/src/include/c_api/kuzu.h +++ b/src/include/c_api/kuzu.h @@ -882,12 +882,12 @@ KUZU_C_API char* kuzu_rel_val_to_string(kuzu_rel_val* rel_val); */ KUZU_C_API void kuzu_query_summary_destroy(kuzu_query_summary* query_summary); /** - * @brief Returns the compilation time of the given query summary. + * @brief Returns the compilation time of the given query summary in milliseconds. * @param query_summary The query summary to get compilation time. */ KUZU_C_API double kuzu_query_summary_get_compiling_time(kuzu_query_summary* query_summary); /** - * @brief Returns the execution time of the given query summary. + * @brief Returns the execution time of the given query summary in milliseconds. * @param query_summary The query summary to get execution time. */ KUZU_C_API double kuzu_query_summary_get_execution_time(kuzu_query_summary* query_summary); diff --git a/src/include/main/query_summary.h b/src/include/main/query_summary.h index ff977455fe..cf49f3d12d 100644 --- a/src/include/main/query_summary.h +++ b/src/include/main/query_summary.h @@ -26,11 +26,11 @@ class QuerySummary { public: /** - * @return query compiling time. + * @return query compiling time in milliseconds. */ KUZU_API double getCompilingTime() const; /** - * @return query execution time. + * @return query execution time in milliseconds. */ KUZU_API double getExecutionTime() const; bool getIsExplain() const; diff --git a/tools/rust_api/.gitignore b/tools/rust_api/.gitignore new file mode 100644 index 0000000000..a9d37c560c --- /dev/null +++ b/tools/rust_api/.gitignore @@ -0,0 +1,2 @@ +target +Cargo.lock diff --git a/tools/rust_api/Cargo.toml b/tools/rust_api/Cargo.toml new file mode 100644 index 0000000000..9d54f219af --- /dev/null +++ b/tools/rust_api/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "kuzu" +version = "0.0.4" +description = "An in-process property graph database management system built for query speed and scalability" +# Note: 1.63 required for building tests +rust-version = "1.51" +readme = "../../README.md" +homepage = "http://kuzudb.com/" +repository = "https://github.com/kuzudb/kuzu" +license = "MIT" +categories = ["database"] + +edition = "2018" +links = "kuzu" + +[dependencies] +cxx = "1.0" +num-derive = "0.3" +num-traits = "0.2" +time = "0.3" + +[build-dependencies] +cxx-build = "1.0" +num_cpus = "1.0" + +[dev-dependencies] +tempdir = "0.3" +anyhow = "1" +time = {version="0.3", features=["macros"]} diff --git a/tools/rust_api/build.rs b/tools/rust_api/build.rs new file mode 100644 index 0000000000..268a9e5b7c --- /dev/null +++ b/tools/rust_api/build.rs @@ -0,0 +1,93 @@ +use std::env; +use std::path::Path; + +fn link_mode() -> &'static str { + if env::var("KUZU_SHARED").is_ok() { + "dylib" + } else { + "static" + } +} + +fn main() -> Result<(), Box> { + // There is a kuzu-src symlink pointing to the root of the repo since Cargo + // only looks at the files within the rust project when packaging crates. + // Using a symlink the library can both be built in-source and from a crate. + let kuzu_root = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("kuzu-src"); + let target = env::var("PROFILE")?; + let kuzu_cmake_root = kuzu_root.join(format!("build/{target}")); + let mut command = std::process::Command::new("make"); + command + .args(&[target, format!("NUM_THREADS={}", num_cpus::get())]) + .current_dir(&kuzu_root); + let make_status = command.status()?; + assert!(make_status.success()); + + let kuzu_lib_path = kuzu_cmake_root.join("src"); + + println!("cargo:rustc-link-search=native={}", kuzu_lib_path.display()); + + let include_paths = vec![ + Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap()).join("include"), + kuzu_root.join("src/include"), + kuzu_root.join("third_party/nlohmann_json"), + kuzu_root.join("third_party/spdlog"), + ]; + for dir in ["utf8proc", "antlr4_cypher", "antlr4_runtime", "re2"] { + let lib_path = kuzu_cmake_root + .join(format!("third_party/{dir}")) + .canonicalize() + .unwrap_or_else(|_| { + panic!( + "Could not find {}/third_party/{dir}", + kuzu_cmake_root.display() + ) + }); + println!("cargo:rustc-link-search=native={}", lib_path.display()); + } + + let arrow_install = kuzu_root.join("external/build/arrow/install"); + println!( + "cargo:rustc-link-search=native={}", + arrow_install.join("lib").display() + ); + println!( + "cargo:rustc-link-search=native={}", + arrow_install.join("lib64").display() + ); + + println!("cargo:rustc-link-lib={}=kuzu", link_mode()); + if link_mode() == "static" { + println!("cargo:rustc-link-lib=dylib=stdc++"); + + println!("cargo:rustc-link-lib=static=arrow_bundled_dependencies"); + // Dependencies of arrow's bundled dependencies + // Only seems to be necessary when building tests. + // This will probably not work on windows/macOS + // openssl-sys has better cross-platform logic, but just using that doesn't work. + if env::var("KUZU_TESTING").is_ok() { + println!("cargo:rustc-link-lib=dylib=ssl"); + println!("cargo:rustc-link-lib=dylib=crypto"); + } + + println!("cargo:rustc-link-lib=static=parquet"); + println!("cargo:rustc-link-lib=static=arrow"); + + println!("cargo:rustc-link-lib=static=utf8proc"); + println!("cargo:rustc-link-lib=static=antlr4_cypher"); + println!("cargo:rustc-link-lib=static=antlr4_runtime"); + println!("cargo:rustc-link-lib=static=re2"); + } + println!("cargo:rerun-if-env-changed=KUZU_SHARED"); + + println!("cargo:rerun-if-changed=include/kuzu_rs.h"); + println!("cargo:rerun-if-changed=include/kuzu_rs.cpp"); + + cxx_build::bridge("src/ffi.rs") + .file("src/kuzu_rs.cpp") + .flag_if_supported("-std=c++20") + .includes(include_paths) + .compile("kuzu_rs"); + + Ok(()) +} diff --git a/tools/rust_api/include/kuzu_rs.h b/tools/rust_api/include/kuzu_rs.h new file mode 100644 index 0000000000..ad00832d53 --- /dev/null +++ b/tools/rust_api/include/kuzu_rs.h @@ -0,0 +1,177 @@ +#pragma once +#include + +#include "main/kuzu.h" +#include "rust/cxx.h" +// Need to explicitly import some types. +// The generated C++ wrapper code needs to be able to call sizeof on PreparedStatement, +// which it can't do when it only sees forward declarations of its components. +#include +#include
+#include + +namespace kuzu_rs { + +struct TypeListBuilder { + std::vector> types; + + void insert(std::unique_ptr type) { + types.push_back(std::move(type)); + } +}; + +std::unique_ptr create_type_list(); + +struct QueryParams { + std::unordered_map> inputParams; + + void insert(const rust::Str key, std::unique_ptr value) { + inputParams.insert(std::make_pair(key, std::move(value))); + } +}; + +std::unique_ptr new_params(); + +std::unique_ptr create_logical_type(kuzu::common::LogicalTypeID id); +std::unique_ptr create_logical_type_var_list( + std::unique_ptr childType); +std::unique_ptr create_logical_type_fixed_list( + std::unique_ptr childType, uint64_t numElements); +std::unique_ptr create_logical_type_struct( + const rust::Vec& fieldNames, std::unique_ptr fieldTypes); + +const kuzu::common::LogicalType& logical_type_get_var_list_child_type( + const kuzu::common::LogicalType& logicalType); +const kuzu::common::LogicalType& logical_type_get_fixed_list_child_type( + const kuzu::common::LogicalType& logicalType); +uint64_t logical_type_get_fixed_list_num_elements(const kuzu::common::LogicalType& logicalType); + +rust::Vec logical_type_get_struct_field_names(const kuzu::common::LogicalType& value); +std::unique_ptr> logical_type_get_struct_field_types( + const kuzu::common::LogicalType& value); + +// Simple wrapper for vector of unique_ptr since cxx doesn't support functions returning a vector of +// unique_ptr +struct ValueList { + ValueList(const std::vector>& values) : values(values) {} + const std::vector>& values; + uint64_t size() const { return values.size(); } + const std::unique_ptr& get(uint64_t index) const { return values[index]; } +}; + +/* Database */ +std::unique_ptr new_database( + const std::string& databasePath, uint64_t bufferPoolSize); + +void database_set_logging_level(kuzu::main::Database& database, const std::string& level); + +/* Connection */ +std::unique_ptr database_connect(kuzu::main::Database& database); +std::unique_ptr connection_execute(kuzu::main::Connection& connection, + kuzu::main::PreparedStatement& query, std::unique_ptr params); + +rust::String get_node_table_names(kuzu::main::Connection& connection); +rust::String get_rel_table_names(kuzu::main::Connection& connection); +rust::String get_node_property_names(kuzu::main::Connection& connection, rust::Str tableName); +rust::String get_rel_property_names(kuzu::main::Connection& connection, rust::Str relTableName); + +/* PreparedStatement */ +rust::String prepared_statement_error_message(const kuzu::main::PreparedStatement& statement); + +/* QueryResult */ +rust::String query_result_to_string(kuzu::main::QueryResult& result); +rust::String query_result_get_error_message(const kuzu::main::QueryResult& result); + +double query_result_get_compiling_time(const kuzu::main::QueryResult& result); +double query_result_get_execution_time(const kuzu::main::QueryResult& result); + +void query_result_write_to_csv(kuzu::main::QueryResult& query_result, const rust::String& filename, + int8_t delimiter, int8_t escape_character, int8_t newline); + +std::unique_ptr> query_result_column_data_types( + const kuzu::main::QueryResult& query_result); +rust::Vec query_result_column_names(const kuzu::main::QueryResult& query_result); + +/* NodeVal/RelVal */ +struct PropertyList { + const std::vector>>& properties; + + size_t size() const { return properties.size(); } + rust::String get_name(size_t index) const { + return rust::String(this->properties[index].first); + } + const kuzu::common::Value& get_value(size_t index) const { + return *this->properties[index].second.get(); + } +}; + +template +rust::String value_get_label_name(const T& val) { + return val.getLabelName(); +} +template +std::unique_ptr value_get_properties(const T& val) { + return std::make_unique(val.getProperties()); +} + +/* NodeVal */ +std::array node_value_get_node_id(const kuzu::common::NodeVal& val); + +/* RelVal */ +std::array rel_value_get_src_id(const kuzu::common::RelVal& val); +std::array rel_value_get_dst_id(const kuzu::common::RelVal& val); + +/* FlatTuple */ +const kuzu::common::Value& flat_tuple_get_value( + const kuzu::processor::FlatTuple& flatTuple, uint32_t index); + +/* Value */ +rust::String value_get_string(const kuzu::common::Value& value); + +template +std::unique_ptr value_get_unique(const kuzu::common::Value& value) { + return std::make_unique(value.getValue()); +} + +int64_t value_get_interval_secs(const kuzu::common::Value& value); +int32_t value_get_interval_micros(const kuzu::common::Value& value); +int32_t value_get_date_days(const kuzu::common::Value& value); +int64_t value_get_timestamp_micros(const kuzu::common::Value& value); +std::array value_get_internal_id(const kuzu::common::Value& value); +std::unique_ptr value_get_list(const kuzu::common::Value& value); +kuzu::common::LogicalTypeID value_get_data_type_id(const kuzu::common::Value& value); +std::unique_ptr value_get_data_type(const kuzu::common::Value& value); +rust::String value_to_string(const kuzu::common::Value& val); + +std::unique_ptr create_value_string(const rust::String& value); +std::unique_ptr create_value_timestamp(const int64_t timestamp); +std::unique_ptr create_value_date(const int64_t date); +std::unique_ptr create_value_interval( + const int32_t months, const int32_t days, const int64_t micros); +std::unique_ptr create_value_null( + std::unique_ptr typ); +std::unique_ptr create_value_internal_id(uint64_t offset, uint64_t table); +std::unique_ptr create_value_node( + std::unique_ptr id_val, std::unique_ptr label_val); +std::unique_ptr create_value_rel(std::unique_ptr src_id, + std::unique_ptr dst_id, std::unique_ptr label_val); + +template +std::unique_ptr create_value(const T value) { + return std::make_unique(value); +} + +struct ValueListBuilder { + std::vector> values; + + void insert(std::unique_ptr value) { values.push_back(std::move(value)); } +}; + +std::unique_ptr get_list_value( + std::unique_ptr typ, std::unique_ptr value); +std::unique_ptr create_list(); + +void value_add_property(kuzu::common::Value& val, const rust::String& name, + std::unique_ptr property); + +} // namespace kuzu_rs diff --git a/tools/rust_api/kuzu-src b/tools/rust_api/kuzu-src new file mode 120000 index 0000000000..c25bddb6dd --- /dev/null +++ b/tools/rust_api/kuzu-src @@ -0,0 +1 @@ +../.. \ No newline at end of file diff --git a/tools/rust_api/src/CMakeLists.txt b/tools/rust_api/src/CMakeLists.txt new file mode 100644 index 0000000000..33300854c2 --- /dev/null +++ b/tools/rust_api/src/CMakeLists.txt @@ -0,0 +1,14 @@ +add_library(kuzu_rs STATIC + ${CMAKE_CURRENT_SOURCE_DIR}/kuzu_rs.cpp + ${KUZU_RS_BINDINGS_DIR}/sources/kuzu/src/lib.rs.cc +) + +target_include_directories(kuzu_rs + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${KUZU_RS_BINDINGS_DIR}/include + ${KUZU_RS_BINDINGS_DIR}/crate + ${KUZU_RS_BINDINGS_DIR}/sources +) + +target_link_libraries(kuzu_rs kuzu) diff --git a/tools/rust_api/src/connection.rs b/tools/rust_api/src/connection.rs new file mode 100644 index 0000000000..735f4d96e7 --- /dev/null +++ b/tools/rust_api/src/connection.rs @@ -0,0 +1,490 @@ +use crate::database::Database; +use crate::error::Error; +use crate::ffi::ffi; +use crate::query_result::QueryResult; +use crate::value::Value; +use cxx::{let_cxx_string, UniquePtr}; +use std::cell::UnsafeCell; +use std::convert::TryInto; + +/// A prepared stattement is a parameterized query which can avoid planning the same query for +/// repeated execution +pub struct PreparedStatement { + statement: UniquePtr, +} + +/// Connections are used to interact with a Database instance. +/// +/// ## Concurrency +/// +/// Each connection is thread-safe, and multiple connections can connect to the same Database +/// instance in a multithreaded environment. +/// +/// Note that since connections require a reference to the Database, creating or using connections +/// in multiple threads cannot be done from a regular std::thread since the threads (and +/// connections) could outlive the database. This can be worked around by using a +/// [scoped thread](std::thread::scope) (Note: Introduced in rust 1.63. For compatibility with +/// older versions of rust, [crosssbeam_utils::thread::scope](https://docs.rs/crossbeam-utils/latest/crossbeam_utils/thread/index.html) can be used instead). +/// +/// Also note that write queries can only be done one at a time; the query command will return an +/// [error](Error::FailedQuery) if another write query is in progress. +/// +/// ``` +/// # use kuzu::{Connection, Database, Value, Error}; +/// # fn main() -> anyhow::Result<()> { +/// # let temp_dir = tempdir::TempDir::new("example3")?; +/// # let db = Database::new(temp_dir.path(), 0)?; +/// let conn = Connection::new(&db)?; +/// conn.query("CREATE NODE TABLE Person(name STRING, age INT32, PRIMARY KEY(name));")?; +/// // Write queries must be done sequentially +/// conn.query("CREATE (:Person {name: 'Alice', age: 25});")?; +/// conn.query("CREATE (:Person {name: 'Bob', age: 30});")?; +/// let (alice, bob) = std::thread::scope(|s| -> Result<(Vec, Vec), Error> { +/// let alice_thread = s.spawn(|| -> Result, Error> { +/// let conn = Connection::new(&db)?; +/// let mut result = conn.query("MATCH (a:Person) WHERE a.name = \"Alice\" RETURN a.name AS NAME, a.age AS AGE;")?; +/// Ok(result.next().unwrap()) +/// }); +/// let bob_thread = s.spawn(|| -> Result, Error> { +/// let conn = Connection::new(&db)?; +/// let mut result = conn.query( +/// "MATCH (a:Person) WHERE a.name = \"Bob\" RETURN a.name AS NAME, a.age AS AGE;", +/// )?; +/// Ok(result.next().unwrap()) +/// }); +/// Ok((alice_thread.join().unwrap()?, bob_thread.join().unwrap()?)) +/// })?; +/// +/// assert_eq!(alice, vec!["Alice".into(), 25.into()]); +/// assert_eq!(bob, vec!["Bob".into(), 30.into()]); +/// temp_dir.close()?; +/// Ok(()) +/// # } +/// ``` +/// +/// ## Committing +/// If the connection is in AUTO_COMMIT mode any query over the connection will be wrapped around +/// a transaction and committed (even if the query is READ_ONLY). +/// If the connection is in MANUAL transaction mode, which happens only if an application +/// manually begins a transaction (see below), then an application has to manually commit or +/// rollback the transaction by calling commit() or rollback(). +/// +/// AUTO_COMMIT is the default mode when a Connection is created. If an application calls +/// begin[ReadOnly/Write]Transaction at any point, the mode switches to MANUAL. This creates +/// an "active transaction" in the connection. When a connection is in MANUAL mode and the +/// active transaction is rolled back or committed, then the active transaction is removed (so +/// the connection no longer has an active transaction) and the mode automatically switches +/// back to AUTO_COMMIT. +/// Note: When a Connection object is deconstructed, if the connection has an active (manual) +/// transaction, then the active transaction is rolled back. +/// +/// ``` +/// use kuzu::{Database, Connection}; +/// # use anyhow::Error; +/// # fn main() -> Result<(), Error> { +/// # let temp_dir = tempdir::TempDir::new("example")?; +/// # let path = temp_dir.path(); +/// let db = Database::new(path, 0)?; +/// let conn = Connection::new(&db)?; +/// // AUTO_COMMIT mode +/// conn.query("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY(name));")?; +/// conn.begin_write_transaction()?; +/// // MANUAL mode (write) +/// conn.query("CREATE (:Person {name: 'Alice', age: 25});")?; +/// conn.query("CREATE (:Person {name: 'Bob', age: 30});")?; +/// // Queries committed and mode goes back to AUTO_COMMIT +/// conn.commit()?; +/// let result = conn.query("MATCH (a:Person) RETURN a.name AS NAME, a.age AS AGE;")?; +/// assert!(result.count() == 2); +/// # temp_dir.close()?; +/// # Ok(()) +/// # } +/// ``` +/// +/// ``` +/// use kuzu::{Database, Connection}; +/// # use anyhow::Error; +/// # fn main() -> Result<(), Error> { +/// # let temp_dir = tempdir::TempDir::new("example")?; +/// # let path = temp_dir.path(); +/// let db = Database::new(path, 0)?; +/// let conn = Connection::new(&db)?; +/// // AUTO_COMMIT mode +/// conn.query("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY(name));")?; +/// conn.begin_write_transaction()?; +/// // MANUAL mode (write) +/// conn.query("CREATE (:Person {name: 'Alice', age: 25});")?; +/// conn.query("CREATE (:Person {name: 'Bob', age: 30});")?; +/// // Queries rolled back and mode goes back to AUTO_COMMIT +/// conn.rollback()?; +/// let result = conn.query("MATCH (a:Person) RETURN a.name AS NAME, a.age AS AGE;")?; +/// assert!(result.count() == 0); +/// # temp_dir.close()?; +/// # Ok(()) +/// # } +/// ``` +pub struct Connection<'a> { + // bmwinger: Access to the underlying value for synchronized functions can be done + // with (*self.conn.get()).pin_mut() + // Turning this into a function just causes lifetime issues. + conn: UnsafeCell>>, +} + +// Connections are synchronized on the C++ side and should be safe to move and access across +// threads +unsafe impl<'a> Send for Connection<'a> {} +unsafe impl<'a> Sync for Connection<'a> {} + +impl<'a> Connection<'a> { + /// Creates a connection to the database. + /// + /// # Arguments + /// * `database`: A reference to the database instance to which this connection will be connected. + pub fn new(database: &'a Database) -> Result { + let db = unsafe { (*database.db.get()).pin_mut() }; + Ok(Connection { + conn: UnsafeCell::new(ffi::database_connect(db)?), + }) + } + + /// Sets the maximum number of threads to use for execution in the current connection + /// + /// # Arguments + /// * `num_threads`: The maximum number of threads to use for execution in the current connection + pub fn set_max_num_threads_for_exec(&mut self, num_threads: u64) { + self.conn + .get_mut() + .pin_mut() + .setMaxNumThreadForExec(num_threads); + } + + /// Returns the maximum number of threads used for execution in the current connection + pub fn get_max_num_threads_for_exec(&self) -> u64 { + unsafe { (*self.conn.get()).pin_mut().getMaxNumThreadForExec() } + } + + /// Prepares the given query and returns the prepared statement. + /// + /// # Arguments + /// * `query`: The query to prepare. + /// See for details on the query format + pub fn prepare(&self, query: &str) -> Result { + let_cxx_string!(query = query); + let statement = unsafe { (*self.conn.get()).pin_mut() }.prepare(&query)?; + if statement.isSuccess() { + Ok(PreparedStatement { statement }) + } else { + Err(Error::FailedPreparedStatement( + ffi::prepared_statement_error_message(&statement), + )) + } + } + + /// Executes the given query and returns the result. + /// + /// # Arguments + /// * `query`: The query to execute. + /// See for details on the query format + // TODO(bmwinger): Instead of having a Value enum in the results, perhaps QueryResult, and thus query + // should be generic. + // + // E.g. + // let result: QueryResult> = conn.query("...")?; + // let result: QueryResult = conn.query("...")?; + // + // But this would really just be syntactic sugar wrapping the current system + pub fn query(&self, query: &str) -> Result { + let mut statement = self.prepare(query)?; + self.execute(&mut statement, vec![]) + } + + /// Executes the given prepared statement with args and returns the result. + /// + /// # Arguments + /// * `prepared_statement`: The prepared statement to execute + pub fn execute( + &self, + prepared_statement: &mut PreparedStatement, + params: Vec<(&str, Value)>, + ) -> Result { + // Passing and converting Values in a collection across the ffi boundary is difficult + // (std::vector cannot be constructed from rust, Vec cannot contain opaque C++ types) + // So we create an opaque parameter pack and copy the parameters into it one by one + let mut cxx_params = ffi::new_params(); + for (key, value) in params { + let ffi_value: cxx::UniquePtr = value.try_into()?; + cxx_params.pin_mut().insert(key, ffi_value); + } + let conn = unsafe { (*self.conn.get()).pin_mut() }; + let result = + ffi::connection_execute(conn, prepared_statement.statement.pin_mut(), cxx_params)?; + if !result.isSuccess() { + Err(Error::FailedQuery(ffi::query_result_get_error_message( + &result, + ))) + } else { + Ok(QueryResult { result }) + } + } + + /// Manually starts a new read-only transaction in the current connection + pub fn begin_read_only_transaction(&self) -> Result<(), Error> { + let conn = unsafe { (*self.conn.get()).pin_mut() }; + Ok(conn.beginReadOnlyTransaction()?) + } + + /// Manually starts a new write transaction in the current connection + pub fn begin_write_transaction(&self) -> Result<(), Error> { + let conn = unsafe { (*self.conn.get()).pin_mut() }; + Ok(conn.beginWriteTransaction()?) + } + + /// Manually commits the current transaction + pub fn commit(&self) -> Result<(), Error> { + let conn = unsafe { (*self.conn.get()).pin_mut() }; + Ok(conn.commit()?) + } + + /// Manually rolls back the current transaction + pub fn rollback(&self) -> Result<(), Error> { + let conn = unsafe { (*self.conn.get()).pin_mut() }; + Ok(conn.rollback()?) + } + + /// Interrupts all queries currently executing within this connection + pub fn interrupt(&self) -> Result<(), Error> { + let conn = unsafe { (*self.conn.get()).pin_mut() }; + Ok(conn.interrupt()?) + } + + /// Returns all node table names in string format + pub fn get_node_table_names(&self) -> String { + let conn = unsafe { (*self.conn.get()).pin_mut() }; + ffi::get_node_table_names(conn) + } + + /// Returns all rel table names in string format + pub fn get_rel_table_names(&self) -> String { + let conn = unsafe { (*self.conn.get()).pin_mut() }; + ffi::get_rel_table_names(conn) + } + + /// Returns all property names of the given table + pub fn get_node_property_names(&self, table_name: &str) -> String { + let conn = unsafe { (*self.conn.get()).pin_mut() }; + ffi::get_node_property_names(conn, table_name) + } + + /// Returns all property names of the given table + pub fn get_rel_property_names(&self, rel_table_name: &str) -> String { + let conn = unsafe { (*self.conn.get()).pin_mut() }; + ffi::get_rel_property_names(conn, rel_table_name) + } + + /// Sets the query timeout value of the current connection + /// + /// A value of zero (the default) disables the timeout. + pub fn set_query_timeout(&self, timeout_ms: u64) { + let conn = unsafe { (*self.conn.get()).pin_mut() }; + conn.setQueryTimeOut(timeout_ms); + } +} + +#[cfg(test)] +mod tests { + use crate::{connection::Connection, database::Database, value::Value}; + use anyhow::{Error, Result}; + // Note: Cargo runs tests in parallel by default, however kuzu does not support + // working with multiple databases in parallel. + // Tests can be run serially with `cargo test -- --test-threads=1` to work around this. + + #[test] + fn test_connection_threads() -> Result<()> { + let temp_dir = tempdir::TempDir::new("example1")?; + let db = Database::new(temp_dir.path(), 0)?; + let mut conn = Connection::new(&db)?; + conn.set_max_num_threads_for_exec(5); + assert_eq!(conn.get_max_num_threads_for_exec(), 5); + temp_dir.close()?; + Ok(()) + } + + #[test] + fn test_invalid_query() -> Result<()> { + let temp_dir = tempdir::TempDir::new("example2")?; + let db = Database::new(temp_dir.path(), 0)?; + let conn = Connection::new(&db)?; + conn.query("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY(name));")?; + conn.query("CREATE (:Person {name: 'Alice', age: 25});")?; + conn.query("CREATE (:Person {name: 'Bob', age: 30});")?; + + let result: Error = conn + .query("MATCH (a:Person RETURN a.name AS NAME, a.age AS AGE;") + .expect_err("Invalid syntax in query should produce an error") + .into(); + assert_eq!( + result.to_string(), + "Query execution failed: Parser exception: \ +Invalid input : expected rule oC_SingleQuery (line: 1, offset: 16) +\"MATCH (a:Person RETURN a.name AS NAME, a.age AS AGE;\" + ^^^^^^" + ); + Ok(()) + } + + #[test] + fn test_query_result() -> Result<()> { + let temp_dir = tempdir::TempDir::new("example3")?; + let db = Database::new(temp_dir.path(), 0)?; + let conn = Connection::new(&db)?; + conn.query("CREATE NODE TABLE Person(name STRING, age INT16, PRIMARY KEY(name));")?; + conn.query("CREATE (:Person {name: 'Alice', age: 25});")?; + + for result in conn.query("MATCH (a:Person) RETURN a.name AS NAME, a.age AS AGE;")? { + assert_eq!(result.len(), 2); + assert_eq!(result[0], Value::String("Alice".to_string())); + assert_eq!(result[1], Value::Int16(25)); + } + temp_dir.close()?; + Ok(()) + } + + #[test] + fn test_params() -> Result<()> { + let temp_dir = tempdir::TempDir::new("example3")?; + let db = Database::new(temp_dir.path(), 0)?; + let conn = Connection::new(&db)?; + conn.query("CREATE NODE TABLE Person(name STRING, age INT16, PRIMARY KEY(name));")?; + conn.query("CREATE (:Person {name: 'Alice', age: 25});")?; + conn.query("CREATE (:Person {name: 'Bob', age: 30});")?; + + let mut statement = conn.prepare("MATCH (a:Person) WHERE a.age = $age RETURN a.name;")?; + for result in conn.execute(&mut statement, vec![("age", Value::Int16(25))])? { + assert_eq!(result.len(), 1); + assert_eq!(result[0], Value::String("Alice".to_string())); + } + temp_dir.close()?; + Ok(()) + } + + #[test] + fn test_params_invalid_type() -> Result<()> { + let temp_dir = tempdir::TempDir::new("example3")?; + let db = Database::new(temp_dir.path(), 0)?; + let conn = Connection::new(&db)?; + conn.query("CREATE NODE TABLE Person(name STRING, age INT16, PRIMARY KEY(name));")?; + conn.query("CREATE (:Person {name: 'Alice', age: 25});")?; + conn.query("CREATE (:Person {name: 'Bob', age: 30});")?; + + let mut statement = conn.prepare("MATCH (a:Person) WHERE a.age = $age RETURN a.name;")?; + let result: Error = conn + .execute( + &mut statement, + vec![("age", Value::String("25".to_string()))], + ) + .expect_err("Age should be an int16!") + .into(); + assert_eq!( + result.to_string(), + "Query execution failed: Parameter age has data type STRING but expects INT16." + ); + temp_dir.close()?; + Ok(()) + } + + #[test] + fn test_multithreaded_single_conn() -> Result<()> { + let temp_dir = tempdir::TempDir::new("example3")?; + let db = Database::new(temp_dir.path(), 0)?; + + let conn = Connection::new(&db)?; + conn.query("CREATE NODE TABLE Person(name STRING, age INT32, PRIMARY KEY(name));")?; + // Write queries must be done sequentially + conn.query("CREATE (:Person {name: 'Alice', age: 25});")?; + conn.query("CREATE (:Person {name: 'Bob', age: 30});")?; + + let (alice, bob) = std::thread::scope(|s| -> Result<(Vec, Vec)> { + let alice_thread = s.spawn(|| -> Result> { + let mut result = conn.query("MATCH (a:Person) WHERE a.name = \"Alice\" RETURN a.name AS NAME, a.age AS AGE;")?; + Ok(result.next().unwrap()) + }); + let bob_thread = s.spawn(|| -> Result> { + let mut result = conn.query( + "MATCH (a:Person) WHERE a.name = \"Bob\" RETURN a.name AS NAME, a.age AS AGE;", + )?; + Ok(result.next().unwrap()) + }); + + Ok((alice_thread.join().unwrap()?, bob_thread.join().unwrap()?)) + })?; + + assert_eq!(alice, vec!["Alice".into(), 25.into()]); + assert_eq!(bob, vec!["Bob".into(), 30.into()]); + temp_dir.close()?; + Ok(()) + } + + #[test] + fn test_multithreaded_multiple_conn() -> Result<()> { + let temp_dir = tempdir::TempDir::new("example3")?; + let db = Database::new(temp_dir.path(), 0)?; + + let conn = Connection::new(&db)?; + conn.query("CREATE NODE TABLE Person(name STRING, age INT32, PRIMARY KEY(name));")?; + // Write queries must be done sequentially + conn.query("CREATE (:Person {name: 'Alice', age: 25});")?; + conn.query("CREATE (:Person {name: 'Bob', age: 30});")?; + + let (alice, bob) = std::thread::scope(|s| -> Result<(Vec, Vec)> { + let alice_thread = s.spawn(|| -> Result> { + let conn = Connection::new(&db)?; + let mut result = conn.query("MATCH (a:Person) WHERE a.name = \"Alice\" RETURN a.name AS NAME, a.age AS AGE;")?; + Ok(result.next().unwrap()) + }); + let bob_thread = s.spawn(|| -> Result> { + let conn = Connection::new(&db)?; + let mut result = conn.query( + "MATCH (a:Person) WHERE a.name = \"Bob\" RETURN a.name AS NAME, a.age AS AGE;", + )?; + Ok(result.next().unwrap()) + }); + + Ok((alice_thread.join().unwrap()?, bob_thread.join().unwrap()?)) + })?; + + assert_eq!(alice, vec!["Alice".into(), 25.into()]); + assert_eq!(bob, vec!["Bob".into(), 30.into()]); + temp_dir.close()?; + Ok(()) + } + + #[test] + fn test_table_names() -> Result<()> { + let temp_dir = tempdir::TempDir::new("example3")?; + let db = Database::new(temp_dir.path(), 0)?; + let conn = Connection::new(&db)?; + conn.query("CREATE NODE TABLE Person(name STRING, age INT16, PRIMARY KEY(name));")?; + conn.query("CREATE REL TABLE Follows(FROM Person TO Person, since DATE);")?; + assert_eq!( + conn.get_node_table_names(), + "Node tables: \n\tPerson\n".to_string() + ); + assert_eq!( + conn.get_rel_table_names(), + "Rel tables: \n\tFollows\n".to_string() + ); + assert_eq!( + conn.get_node_property_names("Person"), + "Person properties: \n\tname STRING(PRIMARY KEY)\n\tage INT16\n".to_string() + ); + assert_eq!( + conn.get_rel_property_names("Follows"), + "Follows src node: Person\n\ + Follows dst node: Person\nFollows properties: \n\tsince DATE\n" + .to_string() + ); + + temp_dir.close()?; + Ok(()) + } +} diff --git a/tools/rust_api/src/database.rs b/tools/rust_api/src/database.rs new file mode 100644 index 0000000000..4680d6972f --- /dev/null +++ b/tools/rust_api/src/database.rs @@ -0,0 +1,89 @@ +use crate::error::Error; +use crate::ffi::ffi; +use cxx::{let_cxx_string, UniquePtr}; +use std::cell::UnsafeCell; +use std::fmt; +use std::path::Path; + +/// The Database class is the main class of KuzuDB. It manages all database components. +pub struct Database { + pub(crate) db: UnsafeCell>, +} + +unsafe impl Send for Database {} +unsafe impl Sync for Database {} + +/// Logging level of the database instance +pub enum LoggingLevel { + Debug, + Info, + Error, +} + +impl Database { + /// Creates a database object + /// + /// # Arguments: + /// * `path`: Path of the database. If the database does not already exist, it will be created. + /// * `buffer_pool_size`: Max size of the buffer pool in bytes + pub fn new>(path: P, buffer_pool_size: u64) -> Result { + let_cxx_string!(path = path.as_ref().display().to_string()); + Ok(Database { + db: UnsafeCell::new(ffi::new_database(&path, buffer_pool_size)?), + }) + } + + /// Sets the logging level of the database instance + /// + /// # Arguments + /// * `logging_level`: New logging level. + pub fn set_logging_level(&mut self, logging_level: LoggingLevel) { + let_cxx_string!( + level = match logging_level { + LoggingLevel::Debug => "debug", + LoggingLevel::Info => "info", + LoggingLevel::Error => "err", + } + ); + ffi::database_set_logging_level(self.db.get_mut().pin_mut(), &level); + } +} + +impl fmt::Debug for Database { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Database") + .field("db", &"Opaque C++ data".to_string()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use crate::database::{Database, LoggingLevel}; + use anyhow::{Error, Result}; + // Note: Cargo runs tests in parallel by default, however kuzu does not support + // working with multiple databases in parallel. + // Tests can be run serially with `cargo test -- --test-threads=1` to work around this. + + #[test] + fn create_database() -> Result<()> { + let temp_dir = tempdir::TempDir::new("example")?; + let mut db = Database::new(temp_dir.path(), 0)?; + db.set_logging_level(LoggingLevel::Debug); + db.set_logging_level(LoggingLevel::Info); + db.set_logging_level(LoggingLevel::Error); + temp_dir.close()?; + Ok(()) + } + + #[test] + fn create_database_failure() { + let result: Error = Database::new("", 0) + .expect_err("An empty string should not be a valid database path!") + .into(); + assert_eq!( + result.to_string(), + "Failed to create directory due to: filesystem error: cannot create directory: No such file or directory []" + ); + } +} diff --git a/tools/rust_api/src/error.rs b/tools/rust_api/src/error.rs new file mode 100644 index 0000000000..cd8701f21e --- /dev/null +++ b/tools/rust_api/src/error.rs @@ -0,0 +1,44 @@ +use std::fmt; + +pub enum Error { + /// Exception raised by C++ kuzu library + CxxException(cxx::Exception), + /// Message produced by kuzu when a query fails + FailedQuery(String), + /// Message produced by kuzu when a query fails to prepare + FailedPreparedStatement(String), +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use Error::*; + match self { + CxxException(cxx) => write!(f, "{cxx}"), + FailedQuery(message) => write!(f, "Query execution failed: {message}"), + FailedPreparedStatement(message) => write!(f, "Query execution failed: {message}"), + } + } +} + +impl std::fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{self}") + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + use Error::*; + match self { + CxxException(cxx) => Some(cxx), + FailedQuery(_) => None, + FailedPreparedStatement(_) => None, + } + } +} + +impl From for Error { + fn from(item: cxx::Exception) -> Self { + Error::CxxException(item) + } +} diff --git a/tools/rust_api/src/ffi.rs b/tools/rust_api/src/ffi.rs new file mode 100644 index 0000000000..319c42af1d --- /dev/null +++ b/tools/rust_api/src/ffi.rs @@ -0,0 +1,331 @@ +#[allow(clippy::module_inception)] +#[cxx::bridge] +pub(crate) mod ffi { + // From types.h + // Note: cxx will check if values change, but not if they are added. + #[namespace = "kuzu::common"] + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + enum LogicalTypeID { + ANY = 0, + NODE = 10, + REL = 11, + RECURSIVE_REL = 12, + // SERIAL is a special data type that is used to represent a sequence of INT64 values that are + // incremented by 1 starting from 0. + SERIAL = 13, + + // fixed size types + BOOL = 22, + INT64 = 23, + INT32 = 24, + INT16 = 25, + DOUBLE = 26, + FLOAT = 27, + DATE = 28, + TIMESTAMP = 29, + INTERVAL = 30, + FIXED_LIST = 31, + + INTERNAL_ID = 40, + + ARROW_COLUMN = 41, + + // variable size types + STRING = 50, + BLOB = 51, + VAR_LIST = 52, + STRUCT = 53, + MAP = 54, + UNION = 55, + } + #[namespace = "kuzu::common"] + unsafe extern "C++" { + type LogicalTypeID; + } + + #[namespace = "kuzu::main"] + unsafe extern "C++" { + include!("kuzu/include/kuzu_rs.h"); + + type PreparedStatement; + fn isSuccess(&self) -> bool; + + #[namespace = "kuzu_rs"] + fn prepared_statement_error_message(statement: &PreparedStatement) -> String; + } + + #[namespace = "kuzu_rs"] + unsafe extern "C++" { + type QueryParams; + + // Simple types which cross the ffi without problems + // Non-copyable types are references so that they only need to be cloned on the + // C++ side of things + fn insert(self: Pin<&mut Self>, key: &str, value: UniquePtr); + + fn new_params() -> UniquePtr; + } + + #[namespace = "kuzu_rs"] + unsafe extern "C++" { + #[namespace = "kuzu::main"] + type Database; + + fn new_database( + databasePath: &CxxString, + bufferPoolSize: u64, + ) -> Result>; + + fn database_set_logging_level(database: Pin<&mut Database>, level: &CxxString); + } + + #[namespace = "kuzu::main"] + unsafe extern "C++" { + // The C++ Connection class includes a pointer to the database. + // We must not destroy a referenced database while a connection is open. + type Connection<'a>; + + #[namespace = "kuzu_rs"] + fn database_connect<'a>( + database: Pin<&'a mut Database>, + ) -> Result>>; + + fn prepare( + self: Pin<&mut Connection>, + query: &CxxString, + ) -> Result>; + + #[namespace = "kuzu_rs"] + fn connection_execute( + connection: Pin<&mut Connection>, + query: Pin<&mut PreparedStatement>, + params: UniquePtr, + ) -> Result>; + + fn getMaxNumThreadForExec(self: Pin<&mut Connection>) -> u64; + fn setMaxNumThreadForExec(self: Pin<&mut Connection>, num_threads: u64); + fn beginReadOnlyTransaction(self: Pin<&mut Connection>) -> Result<()>; + fn beginWriteTransaction(self: Pin<&mut Connection>) -> Result<()>; + fn commit(self: Pin<&mut Connection>) -> Result<()>; + fn rollback(self: Pin<&mut Connection>) -> Result<()>; + fn interrupt(self: Pin<&mut Connection>) -> Result<()>; + fn setQueryTimeOut(self: Pin<&mut Connection>, timeout_ms: u64); + + #[namespace = "kuzu_rs"] + fn get_node_table_names(conn: Pin<&mut Connection>) -> String; + #[namespace = "kuzu_rs"] + fn get_rel_table_names(conn: Pin<&mut Connection>) -> String; + #[namespace = "kuzu_rs"] + fn get_node_property_names(conn: Pin<&mut Connection>, node_table_name: &str) -> String; + #[namespace = "kuzu_rs"] + fn get_rel_property_names(conn: Pin<&mut Connection>, rel_table_name: &str) -> String; + } + + #[namespace = "kuzu::main"] + unsafe extern "C++" { + type QueryResult; + + #[namespace = "kuzu_rs"] + fn query_result_to_string(query_result: Pin<&mut QueryResult>) -> String; + fn isSuccess(&self) -> bool; + #[namespace = "kuzu_rs"] + fn query_result_get_error_message(query_result: &QueryResult) -> String; + fn hasNext(&self) -> bool; + fn getNext(self: Pin<&mut QueryResult>) -> SharedPtr; + + #[namespace = "kuzu_rs"] + fn query_result_get_compiling_time(result: &QueryResult) -> f64; + #[namespace = "kuzu_rs"] + fn query_result_get_execution_time(result: &QueryResult) -> f64; + fn getNumColumns(&self) -> usize; + fn getNumTuples(&self) -> u64; + #[namespace = "kuzu_rs"] + fn query_result_write_to_csv( + query_result: Pin<&mut QueryResult>, + filename: &String, + delimiter: i8, + escape_character: i8, + newline: i8, + ) -> Result<()>; + + #[namespace = "kuzu_rs"] + fn query_result_column_data_types( + query_result: &QueryResult, + ) -> UniquePtr>; + #[namespace = "kuzu_rs"] + fn query_result_column_names(query_result: &QueryResult) -> Vec; + } + + #[namespace = "kuzu::processor"] + unsafe extern "C++" { + type FlatTuple; + + fn len(&self) -> u32; + #[namespace = "kuzu_rs"] + fn flat_tuple_get_value(tuple: &FlatTuple, index: u32) -> &Value; + } + + #[namespace = "kuzu_rs"] + unsafe extern "C++" { + type ValueList<'a>; + + fn size<'a>(&'a self) -> u64; + fn get<'a>(&'a self, index: u64) -> &'a UniquePtr; + } + + #[namespace = "kuzu_rs"] + unsafe extern "C++" { + #[namespace = "kuzu::common"] + type LogicalType; + + #[namespace = "kuzu::common"] + fn getLogicalTypeID(&self) -> LogicalTypeID; + + fn create_logical_type(id: LogicalTypeID) -> UniquePtr; + fn create_logical_type_var_list( + child_type: UniquePtr, + ) -> UniquePtr; + fn create_logical_type_fixed_list( + child_type: UniquePtr, + num_elements: u64, + ) -> UniquePtr; + fn create_logical_type_struct( + field_names: &Vec, + types: UniquePtr, + ) -> UniquePtr; + + fn logical_type_get_var_list_child_type(value: &LogicalType) -> &LogicalType; + fn logical_type_get_fixed_list_child_type(value: &LogicalType) -> &LogicalType; + fn logical_type_get_fixed_list_num_elements(value: &LogicalType) -> u64; + fn logical_type_get_struct_field_names(value: &LogicalType) -> Vec; + fn logical_type_get_struct_field_types( + value: &LogicalType, + ) -> UniquePtr>; + } + + #[namespace = "kuzu_rs"] + unsafe extern "C++" { + type ValueListBuilder; + + fn insert(self: Pin<&mut ValueListBuilder>, value: UniquePtr); + fn get_list_value( + typ: UniquePtr, + value: UniquePtr, + ) -> UniquePtr; + fn create_list() -> UniquePtr; + } + + #[namespace = "kuzu_rs"] + unsafe extern "C++" { + type TypeListBuilder; + + fn insert(self: Pin<&mut TypeListBuilder>, typ: UniquePtr); + fn create_type_list() -> UniquePtr; + } + + #[namespace = "kuzu_rs"] + unsafe extern "C++" { + #[namespace = "kuzu::common"] + type Value; + + // only used by tests + #[allow(dead_code)] + fn value_to_string(node_value: &Value) -> String; + + #[rust_name = "get_value_bool"] + fn getValue(&self) -> bool; + #[rust_name = "get_value_i16"] + fn getValue(&self) -> i16; + #[rust_name = "get_value_i32"] + fn getValue(&self) -> i32; + #[rust_name = "get_value_i64"] + fn getValue(&self) -> i64; + #[rust_name = "get_value_float"] + fn getValue(&self) -> f32; + #[rust_name = "get_value_double"] + fn getValue(&self) -> f64; + + fn value_get_string(value: &Value) -> String; + #[rust_name = "value_get_node_val"] + fn value_get_unique(value: &Value) -> UniquePtr; + #[rust_name = "value_get_rel_val"] + fn value_get_unique(value: &Value) -> UniquePtr; + fn value_get_interval_secs(value: &Value) -> i64; + fn value_get_interval_micros(value: &Value) -> i32; + fn value_get_timestamp_micros(value: &Value) -> i64; + fn value_get_date_days(value: &Value) -> i32; + fn value_get_internal_id(value: &Value) -> [u64; 2]; + fn value_get_list(value: &Value) -> UniquePtr; + + fn value_get_data_type_id(value: &Value) -> LogicalTypeID; + fn value_get_data_type(value: &Value) -> UniquePtr; + + fn isNull(&self) -> bool; + + #[rust_name = "create_value_bool"] + fn create_value(value: bool) -> UniquePtr; + #[rust_name = "create_value_i16"] + fn create_value(value: i16) -> UniquePtr; + #[rust_name = "create_value_i32"] + fn create_value(value: i32) -> UniquePtr; + #[rust_name = "create_value_i64"] + fn create_value(value: i64) -> UniquePtr; + #[rust_name = "create_value_float"] + fn create_value(value: f32) -> UniquePtr; + #[rust_name = "create_value_double"] + fn create_value(value: f64) -> UniquePtr; + + fn create_value_null(typ: UniquePtr) -> UniquePtr; + fn create_value_string(value: &String) -> UniquePtr; + fn create_value_timestamp(value: i64) -> UniquePtr; + fn create_value_date(value: i64) -> UniquePtr; + fn create_value_interval(months: i32, days: i32, micros: i64) -> UniquePtr; + fn create_value_internal_id(offset: u64, table: u64) -> UniquePtr; + fn create_value_node( + id_val: UniquePtr, + label_val: UniquePtr, + ) -> UniquePtr; + fn create_value_rel( + src_id: UniquePtr, + dst_id: UniquePtr, + label_val: UniquePtr, + ) -> UniquePtr; + + fn value_add_property(value: Pin<&mut Value>, name: &String, property: UniquePtr); + } + + #[namespace = "kuzu_rs"] + unsafe extern "C++" { + type PropertyList<'a>; + + fn size<'a>(&'a self) -> usize; + fn get_name<'a>(&'a self, index: usize) -> String; + fn get_value<'a>(&'a self, index: usize) -> &'a Value; + } + + #[namespace = "kuzu_rs"] + unsafe extern "C++" { + #[namespace = "kuzu::common"] + type NodeVal; + + #[rust_name = "node_value_get_properties"] + fn value_get_properties(node_value: &NodeVal) -> UniquePtr; + fn node_value_get_node_id(value: &NodeVal) -> [u64; 2]; + #[rust_name = "node_value_get_label_name"] + fn value_get_label_name(value: &NodeVal) -> String; + } + + #[namespace = "kuzu_rs"] + unsafe extern "C++" { + #[namespace = "kuzu::common"] + type RelVal; + + #[rust_name = "rel_value_get_properties"] + fn value_get_properties(node_value: &RelVal) -> UniquePtr; + #[rust_name = "rel_value_get_label_name"] + fn value_get_label_name(value: &RelVal) -> String; + + fn rel_value_get_src_id(value: &RelVal) -> [u64; 2]; + fn rel_value_get_dst_id(value: &RelVal) -> [u64; 2]; + } +} diff --git a/tools/rust_api/src/kuzu_rs.cpp b/tools/rust_api/src/kuzu_rs.cpp new file mode 100644 index 0000000000..15788d7c2a --- /dev/null +++ b/tools/rust_api/src/kuzu_rs.cpp @@ -0,0 +1,266 @@ +#include "kuzu_rs.h" + +using kuzu::common::FixedListTypeInfo; +using kuzu::common::Interval; +using kuzu::common::LogicalType; +using kuzu::common::LogicalTypeID; +using kuzu::common::StructField; +using kuzu::common::Value; +using kuzu::common::VarListTypeInfo; +using kuzu::main::Connection; +using kuzu::main::Database; +using kuzu::main::SystemConfig; + +namespace kuzu_rs { + +std::unique_ptr new_params() { + return std::make_unique(); +} + +std::unique_ptr create_logical_type(kuzu::common::LogicalTypeID id) { + return std::make_unique(id); +} +std::unique_ptr create_logical_type_var_list(std::unique_ptr childType) { + return std::make_unique( + LogicalTypeID::VAR_LIST, std::make_unique(std::move(childType))); +} + +std::unique_ptr create_logical_type_fixed_list( + std::unique_ptr childType, uint64_t numElements) { + return std::make_unique(LogicalTypeID::FIXED_LIST, + std::make_unique(std::move(childType), numElements)); +} + +std::unique_ptr create_logical_type_struct( + const rust::Vec& fieldNames, std::unique_ptr fieldTypes) { + std::vector> fields; + for (auto i = 0; i < fieldNames.size(); i++) { + fields.push_back(std::make_unique( + std::string(fieldNames[i]), std::move(fieldTypes->types[i]))); + } + return std::make_unique( + LogicalTypeID::STRUCT, std::make_unique(std::move(fields))); +} + +const LogicalType& logical_type_get_var_list_child_type(const LogicalType& logicalType) { + return *kuzu::common::VarListType::getChildType(&logicalType); +} +const LogicalType& logical_type_get_fixed_list_child_type(const LogicalType& logicalType) { + return *kuzu::common::FixedListType::getChildType(&logicalType); +} +uint64_t logical_type_get_fixed_list_num_elements(const LogicalType& logicalType) { + return kuzu::common::FixedListType::getNumElementsInList(&logicalType); +} + +rust::Vec logical_type_get_struct_field_names( + const kuzu::common::LogicalType& value) { + rust::Vec names; + for (auto name : kuzu::common::StructType::getFieldNames(&value)) { + names.push_back(name); + } + return names; +} + +std::unique_ptr> logical_type_get_struct_field_types( + const kuzu::common::LogicalType& value) { + std::vector result; + for (auto type : kuzu::common::StructType::getFieldTypes(&value)) { + result.push_back(*type); + } + return std::make_unique>(result); +} + +std::unique_ptr new_database(const std::string& databasePath, uint64_t bufferPoolSize) { + auto systemConfig = SystemConfig(); + if (bufferPoolSize > 0) { + systemConfig.bufferPoolSize = bufferPoolSize; + } + return std::make_unique(databasePath, systemConfig); +} + +void database_set_logging_level(Database& database, const std::string& level) { + database.setLoggingLevel(level); +} + +std::unique_ptr database_connect(kuzu::main::Database& database) { + return std::make_unique(&database); +} + +std::unique_ptr connection_execute(kuzu::main::Connection& connection, + kuzu::main::PreparedStatement& query, std::unique_ptr params) { + return connection.executeWithParams(&query, params->inputParams); +} + +rust::String get_node_table_names(Connection& connection) { + return rust::String(connection.getNodeTableNames()); +} +rust::String get_rel_table_names(Connection& connection) { + return rust::String(connection.getRelTableNames()); +} +rust::String get_node_property_names(Connection& connection, rust::Str tableName) { + return rust::String(connection.getNodePropertyNames(std::string(tableName))); +} +rust::String get_rel_property_names(Connection& connection, rust::Str relTableName) { + return rust::String(connection.getRelPropertyNames(std::string(relTableName))); +} + +rust::String prepared_statement_error_message(const kuzu::main::PreparedStatement& statement) { + return rust::String(statement.getErrorMessage()); +} + +rust::String query_result_to_string(kuzu::main::QueryResult& result) { + return rust::String(result.toString()); +} + +rust::String query_result_get_error_message(const kuzu::main::QueryResult& result) { + return rust::String(result.getErrorMessage()); +} + +double query_result_get_compiling_time(const kuzu::main::QueryResult& result) { + return result.getQuerySummary()->getCompilingTime(); +} +double query_result_get_execution_time(const kuzu::main::QueryResult& result) { + return result.getQuerySummary()->getExecutionTime(); +} + +void query_result_write_to_csv(kuzu::main::QueryResult& query_result, const rust::String& filename, + int8_t delimiter, int8_t escape_character, int8_t newline) { + query_result.writeToCSV( + std::string(filename), (char)delimiter, (char)escape_character, (char)newline); +} + +std::unique_ptr> query_result_column_data_types( + const kuzu::main::QueryResult& query_result) { + return std::make_unique>( + query_result.getColumnDataTypes()); +} +rust::Vec query_result_column_names(const kuzu::main::QueryResult& query_result) { + rust::Vec names; + for (auto name : query_result.getColumnNames()) { + names.push_back(name); + } + return names; +} + +std::array node_value_get_node_id(const kuzu::common::NodeVal& val) { + auto internalID = val.getNodeID(); + return std::array{internalID.offset, internalID.tableID}; +} + +std::array rel_value_get_src_id(const kuzu::common::RelVal& val) { + auto internalID = val.getSrcNodeID(); + return std::array{internalID.offset, internalID.tableID}; +} +std::array rel_value_get_dst_id(const kuzu::common::RelVal& val) { + auto internalID = val.getDstNodeID(); + return std::array{internalID.offset, internalID.tableID}; +} + +rust::String value_to_string(const kuzu::common::Value& val) { + return rust::String(val.toString()); +} + +uint32_t flat_tuple_len(const kuzu::processor::FlatTuple& flatTuple) { + return flatTuple.len(); +} +const kuzu::common::Value& flat_tuple_get_value( + const kuzu::processor::FlatTuple& flatTuple, uint32_t index) { + return *flatTuple.getValue(index); +} + +rust::String value_get_string(const kuzu::common::Value& value) { + return value.getValue(); +} +int64_t value_get_interval_secs(const kuzu::common::Value& value) { + auto interval = value.getValue(); + return (interval.months * Interval::DAYS_PER_MONTH + interval.days) * Interval::HOURS_PER_DAY * + Interval::MINS_PER_HOUR * Interval::SECS_PER_MINUTE + // Include extra microseconds with the seconds + + interval.micros / Interval::MICROS_PER_SEC; +} +int32_t value_get_interval_micros(const kuzu::common::Value& value) { + auto interval = value.getValue(); + return interval.micros % Interval::MICROS_PER_SEC; +} +int32_t value_get_date_days(const kuzu::common::Value& value) { + return value.getValue().days; +} +int64_t value_get_timestamp_micros(const kuzu::common::Value& value) { + return value.getValue().value; +} +std::array value_get_internal_id(const kuzu::common::Value& value) { + auto internalID = value.getValue(); + return std::array{internalID.offset, internalID.tableID}; +} + +std::unique_ptr value_get_list(const kuzu::common::Value& value) { + return std::make_unique(value.getListValReference()); +} +kuzu::common::LogicalTypeID value_get_data_type_id(const kuzu::common::Value& value) { + return value.getDataType().getLogicalTypeID(); +} +std::unique_ptr value_get_data_type(const kuzu::common::Value& value) { + return std::make_unique(value.getDataType()); +} + +std::unique_ptr create_value_string(const rust::String& value) { + return std::make_unique( + LogicalType(LogicalTypeID::STRING), std::string(value)); +} +std::unique_ptr create_value_timestamp(const int64_t timestamp) { + return std::make_unique(kuzu::common::timestamp_t(timestamp)); +} +std::unique_ptr create_value_date(const int64_t date) { + return std::make_unique(kuzu::common::date_t(date)); +} +std::unique_ptr create_value_interval( + const int32_t months, const int32_t days, const int64_t micros) { + return std::make_unique(kuzu::common::interval_t(months, days, micros)); +} +std::unique_ptr create_value_null( + std::unique_ptr typ) { + return std::make_unique( + kuzu::common::Value::createNullValue(kuzu::common::LogicalType(*typ))); +} +std::unique_ptr create_value_internal_id(uint64_t offset, uint64_t table) { + return std::make_unique(kuzu::common::internalID_t(offset, table)); +} +std::unique_ptr create_value_node( + std::unique_ptr id_val, std::unique_ptr label_val) { + return std::make_unique( + std::make_unique(std::move(id_val), std::move(label_val))); +} + +std::unique_ptr create_value_rel(std::unique_ptr src_id, + std::unique_ptr dst_id, std::unique_ptr label_val) { + return std::make_unique(std::make_unique( + std::move(src_id), std::move(dst_id), std::move(label_val))); +} + +std::unique_ptr get_list_value( + std::unique_ptr typ, std::unique_ptr value) { + return std::make_unique(std::move(*typ.get()), std::move(value->values)); +} + +std::unique_ptr create_list() { + return std::make_unique(); +} + +std::unique_ptr create_type_list() { + return std::make_unique(); +} + +void value_add_property(kuzu::common::Value& val, const rust::String& name, + std::unique_ptr property) { + if (val.getDataType().getLogicalTypeID() == kuzu::common::LogicalTypeID::NODE) { + kuzu::common::NodeVal& nodeVal = val.getValueReference(); + nodeVal.addProperty(std::string(name), std::move(property)); + } else if (val.getDataType().getLogicalTypeID() == kuzu::common::LogicalTypeID::REL) { + kuzu::common::RelVal& relVal = val.getValueReference(); + relVal.addProperty(std::string(name), std::move(property)); + } else { + throw std::runtime_error("Internal Error! Adding property to type without properties!"); + } +} + +} // namespace kuzu_rs diff --git a/tools/rust_api/src/lib.rs b/tools/rust_api/src/lib.rs new file mode 100644 index 0000000000..a2304aca56 --- /dev/null +++ b/tools/rust_api/src/lib.rs @@ -0,0 +1,43 @@ +//! Bindings to Kùzu: an in-process property graph database management system built for query speed and scalability. +//! +//! ## Example Usage +//! ``` +//! use kuzu::{Database, Connection}; +//! # use anyhow::Error; +//! +//! # fn main() -> Result<(), Error> { +//! # let temp_dir = tempdir::TempDir::new("example")?; +//! # let path = temp_dir.path(); +//! let db = Database::new(path, 0)?; +//! let conn = Connection::new(&db)?; +//! conn.query("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY(name));")?; +//! conn.query("CREATE (:Person {name: 'Alice', age: 25});")?; +//! conn.query("CREATE (:Person {name: 'Bob', age: 30});")?; +//! +//! let mut result = conn.query("MATCH (a:Person) RETURN a.name AS NAME, a.age AS AGE;")?; +//! println!("{}", result.display()); +//! # temp_dir.close()?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Safety +//! +//! Generally, use of of this API is safe, however creating multiple databases in the same +//! scope is not safe. +//! If you need to access multiple databases you will need to do so in separate processes. + +mod connection; +mod database; +mod error; +mod ffi; +mod logical_type; +mod query_result; +mod value; + +pub use connection::{Connection, PreparedStatement}; +pub use database::{Database, LoggingLevel}; +pub use error::Error; +pub use logical_type::LogicalType; +pub use query_result::{CSVOptions, QueryResult}; +pub use value::{InternalID, NodeVal, RelVal, Value}; diff --git a/tools/rust_api/src/logical_type.rs b/tools/rust_api/src/logical_type.rs new file mode 100644 index 0000000000..352d3bdf72 --- /dev/null +++ b/tools/rust_api/src/logical_type.rs @@ -0,0 +1,160 @@ +use crate::ffi::ffi; + +/// Type of [Value](crate::value::Value)s produced and consumed by queries. +/// +/// Includes extra type information beyond what can be encoded in [Value](crate::value::Value) such as +/// struct fields and types of lists +#[derive(Clone, Debug, PartialEq)] +pub enum LogicalType { + /// Special type for use with [Value::Null](crate::value::Value::Null) + Any, + /// Correponds to [Value::Bool](crate::value::Value::Bool) + Bool, + /// Correponds to [Value::Int64](crate::value::Value::Int64) + Int64, + /// Correponds to [Value::Int32](crate::value::Value::Int32) + Int32, + /// Correponds to [Value::Int16](crate::value::Value::Int16) + Int16, + /// Correponds to [Value::Double](crate::value::Value::Double) + Double, + /// Correponds to [Value::Float](crate::value::Value::Float) + Float, + /// Correponds to [Value::Date](crate::value::Value::Date) + Date, + /// Correponds to [Value::Interval](crate::value::Value::Interval) + Interval, + /// Correponds to [Value::Timestamp](crate::value::Value::Timestamp) + Timestamp, + /// Correponds to [Value::InternalID](crate::value::Value::InternalID) + InternalID, + /// Correponds to [Value::String](crate::value::Value::String) + String, + /// Correponds to [Value::VarList](crate::value::Value::VarList) + VarList { child_type: Box }, + /// Correponds to [Value::FixedList](crate::value::Value::FixedList) + FixedList { + child_type: Box, + num_elements: u64, + }, + /// Correponds to [Value::Struct](crate::value::Value::Struct) + Struct { fields: Vec<(String, LogicalType)> }, + /// Correponds to [Value::Node](crate::value::Value::Node) + Node, + /// Correponds to [Value::Rel](crate::value::Value::Rel) + Rel, +} + +impl From<&ffi::Value> for LogicalType { + fn from(value: &ffi::Value) -> Self { + ffi::value_get_data_type(value).as_ref().unwrap().into() + } +} + +impl From<&ffi::LogicalType> for LogicalType { + fn from(logical_type: &ffi::LogicalType) -> Self { + use ffi::LogicalTypeID; + match logical_type.getLogicalTypeID() { + LogicalTypeID::ANY => LogicalType::Any, + LogicalTypeID::BOOL => LogicalType::Bool, + LogicalTypeID::INT16 => LogicalType::Int16, + LogicalTypeID::INT32 => LogicalType::Int32, + LogicalTypeID::INT64 => LogicalType::Int64, + LogicalTypeID::FLOAT => LogicalType::Float, + LogicalTypeID::DOUBLE => LogicalType::Double, + LogicalTypeID::STRING => LogicalType::String, + LogicalTypeID::INTERVAL => LogicalType::Interval, + LogicalTypeID::DATE => LogicalType::Date, + LogicalTypeID::TIMESTAMP => LogicalType::Timestamp, + LogicalTypeID::INTERNAL_ID => LogicalType::InternalID, + LogicalTypeID::VAR_LIST => LogicalType::VarList { + child_type: Box::new( + ffi::logical_type_get_var_list_child_type(logical_type).into(), + ), + }, + LogicalTypeID::FIXED_LIST => LogicalType::FixedList { + child_type: Box::new( + ffi::logical_type_get_fixed_list_child_type(logical_type).into(), + ), + num_elements: ffi::logical_type_get_fixed_list_num_elements(logical_type), + }, + LogicalTypeID::STRUCT => { + let names = ffi::logical_type_get_struct_field_names(logical_type); + let types = ffi::logical_type_get_struct_field_types(logical_type); + LogicalType::Struct { + fields: names + .into_iter() + .zip(types.into_iter().map(Into::::into)) + .collect(), + } + } + LogicalTypeID::NODE => LogicalType::Node, + LogicalTypeID::REL => LogicalType::Rel, + // Should be unreachable, as cxx will check that the LogicalTypeID enum matches the one + // on the C++ side. + x => panic!("Unsupported type {:?}", x), + } + } +} + +impl From<&LogicalType> for cxx::UniquePtr { + fn from(typ: &LogicalType) -> Self { + match typ { + LogicalType::Any + | LogicalType::Bool + | LogicalType::Int64 + | LogicalType::Int32 + | LogicalType::Int16 + | LogicalType::Float + | LogicalType::Double + | LogicalType::Date + | LogicalType::Timestamp + | LogicalType::Interval + | LogicalType::InternalID + | LogicalType::String + | LogicalType::Node + | LogicalType::Rel => ffi::create_logical_type(typ.id()), + LogicalType::VarList { child_type } => { + ffi::create_logical_type_var_list(child_type.as_ref().into()) + } + LogicalType::FixedList { + child_type, + num_elements, + } => ffi::create_logical_type_fixed_list(child_type.as_ref().into(), *num_elements), + LogicalType::Struct { fields } => { + let mut builder = ffi::create_type_list(); + let mut names = vec![]; + for (name, typ) in fields { + names.push(name.clone()); + builder.pin_mut().insert(typ.into()); + } + ffi::create_logical_type_struct(&names, builder) + } + } + } +} + +impl LogicalType { + pub(crate) fn id(&self) -> ffi::LogicalTypeID { + use ffi::LogicalTypeID; + match self { + LogicalType::Any => LogicalTypeID::ANY, + LogicalType::Bool => LogicalTypeID::BOOL, + LogicalType::Int16 => LogicalTypeID::INT16, + LogicalType::Int32 => LogicalTypeID::INT32, + LogicalType::Int64 => LogicalTypeID::INT64, + LogicalType::Float => LogicalTypeID::FLOAT, + LogicalType::Double => LogicalTypeID::DOUBLE, + LogicalType::String => LogicalTypeID::STRING, + LogicalType::Interval => LogicalTypeID::INTERVAL, + LogicalType::Date => LogicalTypeID::DATE, + LogicalType::Timestamp => LogicalTypeID::TIMESTAMP, + LogicalType::InternalID => LogicalTypeID::INTERNAL_ID, + LogicalType::VarList { .. } => LogicalTypeID::VAR_LIST, + LogicalType::FixedList { .. } => LogicalTypeID::FIXED_LIST, + LogicalType::Struct { .. } => LogicalTypeID::STRUCT, + LogicalType::Node => LogicalTypeID::NODE, + LogicalType::Rel => LogicalTypeID::REL, + } + } +} diff --git a/tools/rust_api/src/query_result.rs b/tools/rust_api/src/query_result.rs new file mode 100644 index 0000000000..032caf7ac4 --- /dev/null +++ b/tools/rust_api/src/query_result.rs @@ -0,0 +1,220 @@ +use crate::ffi::ffi; +use crate::logical_type::LogicalType; +use crate::value::Value; +use cxx::UniquePtr; +use std::convert::TryInto; +use std::fmt; + +/// Stores the result of a query execution +pub struct QueryResult { + pub(crate) result: UniquePtr, +} + +// Should be safe to move across threads, however access is not synchronized +unsafe impl Send for ffi::QueryResult {} + +/// Options for writing CSV files +pub struct CSVOptions { + delimiter: char, + escape_character: char, + newline: char, +} + +impl Default for CSVOptions { + /// Default CSV options with delimiter `,`, escape character `"` and newline `\n`. + fn default() -> Self { + CSVOptions { + delimiter: ',', + escape_character: '"', + newline: '\n', + } + } +} + +impl CSVOptions { + /// Sets the field delimiter to use when writing the CSV file. If not specified the default is + /// `,` + pub fn delimiter(mut self, delimiter: char) -> Self { + self.delimiter = delimiter; + self + } + + /// Sets the escape character to use for text containing special characters. + /// If not specified the default is `"` + pub fn escape_character(mut self, escape_character: char) -> Self { + self.escape_character = escape_character; + self + } + + /// Sets the newline character + /// If not specified the default is `\n` + pub fn newline(mut self, newline: char) -> Self { + self.newline = newline; + self + } +} + +impl QueryResult { + /// Displays the query result as a string + pub fn display(&mut self) -> String { + ffi::query_result_to_string(self.result.pin_mut()) + } + + /// Returns the time spent compiling the query in milliseconds + pub fn get_compiling_time(&self) -> f64 { + ffi::query_result_get_compiling_time(self.result.as_ref().unwrap()) + } + + /// Returns the time spent executing the query in milliseconds + pub fn get_execution_time(&self) -> f64 { + ffi::query_result_get_execution_time(self.result.as_ref().unwrap()) + } + + /// Returns the number of columns in the query result. + /// + /// This corresponds to the length of each result vector yielded by the iterator. + pub fn get_num_columns(&self) -> usize { + self.result.as_ref().unwrap().getNumColumns() + } + /// Returns the number of tuples in the query result. + /// + /// This corresponds to the total number of result + /// vectors that the query result iterator will yield. + pub fn get_num_tuples(&self) -> u64 { + self.result.as_ref().unwrap().getNumTuples() + } + + /// Returns the name of each column in the query result + pub fn get_column_names(&self) -> Vec { + ffi::query_result_column_names(self.result.as_ref().unwrap()) + } + /// Returns the data type of each column in the query result + pub fn get_column_data_types(&self) -> Vec { + ffi::query_result_column_data_types(self.result.as_ref().unwrap()) + .as_ref() + .unwrap() + .iter() + .map(|x| x.into()) + .collect() + } + + /// Writes the query result to a csv file + /// + /// # Arguments + /// * `path`: The path of the output csv file + /// * `options`: Custom CSV output options + /// + /// ```ignore + /// result.write_to_csv("output.csv", CSVOptions::default().delimiter(','))?; + /// ``` + pub fn write_to_csv>( + &mut self, + path: P, + options: CSVOptions, + ) -> Result<(), crate::error::Error> { + Ok(ffi::query_result_write_to_csv( + self.result.pin_mut(), + &path.as_ref().display().to_string(), + options.delimiter as i8, + options.escape_character as i8, + options.newline as i8, + )?) + } +} + +// the underlying C++ type is both data and an iterator (sort-of) +impl Iterator for QueryResult { + // we will be counting with usize + type Item = Vec; + + // next() is the only required method + fn next(&mut self) -> Option { + if self.result.as_ref().unwrap().hasNext() { + let flat_tuple = self.result.pin_mut().getNext(); + let mut result = vec![]; + for i in 0..flat_tuple.as_ref().unwrap().len() { + let value = ffi::flat_tuple_get_value(flat_tuple.as_ref().unwrap(), i); + // TODO: Return result instead of unwrapping? + // Unfortunately, as an iterator, this would require producing + // Vec>, though it would be possible to turn that into + // Result> instead, but it would lose information when multiple failures + // occur. + result.push(value.try_into().unwrap()); + } + Some(result) + } else { + None + } + } +} + +impl fmt::Debug for QueryResult { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("QueryResult") + .field( + "result", + &"Opaque C++ data which whose toString method requires mutation".to_string(), + ) + .finish() + } +} + +/* TODO: QueryResult.toString() needs to be const +impl std::fmt::Display for QueryResult { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", ffi::query_result_to_string(self.result.as_ref().unwrap())) + } +} +*/ + +#[cfg(test)] +mod tests { + use crate::connection::Connection; + use crate::database::Database; + use crate::logical_type::LogicalType; + use crate::query_result::CSVOptions; + #[test] + fn test_query_result_metadata() -> anyhow::Result<()> { + let temp_dir = tempdir::TempDir::new("example")?; + let db = Database::new(temp_dir.path(), 0)?; + let connection = Connection::new(&db)?; + + // Create schema. + connection.query("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY(name));")?; + // Create nodes. + connection.query("CREATE (:Person {name: 'Alice', age: 25});")?; + connection.query("CREATE (:Person {name: 'Bob', age: 30});")?; + + // Execute a simple query. + let result = connection.query("MATCH (a:Person) RETURN a.name AS NAME, a.age AS AGE;")?; + + assert!(result.get_compiling_time() > 0.); + assert!(result.get_execution_time() > 0.); + assert_eq!(result.get_column_names(), vec!["NAME", "AGE"]); + assert_eq!( + result.get_column_data_types(), + vec![LogicalType::String, LogicalType::Int64] + ); + temp_dir.close()?; + Ok(()) + } + + #[test] + fn test_csv() -> anyhow::Result<()> { + let temp_dir = tempdir::TempDir::new("example")?; + let path = temp_dir.path(); + let db = Database::new(path, 0)?; + let conn = Connection::new(&db)?; + conn.query("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY(name));")?; + conn.query("CREATE (:Person {name: 'Alice', age: 25});")?; + let mut result = conn.query("MATCH (a:Person) RETURN a.name AS NAME, a.age AS AGE;")?; + result.write_to_csv( + path.join("output.csv"), + CSVOptions::default().delimiter(','), + )?; + let data = std::fs::read_to_string(path.join("output.csv"))?; + assert_eq!(data, "Alice,25\n"); + temp_dir.close()?; + Ok(()) + } +} diff --git a/tools/rust_api/src/value.rs b/tools/rust_api/src/value.rs new file mode 100644 index 0000000000..ad0bbff255 --- /dev/null +++ b/tools/rust_api/src/value.rs @@ -0,0 +1,878 @@ +use crate::ffi::ffi; +use crate::logical_type::LogicalType; +use std::cmp::Ordering; +use std::convert::{TryFrom, TryInto}; +use std::fmt; + +pub enum ConversionError { + /// Kuzu's internal date as the number of days since 1970-01-01 + Date(i32), + /// Kuzu's internal timestamp as the number of microseconds since 1970-01-01 + Timestamp(i64), +} + +impl std::fmt::Display for ConversionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{self:?}") + } +} + +impl std::fmt::Debug for ConversionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use ConversionError::*; + match self { + Date(days) => write!(f, "Could not convert Kuzu date offset of UNIX_EPOCH + {days} days to time::Date"), + Timestamp(us) => write!(f, "Could not convert Kuzu timestamp offset of UNIX_EPOCH + {us} microseconds to time::OffsetDateTime"), + } + } +} + +impl std::error::Error for ConversionError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } +} + +/// NodeVal represents a node in the graph and stores the nodeID, label and properties of that +/// node. +#[derive(Clone, Debug, PartialEq)] +pub struct NodeVal { + id: InternalID, + label: String, + properties: Vec<(String, Value)>, +} + +impl NodeVal { + pub fn new(id: InternalID, label: String) -> Self { + NodeVal { + id, + label, + properties: vec![], + } + } + + pub fn get_node_id(&self) -> &InternalID { + &self.id + } + + pub fn get_label_name(&self) -> &String { + &self.label + } + + /// Adds a property with the given key/value pair to the NodeVal + /// # Arguments + /// * `key`: The name of the property + /// * `value`: The value of the property + pub fn add_property(&mut self, key: String, value: Value) { + self.properties.push((key, value)); + } + + /// Returns all properties of the NodeVal + pub fn get_properties(&self) -> &Vec<(String, Value)> { + &self.properties + } +} + +fn properties_display( + f: &mut fmt::Formatter<'_>, + properties: &Vec<(String, Value)>, +) -> fmt::Result { + write!(f, "{{")?; + for (index, (name, value)) in properties.iter().enumerate() { + write!(f, "{}:{}", name, value)?; + if index < properties.len() - 1 { + write!(f, ",")?; + } + } + write!(f, "}}") +} + +impl std::fmt::Display for NodeVal { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "(label:{}, {}, ", self.label, self.id)?; + properties_display(f, &self.properties)?; + write!(f, ")") + } +} + +/// RelVal represents a relationship in the graph and stores the relID, src/dst nodes and properties of that +/// rel +#[derive(Clone, Debug, PartialEq)] +pub struct RelVal { + src_node: InternalID, + dst_node: InternalID, + label: String, + properties: Vec<(String, Value)>, +} + +impl RelVal { + pub fn new(src_node: InternalID, dst_node: InternalID, label: String) -> Self { + RelVal { + src_node, + dst_node, + label, + properties: vec![], + } + } + + pub fn get_src_node(&self) -> &InternalID { + &self.src_node + } + pub fn get_dst_node(&self) -> &InternalID { + &self.dst_node + } + + pub fn get_label_name(&self) -> &String { + &self.label + } + + /// Adds a property with the given key/value pair to the NodeVal + /// # Arguments + /// * `key`: The name of the property + /// * `value`: The value of the property + pub fn add_property(&mut self, key: String, value: Value) { + self.properties.push((key, value)); + } + + /// Returns all properties of the RelVal + pub fn get_properties(&self) -> &Vec<(String, Value)> { + &self.properties + } +} + +impl std::fmt::Display for RelVal { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "({})-[label:{}, ", self.src_node, self.label)?; + properties_display(f, &self.properties)?; + write!(f, "]->({})", self.dst_node) + } +} + +/// Stores the table_id and offset of a node/rel. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct InternalID { + pub offset: u64, + pub table_id: u64, +} + +impl std::fmt::Display for InternalID { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}:{}", self.table_id, self.offset) + } +} + +impl PartialOrd for InternalID { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for InternalID { + fn cmp(&self, other: &Self) -> Ordering { + if self.table_id == other.table_id { + self.offset.cmp(&other.offset) + } else { + self.table_id.cmp(&other.table_id) + } + } +} + +/// Data types supported by Kùzu +/// +/// Also see +#[derive(Clone, Debug, PartialEq)] +pub enum Value { + Null(LogicalType), + Bool(bool), + Int64(i64), + Int32(i32), + Int16(i16), + Double(f64), + Float(f32), + /// Stored internally as the number of days since 1970-01-01 as a 32-bit signed integer, which + /// allows for a wider range of dates to be stored than can be represented by time::Date + /// + /// + Date(time::Date), + /// May be signed or unsigned. + /// + /// Nanosecond precision of time::Duration (if available) will not be preserved when passed to + /// queries, and results will always have at most microsecond precision. + /// + /// + Interval(time::Duration), + /// Stored internally as the number of microseconds since 1970-01-01 + /// Nanosecond precision of SystemTime (if available) will not be preserved when used. + /// + /// + Timestamp(time::OffsetDateTime), + InternalID(InternalID), + /// + String(String), + // TODO: Enforce type of contents + // LogicalType is necessary so that we can pass the correct type to the C++ API if the list is empty. + /// These must contain elements which are all the given type. + /// + VarList(LogicalType, Vec), + /// These must contain elements which are all the same type. + /// + FixedList(LogicalType, Vec), + /// + Struct(Vec<(String, Value)>), + Node(NodeVal), + Rel(RelVal), +} + +impl std::fmt::Display for Value { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Value::Bool(true) => write!(f, "True"), + Value::Bool(false) => write!(f, "False"), + Value::Int16(x) => write!(f, "{x}"), + Value::Int32(x) => write!(f, "{x}"), + Value::Int64(x) => write!(f, "{x}"), + Value::Date(x) => write!(f, "{x}"), + Value::String(x) => write!(f, "{x}"), + Value::Null(_) => write!(f, ""), + Value::VarList(_, x) | Value::FixedList(_, x) => { + write!(f, "[")?; + for (i, value) in x.iter().enumerate() { + write!(f, "{}", value)?; + if i != x.len() - 1 { + write!(f, ",")?; + } + } + write!(f, "]") + } + // Note: These don't match kuzu's toString, but we probably don't want them to + Value::Interval(x) => write!(f, "{x}"), + Value::Timestamp(x) => write!(f, "{x}"), + Value::Float(x) => write!(f, "{x}"), + Value::Double(x) => write!(f, "{x}"), + Value::Struct(x) => { + write!(f, "{{")?; + for (i, (name, value)) in x.iter().enumerate() { + write!(f, "{}: {}", name, value)?; + if i != x.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, "}}") + } + Value::Node(x) => write!(f, "{x}"), + Value::Rel(x) => write!(f, "{x}"), + Value::InternalID(x) => write!(f, "{x}"), + } + } +} + +impl From<&Value> for LogicalType { + fn from(value: &Value) -> Self { + match value { + Value::Bool(_) => LogicalType::Bool, + Value::Int16(_) => LogicalType::Int16, + Value::Int32(_) => LogicalType::Int32, + Value::Int64(_) => LogicalType::Int64, + Value::Float(_) => LogicalType::Float, + Value::Double(_) => LogicalType::Double, + Value::Date(_) => LogicalType::Date, + Value::Interval(_) => LogicalType::Interval, + Value::Timestamp(_) => LogicalType::Timestamp, + Value::String(_) => LogicalType::String, + Value::Null(x) => x.clone(), + Value::VarList(x, _) => LogicalType::VarList { + child_type: Box::new(x.clone()), + }, + Value::FixedList(x, value) => LogicalType::FixedList { + child_type: Box::new(x.clone()), + num_elements: value.len() as u64, + }, + Value::Struct(values) => LogicalType::Struct { + fields: values + .iter() + .map(|(name, x)| { + let typ: LogicalType = x.into(); + (name.clone(), typ) + }) + .collect(), + }, + Value::InternalID(_) => LogicalType::InternalID, + Value::Node(_) => LogicalType::Node, + Value::Rel(_) => LogicalType::Rel, + } + } +} + +impl TryFrom<&ffi::Value> for Value { + type Error = ConversionError; + + fn try_from(value: &ffi::Value) -> Result { + use ffi::LogicalTypeID; + if value.isNull() { + return Ok(Value::Null(value.into())); + } + match ffi::value_get_data_type_id(value) { + LogicalTypeID::ANY => unimplemented!(), + LogicalTypeID::BOOL => Ok(Value::Bool(value.get_value_bool())), + LogicalTypeID::INT16 => Ok(Value::Int16(value.get_value_i16())), + LogicalTypeID::INT32 => Ok(Value::Int32(value.get_value_i32())), + LogicalTypeID::INT64 => Ok(Value::Int64(value.get_value_i64())), + LogicalTypeID::FLOAT => Ok(Value::Float(value.get_value_float())), + LogicalTypeID::DOUBLE => Ok(Value::Double(value.get_value_double())), + LogicalTypeID::STRING => Ok(Value::String(ffi::value_get_string(value))), + LogicalTypeID::INTERVAL => Ok(Value::Interval(time::Duration::new( + ffi::value_get_interval_secs(value), + // Duration is constructed using nanoseconds, but kuzu stores microseconds + ffi::value_get_interval_micros(value) * 1000, + ))), + LogicalTypeID::DATE => { + let days = ffi::value_get_date_days(value); + time::Date::from_calendar_date(1970, time::Month::January, 1) + .unwrap() + .checked_add(time::Duration::days(days as i64)) + .map(Value::Date) + .ok_or(ConversionError::Date(days)) + } + LogicalTypeID::TIMESTAMP => { + let us = ffi::value_get_timestamp_micros(value); + time::OffsetDateTime::UNIX_EPOCH + .checked_add(time::Duration::microseconds(us)) + .map(Value::Timestamp) + .ok_or(ConversionError::Timestamp(us)) + } + LogicalTypeID::VAR_LIST => { + let list = ffi::value_get_list(value); + let mut result = vec![]; + for index in 0..list.size() { + let value: Value = list.get(index).as_ref().unwrap().try_into()?; + result.push(value); + } + if let LogicalType::VarList { child_type } = value.into() { + Ok(Value::VarList(*child_type, result)) + } else { + unreachable!() + } + } + LogicalTypeID::FIXED_LIST => { + let list = ffi::value_get_list(value); + let mut result = vec![]; + for index in 0..list.size() { + let value: Value = list.get(index).as_ref().unwrap().try_into()?; + result.push(value); + } + if let LogicalType::FixedList { child_type, .. } = value.into() { + Ok(Value::FixedList(*child_type, result)) + } else { + unreachable!() + } + } + LogicalTypeID::STRUCT => { + // Data is a list of field values in the value itself (same as list), + // with the field names stored in the DataType + let field_names = ffi::logical_type_get_struct_field_names( + ffi::value_get_data_type(value).as_ref().unwrap(), + ); + let list = ffi::value_get_list(value); + let mut result = vec![]; + for (name, index) in field_names.into_iter().zip(0..list.size()) { + let value: Value = list.get(index).as_ref().unwrap().try_into()?; + result.push((name, value)); + } + Ok(Value::Struct(result)) + } + LogicalTypeID::NODE => { + let ffi_node_val = ffi::value_get_node_val(value); + let id = ffi::node_value_get_node_id(ffi_node_val.as_ref().unwrap()); + let id = InternalID { + offset: id[0], + table_id: id[1], + }; + let label = ffi::node_value_get_label_name(ffi_node_val.as_ref().unwrap()); + let mut node_val = NodeVal::new(id, label); + let properties = ffi::node_value_get_properties(ffi_node_val.as_ref().unwrap()); + for i in 0..properties.size() { + node_val + .add_property(properties.get_name(i), properties.get_value(i).try_into()?); + } + Ok(Value::Node(node_val)) + } + LogicalTypeID::REL => { + let ffi_rel_val = ffi::value_get_rel_val(value); + let src_node = ffi::rel_value_get_src_id(ffi_rel_val.as_ref().unwrap()); + let dst_node = ffi::rel_value_get_dst_id(ffi_rel_val.as_ref().unwrap()); + let src_node = InternalID { + offset: src_node[0], + table_id: src_node[1], + }; + let dst_node = InternalID { + offset: dst_node[0], + table_id: dst_node[1], + }; + let label = ffi::rel_value_get_label_name(ffi_rel_val.as_ref().unwrap()); + let mut rel_val = RelVal::new(src_node, dst_node, label); + let properties = ffi::rel_value_get_properties(ffi_rel_val.as_ref().unwrap()); + for i in 0..properties.size() { + rel_val + .add_property(properties.get_name(i), properties.get_value(i).try_into()?); + } + Ok(Value::Rel(rel_val)) + } + LogicalTypeID::INTERNAL_ID => { + let internal_id = ffi::value_get_internal_id(value); + Ok(Value::InternalID(InternalID { + offset: internal_id[0], + table_id: internal_id[1], + })) + } + // Should be unreachable, as cxx will check that the LogicalTypeID enum matches the one + // on the C++ side. + x => panic!("Unsupported type {:?}", x), + } + } +} + +impl TryInto> for Value { + // Errors should occur if: + // - types are heterogeneous in lists + type Error = crate::error::Error; + + fn try_into(self) -> Result, Self::Error> { + match self { + Value::Null(typ) => Ok(ffi::create_value_null((&typ).into())), + Value::Bool(value) => Ok(ffi::create_value_bool(value)), + Value::Int16(value) => Ok(ffi::create_value_i16(value)), + Value::Int32(value) => Ok(ffi::create_value_i32(value)), + Value::Int64(value) => Ok(ffi::create_value_i64(value)), + Value::Float(value) => Ok(ffi::create_value_float(value)), + Value::Double(value) => Ok(ffi::create_value_double(value)), + Value::String(value) => Ok(ffi::create_value_string(&value)), + Value::Timestamp(value) => Ok(ffi::create_value_timestamp( + // Convert to microseconds since 1970-01-01 + (value.unix_timestamp_nanos() / 1000) as i64, + )), + Value::Date(value) => Ok(ffi::create_value_date( + // Convert to days since 1970-01-01 + (value - time::Date::from_ordinal_date(1970, 1).unwrap()).whole_days(), + )), + Value::Interval(value) => { + use time::Duration; + let mut interval = value; + let months = interval.whole_days() / 30; + interval -= Duration::days(months * 30); + let days = interval.whole_days(); + interval -= Duration::days(days); + let micros = interval.whole_microseconds() as i64; + Ok(ffi::create_value_interval( + months as i32, + days as i32, + micros, + )) + } + Value::VarList(typ, value) => { + let mut builder = ffi::create_list(); + for elem in value { + builder.pin_mut().insert(elem.try_into()?); + } + Ok(ffi::get_list_value( + (&LogicalType::VarList { + child_type: Box::new(typ), + }) + .into(), + builder, + )) + } + Value::FixedList(typ, value) => { + let mut builder = ffi::create_list(); + let len = value.len(); + for elem in value { + builder.pin_mut().insert(elem.try_into()?); + } + Ok(ffi::get_list_value( + (&LogicalType::FixedList { + child_type: Box::new(typ), + num_elements: len as u64, + }) + .into(), + builder, + )) + } + Value::Struct(value) => { + let typ: LogicalType = LogicalType::Struct { + fields: value + .iter() + .map(|(name, value)| { + // Unwrap is safe since we already converted when inserting into the + // builder + (name.clone(), Into::::into(value)) + }) + .collect(), + }; + + let mut builder = ffi::create_list(); + for (_, elem) in value { + builder.pin_mut().insert(elem.try_into()?); + } + + Ok(ffi::get_list_value((&typ).into(), builder)) + } + Value::InternalID(value) => { + Ok(ffi::create_value_internal_id(value.offset, value.table_id)) + } + Value::Node(value) => { + let mut node = ffi::create_value_node( + Value::InternalID(value.id).try_into()?, + Value::String(value.label).try_into()?, + ); + for (name, property) in value.properties { + ffi::value_add_property(node.pin_mut(), &name, property.try_into()?); + } + Ok(node) + } + Value::Rel(value) => { + let mut rel = ffi::create_value_rel( + Value::InternalID(value.src_node).try_into()?, + Value::InternalID(value.dst_node).try_into()?, + Value::String(value.label).try_into()?, + ); + for (name, property) in value.properties { + ffi::value_add_property(rel.pin_mut(), &name, property.try_into()?); + } + Ok(rel) + } + } + } +} + +impl From for Value { + fn from(item: i16) -> Self { + Value::Int16(item) + } +} + +impl From for Value { + fn from(item: i32) -> Self { + Value::Int32(item) + } +} + +impl From for Value { + fn from(item: i64) -> Self { + Value::Int64(item) + } +} + +impl From for Value { + fn from(item: f32) -> Self { + Value::Float(item) + } +} + +impl From for Value { + fn from(item: f64) -> Self { + Value::Double(item) + } +} + +impl From for Value { + fn from(item: String) -> Self { + Value::String(item) + } +} + +impl From<&str> for Value { + fn from(item: &str) -> Self { + Value::String(item.to_string()) + } +} + +#[cfg(test)] +mod tests { + use crate::ffi::ffi; + use crate::{ + connection::Connection, + database::Database, + logical_type::LogicalType, + value::{InternalID, NodeVal, RelVal, Value}, + }; + use anyhow::Result; + use std::collections::HashSet; + use std::convert::TryInto; + use std::iter::FromIterator; + use time::macros::{date, datetime}; + + // Note: Cargo runs tests in parallel by default, however kuzu does not support + // working with multiple databases in parallel. + // Tests can be run serially with `cargo test -- --test-threads=1` to work around this. + + macro_rules! type_tests { + ($($name:ident: $value:expr,)*) => { + $( + #[test] + /// Tests that the values are correctly converted into kuzu::common::Value and back + fn $name() -> Result<()> { + let rust_type: LogicalType = $value; + let typ: cxx::UniquePtr = (&rust_type).try_into()?; + let new_rust_type: LogicalType = typ.as_ref().unwrap().try_into()?; + assert_eq!(new_rust_type, rust_type); + Ok(()) + } + )* + } + } + + macro_rules! value_tests { + ($($name:ident: $value:expr,)*) => { + $( + #[test] + /// Tests that the values are correctly converted into kuzu::common::Value and back + fn $name() -> Result<()> { + let rust_value: Value = $value; + let value: cxx::UniquePtr = rust_value.clone().try_into()?; + let new_rust_value: Value = value.as_ref().unwrap().try_into()?; + assert_eq!(new_rust_value, rust_value); + Ok(()) + } + )* + } + } + + macro_rules! display_tests { + ($($name:ident: $value:expr,)*) => { + $( + #[test] + /// Tests that the values are correctly converted into kuzu::common::Value and back + fn $name() -> Result<()> { + let rust_value: Value = $value; + let value: cxx::UniquePtr = rust_value.clone().try_into()?; + assert_eq!(ffi::value_to_string(value.as_ref().unwrap()), format!("{rust_value}")); + Ok(()) + } + )* + } + } + + macro_rules! database_tests { + ($($name:ident: $value:expr, $decl:expr,)*) => { + $( + #[test] + /// Tests that passing the values through the database returns what we put in + fn $name() -> Result<()> { + let temp_dir = tempdir::TempDir::new("example")?; + let db = Database::new(temp_dir.path(), 0)?; + let conn = Connection::new(&db)?; + conn.query(&format!( + "CREATE NODE TABLE Person(name STRING, item {}, PRIMARY KEY(name));", + $decl, + ))?; + + let mut add_person = + conn.prepare("CREATE (:Person {name: $name, item: $item});")?; + conn.execute( + &mut add_person, + vec![("name", "Bob".into()), ("item", $value)], + )?; + let result = conn + .query("MATCH (a:Person) WHERE a.name = \"Bob\" RETURN a.item;")? + .next() + .unwrap(); + // TODO: Test equivalence to value constructed inside a a query + assert_eq!(result[0], $value); + temp_dir.close()?; + Ok(()) + } + )* + } + } + + type_tests! { + convert_var_list_type: LogicalType::VarList { child_type: Box::new(LogicalType::String) }, + convert_fixed_list_type: LogicalType::FixedList { child_type: Box::new(LogicalType::Int64), num_elements: 3 }, + convert_int16_type: LogicalType::Int16, + convert_int32_type: LogicalType::Int32, + convert_int64_type: LogicalType::Int64, + convert_float_type: LogicalType::Float, + convert_double_type: LogicalType::Double, + convert_timestamp_type: LogicalType::Timestamp, + convert_date_type: LogicalType::Date, + convert_interval_type: LogicalType::Interval, + convert_string_type: LogicalType::String, + convert_bool_type: LogicalType::Bool, + convert_struct_type: LogicalType::Struct { fields: vec![("NAME".to_string(), LogicalType::String)]}, + convert_node_type: LogicalType::Node, + convert_internal_id_type: LogicalType::InternalID, + convert_rel_type: LogicalType::Rel, + } + + value_tests! { + convert_var_list: Value::VarList(LogicalType::String, vec!["Alice".into(), "Bob".into()]), + convert_var_list_empty: Value::VarList(LogicalType::String, vec![]), + convert_fixed_list: Value::FixedList(LogicalType::String, vec!["Alice".into(), "Bob".into()]), + convert_int16: Value::Int16(1), + convert_int32: Value::Int32(2), + convert_int64: Value::Int64(3), + convert_float: Value::Float(4.), + convert_double: Value::Double(5.), + convert_timestamp: Value::Timestamp(datetime!(2023-06-13 11:25:30 UTC)), + convert_date: Value::Date(date!(2023-06-13)), + convert_interval: Value::Interval(time::Duration::weeks(10)), + convert_string: Value::String("Hello World".to_string()), + convert_bool: Value::Bool(false), + convert_null: Value::Null(LogicalType::VarList { + child_type: Box::new(LogicalType::FixedList { child_type: Box::new(LogicalType::Int16), num_elements: 3 }) + }), + convert_struct: Value::Struct(vec![("NAME".to_string(), "Alice".into()), ("AGE".to_string(), 25.into())]), + convert_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }), + convert_node: Value::Node(NodeVal::new(InternalID { table_id: 0, offset: 0 }, "Test Label".to_string())), + convert_rel: Value::Rel(RelVal::new(InternalID { table_id: 0, offset: 0 }, InternalID { table_id: 1, offset: 0 }, "Test Label".to_string())), + } + + display_tests! { + display_var_list: Value::VarList(LogicalType::String, vec!["Alice".into(), "Bob".into()]), + display_var_list_empty: Value::VarList(LogicalType::String, vec![]), + display_fixed_list: Value::FixedList(LogicalType::String, vec!["Alice".into(), "Bob".into()]), + display_int16: Value::Int16(1), + display_int32: Value::Int32(2), + display_int64: Value::Int64(3), + // Float, doble, interval and timestamp have display differences which we probably don't want to + // reconcile + display_date: Value::Date(date!(2023-06-13)), + display_string: Value::String("Hello World".to_string()), + display_bool: Value::Bool(false), + display_null: Value::Null(LogicalType::VarList { + child_type: Box::new(LogicalType::FixedList { child_type: Box::new(LogicalType::Int16), num_elements: 3 }) + }), + display_struct: Value::Struct(vec![("NAME".to_string(), "Alice".into()), ("AGE".to_string(), 25.into())]), + display_internal_id: Value::InternalID(InternalID { table_id: 0, offset: 0 }), + display_node: Value::Node(NodeVal::new(InternalID { table_id: 0, offset: 0 }, "Test Label".to_string())), + display_rel: Value::Rel(RelVal::new(InternalID { table_id: 0, offset: 0 }, InternalID { table_id: 1, offset: 0 }, "Test Label".to_string())), + } + + database_tests! { + // Passing these values as arguments is not yet implemented in kuzu: + // db_struct: + // Value::Struct(vec![("item".to_string(), "Knife".into()), ("count".to_string(), 1.into())]), + // "STRUCT(item STRING, count INT32)", + // db_fixed_list: Value::FixedList(LogicalType::String, vec!["Alice".into(), "Bob".into()]), "STRING[2]", + // db_null_string: Value::Null(LogicalType::String), "STRING", + // db_null_int: Value::Null(LogicalType::Int64), "INT64", + // db_null_list: Value::Null(LogicalType::VarList { + // child_type: Box::new(LogicalType::FixedList { child_type: Box::new(LogicalType::Int16), num_elements: 3 }) + // }), "INT16[3][]", + // db_var_list_string: Value::VarList(LogicalType::String, vec!["Alice".into(), "Bob".into()]), "STRING[]", + // db_var_list_int: Value::VarList(LogicalType::Int64, vec![0i64.into(), 1i64.into(), 2i64.into()]), "INT64[]", + db_int16: Value::Int16(1), "INT16", + db_int32: Value::Int32(2), "INT32", + db_int64: Value::Int64(3), "INT64", + db_float: Value::Float(4.), "FLOAT", + db_double: Value::Double(5.), "DOUBLE", + db_timestamp: Value::Timestamp(datetime!(2023-06-13 11:25:30 UTC)), "TIMESTAMP", + db_date: Value::Date(date!(2023-06-13)), "DATE", + db_interval: Value::Interval(time::Duration::weeks(200)), "INTERVAL", + db_string: Value::String("Hello World".to_string()), "STRING", + db_bool: Value::Bool(true), "BOOLEAN", + } + + #[test] + /// Tests that the list value is correctly constructed + fn test_var_list_get() -> Result<()> { + let temp_dir = tempdir::TempDir::new("example")?; + let db = Database::new(temp_dir.path(), 0)?; + let conn = Connection::new(&db)?; + for result in conn.query("RETURN [\"Alice\", \"Bob\"] AS l;")? { + assert_eq!(result.len(), 1); + assert_eq!( + result[0], + Value::VarList(LogicalType::String, vec!["Alice".into(), "Bob".into(),]) + ); + } + temp_dir.close()?; + Ok(()) + } + + #[test] + /// Test that the timestamp round-trips through kuzu's internal timestamp + fn test_timestamp() -> Result<()> { + let temp_dir = tempdir::TempDir::new("example")?; + let db = Database::new(temp_dir.path(), 0)?; + let conn = Connection::new(&db)?; + conn.query( + "CREATE NODE TABLE Person(name STRING, registerTime TIMESTAMP, PRIMARY KEY(name));", + )?; + conn.query( + "CREATE (:Person {name: \"Alice\", registerTime: timestamp(\"2011-08-20 11:25:30\")});", + )?; + let mut add_person = + conn.prepare("CREATE (:Person {name: $name, registerTime: $time});")?; + let timestamp = datetime!(2011-08-20 11:25:30 UTC); + conn.execute( + &mut add_person, + vec![ + ("name", "Bob".into()), + ("time", Value::Timestamp(timestamp)), + ], + )?; + let result: HashSet = conn + .query( + "MATCH (a:Person) WHERE a.registerTime = timestamp(\"2011-08-20 11:25:30\") RETURN a.name;", + )? + .map(|x| match &x[0] { + Value::String(x) => x.clone(), + _ => unreachable!(), + }) + .collect(); + assert_eq!( + result, + HashSet::from_iter(vec!["Alice".to_string(), "Bob".to_string()]) + ); + let mut result = + conn.query("MATCH (a:Person) WHERE a.name = \"Bob\" RETURN a.registerTime;")?; + let result: time::OffsetDateTime = + if let Value::Timestamp(timestamp) = result.next().unwrap()[0] { + timestamp + } else { + panic!("Wrong type returned!") + }; + assert_eq!(result, timestamp); + temp_dir.close()?; + Ok(()) + } + + #[test] + fn test_node() -> Result<()> { + let temp_dir = tempdir::TempDir::new("example")?; + let db = Database::new(temp_dir.path(), 0)?; + let conn = Connection::new(&db)?; + conn.query("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY(name));")?; + conn.query("CREATE (:Person {name: \"Alice\", age: 25});")?; + let result = conn.query("MATCH (a:Person) RETURN a;")?.next().unwrap(); + assert_eq!( + result[0], + Value::Node(NodeVal { + id: InternalID { + table_id: 0, + offset: 0 + }, + label: "Person".to_string(), + properties: vec![ + ("name".to_string(), Value::String("Alice".to_string())), + ("age".to_string(), Value::Int64(25)) + ] + }) + ); + temp_dir.close()?; + Ok(()) + } + + #[test] + /// Test that null values are read correctly by the API + fn test_null() -> Result<()> { + let temp_dir = tempdir::TempDir::new("example")?; + let db = Database::new(temp_dir.path(), 0)?; + let conn = Connection::new(&db)?; + let result = conn.query("RETURN null")?.next(); + let result = &result.unwrap()[0]; + assert_eq!(result, &Value::Null(LogicalType::String)); + temp_dir.close()?; + Ok(()) + } +}