From 3e91a9fc1ebd839ab0bfbefdd91a8f7a0cb240e7 Mon Sep 17 00:00:00 2001 From: Matthew Waters Date: Sun, 14 Jul 2024 13:41:17 +1000 Subject: [PATCH] deps: update to stun-proto 0.2.0 --- Cargo.toml | 2 +- librice-proto/src/conncheck.rs | 89 +++-- librice-proto/src/gathering.rs | 578 +++++++++++++++++++++------------ librice-proto/src/stream.rs | 7 +- librice/examples/icegather.rs | 7 +- librice/src/component.rs | 20 +- librice/tests/common/mod.rs | 35 +- librice/tests/stund.rs | 11 +- 8 files changed, 478 insertions(+), 271 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index efb1ecc..b6304e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ edition = "2021" rust-version = "1.68.2" [workspace.dependencies] -stun-proto = { version = "0.1.0" } +stun-proto = "0.2.0" arbitrary = { version = "1", features = ["derive"] } byteorder = "1" get_if_addrs = "0.5" diff --git a/librice-proto/src/conncheck.rs b/librice-proto/src/conncheck.rs index 89c83a7..a990c31 100644 --- a/librice-proto/src/conncheck.rs +++ b/librice-proto/src/conncheck.rs @@ -17,6 +17,7 @@ use std::time::{Duration, Instant}; use crate::candidate::{Candidate, CandidatePair, CandidateType, TcpType, TransportType}; use crate::component::ComponentConnectionState; +use byteorder::{BigEndian, ByteOrder}; use stun_proto::agent::{ HandleStunReply, StunAgent, StunAgentPollRet, StunError, TcpBuffer, Transmit, }; @@ -1337,7 +1338,6 @@ impl ConnCheckList { self.triggered .retain(|check_id| !triggered_to_remove.contains(check_id)); let nominated_ids = self.nominated.clone(); - error!("nominated ids {nominated_ids:?}"); self.pairs.retain(|check| { if nominated_ids.contains(&check.conncheck_id) { true @@ -1582,6 +1582,7 @@ impl ConnCheckListSetBuilder { checklist_i: 0, last_send_time: Instant::now() - ConnCheckListSet::MINIMUM_SET_TICK, pending_transmits: Default::default(), + pending_messages: Default::default(), } } } @@ -1596,6 +1597,13 @@ pub struct ConnCheckListSet { checklist_i: usize, last_send_time: Instant, pending_transmits: VecDeque<(usize, usize, Transmit<'static>)>, + pending_messages: VecDeque<( + usize, + usize, + Arc>, + MessageBuilder<'static>, + SocketAddr, + )>, } impl ConnCheckListSet { @@ -1662,12 +1670,12 @@ impl ConnCheckListSet { &request, transmit.from, )? { - debug!("Sending response {response:?} to {:?}", transmit.from); - let mut agent_inner = agent.lock().unwrap(); - self.pending_transmits.push_front(( + self.pending_messages.push_back(( checklist_id, local_cand.component_id, - agent_inner.send(response, transmit.from)?.into_owned(), + agent.clone(), + response.into_owned(), + transmit.from, )); return Ok(Some(HandleRecvReply::Handled)); } @@ -2311,12 +2319,12 @@ impl ConnCheckListSet { let mut agent = agent.lock().unwrap(); let transmit = agent - .send(stun_request, conncheck.pair.remote.address) + .send(stun_request, conncheck.pair.remote.address, now) .unwrap(); Ok(CheckListSetPollRet::Transmit( checklist_id, conncheck.pair.local.component_id, - transmit.into_owned(), + transmit_send(transmit), )) } @@ -2410,6 +2418,20 @@ impl ConnCheckListSet { if let Some((checklist_id, cid, transmit)) = self.pending_transmits.pop_back() { return CheckListSetPollRet::Transmit(checklist_id, cid, transmit); } + if let Some((checklist_id, cid, agent, msg, to)) = self.pending_messages.pop_back() { + debug!("Sending response {msg:?} to {:?}", to); + let mut agent_inner = agent.lock().unwrap(); + match agent_inner.send(msg, to, now) { + Ok(transmit) => { + return CheckListSetPollRet::Transmit( + checklist_id, + cid, + transmit_send(transmit), + ) + } + Err(e) => warn!("error sending: {e}"), + } + } for checklist in self.checklists.iter_mut() { if let Some(event) = checklist.poll_event() { @@ -2469,7 +2491,7 @@ impl ConnCheckListSet { return CheckListSetPollRet::Transmit( checklist.checklist_id, check.pair.local.component_id, - transmit.into_owned(), + transmit_send(transmit), ); } } @@ -2589,15 +2611,14 @@ impl ConnCheckListSet { )); let transaction_id = stun_request.transaction_id(); - self.pending_transmits.push_front(( + let agent = Arc::new(Mutex::new(agent)); + self.pending_messages.push_front(( checklist_id, check.pair.local.component_id, - agent - .send(stun_request, check.pair.remote.address) - .unwrap() - .into_owned(), + agent.clone(), + stun_request.into_owned(), + check.pair.remote.address, )); - let agent = Arc::new(Mutex::new(agent)); checklist.agents.push(agent.clone()); let mut new_check = ConnCheck::new( @@ -2689,6 +2710,24 @@ fn validate_username(username: Username, local_credentials: &Credentials) -> boo } } +pub(crate) fn transmit_send(transmit: Transmit) -> Transmit<'static> { + match transmit.transport { + TransportType::Udp => transmit.into_owned(), + TransportType::Tcp => { + let mut data = Vec::with_capacity(transmit.data.len()); + data.resize(2, 0); + BigEndian::write_u16(&mut data, transmit.data.len() as u16); + data.extend_from_slice(&transmit.data); + Transmit::new_owned( + data.into_boxed_slice(), + transmit.transport, + transmit.from, + transmit.to, + ) + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -2973,6 +3012,7 @@ mod tests { transmit: Transmit<'_>, error_response: Option, response_address: Option, + now: Instant, ) -> Option> { // XXX: assumes that tcp framing is not in use let offset = match transmit.transport { @@ -2996,9 +3036,10 @@ mod tests { ) .unwrap(), transmit.from, + now, ) .unwrap(); - return Some(transmit.into_owned()); + return Some(transmit_send(transmit)); } } } @@ -3327,6 +3368,7 @@ mod tests { transmit, self.error_response, self.response_address, + now, ) .unwrap(); info!("reply: {reply:?}"); @@ -3656,6 +3698,7 @@ mod tests { transmit, None, None, + now, ) else { unreachable!(); }; @@ -3690,6 +3733,7 @@ mod tests { transmit.into_owned(), None, None, + now, ) else { unreachable!(); }; @@ -3733,6 +3777,7 @@ mod tests { .priority(10); let mut state = state.build(); state.local_list().generate_checks(); + let now = Instant::now(); let remote_addr = SocketAddr::new(state.remote.candidate.base_address.ip(), 2000); let mut remote_cand = state.remote.candidate.clone(); remote_cand.address = remote_addr; @@ -3766,9 +3811,11 @@ mod tests { .checklist_set .incoming_data( state.local.checklist_id, - &remote_agent - .send(request, state.local.peer.candidate.base_address) - .unwrap(), + &transmit_send( + remote_agent + .send(request, state.local.peer.candidate.base_address, now) + .unwrap() + ), ) .unwrap()[0], HandleRecvReply::Handled @@ -3810,6 +3857,7 @@ mod tests { transmit.into_owned(), None, None, + now, ) else { unreachable!(); }; @@ -3845,6 +3893,7 @@ mod tests { transmit.into_owned(), None, None, + now, ) else { unreachable!(); }; @@ -3932,7 +3981,7 @@ mod tests { request.add_fingerprint().unwrap(); let local_addr = state.local.peer.stun_agent().local_addr(); - let transmit = remote_agent.send(request, local_addr).unwrap(); + let transmit = remote_agent.send(request, local_addr, now).unwrap(); info!("sending prflx request"); let reply = state @@ -4314,7 +4363,7 @@ mod tests { request.add_fingerprint().unwrap(); let local_addr = state.local.peer.stun_agent().local_addr(); - let transmit = remote_agent.send(request, local_addr).unwrap(); + let transmit = remote_agent.send(request, local_addr, now).unwrap(); info!("sending request"); let reply = state diff --git a/librice-proto/src/gathering.rs b/librice-proto/src/gathering.rs index 4b5a44e..ba2dff7 100644 --- a/librice-proto/src/gathering.rs +++ b/librice-proto/src/gathering.rs @@ -13,12 +13,10 @@ use std::net::{IpAddr, SocketAddr}; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; -use crate::candidate::{Candidate, TransportType}; -use stun_proto::agent::{ - HandleStunReply, StunAgent, StunAgentPollRet, StunError, TcpBuffer, Transmit, -}; +use crate::candidate::{Candidate, TcpType, TransportType}; +use stun_proto::agent::{HandleStunReply, StunAgent, StunAgentPollRet, StunError, Transmit}; use stun_proto::types::attribute::XorMappedAddress; -use stun_proto::types::message::{Message, TransactionId, BINDING}; +use stun_proto::types::message::{Message, MessageHeader, StunParseError, TransactionId, BINDING}; fn address_is_ignorable(ip: IpAddr) -> bool { // TODO: add is_benchmarking() and is_documentation() when they become stable @@ -33,39 +31,30 @@ fn address_is_ignorable(ip: IpAddr) -> bool { #[derive(Debug)] enum RequestProtocol { - Udp(RequestUdp), + Udp, Tcp(RequestTcp), } impl RequestProtocol { fn transport(&self) -> TransportType { match self { - RequestProtocol::Udp(_) => TransportType::Udp, + RequestProtocol::Udp => TransportType::Udp, RequestProtocol::Tcp(_) => TransportType::Tcp, } } } -#[derive(Debug)] -struct RequestUdp { - // TODO: remove this Arc>, - agent: Arc>, - request: TransactionId, -} - #[derive(Debug)] struct RequestTcp { - // TODO: remove this Arc>, - agent: Option>>, request: Option, - active_addr: SocketAddr, - asked_for_agent: bool, - tcp_buffer: TcpBuffer, + tcp_buffer: Vec, } #[derive(Debug)] struct Request { protocol: RequestProtocol, + // TODO: remove this Arc>, + agent: Arc>, base_addr: SocketAddr, server: SocketAddr, other_preference: u32, @@ -73,20 +62,15 @@ struct Request { completed: bool, } -impl Request { - fn agent(&self) -> Option>> { - match &self.protocol { - RequestProtocol::Udp(udp) => Some(udp.agent.clone()), - RequestProtocol::Tcp(tcp) => tcp.agent.clone(), - } - } - - fn request(&self) -> Option { - match &self.protocol { - RequestProtocol::Udp(udp) => Some(udp.request), - RequestProtocol::Tcp(tcp) => tcp.request, - } - } +#[derive(Debug)] +struct PendingRequest { + completed: bool, + component_id: usize, + transport_type: TransportType, + local_addr: SocketAddr, + server_addr: SocketAddr, + other_preference: u32, + agent_request_time: Option, } /// Gatherer that uses STUN to gather a list of local candidates @@ -95,8 +79,10 @@ pub struct StunGatherer { component_id: usize, requests: Vec, pending_candidates: VecDeque, + produced_candidates: VecDeque, produced_i: usize, pending_transmits: VecDeque<(usize, Transmit<'static>)>, + pending_requests: VecDeque, } /// Return value for the gather state machine @@ -122,9 +108,8 @@ impl StunGatherer { stun_servers: Vec<(TransportType, SocketAddr)>, ) -> Self { // TODO: what to do on duplicate socket or stun_server addresses? - let mut requests = vec![]; let mut pending_candidates = VecDeque::new(); - let mut pending_transmits = VecDeque::new(); + let mut pending_requests = VecDeque::new(); for (i, (socket_transport, socket_addr)) in sockets.iter().enumerate() { if address_is_ignorable(socket_addr.ip()) { continue; @@ -205,61 +190,26 @@ impl StunGatherer { if socket_addr.is_ipv6() && !stun_addr.is_ipv6() { continue; } - let mut msg = Message::builder_request(BINDING); - msg.add_fingerprint().unwrap(); - let (protocol, base_addr) = match socket_transport { - TransportType::Udp => { - let transaction_id = msg.transaction_id(); - let mut agent = StunAgent::builder(*socket_transport, *socket_addr) - .remote_addr(*stun_addr) - .build(); - trace!("adding gather request {socket_transport} from {socket_addr} to {stun_addr}"); - pending_transmits.push_front(( - component_id, - agent.send(msg, *stun_addr).unwrap().into_owned(), - )); - ( - RequestProtocol::Udp(RequestUdp { - agent: Arc::new(Mutex::new(agent)), - request: transaction_id, - }), - *socket_addr, - ) - } - TransportType::Tcp => { - let active_addr = SocketAddr::new(socket_addr.ip(), 9); - trace!( - "adding gather request {active_addr} from {socket_addr} to {stun_addr}" - ); - ( - RequestProtocol::Tcp(RequestTcp { - agent: None, - active_addr, - request: None, - asked_for_agent: false, - tcp_buffer: Default::default(), - }), - active_addr, - ) - } - }; - requests.push(Request { - protocol, - base_addr, - server: *stun_addr, - other_preference, + pending_requests.push_front(PendingRequest { component_id, + transport_type: *socket_transport, + local_addr: *socket_addr, + server_addr: *stun_addr, + other_preference, completed: false, + agent_request_time: None, }); } } Self { component_id, - requests, + requests: vec![], pending_candidates, + produced_candidates: Default::default(), produced_i: 0, - pending_transmits, + pending_transmits: Default::default(), + pending_requests, } } @@ -272,44 +222,113 @@ impl StunGatherer { /// or [`GatherPoll::Complete`] is returned. #[tracing::instrument(name = "gatherer_poll", level = "trace", ret, err, skip(self))] pub fn poll(&mut self, now: Instant) -> Result { + let mut lowest_wait = None; + if let Some(cand) = self.pending_candidates.pop_back() { + info!("produced {cand:?}"); + self.produced_candidates.push_front(cand.clone()); return Ok(GatherPoll::NewCandidate(cand)); } + + for pending_request in self.pending_requests.iter_mut() { + if pending_request.completed { + continue; + } + + let (protocol, agent, base_addr) = match pending_request.transport_type { + TransportType::Udp => { + pending_request.completed = true; + let mut msg = Message::builder_request(BINDING); + msg.add_fingerprint().unwrap(); + let mut agent = + StunAgent::builder(TransportType::Udp, pending_request.local_addr) + .remote_addr(pending_request.server_addr) + .build(); + trace!( + "adding gather request UDP from {local_addr} to {server_addr}", + local_addr = pending_request.local_addr, + server_addr = pending_request.server_addr + ); + self.pending_transmits.push_front(( + pending_request.component_id, + agent + .send(msg, pending_request.server_addr, now) + .unwrap() + .into_owned(), + )); + ( + RequestProtocol::Udp, + Arc::new(Mutex::new(agent)), + pending_request.local_addr, + ) + } + TransportType::Tcp => { + if pending_request.agent_request_time.is_none() { + let active_addr = SocketAddr::new(pending_request.local_addr.ip(), 9); + trace!( + "adding gather request TCP {active_addr} from {local_addr} to {server_addr}", + local_addr = pending_request.local_addr, + server_addr = pending_request.server_addr, + ); + pending_request.local_addr = active_addr; + pending_request.agent_request_time = Some(now); + return Ok(GatherPoll::NeedAgent( + self.component_id, + TransportType::Tcp, + active_addr, + pending_request.server_addr, + )); + } + if lowest_wait.is_none() { + lowest_wait = Some(now + Duration::from_secs(600)); + } + continue; + } + }; + self.requests.push(Request { + protocol, + base_addr, + server: pending_request.server_addr, + other_preference: pending_request.other_preference, + component_id: pending_request.component_id, + completed: false, + agent, + }); + } + if let Some((component_id, transmit)) = self.pending_transmits.pop_back() { return Ok(GatherPoll::SendData(component_id, transmit)); } - let mut lowest_wait = None; + for request in self.requests.iter_mut() { if request.completed { continue; } - let _stun_request = if let Some(request) = request.request() { - request - } else { - match request.protocol { - RequestProtocol::Udp(ref _udp) => unreachable!(), - RequestProtocol::Tcp(ref mut tcp) => { - if !tcp.asked_for_agent { - tcp.asked_for_agent = true; - return Ok(GatherPoll::NeedAgent( - self.component_id, - TransportType::Tcp, - tcp.active_addr, - request.server, - )); - } else { - if lowest_wait.is_none() { - lowest_wait = Some(now + Duration::from_secs(600)); - } - continue; + match request.protocol { + RequestProtocol::Udp => (), + RequestProtocol::Tcp(ref mut tcp) => { + if tcp.request.is_none() { + let mut msg = Message::builder_request(BINDING); + msg.add_fingerprint().unwrap(); + tcp.request = Some(msg.transaction_id()); + return Ok(GatherPoll::SendData( + self.component_id, + request + .agent + .lock() + .unwrap() + .send(msg, request.server, now)? + .into_owned(), + )); + } else { + if lowest_wait.is_none() { + lowest_wait = Some(now + Duration::from_secs(600)); } + continue; } } }; - let Some(agent) = request.agent() else { - continue; - }; - let mut agent = agent.lock().unwrap(); + let mut agent = request.agent.lock().unwrap(); match agent.poll(now) { StunAgentPollRet::TransactionCancelled(_msg) => { request.completed = true; @@ -349,6 +368,7 @@ impl StunGatherer { foundation: String, base_addr: SocketAddr, server: SocketAddr, + tcp_type: Option, ) -> Option<(Candidate, Message<'_>)> { if let Ok(xor_addr) = response.attribute::() { let stun_addr = xor_addr.addr(response.transaction_id()); @@ -356,11 +376,16 @@ impl StunGatherer { let priority = Candidate::calculate_priority( crate::candidate::CandidateType::Host, transport, - None, + tcp_type, other_preference, component_id, ); - let builder = Candidate::builder( + let stun_addr = if tcp_type == Some(TcpType::Active) { + SocketAddr::new(stun_addr.ip(), 9) + } else { + stun_addr + }; + let mut builder = Candidate::builder( component_id, crate::candidate::CandidateType::ServerReflexive, transport, @@ -370,6 +395,9 @@ impl StunGatherer { .priority(priority) .base_address(base_addr) .related_address(server); + if let Some(tcp_type) = tcp_type { + builder = builder.tcp_type(tcp_type); + } let cand = builder.build(); return Some((cand, response)); } @@ -384,111 +412,136 @@ impl StunGatherer { level = "trace", ret, err, - skip(self, data) + skip(self, transmit) + fields( + transport = %transmit.transport, + from = %transmit.from, + to = %transmit.to, + ) )] - pub fn handle_data<'a>( - &'a mut self, - data: &'a [u8], - transport: TransportType, - from: SocketAddr, - to: SocketAddr, - ) -> Result { - trace!("received {} bytes", data.len()); + pub fn handle_data<'a>(&'a mut self, transmit: &Transmit<'a>) -> Result { + trace!("received {} bytes", transmit.data.len()); + trace!("requests {:?}", self.requests); for request in self.requests.iter_mut() { if !request.completed - && request.protocol.transport() == transport - && request.server == from - && request.base_addr == to + && request.protocol.transport() == transmit.transport + && request.server == transmit.from + && request.base_addr == transmit.to { - if let Some(agent) = request.agent() { - let mut agent_inner = agent.lock().unwrap(); - let mut handled = false; - match &mut request.protocol { - RequestProtocol::Tcp(ref mut tcp) => { - tcp.tcp_buffer.push_data(data); - while let Some(data) = tcp.tcp_buffer.pull_data() { - // TODO: need to handle heteregoneous TCP data - match Message::from_bytes(&data) { - Ok(msg) => { - if let HandleStunReply::StunResponse(response) = - agent_inner.handle_stun(msg, from) - { - if let Ok(_xor_addr) = - response.attribute::() - { - request.completed = true; + let mut agent_inner = request.agent.lock().unwrap(); + let mut handled = false; + match &mut request.protocol { + RequestProtocol::Tcp(ref mut tcp) => { + tcp.tcp_buffer.extend_from_slice(&transmit.data); + match MessageHeader::from_bytes(&tcp.tcp_buffer) { + // we fail for anything that is not a BINDING response + Ok(header) => { + if !header.get_type().is_response() + || header.get_type().method() != BINDING + { + request.completed = true; + return Ok(false); + } + } + Err(StunParseError::NotStun) => { + request.completed = true; + return Ok(false); + } + _ => (), + } + match Message::from_bytes(&tcp.tcp_buffer) { + Ok(msg) => { + trace!("parsed STUN message {msg}"); + if let HandleStunReply::StunResponse(response) = + agent_inner.handle_stun(msg, transmit.from) + { + request.completed = true; + for tcp_type in [TcpType::Active, TcpType::Passive] { + let foundation = self.produced_i.to_string(); + let base_addr = match tcp_type { + TcpType::Active => { + SocketAddr::new(request.base_addr.ip(), 9) } - let foundation = self.produced_i.to_string(); - if let Some((cand, _response)) = - Self::handle_stun_response( - response, - TransportType::Tcp, - request.other_preference, - request.component_id, - foundation, - request.base_addr, - request.server, - ) + TcpType::Passive => request.base_addr, + TcpType::So => unreachable!(), + }; + if let Some((cand, _response)) = Self::handle_stun_response( + response.clone(), + TransportType::Tcp, + request.other_preference, + request.component_id, + foundation, + base_addr, + request.server, + Some(tcp_type), + ) { + for c in self + .produced_candidates + .iter() + .chain(self.pending_candidates.iter()) { - self.produced_i += 1; - for c in self.pending_candidates.iter() { - if cand.redundant_with(c) { - trace!("redundant {cand:?}"); - continue; - } + if cand.redundant_with(c) { + trace!("redundant {cand:?}"); + return Ok(true); } - self.pending_candidates.push_front(cand.clone()); } - handled = true; + self.produced_i += 1; + self.pending_candidates.push_front(cand.clone()); } } - Err(_e) => request.completed = true, + handled = true; } } + // TODO: should signal closure of the TCP connection + Err(_e) => request.completed = true, } - RequestProtocol::Udp(_udp) => match Message::from_bytes(data) { - Ok(msg) => { - if let HandleStunReply::StunResponse(response) = - agent_inner.handle_stun(msg, from) - { - if let Ok(_xor_addr) = response.attribute::() + } + RequestProtocol::Udp => match Message::from_bytes(&transmit.data) { + Ok(msg) => { + trace!("parsed STUN message {msg}"); + if let HandleStunReply::StunResponse(response) = + agent_inner.handle_stun(msg, transmit.from) + { + request.completed = true; + let foundation = self.produced_i.to_string(); + if let Some((cand, _response)) = Self::handle_stun_response( + response, + TransportType::Udp, + request.other_preference, + request.component_id, + foundation, + request.base_addr, + request.server, + None, + ) { + for c in self + .produced_candidates + .iter() + .chain(self.pending_candidates.iter()) { - request.completed = true; - } - let foundation = self.produced_i.to_string(); - if let Some((cand, _response)) = Self::handle_stun_response( - response, - TransportType::Udp, - request.other_preference, - request.component_id, - foundation, - request.base_addr, - request.server, - ) { - self.produced_i += 1; - for c in self.pending_candidates.iter() { - if cand.redundant_with(c) { - trace!("redundant {cand:?}"); - continue; - } + if cand.redundant_with(c) { + trace!("redundant {cand:?}"); + return Ok(true); } - self.pending_candidates.push_front(cand.clone()); } - handled = true; + self.produced_i += 1; + self.pending_candidates.push_front(cand.clone()); } + handled = true; } - Err(_e) => (), - }, - } - return Ok(handled); + } + Err(_e) => (), + }, } + return Ok(handled); } } - Err(StunError::ResourceNotFound) + Ok(false) } /// Provide an agent as requested through [`GatherPoll::NeedAgent`]. The transport and address /// must match the value from the corresponding [`GatherPoll::NeedAgent`]. + #[tracing::instrument(name = "gatherer_add_agent", level = "debug", skip(self, agent))] pub fn add_agent( &mut self, transport: TransportType, @@ -496,34 +549,40 @@ impl StunGatherer { remote_addr: SocketAddr, agent: Result, ) { - for request in self.requests.iter_mut() { - if transport == request.protocol.transport() - && request.base_addr == local_addr - && request.server == remote_addr + trace!("{:?}", self.pending_requests); + for request in self.pending_requests.iter_mut() { + if !request.completed + && request.agent_request_time.is_some() + && transport == request.transport_type + && request.local_addr == local_addr + && request.server_addr == remote_addr { - if let RequestProtocol::Tcp(ref mut tcp) = request.protocol { - if tcp.agent.is_none() { - match agent { - Ok(mut agent) => { - let local_addr = agent.local_addr(); - let mut msg = Message::builder_request(BINDING); - msg.add_fingerprint().unwrap(); - let transaction_id = msg.transaction_id(); - self.pending_transmits.push_front(( - self.component_id, - agent.send(msg, request.server).unwrap().into_owned(), - )); - tcp.request = Some(transaction_id); - tcp.agent = Some(Arc::new(Mutex::new(agent))); - request.base_addr = local_addr; - } - Err(_e) => { - request.completed = true; - } - } - break; + info!( + "adding agent with local addr {:?}", + agent.as_ref().map(|agent| agent.local_addr()) + ); + request.completed = true; + match agent { + Ok(agent) => { + let local_addr = agent.local_addr(); + self.requests.push(Request { + protocol: RequestProtocol::Tcp(RequestTcp { + request: None, + tcp_buffer: vec![], + }), + agent: Arc::new(Mutex::new(agent)), + base_addr: local_addr, + server: request.server_addr, + other_preference: request.other_preference, + component_id: request.component_id, + completed: false, + }); + } + Err(_e) => { + request.completed = true; } } + break; } } } @@ -629,8 +688,14 @@ mod tests { .add_attribute(&XorMappedAddress::new(public_ip, response.transaction_id())) .unwrap(); assert!(matches!(gather.poll(now), Ok(GatherPoll::WaitUntil(_)))); + let response = response.build(); gather - .handle_data(&response.build(), TransportType::Udp, stun_addr, local_addr) + .handle_data(&Transmit::new( + &*response, + TransportType::Udp, + stun_addr, + local_addr, + )) .unwrap(); } else { error!("{ret:?}"); @@ -652,4 +717,93 @@ mod tests { assert!(matches!(gather.poll(now), Ok(GatherPoll::Complete))); assert!(matches!(gather.poll(now), Ok(GatherPoll::Complete))); } + + #[test] + fn stun_tcp() { + init(); + let local_addr = "192.168.1.1:1000".parse().unwrap(); + let stun_addr = "192.168.1.2:2000".parse().unwrap(); + let public_ip = "192.168.1.3:3000".parse().unwrap(); + let mut gather = StunGatherer::new( + 1, + vec![(TransportType::Tcp, local_addr)], + vec![(TransportType::Tcp, stun_addr)], + ); + let now = Instant::now(); + /* host candidate contents checked in `host_tcp()` */ + assert!(matches!( + gather.poll(now), + Ok(GatherPoll::NewCandidate(_cand)) + )); + assert!(matches!( + gather.poll(now), + Ok(GatherPoll::NewCandidate(_cand)) + )); + let ret = gather.poll(now); + if let Ok(GatherPoll::NeedAgent(_cid, transport, from, to)) = ret { + let agent = StunAgent::builder(transport, local_addr) + .remote_addr(to) + .build(); + gather.add_agent(transport, from, to, Ok(agent)); + } else { + error!("{ret:?}"); + unreachable!(); + } + + let ret = gather.poll(now); + if let Ok(GatherPoll::SendData(_cid, transmit)) = ret { + assert_eq!(transmit.from, local_addr); + assert_eq!(transmit.to, stun_addr); + let msg = Message::from_bytes(&transmit.data).unwrap(); + assert!(msg.has_method(BINDING)); + assert!(msg.has_class(MessageClass::Request)); + let mut response = Message::builder_success(&msg); + response + .add_attribute(&XorMappedAddress::new(public_ip, response.transaction_id())) + .unwrap(); + assert!(matches!(gather.poll(now), Ok(GatherPoll::WaitUntil(_)))); + let response = response.build(); + gather + .handle_data(&Transmit::new( + &*response, + TransportType::Tcp, + stun_addr, + local_addr, + )) + .unwrap(); + } else { + error!("{ret:?}"); + unreachable!(); + } + let ret = gather.poll(now); + if let Ok(GatherPoll::NewCandidate(cand)) = ret { + let local_addr = SocketAddr::new(local_addr.ip(), 9); + let public_ip = SocketAddr::new(public_ip.ip(), 9); + assert_eq!(cand.component_id, 1); + assert_eq!(cand.candidate_type, CandidateType::ServerReflexive); + assert_eq!(cand.transport_type, TransportType::Tcp); + assert_eq!(cand.address, public_ip); + assert_eq!(cand.base_address, local_addr); + assert_eq!(cand.tcp_type, Some(TcpType::Active)); + assert_eq!(cand.extensions, vec![]); + } else { + error!("{ret:?}"); + unreachable!(); + } + let ret = gather.poll(now); + if let Ok(GatherPoll::NewCandidate(cand)) = ret { + assert_eq!(cand.component_id, 1); + assert_eq!(cand.candidate_type, CandidateType::ServerReflexive); + assert_eq!(cand.transport_type, TransportType::Tcp); + assert_eq!(cand.address, public_ip); + assert_eq!(cand.base_address, local_addr); + assert_eq!(cand.tcp_type, Some(TcpType::Passive)); + assert_eq!(cand.extensions, vec![]); + } else { + error!("{ret:?}"); + unreachable!(); + } + assert!(matches!(gather.poll(now), Ok(GatherPoll::Complete))); + assert!(matches!(gather.poll(now), Ok(GatherPoll::Complete))); + } } diff --git a/librice-proto/src/stream.rs b/librice-proto/src/stream.rs index 21f5b93..44be658 100644 --- a/librice-proto/src/stream.rs +++ b/librice-proto/src/stream.rs @@ -586,12 +586,7 @@ impl StreamState { }; // XXX: is this enough to successfully route to the gatherer over the // connection check or component received handling? - let Ok(wake) = gather.handle_data( - &transmit.data, - transmit.transport, - transmit.from, - transmit.to, - ) else { + let Ok(wake) = gather.handle_data(transmit) else { return StreamIncomingDataReply::default(); }; if wake { diff --git a/librice/examples/icegather.rs b/librice/examples/icegather.rs index d521c9b..85725a6 100644 --- a/librice/examples/icegather.rs +++ b/librice/examples/icegather.rs @@ -24,12 +24,13 @@ fn main() -> io::Result<()> { task::block_on(async move { // non-existent //let stun_servers = ["192.168.1.200:3000".parse().unwrap()].to_vec(); - let stun_servers = [(TransportType::Udp, "127.0.0.1:3478".parse().unwrap())].to_vec(); + let stun_servers = ["192.168.20.28:3478".parse().unwrap()]; //let stun_servers = ["172.253.56.127:19302".parse().unwrap()].to_vec(); let agent = Agent::builder().build(); - for (tt, ss) in stun_servers { - agent.add_stun_server(tt, ss); + for ss in stun_servers { + agent.add_stun_server(TransportType::Udp, ss); + agent.add_stun_server(TransportType::Tcp, ss); } let stream = agent.add_stream(); let _comp = stream.add_component(); diff --git a/librice/src/component.rs b/librice/src/component.rs index 356704c..3c6d532 100644 --- a/librice/src/component.rs +++ b/librice/src/component.rs @@ -13,9 +13,12 @@ use std::sync::{Arc, Mutex, Weak}; use std::task::{Poll, Waker}; +use byteorder::{BigEndian, ByteOrder}; use librice_proto::candidate::CandidatePair; pub use librice_proto::component::ComponentConnectionState; +use stun_proto::agent::Transmit; +use stun_proto::types::TransportType; use crate::agent::AgentError; use crate::socket::StunChannel; @@ -101,7 +104,22 @@ impl Component { let transmit; { let local_agent = local_agent.lock().unwrap(); - transmit = local_agent.send_data(data, to); + let stun_transmit = local_agent.send_data(data, to); + transmit = match stun_transmit.transport { + TransportType::Udp => stun_transmit, + TransportType::Tcp => { + let mut data = Vec::with_capacity(stun_transmit.data.len()); + data.resize(2, 0); + BigEndian::write_u16(&mut data, stun_transmit.data.len() as u16); + data.extend_from_slice(&stun_transmit.data); + Transmit::new_owned( + data.into_boxed_slice(), + stun_transmit.transport, + stun_transmit.from, + stun_transmit.to, + ) + } + } } channel.send_to(&transmit.data, transmit.to).await?; diff --git a/librice/tests/common/mod.rs b/librice/tests/common/mod.rs index 53b39bc..60c72ca 100644 --- a/librice/tests/common/mod.rs +++ b/librice/tests/common/mod.rs @@ -14,6 +14,7 @@ use stun_proto::agent::StunAgent; use std::fmt::Display; use std::net::SocketAddr; use std::sync::Once; +use std::time::Instant; use tracing_subscriber::EnvFilter; @@ -121,33 +122,27 @@ pub async fn stund_tcp(listener: TcpListener) -> std::io::Result<()> { while let Some(Ok(mut stream)) = incoming.next().await { debug!("stund incoming tcp connection"); async_std::task::spawn(async move { - let mut tcp_buffer = stun_proto::agent::TcpBuffer::default(); let remote_addr = stream.peer_addr().unwrap(); let mut tcp_stun_agent = StunAgent::builder(stun_proto::types::TransportType::Tcp, local_addr) .remote_addr(remote_addr) .build(); - loop { - let mut data = vec![0; 1500]; - let size = warn_on_err(stream.read(&mut data).await, 0); - if size == 0 { - debug!("TCP connection with {remote_addr} closed"); - break; - } - debug!("stund tcp received {size} bytes"); - tcp_buffer.push_data(&data[..size]); - while let Some(data) = tcp_buffer.pull_data() { - if let Some((response, to)) = - handle_incoming_data(&data, remote_addr, &mut tcp_stun_agent) - { - if let Ok(data) = tcp_stun_agent.send(response, to) { - warn_on_err(stream.write_all(&data.data).await, ()); - } - } + let mut data = vec![0; 1500]; + let size = warn_on_err(stream.read(&mut data).await, 0); + if size == 0 { + debug!("TCP connection with {remote_addr} closed"); + return; + } + debug!("stund tcp received {size} bytes"); + if let Some((response, to)) = + handle_incoming_data(&data[..size], remote_addr, &mut tcp_stun_agent) + { + if let Ok(data) = tcp_stun_agent.send(response, to, Instant::now()) { + warn_on_err(stream.write_all(&data.data).await, ()); } - // XXX: Assumes that the stun packet arrives in a single packet - stream.shutdown(std::net::Shutdown::Read).unwrap(); } + // XXX: Assumes that the stun packet arrives in a single packet + stream.shutdown(std::net::Shutdown::Read).unwrap(); }); } Ok(()) diff --git a/librice/tests/stund.rs b/librice/tests/stund.rs index 6c53226..9578c73 100644 --- a/librice/tests/stund.rs +++ b/librice/tests/stund.rs @@ -9,7 +9,6 @@ use async_std::net::UdpSocket; use async_std::net::{TcpListener, TcpStream}; -use byteorder::{BigEndian, ByteOrder}; use futures::future::{AbortHandle, Abortable, Aborted}; use futures::{AsyncReadExt, AsyncWriteExt}; @@ -60,11 +59,7 @@ fn tcp_stund() { let mut socket = TcpStream::connect(stun_addr).await.unwrap(); let msg = Message::builder_request(BINDING); let msg_bytes = msg.build(); - let msg_bytes_len = msg_bytes.len() as u16; - let mut data = vec![0; 2]; - data.extend(msg_bytes); - BigEndian::write_u16(&mut data, msg_bytes_len); - socket.write_all(&data).await.unwrap(); + socket.write_all(&msg_bytes).await.unwrap(); debug!("sent to {:?}, {:?}", stun_addr, msg); let mut buf = [0; 1500]; @@ -76,10 +71,10 @@ fn tcp_stund() { "got {} bytes, buffer contains {} bytes", read_amount, read_position ); - if read_position < 2 { + if read_position < 20 { continue; } - match Message::from_bytes(&buf[2..read_position]) { + match Message::from_bytes(&buf[..read_position]) { Ok(msg) => { debug!("received response {}", msg); break;