diff --git a/include/mqtt/control_packet_type.hpp b/include/mqtt/control_packet_type.hpp index bc3d6993d..3ff1787c7 100644 --- a/include/mqtt/control_packet_type.hpp +++ b/include/mqtt/control_packet_type.hpp @@ -11,6 +11,7 @@ #include #include +#include namespace MQTT_NS { @@ -33,7 +34,7 @@ enum class control_packet_type : std::uint8_t { disconnect = 0b11100000, // 14 auth = 0b11110000, // 15 -}; // namespace control_packet_type +}; constexpr control_packet_type get_control_packet_type(std::uint8_t v) { return static_cast(v & 0b11110000); @@ -68,6 +69,68 @@ std::ostream& operator<<(std::ostream& os, control_packet_type val) return os; } +enum class control_packet_reserved_bits : std::uint8_t { + connect = 0b00000000, + connack = 0b00000000, + // publish = dup qos retain, + puback = 0b00000000, + pubrec = 0b00000000, + pubrel = 0b00000010, + pubcomp = 0b00000000, + subscribe = 0b00000010, + suback = 0b00000000, + unsubscribe = 0b00000010, + unsuback = 0b00000000, + pingreq = 0b00000000, + pingresp = 0b00000000, + disconnect = 0b00000000, + auth = 0b00000000, +}; + +inline optional get_control_packet_type_with_check(std::uint8_t v) { + auto cpt = static_cast(v & 0b11110000); + auto valid = + [&] { + auto rsv = static_cast(v & 0b00001111); + switch (cpt) { + case control_packet_type::connect: + return rsv == control_packet_reserved_bits::connect; + case control_packet_type::connack: + return rsv == control_packet_reserved_bits::connack; + case control_packet_type::publish: + return true; + case control_packet_type::puback: + return rsv == control_packet_reserved_bits::puback; + case control_packet_type::pubrec: + return rsv == control_packet_reserved_bits::pubrec; + case control_packet_type::pubrel: + return rsv == control_packet_reserved_bits::pubrel; + case control_packet_type::pubcomp: + return rsv == control_packet_reserved_bits::pubcomp; + case control_packet_type::subscribe: + return rsv == control_packet_reserved_bits::subscribe; + case control_packet_type::suback: + return rsv == control_packet_reserved_bits::suback; + case control_packet_type::unsubscribe: + return rsv == control_packet_reserved_bits::unsubscribe; + case control_packet_type::unsuback: + return rsv == control_packet_reserved_bits::unsuback; + case control_packet_type::pingreq: + return rsv == control_packet_reserved_bits::pingreq; + case control_packet_type::pingresp: + return rsv == control_packet_reserved_bits::pingresp; + case control_packet_type::disconnect: + return rsv == control_packet_reserved_bits::disconnect; + case control_packet_type::auth: + return rsv == control_packet_reserved_bits::auth; + default: + return false; + } + } (); + if (valid) return cpt; + return nullopt; +} + } // namespace MQTT_NS #endif // MQTT_CONTROL_PACKET_TYPE_HPP diff --git a/include/mqtt/endpoint.hpp b/include/mqtt/endpoint.hpp index fbd298922..2bcade2b1 100644 --- a/include/mqtt/endpoint.hpp +++ b/include/mqtt/endpoint.hpp @@ -4295,10 +4295,22 @@ class endpoint : public std::enable_shared_from_this(*b); - switch (get_control_packet_type(fixed_header)) { + auto cpt_opt = get_control_packet_type_with_check(fixed_header); + if (!cpt_opt) { + MQTT_LOG("mqtt_api", error) + << MQTT_ADD_VALUE(address, this) + << "invalid fixed_header ignored. " + << std::hex << static_cast(fixed_header); + throw protocol_error(); + } + switch (cpt_opt.value()) { case control_packet_type::publish: { auto buf = allocate_buffer(b, e); restore_serialized_message( @@ -4320,6 +4332,10 @@ class endpoint : public std::enable_shared_from_this(fixed_header); throw protocol_error(); break; } @@ -4410,7 +4426,15 @@ class endpoint : public std::enable_shared_from_this(*b); - switch (get_control_packet_type(fixed_header)) { + auto cpt_opt = get_control_packet_type_with_check(fixed_header); + if (!cpt_opt) { + MQTT_LOG("mqtt_api", error) + << MQTT_ADD_VALUE(address, this) + << "invalid fixed_header ignored. " + << std::hex << static_cast(fixed_header); + throw protocol_error(); + } + switch (cpt_opt.value()) { case control_packet_type::publish: { auto buf = allocate_buffer(b, e); restore_v5_serialized_message( @@ -4426,6 +4450,10 @@ class endpoint : public std::enable_shared_from_this(fixed_header); throw protocol_error(); break; } @@ -5033,9 +5061,14 @@ class endpoint : public std::enable_shared_from_this bool { - auto cpt = get_control_packet_type(fixed_header_); switch (version_) { case protocol_version::v3_1_1: switch (cpt) { @@ -5433,9 +5466,14 @@ class endpoint : public std::enable_shared_from_this(packet_id), force_move(buf), force_move(session_life_keeper), force_move(self)); }, force_move(self)