Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
yinqiwen committed Dec 12, 2023
1 parent 828c25c commit 6fcde16
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 19 deletions.
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ pki-types = { package = "rustls-pki-types", version = "1" }
s2n-quic = { version = "1", optional = true }
quinn = { version = "0.10", optional = true }


[target.'cfg(not(target_os = "windows"))'.dependencies]
aws-lc-rs = { version = "1.5", features = ["bindgen"], optional = true }

[target.'cfg(not(target_arch = "arm"))'.dependencies]
mimalloc = { version = "*" }


[features]
default = ["s2n_quic"]
s2n_quic = ["dep:s2n-quic", "dep:aws-lc-rs"]
Expand Down
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ Rust practice project

## Features

- QUIC/TLS Transport
- HTTP/Socks5 Proxy
- QUIC/TLS Transport
- HTTP/Socks5 Proxy


# Getting Started
Expand Down Expand Up @@ -40,5 +40,4 @@ $ ./rsnova --role client --cert ./cert.pem --listen 127.0.0.1:48100 --remote tl
```

**Use Proxy**
Now you can configure `socks5://127.0.0.1:48100` or `http://127.0.0.1:48100` as the proxy for your browser/tools.

Now you can configure `socks5://127.0.0.1:48100` or `http://127.0.0.1:48100` as the proxy for your browser/tools.
11 changes: 11 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ mod mux;
mod tunnel;
mod utils;

#[cfg(not(target_arch = "arm"))]
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;

#[derive(ValueEnum, Clone, Debug)]
enum Protocol {
Tls,
Expand Down Expand Up @@ -66,6 +70,9 @@ struct Args {
#[clap(default_value = "1048576", long)]
thread_stack_size: usize,

#[clap(default_value = "30", long)]
idle_timeout_secs: usize,

#[clap(default_value = "mydomain.io", long)]
tls_host: String,

Expand Down Expand Up @@ -117,6 +124,7 @@ async fn service_main(args: &Args) -> anyhow::Result<()> {
args.cert.as_ref().unwrap(),
&args.tls_host,
args.concurrent,
args.idle_timeout_secs,
)
.await?
}
Expand All @@ -126,6 +134,7 @@ async fn service_main(args: &Args) -> anyhow::Result<()> {
args.cert.as_ref().unwrap(),
&args.tls_host,
args.concurrent,
args.idle_timeout_secs,
)
.await?
}
Expand Down Expand Up @@ -160,6 +169,7 @@ async fn service_main(args: &Args) -> anyhow::Result<()> {
&args.listen,
args.cert.as_ref().unwrap(),
args.key.as_ref().unwrap(),
args.idle_timeout_secs,
)
.await
{
Expand All @@ -171,6 +181,7 @@ async fn service_main(args: &Args) -> anyhow::Result<()> {
&args.listen,
args.cert.as_ref().unwrap(),
args.key.as_ref().unwrap(),
args.idle_timeout_secs,
)
.await
{
Expand Down
3 changes: 2 additions & 1 deletion src/tunnel/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ impl<T: MuxConnection> MuxClientTrait for MuxClient<T> {
pub(crate) async fn mux_client_loop<T: MuxClientTrait>(
mut client: T,
mut receiver: mpsc::UnboundedReceiver<Message>,
idle_timeout_secs: usize,
) where
<T as MuxClientTrait>::SendStream: 'static,
<T as MuxClientTrait>::RecvStream: 'static,
Expand Down Expand Up @@ -130,7 +131,7 @@ pub(crate) async fn mux_client_loop<T: MuxClientTrait>(
&mut recv,
&mut send,
);
if let Err(e) = stream.transfer().await {
if let Err(e) = stream.transfer(idle_timeout_secs).await {
tracing::error!("transfer finish:{}", e);
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/tunnel/s2n_quic_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ use super::client::MuxConnection;
use super::Message;

pub struct S2NQuicConnection {
// inner: Option<quinn::Connection>,
// endpoint: Arc<quinn::Endpoint>,
inner: Option<s2n_quic::Connection>,
endpoint: Arc<s2n_quic::client::Client>,
}
Expand Down Expand Up @@ -77,6 +75,7 @@ impl MuxClient<S2NQuicConnection> {
cert_path: &Path,
host: &String,
count: usize,
idle_timeout_secs: usize,
) -> anyhow::Result<mpsc::UnboundedSender<Message>> {
match url.scheme() {
"quic" => {
Expand Down Expand Up @@ -107,7 +106,7 @@ impl MuxClient<S2NQuicConnection> {
}
client.conns.push(quic_conn);
}
tokio::spawn(mux_client_loop(client, receiver));
tokio::spawn(mux_client_loop(client, receiver, idle_timeout_secs));
Ok(sender)
}
_ => Err(anyhow!("unsupported schema:{:?}", url.scheme())),
Expand Down Expand Up @@ -145,6 +144,7 @@ pub async fn new_quic_client(
cert_path: &Path,
host: &String,
count: usize,
idle_timeout_secs: usize,
) -> anyhow::Result<mpsc::UnboundedSender<Message>> {
MuxClient::<S2NQuicConnection>::from(url, cert_path, host, count).await
MuxClient::<S2NQuicConnection>::from(url, cert_path, host, count, idle_timeout_secs).await
}
6 changes: 5 additions & 1 deletion src/tunnel/s2n_quic_remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub async fn start_quic_remote_server(
listen: &SocketAddr,
cert_path: &Path,
key_path: &Path,
idle_timeout_secs: usize,
) -> Result<()> {
let io = s2n_quic::provider::io::tokio::Builder::default()
.with_receive_address(*listen)?
Expand All @@ -30,7 +31,10 @@ pub async fn start_quic_remote_server(
metrics::increment_gauge!("quic_server_proxy_streams", 1.0);
let (mut recv_stream, mut send_stream) = stream.split();
tokio::spawn(async move {
if let Err(e) = handle_server_stream(&mut recv_stream, &mut send_stream).await {
if let Err(e) =
handle_server_stream(&mut recv_stream, &mut send_stream, idle_timeout_secs)
.await
{
tracing::error!("failed: {reason}", reason = e.to_string());
}
metrics::decrement_gauge!("quic_server_proxy_streams", 1.0);
Expand Down
9 changes: 5 additions & 4 deletions src/tunnel/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,18 @@ where
}
}

pub async fn transfer(&mut self) -> Result<()> {
pub async fn transfer(&mut self, idle_timeout_secs: usize) -> Result<()> {
let state = Arc::new(TransferState::new());
let client_to_server = timeout_copy(
&mut self.local_reader,
&mut self.remote_writer,
DEFAULT_TIMEOUT_SECS,
idle_timeout_secs as u64,
state.clone(),
);
let server_to_client = timeout_copy(
&mut self.remote_reader,
&mut self.local_writer,
DEFAULT_TIMEOUT_SECS,
idle_timeout_secs as u64,
state.clone(),
);
try_join(client_to_server, server_to_client).await?;
Expand All @@ -136,6 +136,7 @@ where
pub async fn handle_server_stream<'a, LR: AsyncReadExt + Unpin, LW: AsyncWriteExt + Unpin>(
mut lr: &'a mut LR,
lw: &'a mut LW,
idle_timeout_secs: usize,
) -> Result<()> {
let timeout_secs = Duration::from_secs(DEFAULT_TIMEOUT_SECS);
match timeout(timeout_secs, event::read_event(&mut lr)).await? {
Expand Down Expand Up @@ -163,7 +164,7 @@ pub async fn handle_server_stream<'a, LR: AsyncReadExt + Unpin, LW: AsyncWriteEx
tokio::net::tcp::ReadHalf<'_>,
tokio::net::tcp::WriteHalf<'_>,
> = Stream::new(lr, lw, &mut remote_receiver, &mut remote_sender);
stream.transfer().await
stream.transfer(idle_timeout_secs).await
}
}
}
6 changes: 4 additions & 2 deletions src/tunnel/tls_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ impl MuxClient<TlsConnection> {
cert_path: &Path,
host: &String,
count: usize,
idle_timeout_secs: usize,
) -> anyhow::Result<mpsc::UnboundedSender<Message>> {
match url.scheme() {
"tls" => {
Expand Down Expand Up @@ -98,7 +99,7 @@ impl MuxClient<TlsConnection> {
}
client.conns.push(tls_conn);
}
tokio::spawn(mux_client_loop(client, receiver));
tokio::spawn(mux_client_loop(client, receiver, idle_timeout_secs));
Ok(sender)
}
_ => Err(anyhow!("unsupported schema:{:?}", url.scheme())),
Expand Down Expand Up @@ -147,6 +148,7 @@ pub async fn new_tls_client(
cert_path: &Path,
host: &String,
count: usize,
idle_timeout_secs: usize,
) -> anyhow::Result<mpsc::UnboundedSender<Message>> {
MuxClient::<TlsConnection>::from(url, cert_path, host, count).await
MuxClient::<TlsConnection>::from(url, cert_path, host, count, idle_timeout_secs).await
}
9 changes: 7 additions & 2 deletions src/tunnel/tls_remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub async fn start_tls_remote_server(
listen: &SocketAddr,
cert_path: &Path,
key_path: &Path,
idle_timeout_secs: usize,
) -> Result<()> {
// let key = fs::read(key_path.clone()).context("failed to read private key")?;
// let key = rsa_private_keys(&mut BufReader::new(File::open(key_path)?))
Expand Down Expand Up @@ -86,7 +87,7 @@ pub async fn start_tls_remote_server(
let fut = async move {
let stream = acceptor.accept(stream).await?;
tracing::info!("TLS connection incoming");
handle_tls_connection(stream, conn_id).await?;
handle_tls_connection(stream, conn_id, idle_timeout_secs).await?;
Ok(()) as Result<()>
};

Expand All @@ -102,6 +103,7 @@ pub async fn start_tls_remote_server(
async fn handle_tls_connection(
conn: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
id: u32,
idle_timeout_secs: usize,
) -> Result<()> {
let (r, w) = tokio::io::split(conn);
let mux_conn = mux::Connection::new(r, w, mux::Mode::Server, id);
Expand All @@ -112,7 +114,10 @@ async fn handle_tls_connection(
tokio::spawn(async move {
let stream_id = stream.id();
let (mut stream_reader, mut stream_writer) = tokio::io::split(stream);
if let Err(e) = handle_server_stream(&mut stream_reader, &mut stream_writer).await {
if let Err(e) =
handle_server_stream(&mut stream_reader, &mut stream_writer, idle_timeout_secs)
.await
{
tracing::error!(
"[{}/{}]failed: {reason}",
id,
Expand Down

0 comments on commit 6fcde16

Please sign in to comment.