From 5ac936b54ae3bc4132956ba8d6a5b72a22288a16 Mon Sep 17 00:00:00 2001 From: Myungjin Lee Date: Thu, 30 Mar 2023 23:52:42 -0700 Subject: [PATCH] refactor: move syncfl files While top-level horizontal files (top_aggregator.py, middle_aggregator.py and trainer.py) are basis of other modes under horizontal category, the directory structure is not well organized. These files are now moved into horizontal/syncfl folder while we keep backward compatibility. --- .../mode/horizontal/middle_aggregator.py | 300 +--------------- .../flame/mode/horizontal/syncfl/__init__.py | 15 + .../horizontal/syncfl/middle_aggregator.py | 309 +++++++++++++++++ .../mode/horizontal/syncfl/top_aggregator.py | 323 ++++++++++++++++++ .../flame/mode/horizontal/syncfl/trainer.py | 229 +++++++++++++ .../flame/mode/horizontal/top_aggregator.py | 311 +---------------- lib/python/flame/mode/horizontal/trainer.py | 218 +----------- 7 files changed, 898 insertions(+), 807 deletions(-) create mode 100644 lib/python/flame/mode/horizontal/syncfl/__init__.py create mode 100644 lib/python/flame/mode/horizontal/syncfl/middle_aggregator.py create mode 100644 lib/python/flame/mode/horizontal/syncfl/top_aggregator.py create mode 100644 lib/python/flame/mode/horizontal/syncfl/trainer.py diff --git a/lib/python/flame/mode/horizontal/middle_aggregator.py b/lib/python/flame/mode/horizontal/middle_aggregator.py index 6731ca5dc..9cb5b2356 100644 --- a/lib/python/flame/mode/horizontal/middle_aggregator.py +++ b/lib/python/flame/mode/horizontal/middle_aggregator.py @@ -15,296 +15,14 @@ # SPDX-License-Identifier: Apache-2.0 """honrizontal FL middle level aggregator.""" -import logging -import time -from copy import deepcopy - -from diskcache import Cache - -from ...channel_manager import ChannelManager -from ...common.custom_abcmeta import ABCMeta, abstract_attribute -from ...common.util import ( - MLFramework, - delta_weights_pytorch, - delta_weights_tensorflow, - get_ml_framework_in_use, - valid_frameworks, +from flame.mode.horizontal.syncfl.middle_aggregator import ( + TAG_AGGREGATE, + TAG_DISTRIBUTE, + TAG_FETCH, + TAG_UPLOAD, + MiddleAggregator, ) -from ...config import Config -from ...optimizer.train_result import TrainResult -from ...optimizers import optimizer_provider -from ...plugin import PluginManager -from ..composer import Composer -from ..message import MessageType -from ..role import Role -from ..tasklet import Loop, Tasklet - -logger = logging.getLogger(__name__) - -TAG_DISTRIBUTE = "distribute" -TAG_AGGREGATE = "aggregate" -TAG_FETCH = "fetch" -TAG_UPLOAD = "upload" - -# 60 second wait time until a trainer appears in a channel -WAIT_TIME_FOR_TRAINER = 60 - - -class MiddleAggregator(Role, metaclass=ABCMeta): - """Middle level aggregator. - - It acts as a proxy between top level aggregator and trainer. - """ - - @abstract_attribute - def config(self) -> Config: - """Abstract attribute for config object.""" - - def internal_init(self) -> None: - """Initialize internal state for role.""" - # global variable for plugin manager - self.plugin_manager = PluginManager() - - self.cm = ChannelManager() - self.cm(self.config) - self.cm.join_all() - - self.optimizer = optimizer_provider.get( - self.config.optimizer.sort, **self.config.optimizer.kwargs - ) - - self._round = 1 - self._work_done = False - - self.cache = Cache() - self.dataset_size = 0 - - # save distribute tag in an instance variable - self.dist_tag = TAG_DISTRIBUTE - - self.framework = get_ml_framework_in_use() - if self.framework == MLFramework.UNKNOWN: - raise NotImplementedError( - "supported ml framework not found; " - f"supported frameworks are: {valid_frameworks}" - ) - - if self.framework == MLFramework.PYTORCH: - self._delta_weights_fn = delta_weights_pytorch - - elif self.framework == MLFramework.TENSORFLOW: - self._delta_weights_fn = delta_weights_tensorflow - - def get(self, tag: str) -> None: - """Get data from remote role(s).""" - if tag == TAG_FETCH: - self._fetch_weights(tag) - if tag == TAG_AGGREGATE: - self._aggregate_weights(tag) - - def put(self, tag: str) -> None: - """Set data to remote role(s).""" - if tag == TAG_UPLOAD: - self._send_weights(tag) - if tag == TAG_DISTRIBUTE: - self._distribute_weights(tag) - - def _fetch_weights(self, tag: str) -> None: - logger.debug("calling _fetch_weights") - channel = self.cm.get_by_tag(tag) - if not channel: - logger.debug(f"[_fetch_weights] channel not found with tag {tag}") - return - - # this call waits for at least one peer to join this channel - channel.await_join() - - # one aggregator is sufficient - end = channel.one_end() - msg, _ = channel.recv(end) - - if MessageType.WEIGHTS in msg: - self.weights = msg[MessageType.WEIGHTS] - - if MessageType.EOT in msg: - self._work_done = msg[MessageType.EOT] - - if MessageType.ROUND in msg: - self._round = msg[MessageType.ROUND] - - def _distribute_weights(self, tag: str) -> None: - channel = self.cm.get_by_tag(tag) - if not channel: - logger.debug(f"channel not found for tag {tag}") - return - - # this call waits for at least one peer to join this channel - self.trainer_no_show = channel.await_join(WAIT_TIME_FOR_TRAINER) - if self.trainer_no_show: - logger.debug("channel await join timeouted") - # send dummy weights to unblock top aggregator - self._send_dummy_weights(TAG_UPLOAD) - return - - for end in channel.ends(): - logger.debug(f"sending weights to {end}") - channel.send( - end, {MessageType.WEIGHTS: self.weights, MessageType.ROUND: self._round} - ) - - def _aggregate_weights(self, tag: str) -> None: - channel = self.cm.get_by_tag(tag) - if not channel: - return - - total = 0 - # receive local model parameters from trainers - for msg, metadata in channel.recv_fifo(channel.ends()): - end, _ = metadata - if not msg: - logger.debug(f"No data from {end}; skipping it") - continue - - if MessageType.WEIGHTS in msg: - weights = msg[MessageType.WEIGHTS] - - if MessageType.DATASET_SIZE in msg: - count = msg[MessageType.DATASET_SIZE] - - logger.debug(f"{end}'s parameters trained with {count} samples") - - if weights is not None and count > 0: - total += count - tres = TrainResult(weights, count) - # save training result from trainer in a disk cache - self.cache[end] = tres - - # optimizer conducts optimization (in this case, aggregation) - global_weights = self.optimizer.do( - deepcopy(self.weights), self.cache, total=total - ) - if global_weights is None: - logger.debug("failed model aggregation") - time.sleep(1) - return - - # save global weights before updating it - self.prev_weights = self.weights - - # set global weights - self.weights = global_weights - self.dataset_size = total - - def _send_weights(self, tag: str) -> None: - logger.debug("calling _send_weights") - channel = self.cm.get_by_tag(tag) - if not channel: - logger.debug(f"channel not found with {tag}") - return - - # this call waits for at least one peer to join this channel - channel.await_join() - - # one aggregator is sufficient - end = channel.one_end() - - delta_weights = self._delta_weights_fn(self.weights, self.prev_weights) - - msg = { - MessageType.WEIGHTS: delta_weights, - MessageType.DATASET_SIZE: self.dataset_size, - MessageType.MODEL_VERSION: self._round, - } - channel.send(end, msg) - logger.debug("sending weights done") - - def _send_dummy_weights(self, tag: str) -> None: - channel = self.cm.get_by_tag(tag) - if not channel: - logger.debug(f"channel not found with {tag}") - return - - # this call waits for at least one peer to join this channel - channel.await_join() - - # one aggregator is sufficient - end = channel.one_end() - - dummy_msg = {MessageType.WEIGHTS: None, MessageType.DATASET_SIZE: 0} - channel.send(end, dummy_msg) - logger.debug("sending dummy weights done") - - def update_round(self): - """Update the round counter.""" - logger.debug(f"Update current round: {self._round}") - - channel = self.cm.get_by_tag(self.dist_tag) - if not channel: - logger.debug(f"channel not found for tag {self.dist_tag}") - return - - # set necessary properties to help channel decide how to select ends - channel.set_property("round", self._round) - - def inform_end_of_training(self) -> None: - """Inform all the trainers that the training is finished.""" - logger.debug("inform end of training") - - channel = self.cm.get_by_tag(self.dist_tag) - if not channel: - logger.debug(f"channel not found for tag {self.dist_tag}") - return - - channel.broadcast({MessageType.EOT: self._work_done}) - - def compose(self) -> None: - """Compose role with tasklets.""" - with Composer() as composer: - self.composer = composer - - task_internal_init = Tasklet(self.internal_init) - - task_init = Tasklet(self.initialize) - - task_load_data = Tasklet(self.load_data) - - task_put_dist = Tasklet(self.put, TAG_DISTRIBUTE) - task_put_dist.set_continue_fn(cont_fn=lambda: self.trainer_no_show) - - task_put_upload = Tasklet(self.put, TAG_UPLOAD) - - task_get_aggr = Tasklet(self.get, TAG_AGGREGATE) - - task_get_fetch = Tasklet(self.get, TAG_FETCH) - - task_eval = Tasklet(self.evaluate) - - task_update_round = Tasklet(self.update_round) - - task_end_of_training = Tasklet(self.inform_end_of_training) - - # create a loop object with loop exit condition function - loop = Loop(loop_check_fn=lambda: self._work_done) - ( - task_internal_init - >> task_load_data - >> task_init - >> loop( - task_get_fetch - >> task_put_dist - >> task_get_aggr - >> task_put_upload - >> task_eval - >> task_update_round - ) - >> task_end_of_training - ) - - def run(self) -> None: - """Run role.""" - self.composer.run() - @classmethod - def get_func_tags(cls) -> list[str]: - """Return a list of function tags defined in the middle level aggregator role.""" - return [TAG_DISTRIBUTE, TAG_AGGREGATE, TAG_FETCH, TAG_UPLOAD] +# Redirect `flame.mode.horizontal.middle_aggregator` to +# `flame.mode.horizontal.syncfl.middle_aggregator` +# This is for backwards compatibility diff --git a/lib/python/flame/mode/horizontal/syncfl/__init__.py b/lib/python/flame/mode/horizontal/syncfl/__init__.py new file mode 100644 index 000000000..00b0536f7 --- /dev/null +++ b/lib/python/flame/mode/horizontal/syncfl/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/lib/python/flame/mode/horizontal/syncfl/middle_aggregator.py b/lib/python/flame/mode/horizontal/syncfl/middle_aggregator.py new file mode 100644 index 000000000..77a090ffa --- /dev/null +++ b/lib/python/flame/mode/horizontal/syncfl/middle_aggregator.py @@ -0,0 +1,309 @@ +# Copyright 2022 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""honrizontal FL middle level aggregator.""" + +import logging +import time +from copy import deepcopy + +from diskcache import Cache +from flame.channel_manager import ChannelManager +from flame.common.custom_abcmeta import ABCMeta, abstract_attribute +from flame.common.util import ( + MLFramework, + delta_weights_pytorch, + delta_weights_tensorflow, + get_ml_framework_in_use, + valid_frameworks, +) +from flame.config import Config +from flame.mode.composer import Composer +from flame.mode.message import MessageType +from flame.mode.role import Role +from flame.mode.tasklet import Loop, Tasklet +from flame.optimizer.train_result import TrainResult +from flame.optimizers import optimizer_provider +from flame.plugin import PluginManager + +logger = logging.getLogger(__name__) + +TAG_DISTRIBUTE = "distribute" +TAG_AGGREGATE = "aggregate" +TAG_FETCH = "fetch" +TAG_UPLOAD = "upload" + +# 60 second wait time until a trainer appears in a channel +WAIT_TIME_FOR_TRAINER = 60 + + +class MiddleAggregator(Role, metaclass=ABCMeta): + """Middle level aggregator. + + It acts as a proxy between top level aggregator and trainer. + """ + + @abstract_attribute + def config(self) -> Config: + """Abstract attribute for config object.""" + + def internal_init(self) -> None: + """Initialize internal state for role.""" + # global variable for plugin manager + self.plugin_manager = PluginManager() + + self.cm = ChannelManager() + self.cm(self.config) + self.cm.join_all() + + self.optimizer = optimizer_provider.get( + self.config.optimizer.sort, **self.config.optimizer.kwargs + ) + + self._round = 1 + self._work_done = False + + self.cache = Cache() + self.dataset_size = 0 + + # save distribute tag in an instance variable + self.dist_tag = TAG_DISTRIBUTE + + self.framework = get_ml_framework_in_use() + if self.framework == MLFramework.UNKNOWN: + raise NotImplementedError( + "supported ml framework not found; " + f"supported frameworks are: {valid_frameworks}" + ) + + if self.framework == MLFramework.PYTORCH: + self._delta_weights_fn = delta_weights_pytorch + + elif self.framework == MLFramework.TENSORFLOW: + self._delta_weights_fn = delta_weights_tensorflow + + def get(self, tag: str) -> None: + """Get data from remote role(s).""" + if tag == TAG_FETCH: + self._fetch_weights(tag) + if tag == TAG_AGGREGATE: + self._aggregate_weights(tag) + + def put(self, tag: str) -> None: + """Set data to remote role(s).""" + if tag == TAG_UPLOAD: + self._send_weights(tag) + if tag == TAG_DISTRIBUTE: + self._distribute_weights(tag) + + def _fetch_weights(self, tag: str) -> None: + logger.debug("calling _fetch_weights") + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"[_fetch_weights] channel not found with tag {tag}") + return + + # this call waits for at least one peer to join this channel + channel.await_join() + + # one aggregator is sufficient + end = channel.one_end() + msg, _ = channel.recv(end) + + if MessageType.WEIGHTS in msg: + self.weights = msg[MessageType.WEIGHTS] + + if MessageType.EOT in msg: + self._work_done = msg[MessageType.EOT] + + if MessageType.ROUND in msg: + self._round = msg[MessageType.ROUND] + + def _distribute_weights(self, tag: str) -> None: + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"channel not found for tag {tag}") + return + + # this call waits for at least one peer to join this channel + self.trainer_no_show = channel.await_join(WAIT_TIME_FOR_TRAINER) + if self.trainer_no_show: + logger.debug("channel await join timeouted") + # send dummy weights to unblock top aggregator + self._send_dummy_weights(TAG_UPLOAD) + return + + for end in channel.ends(): + logger.debug(f"sending weights to {end}") + channel.send( + end, {MessageType.WEIGHTS: self.weights, MessageType.ROUND: self._round} + ) + + def _aggregate_weights(self, tag: str) -> None: + channel = self.cm.get_by_tag(tag) + if not channel: + return + + total = 0 + # receive local model parameters from trainers + for msg, metadata in channel.recv_fifo(channel.ends()): + end, _ = metadata + if not msg: + logger.debug(f"No data from {end}; skipping it") + continue + + if MessageType.WEIGHTS in msg: + weights = msg[MessageType.WEIGHTS] + + if MessageType.DATASET_SIZE in msg: + count = msg[MessageType.DATASET_SIZE] + + logger.debug(f"{end}'s parameters trained with {count} samples") + + if weights is not None and count > 0: + total += count + tres = TrainResult(weights, count) + # save training result from trainer in a disk cache + self.cache[end] = tres + + # optimizer conducts optimization (in this case, aggregation) + global_weights = self.optimizer.do( + deepcopy(self.weights), self.cache, total=total + ) + if global_weights is None: + logger.debug("failed model aggregation") + time.sleep(1) + return + + # save global weights before updating it + self.prev_weights = self.weights + + # set global weights + self.weights = global_weights + self.dataset_size = total + + def _send_weights(self, tag: str) -> None: + logger.debug("calling _send_weights") + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"channel not found with {tag}") + return + + # this call waits for at least one peer to join this channel + channel.await_join() + + # one aggregator is sufficient + end = channel.one_end() + + delta_weights = self._delta_weights_fn(self.weights, self.prev_weights) + + msg = { + MessageType.WEIGHTS: delta_weights, + MessageType.DATASET_SIZE: self.dataset_size, + MessageType.MODEL_VERSION: self._round, + } + channel.send(end, msg) + logger.debug("sending weights done") + + def _send_dummy_weights(self, tag: str) -> None: + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"channel not found with {tag}") + return + + # this call waits for at least one peer to join this channel + channel.await_join() + + # one aggregator is sufficient + end = channel.one_end() + + dummy_msg = {MessageType.WEIGHTS: None, MessageType.DATASET_SIZE: 0} + channel.send(end, dummy_msg) + logger.debug("sending dummy weights done") + + def update_round(self): + """Update the round counter.""" + logger.debug(f"Update current round: {self._round}") + + channel = self.cm.get_by_tag(self.dist_tag) + if not channel: + logger.debug(f"channel not found for tag {self.dist_tag}") + return + + # set necessary properties to help channel decide how to select ends + channel.set_property("round", self._round) + + def inform_end_of_training(self) -> None: + """Inform all the trainers that the training is finished.""" + logger.debug("inform end of training") + + channel = self.cm.get_by_tag(self.dist_tag) + if not channel: + logger.debug(f"channel not found for tag {self.dist_tag}") + return + + channel.broadcast({MessageType.EOT: self._work_done}) + + def compose(self) -> None: + """Compose role with tasklets.""" + with Composer() as composer: + self.composer = composer + + task_internal_init = Tasklet(self.internal_init) + + task_init = Tasklet(self.initialize) + + task_load_data = Tasklet(self.load_data) + + task_put_dist = Tasklet(self.put, TAG_DISTRIBUTE) + task_put_dist.set_continue_fn(cont_fn=lambda: self.trainer_no_show) + + task_put_upload = Tasklet(self.put, TAG_UPLOAD) + + task_get_aggr = Tasklet(self.get, TAG_AGGREGATE) + + task_get_fetch = Tasklet(self.get, TAG_FETCH) + + task_eval = Tasklet(self.evaluate) + + task_update_round = Tasklet(self.update_round) + + task_end_of_training = Tasklet(self.inform_end_of_training) + + # create a loop object with loop exit condition function + loop = Loop(loop_check_fn=lambda: self._work_done) + ( + task_internal_init + >> task_load_data + >> task_init + >> loop( + task_get_fetch + >> task_put_dist + >> task_get_aggr + >> task_put_upload + >> task_eval + >> task_update_round + ) + >> task_end_of_training + ) + + def run(self) -> None: + """Run role.""" + self.composer.run() + + @classmethod + def get_func_tags(cls) -> list[str]: + """Return a list of function tags defined in the middle level aggregator role.""" + return [TAG_DISTRIBUTE, TAG_AGGREGATE, TAG_FETCH, TAG_UPLOAD] diff --git a/lib/python/flame/mode/horizontal/syncfl/top_aggregator.py b/lib/python/flame/mode/horizontal/syncfl/top_aggregator.py new file mode 100644 index 000000000..cfb1aedf1 --- /dev/null +++ b/lib/python/flame/mode/horizontal/syncfl/top_aggregator.py @@ -0,0 +1,323 @@ +# Copyright 2022 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""horizontal FL top level aggregator.""" + +import logging +import time +from copy import deepcopy + +from diskcache import Cache +from flame.channel_manager import ChannelManager +from flame.common.constants import DeviceType +from flame.common.custom_abcmeta import ABCMeta, abstract_attribute +from flame.common.util import ( + MLFramework, + get_ml_framework_in_use, + mlflow_runname, + valid_frameworks, + weights_to_device, + weights_to_model_device, +) +from flame.config import Config +from flame.mode.composer import Composer +from flame.mode.message import MessageType +from flame.mode.role import Role +from flame.mode.tasklet import Loop, Tasklet +from flame.optimizer.train_result import TrainResult +from flame.optimizers import optimizer_provider +from flame.plugin import PluginManager, PluginType +from flame.registries import registry_provider + +logger = logging.getLogger(__name__) + +TAG_DISTRIBUTE = "distribute" +TAG_AGGREGATE = "aggregate" + + +class TopAggregator(Role, metaclass=ABCMeta): + """Top level Aggregator implements an ML aggregation role.""" + + @abstract_attribute + def config(self) -> Config: + """Abstract attribute for config object.""" + + @abstract_attribute + def model(self): + """Abstract attribute for model object.""" + + @abstract_attribute + def dataset(self): + """ + Abstract attribute for datset. + + dataset's type is Dataset (in flame/dataset.py). + """ + + def internal_init(self) -> None: + """Initialize internal state for role.""" + # global variable for plugin manager + self.plugin_manager = PluginManager() + + self.cm = ChannelManager() + self.cm(self.config) + self.cm.join_all() + + self.registry_client = registry_provider.get(self.config.registry.sort) + # initialize registry client + self.registry_client(self.config.registry.uri, self.config.job.job_id) + + base_model = self.config.base_model + if base_model and base_model.name != "" and base_model.version > 0: + self.model = self.registry_client.load_model( + base_model.name, base_model.version + ) + + self.registry_client.setup_run(mlflow_runname(self.config)) + self.metrics = dict() + + # disk cache is used for saving memory in case model is large + self.cache = Cache() + self.optimizer = optimizer_provider.get( + self.config.optimizer.sort, **self.config.optimizer.kwargs + ) + + self._round = 1 + self._rounds = 1 + self._rounds = self.config.hyperparameters.rounds + self._work_done = False + + self.framework = get_ml_framework_in_use() + if self.framework == MLFramework.UNKNOWN: + raise NotImplementedError( + "supported ml framework not found; " + f"supported frameworks are: {valid_frameworks}" + ) + + def get(self, tag: str) -> None: + """Get data from remote role(s).""" + if tag == TAG_AGGREGATE: + self._aggregate_weights(tag) + + def _aggregate_weights(self, tag: str) -> None: + channel = self.cm.get_by_tag(tag) + if not channel: + return + + total = 0 + # receive local model parameters from trainers + for msg, metadata in channel.recv_fifo(channel.ends()): + end, _ = metadata + if not msg: + logger.debug(f"No data from {end}; skipping it") + continue + + logger.debug(f"received data from {end}") + if MessageType.WEIGHTS in msg: + weights = weights_to_model_device(msg[MessageType.WEIGHTS], self.model) + + if MessageType.DATASET_SIZE in msg: + count = msg[MessageType.DATASET_SIZE] + + logger.debug(f"{end}'s parameters trained with {count} samples") + + if weights is not None and count > 0: + total += count + tres = TrainResult(weights, count) + # save training result from trainer in a disk cache + self.cache[end] = tres + + # optimizer conducts optimization (in this case, aggregation) + global_weights = self.optimizer.do( + deepcopy(self.weights), + self.cache, + total=total, + num_trainers=len(channel.ends()), + ) + if global_weights is None: + logger.debug("failed model aggregation") + time.sleep(1) + return + + # set global weights + self.weights = global_weights + + # update model with global weights + self._update_model() + + def put(self, tag: str) -> None: + """Set data to remote role(s).""" + if tag == TAG_DISTRIBUTE: + self.dist_tag = tag + self._distribute_weights(tag) + + def _distribute_weights(self, tag: str) -> None: + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"channel not found for tag {tag}") + return + + # this call waits for at least one peer to join this channel + channel.await_join() + + # before distributing weights, update it from global model + self._update_weights() + + # send out global model parameters to trainers + for end in channel.ends(): + logger.debug(f"sending weights to {end}") + channel.send( + end, + { + MessageType.WEIGHTS: weights_to_device( + self.weights, DeviceType.CPU + ), + MessageType.ROUND: self._round, + }, + ) + + def inform_end_of_training(self) -> None: + """Inform all the trainers that the training is finished.""" + channel = self.cm.get_by_tag(self.dist_tag) + if not channel: + logger.debug(f"channel not found for tag {self.dist_tag}") + return + + channel.broadcast({MessageType.EOT: self._work_done}) + logger.debug("done broadcasting end-of-training") + + def run_analysis(self): + """Run analysis plugins and update results to metrics.""" + logger.debug("running analyzer plugins") + + plugins = self.plugin_manager.get_plugins(PluginType.ANALYZER) + for plugin in plugins: + # get callback function and call it + func = plugin.callback() + metrics = func(self.model, self.dataset) + if not metrics: + continue + + self.update_metrics(metrics) + + def save_metrics(self): + """Save metrics in a model registry.""" + logger.debug(f"saving metrics: {self.metrics}") + if self.metrics: + self.registry_client.save_metrics(self._round - 1, self.metrics) + logger.debug("saving metrics done") + + def increment_round(self): + """Increment the round counter.""" + logger.debug(f"Incrementing current round: {self._round}") + logger.debug(f"Total rounds: {self._rounds}") + self._round += 1 + self._work_done = self._round > self._rounds + + channel = self.cm.get_by_tag(self.dist_tag) + if not channel: + logger.debug(f"channel not found for tag {self.dist_tag}") + return + + logger.debug(f"Incremented round to {self._round}") + # set necessary properties to help channel decide how to select ends + channel.set_property("round", self._round) + + def save_params(self): + """Save hyperparamets in a model registry.""" + if self.config.hyperparameters: + self.registry_client.save_params(self.config.hyperparameters) + + def save_model(self): + """Save model in a model registry.""" + if self.model: + model_name = f"{self.config.job.name}-{self.config.job.job_id}" + self.registry_client.save_model(model_name, self.model) + + def update_metrics(self, metrics: dict[str, float]): + """Update metrics.""" + self.metrics = self.metrics | metrics + + def _update_model(self): + if self.framework == MLFramework.PYTORCH: + self.model.load_state_dict(self.weights) + elif self.framework == MLFramework.TENSORFLOW: + self.model.set_weights(self.weights) + + def _update_weights(self): + if self.framework == MLFramework.PYTORCH: + self.weights = self.model.state_dict() + elif self.framework == MLFramework.TENSORFLOW: + self.weights = self.model.get_weights() + + def compose(self) -> None: + """Compose role with tasklets.""" + with Composer() as composer: + self.composer = composer + + task_internal_init = Tasklet(self.internal_init) + + task_init = Tasklet(self.initialize) + + task_load_data = Tasklet(self.load_data) + + task_put = Tasklet(self.put, TAG_DISTRIBUTE) + + task_get = Tasklet(self.get, TAG_AGGREGATE) + + task_train = Tasklet(self.train) + + task_eval = Tasklet(self.evaluate) + + task_analysis = Tasklet(self.run_analysis) + + task_save_metrics = Tasklet(self.save_metrics) + + task_increment_round = Tasklet(self.increment_round) + + task_end_of_training = Tasklet(self.inform_end_of_training) + + task_save_params = Tasklet(self.save_params) + + task_save_model = Tasklet(self.save_model) + + # create a loop object with loop exit condition function + loop = Loop(loop_check_fn=lambda: self._work_done) + ( + task_internal_init + >> task_load_data + >> task_init + >> loop( + task_put + >> task_get + >> task_train + >> task_eval + >> task_analysis + >> task_save_metrics + >> task_increment_round + ) + >> task_end_of_training + >> task_save_params + >> task_save_model + ) + + def run(self) -> None: + """Run role.""" + self.composer.run() + + @classmethod + def get_func_tags(cls) -> list[str]: + """Return a list of function tags defined in the top level aggregator role.""" + return [TAG_DISTRIBUTE, TAG_AGGREGATE] diff --git a/lib/python/flame/mode/horizontal/syncfl/trainer.py b/lib/python/flame/mode/horizontal/syncfl/trainer.py new file mode 100644 index 000000000..41314e330 --- /dev/null +++ b/lib/python/flame/mode/horizontal/syncfl/trainer.py @@ -0,0 +1,229 @@ +# Copyright 2022 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""horizontal FL trainer.""" + +import logging + +from flame.channel import VAL_CH_STATE_RECV, VAL_CH_STATE_SEND +from flame.channel_manager import ChannelManager +from flame.common.constants import DeviceType, TrainerState +from flame.common.custom_abcmeta import ABCMeta, abstract_attribute +from flame.common.util import ( + MLFramework, + delta_weights_pytorch, + delta_weights_tensorflow, + get_ml_framework_in_use, + mlflow_runname, + valid_frameworks, + weights_to_device, + weights_to_model_device, +) +from flame.config import Config +from flame.mode.composer import Composer +from flame.mode.message import MessageType +from flame.mode.role import Role +from flame.mode.tasklet import Loop, Tasklet +from flame.optimizers import optimizer_provider +from flame.registries import registry_provider + +logger = logging.getLogger(__name__) + +TAG_FETCH = "fetch" +TAG_UPLOAD = "upload" + + +class Trainer(Role, metaclass=ABCMeta): + """Trainer implements an ML training role.""" + + @abstract_attribute + def config(self) -> Config: + """Abstract attribute for config object.""" + + @abstract_attribute + def model(self): + """Abstract attribute for model object.""" + + @abstract_attribute + def dataset_size(self): + """Abstract attribute for size of dataset used to train.""" + + def internal_init(self) -> None: + """Initialize internal state for role.""" + self.cm = ChannelManager() + self.cm(self.config) + self.cm.join_all() + + self.registry_client = registry_provider.get(self.config.registry.sort) + # initialize registry client + self.registry_client(self.config.registry.uri, self.config.job.job_id) + + self.registry_client.setup_run(mlflow_runname(self.config)) + self.metrics = dict() + + # needed for trainer-side optimization algorithms such as fedprox + temp_opt = optimizer_provider.get( + self.config.optimizer.sort, **self.config.optimizer.kwargs + ) + self.regularizer = temp_opt.regularizer + + self._round = 1 + self._work_done = False + + self.framework = get_ml_framework_in_use() + if self.framework == MLFramework.UNKNOWN: + raise NotImplementedError( + "supported ml framework not found; " + f"supported frameworks are: {valid_frameworks}" + ) + + if self.framework == MLFramework.PYTORCH: + self._delta_weights_fn = delta_weights_pytorch + + elif self.framework == MLFramework.TENSORFLOW: + self._delta_weights_fn = delta_weights_tensorflow + + def get(self, tag: str) -> None: + """Get data from remote role(s).""" + if tag == TAG_FETCH: + self._fetch_weights(tag) + + def _fetch_weights(self, tag: str) -> None: + logger.debug("calling _fetch_weights") + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"[_fetch_weights] channel not found with tag {tag}") + return + + # this call waits for at least one peer joins this channel + channel.await_join() + + # one aggregator is sufficient + end = channel.one_end(VAL_CH_STATE_RECV) + msg, _ = channel.recv(end) + + if MessageType.WEIGHTS in msg: + self.weights = weights_to_model_device(msg[MessageType.WEIGHTS], self.model) + self._update_model() + + if MessageType.EOT in msg: + self._work_done = msg[MessageType.EOT] + + if MessageType.ROUND in msg: + self._round = msg[MessageType.ROUND] + + self.regularizer.save_state(TrainerState.PRE_TRAIN, glob_model=self.model) + logger.debug(f"work_done: {self._work_done}, round: {self._round}") + + def put(self, tag: str) -> None: + """Set data to remote role(s).""" + if tag == TAG_UPLOAD: + self._send_weights(tag) + + def _send_weights(self, tag: str) -> None: + logger.debug("calling _send_weights") + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"[_send_weights] channel not found with {tag}") + return + + # this call waits for at least one peer to join this channel + channel.await_join() + + # one aggregator is sufficient + end = channel.one_end(VAL_CH_STATE_SEND) + + self._update_weights() + self.regularizer.save_state(TrainerState.POST_TRAIN, loc_model=self.model) + + delta_weights = self._delta_weights_fn(self.weights, self.prev_weights) + + # send delta_weights to regularizer + self.regularizer.update() + + msg = { + MessageType.WEIGHTS: weights_to_device(delta_weights, DeviceType.CPU), + MessageType.DATASET_SIZE: self.dataset_size, + MessageType.MODEL_VERSION: self._round, + } + channel.send(end, msg) + logger.debug("sending weights done") + + def save_metrics(self): + """Save metrics in a model registry.""" + logger.debug(f"saving metrics: {self.metrics}") + if self.metrics: + self.registry_client.save_metrics(self._round - 1, self.metrics) + logger.debug("saving metrics done") + + def update_metrics(self, metrics: dict[str, float]): + """Update metrics.""" + self.metrics = self.metrics | metrics + + def _update_model(self): + if self.framework == MLFramework.PYTORCH: + self.model.load_state_dict(self.weights) + elif self.framework == MLFramework.TENSORFLOW: + self.model.set_weights(self.weights) + + def _update_weights(self): + # save weights before updating it + self.prev_weights = self.weights + + if self.framework == MLFramework.PYTORCH: + self.weights = self.model.state_dict() + elif self.framework == MLFramework.TENSORFLOW: + self.weights = self.model.get_weights() + + def compose(self) -> None: + """Compose role with tasklets.""" + with Composer() as composer: + self.composer = composer + + task_internal_init = Tasklet(self.internal_init) + + task_load_data = Tasklet(self.load_data) + + task_init = Tasklet(self.initialize) + + task_get = Tasklet(self.get, TAG_FETCH) + + task_train = Tasklet(self.train) + + task_eval = Tasklet(self.evaluate) + + task_put = Tasklet(self.put, TAG_UPLOAD) + + task_save_metrics = Tasklet(self.save_metrics) + + # create a loop object with loop exit condition function + loop = Loop(loop_check_fn=lambda: self._work_done) + ( + task_internal_init + >> task_load_data + >> task_init + >> loop( + task_get >> task_train >> task_eval >> task_put >> task_save_metrics + ) + ) + + def run(self) -> None: + """Run role.""" + self.composer.run() + + @classmethod + def get_func_tags(cls) -> list[str]: + """Return a list of function tags defined in the trainer role.""" + return [TAG_FETCH, TAG_UPLOAD] diff --git a/lib/python/flame/mode/horizontal/top_aggregator.py b/lib/python/flame/mode/horizontal/top_aggregator.py index 6da3d3d38..af14461ab 100644 --- a/lib/python/flame/mode/horizontal/top_aggregator.py +++ b/lib/python/flame/mode/horizontal/top_aggregator.py @@ -15,308 +15,13 @@ # SPDX-License-Identifier: Apache-2.0 """horizontal FL top level aggregator.""" -import logging -import time -from copy import deepcopy - -from diskcache import Cache - -from ...channel_manager import ChannelManager -from ...common.custom_abcmeta import ABCMeta, abstract_attribute -from ...common.util import ( - MLFramework, - get_ml_framework_in_use, - mlflow_runname, - valid_frameworks, - weights_to_device, - weights_to_model_device, +# pylint: disable=unused-import +from flame.mode.horizontal.syncfl.top_aggregator import ( + TAG_AGGREGATE, + TAG_DISTRIBUTE, + TopAggregator, ) -from ...common.constants import DeviceType -from ...optimizer.train_result import TrainResult -from ...optimizers import optimizer_provider -from ...plugin import PluginManager, PluginType -from ...registries import registry_provider -from ..composer import Composer -from ..message import MessageType -from ..role import Role -from ..tasklet import Loop, Tasklet -from ...config import Config - -logger = logging.getLogger(__name__) - -TAG_DISTRIBUTE = "distribute" -TAG_AGGREGATE = "aggregate" - - -class TopAggregator(Role, metaclass=ABCMeta): - """Top level Aggregator implements an ML aggregation role.""" - - @abstract_attribute - def config(self) -> Config: - """Abstract attribute for config object.""" - - @abstract_attribute - def model(self): - """Abstract attribute for model object.""" - - @abstract_attribute - def dataset(self): - """ - Abstract attribute for datset. - - dataset's type is Dataset (in flame/dataset.py). - """ - - def internal_init(self) -> None: - """Initialize internal state for role.""" - # global variable for plugin manager - self.plugin_manager = PluginManager() - - self.cm = ChannelManager() - self.cm(self.config) - self.cm.join_all() - - self.registry_client = registry_provider.get(self.config.registry.sort) - # initialize registry client - self.registry_client(self.config.registry.uri, self.config.job.job_id) - - base_model = self.config.base_model - if base_model and base_model.name != "" and base_model.version > 0: - self.model = self.registry_client.load_model( - base_model.name, base_model.version - ) - - self.registry_client.setup_run(mlflow_runname(self.config)) - self.metrics = dict() - - # disk cache is used for saving memory in case model is large - self.cache = Cache() - self.optimizer = optimizer_provider.get( - self.config.optimizer.sort, **self.config.optimizer.kwargs - ) - - self._round = 1 - self._rounds = 1 - self._rounds = self.config.hyperparameters.rounds - self._work_done = False - - self.framework = get_ml_framework_in_use() - if self.framework == MLFramework.UNKNOWN: - raise NotImplementedError( - "supported ml framework not found; " - f"supported frameworks are: {valid_frameworks}" - ) - - def get(self, tag: str) -> None: - """Get data from remote role(s).""" - if tag == TAG_AGGREGATE: - self._aggregate_weights(tag) - - def _aggregate_weights(self, tag: str) -> None: - channel = self.cm.get_by_tag(tag) - if not channel: - return - - total = 0 - # receive local model parameters from trainers - for msg, metadata in channel.recv_fifo(channel.ends()): - end, _ = metadata - if not msg: - logger.debug(f"No data from {end}; skipping it") - continue - - logger.debug(f"received data from {end}") - if MessageType.WEIGHTS in msg: - weights = weights_to_model_device(msg[MessageType.WEIGHTS], self.model) - - if MessageType.DATASET_SIZE in msg: - count = msg[MessageType.DATASET_SIZE] - - logger.debug(f"{end}'s parameters trained with {count} samples") - - if weights is not None and count > 0: - total += count - tres = TrainResult(weights, count) - # save training result from trainer in a disk cache - self.cache[end] = tres - - # optimizer conducts optimization (in this case, aggregation) - global_weights = self.optimizer.do(deepcopy(self.weights), - self.cache, - total=total, - num_trainers=len(channel.ends())) - if global_weights is None: - logger.debug("failed model aggregation") - time.sleep(1) - return - - # set global weights - self.weights = global_weights - - # update model with global weights - self._update_model() - - def put(self, tag: str) -> None: - """Set data to remote role(s).""" - if tag == TAG_DISTRIBUTE: - self.dist_tag = tag - self._distribute_weights(tag) - - def _distribute_weights(self, tag: str) -> None: - channel = self.cm.get_by_tag(tag) - if not channel: - logger.debug(f"channel not found for tag {tag}") - return - - # this call waits for at least one peer to join this channel - channel.await_join() - - # before distributing weights, update it from global model - self._update_weights() - - # send out global model parameters to trainers - for end in channel.ends(): - logger.debug(f"sending weights to {end}") - channel.send( - end, - { - MessageType.WEIGHTS: weights_to_device( - self.weights, DeviceType.CPU - ), - MessageType.ROUND: self._round, - }, - ) - - def inform_end_of_training(self) -> None: - """Inform all the trainers that the training is finished.""" - channel = self.cm.get_by_tag(self.dist_tag) - if not channel: - logger.debug(f"channel not found for tag {self.dist_tag}") - return - - channel.broadcast({MessageType.EOT: self._work_done}) - logger.debug("done broadcasting end-of-training") - - def run_analysis(self): - """Run analysis plugins and update results to metrics.""" - logger.debug("running analyzer plugins") - - plugins = self.plugin_manager.get_plugins(PluginType.ANALYZER) - for plugin in plugins: - # get callback function and call it - func = plugin.callback() - metrics = func(self.model, self.dataset) - if not metrics: - continue - - self.update_metrics(metrics) - - def save_metrics(self): - """Save metrics in a model registry.""" - logger.debug(f"saving metrics: {self.metrics}") - if self.metrics: - self.registry_client.save_metrics(self._round - 1, self.metrics) - logger.debug("saving metrics done") - - def increment_round(self): - """Increment the round counter.""" - logger.debug(f"Incrementing current round: {self._round}") - logger.debug(f"Total rounds: {self._rounds}") - self._round += 1 - self._work_done = self._round > self._rounds - - channel = self.cm.get_by_tag(self.dist_tag) - if not channel: - logger.debug(f"channel not found for tag {self.dist_tag}") - return - - logger.debug(f"Incremented round to {self._round}") - # set necessary properties to help channel decide how to select ends - channel.set_property("round", self._round) - - def save_params(self): - """Save hyperparamets in a model registry.""" - if self.config.hyperparameters: - self.registry_client.save_params(self.config.hyperparameters) - - def save_model(self): - """Save model in a model registry.""" - if self.model: - model_name = f"{self.config.job.name}-{self.config.job.job_id}" - self.registry_client.save_model(model_name, self.model) - - def update_metrics(self, metrics: dict[str, float]): - """Update metrics.""" - self.metrics = self.metrics | metrics - - def _update_model(self): - if self.framework == MLFramework.PYTORCH: - self.model.load_state_dict(self.weights) - elif self.framework == MLFramework.TENSORFLOW: - self.model.set_weights(self.weights) - - def _update_weights(self): - if self.framework == MLFramework.PYTORCH: - self.weights = self.model.state_dict() - elif self.framework == MLFramework.TENSORFLOW: - self.weights = self.model.get_weights() - - def compose(self) -> None: - """Compose role with tasklets.""" - with Composer() as composer: - self.composer = composer - - task_internal_init = Tasklet(self.internal_init) - - task_init = Tasklet(self.initialize) - - task_load_data = Tasklet(self.load_data) - - task_put = Tasklet(self.put, TAG_DISTRIBUTE) - - task_get = Tasklet(self.get, TAG_AGGREGATE) - - task_train = Tasklet(self.train) - - task_eval = Tasklet(self.evaluate) - - task_analysis = Tasklet(self.run_analysis) - - task_save_metrics = Tasklet(self.save_metrics) - - task_increment_round = Tasklet(self.increment_round) - - task_end_of_training = Tasklet(self.inform_end_of_training) - - task_save_params = Tasklet(self.save_params) - - task_save_model = Tasklet(self.save_model) - - # create a loop object with loop exit condition function - loop = Loop(loop_check_fn=lambda: self._work_done) - ( - task_internal_init - >> task_load_data - >> task_init - >> loop( - task_put - >> task_get - >> task_train - >> task_eval - >> task_analysis - >> task_save_metrics - >> task_increment_round - ) - >> task_end_of_training - >> task_save_params - >> task_save_model - ) - - def run(self) -> None: - """Run role.""" - self.composer.run() - @classmethod - def get_func_tags(cls) -> list[str]: - """Return a list of function tags defined in the top level aggregator role.""" - return [TAG_DISTRIBUTE, TAG_AGGREGATE] +# Redirect `flame.mode.horizontal.top_aggregator` to +# `flame.mode.horizontal.syncfl.top_aggregator` +# This is for backwards compatibility diff --git a/lib/python/flame/mode/horizontal/trainer.py b/lib/python/flame/mode/horizontal/trainer.py index d0717d0b5..7180400d6 100644 --- a/lib/python/flame/mode/horizontal/trainer.py +++ b/lib/python/flame/mode/horizontal/trainer.py @@ -15,217 +15,9 @@ # SPDX-License-Identifier: Apache-2.0 """horizontal FL trainer.""" -import logging +# pylint: disable=unused-import +from flame.mode.horizontal.syncfl.trainer import Trainer -from copy import deepcopy -from ...channel import VAL_CH_STATE_RECV, VAL_CH_STATE_SEND -from ...channel_manager import ChannelManager -from ...common.constants import TrainerState -from ...common.custom_abcmeta import ABCMeta, abstract_attribute -from ...common.util import ( - MLFramework, - delta_weights_pytorch, - delta_weights_tensorflow, - get_ml_framework_in_use, - mlflow_runname, - valid_frameworks, - weights_to_device, - weights_to_model_device, -) -from ...optimizers import optimizer_provider -from ...common.constants import DeviceType -from ...registries import registry_provider -from ..composer import Composer -from ..message import MessageType -from ..role import Role -from ..tasklet import Loop, Tasklet -from ...config import Config - -logger = logging.getLogger(__name__) - -TAG_FETCH = "fetch" -TAG_UPLOAD = "upload" - - -class Trainer(Role, metaclass=ABCMeta): - """Trainer implements an ML training role.""" - - @abstract_attribute - def config(self) -> Config: - """Abstract attribute for config object.""" - - @abstract_attribute - def model(self): - """Abstract attribute for model object.""" - - @abstract_attribute - def dataset_size(self): - """Abstract attribute for size of dataset used to train.""" - - def internal_init(self) -> None: - """Initialize internal state for role.""" - self.cm = ChannelManager() - self.cm(self.config) - self.cm.join_all() - - self.registry_client = registry_provider.get(self.config.registry.sort) - # initialize registry client - self.registry_client(self.config.registry.uri, self.config.job.job_id) - - self.registry_client.setup_run(mlflow_runname(self.config)) - self.metrics = dict() - - # needed for trainer-side optimization algorithms such as fedprox - temp_opt = optimizer_provider.get( - self.config.optimizer.sort, **self.config.optimizer.kwargs - ) - self.regularizer = temp_opt.regularizer - - self._round = 1 - self._work_done = False - - self.framework = get_ml_framework_in_use() - if self.framework == MLFramework.UNKNOWN: - raise NotImplementedError( - "supported ml framework not found; " - f"supported frameworks are: {valid_frameworks}" - ) - - if self.framework == MLFramework.PYTORCH: - self._delta_weights_fn = delta_weights_pytorch - - elif self.framework == MLFramework.TENSORFLOW: - self._delta_weights_fn = delta_weights_tensorflow - - def get(self, tag: str) -> None: - """Get data from remote role(s).""" - if tag == TAG_FETCH: - self._fetch_weights(tag) - - def _fetch_weights(self, tag: str) -> None: - logger.debug("calling _fetch_weights") - channel = self.cm.get_by_tag(tag) - if not channel: - logger.debug(f"[_fetch_weights] channel not found with tag {tag}") - return - - # this call waits for at least one peer joins this channel - channel.await_join() - - # one aggregator is sufficient - end = channel.one_end(VAL_CH_STATE_RECV) - msg, _ = channel.recv(end) - - if MessageType.WEIGHTS in msg: - self.weights = weights_to_model_device(msg[MessageType.WEIGHTS], self.model) - self._update_model() - - if MessageType.EOT in msg: - self._work_done = msg[MessageType.EOT] - - if MessageType.ROUND in msg: - self._round = msg[MessageType.ROUND] - - self.regularizer.save_state(TrainerState.PRE_TRAIN, glob_model = self.model) - logger.debug(f"work_done: {self._work_done}, round: {self._round}") - - def put(self, tag: str) -> None: - """Set data to remote role(s).""" - if tag == TAG_UPLOAD: - self._send_weights(tag) - - def _send_weights(self, tag: str) -> None: - logger.debug("calling _send_weights") - channel = self.cm.get_by_tag(tag) - if not channel: - logger.debug(f"[_send_weights] channel not found with {tag}") - return - - # this call waits for at least one peer to join this channel - channel.await_join() - - # one aggregator is sufficient - end = channel.one_end(VAL_CH_STATE_SEND) - - self._update_weights() - self.regularizer.save_state(TrainerState.POST_TRAIN, loc_model=self.model) - - delta_weights = self._delta_weights_fn(self.weights, self.prev_weights) - - # send delta_weights to regularizer - self.regularizer.update() - - msg = { - MessageType.WEIGHTS: weights_to_device(delta_weights, DeviceType.CPU), - MessageType.DATASET_SIZE: self.dataset_size, - MessageType.MODEL_VERSION: self._round, - } - channel.send(end, msg) - logger.debug("sending weights done") - - def save_metrics(self): - """Save metrics in a model registry.""" - logger.debug(f"saving metrics: {self.metrics}") - if self.metrics: - self.registry_client.save_metrics(self._round - 1, self.metrics) - logger.debug("saving metrics done") - - def update_metrics(self, metrics: dict[str, float]): - """Update metrics.""" - self.metrics = self.metrics | metrics - - def _update_model(self): - if self.framework == MLFramework.PYTORCH: - self.model.load_state_dict(self.weights) - elif self.framework == MLFramework.TENSORFLOW: - self.model.set_weights(self.weights) - - def _update_weights(self): - # save weights before updating it - self.prev_weights = self.weights - - if self.framework == MLFramework.PYTORCH: - self.weights = self.model.state_dict() - elif self.framework == MLFramework.TENSORFLOW: - self.weights = self.model.get_weights() - - def compose(self) -> None: - """Compose role with tasklets.""" - with Composer() as composer: - self.composer = composer - - task_internal_init = Tasklet(self.internal_init) - - task_load_data = Tasklet(self.load_data) - - task_init = Tasklet(self.initialize) - - task_get = Tasklet(self.get, TAG_FETCH) - - task_train = Tasklet(self.train) - - task_eval = Tasklet(self.evaluate) - - task_put = Tasklet(self.put, TAG_UPLOAD) - - task_save_metrics = Tasklet(self.save_metrics) - - # create a loop object with loop exit condition function - loop = Loop(loop_check_fn=lambda: self._work_done) - ( - task_internal_init - >> task_load_data - >> task_init - >> loop( - task_get >> task_train >> task_eval >> task_put >> task_save_metrics - ) - ) - - def run(self) -> None: - """Run role.""" - self.composer.run() - - @classmethod - def get_func_tags(cls) -> list[str]: - """Return a list of function tags defined in the trainer role.""" - return [TAG_FETCH, TAG_UPLOAD] +# Redirect `flame.mode.horizontal.trainer.Trainer` to +# `flame.mode.horizontal.syncfl.trainer.Trainer +# This is for backwards compatibility