diff --git a/lib/python/flame/backend/mqtt.py b/lib/python/flame/backend/mqtt.py index 416d49ab2..46488905d 100644 --- a/lib/python/flame/backend/mqtt.py +++ b/lib/python/flame/backend/mqtt.py @@ -26,16 +26,21 @@ from paho.mqtt.client import MQTTv5 from ..channel import Channel -from ..common.constants import (DEFAULT_RUN_ASYNC_WAIT_TIME, EMPTY_PAYLOAD, - MQTT_TOPIC_PREFIX, BackendEvent, CommType) +from ..common.constants import ( + DEFAULT_RUN_ASYNC_WAIT_TIME, + EMPTY_PAYLOAD, + MQTT_TOPIC_PREFIX, + BackendEvent, + CommType, +) from ..common.util import background_thread_loop, run_async from ..proto import backend_msg_pb2 as msg_pb2 from .abstract import AbstractBackend from .chunk_manager import ChunkManager from .chunk_store import ChunkStore -END_STATUS_ON = 'online' -END_STATUS_OFF = 'offline' +END_STATUS_ON = "online" +END_STATUS_OFF = "offline" # wait time of 10 sec # clean up resources allocated for terminated end @@ -98,7 +103,7 @@ async def _init_loop_stuff(): coro = _init_loop_stuff() _, success = run_async(coro, self._loop, DEFAULT_RUN_ASYNC_WAIT_TIME) if not success: - raise SystemError('initialization failure') + raise SystemError("initialization failure") self._initialized = True @@ -123,16 +128,18 @@ def configure(self, broker: str, job_id: str, task_id: str): self._mqtt_client = mqtt.Client(self._id, protocol=MQTTv5) - self._health_check_topic = f'{MQTT_TOPIC_PREFIX}/{self._job_id}' + self._health_check_topic = f"{MQTT_TOPIC_PREFIX}/{self._job_id}" async def _setup_mqtt_client(): _ = asyncio.create_task(self._rx_task()) self._mqtt_client.on_connect = self.on_connect self._mqtt_client.on_message = self.on_message - self._mqtt_client.will_set(self._health_check_topic, - payload=f'{self._id}:{END_STATUS_OFF}', - qos=MqttQoS.EXACTLY_ONCE) + self._mqtt_client.will_set( + self._health_check_topic, + payload=f"{self._id}:{END_STATUS_OFF}", + qos=MqttQoS.EXACTLY_ONCE, + ) _ = AsyncioHelper(self._loop, self._mqtt_client) @@ -142,7 +149,7 @@ async def _setup_mqtt_client(): coro = _setup_mqtt_client() _, success = run_async(coro, self._loop, DEFAULT_RUN_ASYNC_WAIT_TIME) if not success: - logger.error('failed to set up mqtt client') + logger.error("failed to set up mqtt client") raise ConnectionError def _topics_for_notify(self, channel: Channel) -> list[str]: @@ -153,15 +160,17 @@ def _topics_for_notify(self, channel: Channel) -> list[str]: # format for unicast topic to subscribe: # /flame////unicast//+// for comm_type in CommType: - topic = TOPIC_SEP.join([ - MQTT_TOPIC_PREFIX, - self._job_id, - channel.name(), - channel.groupby(), - comm_type.name, - channel.other_role(), - '+', - ]) + topic = TOPIC_SEP.join( + [ + MQTT_TOPIC_PREFIX, + self._job_id, + channel.name(), + channel.groupby(), + comm_type.name, + channel.other_role(), + "+", + ] + ) if comm_type == CommType.UNICAST: topic = TOPIC_SEP.join([topic, channel.my_role(), self._id]) @@ -192,15 +201,15 @@ def leave(self, channel: Channel) -> None: def _handle_health_message(self, message): health_data = str(message.payload.decode("utf-8")) # the correct format of health data: : - (end_id, status) = health_data.split(':')[0:2] - logger.debug(f'end: {end_id}, status: {status}') + (end_id, status) = health_data.split(":")[0:2] + logger.debug(f"end: {end_id}, status: {status}") if end_id == self._id or status == END_STATUS_ON: # nothing to do return expiry = time.time() + MQTT_TIME_WAIT self._cleanup_waits[end_id] = expiry - logger.debug(f'end: {end_id}, expiry time: {expiry}') + logger.debug(f"end: {end_id}, expiry time: {expiry}") async def _handle_notification(self, any_msg): msg = msg_pb2.Notify() @@ -209,11 +218,11 @@ async def _handle_notification(self, any_msg): if msg.end_id == self._id: # This case happens when message is broadcast to a self-loop # e.g., distributed topology - logger.debug('message sent to self; do nothing') + logger.debug("message sent to self; do nothing") return if msg.channel_name not in self._channels: - logger.debug('channel not found') + logger.debug("channel not found") return channel = self._channels[msg.channel_name] @@ -221,7 +230,7 @@ async def _handle_notification(self, any_msg): if msg.type == msg_pb2.NotifyType.JOIN and not channel.has(msg.end_id): # this is the first time to see this end, # so let's notify my presence to the end - logger.debug('acknowledge notification') + logger.debug("acknowledge notification") self.notify(msg.channel_name, msg_pb2.NotifyType.JOIN) # add end to the channel @@ -238,7 +247,7 @@ async def _handle_data(self, any_msg: Any) -> None: if msg.end_id == self._id: # This case happens when message is broadcast to a self-loop # e.g., distributed topology - logger.debug('message sent to self; do nothing') + logger.debug("message sent to self; do nothing") return # update is needed only if end_id's termination is detected @@ -263,7 +272,7 @@ async def _rx_task(self): continue logger.debug( - f'_rx_task - topic: {message.topic}; len: {len(message.payload)}' + f"_rx_task - topic: {message.topic}; len: {len(message.payload)}" ) any_msg = Any().FromString(message.payload) @@ -273,7 +282,7 @@ async def _rx_task(self): elif any_msg.Is(msg_pb2.Data.DESCRIPTOR): await self._handle_data(any_msg) else: - logger.warning('unknown message type') + logger.warning("unknown message type") def uid(self): """Return backend id.""" @@ -285,18 +294,20 @@ def eventq(self): def on_connect(self, client, userdata, flags, rc, properties=None): """on_connect publishes a health check message to a mqtt broker.""" - logger.debug('calling on_connect') + logger.debug("calling on_connect") # publish health data; format: : # status is either END_STATUS_ON or END_STATUS_OFF - client.publish(self._health_check_topic, - payload=f'{self._id}:{END_STATUS_ON}', - qos=MqttQoS.EXACTLY_ONCE) + client.publish( + self._health_check_topic, + payload=f"{self._id}:{END_STATUS_ON}", + qos=MqttQoS.EXACTLY_ONCE, + ) def on_message(self, client, userdata, message): """on_message receives message.""" logger.debug( - f'on_message - topic: {message.topic}; len: {len(message.payload)}' + f"on_message - topic: {message.topic}; len: {len(message.payload)}" ) idx = len(self._rx_deque) - 1 @@ -307,18 +318,18 @@ def on_message(self, client, userdata, message): def subscribe(self, topic) -> None: """Subscribe to a topic.""" - logger.debug(f'subscribe topic: {topic}') + logger.debug(f"subscribe topic: {topic}") self._mqtt_client.subscribe(topic, qos=MqttQoS.EXACTLY_ONCE) def unsubscribe(self, topic) -> None: """Unsubscribe from a topic.""" - logger.debug(f'unsubscribe topic: {topic}') + logger.debug(f"unsubscribe topic: {topic}") self._mqtt_client.unsubscribe(topic) def notify(self, channel_name, notify_type) -> bool: """Broadcast a notify message to a channel.""" if channel_name not in self._channels: - logger.debug(f'channel {channel_name} not found') + logger.debug(f"channel {channel_name} not found") return False channel = self._channels[channel_name] @@ -334,7 +345,7 @@ def notify(self, channel_name, notify_type) -> bool: any.Pack(msg) payload = any.SerializeToString() - logger.debug(f'notify: publish topic: {topic}') + logger.debug(f"notify: publish topic: {topic}") self._mqtt_client.publish(topic, payload, qos=MqttQoS.EXACTLY_ONCE) return True @@ -347,14 +358,14 @@ def attach_channel(self, channel): """Attach a channel to backend.""" self._channels[channel.name()] = channel - def create_tx_task(self, - channel_name: str, - end_id: str, - comm_type=CommType.UNICAST) -> bool: + def create_tx_task( + self, channel_name: str, end_id: str, comm_type=CommType.UNICAST + ) -> bool: """Create asyncio task for transmission.""" - if (channel_name not in self._channels - or (not self._channels[channel_name].has(end_id) - and comm_type != CommType.BROADCAST)): + if channel_name not in self._channels or ( + not self._channels[channel_name].has(end_id) + and comm_type != CommType.BROADCAST + ): return False channel = self._channels[channel_name] @@ -364,35 +375,38 @@ def create_tx_task(self, return True - def topic_for_pub(self, - ch: Channel, - other_id: str = "", - comm_type=CommType.BROADCAST): + def topic_for_pub( + self, ch: Channel, other_id: str = "", comm_type=CommType.BROADCAST + ): """Return a proper topic for a given channel.""" if comm_type == CommType.BROADCAST: - topic = TOPIC_SEP.join([ - MQTT_TOPIC_PREFIX, - ch.job_id(), - ch.name(), - ch.groupby(), - CommType.BROADCAST.name, - ch.my_role(), - self._id, - ]) + topic = TOPIC_SEP.join( + [ + MQTT_TOPIC_PREFIX, + ch.job_id(), + ch.name(), + ch.groupby(), + CommType.BROADCAST.name, + ch.my_role(), + self._id, + ] + ) elif comm_type == CommType.UNICAST: - topic = TOPIC_SEP.join([ - MQTT_TOPIC_PREFIX, - ch.job_id(), - ch.name(), - ch.groupby(), - CommType.UNICAST.name, - ch.my_role(), - self._id, - ch.other_role(), - other_id, - ]) + topic = TOPIC_SEP.join( + [ + MQTT_TOPIC_PREFIX, + ch.job_id(), + ch.name(), + ch.groupby(), + CommType.UNICAST.name, + ch.my_role(), + self._id, + ch.other_role(), + other_id, + ] + ) else: - raise ValueError(f'unknown CommType {comm_type}') + raise ValueError(f"unknown CommType {comm_type}") return topic @@ -433,8 +447,9 @@ def send_chunks(self, topic, ch_name: str, data: bytes) -> None: self.send_chunk(topic, ch_name, chunk, seqno, eom) - def send_chunk(self, topic: str, channel_name: str, data: bytes, - seqno: int, eom: bool) -> None: + def send_chunk( + self, topic: str, channel_name: str, data: bytes, seqno: int, eom: bool + ) -> None: """Send a chunk.""" msg = msg_pb2.Data() msg.end_id = self._id @@ -447,16 +462,14 @@ def send_chunk(self, topic: str, channel_name: str, data: bytes, any.Pack(msg) payload = any.SerializeToString() - info = self._mqtt_client.publish(topic, - payload, - qos=MqttQoS.EXACTLY_ONCE) + info = self._mqtt_client.publish(topic, payload, qos=MqttQoS.EXACTLY_ONCE) while not info.is_published(): logger.debug(f"waiting for publish completion: rc = {info.rc}") retval = self._mqtt_client.loop(MQTT_LOOP_CHECK_PERIOD) logger.debug(f"retval from loop = {retval}") - logger.debug(f'sending chunk {seqno} to {topic} is done') + logger.debug(f"sending chunk {seqno} to {topic} is done") async def cleanup(self): """Clean up resources in backend.""" @@ -481,10 +494,10 @@ def __init__(self, loop, client): def on_socket_open(self, client, userdata, sock): """Call a callback function when socket opens.""" - logger.debug('Socket opened') + logger.debug("Socket opened") def cb(): - logger.debug('Socket is readable, calling loop_read') + logger.debug("Socket is readable, calling loop_read") client.loop_read() self.loop.add_reader(sock, cb) @@ -492,31 +505,31 @@ def cb(): def on_socket_close(self, client, userdata, sock): """Call a callback function when socket closes.""" - logger.debug('Socket closed') + logger.debug("Socket closed") self.loop.remove_reader(sock) self.misc.cancel() def on_socket_register_write(self, client, userdata, sock): """Watch socket's writability.""" - logger.debug('Watching socket for writability.') + logger.debug("Watching socket for writability.") def cb(): - logger.debug('Socket is writable, calling loop_write') + logger.debug("Socket is writable, calling loop_write") client.loop_write() self.loop.add_writer(sock, cb) def on_socket_unregister_write(self, client, userdata, sock): """Stop watching socket's writability.""" - logger.debug('Stop watching socket for writability.') + logger.debug("Stop watching socket for writability.") self.loop.remove_writer(sock) async def misc_loop(self): """Start misc loop.""" - logger.debug('misc_loop started') + logger.debug("misc_loop started") while self.client.loop_misc() == mqtt.MQTT_ERR_SUCCESS: try: await asyncio.sleep(1) except asyncio.CancelledError: break - logger.debug('misc_loop finished') + logger.debug("misc_loop finished") diff --git a/lib/python/flame/channel_manager.py b/lib/python/flame/channel_manager.py index 58916a274..10c19cb61 100644 --- a/lib/python/flame/channel_manager.py +++ b/lib/python/flame/channel_manager.py @@ -44,8 +44,9 @@ def custom_excepthook(exc_type, exc_value, exc_traceback): A root-cause is not identified. As a workaround, this custom hook is implemented and set to sys.excepthook """ - logger.critical("Uncaught exception:", - exc_info=(exc_type, exc_value, exc_traceback)) + logger.critical( + "Uncaught exception:", exc_info=(exc_type, exc_value, exc_traceback) + ) sys.excepthook = custom_excepthook @@ -71,7 +72,7 @@ class ChannelManager(object): def __new__(cls): """Create a singleton instance.""" if cls._instance is None: - logger.info('creating a ChannelManager instance') + logger.info("creating a ChannelManager instance") cls._instance = super(ChannelManager, cls).__new__(cls) return cls._instance @@ -89,13 +90,11 @@ def __call__(self, config: Config): self._setup_backends() - self._discovery_client = discovery_client_provider.get( - self._config.task) + self._discovery_client = discovery_client_provider.get(self._config.task) atexit.register(self.cleanup) def _setup_backends(self): - async def inner(q: asyncio.Queue) -> None: # create a coroutine task coro = self._backend_eventq_task(q) @@ -145,10 +144,13 @@ def join(self, name: str) -> bool: me = channel_config.pair[1] other = channel_config.pair[0] - groupby = channel_config.group_by.groupable_value(self._config.group_association.get(name)) + groupby = channel_config.group_by.groupable_value( + self._config.group_association.get(name) + ) - selector = selector_provider.get(self._config.selector.sort, - **self._config.selector.kwargs) + selector = selector_provider.get( + self._config.selector.sort, **self._config.selector.kwargs + ) if name in self._backends: backend = self._backends[name] @@ -156,8 +158,9 @@ def join(self, name: str) -> bool: logger.info(f"no backend found for channel {name}; use default") backend = self._backend - self._channels[name] = Channel(backend, selector, self._job_id, name, - me, other, groupby) + self._channels[name] = Channel( + backend, selector, self._job_id, name, me, other, groupby + ) self._channels[name].join() def leave(self, name): @@ -167,9 +170,9 @@ def leave(self, name): # TODO: reset_channel isn't implemented; the whole discovery module # needs to be revisited. - coro = self._discovery_client.reset_channel(self._job_id, name, - self._role, - self._backend.uid()) + coro = self._discovery_client.reset_channel( + self._job_id, name, self._role, self._backend.uid() + ) _, status = run_async(coro, self._loop, DEFAULT_RUN_ASYNC_WAIT_TIME) if status: diff --git a/lib/python/flame/config.py b/lib/python/flame/config.py index c8a374dee..0093b4a12 100644 --- a/lib/python/flame/config.py +++ b/lib/python/flame/config.py @@ -149,7 +149,6 @@ class ChannelConfigs(FlameSchema): class Config(FlameSchema): - def __init__(self, config_path: str): raw_config = read_config(config_path) transformed_config = transform_config(raw_config) @@ -197,16 +196,13 @@ def transform_config(raw_config: dict) -> dict: "task": raw_config["task"], } - channels, func_tag_map = transform_channels(config_data["role"], - raw_config["channels"]) - config_data = config_data | { - "channels": channels, - "func_tag_map": func_tag_map - } + channels, func_tag_map = transform_channels( + config_data["role"], raw_config["channels"] + ) + config_data = config_data | {"channels": channels, "func_tag_map": func_tag_map} if raw_config.get("hyperparameters", None): - hyperparameters = transform_hyperparameters( - raw_config["hyperparameters"]) + hyperparameters = transform_hyperparameters(raw_config["hyperparameters"]) config_data = config_data | {"hyperparameters": hyperparameters} @@ -225,12 +221,10 @@ def transform_config(raw_config: dict) -> dict: config_data = config_data | {"optimizer": raw_config.get("optimizer")} backends, channel_brokers = transform_channel_configs( - raw_config.get("channelConfigs", {})) + raw_config.get("channelConfigs", {}) + ) config_data = config_data | { - "channel_configs": { - "backends": backends, - "channel_brokers": channel_brokers - } + "channel_configs": {"backends": backends, "channel_brokers": channel_brokers} } config_data = config_data | { @@ -247,10 +241,7 @@ def transform_channel(raw_channel_config: dict): name = raw_channel_config["name"] pair = raw_channel_config["pair"] is_bidirectional = raw_channel_config.get("isBidirectional", True) - group_by = { - "type": "", - "value": [] - } | raw_channel_config.get("groupBy", {}) + group_by = {"type": "", "value": []} | raw_channel_config.get("groupBy", {}) func_tags = raw_channel_config.get("funcTags", {}) description = raw_channel_config.get("description", "")