From eb4b68b0e0d4a5248e6e3215c0e90701e3b81108 Mon Sep 17 00:00:00 2001 From: Myungjin Lee Date: Tue, 31 Jan 2023 19:54:39 -0800 Subject: [PATCH] feat: asynchronous fl Asynchronous FL is implemented for two-tier topology and three-tier hierarchical topology. The main algorithm is based on the following two papers: - https://arxiv.org/pdf/2111.04877.pdf - https://arxiv.org/pdf/2106.06639.pdf Two examples for asynchronous fl are also added. One is for a two-tier topology and the other for a three-tier hierarchical topology. This implementation includes the core algorithm but doesn't include SecAgg algorithm (presented in the papers), which is not the scope of this change. --- lib/python/flame/backend/p2p.py | 4 +- lib/python/flame/channel.py | 166 +++++++++--- lib/python/flame/channel_manager.py | 4 + lib/python/flame/config.py | 4 + lib/python/flame/end.py | 4 + .../examples/async_hier_mnist/__init__.py | 15 ++ .../middle_aggregator/__init__.py | 15 ++ .../middle_aggregator/config_uk.json | 102 +++++++ .../middle_aggregator/config_us.json | 102 +++++++ .../middle_aggregator/main.py | 65 +++++ .../top_aggregator/__init__.py | 15 ++ .../top_aggregator/config.json | 77 ++++++ .../async_hier_mnist/top_aggregator/main.py | 89 +++++++ .../async_hier_mnist/trainer/__init__.py | 15 ++ .../async_hier_mnist/trainer/config_uk1.json | 78 ++++++ .../async_hier_mnist/trainer/config_uk2.json | 78 ++++++ .../async_hier_mnist/trainer/config_us1.json | 78 ++++++ .../async_hier_mnist/trainer/config_us2.json | 78 ++++++ .../examples/async_hier_mnist/trainer/main.py | 140 ++++++++++ .../flame/examples/async_mnist/__init__.py | 15 ++ .../async_mnist/aggregator/__init__.py | 15 ++ .../async_mnist/aggregator/config.json | 73 +++++ .../examples/async_mnist/aggregator/main.py | 89 +++++++ .../examples/async_mnist/trainer/__init__.py | 15 ++ .../examples/async_mnist/trainer/config1.json | 71 +++++ .../examples/async_mnist/trainer/config2.json | 71 +++++ .../examples/async_mnist/trainer/config3.json | 71 +++++ .../examples/async_mnist/trainer/config4.json | 71 +++++ .../examples/async_mnist/trainer/main.py | 140 ++++++++++ lib/python/flame/mode/composer.py | 27 +- .../flame/mode/horizontal/asyncfl/__init__.py | 15 ++ .../horizontal/asyncfl/middle_aggregator.py | 249 ++++++++++++++++++ .../mode/horizontal/asyncfl/top_aggregator.py | 205 ++++++++++++++ .../mode/horizontal/middle_aggregator.py | 8 +- .../flame/mode/horizontal/top_aggregator.py | 4 +- lib/python/flame/mode/horizontal/trainer.py | 9 +- lib/python/flame/mode/message.py | 11 +- lib/python/flame/mode/tasklet.py | 34 ++- lib/python/flame/optimizer/abstract.py | 12 +- lib/python/flame/optimizer/fedavg.py | 10 +- lib/python/flame/optimizer/fedbuff.py | 100 +++++++ lib/python/flame/optimizer/fedopt.py | 71 +++-- lib/python/flame/optimizer/train_result.py | 7 +- lib/python/flame/optimizers.py | 5 +- lib/python/flame/selector/__init__.py | 3 +- lib/python/flame/selector/default.py | 2 +- lib/python/flame/selector/fedbuff.py | 139 ++++++++++ lib/python/flame/selector/random.py | 3 +- lib/python/flame/selectors.py | 3 +- lib/python/setup.py | 2 +- 50 files changed, 2579 insertions(+), 100 deletions(-) create mode 100644 lib/python/flame/examples/async_hier_mnist/__init__.py create mode 100644 lib/python/flame/examples/async_hier_mnist/middle_aggregator/__init__.py create mode 100644 lib/python/flame/examples/async_hier_mnist/middle_aggregator/config_uk.json create mode 100644 lib/python/flame/examples/async_hier_mnist/middle_aggregator/config_us.json create mode 100644 lib/python/flame/examples/async_hier_mnist/middle_aggregator/main.py create mode 100644 lib/python/flame/examples/async_hier_mnist/top_aggregator/__init__.py create mode 100644 lib/python/flame/examples/async_hier_mnist/top_aggregator/config.json create mode 100644 lib/python/flame/examples/async_hier_mnist/top_aggregator/main.py create mode 100644 lib/python/flame/examples/async_hier_mnist/trainer/__init__.py create mode 100644 lib/python/flame/examples/async_hier_mnist/trainer/config_uk1.json create mode 100644 lib/python/flame/examples/async_hier_mnist/trainer/config_uk2.json create mode 100644 lib/python/flame/examples/async_hier_mnist/trainer/config_us1.json create mode 100644 lib/python/flame/examples/async_hier_mnist/trainer/config_us2.json create mode 100644 lib/python/flame/examples/async_hier_mnist/trainer/main.py create mode 100644 lib/python/flame/examples/async_mnist/__init__.py create mode 100644 lib/python/flame/examples/async_mnist/aggregator/__init__.py create mode 100644 lib/python/flame/examples/async_mnist/aggregator/config.json create mode 100644 lib/python/flame/examples/async_mnist/aggregator/main.py create mode 100644 lib/python/flame/examples/async_mnist/trainer/__init__.py create mode 100644 lib/python/flame/examples/async_mnist/trainer/config1.json create mode 100644 lib/python/flame/examples/async_mnist/trainer/config2.json create mode 100644 lib/python/flame/examples/async_mnist/trainer/config3.json create mode 100644 lib/python/flame/examples/async_mnist/trainer/config4.json create mode 100644 lib/python/flame/examples/async_mnist/trainer/main.py create mode 100644 lib/python/flame/mode/horizontal/asyncfl/__init__.py create mode 100644 lib/python/flame/mode/horizontal/asyncfl/middle_aggregator.py create mode 100644 lib/python/flame/mode/horizontal/asyncfl/top_aggregator.py create mode 100644 lib/python/flame/optimizer/fedbuff.py create mode 100644 lib/python/flame/selector/fedbuff.py diff --git a/lib/python/flame/backend/p2p.py b/lib/python/flame/backend/p2p.py index 989eb26d5..3a3162fba 100644 --- a/lib/python/flame/backend/p2p.py +++ b/lib/python/flame/backend/p2p.py @@ -363,7 +363,7 @@ async def _broadcast_task(self, channel): break end_ids = list(channel._ends.keys()) - logger.debug(f"end ids for bcast = {end_ids}") + logger.debug(f"end ids for {channel.name()} bcast = {end_ids}") for end_id in end_ids: try: await self.send_chunks(end_id, channel.name(), data) @@ -374,6 +374,8 @@ async def _broadcast_task(self, channel): await self._cleanup_end(end_id) txq.task_done() + logger.debug(f"broadcast task for {channel.name()} terminated") + async def _unicast_task(self, channel, end_id): txq = channel.get_txq(end_id) diff --git a/lib/python/flame/channel.py b/lib/python/flame/channel.py index e4d64dd8a..4b3f0c937 100644 --- a/lib/python/flame/channel.py +++ b/lib/python/flame/channel.py @@ -26,10 +26,14 @@ from .common.typing import Scalar from .common.util import run_async from .config import GROUPBY_DEFAULT_GROUP -from .end import End +from .end import KEY_END_STATE, VAL_END_STATE_RECVD, End logger = logging.getLogger(__name__) +KEY_CH_STATE = 'state' +VAL_CH_STATE_RECV = 'recv' +VAL_CH_STATE_SEND = 'send' + class Channel(object): """Channel class.""" @@ -117,12 +121,14 @@ async def inner() -> bool: return result - def one_end(self) -> str: + def one_end(self, state: Union[None, str] = None) -> str: """Return one end out of all ends.""" - return self.ends()[0] + return self.ends(state)[0] - def ends(self) -> list[str]: + def ends(self, state: Union[None, str] = None) -> list[str]: """Return a list of end ids.""" + if state == VAL_CH_STATE_RECV or state == VAL_CH_STATE_SEND: + self.properties[KEY_CH_STATE] = state async def inner(): selected = self._selector.select(self._ends, self.properties) @@ -198,17 +204,94 @@ async def _get(): payload, status = run_async(_get(), self._backend.loop()) + if self.has(end_id): + # set a property that says a message was received for the end + self._ends[end_id].set_property(KEY_END_STATE, VAL_END_STATE_RECVD) + return cloudpickle.loads(payload) if payload and status else None - def recv_fifo(self, end_ids: list[str]) -> Tuple[str, Any]: + def recv_fifo(self, + end_ids: list[str], + first_k: int = 0) -> Tuple[str, Any]: """Receive a message per end from a list of ends. The message arrival order among ends is not fixed. Messages are yielded in a FIFO manner. This method is not thread-safe. + + Parameters + ---------- + end_ids: a list of ends to receive a message from + first_k: an integer argument to restrict the number of ends + to receive a messagae from. The default value (= 0) + means that we'd like to receive messages from all + ends in the list. If first_k > len(end_ids), + first_k is set to len(end_ids). + + Returns + ------- + The function yields a pair: end id and message """ + logger.debug(f"first_k = {first_k}, len(end_ids) = {len(end_ids)}") + + first_k = min(first_k, len(end_ids)) + if first_k <= 0: + # a negative value in first_k is an error + # we handle it by setting first_k as the length of the array + first_k = len(end_ids) + + # DO NOT CHANGE self.tmqp as a local variable. + # With aiostream, local variable update looks incorrect. + # but with an instance variable , the variable update is + # done correctly. + # + # A temporary aysncio queue to store messages in a FIFO manner + self.tmpq = None + + async def _put_message_to_tmpq_inner(): + # self.tmpq must be created in the _backend loop + self.tmpq = asyncio.Queue() + _ = asyncio.create_task( + self._streamer_for_recv_fifo(end_ids, first_k)) + + async def _get_message_inner(): + return await self.tmpq.get() + + # first, create an asyncio task to fetch messages and put a temp queue + # _put_message_to_tmpq_inner works as if it is a non-blocking call + # because a task is created within it + _, _ = run_async(_put_message_to_tmpq_inner(), self._backend.loop()) + + # the _get_message_inner() coroutine fetches a message from the temp + # queue; we call this coroutine first_k times + for _ in range(first_k): + result, status = run_async(_get_message_inner(), + self._backend.loop()) + (end_id, payload) = result + logger.debug(f"get payload for {end_id}") + + if self.has(end_id): + logger.debug(f"channel got a msg for {end_id}") + # set a property to indicate that a message was received + # for the end + self._ends[end_id].set_property(KEY_END_STATE, + VAL_END_STATE_RECVD) + else: + logger.debug(f"channel has no end id {end_id} for msg") + + msg = cloudpickle.loads(payload) if payload and status else None + yield end_id, msg - async def _get(end_id) -> Tuple[str, Any]: + async def _streamer_for_recv_fifo(self, end_ids: list[str], first_k: int): + """Read messages in a FIFO fashion. + + This method reads messages from queues associated with each end + and puts first_k number of the messages into a queue; + The remaining messages are saved back into a variable (peek_buf) + of their corresponding end so that they can be read later. + """ + + async def _get_inner(end_id) -> Tuple[str, Any]: if not self.has(end_id): # can't receive message from end_id yield end_id, None @@ -221,40 +304,43 @@ async def _get(end_id) -> Tuple[str, Any]: yield end_id, payload - async def _streamer(tmpq): - runs = [_get(end_id) for end_id in end_ids] - - merged = stream.merge(*runs) - async with merged.stream() as streamer: - async for result in streamer: - await tmpq.put(result) - - # a temporary aysncio queue to store messages in a FIFO manner. - # we define this varialbe to make sure it is visiable - # in both _inner1() and _inner2() - tmpq = None - - async def _inner1(): - nonlocal tmpq - # tmpq must be created in the _backend loop - tmpq = asyncio.Queue() - _ = asyncio.create_task(_streamer(tmpq)) - - async def _inner2(): - return await tmpq.get() - - # first, create an asyncio task to fetch messages and put a temp queue - # _inner1 works as if it is a non-blocking call - # because a task is created within it - _, _ = run_async(_inner1(), self._backend.loop()) - - # the _inner2() coroutine fetches a message from the temp queue - # we call this coroutine the number of end_ids by iterating end_ids - for _ in end_ids: - result, status = run_async(_inner2(), self._backend.loop()) - (end_id, payload) = result - msg = cloudpickle.loads(payload) if payload and status else None - yield end_id, msg + runs = [_get_inner(end_id) for end_id in end_ids] + + # DO NOT CHANGE self.count as a local variable + # with aiostream, local variable update looks incorrect. + # but with an instance variable , the variable update is + # done correctly. + self.count = 0 + merged = stream.merge(*runs) + async with merged.stream() as streamer: + logger.debug(f"0) cnt: {self.count}, first_k: {first_k}") + async for result in streamer: + (end_id, payload) = result + logger.debug(f"1) end id: {end_id}, cnt: {self.count}") + + self.count += 1 + logger.debug(f"2) end id: {end_id}, cnt: {self.count}") + if self.count <= first_k: + logger.debug(f"3) end id: {end_id}, cnt: {self.count}") + await self.tmpq.put(result) + + else: + logger.debug(f"4) end id: {end_id}, cnt: {self.count}") + # We already put the first_k number of messages into + # a queue. + # + # Now we need to save the remaining messages which + # were already taken out from each end's rcv queue. + # In order not to lose those messages, we use peek_buf + # in end object. + + # WARNING: peek_buf must be none; if not, we called + # peek() somewhere else and then called recv_fifo() + # before recv() was called. + # To detect this potential issue, assert is given here. + assert self._ends[end_id].peek_buf is None + + self._ends[end_id].peek_buf = payload def peek(self, end_id): """Peek rxq of end_id and return data if queue is not empty.""" diff --git a/lib/python/flame/channel_manager.py b/lib/python/flame/channel_manager.py index dcb60dc7b..e1a156a5a 100644 --- a/lib/python/flame/channel_manager.py +++ b/lib/python/flame/channel_manager.py @@ -204,6 +204,10 @@ def cleanup(self): ch.cleanup() async def _inner(backend): + # TODO: need better mechanism to wait tx completion + # as a temporary measure, sleep 5 seconds + await asyncio.sleep(5) + # clean up backend await backend.cleanup() diff --git a/lib/python/flame/config.py b/lib/python/flame/config.py index 75f45660b..9d9ef62b7 100644 --- a/lib/python/flame/config.py +++ b/lib/python/flame/config.py @@ -90,6 +90,9 @@ class OptimizerType(Enum): FEDADAGRAD = 2 # FedAdaGrad FEDADAM = 3 # FedAdam FEDYOGI = 4 # FedYogi + # FedBuff from https://arxiv.org/pdf/1903.03934.pdf and + # https://arxiv.org/pdf/2111.04877.pdf + FEDBUFF = 5 class SelectorType(Enum): @@ -97,6 +100,7 @@ class SelectorType(Enum): DEFAULT = 1 # default RANDOM = 2 # random + FEDBUFF = 3 # fedbuff REALM_SEPARATOR = '/' diff --git a/lib/python/flame/end.py b/lib/python/flame/end.py index 5d10ab76d..6755714f1 100644 --- a/lib/python/flame/end.py +++ b/lib/python/flame/end.py @@ -20,6 +20,10 @@ from .common.typing import Scalar +KEY_END_STATE = 'state' +VAL_END_STATE_RECVD = 'recvd' +VAL_END_STATE_NONE = '' + class End(object): """End class.""" diff --git a/lib/python/flame/examples/async_hier_mnist/__init__.py b/lib/python/flame/examples/async_hier_mnist/__init__.py new file mode 100644 index 000000000..00b0536f7 --- /dev/null +++ b/lib/python/flame/examples/async_hier_mnist/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/lib/python/flame/examples/async_hier_mnist/middle_aggregator/__init__.py b/lib/python/flame/examples/async_hier_mnist/middle_aggregator/__init__.py new file mode 100644 index 000000000..00b0536f7 --- /dev/null +++ b/lib/python/flame/examples/async_hier_mnist/middle_aggregator/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/lib/python/flame/examples/async_hier_mnist/middle_aggregator/config_uk.json b/lib/python/flame/examples/async_hier_mnist/middle_aggregator/config_uk.json new file mode 100644 index 000000000..33bc7dd95 --- /dev/null +++ b/lib/python/flame/examples/async_hier_mnist/middle_aggregator/config_uk.json @@ -0,0 +1,102 @@ +{ + "taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa742", + "backend": "p2p", + "brokers": [ + { + "host": "localhost", + "sort": "mqtt" + }, + { + "host": "localhost:10104", + "sort": "p2p" + } + ], + "channels": [ + { + "description": "Model update is sent from mid aggregator to global aggregator and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default" + ] + }, + "name": "global-channel", + "pair": [ + "top-aggregator", + "middle-aggregator" + ], + "funcTags": { + "top-aggregator": [ + "distribute", + "aggregate" + ], + "middle-aggregator": [ + "fetch", + "upload" + ] + } + }, + { + "description": "Model update is sent from mid aggregator to trainer and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default/us/west/org1", + "default/uk/london/org2" + ] + }, + "name": "param-channel", + "pair": [ + "middle-aggregator", + "trainer" + ], + "funcTags": { + "middle-aggregator": [ + "distribute", + "aggregate" + ], + "trainer": [ + "fetch", + "upload" + ] + } + } + ], + "dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz", + "dependencies": [ + "numpy >= 1.2.0" + ], + "hyperparameters": { + "batchSize": 32, + "learningRate": 0.01, + "rounds": 5, + "//": "aggGoal is aggregation goal for fedbuff", + "aggGoal": 1 + }, + "baseModel": { + "name": "", + "version": 1 + }, + "job" : { + "id": "622a358619ab59012eabeefb", + "name": "mnist" + }, + "registry": { + "sort": "dummy", + "uri": "http://flame-mlflow:5000" + }, + "selector": { + "sort": "fedbuff", + "kwargs": { + "//": "c: concurrency level", + "c": 2 + } + }, + "optimizer": { + "sort": "fedbuff", + "kwargs": {} + }, + "maxRunTime": 300, + "realm": "default/uk/london/org2/flame", + "role": "middle-aggregator" +} diff --git a/lib/python/flame/examples/async_hier_mnist/middle_aggregator/config_us.json b/lib/python/flame/examples/async_hier_mnist/middle_aggregator/config_us.json new file mode 100644 index 000000000..d1aedef88 --- /dev/null +++ b/lib/python/flame/examples/async_hier_mnist/middle_aggregator/config_us.json @@ -0,0 +1,102 @@ +{ + "taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa741", + "backend": "p2p", + "brokers": [ + { + "host": "localhost", + "sort": "mqtt" + }, + { + "host": "localhost:10104", + "sort": "p2p" + } + ], + "channels": [ + { + "description": "Model update is sent from mid aggregator to global aggregator and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default" + ] + }, + "name": "global-channel", + "pair": [ + "top-aggregator", + "middle-aggregator" + ], + "funcTags": { + "top-aggregator": [ + "distribute", + "aggregate" + ], + "middle-aggregator": [ + "fetch", + "upload" + ] + } + }, + { + "description": "Model update is sent from mid aggregator to trainer and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default/us/west/org1", + "default/uk/london/org2" + ] + }, + "name": "param-channel", + "pair": [ + "middle-aggregator", + "trainer" + ], + "funcTags": { + "middle-aggregator": [ + "distribute", + "aggregate" + ], + "trainer": [ + "fetch", + "upload" + ] + } + } + ], + "dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz", + "dependencies": [ + "numpy >= 1.2.0" + ], + "hyperparameters": { + "batchSize": 32, + "learningRate": 0.01, + "rounds": 5, + "//": "aggGoal is aggregation goal for fedbuff", + "aggGoal": 1 + }, + "baseModel": { + "name": "", + "version": 1 + }, + "job" : { + "id": "622a358619ab59012eabeefb", + "name": "mnist" + }, + "registry": { + "sort": "dummy", + "uri": "http://flame-mlflow:5000" + }, + "selector": { + "sort": "fedbuff", + "kwargs": { + "//": "c: concurrency level", + "c": 2 + } + }, + "optimizer": { + "sort": "fedbuff", + "kwargs": {} + }, + "maxRunTime": 300, + "realm": "default/us/west/org1/flame", + "role": "middle-aggregator" +} diff --git a/lib/python/flame/examples/async_hier_mnist/middle_aggregator/main.py b/lib/python/flame/examples/async_hier_mnist/middle_aggregator/main.py new file mode 100644 index 000000000..acd5e582d --- /dev/null +++ b/lib/python/flame/examples/async_hier_mnist/middle_aggregator/main.py @@ -0,0 +1,65 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""HIRE_MNIST horizontal hierarchical FL middle level aggregator for Keras.""" + +import logging + +from flame.config import Config +from flame.mode.horizontal.asyncfl.middle_aggregator import MiddleAggregator +# the following needs to be imported to let the flame know +# this aggregator works on tensorflow model +from tensorflow import keras + +logger = logging.getLogger(__name__) + + +class KerasMnistMiddleAggregator(MiddleAggregator): + """Keras Mnist Middle Level Aggregator.""" + + def __init__(self, config: Config) -> None: + """Initialize a class instance.""" + self.config = config + + def initialize(self): + """Initialize role.""" + pass + + def load_data(self) -> None: + """Load a test dataset.""" + pass + + def train(self) -> None: + """Train a model.""" + pass + + def evaluate(self) -> None: + """Evaluate (test) a model.""" + pass + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='') + parser.add_argument('config', nargs='?', default="./config.json") + + args = parser.parse_args() + + config = Config(args.config) + + a = KerasMnistMiddleAggregator(config) + a.compose() + a.run() diff --git a/lib/python/flame/examples/async_hier_mnist/top_aggregator/__init__.py b/lib/python/flame/examples/async_hier_mnist/top_aggregator/__init__.py new file mode 100644 index 000000000..00b0536f7 --- /dev/null +++ b/lib/python/flame/examples/async_hier_mnist/top_aggregator/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/lib/python/flame/examples/async_hier_mnist/top_aggregator/config.json b/lib/python/flame/examples/async_hier_mnist/top_aggregator/config.json new file mode 100644 index 000000000..0bf4d24f4 --- /dev/null +++ b/lib/python/flame/examples/async_hier_mnist/top_aggregator/config.json @@ -0,0 +1,77 @@ +{ + "taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa740", + "backend": "p2p", + "brokers": [ + { + "host": "localhost", + "sort": "mqtt" + }, + { + "host": "localhost:10104", + "sort": "p2p" + } + ], + "channels": [ + { + "description": "Model update is sent from mid aggregator to global aggregator and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default" + ] + }, + "name": "global-channel", + "pair": [ + "top-aggregator", + "middle-aggregator" + ], + "funcTags": { + "top-aggregator": [ + "distribute", + "aggregate" + ], + "middle-aggregator": [ + "fetch", + "upload" + ] + } + } + ], + "dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz", + "dependencies": [ + "numpy >= 1.2.0" + ], + "hyperparameters": { + "batchSize": 32, + "learningRate": 0.01, + "rounds": 10, + "//": "aggGoal is aggregation goal for fedbuff", + "aggGoal": 1 + }, + "baseModel": { + "name": "", + "version": 1 + }, + "job" : { + "id": "622a358619ab59012eabeefb", + "name": "mnist" + }, + "registry": { + "sort": "dummy", + "uri": "http://flame-mlflow:5000" + }, + "selector": { + "sort": "fedbuff", + "kwargs": { + "//": "c: concurrency level", + "c": 2 + } + }, + "optimizer": { + "sort": "fedbuff", + "kwargs": {} + }, + "maxRunTime": 300, + "realm": "", + "role": "top-aggregator" +} diff --git a/lib/python/flame/examples/async_hier_mnist/top_aggregator/main.py b/lib/python/flame/examples/async_hier_mnist/top_aggregator/main.py new file mode 100644 index 000000000..0e7d9914d --- /dev/null +++ b/lib/python/flame/examples/async_hier_mnist/top_aggregator/main.py @@ -0,0 +1,89 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""HIRE_MNIST horizontal hierarchical FL top level aggregator for Keras.""" + +import logging + +from flame.config import Config +from flame.dataset import Dataset +from flame.mode.horizontal.asyncfl.top_aggregator import TopAggregator +from tensorflow import keras +from tensorflow.keras import layers + +logger = logging.getLogger(__name__) + + +class KerasMnistTopAggregator(TopAggregator): + """Keras Mnist Top Level Aggregator.""" + + def __init__(self, config: Config) -> None: + """Initialize a class instance.""" + self.config = config + self.model = None + + self.dataset: Dataset = None + + self.num_classes = 10 + self.input_shape = (28, 28, 1) + + def initialize(self): + """Initialize role.""" + model = keras.Sequential([ + keras.Input(shape=self.input_shape), + layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Flatten(), + layers.Dropout(0.5), + layers.Dense(self.num_classes, activation="softmax"), + ]) + + model.compile(loss="categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"]) + + self.model = model + + def load_data(self) -> None: + """Load a test dataset.""" + # Implement this if loading data is needed in aggregator + pass + + def train(self) -> None: + """Train a model.""" + # Implement this if training is needed in aggregator + pass + + def evaluate(self) -> None: + """Evaluate (test) a model.""" + # Implement this if testing is needed in aggregator + pass + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='') + parser.add_argument('config', nargs='?', default="./config.json") + + args = parser.parse_args() + + config = Config(args.config) + + a = KerasMnistTopAggregator(config) + a.compose() + a.run() diff --git a/lib/python/flame/examples/async_hier_mnist/trainer/__init__.py b/lib/python/flame/examples/async_hier_mnist/trainer/__init__.py new file mode 100644 index 000000000..00b0536f7 --- /dev/null +++ b/lib/python/flame/examples/async_hier_mnist/trainer/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/lib/python/flame/examples/async_hier_mnist/trainer/config_uk1.json b/lib/python/flame/examples/async_hier_mnist/trainer/config_uk1.json new file mode 100644 index 000000000..f8fa0ce03 --- /dev/null +++ b/lib/python/flame/examples/async_hier_mnist/trainer/config_uk1.json @@ -0,0 +1,78 @@ +{ + "taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa745", + "backend": "p2p", + "brokers": [ + { + "host": "localhost", + "sort": "mqtt" + }, + { + "host": "localhost:10104", + "sort": "p2p" + } + ], + "channels": [ + { + "description": "Model update is sent from mid aggregator to trainer and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default/us/west/org1", + "default/uk/london/org2" + ] + }, + "name": "param-channel", + "pair": [ + "middle-aggregator", + "trainer" + ], + "funcTags": { + "middle-aggregator": [ + "distribute", + "aggregate" + ], + "trainer": [ + "fetch", + "upload" + ] + } + } + ], + "dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz", + "dependencies": [ + "numpy >= 1.2.0" + ], + "hyperparameters": { + "batchSize": 32, + "learningRate": 0.01, + "rounds": 5, + "//": "aggGoal is aggregation goal for fedbuff", + "aggGoal": 1 + }, + "baseModel": { + "name": "", + "version": 1 + }, + "job" : { + "id": "622a358619ab59012eabeefb", + "name": "mnist" + }, + "registry": { + "sort": "dummy", + "uri": "http://flame-mlflow:5000" + }, + "selector": { + "sort": "fedbuff", + "kwargs": { + "//": "c: concurrency level", + "c": 2 + } + }, + "optimizer": { + "sort": "fedbuff", + "kwargs": {} + }, + "maxRunTime": 300, + "realm": "default/uk/london/org2/machine1", + "role": "trainer" +} diff --git a/lib/python/flame/examples/async_hier_mnist/trainer/config_uk2.json b/lib/python/flame/examples/async_hier_mnist/trainer/config_uk2.json new file mode 100644 index 000000000..239e6f914 --- /dev/null +++ b/lib/python/flame/examples/async_hier_mnist/trainer/config_uk2.json @@ -0,0 +1,78 @@ +{ + "taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa746", + "backend": "p2p", + "brokers": [ + { + "host": "localhost", + "sort": "mqtt" + }, + { + "host": "localhost:10104", + "sort": "p2p" + } + ], + "channels": [ + { + "description": "Model update is sent from mid aggregator to trainer and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default/us/west/org1", + "default/uk/london/org2" + ] + }, + "name": "param-channel", + "pair": [ + "middle-aggregator", + "trainer" + ], + "funcTags": { + "middle-aggregator": [ + "distribute", + "aggregate" + ], + "trainer": [ + "fetch", + "upload" + ] + } + } + ], + "dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz", + "dependencies": [ + "numpy >= 1.2.0" + ], + "hyperparameters": { + "batchSize": 32, + "learningRate": 0.01, + "rounds": 5, + "//": "aggGoal is aggregation goal for fedbuff", + "aggGoal": 1 + }, + "baseModel": { + "name": "", + "version": 1 + }, + "job" : { + "id": "622a358619ab59012eabeefb", + "name": "mnist" + }, + "registry": { + "sort": "dummy", + "uri": "http://flame-mlflow:5000" + }, + "selector": { + "sort": "fedbuff", + "kwargs": { + "//": "c: concurrency level", + "c": 2 + } + }, + "optimizer": { + "sort": "fedbuff", + "kwargs": {} + }, + "maxRunTime": 300, + "realm": "default/uk/london/org2/machine2", + "role": "trainer" +} diff --git a/lib/python/flame/examples/async_hier_mnist/trainer/config_us1.json b/lib/python/flame/examples/async_hier_mnist/trainer/config_us1.json new file mode 100644 index 000000000..855ccd948 --- /dev/null +++ b/lib/python/flame/examples/async_hier_mnist/trainer/config_us1.json @@ -0,0 +1,78 @@ +{ + "taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa743", + "backend": "p2p", + "brokers": [ + { + "host": "localhost", + "sort": "mqtt" + }, + { + "host": "localhost:10104", + "sort": "p2p" + } + ], + "channels": [ + { + "description": "Model update is sent from mid aggregator to trainer and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default/us/west/org1", + "default/uk/london/org2" + ] + }, + "name": "param-channel", + "pair": [ + "middle-aggregator", + "trainer" + ], + "funcTags": { + "middle-aggregator": [ + "distribute", + "aggregate" + ], + "trainer": [ + "fetch", + "upload" + ] + } + } + ], + "dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz", + "dependencies": [ + "numpy >= 1.2.0" + ], + "hyperparameters": { + "batchSize": 32, + "learningRate": 0.01, + "rounds": 5, + "//": "aggGoal is aggregation goal for fedbuff", + "aggGoal": 1 + }, + "baseModel": { + "name": "", + "version": 1 + }, + "job" : { + "id": "622a358619ab59012eabeefb", + "name": "mnist" + }, + "registry": { + "sort": "dummy", + "uri": "http://flame-mlflow:5000" + }, + "selector": { + "sort": "fedbuff", + "kwargs": { + "//": "c: concurrency level", + "c": 2 + } + }, + "optimizer": { + "sort": "fedbuff", + "kwargs": {} + }, + "maxRunTime": 300, + "realm": "default/us/west/org1/machine1", + "role": "trainer" +} diff --git a/lib/python/flame/examples/async_hier_mnist/trainer/config_us2.json b/lib/python/flame/examples/async_hier_mnist/trainer/config_us2.json new file mode 100644 index 000000000..4a42bc27d --- /dev/null +++ b/lib/python/flame/examples/async_hier_mnist/trainer/config_us2.json @@ -0,0 +1,78 @@ +{ + "taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa744", + "backend": "p2p", + "brokers": [ + { + "host": "localhost", + "sort": "mqtt" + }, + { + "host": "localhost:10104", + "sort": "p2p" + } + ], + "channels": [ + { + "description": "Model update is sent from mid aggregator to trainer and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default/us/west/org1", + "default/uk/london/org2" + ] + }, + "name": "param-channel", + "pair": [ + "middle-aggregator", + "trainer" + ], + "funcTags": { + "middle-aggregator": [ + "distribute", + "aggregate" + ], + "trainer": [ + "fetch", + "upload" + ] + } + } + ], + "dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz", + "dependencies": [ + "numpy >= 1.2.0" + ], + "hyperparameters": { + "batchSize": 32, + "learningRate": 0.01, + "rounds": 5, + "//": "aggGoal is aggregation goal for fedbuff", + "aggGoal": 1 + }, + "baseModel": { + "name": "", + "version": 1 + }, + "job" : { + "id": "622a358619ab59012eabeefb", + "name": "mnist" + }, + "registry": { + "sort": "dummy", + "uri": "http://flame-mlflow:5000" + }, + "selector": { + "sort": "fedbuff", + "kwargs": { + "//": "c: concurrency level", + "c": 2 + } + }, + "optimizer": { + "sort": "fedbuff", + "kwargs": {} + }, + "maxRunTime": 300, + "realm": "default/us/west/org1/machine2", + "role": "trainer" +} diff --git a/lib/python/flame/examples/async_hier_mnist/trainer/main.py b/lib/python/flame/examples/async_hier_mnist/trainer/main.py new file mode 100644 index 000000000..1b2a539b7 --- /dev/null +++ b/lib/python/flame/examples/async_hier_mnist/trainer/main.py @@ -0,0 +1,140 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""HIRE_MNIST horizontal hierarchical FL trainer for Keras.""" + +import logging +from random import randrange +from statistics import mean + +import numpy as np +from flame.config import Config +from flame.mode.horizontal.trainer import Trainer +from tensorflow import keras +from tensorflow.keras import layers + +logger = logging.getLogger(__name__) + + +class KerasMnistTrainer(Trainer): + """Keras Mnist Trainer.""" + + def __init__(self, config: Config) -> None: + """Initialize a class instance.""" + self.config = config + self.dataset_size = 0 + + self.num_classes = 10 + self.input_shape = (28, 28, 1) + + self.model = None + self._x_train = None + self._y_train = None + self._x_test = None + self._y_test = None + + self.epochs = self.config.hyperparameters['epochs'] + self.batch_size = 128 + if 'batchSize' in self.config.hyperparameters: + self.batch_size = self.config.hyperparameters['batchSize'] + + def initialize(self) -> None: + """Initialize role.""" + model = keras.Sequential([ + keras.Input(shape=self.input_shape), + layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Flatten(), + layers.Dropout(0.5), + layers.Dense(self.num_classes, activation="softmax"), + ]) + + model.compile(loss="categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"]) + + self.model = model + + def load_data(self) -> None: + """Load data.""" + # the data, split between train and test sets + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + + split_n = 10 + index = randrange(split_n) + # reduce train sample size to reduce the runtime + x_train = np.split(x_train, split_n)[index] + y_train = np.split(y_train, split_n)[index] + x_test = np.split(x_test, split_n)[index] + y_test = np.split(y_test, split_n)[index] + + # Scale images to the [0, 1] range + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + # Make sure images have shape (28, 28, 1) + x_train = np.expand_dims(x_train, -1) + x_test = np.expand_dims(x_test, -1) + + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, self.num_classes) + y_test = keras.utils.to_categorical(y_test, self.num_classes) + + self._x_train = x_train + self._y_train = y_train + self._x_test = x_test + self._y_test = y_test + + def train(self) -> None: + """Train a model.""" + history = self.model.fit(self._x_train, + self._y_train, + batch_size=self.batch_size, + epochs=self.epochs, + validation_split=0.1) + + # save dataset size so that the info can be shared with aggregator + self.dataset_size = len(self._x_train) + + loss = mean(history.history['loss']) + accuracy = mean(history.history['accuracy']) + self.update_metrics({'loss': loss, 'accuracy': accuracy}) + + def evaluate(self) -> None: + """Evaluate a model.""" + score = self.model.evaluate(self._x_test, self._y_test, verbose=0) + + logger.info(f"Test loss: {score[0]}") + logger.info(f"Test accuracy: {score[1]}") + + # update metrics after each evaluation so that the metrics can be + # logged in a model registry. + self.update_metrics({'test-loss': score[0], 'test-accuracy': score[1]}) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='') + parser.add_argument('config', nargs='?', default="./config.json") + + args = parser.parse_args() + + config = Config(args.config) + + t = KerasMnistTrainer(config) + t.compose() + t.run() diff --git a/lib/python/flame/examples/async_mnist/__init__.py b/lib/python/flame/examples/async_mnist/__init__.py new file mode 100644 index 000000000..00b0536f7 --- /dev/null +++ b/lib/python/flame/examples/async_mnist/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/lib/python/flame/examples/async_mnist/aggregator/__init__.py b/lib/python/flame/examples/async_mnist/aggregator/__init__.py new file mode 100644 index 000000000..00b0536f7 --- /dev/null +++ b/lib/python/flame/examples/async_mnist/aggregator/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/lib/python/flame/examples/async_mnist/aggregator/config.json b/lib/python/flame/examples/async_mnist/aggregator/config.json new file mode 100644 index 000000000..a8646f70c --- /dev/null +++ b/lib/python/flame/examples/async_mnist/aggregator/config.json @@ -0,0 +1,73 @@ +{ + "taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa742", + "backend": "p2p", + "brokers": [ + { + "host": "localhost:10104", + "sort": "p2p" + } + ], + "channels": [ + { + "description": "Model update is sent from trainer to aggregator and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default" + ] + }, + "name": "param-channel", + "pair": [ + "trainer", + "aggregator" + ], + "funcTags": { + "aggregator": [ + "distribute", + "aggregate" + ], + "trainer": [ + "fetch", + "upload" + ] + } + } + ], + "dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz", + "dependencies": [ + "numpy >= 1.2.0" + ], + "hyperparameters": { + "batchSize": 32, + "learningRate": 0.01, + "rounds": 10, + "//": "aggGoal is aggregation goal for fedbuff", + "aggGoal": 2 + }, + "baseModel": { + "name": "", + "version": 2 + }, + "job": { + "id": "622a358619ab59012eabeefb", + "name": "mnist" + }, + "registry": { + "sort": "dummy", + "uri": "" + }, + "selector": { + "sort": "fedbuff", + "kwargs": { + "//": "c: concurrency level", + "c": 4 + } + }, + "optimizer": { + "sort": "fedbuff", + "kwargs": {} + }, + "maxRunTime": 300, + "realm": "default", + "role": "aggregator" +} diff --git a/lib/python/flame/examples/async_mnist/aggregator/main.py b/lib/python/flame/examples/async_mnist/aggregator/main.py new file mode 100644 index 000000000..6f75e2eb0 --- /dev/null +++ b/lib/python/flame/examples/async_mnist/aggregator/main.py @@ -0,0 +1,89 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""MNIST asynchronous horizontal FL aggregator for Keras.""" + +import logging + +from flame.config import Config +from flame.dataset import Dataset +from flame.mode.horizontal.asyncfl.top_aggregator import TopAggregator +from tensorflow import keras +from tensorflow.keras import layers + +logger = logging.getLogger(__name__) + + +class KerasMnistAggregator(TopAggregator): + """Keras Mnist Aggregator.""" + + def __init__(self, config: Config) -> None: + """Initialize a class instance.""" + self.config = config + self.model = None + + self.dataset: Dataset = None + + self.num_classes = 10 + self.input_shape = (28, 28, 1) + + def initialize(self): + """Initialize role.""" + model = keras.Sequential([ + keras.Input(shape=self.input_shape), + layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Flatten(), + layers.Dropout(0.5), + layers.Dense(self.num_classes, activation="softmax"), + ]) + + model.compile(loss="categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"]) + + self.model = model + + def load_data(self) -> None: + """Load a test dataset.""" + # Implement this if loading data is needed in aggregator + pass + + def train(self) -> None: + """Train a model.""" + # Implement this if training is needed in aggregator + pass + + def evaluate(self) -> None: + """Evaluate (test) a model.""" + # Implement this if testing is needed in aggregator + pass + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='') + parser.add_argument('config', nargs='?', default="./config.json") + + args = parser.parse_args() + + config = Config(args.config) + + a = KerasMnistAggregator(config) + a.compose() + a.run() diff --git a/lib/python/flame/examples/async_mnist/trainer/__init__.py b/lib/python/flame/examples/async_mnist/trainer/__init__.py new file mode 100644 index 000000000..00b0536f7 --- /dev/null +++ b/lib/python/flame/examples/async_mnist/trainer/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/lib/python/flame/examples/async_mnist/trainer/config1.json b/lib/python/flame/examples/async_mnist/trainer/config1.json new file mode 100644 index 000000000..4766ce89e --- /dev/null +++ b/lib/python/flame/examples/async_mnist/trainer/config1.json @@ -0,0 +1,71 @@ +{ + "taskid": "505f9fc483cf4df68a2409257b5fad7d3c580370", + "backend": "p2p", + "brokers": [ + { + "host": "broker.hivemq.com", + "sort": "mqtt" + }, + { + "host": "localhost:10104", + "sort": "p2p" + } + ], + "channels": [ + { + "description": "Model update is sent from trainer to aggregator and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default" + ] + }, + "name": "param-channel", + "pair": [ + "trainer", + "aggregator" + ], + "funcTags": { + "aggregator": ["distribute", "aggregate"], + "trainer": ["fetch", "upload"] + } + } + ], + "dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz", + "dependencies": [ + "numpy >= 1.2.0" + ], + "hyperparameters": { + "batchSize": 32, + "learningRate": 0.01, + "rounds": 5, + "//": "aggGoal is irrelevant since it's a trainer", + "aggGoal": 2 + }, + "baseModel": { + "name": "", + "version": 1 + }, + "job" : { + "id": "622a358619ab59012eabeefb", + "name": "mnist" + }, + "registry": { + "sort": "dummy", + "uri": "" + }, + "selector": { + "sort": "fedbuff", + "kwargs": { + "//": "c: concurrency (this irrelevant since it's traner)", + "c": 4 + } + }, + "optimizer": { + "sort": "fedbuff", + "kwargs": {} + }, + "maxRunTime": 300, + "realm": "default/us/west", + "role": "trainer" +} diff --git a/lib/python/flame/examples/async_mnist/trainer/config2.json b/lib/python/flame/examples/async_mnist/trainer/config2.json new file mode 100644 index 000000000..e97b28b92 --- /dev/null +++ b/lib/python/flame/examples/async_mnist/trainer/config2.json @@ -0,0 +1,71 @@ +{ + "taskid": "505f9fc483cf4df68a2409257b5fad7d3c580371", + "backend": "p2p", + "brokers": [ + { + "host": "broker.hivemq.com", + "sort": "mqtt" + }, + { + "host": "localhost:10104", + "sort": "p2p" + } + ], + "channels": [ + { + "description": "Model update is sent from trainer to aggregator and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default" + ] + }, + "name": "param-channel", + "pair": [ + "trainer", + "aggregator" + ], + "funcTags": { + "aggregator": ["distribute", "aggregate"], + "trainer": ["fetch", "upload"] + } + } + ], + "dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz", + "dependencies": [ + "numpy >= 1.2.0" + ], + "hyperparameters": { + "batchSize": 32, + "learningRate": 0.01, + "rounds": 5, + "//": "aggGoal is irrelevant since it's a trainer", + "aggGoal": 2 + }, + "baseModel": { + "name": "", + "version": 1 + }, + "job" : { + "id": "622a358619ab59012eabeefb", + "name": "mnist" + }, + "registry": { + "sort": "dummy", + "uri": "" + }, + "selector": { + "sort": "fedbuff", + "kwargs": { + "//": "c: concurrency (this irrelevant since it's traner)", + "c": 4 + } + }, + "optimizer": { + "sort": "fedbuff", + "kwargs": {} + }, + "maxRunTime": 300, + "realm": "default/us/west", + "role": "trainer" +} diff --git a/lib/python/flame/examples/async_mnist/trainer/config3.json b/lib/python/flame/examples/async_mnist/trainer/config3.json new file mode 100644 index 000000000..f87e0e2be --- /dev/null +++ b/lib/python/flame/examples/async_mnist/trainer/config3.json @@ -0,0 +1,71 @@ +{ + "taskid": "505f9fc483cf4df68a2409257b5fad7d3c580372", + "backend": "p2p", + "brokers": [ + { + "host": "broker.hivemq.com", + "sort": "mqtt" + }, + { + "host": "localhost:10104", + "sort": "p2p" + } + ], + "channels": [ + { + "description": "Model update is sent from trainer to aggregator and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default" + ] + }, + "name": "param-channel", + "pair": [ + "trainer", + "aggregator" + ], + "funcTags": { + "aggregator": ["distribute", "aggregate"], + "trainer": ["fetch", "upload"] + } + } + ], + "dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz", + "dependencies": [ + "numpy >= 1.2.0" + ], + "hyperparameters": { + "batchSize": 32, + "learningRate": 0.01, + "rounds": 5, + "//": "aggGoal is irrelevant since it's a trainer", + "aggGoal": 2 + }, + "baseModel": { + "name": "", + "version": 1 + }, + "job" : { + "id": "622a358619ab59012eabeefb", + "name": "mnist" + }, + "registry": { + "sort": "dummy", + "uri": "" + }, + "selector": { + "sort": "fedbuff", + "kwargs": { + "//": "c: concurrency (this irrelevant since it's traner)", + "c": 4 + } + }, + "optimizer": { + "sort": "fedbuff", + "kwargs": {} + }, + "maxRunTime": 300, + "realm": "default/us/west", + "role": "trainer" +} diff --git a/lib/python/flame/examples/async_mnist/trainer/config4.json b/lib/python/flame/examples/async_mnist/trainer/config4.json new file mode 100644 index 000000000..1374553ca --- /dev/null +++ b/lib/python/flame/examples/async_mnist/trainer/config4.json @@ -0,0 +1,71 @@ +{ + "taskid": "505f9fc483cf4df68a2409257b5fad7d3c580373", + "backend": "p2p", + "brokers": [ + { + "host": "broker.hivemq.com", + "sort": "mqtt" + }, + { + "host": "localhost:10104", + "sort": "p2p" + } + ], + "channels": [ + { + "description": "Model update is sent from trainer to aggregator and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default" + ] + }, + "name": "param-channel", + "pair": [ + "trainer", + "aggregator" + ], + "funcTags": { + "aggregator": ["distribute", "aggregate"], + "trainer": ["fetch", "upload"] + } + } + ], + "dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz", + "dependencies": [ + "numpy >= 1.2.0" + ], + "hyperparameters": { + "batchSize": 32, + "learningRate": 0.01, + "rounds": 5, + "//": "aggGoal is irrelevant since it's a trainer", + "aggGoal": 2 + }, + "baseModel": { + "name": "", + "version": 1 + }, + "job" : { + "id": "622a358619ab59012eabeefb", + "name": "mnist" + }, + "registry": { + "sort": "dummy", + "uri": "" + }, + "selector": { + "sort": "fedbuff", + "kwargs": { + "//": "c: concurrency (this irrelevant since it's traner)", + "c": 4 + } + }, + "optimizer": { + "sort": "fedbuff", + "kwargs": {} + }, + "maxRunTime": 300, + "realm": "default/us/west", + "role": "trainer" +} diff --git a/lib/python/flame/examples/async_mnist/trainer/main.py b/lib/python/flame/examples/async_mnist/trainer/main.py new file mode 100644 index 000000000..0f2cbc08a --- /dev/null +++ b/lib/python/flame/examples/async_mnist/trainer/main.py @@ -0,0 +1,140 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""MNIST horizontal FL trainer for Keras.""" + +import logging +from random import randrange +from statistics import mean + +import numpy as np +from flame.config import Config +from flame.mode.horizontal.trainer import Trainer +from tensorflow import keras +from tensorflow.keras import layers + +logger = logging.getLogger(__name__) + + +class KerasMnistTrainer(Trainer): + """Keras Mnist Trainer.""" + + def __init__(self, config: Config) -> None: + """Initialize a class instance.""" + self.config = config + self.dataset_size = 0 + + self.num_classes = 10 + self.input_shape = (28, 28, 1) + + self.model = None + self._x_train = None + self._y_train = None + self._x_test = None + self._y_test = None + + self.epochs = self.config.hyperparameters['epochs'] + self.batch_size = 128 + if 'batchSize' in self.config.hyperparameters: + self.batch_size = self.config.hyperparameters['batchSize'] + + def initialize(self) -> None: + """Initialize role.""" + model = keras.Sequential([ + keras.Input(shape=self.input_shape), + layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Conv2D(64, kernel_size=(3, 3), activation="relu"), + layers.MaxPooling2D(pool_size=(2, 2)), + layers.Flatten(), + layers.Dropout(0.5), + layers.Dense(self.num_classes, activation="softmax"), + ]) + + model.compile(loss="categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"]) + + self.model = model + + def load_data(self) -> None: + """Load data.""" + # the data, split between train and test sets + (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + + split_n = 10 + index = randrange(split_n) + # reduce train sample size to reduce the runtime + x_train = np.split(x_train, split_n)[index] + y_train = np.split(y_train, split_n)[index] + x_test = np.split(x_test, split_n)[index] + y_test = np.split(y_test, split_n)[index] + + # Scale images to the [0, 1] range + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + # Make sure images have shape (28, 28, 1) + x_train = np.expand_dims(x_train, -1) + x_test = np.expand_dims(x_test, -1) + + # convert class vectors to binary class matrices + y_train = keras.utils.to_categorical(y_train, self.num_classes) + y_test = keras.utils.to_categorical(y_test, self.num_classes) + + self._x_train = x_train + self._y_train = y_train + self._x_test = x_test + self._y_test = y_test + + def train(self) -> None: + """Train a model.""" + history = self.model.fit(self._x_train, + self._y_train, + batch_size=self.batch_size, + epochs=self.epochs, + validation_split=0.1) + + # save dataset size so that the info can be shared with aggregator + self.dataset_size = len(self._x_train) + + loss = mean(history.history['loss']) + accuracy = mean(history.history['accuracy']) + self.update_metrics({'loss': loss, 'accuracy': accuracy}) + + def evaluate(self) -> None: + """Evaluate a model.""" + score = self.model.evaluate(self._x_test, self._y_test, verbose=0) + + logger.info(f"Test loss: {score[0]}") + logger.info(f"Test accuracy: {score[1]}") + + # update metrics after each evaluation so that the metrics can be + # logged in a model registry. + self.update_metrics({'test-loss': score[0], 'test-accuracy': score[1]}) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='') + parser.add_argument('config', nargs='?', default="./config.json") + + args = parser.parse_args() + + config = Config(args.config) + + t = KerasMnistTrainer(config) + t.compose() + t.run() diff --git a/lib/python/flame/mode/composer.py b/lib/python/flame/mode/composer.py index 99fdb519a..298c252b5 100644 --- a/lib/python/flame/mode/composer.py +++ b/lib/python/flame/mode/composer.py @@ -26,7 +26,6 @@ class Composer(object): """Composer enables composition of tasklets.""" - # def __init__(self) -> None: """Initialize the class.""" # maintain tasklet chains @@ -133,6 +132,32 @@ def run(self) -> None: visited.add(child) q.put(child) + logger.debug("end of run") + + def print(self): + """Print the chain of tasklets. + + This function is for debugging. + """ + tasklet = next(iter(self.chain)) + # get the first tasklet in the chain + root = tasklet.get_root() + + # traverse tasklets and print tasklet details + q = Queue() + q.put(root) + while not q.empty(): + tasklet = q.get() + + print("-----") + print(tasklet) + + # put unvisited children of a selected tasklet + for child in self.chain[tasklet]: + q.put(child) + print("=====") + print("done with printing chain") + class ComposerContext(object): """ComposerContext maintains a context of composer.""" diff --git a/lib/python/flame/mode/horizontal/asyncfl/__init__.py b/lib/python/flame/mode/horizontal/asyncfl/__init__.py new file mode 100644 index 000000000..00b0536f7 --- /dev/null +++ b/lib/python/flame/mode/horizontal/asyncfl/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/lib/python/flame/mode/horizontal/asyncfl/middle_aggregator.py b/lib/python/flame/mode/horizontal/asyncfl/middle_aggregator.py new file mode 100644 index 000000000..19737ec7d --- /dev/null +++ b/lib/python/flame/mode/horizontal/asyncfl/middle_aggregator.py @@ -0,0 +1,249 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""Asynchronous honrizontal FL middle level aggregator.""" + +import logging +import time + +from ....channel import VAL_CH_STATE_RECV, VAL_CH_STATE_SEND +from ....optimizer.train_result import TrainResult +from ...composer import Composer +from ...message import MessageType +from ...tasklet import Loop, Tasklet +from ..middle_aggregator import (TAG_AGGREGATE, TAG_DISTRIBUTE, TAG_FETCH, + TAG_UPLOAD) +from ..middle_aggregator import MiddleAggregator as SyncMidAgg + +logger = logging.getLogger(__name__) + +# 60 second wait time until a trainer appears in a channel +WAIT_TIME_FOR_TRAINER = 60 + + +class MiddleAggregator(SyncMidAgg): + """Asynchronous middle level aggregator. + + It acts as a proxy between top level aggregator and trainer. + """ + + def internal_init(self) -> None: + """Initialize internal state for role.""" + super().internal_init() + + self._agg_goal_cnt = 0 + self._agg_goal_weights = None + self._agg_goal = 0 + if 'aggGoal' in self.config.hyperparameters: + self._agg_goal = self.config.hyperparameters['aggGoal'] + + def _reset_agg_goal_variables(self): + logger.debug("reset agg goal variables") + # reset agg goal count + self._agg_goal_cnt = 0 + + # reset agg goal weights + self._agg_goal_weights = None + + def _fetch_weights(self, tag: str) -> None: + logger.debug("calling _fetch_weights") + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"channel not found with tag {tag}") + return + + # this call waits for at least one peer to join this channel + channel.await_join() + + # one aggregator is sufficient + end = channel.one_end(VAL_CH_STATE_RECV) + msg = channel.recv(end) + + if MessageType.WEIGHTS in msg: + self.weights = msg[MessageType.WEIGHTS] + + if MessageType.EOT in msg: + self._work_done = msg[MessageType.EOT] + + if MessageType.ROUND in msg: + self._round = msg[MessageType.ROUND] + + def _distribute_weights(self, tag: str) -> None: + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"channel not found for tag {tag}") + return + + # this call waits for at least one peer to join this channel + self.trainer_no_show = channel.await_join(WAIT_TIME_FOR_TRAINER) + if self.trainer_no_show: + logger.debug("channel await join timeouted") + # send dummy weights to unblock top aggregator + self._send_dummy_weights(TAG_UPLOAD) + return + + for end in channel.ends(VAL_CH_STATE_SEND): + logger.debug(f"sending weights to {end}") + channel.send( + end, { + MessageType.WEIGHTS: self.weights, + MessageType.ROUND: self._round, + MessageType.MODEL_VERSION: self._round + }) + + def _aggregate_weights(self, tag: str) -> None: + """Aggregate local model weights asynchronously. + + This method is overriden from one in synchronous middle aggregator + (..middle_aggregator). + """ + channel = self.cm.get_by_tag(tag) + if not channel: + return + + if self._agg_goal_weights is None: + logger.debug(f"type of weights: {type(self.weights)}") + self._agg_goal_weights = self.weights.copy() + + # receive local model parameters from a trainer who arrives first + end, msg = next(channel.recv_fifo(channel.ends(VAL_CH_STATE_RECV), 1)) + if not msg: + logger.debug(f"No data from {end}; skipping it") + return + + logger.debug(f"received data from {end}") + + if MessageType.WEIGHTS in msg: + weights = msg[MessageType.WEIGHTS] + + if MessageType.DATASET_SIZE in msg: + count = msg[MessageType.DATASET_SIZE] + + if MessageType.MODEL_VERSION in msg: + version = msg[MessageType.MODEL_VERSION] + + logger.debug(f"{end}'s parameters trained with {count} samples") + + if weights is not None and count > 0: + tres = TrainResult(weights, count, version) + # save training result from trainer in a disk cache + self.cache[end] = tres + + self._agg_goal_weights = self.optimizer.do( + self.cache, + base_weights=self._agg_goal_weights, + total=count, + version=self._round) + # increment agg goal count + self._agg_goal_cnt += 1 + + if self._agg_goal_cnt < self._agg_goal: + # didn't reach the aggregation goal; return + logger.debug("didn't reach agg goal") + logger.debug( + f" current: {self._agg_goal_cnt}; agg goal: {self._agg_goal}") + return + + if self._agg_goal_weights is None: + logger.debug("failed model aggregation") + time.sleep(1) + return + + # set global weights + self.weights = self._agg_goal_weights + + self.dataset_size = count + + def _send_weights(self, tag: str) -> None: + logger.debug("calling _send_weights") + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"channel not found with {tag}") + return + + # this call waits for at least one peer to join this channel + channel.await_join() + + # one aggregator is sufficient + end = channel.one_end(VAL_CH_STATE_SEND) + channel.send( + end, { + MessageType.WEIGHTS: self.weights, + MessageType.DATASET_SIZE: self.dataset_size, + MessageType.MODEL_VERSION: self._round + }) + logger.debug("sending weights done") + + def _send_dummy_weights(self, tag: str) -> None: + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"channel not found with {tag}") + return + + # this call waits for at least one peer to join this channel + channel.await_join() + + # one aggregator is sufficient + end = channel.one_end(VAL_CH_STATE_SEND) + + dummy_msg = {MessageType.WEIGHTS: None, MessageType.DATASET_SIZE: 0} + channel.send(end, dummy_msg) + logger.debug("sending dummy weights done") + + def compose(self) -> None: + """Compose role with tasklets.""" + with Composer() as composer: + self.composer = composer + + task_internal_init = Tasklet(self.internal_init) + + task_init = Tasklet(self.initialize) + + task_load_data = Tasklet(self.load_data) + + task_reset_agg_goal_vars = Tasklet(self._reset_agg_goal_variables) + + task_put_dist = Tasklet(self.put, TAG_DISTRIBUTE) + task_put_dist.set_continue_fn(cont_fn=lambda: self.trainer_no_show) + + task_put_upload = Tasklet(self.put, TAG_UPLOAD) + + task_get_aggr = Tasklet(self.get, TAG_AGGREGATE) + + task_get_fetch = Tasklet(self.get, TAG_FETCH) + + task_eval = Tasklet(self.evaluate) + + task_update_round = Tasklet(self.update_round) + + task_end_of_training = Tasklet(self.inform_end_of_training) + + # create a loop object with loop exit condition function + loop = Loop(loop_check_fn=lambda: self._work_done) + + # create a loop object for asyncfl to manage concurrency as well as + # aggregation goal + asyncfl_loop = Loop( + loop_check_fn=lambda: self._agg_goal_cnt == self._agg_goal) + + task_internal_init >> task_load_data >> task_init >> loop( + task_get_fetch >> task_reset_agg_goal_vars >> asyncfl_loop( + task_put_dist >> task_get_aggr) >> task_put_upload >> task_eval + >> task_update_round) >> task_end_of_training + + @classmethod + def get_func_tags(cls) -> list[str]: + """Return a list of function tags defined in the middle level aggregator role.""" + return [TAG_DISTRIBUTE, TAG_AGGREGATE, TAG_FETCH, TAG_UPLOAD] diff --git a/lib/python/flame/mode/horizontal/asyncfl/top_aggregator.py b/lib/python/flame/mode/horizontal/asyncfl/top_aggregator.py new file mode 100644 index 000000000..cfdc65a17 --- /dev/null +++ b/lib/python/flame/mode/horizontal/asyncfl/top_aggregator.py @@ -0,0 +1,205 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""Asynchronous horizontal FL top level aggregator.""" + +import logging +import time + +from ....channel import VAL_CH_STATE_RECV, VAL_CH_STATE_SEND +from ....optimizer.train_result import TrainResult +from ...composer import Composer +from ...message import MessageType +from ...tasklet import Loop, Tasklet +from ..top_aggregator import TAG_AGGREGATE, TAG_DISTRIBUTE +from ..top_aggregator import TopAggregator as SyncTopAgg + +logger = logging.getLogger(__name__) + + +class TopAggregator(SyncTopAgg): + """Asynchronous top level Aggregator implements an ML aggregation role.""" + + def internal_init(self) -> None: + """Initialize internal state for role.""" + super().internal_init() + + self._agg_goal_cnt = 0 + self._agg_goal_weights = None + self._agg_goal = 0 + if 'aggGoal' in self.config.hyperparameters: + self._agg_goal = self.config.hyperparameters['aggGoal'] + + def _reset_agg_goal_variables(self): + logger.debug("reset agg goal variables") + # reset agg goal count + self._agg_goal_cnt = 0 + + # reset agg goal weights + self._agg_goal_weights = None + + def _aggregate_weights(self, tag: str) -> None: + """Aggregate local model weights asynchronously. + + This method is overriden from one in synchronous top aggregator + (..top_aggregator). + """ + channel = self.cm.get_by_tag(tag) + if not channel: + return + + if self._agg_goal_weights is None: + logger.debug(f"type of weights: {type(self.weights)}") + self._agg_goal_weights = self.weights.copy() + + # receive local model parameters from a trainer who arrives first + end, msg = next(channel.recv_fifo(channel.ends(VAL_CH_STATE_RECV), 1)) + if not msg: + logger.debug(f"No data from {end}; skipping it") + return + + logger.debug(f"received data from {end}") + + if MessageType.WEIGHTS in msg: + # TODO: client should send delta instead of whole weights; + # in the current implementation without detla transmission, + # fedbuff algorithm's loss function diverages. + # This needs code refactoring optimizer as well as + # trainer code across all different mode, which involves + # extensive testing of other code. + # The whole change should be done separately to avoid + # too many changes. + weights = msg[MessageType.WEIGHTS] + + if MessageType.DATASET_SIZE in msg: + count = msg[MessageType.DATASET_SIZE] + + if MessageType.MODEL_VERSION in msg: + version = msg[MessageType.MODEL_VERSION] + + logger.debug(f"{end}'s parameters trained with {count} samples") + + if weights is not None and count > 0: + tres = TrainResult(weights, count, version) + # save training result from trainer in a disk cache + self.cache[end] = tres + + self._agg_goal_weights = self.optimizer.do( + self.cache, + base_weights=self._agg_goal_weights, + total=count, + version=self._round) + # increment agg goal count + self._agg_goal_cnt += 1 + + if self._agg_goal_cnt < self._agg_goal: + # didn't reach the aggregation goal; return + logger.debug("didn't reach agg goal") + logger.debug( + f" current: {self._agg_goal_cnt}; agg goal: {self._agg_goal}") + return + + if self._agg_goal_weights is None: + logger.debug("failed model aggregation") + time.sleep(1) + return + + # set global weights + self.weights = self._agg_goal_weights + + # update model with global weights + self._update_model() + + logger.debug(f"aggregation finished for round {self._round}") + + def _distribute_weights(self, tag: str) -> None: + """Distributed a global model in asynchronous FL fashion. + + This method is overriden from one in synchronous top aggregator + (..top_aggregator). + """ + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"channel not found for tag {tag}") + return + + # this call waits for at least one peer to join this channel + channel.await_join() + + # before distributing weights, update it from global model + self._update_weights() + + # send out global model parameters to trainers + for end in channel.ends(VAL_CH_STATE_SEND): + logger.debug(f"sending weights to {end}") + # we use _round to indicate a model version + channel.send( + end, { + MessageType.WEIGHTS: self.weights, + MessageType.ROUND: self._round, + MessageType.MODEL_VERSION: self._round + }) + + def compose(self) -> None: + """Compose role with tasklets.""" + with Composer() as composer: + self.composer = composer + + task_internal_init = Tasklet(self.internal_init) + + task_init = Tasklet(self.initialize) + + task_load_data = Tasklet(self.load_data) + + task_reset_agg_goal_vars = Tasklet(self._reset_agg_goal_variables) + + task_put = Tasklet(self.put, TAG_DISTRIBUTE) + + task_get = Tasklet(self.get, TAG_AGGREGATE) + + task_train = Tasklet(self.train) + + task_eval = Tasklet(self.evaluate) + + task_analysis = Tasklet(self.run_analysis) + + task_save_metrics = Tasklet(self.save_metrics) + + task_increment_round = Tasklet(self.increment_round) + + task_end_of_training = Tasklet(self.inform_end_of_training) + + task_save_params = Tasklet(self.save_params) + + task_save_model = Tasklet(self.save_model) + + # create a loop object with loop exit condition function + loop = Loop(loop_check_fn=lambda: self._work_done) + + # create a loop object for asyncfl to manage concurrency as well as + # aggregation goal + asyncfl_loop = Loop( + loop_check_fn=lambda: self._agg_goal_cnt == self._agg_goal) + + task_internal_init >> task_load_data >> task_init >> loop( + task_reset_agg_goal_vars >> asyncfl_loop( + task_put >> task_get) >> task_train >> task_eval >> + task_analysis >> task_save_metrics >> task_increment_round + ) >> task_end_of_training >> task_save_params >> task_save_model + + @classmethod + def get_func_tags(cls) -> list[str]: + """Return a list of function tags defined in the top level aggregator role.""" + return [TAG_DISTRIBUTE, TAG_AGGREGATE] diff --git a/lib/python/flame/mode/horizontal/middle_aggregator.py b/lib/python/flame/mode/horizontal/middle_aggregator.py index 6c6af7cf6..fc7367cd1 100644 --- a/lib/python/flame/mode/horizontal/middle_aggregator.py +++ b/lib/python/flame/mode/horizontal/middle_aggregator.py @@ -69,6 +69,9 @@ def internal_init(self) -> None: self.cache = Cache() self.dataset_size = 0 + # save distribute tag in an instance variable + self.dist_tag = TAG_DISTRIBUTE + def get(self, tag: str) -> None: """Get data from remote role(s).""" if tag == TAG_FETCH: @@ -81,7 +84,6 @@ def put(self, tag: str) -> None: if tag == TAG_UPLOAD: self._send_weights(tag) if tag == TAG_DISTRIBUTE: - self.dist_tag = tag self._distribute_weights(tag) def _fetch_weights(self, tag: str) -> None: @@ -155,7 +157,7 @@ def _aggregate_weights(self, tag: str) -> None: self.cache[end] = tres # optimizer conducts optimization (in this case, aggregation) - global_weights = self.optimizer.do(self.cache, total) + global_weights = self.optimizer.do(self.cache, total=total) if global_weights is None: logger.debug("failed model aggregation") time.sleep(1) @@ -215,6 +217,8 @@ def update_round(self): def inform_end_of_training(self) -> None: """Inform all the trainers that the training is finished.""" + logger.debug("inform end of training") + channel = self.cm.get_by_tag(self.dist_tag) if not channel: logger.debug(f"channel not found for tag {self.dist_tag}") diff --git a/lib/python/flame/mode/horizontal/top_aggregator.py b/lib/python/flame/mode/horizontal/top_aggregator.py index 1d2e3d890..367bc03f6 100644 --- a/lib/python/flame/mode/horizontal/top_aggregator.py +++ b/lib/python/flame/mode/horizontal/top_aggregator.py @@ -129,7 +129,7 @@ def _aggregate_weights(self, tag: str) -> None: self.cache[end] = tres # optimizer conducts optimization (in this case, aggregation) - global_weights = self.optimizer.do(self.cache, total) + global_weights = self.optimizer.do(self.cache, total=total) if global_weights is None: logger.debug("failed model aggregation") time.sleep(1) @@ -175,6 +175,7 @@ def inform_end_of_training(self) -> None: return channel.broadcast({MessageType.EOT: self._work_done}) + logger.debug("done broadcasting end-of-training") def run_analysis(self): """Run analysis plugins and update results to metrics.""" @@ -208,6 +209,7 @@ def increment_round(self): logger.debug(f"channel not found for tag {self.dist_tag}") return + logger.debug(f"Incremented round to {self._round}") # set necessary properties to help channel decide how to select ends channel.set_property("round", self._round) diff --git a/lib/python/flame/mode/horizontal/trainer.py b/lib/python/flame/mode/horizontal/trainer.py index 280fa88de..6fd8daf8c 100644 --- a/lib/python/flame/mode/horizontal/trainer.py +++ b/lib/python/flame/mode/horizontal/trainer.py @@ -16,8 +16,8 @@ """horizontal FL trainer.""" import logging -import time +from ...channel import VAL_CH_STATE_RECV, VAL_CH_STATE_SEND from ...channel_manager import ChannelManager from ...common.custom_abcmeta import ABCMeta, abstract_attribute from ...common.util import (MLFramework, get_ml_framework_in_use, @@ -83,7 +83,7 @@ def _fetch_weights(self, tag: str) -> None: channel.await_join() # one aggregator is sufficient - end = channel.one_end() + end = channel.one_end(VAL_CH_STATE_RECV) msg = channel.recv(end) if MessageType.WEIGHTS in msg: @@ -114,13 +114,14 @@ def _send_weights(self, tag: str) -> None: channel.await_join() # one aggregator is sufficient - end = channel.one_end() + end = channel.one_end(VAL_CH_STATE_SEND) self._update_weights() channel.send( end, { MessageType.WEIGHTS: self.weights, - MessageType.DATASET_SIZE: self.dataset_size + MessageType.DATASET_SIZE: self.dataset_size, + MessageType.MODEL_VERSION: self._round }) logger.debug("sending weights done") diff --git a/lib/python/flame/mode/message.py b/lib/python/flame/mode/message.py index e7f79b67b..78afcf48b 100644 --- a/lib/python/flame/mode/message.py +++ b/lib/python/flame/mode/message.py @@ -28,6 +28,11 @@ class MessageType(Enum): # a digest of all the workers in distributed learning MEMBER_DIGEST = 5 - RING_WEIGHTS = 6 # global model weights in distributed learning - NEW_TRAINER = 7 # sending message for the arrival of a new trainer - IS_COMMITTER = 8 # is a trainer responsible to send weights to a new trainer in distributed learning + RING_WEIGHTS = 6 # global model weights in distributed learning + NEW_TRAINER = 7 # sending message for the arrival of a new trainer + + # a variable to indicate that a trainer is responsible to send weights + # to a new trainer joining a distributed learning job + IS_COMMITTER = 8 + + MODEL_VERSION = 9 # model version used; an non-negative integer diff --git a/lib/python/flame/mode/tasklet.py b/lib/python/flame/mode/tasklet.py index c33eb9c79..e5501af99 100644 --- a/lib/python/flame/mode/tasklet.py +++ b/lib/python/flame/mode/tasklet.py @@ -60,6 +60,16 @@ def __init__(self, func: Callable, *args, **kwargs) -> None: self.loop_ender = None self.loop_state = LoopIndicator.NONE + def __str__(self): + """Return tasklet details.""" + starter = self.loop_starter.func.__name__ if self.loop_starter else "" + ender = self.loop_ender.func.__name__ if self.loop_ender else "" + + return f"func: {self.func.__name__}" + \ + f"\nloop_state: {self.loop_state}" + \ + f"\nloop_starter: {starter}" + \ + f"\nloop_ender: {ender}" + def __rshift__(self, other: Tasklet) -> Tasklet: """Set up connection.""" if self not in self.composer.chain: @@ -68,17 +78,17 @@ def __rshift__(self, other: Tasklet) -> Tasklet: if other not in self.composer.chain: self.composer.chain[other] = set() - # case 1: t1 >> loop(t2 >> t3) - # if t1 is self, t3 is other; t3.loop_starter is t2 - if other.loop_starter and other.loop_starter not in self.composer.chain: - self.composer.chain[other.loop_starter] = set() - if self not in self.composer.reverse_chain: self.composer.reverse_chain[self] = set() if other not in self.composer.reverse_chain: self.composer.reverse_chain[other] = set() + # case 1: t1 >> loop(t2 >> t3) + # if t1 is self, t3 is other; t3.loop_starter is t2 + if other.loop_starter and other.loop_starter not in self.composer.chain: + self.composer.chain[other.loop_starter] = set() + # same as case 1 if other.loop_starter and other.loop_starter not in self.composer.reverse_chain: self.composer.reverse_chain[other.loop_starter] = set() @@ -86,13 +96,9 @@ def __rshift__(self, other: Tasklet) -> Tasklet: if other.loop_state & LoopIndicator.END: # same as case 1 self.composer.chain[self].add(other.loop_starter) - else: - self.composer.chain[self].add(other) - - if other.loop_state & LoopIndicator.END: - # same as case 1 self.composer.reverse_chain[other.loop_starter].add(self) else: + self.composer.chain[self].add(other) self.composer.reverse_chain[other].add(self) return other @@ -193,7 +199,7 @@ def __call__(self, ender: Tasklet) -> Tasklet: ------- ender: last tasklet in a loop """ - # composer is univercially shared across tasklets + # composer is universally shared across tasklets # let's get it from ender composer = ender.get_composer() @@ -225,6 +231,12 @@ def __call__(self, ender: Tasklet) -> Tasklet: tasklets_in_loop = composer.get_tasklets_in_loop(starter, ender) # for each tasklet in loop, loop_check_fn and loop_ender are updated for tasklet in tasklets_in_loop: + if tasklet.loop_starter and tasklet.loop_ender: + # if both loop_starter and loop_ender are already set, + # they are set for an inner loop + # so, don't update loop_starter and loop_ender in that case + continue + tasklet.loop_starter = starter tasklet.loop_check_fn = self.loop_check_fn tasklet.loop_ender = ender diff --git a/lib/python/flame/optimizer/abstract.py b/lib/python/flame/optimizer/abstract.py index 7e001ce98..6dabfe128 100644 --- a/lib/python/flame/optimizer/abstract.py +++ b/lib/python/flame/optimizer/abstract.py @@ -13,16 +13,22 @@ # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 - - """optimizer abstract class.""" from abc import ABC, abstractmethod +from typing import Union + +from diskcache import Cache class AbstractOptimizer(ABC): """Abstract base class for optimizer implementation.""" @abstractmethod - def do(self) -> None: + def do(self, + cache: Cache, + *, + base_weights=None, + total: int = 0, + version: int = 0) -> Union[list, dict]: """Abstract method to conduct optimization.""" diff --git a/lib/python/flame/optimizer/fedavg.py b/lib/python/flame/optimizer/fedavg.py index e33617e95..d2c078871 100644 --- a/lib/python/flame/optimizer/fedavg.py +++ b/lib/python/flame/optimizer/fedavg.py @@ -15,6 +15,7 @@ # SPDX-License-Identifier: Apache-2.0 """Federated Averaging optimizer.""" import logging +from typing import Union from diskcache import Cache @@ -42,7 +43,12 @@ def __init__(self): "supported ml framework not found; " f"supported frameworks are: {valid_frameworks}") - def do(self, cache: Cache, total: int): + def do(self, + cache: Cache, + *, + base_weights=None, + total: int = 0, + version: int = 0) -> Union[list, dict]: """Do aggregates models of trainers. Return: aggregated model @@ -50,7 +56,7 @@ def do(self, cache: Cache, total: int): logger.debug("calling fedavg") # reset global weights before aggregation - self.agg_weights = None + self.agg_weights = base_weights if len(cache) == 0 or total == 0: return None diff --git a/lib/python/flame/optimizer/fedbuff.py b/lib/python/flame/optimizer/fedbuff.py new file mode 100644 index 000000000..2d0ec86c1 --- /dev/null +++ b/lib/python/flame/optimizer/fedbuff.py @@ -0,0 +1,100 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""FedBuff optimizer. + +The implementation is based on the following paper: +https://arxiv.org/pdf/2106.06639.pdf +https://arxiv.org/pdf/2111.04877.pdf + +SecAgg algorithm is not the scope of this implementation. +""" +import logging +import math +from typing import Union + +from diskcache import Cache + +from ..common.util import (MLFramework, get_ml_framework_in_use, + valid_frameworks) +from .abstract import AbstractOptimizer + +logger = logging.getLogger(__name__) + + +class FedBuff(AbstractOptimizer): + """FedBuff class.""" + + def __init__(self): + """Initialize FedBuff instance.""" + self.agg_weights = None + + ml_framework_in_use = get_ml_framework_in_use() + if ml_framework_in_use == MLFramework.PYTORCH: + self.aggregate_fn = self._aggregate_pytorch + elif ml_framework_in_use == MLFramework.TENSORFLOW: + self.aggregate_fn = self._aggregate_tesnorflow + else: + raise NotImplementedError( + "supported ml framework not found; " + f"supported frameworks are: {valid_frameworks}") + + def do(self, + cache: Cache, + *, + base_weights=None, + total: int = 0, + version: int = 0) -> Union[list, dict]: + """Do aggregates models of trainers. + + Return: aggregated model + """ + logger.debug("calling fedbuff") + + # reset global weights before aggregation + self.agg_weights = base_weights + + if len(cache) == 0 or total == 0: + return None + + for k in list(cache.iterkeys()): + # after popping, the item is removed from the cache + # hence, explicit cache cleanup is not needed + tres = cache.pop(k) + + logger.debug(f"agg ver: {version}, trainer ver: {tres.version}") + # rate determined based on the staleness of local model + rate = 1 / math.sqrt(1 + version - tres.version) + self.aggregate_fn(tres, rate) + + return self.agg_weights + + def _aggregate_pytorch(self, tres, rate): + logger.debug("calling _aggregate_pytorch") + + if self.agg_weights is None: + self.agg_weights = {k: v * rate for k, v in tres.weights.items()} + else: + for k, v in tres.weights.items(): + self.agg_weights[k] += v * rate + + def _aggregate_tesnorflow(self, tres, rate): + logger.debug("calling _aggregate_tensorflow") + + if self.agg_weights is None: + self.agg_weights = [weight * rate for weight in tres.weights] + else: + for idx in range(len(tres.weights)): + self.agg_weights[idx] += tres.weights[idx] * rate diff --git a/lib/python/flame/optimizer/fedopt.py b/lib/python/flame/optimizer/fedopt.py index f08737b63..7156cd109 100644 --- a/lib/python/flame/optimizer/fedopt.py +++ b/lib/python/flame/optimizer/fedopt.py @@ -13,22 +13,23 @@ # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 +"""FedOPT optimizer. -"""FedOPT optimizer""" -"""https://arxiv.org/abs/2003.00295""" -from abc import abstractmethod +https://arxiv.org/abs/2003.00295""" import logging +from abc import abstractmethod +from collections import OrderedDict +from typing import Union from diskcache import Cache -from .fedavg import FedAvg from ..common.util import (MLFramework, get_ml_framework_in_use, valid_frameworks) - -from collections import OrderedDict +from .fedavg import FedAvg logger = logging.getLogger(__name__) + class FedOPT(FedAvg): """FedOPT class.""" @@ -54,14 +55,22 @@ def __init__(self, beta_1, beta_2, eta, tau): "supported ml framework not found; " f"supported frameworks are: {valid_frameworks}") - def do(self, cache: Cache, total: int): + def do(self, + cache: Cache, + *, + base_weights=None, + total: int = 0, + version: int = 0) -> Union[list, dict]: """Do aggregates models of trainers. Return: aggregated model """ logger.debug("calling fedopt") - self.agg_weights = super().do(cache, total) + self.agg_weights = super().do(cache, + base_weights=base_weights, + total=total, + version=version) if self.agg_weights is None: return self.current_weights @@ -87,27 +96,51 @@ def _adapt_pytorch(self, average, current): self.d_t = {k: average[k] - current[k] for k in average.keys()} if self.m_t is None: - self.m_t = {k: torch.zeros_like(self.d_t[k]) for k in self.d_t.keys()} - self.m_t = {k: self.beta_1 * self.m_t[k] + (1 - self.beta_1) * self.d_t[k] for k in self.m_t.keys()} + self.m_t = { + k: torch.zeros_like(self.d_t[k]) + for k in self.d_t.keys() + } + self.m_t = { + k: self.beta_1 * self.m_t[k] + (1 - self.beta_1) * self.d_t[k] + for k in self.m_t.keys() + } if self.v_t is None: - self.v_t = {k: torch.zeros_like(self.d_t[k]) for k in self.d_t.keys()} + self.v_t = { + k: torch.zeros_like(self.d_t[k]) + for k in self.d_t.keys() + } self._delta_v_pytorch() - self.current_weights = OrderedDict({k: self.current_weights[k] + self.eta * self.m_t[k] / (torch.sqrt(self.v_t[k]) + self.tau) for k in self.current_weights.keys()}) + self.current_weights = OrderedDict({ + k: self.current_weights[k] + self.eta * self.m_t[k] / + (torch.sqrt(self.v_t[k]) + self.tau) + for k in self.current_weights.keys() + }) def _adapt_tensorflow(self, average, current): import tensorflow as tf logger.debug("calling _adapt_tensorflow") - - self.d_t = [average[idx]-current[idx] for idx in range(len(average))] + + self.d_t = [average[idx] - current[idx] for idx in range(len(average))] if self.m_t is None: - self.m_t = [tf.zeros_like(self.d_t[idx]) for idx in range(len(self.d_t))] - self.m_t = [self.beta_1 * self.m_t[idx] + (1 - self.beta_1) * self.d_t[idx] for idx in range(len(self.m_t))] + self.m_t = [ + tf.zeros_like(self.d_t[idx]) for idx in range(len(self.d_t)) + ] + self.m_t = [ + self.beta_1 * self.m_t[idx] + (1 - self.beta_1) * self.d_t[idx] + for idx in range(len(self.m_t)) + ] if self.v_t is None: - self.v_t = [tf.zeros_like(self.d_t[idx]) for idx in range(len(self.d_t))] + self.v_t = [ + tf.zeros_like(self.d_t[idx]) for idx in range(len(self.d_t)) + ] self._delta_v_tensorflow() - - self.current_weights = [self.current_weights[idx] + self.eta * self.m_t[idx] / (tf.sqrt(self.v_t[idx]) + self.tau) for idx in range(len(self.current_weights))] + + self.current_weights = [ + self.current_weights[idx] + self.eta * self.m_t[idx] / + (tf.sqrt(self.v_t[idx]) + self.tau) + for idx in range(len(self.current_weights)) + ] diff --git a/lib/python/flame/optimizer/train_result.py b/lib/python/flame/optimizer/train_result.py index c099be7f5..7314c74f8 100644 --- a/lib/python/flame/optimizer/train_result.py +++ b/lib/python/flame/optimizer/train_result.py @@ -13,15 +13,14 @@ # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 - - -"""A class that contains train result.""" +"""A class that contains train result and its meta data.""" class TrainResult(object): """TrainResult class.""" - def __init__(self, weights=None, count=0): + def __init__(self, weights=None, count=0, version=0): """Initialize.""" self.weights = weights self.count = count + self.version = version diff --git a/lib/python/flame/optimizers.py b/lib/python/flame/optimizers.py index ef1f37fbb..b3e2c4e9f 100644 --- a/lib/python/flame/optimizers.py +++ b/lib/python/flame/optimizers.py @@ -13,14 +13,14 @@ # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 - """optimizer provider class.""" from .config import OptimizerType from .object_factory import ObjectFactory -from .optimizer.fedavg import FedAvg from .optimizer.fedadagrad import FedAdaGrad from .optimizer.fedadam import FedAdam +from .optimizer.fedavg import FedAvg +from .optimizer.fedbuff import FedBuff from .optimizer.fedyogi import FedYogi @@ -37,3 +37,4 @@ def get(self, optimizer_name, **kwargs): optimizer_provider.register(OptimizerType.FEDADAGRAD, FedAdaGrad) optimizer_provider.register(OptimizerType.FEDADAM, FedAdam) optimizer_provider.register(OptimizerType.FEDYOGI, FedYogi) +optimizer_provider.register(OptimizerType.FEDBUFF, FedBuff) diff --git a/lib/python/flame/selector/__init__.py b/lib/python/flame/selector/__init__.py index eb384bc88..ba8368bdc 100644 --- a/lib/python/flame/selector/__init__.py +++ b/lib/python/flame/selector/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 - """selector abstract class.""" from abc import ABC, abstractmethod @@ -32,7 +31,7 @@ def __init__(self, **kwargs) -> None: """Initialize an instance with keyword-based arguments.""" for key, value in kwargs.items(): setattr(self, key, value) - self.selected_ends = list() + self.selected_ends = set() @abstractmethod def select(self, ends: dict[str, End], diff --git a/lib/python/flame/selector/default.py b/lib/python/flame/selector/default.py index 0ca27bbc2..e462f4d6d 100644 --- a/lib/python/flame/selector/default.py +++ b/lib/python/flame/selector/default.py @@ -41,7 +41,7 @@ def select(self, ends: dict[str, End], if len(self.selected_ends) == 0 or round > self.round: logger.debug(f"let's select the whole ends for new round {round}") - self.selected_ends = list(ends.keys()) + self.selected_ends = set(ends.keys()) self.round = round logger.debug(f"selected ends: {self.selected_ends}") diff --git a/lib/python/flame/selector/fedbuff.py b/lib/python/flame/selector/fedbuff.py new file mode 100644 index 000000000..d3f42d399 --- /dev/null +++ b/lib/python/flame/selector/fedbuff.py @@ -0,0 +1,139 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""FedBuffSelector class.""" + +import logging +import random + +from ..channel import KEY_CH_STATE, VAL_CH_STATE_RECV, VAL_CH_STATE_SEND +from ..common.typing import Scalar +from ..end import KEY_END_STATE, VAL_END_STATE_NONE, VAL_END_STATE_RECVD, End +from . import AbstractSelector, SelectorReturnType + +logger = logging.getLogger(__name__) + + +class FedBuffSelector(AbstractSelector): + """A selector class for fedbuff-based asyncfl.""" + + def __init__(self, **kwargs): + """Initailize instance.""" + super().__init__(**kwargs) + + try: + self.c = kwargs['c'] + except KeyError: + raise KeyError("c (concurrency level) is not specified in config") + + self.round = 0 + + def select(self, ends: dict[str, End], + channel_props: dict[str, Scalar]) -> SelectorReturnType: + """Select ends from the given ends to meet concurrency level. + + This select method chooses ends differently depending on what state + a channel is in. + In 'send' state, it chooses ends that are not in self.selected_ends. + In 'recv' state, it chooses all ends from self.selected_ends. + Essentially, if an end is in self.selected_ends, it means that we sent + some message already to that end. For such an end, we exclude it from + send and include it for recv in return. + """ + logger.debug("calling fedbuff select") + logger.debug(f"len(ends): {len(ends)}, c: {self.c}") + + concurrency = min(len(ends), self.c) + if concurrency == 0: + logger.debug("ends is empty") + return {} + + self.round = channel_props['round'] if 'round' in channel_props else 0 + + if KEY_CH_STATE not in channel_props: + raise KeyError("channel property doesn't have {KEY_CH_STATE}") + + self._cleanup_recvd_ends(ends) + results = {} + if channel_props[KEY_CH_STATE] == VAL_CH_STATE_SEND: + results = self._handle_send_state(ends, concurrency) + + elif channel_props[KEY_CH_STATE] == VAL_CH_STATE_RECV: + results = self._handle_recv_state(ends, concurrency) + + else: + state = channel_props[KEY_CH_STATE] + raise ValueError(f"unkown channel state: {state}") + + logger.debug(f"selected ends: {self.selected_ends}") + logger.debug(f"results: {results}") + + return results + + def _cleanup_recvd_ends(self, ends: dict[str, End]): + """Clean up ends whose a message was received, from selected ends.""" + logger.debug("clean up recvd ends") + logger.debug(f"ends: {ends}, selected ends: {self.selected_ends}") + for end_id in list(self.selected_ends): + if end_id not in ends: + # something happened to end of end_id + # (e.g., connection loss) + # let's remove it from selected_ends + logger.debug(f"no end id {end_id} in ends") + self.selected_ends.remove(end_id) + else: + state = ends[end_id].get_property(KEY_END_STATE) + logger.debug(f"end id {end_id} state: {state}") + if state == VAL_END_STATE_RECVD: + ends[end_id].set_property(KEY_END_STATE, + VAL_END_STATE_NONE) + self.selected_ends.remove(end_id) + + def _handle_send_state(self, ends: dict[str, End], + concurrency: int) -> SelectorReturnType: + extra = max(0, concurrency - len(self.selected_ends)) + logger.debug(f"c: {concurrency}, ends: {ends.keys()}") + candidates = [] + idx = 0 + # reservoir sampling + for end_id in ends.keys(): + if end_id in self.selected_ends: + # skip if an end is already selected + continue + + idx += 1 + if len(candidates) < extra: + candidates.append(end_id) + continue + + i = random.randrange(idx) + if i < extra: + candidates[i] = end_id + + logger.debug(f"candidates: {candidates}") + # add candidates to selected ends + self.selected_ends = set(list(self.selected_ends) + candidates) + + return {end_id: None for end_id in candidates} + + def _handle_recv_state(self, ends: dict[str, End], + concurrency: int) -> SelectorReturnType: + if len(self.selected_ends) == 0: + logger.debug(f"let's select {concurrency} ends") + self.selected_ends = set(random.sample(list(ends), concurrency)) + + logger.debug(f"selected ends: {self.selected_ends}") + + return {key: None for key in self.selected_ends} diff --git a/lib/python/flame/selector/random.py b/lib/python/flame/selector/random.py index f50d7fe8b..2c1cbb9df 100644 --- a/lib/python/flame/selector/random.py +++ b/lib/python/flame/selector/random.py @@ -13,7 +13,6 @@ # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 - """RandomSelector class.""" import logging @@ -57,7 +56,7 @@ def select(self, ends: dict[str, End], if len(self.selected_ends) == 0 or round > self.round: logger.debug(f"let's select {k} ends for new round {round}") - self.selected_ends = random.sample(ends.keys(), k) + self.selected_ends = set(random.sample(list(ends), k)) self.round = round logger.debug(f"selected ends: {self.selected_ends}") diff --git a/lib/python/flame/selectors.py b/lib/python/flame/selectors.py index 889a19ba2..fafe66ef5 100644 --- a/lib/python/flame/selectors.py +++ b/lib/python/flame/selectors.py @@ -13,12 +13,12 @@ # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 - """selector provider class.""" from .config import SelectorType from .object_factory import ObjectFactory from .selector.default import DefaultSelector +from .selector.fedbuff import FedBuffSelector from .selector.random import RandomSelector @@ -33,3 +33,4 @@ def get(self, selector_name, **kwargs): selector_provider = SelectorProvider() selector_provider.register(SelectorType.DEFAULT, DefaultSelector) selector_provider.register(SelectorType.RANDOM, RandomSelector) +selector_provider.register(SelectorType.FEDBUFF, FedBuffSelector) diff --git a/lib/python/setup.py b/lib/python/setup.py index 17933b91f..b834e86c7 100644 --- a/lib/python/setup.py +++ b/lib/python/setup.py @@ -19,7 +19,7 @@ setup( name='flame', - version='0.0.14', + version='0.0.15', author='Flame Maintainers', author_email='flame-github-owners@cisco.com', include_package_data=True,