From de5f8f7ebaef3685ddc6be1de84b747526fd528f Mon Sep 17 00:00:00 2001 From: Kasper Ziemianek Date: Fri, 5 Jul 2024 00:36:36 +0200 Subject: [PATCH] reuse websocket connection for ceremony rounds --- .../core/bc-musig2-runner/src/lib.rs | 152 +++++++++-------- .../core/direct-rpc-client/src/lib.rs | 154 +++++++++++++++--- 2 files changed, 213 insertions(+), 93 deletions(-) diff --git a/bitacross-worker/bitacross/core/bc-musig2-runner/src/lib.rs b/bitacross-worker/bitacross/core/bc-musig2-runner/src/lib.rs index 87fbe87b63..e4df3604c1 100644 --- a/bitacross-worker/bitacross/core/bc-musig2-runner/src/lib.rs +++ b/bitacross-worker/bitacross/core/bc-musig2-runner/src/lib.rs @@ -26,7 +26,7 @@ compile_error!("feature \"std\" and feature \"sgx\" cannot be enabled at the sam use bc_enclave_registry::EnclaveRegistryLookup; use codec::Encode; use core::time::Duration; -use itc_direct_rpc_client::{Response, RpcClient, RpcClientFactory}; +use itc_direct_rpc_client::{RpcClient, RpcClientFactory}; use itc_direct_rpc_server::SendRpcResponse; use itp_ocall_api::EnclaveAttestationOCallApi; use itp_sgx_crypto::{ @@ -55,11 +55,9 @@ use litentry_primitives::{aes_encrypt_default, Address32, AesRequest, Identity, #[cfg(feature = "sgx")] use std::sync::SgxMutex as Mutex; use std::{ + collections::HashMap, string::ToString, - sync::{ - mpsc::{sync_channel, SyncSender}, - Arc, - }, + sync::{mpsc::channel, Arc}, vec, }; @@ -82,27 +80,45 @@ pub fn init_ceremonies_thread + Send + Sync + 'static, Responder: SendRpcResponse + 'static, { - let (responses_sender, responses_receiver) = sync_channel(1000); + let (responses_sender, responses_receiver) = channel(); std::thread::spawn(move || { let my_identity: Address32 = signing_key_access.retrieve_key().unwrap().public().0.into(); let identity = Identity::Substrate(my_identity); + let mut peers_map = HashMap::new(); + let mut ceremonies_to_remove = vec![]; loop { + enclave_registry.get_all().iter().for_each(|(identity, address)| { + if my_identity != *identity && !peers_map.contains_key(identity.as_ref()) { + info!("creating new connection to peer: {:?}", address); + match client_factory.create(address, responses_sender.clone()) { + Ok(client) => { + peers_map.insert(*identity.as_ref(), client); + }, + Err(e) => error!("Could not connect to peer {}, reason: {:?}", address, e), + } + } + }); { - let mut ceremonies_to_remove = vec![]; - // do not hold lock for too long let ceremonies_to_process: Vec = - ceremony_registry.lock().unwrap().keys().cloned().collect(); + if let Ok(ceremonies) = ceremony_registry.try_lock() { + ceremonies.keys().cloned().collect() + } else { + warn!("Could not determine ceremonies to process"); + vec![] + }; for ceremony_id_to_process in ceremonies_to_process { // do not hold lock for too long as logic below includes I/O - let events = if let Some(ceremony) = - ceremony_registry.lock().unwrap().get_mut(&ceremony_id_to_process) - { - ceremony.tick() + let events = if let Ok(mut ceremonies) = ceremony_registry.try_lock() { + if let Some(ceremony) = ceremonies.get_mut(&ceremony_id_to_process) { + ceremony.tick() + } else { + warn!("Could not find ceremony with id: {:?}", ceremony_id_to_process); + vec![] + } } else { - warn!("Could not find ceremony with id: {:?}", ceremony_id_to_process); - continue + vec![] }; trace!("Got ceremony {:?} events {:?}", ceremony_id_to_process, events); @@ -136,12 +152,10 @@ pub fn init_ceremonies_thread( &request, + signer_id, + &mut peers_map, ); }); }, @@ -168,12 +182,10 @@ pub fn init_ceremonies_thread( &request, + signer_id, + &mut peers_map, ); }); }, @@ -240,13 +252,7 @@ pub fn init_ceremonies_thread(&request, &mut peers_map); }); }, CeremonyEvent::CeremonyTimedOut(signers, request_aes_key) => { @@ -290,40 +296,43 @@ pub fn init_ceremonies_thread(&request, &mut peers_map); }); }, } } } - let mut ceremony_commands = ceremony_commands.lock().unwrap(); - let mut ceremony_registry = ceremony_registry.lock().unwrap(); - ceremony_commands.retain(|_, ceremony_pending_commands| { - ceremony_pending_commands.retain_mut(|c| { - c.tick(); - c.ticks_left > 0 + let ceremony_commands = ceremony_commands.try_lock(); + let ceremony_registry = ceremony_registry.try_lock(); + + if let Ok(mut ceremony_commands) = ceremony_commands { + ceremony_commands.retain(|_, ceremony_pending_commands| { + ceremony_pending_commands.retain_mut(|c| { + c.tick(); + c.ticks_left > 0 + }); + !ceremony_pending_commands.is_empty() }); - !ceremony_pending_commands.is_empty() - }); - ceremonies_to_remove.iter().for_each(|ceremony_id| { - debug!("Removing ceremony {:?}", ceremony_id); - let _ = ceremony_registry.remove_entry(ceremony_id); - let _ = ceremony_commands.remove_entry(ceremony_id); - }); + if let Ok(mut ceremonies) = ceremony_registry { + ceremonies_to_remove.iter().for_each(|ceremony_id| { + debug!("Removing ceremony {:?}", ceremony_id); + let _ = ceremonies.remove_entry(ceremony_id); + let _ = ceremony_commands.remove_entry(ceremony_id); + }); + ceremonies_to_remove = vec![]; + } else { + warn!("Could not get ceremonies lock"); + } + } else { + warn!("Could not get ceremony commands lock"); + } } std::thread::sleep(Duration::from_millis(1)) } }); - // here we will process all responses std::thread::spawn(move || { while let Ok((_id, rpc_return_value)) = responses_receiver.recv() { @@ -335,28 +344,31 @@ pub fn init_ceremonies_thread( - client_factory: &Arc, - enclave_registry: &Arc, - responses_sender: &SyncSender, +fn send_to_signer( + request: &RpcRequest, signer_id: &SignerId, + peers: &mut HashMap, +) where + ClientFactory: RpcClientFactory, +{ + if let Some(client) = peers.get_mut(signer_id) { + if let Err(e) = client.send(request) { + error!("Could not send request to signer: {:?}, reason: {:?}", signer_id, e) + } + } +} + +fn send_to_all_signers( request: &RpcRequest, + peers: &mut HashMap, ) where - ClientFactory: RpcClientFactory + Send + Sync + 'static, - ER: EnclaveRegistryLookup + Send + Sync + 'static, + ClientFactory: RpcClientFactory, { - enclave_registry.get_all().iter().for_each(|(identity, address)| { - if signer_id == identity.as_ref() { - trace!("creating new connection to peer: {:?}", address); - match client_factory.create(address, responses_sender.clone()) { - Ok(mut client) => - if let Err(e) = client.send(address, request) { - error!("Could not send request to signer: {:?}, reason: {:?}", signer_id, e) - }, - Err(e) => error!("Could not connect to peer {}, reason: {:?}", address, e), - } + for (signer_id, client) in peers.iter_mut() { + if let Err(e) = client.send(request) { + error!("Could not send request to signer: {:?}, reason: {:?}", signer_id, e) } - }); + } } fn prepare_request( diff --git a/bitacross-worker/core/direct-rpc-client/src/lib.rs b/bitacross-worker/core/direct-rpc-client/src/lib.rs index bb67de7736..69f5d95bfc 100644 --- a/bitacross-worker/core/direct-rpc-client/src/lib.rs +++ b/bitacross-worker/core/direct-rpc-client/src/lib.rs @@ -39,17 +39,27 @@ use alloc::format; use core::str::FromStr; -use itp_rpc::{Id, RpcRequest, RpcReturnValue}; +use log::{debug, error}; + +use serde_json::from_str; + +use itp_rpc::{Id, RpcRequest, RpcResponse, RpcReturnValue}; + +use itp_utils::FromHexPrefixed; use std::{ boxed::Box, error::Error, net::TcpStream, string::String, - sync::{mpsc::SyncSender, Arc}, + sync::{ + mpsc::{channel, Sender}, + Arc, + }, + time::Duration, vec::Vec, }; -use tungstenite::{client_tls_with_config, Connector, Message}; +use tungstenite::{client_tls_with_config, stream::MaybeTlsStream, Connector, Message, WebSocket}; use url::Url; use webpki::{DNSName, DNSNameRef}; @@ -92,7 +102,7 @@ pub trait RpcClientFactory { fn create( &self, url: &str, - response_sink: SyncSender, + response_sink: Sender, ) -> Result>; } @@ -104,32 +114,22 @@ impl RpcClientFactory for DirectRpcClientFactory { fn create( &self, url: &str, - response_sink: SyncSender, + response_sink: Sender, ) -> Result> { DirectRpcClient::new(url, response_sink) } } pub trait RpcClient { - fn send(&mut self, url: &str, request: &RpcRequest) -> Result<(), Box>; + fn send(&mut self, request: &RpcRequest) -> Result<(), Box>; } -pub struct DirectRpcClient {} - -impl DirectRpcClient { - pub fn new(_url: &str, _response_sink: SyncSender) -> Result> { - Ok(Self {}) - } -} - -#[derive(Clone)] -pub enum RequestParams { - Rsa(Vec), - Aes(Vec), +pub struct DirectRpcClient { + request_sink: Sender, } -impl RpcClient for DirectRpcClient { - fn send(&mut self, url: &str, request: &RpcRequest) -> Result<(), Box> { +impl DirectRpcClient { + pub fn new(url: &str, response_sink: Sender) -> Result> { let server_url = Url::from_str(url).map_err(|e| format!("Could not connect, reason: {:?}", e))?; let mut config = rustls::ClientConfig::new(); @@ -144,11 +144,119 @@ impl RpcClient for DirectRpcClient { client_tls_with_config(server_url.as_str(), stream, None, Some(connector)) .map_err(|e| format!("Could not open websocket connection: {:?}", e))?; + let (request_sender, request_receiver) = channel(); + + //it fails to perform handshake in non_blocking mode so we are setting it up after the handshake is performed + Self::switch_to_non_blocking(&mut socket); + + std::thread::spawn(move || { + loop { + // let's flush all pending requests first + while let Ok(request) = request_receiver.try_recv() { + if let Err(e) = socket.write_message(Message::Text(request)) { + error!("Could not write message to socket, reason: {:?}", e) + } + } + + if let Ok(message) = socket.read_message() { + if let Ok(Some(response)) = Self::handle_ws_message(message) { + if let Err(e) = response_sink.send(response) { + log::error!("Could not forward response, reason: {:?}", e) + }; + } + } + std::thread::sleep(Duration::from_millis(1)) + } + }); + debug!("Connected to peer: {}", url); + Ok(Self { request_sink: request_sender }) + } + + fn switch_to_non_blocking(socket: &mut WebSocket>) { + match socket.get_ref() { + MaybeTlsStream::Plain(stream) => { + stream.set_nonblocking(true).expect("set_nonblocking call failed"); + stream + .set_read_timeout(Some(Duration::from_millis(5))) + .expect("set_read_timeout call failed"); + }, + MaybeTlsStream::Rustls(stream) => { + stream.get_ref().set_nonblocking(true).expect("set_nonblocking call failed"); + stream + .get_ref() + .set_read_timeout(Some(Duration::from_millis(1))) + .expect("set_read_timeout call failed"); + }, + _ => {}, + } + } + + fn handle_ws_message(message: Message) -> Result, Box> { + match message { + Message::Text(text) => { + let rpc_response: RpcResponse = from_str(&text) + .map_err(|e| format!("Could not deserialize RpcResponse, reason: {:?}", e))?; + let return_value: RpcReturnValue = + RpcReturnValue::from_hex(&rpc_response.result) + .map_err(|e| format!("Could not deserialize value , reason: {:?}", e))?; + Ok(Some((rpc_response.id, return_value))) + }, + _ => { + log::warn!("Only text messages are supported"); + Ok(None) + }, + } + } +} + +#[derive(Clone)] +pub enum RequestParams { + Rsa(Vec), + Aes(Vec), +} + +impl RpcClient for DirectRpcClient { + fn send(&mut self, request: &RpcRequest) -> Result<(), Box> { let request = serde_json::to_string(request) .map_err(|e| format!("Could not parse RpcRequest {:?}", e))?; + self.request_sink + .send(request) + .map_err(|e| format!("Could not write message, reason: {:?}", e).into()) + } +} + +#[cfg(test)] +mod tests { + use crate::DirectRpcClient; + use itp_rpc::{Id, RpcResponse, RpcReturnValue}; + use itp_types::{DirectRequestStatus, TrustedOperationStatus, H256}; + use itp_utils::ToHexPrefixed; + use tungstenite::Message; + + #[test] + fn test_response_handling() { + let id = Id::Text( + "0x0000000000000000000000000000000000000000000000000000000000000000".to_owned(), + ); + let return_value: RpcReturnValue = RpcReturnValue::new( + vec![], + false, + DirectRequestStatus::TrustedOperationStatus( + TrustedOperationStatus::TopExecuted(vec![], true), + H256::random(), + ), + ); + let rpc_response: RpcResponse = RpcResponse { + jsonrpc: "2.0".to_owned(), + result: return_value.to_hex(), + id: id.clone(), + }; + let serialized_rpc_response = serde_json::to_string(&rpc_response).unwrap(); + let message = Message::text(serialized_rpc_response); + + let (result_id, result) = DirectRpcClient::handle_ws_message(message).unwrap().unwrap(); - log::trace!("Sending request: {:?}", request); - socket.write_message(Message::Text(request))?; - Ok(()) + assert_eq!(id, result_id); + assert_eq!(return_value, result); } }