Skip to content

Commit

Permalink
Merge pull request #341 from kodegenix/master
Browse files Browse the repository at this point in the history
Support for manual acks
  • Loading branch information
de-sh authored Feb 9, 2022
2 parents 7e793da + 35e341b commit 6406090
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 12 deletions.
77 changes: 77 additions & 0 deletions rumqttc/examples/async_manual_acks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use tokio::{task, time};

use rumqttc::{self, AsyncClient, Event, EventLoop, Incoming, MqttOptions, QoS};
use std::error::Error;
use std::time::Duration;

fn create_conn() -> (AsyncClient, EventLoop) {
let mut mqttoptions = MqttOptions::new("test-1", "localhost", 1883);
mqttoptions
.set_keep_alive(Duration::from_secs(5))
.set_manual_acks(true)
.set_clean_session(false);

AsyncClient::new(mqttoptions, 10)
}


#[tokio::main(worker_threads = 1)]
async fn main() -> Result<(), Box<dyn Error>> {
pretty_env_logger::init();

// create mqtt connection with clean_session = false and manual_acks = true
let (client, mut eventloop) = create_conn();

// subscribe example topic
client
.subscribe("hello/world", QoS::AtLeastOnce)
.await
.unwrap();

task::spawn(async move {
// send some messages to example topic and disconnect
requests(client.clone()).await;
client.disconnect().await.unwrap()
});

loop {
// get subscribed messages without acking
let event = eventloop.poll().await;
println!("{:?}", event);
if let Err(_err) = event {
// break loop on disconnection
break;
}
}

// create new broker connection
let (client, mut eventloop) = create_conn();

loop {
// previously published messages should be republished after reconnection.
let event = eventloop.poll().await;
println!("{:?}", event);
match event {
Ok(Event::Incoming(Incoming::Publish(publish))) => {
// this time we will ack incoming publishes.
// Its important not to block eventloop as this can cause deadlock.
let c = client.clone();
tokio::spawn(async move {
c.ack(&publish).await.unwrap();
});
}
_ => {}
}
}
}

async fn requests(client: AsyncClient) {
for i in 1..=10 {
client
.publish("hello/world", QoS::AtLeastOnce, false, vec![1; i])
.await
.unwrap();

time::sleep(Duration::from_secs(1)).await;
}
}
57 changes: 57 additions & 0 deletions rumqttc/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,33 @@ impl AsyncClient {
Ok(())
}

/// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set.
pub async fn ack(
&self,
publish: &Publish
) -> Result<(), ClientError>
{
let ack = get_ack_req(publish);

if let Some(ack) = ack {
self.request_tx.send(ack).await?;
}
Ok(())
}

/// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set.
pub fn try_ack(
&self,
publish: &Publish
) -> Result<(), ClientError>
{
let ack = get_ack_req(publish);
if let Some(ack) = ack {
self.request_tx.try_send(ack)?;
}
Ok(())
}

/// Sends a MQTT Publish to the eventloop
pub async fn publish_bytes<S>(
&self,
Expand Down Expand Up @@ -186,6 +213,15 @@ impl AsyncClient {
}
}

fn get_ack_req(publish: &Publish) -> Option<Request> {
let ack = match publish.qos {
QoS::AtMostOnce => return None,
QoS::AtLeastOnce => Request::PubAck(PubAck::new(publish.pkid)),
QoS::ExactlyOnce => Request::PubRec(PubRec::new(publish.pkid))
};
Some(ack)
}

/// `Client` to communicate with MQTT eventloop `Connection`.
///
/// Client is cloneable and can be used to synchronously Publish, Subscribe.
Expand Down Expand Up @@ -240,6 +276,27 @@ impl Client {
Ok(())
}

/// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set.
pub fn ack(
&self,
publish: &Publish
) -> Result<(), ClientError>
{
pollster::block_on(self.client.ack(publish))?;
Ok(())
}

/// Sends a MQTT PubAck to the eventloop. Only needed in if `manual_acks` flag is set.
pub fn try_ack(
&self,
publish: &Publish
) -> Result<(), ClientError>
{
self.client.try_ack(publish)?;
Ok(())
}


/// Sends a MQTT Subscribe to the eventloop
pub fn subscribe<S: Into<String>>(&mut self, topic: S, qos: QoS) -> Result<(), ClientError> {
pollster::block_on(self.client.subscribe(topic, qos))?;
Expand Down
3 changes: 2 additions & 1 deletion rumqttc/src/eventloop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,11 @@ impl EventLoop {
let pending = Vec::new();
let pending = pending.into_iter();
let max_inflight = options.inflight;
let manual_acks = options.manual_acks;

EventLoop {
options,
state: MqttState::new(max_inflight),
state: MqttState::new(max_inflight, manual_acks),
requests_tx,
requests_rx,
pending,
Expand Down
17 changes: 17 additions & 0 deletions rumqttc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ pub struct MqttOptions {
last_will: Option<LastWill>,
/// Connection timeout
conn_timeout: u64,
/// If set to `true` MQTT acknowledgements are not sent automatically.
/// Every incoming publish packet must be manually acknowledged with `client.ack(...)` method.
manual_acks: bool,
}

impl MqttOptions {
Expand All @@ -358,6 +361,7 @@ impl MqttOptions {
inflight: 100,
last_will: None,
conn_timeout: 5,
manual_acks: false,
}
}

Expand Down Expand Up @@ -491,6 +495,17 @@ impl MqttOptions {
pub fn connection_timeout(&self) -> u64 {
self.conn_timeout
}

/// set manual acknowledgements
pub fn set_manual_acks(&mut self, manual_acks: bool) -> &mut Self {
self.manual_acks = manual_acks;
self
}

/// get manual acknowledgements
pub fn manual_acks(&self) -> bool {
self.manual_acks
}
}

#[cfg(feature = "url")]
Expand Down Expand Up @@ -657,6 +672,7 @@ impl std::convert::TryFrom<url::Url> for MqttOptions {
inflight,
last_will: None,
conn_timeout,
manual_acks: false
})
}
}
Expand All @@ -679,6 +695,7 @@ impl Debug for MqttOptions {
.field("inflight", &self.inflight)
.field("last_will", &self.last_will)
.field("conn_timeout", &self.conn_timeout)
.field("manual_acks", &self.manual_acks)
.finish()
}
}
Expand Down
87 changes: 76 additions & 11 deletions rumqttc/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,15 @@ pub struct MqttState {
pub events: VecDeque<Event>,
/// Write buffer
pub write: BytesMut,
/// Indicates if acknowledgements should be send immediately
pub manual_acks: bool,
}

impl MqttState {
/// Creates new mqtt state. Same state should be used during a
/// connection for persistent sessions while new state should
/// instantiated for clean sessions
pub fn new(max_inflight: u16) -> Self {
pub fn new(max_inflight: u16, manual_acks: bool) -> Self {
MqttState {
await_pingresp: false,
collision_ping_count: 0,
Expand All @@ -98,6 +100,7 @@ impl MqttState {
// TODO: Optimize these sizes later
events: VecDeque::with_capacity(100),
write: BytesMut::with_capacity(10 * 1024),
manual_acks
}
}

Expand Down Expand Up @@ -145,6 +148,8 @@ impl MqttState {
Request::Unsubscribe(unsubscribe) => self.outgoing_unsubscribe(unsubscribe)?,
Request::PingReq => self.outgoing_ping()?,
Request::Disconnect => self.outgoing_disconnect()?,
Request::PubAck(puback) => self.outgoing_puback(puback)?,
Request::PubRec(pubrec) => self.outgoing_pubrec(pubrec)?,
_ => unimplemented!(),
};

Expand Down Expand Up @@ -194,19 +199,19 @@ impl MqttState {
match qos {
QoS::AtMostOnce => Ok(()),
QoS::AtLeastOnce => {
let pkid = publish.pkid;
PubAck::new(pkid).write(&mut self.write)?;
let event = Event::Outgoing(Outgoing::PubAck(pkid));
self.events.push_back(event);

if !self.manual_acks {
let puback = PubAck::new(publish.pkid);
self.outgoing_puback(puback)?
}
Ok(())
}
QoS::ExactlyOnce => {
let pkid = publish.pkid;
PubRec::new(pkid).write(&mut self.write)?;
self.incoming_pub[pkid as usize] = Some(pkid);
let event = Event::Outgoing(Outgoing::PubRec(pkid));
self.events.push_back(event);
if !self.manual_acks {
let pubrec = PubRec::new(pkid);
self.outgoing_pubrec(pubrec)?;
}
Ok(())
}
}
Expand Down Expand Up @@ -347,6 +352,20 @@ impl MqttState {
Ok(())
}

fn outgoing_puback(&mut self, puback: PubAck) -> Result<(), StateError> {
puback.write(&mut self.write)?;
let event = Event::Outgoing(Outgoing::PubAck(puback.pkid));
self.events.push_back(event);
Ok(())
}

fn outgoing_pubrec(&mut self, pubrec: PubRec) -> Result<(), StateError> {
pubrec.write(&mut self.write)?;
let event = Event::Outgoing(Outgoing::PubRec(pubrec.pkid));
self.events.push_back(event);
Ok(())
}

/// check when the last control packet/pingreq packet is received and return
/// the status which tells if keep alive time has exceeded
/// NOTE: status will be checked for zero keepalive times also
Expand Down Expand Up @@ -468,7 +487,7 @@ impl MqttState {
#[cfg(test)]
mod test {
use super::{MqttState, StateError};
use crate::{Incoming, MqttOptions, Request};
use crate::{Event, Incoming, MqttOptions, Outgoing, Request};
use mqttbytes::v4::*;
use mqttbytes::*;

Expand All @@ -492,7 +511,7 @@ mod test {
}

fn build_mqttstate() -> MqttState {
MqttState::new(100)
MqttState::new(100, false)
}

#[test]
Expand Down Expand Up @@ -570,6 +589,52 @@ mod test {
assert_eq!(pkid, 3);
}

#[test]
fn incoming_publish_should_be_acked() {
let mut mqtt = build_mqttstate();

// QoS0, 1, 2 Publishes
let publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);

mqtt.handle_incoming_publish(&publish1).unwrap();
mqtt.handle_incoming_publish(&publish2).unwrap();
mqtt.handle_incoming_publish(&publish3).unwrap();

if let Event::Outgoing(Outgoing::PubAck(pkid)) = mqtt.events[0] {
assert_eq!(pkid, 2);
} else {
panic!("missing puback")
}

if let Event::Outgoing(Outgoing::PubRec(pkid)) = mqtt.events[1] {
assert_eq!(pkid, 3);
} else {
panic!("missing PubRec")
}
}

#[test]
fn incoming_publish_should_not_be_acked_with_manual_acks() {
let mut mqtt = build_mqttstate();
mqtt.manual_acks = true;

// QoS0, 1, 2 Publishes
let publish1 = build_incoming_publish(QoS::AtMostOnce, 1);
let publish2 = build_incoming_publish(QoS::AtLeastOnce, 2);
let publish3 = build_incoming_publish(QoS::ExactlyOnce, 3);

mqtt.handle_incoming_publish(&publish1).unwrap();
mqtt.handle_incoming_publish(&publish2).unwrap();
mqtt.handle_incoming_publish(&publish3).unwrap();

let pkid = mqtt.incoming_pub[3].unwrap();
assert_eq!(pkid, 3);

assert!(mqtt.events.is_empty());
}

#[test]
fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() {
let mut mqtt = build_mqttstate();
Expand Down

0 comments on commit 6406090

Please sign in to comment.