Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add access mode option to Rust API #2233

Merged
merged 1 commit into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/main/database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ static void getLockFileFlagsAndType(
AccessMode accessMode, bool createNew, int& flags, FileLockType& lock) {
flags = accessMode == AccessMode::READ_ONLY ? O_RDONLY : O_RDWR;
if (createNew) {
assert(flags == O_RDWR);
flags |= O_CREAT;
}
lock = accessMode == AccessMode::READ_ONLY ? FileLockType::READ_LOCK : FileLockType::WRITE_LOCK;
Expand Down
3 changes: 2 additions & 1 deletion tools/rust_api/include/kuzu_rs.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ std::unique_ptr<std::vector<kuzu::common::LogicalType>> logical_type_get_struct_

/* Database */
std::unique_ptr<kuzu::main::Database> new_database(const std::string& databasePath,
uint64_t bufferPoolSize, uint64_t maxNumThreads, bool enableCompression);
uint64_t bufferPoolSize, uint64_t maxNumThreads, bool enableCompression,
kuzu::main::AccessMode accessMode);

void database_set_logging_level(kuzu::main::Database& database, const std::string& level);

Expand Down
49 changes: 48 additions & 1 deletion tools/rust_api/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ pub enum LoggingLevel {
Error,
}

#[derive(Clone, Debug)]
pub enum AccessMode {
ReadWrite,
ReadOnly,
}

impl From<AccessMode> for ffi::AccessMode {
fn from(other: AccessMode) -> Self {
match other {
AccessMode::ReadWrite => ffi::AccessMode::READ_WRITE,
AccessMode::ReadOnly => ffi::AccessMode::READ_ONLY,
}
}
}

#[derive(Clone, Debug)]
/// Configuration options for the database.
pub struct SystemConfig {
Expand All @@ -36,6 +51,7 @@ pub struct SystemConfig {
/// When true, new columns will be compressed if possible
/// Defaults to true
enable_compression: bool,
access_mode: AccessMode,
}

impl Default for SystemConfig {
Expand All @@ -44,6 +60,7 @@ impl Default for SystemConfig {
buffer_pool_size: 0,
max_num_threads: 0,
enable_compression: true,
access_mode: AccessMode::ReadWrite,
}
}
}
Expand All @@ -61,6 +78,10 @@ impl SystemConfig {
self.enable_compression = enable_compression;
self
}
pub fn access_mode(mut self, access_mode: AccessMode) -> Self {
self.access_mode = access_mode;
self
}
}

impl Database {
Expand All @@ -77,6 +98,7 @@ impl Database {
config.buffer_pool_size,
config.max_num_threads,
config.enable_compression,
config.access_mode.into(),
)?),
})
}
Expand Down Expand Up @@ -107,7 +129,8 @@ impl fmt::Debug for Database {

#[cfg(test)]
mod tests {
use crate::database::{Database, LoggingLevel, SystemConfig};
use crate::connection::Connection;
use crate::database::{AccessMode, Database, LoggingLevel, SystemConfig};
use anyhow::{Error, Result};
// Note: Cargo runs tests in parallel by default, however kuzu does not support
// working with multiple databases in parallel.
Expand Down Expand Up @@ -156,4 +179,28 @@ mod tests {
.starts_with("Failed to create directory due to"));
}
}

#[test]
fn test_database_read_only() -> Result<()> {
let temp_dir = tempfile::tempdir()?;
// Create database first so that it can be opened read-only
{
Database::new(temp_dir.path(), SystemConfig::default())?;
}
let db = Database::new(
temp_dir.path(),
SystemConfig::default().access_mode(AccessMode::ReadOnly),
)?;
let conn = Connection::new(&db)?;
let result: Error = conn
.query("CREATE NODE TABLE Person(name STRING, age INT64, PRIMARY KEY(name));")
.expect_err("Invalid syntax in query should produce an error")
.into();

assert_eq!(
result.to_string(),
"Cannot execute write operations in a read-only access mode database!"
);
Ok(())
}
}
14 changes: 14 additions & 0 deletions tools/rust_api/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,24 @@ pub(crate) mod ffi {
MAP = 54,
UNION = 55,
}

#[namespace = "kuzu::common"]
unsafe extern "C++" {
type LogicalTypeID;
}

#[namespace = "kuzu::main"]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum AccessMode {
READ_ONLY = 0,
READ_WRITE = 1,
}

#[namespace = "kuzu::main"]
unsafe extern "C++" {
type AccessMode;
}

#[namespace = "kuzu::main"]
unsafe extern "C++" {
include!("kuzu/include/kuzu_rs.h");
Expand Down Expand Up @@ -84,6 +97,7 @@ pub(crate) mod ffi {
bufferPoolSize: u64,
maxNumThreads: u64,
enableCompression: bool,
access_mode: AccessMode,
) -> Result<UniquePtr<Database>>;

fn database_set_logging_level(database: Pin<&mut Database>, level: &CxxString);
Expand Down
3 changes: 2 additions & 1 deletion tools/rust_api/src/kuzu_rs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,15 @@ std::unique_ptr<std::vector<kuzu::common::LogicalType>> logical_type_get_struct_
}

std::unique_ptr<Database> new_database(const std::string& databasePath, uint64_t bufferPoolSize,
uint64_t maxNumThreads, bool enableCompression) {
uint64_t maxNumThreads, bool enableCompression, kuzu::main::AccessMode accessMode) {
auto systemConfig = SystemConfig();
if (bufferPoolSize > 0) {
systemConfig.bufferPoolSize = bufferPoolSize;
}
if (maxNumThreads > 0) {
systemConfig.maxNumThreads = maxNumThreads;
}
systemConfig.accessMode = accessMode;
systemConfig.enableCompression = enableCompression;
return std::make_unique<Database>(databasePath, systemConfig);
}
Expand Down
2 changes: 1 addition & 1 deletion tools/rust_api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ mod query_result;
mod value;

pub use connection::{Connection, PreparedStatement};
pub use database::{Database, LoggingLevel, SystemConfig};
pub use database::{AccessMode, Database, LoggingLevel, SystemConfig};
pub use error::Error;
pub use logical_type::LogicalType;
pub use query_result::{CSVOptions, QueryResult};
Expand Down