diff --git a/src/forward_traffic.rs b/src/forward_traffic.rs index 6251895..f7231e3 100644 --- a/src/forward_traffic.rs +++ b/src/forward_traffic.rs @@ -100,37 +100,38 @@ async fn maybe_timeout( /// Forward all complete datagrams in `buffer` to `udp_out`. /// Returns the number of processed bytes. -async fn forward_datagrams_in_buffer(udp_out: &UdpSocket, buffer: &[u8]) -> io::Result { - let mut header_start = 0; +async fn forward_datagrams_in_buffer(udp_out: &UdpSocket, mut buffer: &[u8]) -> io::Result { + let original_buffer_len = buffer.len(); loop { - let header_end = header_start + HEADER_LEN; - // "parse" the header - let header = match buffer.get(header_start..header_end) { - Some(header) => <[u8; HEADER_LEN]>::try_from(header).unwrap(), - // Buffer does not contain entire header for next datagram - None => break Ok(header_start), - }; - let datagram_len = usize::from(u16::from_be_bytes(header)); - let datagram_start = header_end; - let datagram_end = datagram_start + datagram_len; - - let datagram_data = match buffer.get(datagram_start..datagram_end) { - Some(datagram_data) => datagram_data, + let (datagram_data, tail) = match split_first_datagram(buffer) { + Some(data_tuple) => data_tuple, // The buffer does not contain the entire datagram - None => break Ok(header_start), + None => break Ok(original_buffer_len - buffer.len()), }; let udp_write_len = udp_out.send(datagram_data).await?; assert_eq!( - udp_write_len, datagram_len, + udp_write_len, + datagram_data.len(), "Did not send entire UDP datagram" ); - log::trace!("Forwarded {} byte TCP->UDP", datagram_len); + log::trace!("Forwarded {} byte TCP->UDP", datagram_data.len()); - header_start = datagram_end; + buffer = tail; } } +/// Parses the header at the beginning of the `buffer` and if it contains a full +/// `udp-to-tcp` datagram it splits the buffer and returns the datagram data and +/// buffer tail as two separate slices: `(datagram_data, tail)` +fn split_first_datagram(buffer: &[u8]) -> Option<(&[u8], &[u8])> { + let (header, tail) = buffer.split_first_chunk::()?; + let datagram_len = usize::from(u16::from_be_bytes(*header)); + let datagram_data = tail.get(..datagram_len)?; + let tail = tail.get(datagram_len..)?; + Some((datagram_data, tail)) +} + /// Reads datagrams from `udp_in` and writes them (with the 16 bit header containing the length) /// to `tcp_out` indefinitely, or until an IO error happens on either socket. async fn process_udp2tcp( diff --git a/src/lib.rs b/src/lib.rs index ad39936..7f55048 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -80,6 +80,7 @@ #![forbid(unsafe_code)] #![deny(clippy::all)] +#![feature(slice_first_last_chunk)] pub mod tcp2udp; pub mod udp2tcp;