Skip to content

Commit

Permalink
Merge pull request #62 from tpisto/main
Browse files Browse the repository at this point in the history
Row<T>
  • Loading branch information
psarna committed Jul 12, 2023
2 parents 056af57 + 677c956 commit b73ad11
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 43 deletions.
6 changes: 4 additions & 2 deletions bindings/c/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ use types::{DbContext, MVCCDatabaseRef, MVCCScanCursorRef, ScanCursorContext};
type Clock = clock::LocalClock;

/// cbindgen:ignore
type Db = database::Database<Clock>;
/// Note - We use String type in C bindings as Row type. Type is generic.
type Db = database::Database<Clock, String>;

/// cbindgen:ignore
type ScanCursor = cursor::ScanCursor<'static, Clock>;
/// Note - We use String type in C bindings as Row type. Type is generic.
type ScanCursor = cursor::ScanCursor<'static, Clock, String>;

static INIT_RUST_LOG: std::sync::Once = std::sync::Once::new();

Expand Down
2 changes: 1 addition & 1 deletion mvcc-rs/benches/my_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use mvcc_rs::clock::LocalClock;
use mvcc_rs::database::{Database, Row, RowID};
use pprof::criterion::{Output, PProfProfiler};

fn bench_db() -> Database<LocalClock> {
fn bench_db() -> Database<LocalClock, String> {
let clock = LocalClock::default();
let storage = mvcc_rs::persistent_storage::Storage::new_noop();
Database::new(clock, storage)
Expand Down
16 changes: 10 additions & 6 deletions mvcc-rs/src/cursor.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
use serde::de::DeserializeOwned;
use serde::Serialize;

use crate::clock::LogicalClock;
use crate::database::{Database, Result, Row, RowID};
use std::fmt::Debug;

#[derive(Debug)]
pub struct ScanCursor<'a, Clock: LogicalClock> {
pub db: &'a Database<Clock>,
pub struct ScanCursor<'a, Clock: LogicalClock, T: Sync + Send + Clone + Serialize + DeserializeOwned + Debug> {
pub db: &'a Database<Clock, T>,
pub row_ids: Vec<RowID>,
pub index: usize,
tx_id: u64,
}

impl<'a, Clock: LogicalClock> ScanCursor<'a, Clock> {
impl<'a, Clock: LogicalClock, T: Sync + Send + Clone + Serialize + DeserializeOwned + Debug + 'static> ScanCursor<'a, Clock, T> {
pub fn new(
db: &'a Database<Clock>,
db: &'a Database<Clock, T>,
tx_id: u64,
table_id: u64,
) -> Result<ScanCursor<'a, Clock>> {
) -> Result<ScanCursor<'a, Clock, T>> {
let row_ids = db.scan_row_ids_for_table(table_id)?;
Ok(Self {
db,
Expand All @@ -31,7 +35,7 @@ impl<'a, Clock: LogicalClock> ScanCursor<'a, Clock> {
Some(self.row_ids[self.index])
}

pub fn current_row(&self) -> Result<Option<Row>> {
pub fn current_row(&self) -> Result<Option<Row<T>>> {
if self.index >= self.row_ids.len() {
return Ok(None);
}
Expand Down
59 changes: 34 additions & 25 deletions mvcc-rs/src/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use crate::clock::LogicalClock;
use crate::errors::DatabaseError;
use crate::persistent_storage::Storage;
use crossbeam_skiplist::{SkipMap, SkipSet};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;

Expand All @@ -19,29 +21,29 @@ pub struct RowID {

#[derive(Clone, Debug, PartialEq, PartialOrd, Serialize, Deserialize)]

pub struct Row {
pub struct Row<T> {
pub id: RowID,
pub data: String,
pub data: T,
}

/// A row version.
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct RowVersion {
pub struct RowVersion<T> {
begin: TxTimestampOrID,
end: Option<TxTimestampOrID>,
row: Row,
row: Row<T>,
}

pub type TxID = u64;

/// A log record contains all the versions inserted and deleted by a transaction.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LogRecord {
pub struct LogRecord<T> {
pub(crate) tx_timestamp: TxID,
row_versions: Vec<RowVersion>,
row_versions: Vec<RowVersion<T>>,
}

impl LogRecord {
impl<T> LogRecord<T> {
fn new(tx_timestamp: TxID) -> Self {
Self {
tx_timestamp,
Expand Down Expand Up @@ -254,15 +256,20 @@ impl AtomicTransactionState {
}

#[derive(Debug)]
pub struct Database<Clock: LogicalClock> {
rows: SkipMap<RowID, RwLock<Vec<RowVersion>>>,
pub struct Database<
Clock: LogicalClock,
T: Sync + Send + Clone + Serialize + Debug + DeserializeOwned,
> {
rows: SkipMap<RowID, RwLock<Vec<RowVersion<T>>>>,
txs: SkipMap<TxID, RwLock<Transaction>>,
tx_ids: AtomicU64,
clock: Clock,
storage: Storage,
}

impl<Clock: LogicalClock> Database<Clock> {
impl<Clock: LogicalClock, T: Sync + Send + Clone + Serialize + Debug + DeserializeOwned + 'static>
Database<Clock, T>
{
/// Creates a new database.
pub fn new(clock: Clock, storage: Storage) -> Self {
Self {
Expand Down Expand Up @@ -292,15 +299,15 @@ impl<Clock: LogicalClock> Database<Clock> {

/// Inserts a new row version into the database, while making sure that
/// the row version is inserted in the correct order.
fn insert_version(&self, id: RowID, row_version: RowVersion) {
fn insert_version(&self, id: RowID, row_version: RowVersion<T>) {
let versions = self.rows.get_or_insert_with(id, || RwLock::new(Vec::new()));
let mut versions = versions.value().write().unwrap();
self.insert_version_raw(&mut versions, row_version)
}

/// Inserts a new row version into the internal data structure for versions,
/// while making sure that the row version is inserted in the correct order.
fn insert_version_raw(&self, versions: &mut Vec<RowVersion>, row_version: RowVersion) {
fn insert_version_raw(&self, versions: &mut Vec<RowVersion<T>>, row_version: RowVersion<T>) {
// NOTICE: this is an insert a'la insertion sort, with pessimistic linear complexity.
// However, we expect the number of versions to be nearly sorted, so we deem it worthy
// to search linearly for the insertion point instead of paying the price of using
Expand Down Expand Up @@ -333,7 +340,7 @@ impl<Clock: LogicalClock> Database<Clock> {
/// * `tx_id` - the ID of the transaction in which to insert the new row.
/// * `row` - the row object containing the values to be inserted.
///
pub fn insert(&self, tx_id: TxID, row: Row) -> Result<()> {
pub fn insert(&self, tx_id: TxID, row: Row<T>) -> Result<()> {
let tx = self
.txs
.get(&tx_id)
Expand Down Expand Up @@ -370,7 +377,7 @@ impl<Clock: LogicalClock> Database<Clock> {
/// # Returns
///
/// Returns `true` if the row was successfully updated, and `false` otherwise.
pub fn update(&self, tx_id: TxID, row: Row) -> Result<bool> {
pub fn update(&self, tx_id: TxID, row: Row<T>) -> Result<bool> {
if !self.delete(tx_id, row.id)? {
return Ok(false);
}
Expand All @@ -380,7 +387,7 @@ impl<Clock: LogicalClock> Database<Clock> {

/// Inserts a row in the database with new values, previously deleting
/// any old data if it existed. Bails on a delete error, e.g. write-write conflict.
pub fn upsert(&self, tx_id: TxID, row: Row) -> Result<()> {
pub fn upsert(&self, tx_id: TxID, row: Row<T>) -> Result<()> {
self.delete(tx_id, row.id)?;
self.insert(tx_id, row)
}
Expand Down Expand Up @@ -449,7 +456,7 @@ impl<Clock: LogicalClock> Database<Clock> {
///
/// Returns `Some(row)` with the row data if the row with the given `id` exists,
/// and `None` otherwise.
pub fn read(&self, tx_id: TxID, id: RowID) -> Result<Option<Row>> {
pub fn read(&self, tx_id: TxID, id: RowID) -> Result<Option<Row<T>>> {
let tx = self.txs.get(&tx_id).unwrap();
let tx = tx.value().read().unwrap();
assert_eq!(tx.state, TransactionState::Active);
Expand Down Expand Up @@ -606,7 +613,7 @@ impl<Clock: LogicalClock> Database<Clock> {
drop(tx);
// Postprocessing: inserting row versions and logging the transaction to persistent storage.
// TODO: we should probably save to persistent storage first, and only then update the in-memory structures.
let mut log_record: LogRecord = LogRecord::new(end_ts);
let mut log_record: LogRecord<T> = LogRecord::new(end_ts);
for ref id in write_set {
if let Some(row_versions) = self.rows.get(id) {
let mut row_versions = row_versions.value().write().unwrap();
Expand Down Expand Up @@ -665,6 +672,7 @@ impl<Clock: LogicalClock> Database<Clock> {
tracing::trace!("ABORT {tx}");
let write_set: Vec<RowID> = tx.write_set.iter().map(|v| *v.value()).collect();
drop(tx);

for ref id in write_set {
if let Some(row_versions) = self.rows.get(id) {
let mut row_versions = row_versions.value().write().unwrap();
Expand All @@ -674,6 +682,7 @@ impl<Clock: LogicalClock> Database<Clock> {
}
}
}

let tx = tx_unlocked.value().write().unwrap();
tx.state.store(TransactionState::Terminated);
tracing::trace!("TERMINATE {tx}");
Expand Down Expand Up @@ -765,10 +774,10 @@ impl<Clock: LogicalClock> Database<Clock> {

/// A write-write conflict happens when transaction T_m attempts to update a
/// row version that is currently being updated by an active transaction T_n.
pub(crate) fn is_write_write_conflict(
pub(crate) fn is_write_write_conflict<T>(
txs: &SkipMap<TxID, RwLock<Transaction>>,
tx: &Transaction,
rv: &RowVersion,
rv: &RowVersion<T>,
) -> bool {
match rv.end {
Some(TxTimestampOrID::TxID(rv_end)) => {
Expand All @@ -784,18 +793,18 @@ pub(crate) fn is_write_write_conflict(
}
}

pub(crate) fn is_version_visible(
pub(crate) fn is_version_visible<T>(
txs: &SkipMap<TxID, RwLock<Transaction>>,
tx: &Transaction,
rv: &RowVersion,
rv: &RowVersion<T>,
) -> bool {
is_begin_visible(txs, tx, rv) && is_end_visible(txs, tx, rv)
}

fn is_begin_visible(
fn is_begin_visible<T>(
txs: &SkipMap<TxID, RwLock<Transaction>>,
tx: &Transaction,
rv: &RowVersion,
rv: &RowVersion<T>,
) -> bool {
match rv.begin {
TxTimestampOrID::Timestamp(rv_begin_ts) => tx.begin_ts >= rv_begin_ts,
Expand All @@ -822,10 +831,10 @@ fn is_begin_visible(
}
}

fn is_end_visible(
fn is_end_visible<T>(
txs: &SkipMap<TxID, RwLock<Transaction>>,
tx: &Transaction,
rv: &RowVersion,
rv: &RowVersion<T>,
) -> bool {
match rv.end {
Some(TxTimestampOrID::Timestamp(rv_end_ts)) => tx.begin_ts < rv_end_ts,
Expand Down
4 changes: 2 additions & 2 deletions mvcc-rs/src/database/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::*;
use crate::clock::LocalClock;
use tracing_test::traced_test;

fn test_db() -> Database<LocalClock> {
fn test_db() -> Database<LocalClock, String> {
let clock = LocalClock::new();
let storage = crate::persistent_storage::Storage::new_noop();
Database::new(clock, storage)
Expand Down Expand Up @@ -721,7 +721,7 @@ fn test_storage1() {

let clock = LocalClock::new();
let storage = crate::persistent_storage::Storage::new_json_on_disk(path);
let db = Database::new(clock, storage);
let db: Database<LocalClock, String> = Database::new(clock, storage);
db.recover().unwrap();
println!("{:#?}", db);

Expand Down
10 changes: 7 additions & 3 deletions mvcc-rs/src/persistent_storage/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::fmt::Debug;

use crate::database::{LogRecord, Result};
use crate::errors::DatabaseError;

Expand Down Expand Up @@ -27,7 +31,7 @@ impl Storage {
}

impl Storage {
pub fn log_tx(&self, m: LogRecord) -> Result<()> {
pub fn log_tx<T: Serialize>(&self, m: LogRecord<T>) -> Result<()> {
match self {
Self::JsonOnDisk(path) => {
use std::io::Write;
Expand All @@ -50,7 +54,7 @@ impl Storage {
Ok(())
}

pub fn read_tx_log(&self) -> Result<Vec<LogRecord>> {
pub fn read_tx_log<T: DeserializeOwned + Debug>(&self) -> Result<Vec<LogRecord<T>>> {
match self {
Self::JsonOnDisk(path) => {
use std::io::BufRead;
Expand All @@ -59,7 +63,7 @@ impl Storage {
.open(path)
.map_err(|e| DatabaseError::Io(e.to_string()))?;

let mut records: Vec<LogRecord> = Vec::new();
let mut records: Vec<LogRecord<T>> = Vec::new();
let mut lines = std::io::BufReader::new(file).lines();
while let Some(Ok(line)) = lines.next() {
records.push(
Expand Down
11 changes: 7 additions & 4 deletions mvcc-rs/src/persistent_storage/s3.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::database::{LogRecord, Result};
use crate::errors::DatabaseError;
use aws_sdk_s3::Client;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::fmt::Debug;

#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
Expand Down Expand Up @@ -66,7 +69,7 @@ impl Replicator {
})
}

pub async fn replicate_tx(&self, record: LogRecord) -> Result<()> {
pub async fn replicate_tx<T: Serialize>(&self, record: LogRecord<T>) -> Result<()> {
let key = format!("{}-{:020}", self.prefix, record.tx_timestamp);
tracing::trace!("Replicating {key}");
let body = serde_json::to_vec(&record).map_err(|e| DatabaseError::Io(e.to_string()))?;
Expand All @@ -83,8 +86,8 @@ impl Replicator {
Ok(())
}

pub async fn read_tx_log(&self) -> Result<Vec<LogRecord>> {
let mut records: Vec<LogRecord> = Vec::new();
pub async fn read_tx_log<T: DeserializeOwned + Debug>(&self) -> Result<Vec<LogRecord<T>>> {
let mut records: Vec<LogRecord<T>> = Vec::new();
// Read all objects from the bucket, one log record is stored in one object
let mut next_token = None;
loop {
Expand Down Expand Up @@ -120,7 +123,7 @@ impl Replicator {
.collect()
.await
.map_err(|e| DatabaseError::Io(e.to_string()))?;
let record: LogRecord = serde_json::from_slice(&body.into_bytes())
let record: LogRecord<T> = serde_json::from_slice(&body.into_bytes())
.map_err(|e| DatabaseError::Io(e.to_string()))?;
records.push(record);
}
Expand Down

0 comments on commit b73ad11

Please sign in to comment.