Skip to content

Commit

Permalink
refactor: auto-formatting sdk files with black
Browse files Browse the repository at this point in the history
The following three files are revised for a future PR: mqtt.py,
channel_manager.py and config.py. As a prep, these files are
auto-formatted with black.
  • Loading branch information
myungjin committed Apr 1, 2023
1 parent f6c291e commit 3f71f82
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 115 deletions.
179 changes: 96 additions & 83 deletions lib/python/flame/backend/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,21 @@
from paho.mqtt.client import MQTTv5

from ..channel import Channel
from ..common.constants import (DEFAULT_RUN_ASYNC_WAIT_TIME, EMPTY_PAYLOAD,
MQTT_TOPIC_PREFIX, BackendEvent, CommType)
from ..common.constants import (
DEFAULT_RUN_ASYNC_WAIT_TIME,
EMPTY_PAYLOAD,
MQTT_TOPIC_PREFIX,
BackendEvent,
CommType,
)
from ..common.util import background_thread_loop, run_async
from ..proto import backend_msg_pb2 as msg_pb2
from .abstract import AbstractBackend
from .chunk_manager import ChunkManager
from .chunk_store import ChunkStore

END_STATUS_ON = 'online'
END_STATUS_OFF = 'offline'
END_STATUS_ON = "online"
END_STATUS_OFF = "offline"

# wait time of 10 sec
# clean up resources allocated for terminated end
Expand Down Expand Up @@ -98,7 +103,7 @@ async def _init_loop_stuff():
coro = _init_loop_stuff()
_, success = run_async(coro, self._loop, DEFAULT_RUN_ASYNC_WAIT_TIME)
if not success:
raise SystemError('initialization failure')
raise SystemError("initialization failure")

self._initialized = True

Expand All @@ -123,16 +128,18 @@ def configure(self, broker: str, job_id: str, task_id: str):

self._mqtt_client = mqtt.Client(self._id, protocol=MQTTv5)

self._health_check_topic = f'{MQTT_TOPIC_PREFIX}/{self._job_id}'
self._health_check_topic = f"{MQTT_TOPIC_PREFIX}/{self._job_id}"

async def _setup_mqtt_client():
_ = asyncio.create_task(self._rx_task())

self._mqtt_client.on_connect = self.on_connect
self._mqtt_client.on_message = self.on_message
self._mqtt_client.will_set(self._health_check_topic,
payload=f'{self._id}:{END_STATUS_OFF}',
qos=MqttQoS.EXACTLY_ONCE)
self._mqtt_client.will_set(
self._health_check_topic,
payload=f"{self._id}:{END_STATUS_OFF}",
qos=MqttQoS.EXACTLY_ONCE,
)

_ = AsyncioHelper(self._loop, self._mqtt_client)

Expand All @@ -142,7 +149,7 @@ async def _setup_mqtt_client():
coro = _setup_mqtt_client()
_, success = run_async(coro, self._loop, DEFAULT_RUN_ASYNC_WAIT_TIME)
if not success:
logger.error('failed to set up mqtt client')
logger.error("failed to set up mqtt client")
raise ConnectionError

def _topics_for_notify(self, channel: Channel) -> list[str]:
Expand All @@ -153,15 +160,17 @@ def _topics_for_notify(self, channel: Channel) -> list[str]:
# format for unicast topic to subscribe:
# /flame/<job_id>/<channel_name>/<groupby>/unicast/<other_role>/+/<my_role>/<my_end_id>
for comm_type in CommType:
topic = TOPIC_SEP.join([
MQTT_TOPIC_PREFIX,
self._job_id,
channel.name(),
channel.groupby(),
comm_type.name,
channel.other_role(),
'+',
])
topic = TOPIC_SEP.join(
[
MQTT_TOPIC_PREFIX,
self._job_id,
channel.name(),
channel.groupby(),
comm_type.name,
channel.other_role(),
"+",
]
)

if comm_type == CommType.UNICAST:
topic = TOPIC_SEP.join([topic, channel.my_role(), self._id])
Expand Down Expand Up @@ -192,15 +201,15 @@ def leave(self, channel: Channel) -> None:
def _handle_health_message(self, message):
health_data = str(message.payload.decode("utf-8"))
# the correct format of health data: <end_id>:<status>
(end_id, status) = health_data.split(':')[0:2]
logger.debug(f'end: {end_id}, status: {status}')
(end_id, status) = health_data.split(":")[0:2]
logger.debug(f"end: {end_id}, status: {status}")
if end_id == self._id or status == END_STATUS_ON:
# nothing to do
return

expiry = time.time() + MQTT_TIME_WAIT
self._cleanup_waits[end_id] = expiry
logger.debug(f'end: {end_id}, expiry time: {expiry}')
logger.debug(f"end: {end_id}, expiry time: {expiry}")

async def _handle_notification(self, any_msg):
msg = msg_pb2.Notify()
Expand All @@ -209,19 +218,19 @@ async def _handle_notification(self, any_msg):
if msg.end_id == self._id:
# This case happens when message is broadcast to a self-loop
# e.g., distributed topology
logger.debug('message sent to self; do nothing')
logger.debug("message sent to self; do nothing")
return

if msg.channel_name not in self._channels:
logger.debug('channel not found')
logger.debug("channel not found")
return

channel = self._channels[msg.channel_name]

if msg.type == msg_pb2.NotifyType.JOIN and not channel.has(msg.end_id):
# this is the first time to see this end,
# so let's notify my presence to the end
logger.debug('acknowledge notification')
logger.debug("acknowledge notification")
self.notify(msg.channel_name, msg_pb2.NotifyType.JOIN)

# add end to the channel
Expand All @@ -238,7 +247,7 @@ async def _handle_data(self, any_msg: Any) -> None:
if msg.end_id == self._id:
# This case happens when message is broadcast to a self-loop
# e.g., distributed topology
logger.debug('message sent to self; do nothing')
logger.debug("message sent to self; do nothing")
return

# update is needed only if end_id's termination is detected
Expand All @@ -263,7 +272,7 @@ async def _rx_task(self):
continue

logger.debug(
f'_rx_task - topic: {message.topic}; len: {len(message.payload)}'
f"_rx_task - topic: {message.topic}; len: {len(message.payload)}"
)

any_msg = Any().FromString(message.payload)
Expand All @@ -273,7 +282,7 @@ async def _rx_task(self):
elif any_msg.Is(msg_pb2.Data.DESCRIPTOR):
await self._handle_data(any_msg)
else:
logger.warning('unknown message type')
logger.warning("unknown message type")

def uid(self):
"""Return backend id."""
Expand All @@ -285,18 +294,20 @@ def eventq(self):

def on_connect(self, client, userdata, flags, rc, properties=None):
"""on_connect publishes a health check message to a mqtt broker."""
logger.debug('calling on_connect')
logger.debug("calling on_connect")

# publish health data; format: <end_id>:<status>
# status is either END_STATUS_ON or END_STATUS_OFF
client.publish(self._health_check_topic,
payload=f'{self._id}:{END_STATUS_ON}',
qos=MqttQoS.EXACTLY_ONCE)
client.publish(
self._health_check_topic,
payload=f"{self._id}:{END_STATUS_ON}",
qos=MqttQoS.EXACTLY_ONCE,
)

def on_message(self, client, userdata, message):
"""on_message receives message."""
logger.debug(
f'on_message - topic: {message.topic}; len: {len(message.payload)}'
f"on_message - topic: {message.topic}; len: {len(message.payload)}"
)

idx = len(self._rx_deque) - 1
Expand All @@ -307,18 +318,18 @@ def on_message(self, client, userdata, message):

def subscribe(self, topic) -> None:
"""Subscribe to a topic."""
logger.debug(f'subscribe topic: {topic}')
logger.debug(f"subscribe topic: {topic}")
self._mqtt_client.subscribe(topic, qos=MqttQoS.EXACTLY_ONCE)

def unsubscribe(self, topic) -> None:
"""Unsubscribe from a topic."""
logger.debug(f'unsubscribe topic: {topic}')
logger.debug(f"unsubscribe topic: {topic}")
self._mqtt_client.unsubscribe(topic)

def notify(self, channel_name, notify_type) -> bool:
"""Broadcast a notify message to a channel."""
if channel_name not in self._channels:
logger.debug(f'channel {channel_name} not found')
logger.debug(f"channel {channel_name} not found")
return False

channel = self._channels[channel_name]
Expand All @@ -334,7 +345,7 @@ def notify(self, channel_name, notify_type) -> bool:
any.Pack(msg)
payload = any.SerializeToString()

logger.debug(f'notify: publish topic: {topic}')
logger.debug(f"notify: publish topic: {topic}")
self._mqtt_client.publish(topic, payload, qos=MqttQoS.EXACTLY_ONCE)

return True
Expand All @@ -347,14 +358,14 @@ def attach_channel(self, channel):
"""Attach a channel to backend."""
self._channels[channel.name()] = channel

def create_tx_task(self,
channel_name: str,
end_id: str,
comm_type=CommType.UNICAST) -> bool:
def create_tx_task(
self, channel_name: str, end_id: str, comm_type=CommType.UNICAST
) -> bool:
"""Create asyncio task for transmission."""
if (channel_name not in self._channels
or (not self._channels[channel_name].has(end_id)
and comm_type != CommType.BROADCAST)):
if channel_name not in self._channels or (
not self._channels[channel_name].has(end_id)
and comm_type != CommType.BROADCAST
):
return False

channel = self._channels[channel_name]
Expand All @@ -364,35 +375,38 @@ def create_tx_task(self,

return True

def topic_for_pub(self,
ch: Channel,
other_id: str = "",
comm_type=CommType.BROADCAST):
def topic_for_pub(
self, ch: Channel, other_id: str = "", comm_type=CommType.BROADCAST
):
"""Return a proper topic for a given channel."""
if comm_type == CommType.BROADCAST:
topic = TOPIC_SEP.join([
MQTT_TOPIC_PREFIX,
ch.job_id(),
ch.name(),
ch.groupby(),
CommType.BROADCAST.name,
ch.my_role(),
self._id,
])
topic = TOPIC_SEP.join(
[
MQTT_TOPIC_PREFIX,
ch.job_id(),
ch.name(),
ch.groupby(),
CommType.BROADCAST.name,
ch.my_role(),
self._id,
]
)
elif comm_type == CommType.UNICAST:
topic = TOPIC_SEP.join([
MQTT_TOPIC_PREFIX,
ch.job_id(),
ch.name(),
ch.groupby(),
CommType.UNICAST.name,
ch.my_role(),
self._id,
ch.other_role(),
other_id,
])
topic = TOPIC_SEP.join(
[
MQTT_TOPIC_PREFIX,
ch.job_id(),
ch.name(),
ch.groupby(),
CommType.UNICAST.name,
ch.my_role(),
self._id,
ch.other_role(),
other_id,
]
)
else:
raise ValueError(f'unknown CommType {comm_type}')
raise ValueError(f"unknown CommType {comm_type}")

return topic

Expand Down Expand Up @@ -433,8 +447,9 @@ def send_chunks(self, topic, ch_name: str, data: bytes) -> None:

self.send_chunk(topic, ch_name, chunk, seqno, eom)

def send_chunk(self, topic: str, channel_name: str, data: bytes,
seqno: int, eom: bool) -> None:
def send_chunk(
self, topic: str, channel_name: str, data: bytes, seqno: int, eom: bool
) -> None:
"""Send a chunk."""
msg = msg_pb2.Data()
msg.end_id = self._id
Expand All @@ -447,16 +462,14 @@ def send_chunk(self, topic: str, channel_name: str, data: bytes,
any.Pack(msg)
payload = any.SerializeToString()

info = self._mqtt_client.publish(topic,
payload,
qos=MqttQoS.EXACTLY_ONCE)
info = self._mqtt_client.publish(topic, payload, qos=MqttQoS.EXACTLY_ONCE)

while not info.is_published():
logger.debug(f"waiting for publish completion: rc = {info.rc}")
retval = self._mqtt_client.loop(MQTT_LOOP_CHECK_PERIOD)
logger.debug(f"retval from loop = {retval}")

logger.debug(f'sending chunk {seqno} to {topic} is done')
logger.debug(f"sending chunk {seqno} to {topic} is done")

async def cleanup(self):
"""Clean up resources in backend."""
Expand All @@ -481,42 +494,42 @@ def __init__(self, loop, client):

def on_socket_open(self, client, userdata, sock):
"""Call a callback function when socket opens."""
logger.debug('Socket opened')
logger.debug("Socket opened")

def cb():
logger.debug('Socket is readable, calling loop_read')
logger.debug("Socket is readable, calling loop_read")
client.loop_read()

self.loop.add_reader(sock, cb)
self.misc = self.loop.create_task(self.misc_loop())

def on_socket_close(self, client, userdata, sock):
"""Call a callback function when socket closes."""
logger.debug('Socket closed')
logger.debug("Socket closed")
self.loop.remove_reader(sock)
self.misc.cancel()

def on_socket_register_write(self, client, userdata, sock):
"""Watch socket's writability."""
logger.debug('Watching socket for writability.')
logger.debug("Watching socket for writability.")

def cb():
logger.debug('Socket is writable, calling loop_write')
logger.debug("Socket is writable, calling loop_write")
client.loop_write()

self.loop.add_writer(sock, cb)

def on_socket_unregister_write(self, client, userdata, sock):
"""Stop watching socket's writability."""
logger.debug('Stop watching socket for writability.')
logger.debug("Stop watching socket for writability.")
self.loop.remove_writer(sock)

async def misc_loop(self):
"""Start misc loop."""
logger.debug('misc_loop started')
logger.debug("misc_loop started")
while self.client.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
break
logger.debug('misc_loop finished')
logger.debug("misc_loop finished")
Loading

0 comments on commit 3f71f82

Please sign in to comment.