Skip to content

Commit

Permalink
feat: Add support for HTTP(s) proxy (#611)
Browse files Browse the repository at this point in the history
* feat: Add support for HTTP(s) proxy

Signed off: Ivan Vaitusionak

Issue: #608

Attribute: Ivan Vaitusionak
  • Loading branch information
sushi-shi authored May 8, 2023
1 parent fe8a8e3 commit b863cec
Show file tree
Hide file tree
Showing 12 changed files with 410 additions and 63 deletions.
13 changes: 13 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions rumqttc/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new API's on v5 client for properties, eg `publish_with_props` etc
- Refactored `MqttOptions` to use `ConnectProperties` for some fields
- Other minor changes for MQTT5
- Added support for HTTP(s) proxy (#608)
- Added `proxy` feature gate
- Refactored `eventloop::network_connect` to allow setting proxy
- Added proxy options to `MqttOptions`

### Changed
- Remove `Box` on `Event::Incoming`
Expand Down
3 changes: 3 additions & 0 deletions rumqttc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ default = ["use-rustls"]
use-rustls = ["dep:tokio-rustls", "dep:rustls-pemfile", "dep:rustls-native-certs"]
use-native-tls = ["dep:tokio-native-tls", "dep:native-tls"]
websocket = ["dep:async-tungstenite", "dep:ws_stream_tungstenite", "dep:http"]
proxy = ["dep:async-http-proxy"]

[dependencies]
futures = "0.3"
Expand All @@ -41,6 +42,8 @@ tokio-native-tls = { version = "0.3.0", optional = true }
native-tls = { version = "0.2.8", optional = true }
# url
url = { version = "2", default-features = false, optional = true }
# proxy
async-http-proxy = { version = "1.2.5", features = ["runtime-tokio", "basic-auth"], optional = true }

[dev-dependencies]
color-backtrace = "0.4"
Expand Down
2 changes: 1 addition & 1 deletion rumqttc/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ eventloop.run() //return -> Result<MqttSt>
```


Keep additional functionality like gcloud jwt auth and http connect proxy out of rumqtt
Keep additional functionality like gcloud jwt auth out of rumqtt
-------

Prevents (some) conflicts w.r.t different versions of ring. conflicts because of rustls are still possible but atleast
Expand Down
73 changes: 73 additions & 0 deletions rumqttc/examples/websocket_proxy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#[cfg(all(feature = "websocket", feature = "proxy"))]
use rumqttc::{self, AsyncClient, Proxy, ProxyAuth, ProxyType, QoS, Transport};
#[cfg(all(feature = "websocket", feature = "proxy"))]
use std::{error::Error, time::Duration};
#[cfg(all(feature = "websocket", feature = "proxy"))]
use tokio::{task, time};

#[cfg(all(feature = "websocket", feature = "proxy"))]
#[tokio::main(worker_threads = 1)]
async fn main() -> Result<(), Box<dyn Error>> {
use rumqttc::MqttOptions;

pretty_env_logger::init();

// port parameter is ignored when scheme is websocket
let mut mqttoptions = MqttOptions::new(
"clientId-aSziq39Bp3",
"ws://broker.mqttdashboard.com:8000/mqtt",
8000,
);
mqttoptions.set_transport(Transport::Ws);
mqttoptions.set_keep_alive(Duration::from_secs(60));
// Presumes that there is a proxy server already set up listening on 127.0.0.1:8100
mqttoptions.set_proxy(Proxy {
ty: ProxyType::Http,
auth: ProxyAuth::None,
addr: "127.0.0.1".into(),
port: 8100,
});

let (client, mut eventloop) = AsyncClient::new(mqttoptions, 10);
task::spawn(async move {
requests(client).await;
time::sleep(Duration::from_secs(3)).await;
});

loop {
let event = eventloop.poll().await;
match event {
Ok(notif) => {
println!("Event = {notif:?}");
}
Err(err) => {
println!("Error = {err:?}");
return Ok(());
}
}
}
}

#[cfg(all(feature = "websocket", feature = "proxy"))]
async fn requests(client: AsyncClient) {
client
.subscribe("hello/world", QoS::AtMostOnce)
.await
.unwrap();

for i in 1..=10 {
client
.publish("hello/world", QoS::ExactlyOnce, false, vec![1; i])
.await
.unwrap();

time::sleep(Duration::from_secs(1)).await;
}

time::sleep(Duration::from_secs(120)).await;
}

#[cfg(not(all(feature = "websocket", feature = "proxy")))]
fn main() {
panic!("Enable websocket and proxy feature with `--features=websocket, proxy`");
}
94 changes: 67 additions & 27 deletions rumqttc/src/eventloop.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
use crate::tls;
use crate::{framed::Network, Transport};
use crate::{Incoming, MqttState, NetworkOptions, Packet, Request, StateError};
use crate::{MqttOptions, Outgoing};

use crate::framed::N;
use crate::mqttbytes::v4::*;
#[cfg(feature = "websocket")]
use async_tungstenite::tokio::connect_async;
#[cfg(all(feature = "use-rustls", feature = "websocket"))]
use async_tungstenite::tokio::connect_async_with_tls_connector;
use flume::{bounded, Receiver, Sender};
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::net::{lookup_host, TcpSocket, TcpStream};
use tokio::select;
use tokio::time::{self, Instant, Sleep};
#[cfg(feature = "websocket")]
use ws_stream_tungstenite::WsStream;

use std::io;
use std::net::SocketAddr;
#[cfg(unix)]
use std::path::Path;
use std::pin::Pin;
use std::time::Duration;
use std::vec::IntoIter;

#[cfg(unix)]
use {std::path::Path, tokio::net::UnixStream};

#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
use crate::tls;

#[cfg(feature = "websocket")]
use {
crate::websockets::{split_url, UrlError},
ws_stream_tungstenite::WsStream,
};

#[cfg(feature = "proxy")]
use crate::proxy::ProxyError;

/// Critical errors during eventloop polling
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
Expand All @@ -52,6 +56,12 @@ pub enum ConnectionError {
NotConnAck(Packet),
#[error("Requests done")]
RequestsDone,
#[cfg(feature = "websocket")]
#[error("Invalid Url: {0}")]
InvalidUrl(#[from] UrlError),
#[cfg(feature = "proxy")]
#[error("Proxy Connect: {0}")]
Proxy(#[from] ProxyError),
}

/// Eventloop with all the state of a connection
Expand Down Expand Up @@ -320,28 +330,53 @@ async fn network_connect(
options: &MqttOptions,
network_options: NetworkOptions,
) -> Result<Network, ConnectionError> {
let network = match options.transport() {
Transport::Tcp => {
let addr = format!("{}:{}", options.broker_addr, options.port);
let tcp_stream = socket_connect(addr, network_options).await?;
Network::new(tcp_stream, options.max_incoming_packet_size)
// Process Unix files early, as proxy is not supported for them.
#[cfg(unix)]
if matches!(options.transport(), Transport::Unix) {
let file = options.broker_addr.as_str();
let socket = UnixStream::connect(Path::new(file)).await?;
let network = Network::new(socket, options.max_incoming_packet_size);
return Ok(network);
}

// For websockets domain and port are taken directly from `broker_addr` (which is a url).
let (domain, port) = match options.transport() {
#[cfg(feature = "websocket")]
Transport::Ws => split_url(&options.broker_addr)?,
#[cfg(all(feature = "use-rustls", feature = "websocket"))]
Transport::Wss(_) => split_url(&options.broker_addr)?,
_ => options.broker_address(),
};

let tcp_stream: Box<dyn N> = {
#[cfg(feature = "proxy")]
match options.proxy() {
Some(proxy) => proxy.connect(&domain, port, network_options).await?,
None => {
let addr = format!("{domain}:{port}");
let tcp = socket_connect(addr, network_options).await?;
Box::new(tcp)
}
}
#[cfg(not(feature = "proxy"))]
{
let addr = format!("{domain}:{port}");
let tcp = socket_connect(addr, network_options).await?;
Box::new(tcp)
}
};

let network = match options.transport() {
Transport::Tcp => Network::new(tcp_stream, options.max_incoming_packet_size),
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
Transport::Tls(tls_config) => {
let addr = format!("{}:{}", options.broker_addr, options.port);
let tcp_stream = socket_connect(addr, network_options).await?;

let socket =
tls::tls_connect(&options.broker_addr, options.port, &tls_config, tcp_stream)
.await?;
Network::new(socket, options.max_incoming_packet_size)
}
#[cfg(unix)]
Transport::Unix => {
let file = options.broker_addr.as_str();
let socket = UnixStream::connect(Path::new(file)).await?;
Network::new(socket, options.max_incoming_packet_size)
}
Transport::Unix => unreachable!(),
#[cfg(feature = "websocket")]
Transport::Ws => {
let request = http::Request::builder()
Expand All @@ -350,7 +385,7 @@ async fn network_connect(
.header("Sec-WebSocket-Protocol", "mqttv3.1")
.body(())?;

let (socket, _) = connect_async(request).await?;
let (socket, _) = async_tungstenite::tokio::client_async(request, tcp_stream).await?;

Network::new(WsStream::new(socket), options.max_incoming_packet_size)
}
Expand All @@ -364,7 +399,12 @@ async fn network_connect(

let connector = tls::rustls_connector(&tls_config).await?;

let (socket, _) = connect_async_with_tls_connector(request, Some(connector)).await?;
let (socket, _) = async_tungstenite::tokio::client_async_tls_with_connector(
request,
tcp_stream,
Some(connector),
)
.await?;

Network::new(WsStream::new(socket), options.max_incoming_packet_size)
}
Expand Down
30 changes: 28 additions & 2 deletions rumqttc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,16 @@ mod eventloop;
mod framed;
pub mod mqttbytes;
mod state;
pub mod v5;

#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
mod tls;
pub mod v5;

#[cfg(feature = "websocket")]
mod websockets;

#[cfg(feature = "proxy")]
mod proxy;

pub use client::{
AsyncClient, Client, ClientError, Connection, Iter, RecvError, RecvTimeoutError, TryRecvError,
Expand All @@ -128,6 +135,9 @@ pub use tokio_rustls;
#[cfg(feature = "use-rustls")]
use tokio_rustls::rustls::{Certificate, ClientConfig, RootCertStore};

#[cfg(feature = "proxy")]
pub use proxy::{Proxy, ProxyAuth, ProxyType};

pub type Incoming = Packet;

/// Current outgoing activity on the eventloop
Expand Down Expand Up @@ -297,7 +307,7 @@ impl Transport {
}

/// TLS configuration method
#[derive(Clone)]
#[derive(Clone, Debug)]
#[cfg(any(feature = "use-rustls", feature = "use-native-tls"))]
pub enum TlsConfiguration {
#[cfg(feature = "use-rustls")]
Expand Down Expand Up @@ -426,6 +436,9 @@ pub struct MqttOptions {
/// If set to `true` MQTT acknowledgements are not sent automatically.
/// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method.
manual_acks: bool,
#[cfg(feature = "proxy")]
/// Proxy configuration.
proxy: Option<Proxy>,
}

impl MqttOptions {
Expand Down Expand Up @@ -466,6 +479,8 @@ impl MqttOptions {
inflight: 100,
last_will: None,
manual_acks: false,
#[cfg(feature = "proxy")]
proxy: None,
}
}

Expand Down Expand Up @@ -635,6 +650,17 @@ impl MqttOptions {
pub fn manual_acks(&self) -> bool {
self.manual_acks
}

#[cfg(feature = "proxy")]
pub fn set_proxy(&mut self, proxy: Proxy) -> &mut Self {
self.proxy = Some(proxy);
self
}

#[cfg(feature = "proxy")]
pub fn proxy(&self) -> Option<Proxy> {
self.proxy.clone()
}
}

#[cfg(feature = "url")]
Expand Down
Loading

0 comments on commit b863cec

Please sign in to comment.