Skip to content

Commit

Permalink
feat: get snapshot and certification request payload via callback (#75)
Browse files Browse the repository at this point in the history
* feat: initial commit

* feat: allow cohort to create the certification request payload in callback

* feat: move changes to cohort_v2

* feat: wrap up using a closure get the request + snapshot before sending for certification

* feat: copy cohort_v2 temporary changes for callback to cohort

* feat: Pass the statemaps from the request created by the callback fn

* feat: use callback for installer

* fix: lint and fmt on updating rust to 1.72

* chore: updates from review comments

* fix: bring back generic trait implementation for Out of order installs

* chore: remove unused Enum variant and minor refactors based on review
  • Loading branch information
gk-kindred authored Sep 6, 2023
1 parent c802f0c commit 78f9059
Show file tree
Hide file tree
Showing 14 changed files with 347 additions and 344 deletions.
17 changes: 7 additions & 10 deletions packages/cohort_banking/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use cohort_sdk::{
cohort::Cohort,
model::{CandidateData, CertificationRequest, ClientErrorKind, Config},
model::{ClientErrorKind, Config},
};

use opentelemetry_api::{
Expand All @@ -13,9 +13,9 @@ use opentelemetry_api::{
use talos_agent::messaging::api::Decision;

use crate::{
callbacks::{oo_installer::OutOfOrderInstallerImpl, state_provider::StateProviderImpl},
callbacks::{certification_candidate_provider::CertificationCandidateProviderImpl, oo_installer::OutOfOrderInstallerImpl},
examples_support::queue_processor::Handler,
model::requests::{BusinessActionType, TransferRequest},
model::requests::{BusinessActionType, CandidateData, CertificationRequest, TransferRequest},
state::postgres::{database::Database, database_config::DatabaseConfig},
};

Expand Down Expand Up @@ -63,14 +63,12 @@ impl Handler<TransferRequest> for BankingApp {
async fn handle(&self, request: TransferRequest) -> Result<(), String> {
log::debug!("processig new banking transfer request: {:?}", request);

let request_copy = request.clone();

let statemap = vec![HashMap::from([(
BusinessActionType::TRANSFER.to_string(),
TransferRequest::new(request.from.clone(), request.to.clone(), request.amount).json(),
)])];

let request = CertificationRequest {
let certification_request = CertificationRequest {
timeout_ms: 0,
candidate: CandidateData {
readset: vec![request.from.clone(), request.to.clone()],
Expand All @@ -80,15 +78,14 @@ impl Handler<TransferRequest> for BankingApp {
};

let single_query_strategy = true;
let state_provider = StateProviderImpl {
let state_provider = CertificationCandidateProviderImpl {
database: Arc::clone(&self.database),
request: request_copy.clone(),
single_query_strategy,
};
let request_payload_callback = || state_provider.get_certification_candidate(certification_request.clone());

let oo_inst = OutOfOrderInstallerImpl {
database: Arc::clone(&self.database),
request: request_copy,
detailed_logging: false,
counter_oo_no_data_found: Arc::clone(&self.counter_oo_no_data_found),
single_query_strategy,
Expand All @@ -98,7 +95,7 @@ impl Handler<TransferRequest> for BankingApp {
.cohort_api
.as_ref()
.expect("Banking app is not initialised")
.certify(request, &state_provider, &oo_inst)
.certify(&request_payload_callback, &oo_inst)
.await
{
Ok(rsp) => {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
use std::sync::Arc;

use async_trait::async_trait;
use cohort_sdk::model::callbacks::{CapturedItemState, CapturedState, ItemStateProvider};
use cohort_sdk::model::callback::{CertificationCandidate, CertificationCandidateCallbackResponse, CertificationRequestPayload};
use rust_decimal::Decimal;
use tokio_postgres::Row;

use crate::{
model::{bank_account::BankAccount, requests::TransferRequest},
model::{bank_account::BankAccount, requests::CertificationRequest},
state::postgres::database::{Database, DatabaseError},
};

pub struct StateProviderImpl {
pub request: TransferRequest,
pub database: Arc<Database>,
pub single_query_strategy: bool,
#[derive(Debug, PartialEq, PartialOrd)]
pub struct CapturedState {
pub snapshot_version: u64,
pub items: Vec<CapturedItemState>,
}

#[derive(Debug, PartialEq, PartialOrd)]
pub struct CapturedItemState {
pub id: String,
pub version: u64,
}

impl StateProviderImpl {
pub fn account_from_row(row: &Row) -> Result<BankAccount, DatabaseError> {
impl TryFrom<&Row> for BankAccount {
type Error = DatabaseError;

fn try_from(row: &Row) -> Result<Self, Self::Error> {
Ok(BankAccount {
name: row
.try_get::<&str, String>("name")
Expand All @@ -33,20 +40,27 @@ impl StateProviderImpl {
.map_err(|e| DatabaseError::deserialise_payload(e.to_string(), "Cannot read account amount".into()))?,
})
}
}

pub struct CertificationCandidateProviderImpl {
pub database: Arc<Database>,
pub single_query_strategy: bool,
}

async fn get_state_using_two_queries(&self) -> Result<CapturedState, String> {
impl CertificationCandidateProviderImpl {
async fn get_state_using_two_queries(&self, from_account: &str, to_account: &str) -> Result<CapturedState, String> {
let list = self
.database
.query_many(
r#"SELECT ba.* FROM bank_accounts ba WHERE ba."number" = $1 OR ba."number" = $2"#,
&[&self.request.from, &self.request.to],
Self::account_from_row,
&[&from_account, &to_account],
|row| BankAccount::try_from(row),
)
.await
.map_err(|e| e.to_string())?;

if list.len() != 2 {
return Err(format!("Unable to load state of accounts: '{}' and '{}'", self.request.from, self.request.to));
return Err(format!("Unable to load state of accounts: '{:?}' and '{:?}'", from_account, to_account));
}

let snapshot_version = self
Expand All @@ -73,11 +87,10 @@ impl StateProviderImpl {
version: account.version,
})
.collect(),
abort_reason: None,
})
}

async fn get_state_using_one_query(&self) -> Result<CapturedState, String> {
async fn get_state_using_one_query(&self, from_account: &str, to_account: &str) -> Result<CapturedState, String> {
let list = self
.database
.query_many(
Expand All @@ -92,10 +105,10 @@ impl StateProviderImpl {
bank_accounts ba, cohort_snapshot cs
WHERE
ba."number" = $1 OR ba."number" = $2"#,
&[&self.request.from, &self.request.to],
&[&from_account, &to_account],
// convert RAW output into tuple (bank account, snap ver)
|row| {
let account = Self::account_from_row(row)?;
let account = BankAccount::try_from(row)?;
let snapshot_version = row
.try_get::<&str, i64>("snapshot_version")
.map_err(|e| DatabaseError::deserialise_payload(e.to_string(), "Cannot read snapshot_version".into()))?;
Expand All @@ -106,7 +119,7 @@ impl StateProviderImpl {
.map_err(|e| e.to_string())?;

if list.len() != 2 {
return Err(format!("Unable to load state of accounts: '{}' and '{}'", self.request.from, self.request.to));
return Err(format!("Unable to load state of accounts: '{:?}' and '{:?}'", from_account, to_account));
}

Ok(CapturedState {
Expand All @@ -118,18 +131,38 @@ impl StateProviderImpl {
version: tuple.0.version,
})
.collect(),
abort_reason: None,
})
}
}

#[async_trait]
impl ItemStateProvider for StateProviderImpl {
async fn get_state(&self) -> Result<CapturedState, String> {
if self.single_query_strategy {
self.get_state_using_one_query().await
pub async fn get_certification_candidate(&self, request: CertificationRequest) -> Result<CertificationCandidateCallbackResponse, String> {
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// This example doesn't handle `Cancelled` scenario.
// If user cancellation is needed, add additional logic in this fn to return `Cancelled` instead of `Proceed` in the result.
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

// The order of the accounts doesn't really matter in this example as we just use these accounts to get their respective versions.
// Safe assumption made here that we have 2 items in the writeset here. As our example is to transfer between 2 accounts.
// The other alternative is to deserialize from statemap Value, which could be expensive comparitively, also we may not have statemap.
let first_account = &request.candidate.writeset[0];
let second_account = &request.candidate.writeset[1];

let state = if self.single_query_strategy {
self.get_state_using_one_query(first_account, second_account).await
} else {
self.get_state_using_two_queries().await
}
self.get_state_using_two_queries(first_account, second_account).await
}?;

let candidate = CertificationCandidate {
readset: request.candidate.readset,
writeset: request.candidate.writeset,
statemaps: request.candidate.statemap,
readvers: state.items.into_iter().map(|x| x.version).collect(),
};

Ok(CertificationCandidateCallbackResponse::Proceed(CertificationRequestPayload {
candidate,
snapshot: state.snapshot_version,
timeout_ms: request.timeout_ms,
}))
}
}
2 changes: 1 addition & 1 deletion packages/cohort_banking/src/callbacks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pub mod certification_candidate_provider;
pub mod oo_installer;
pub mod state_provider;
pub mod statemap_installer;
Loading

0 comments on commit 78f9059

Please sign in to comment.