Skip to content

Commit

Permalink
Merge pull request #766 from redboltz/add_protocol_error_check
Browse files Browse the repository at this point in the history
Added protocol error checking code.
  • Loading branch information
redboltz authored Dec 16, 2020
2 parents 5333681 + 223c833 commit 364be87
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 4 deletions.
65 changes: 64 additions & 1 deletion include/mqtt/control_packet_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <ostream>

#include <mqtt/namespace.hpp>
#include <mqtt/optional.hpp>

namespace MQTT_NS {

Expand All @@ -33,7 +34,7 @@ enum class control_packet_type : std::uint8_t {
disconnect = 0b11100000, // 14
auth = 0b11110000, // 15

}; // namespace control_packet_type
};

constexpr control_packet_type get_control_packet_type(std::uint8_t v) {
return static_cast<control_packet_type>(v & 0b11110000);
Expand Down Expand Up @@ -68,6 +69,68 @@ std::ostream& operator<<(std::ostream& os, control_packet_type val)
return os;
}

enum class control_packet_reserved_bits : std::uint8_t {
connect = 0b00000000,
connack = 0b00000000,
// publish = dup qos retain,
puback = 0b00000000,
pubrec = 0b00000000,
pubrel = 0b00000010,
pubcomp = 0b00000000,
subscribe = 0b00000010,
suback = 0b00000000,
unsubscribe = 0b00000010,
unsuback = 0b00000000,
pingreq = 0b00000000,
pingresp = 0b00000000,
disconnect = 0b00000000,
auth = 0b00000000,
};

inline optional<control_packet_type> get_control_packet_type_with_check(std::uint8_t v) {
auto cpt = static_cast<control_packet_type>(v & 0b11110000);
auto valid =
[&] {
auto rsv = static_cast<control_packet_reserved_bits>(v & 0b00001111);
switch (cpt) {
case control_packet_type::connect:
return rsv == control_packet_reserved_bits::connect;
case control_packet_type::connack:
return rsv == control_packet_reserved_bits::connack;
case control_packet_type::publish:
return true;
case control_packet_type::puback:
return rsv == control_packet_reserved_bits::puback;
case control_packet_type::pubrec:
return rsv == control_packet_reserved_bits::pubrec;
case control_packet_type::pubrel:
return rsv == control_packet_reserved_bits::pubrel;
case control_packet_type::pubcomp:
return rsv == control_packet_reserved_bits::pubcomp;
case control_packet_type::subscribe:
return rsv == control_packet_reserved_bits::subscribe;
case control_packet_type::suback:
return rsv == control_packet_reserved_bits::suback;
case control_packet_type::unsubscribe:
return rsv == control_packet_reserved_bits::unsubscribe;
case control_packet_type::unsuback:
return rsv == control_packet_reserved_bits::unsuback;
case control_packet_type::pingreq:
return rsv == control_packet_reserved_bits::pingreq;
case control_packet_type::pingresp:
return rsv == control_packet_reserved_bits::pingresp;
case control_packet_type::disconnect:
return rsv == control_packet_reserved_bits::disconnect;
case control_packet_type::auth:
return rsv == control_packet_reserved_bits::auth;
default:
return false;
}
} ();
if (valid) return cpt;
return nullopt;
}

} // namespace MQTT_NS

#endif // MQTT_CONTROL_PACKET_TYPE_HPP
44 changes: 41 additions & 3 deletions include/mqtt/endpoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4295,10 +4295,22 @@ class endpoint : public std::enable_shared_from_this<endpoint<Mutex, LockGuard,
"Iterators provided to restore_serialized_message() must be random access iterators."
);

MQTT_LOG("mqtt_api", info)
<< MQTT_ADD_VALUE(address, this)
<< "restore_serialized_message(b, e)";

if (b == e) return;

auto fixed_header = static_cast<std::uint8_t>(*b);
switch (get_control_packet_type(fixed_header)) {
auto cpt_opt = get_control_packet_type_with_check(fixed_header);
if (!cpt_opt) {
MQTT_LOG("mqtt_api", error)
<< MQTT_ADD_VALUE(address, this)
<< "invalid fixed_header ignored. "
<< std::hex << static_cast<int>(fixed_header);
throw protocol_error();
}
switch (cpt_opt.value()) {
case control_packet_type::publish: {
auto buf = allocate_buffer(b, e);
restore_serialized_message(
Expand All @@ -4320,6 +4332,10 @@ class endpoint : public std::enable_shared_from_this<endpoint<Mutex, LockGuard,
);
} break;
default:
MQTT_LOG("mqtt_api", error)
<< MQTT_ADD_VALUE(address, this)
<< "invalid control packet type. "
<< std::hex << static_cast<int>(fixed_header);
throw protocol_error();
break;
}
Expand Down Expand Up @@ -4410,7 +4426,15 @@ class endpoint : public std::enable_shared_from_this<endpoint<Mutex, LockGuard,
if (b == e) return;

auto fixed_header = static_cast<std::uint8_t>(*b);
switch (get_control_packet_type(fixed_header)) {
auto cpt_opt = get_control_packet_type_with_check(fixed_header);
if (!cpt_opt) {
MQTT_LOG("mqtt_api", error)
<< MQTT_ADD_VALUE(address, this)
<< "invalid fixed_header ignored. "
<< std::hex << static_cast<int>(fixed_header);
throw protocol_error();
}
switch (cpt_opt.value()) {
case control_packet_type::publish: {
auto buf = allocate_buffer(b, e);
restore_v5_serialized_message(
Expand All @@ -4426,6 +4450,10 @@ class endpoint : public std::enable_shared_from_this<endpoint<Mutex, LockGuard,
);
} break;
default:
MQTT_LOG("mqtt_api", error)
<< MQTT_ADD_VALUE(address, this)
<< "invalid control packet type. "
<< std::hex << static_cast<int>(fixed_header);
throw protocol_error();
break;
}
Expand Down Expand Up @@ -5033,9 +5061,14 @@ class endpoint : public std::enable_shared_from_this<endpoint<Mutex, LockGuard,
);
}
else {
auto cpt_opt = get_control_packet_type_with_check(fixed_header_);
if (!cpt_opt) {
call_protocol_error_handlers();
return;
}
auto cpt = cpt_opt.value();
auto check =
[&]() -> bool {
auto cpt = get_control_packet_type(fixed_header_);
switch (version_) {
case protocol_version::v3_1_1:
switch (cpt) {
Expand Down Expand Up @@ -5433,9 +5466,14 @@ class endpoint : public std::enable_shared_from_this<endpoint<Mutex, LockGuard,
force_move(session_life_keeper),
force_move(buf),
[
this,
handler = force_move(handler)
]
(std::size_t packet_id, buffer buf, any session_life_keeper, this_type_sp self) mutable {
if (packet_id == 0) {
call_protocol_error_handlers();
return;
}
handler(static_cast<packet_id_t>(packet_id), force_move(buf), force_move(session_life_keeper), force_move(self));
},
force_move(self)
Expand Down

0 comments on commit 364be87

Please sign in to comment.