diff --git a/lib/python/flame/mode/distributed/trainer.py b/lib/python/flame/mode/distributed/trainer.py index b6d90eb9c..14769db74 100644 --- a/lib/python/flame/mode/distributed/trainer.py +++ b/lib/python/flame/mode/distributed/trainer.py @@ -23,15 +23,10 @@ from flame.channel_manager import ChannelManager from flame.common.constants import DeviceType from flame.common.custom_abcmeta import ABCMeta, abstract_attribute -from flame.common.util import ( - MLFramework, - delta_weights_pytorch, - delta_weights_tensorflow, - get_ml_framework_in_use, - valid_frameworks, - weights_to_device, - weights_to_model_device, -) +from flame.common.util import (MLFramework, delta_weights_pytorch, + delta_weights_tensorflow, + get_ml_framework_in_use, valid_frameworks, + weights_to_device, weights_to_model_device) from flame.config import Config from flame.mode.composer import Composer from flame.mode.message import MessageType @@ -444,6 +439,7 @@ def save_metrics(self): if self.metrics: self.registry_client.save_metrics(self._round, self.metrics) logger.debug("saving metrics done") + self.metrics = dict() def update_metrics(self, metrics: dict[str, float]): """Update metrics.""" @@ -501,29 +497,33 @@ def compose(self) -> None: with Composer() as composer: self.composer = composer - task_init_cm = Tasklet("", self.init_cm) + task_init_cm = Tasklet("init_cm", self.init_cm) - task_internal_init = Tasklet("", self.internal_init) + task_internal_init = Tasklet("internal_init", self.internal_init) - task_load_data = Tasklet("", self.load_data) + task_load_data = Tasklet("load_data", self.load_data) - task_init = Tasklet("", self.initialize) + task_init = Tasklet("initialize", self.initialize) - task_member_check = Tasklet("", self._member_check, TAG_RING_ALLREDUCE) + task_member_check = Tasklet( + "member_check", self._member_check, TAG_RING_ALLREDUCE + ) - task_allreduce = Tasklet("", self._ring_allreduce, TAG_RING_ALLREDUCE) + task_allreduce = Tasklet( + "ring_allreduce", self._ring_allreduce, TAG_RING_ALLREDUCE + ) - task_train = Tasklet("", self.train) + task_train = Tasklet("train", self.train) - task_eval = Tasklet("", self.evaluate) + task_eval = Tasklet("evaluate", self.evaluate) - task_increment_round = Tasklet("", self.increment_round) + task_increment_round = Tasklet("inc_round", self.increment_round) - task_save_metrics = Tasklet("", self.save_metrics) + task_save_metrics = Tasklet("save_metrics", self.save_metrics) - task_save_params = Tasklet("", self.save_params) + task_save_params = Tasklet("save_params", self.save_params) - task_save_model = Tasklet("", self.save_model) + task_save_model = Tasklet("save_model", self.save_model) # create a loop object with loop exit condition function loop = Loop(loop_check_fn=lambda: self._work_done) diff --git a/lib/python/flame/mode/horizontal/coord_syncfl/coordinator.py b/lib/python/flame/mode/horizontal/coord_syncfl/coordinator.py index d21ee4da3..f0e9a3011 100644 --- a/lib/python/flame/mode/horizontal/coord_syncfl/coordinator.py +++ b/lib/python/flame/mode/horizontal/coord_syncfl/coordinator.py @@ -225,21 +225,32 @@ def compose(self) -> None: with Composer() as composer: self.composer = composer - task_await = Tasklet("", self.await_mid_aggs_and_trainers) + task_await = Tasklet( + "await_mid_aggs_and_trainers", self.await_mid_aggs_and_trainers + ) - task_pairing = Tasklet("", self.pair_mid_aggs_and_trainers) + task_pairing = Tasklet( + "pair_mid_aggs_and_trainers", self.pair_mid_aggs_and_trainers + ) task_send_mid_aggs_to_top_agg = Tasklet( - "", self.send_selected_middle_aggregators + "send_selected_middle_aggregators", + self.send_selected_middle_aggregators, ) - task_send_trainers_to_agg = Tasklet("", self.send_selected_trainers) + task_send_trainers_to_agg = Tasklet( + "send_selected_trainers", self.send_selected_trainers + ) - task_send_agg_to_trainer = Tasklet("", self.send_selected_middle_aggregator) + task_send_agg_to_trainer = Tasklet( + "send_selected_middle_aggregator", self.send_selected_middle_aggregator + ) - task_increment_round = Tasklet("", self.increment_round) + task_increment_round = Tasklet("inc_round", self.increment_round) - task_inform_eot = Tasklet("", self.inform_end_of_training) + task_inform_eot = Tasklet( + "inform_end_of_training", self.inform_end_of_training + ) loop = Loop(loop_check_fn=lambda: self._work_done) ( diff --git a/lib/python/flame/mode/horizontal/oort/trainer.py b/lib/python/flame/mode/horizontal/oort/trainer.py index 91bced97e..b6d24ae09 100644 --- a/lib/python/flame/mode/horizontal/oort/trainer.py +++ b/lib/python/flame/mode/horizontal/oort/trainer.py @@ -15,15 +15,15 @@ # SPDX-License-Identifier: Apache-2.0 """Oort horizontal FL top level aggregator.""" +import inspect import logging import math -import inspect import torch from flame.channel import VAL_CH_STATE_SEND from flame.common.constants import DeviceType -from flame.common.util import weights_to_device from flame.common.custom_abcmeta import abstract_attribute +from flame.common.util import weights_to_device from flame.mode.composer import Composer from flame.mode.horizontal.syncfl.trainer import TAG_FETCH, TAG_UPLOAD from flame.mode.horizontal.syncfl.trainer import Trainer as BaseTrainer @@ -79,7 +79,7 @@ def init_oort_variables(self) -> None: """Initialize Oort variables.""" self._stat_utility = 0 - if 'reduction' not in inspect.signature(self.loss_fn).parameters: + if "reduction" not in inspect.signature(self.loss_fn).parameters: msg = "Parameter 'reduction' not found in loss function " msg += f"'{self.loss_fn.__name__}', which is required for Oort" raise TypeError(msg) @@ -90,26 +90,28 @@ def oort_loss( target: torch.Tensor, epoch: int, batch_idx: int, - **kwargs + **kwargs, ) -> torch.Tensor: """ Measure the loss of a trainer during training. The trainer's statistical utility is measured at epoch 1. """ if epoch == 1 and batch_idx == 0: - if 'reduction' in kwargs.keys(): - reduction = kwargs['reduction'] + if "reduction" in kwargs.keys(): + reduction = kwargs["reduction"] else: - reduction = 'mean' # default reduction policy is mean - kwargs_wo_reduction = {key: value for key, value in kwargs.items() if key != 'reduction'} + reduction = "mean" # default reduction policy is mean + kwargs_wo_reduction = { + key: value for key, value in kwargs.items() if key != "reduction" + } - criterion = self.loss_fn(reduction='none', **kwargs_wo_reduction) + criterion = self.loss_fn(reduction="none", **kwargs_wo_reduction) loss_list = criterion(output, target) self._stat_utility += torch.square(loss_list).sum() - - if reduction == 'mean': + + if reduction == "mean": loss = loss_list.mean() - elif reduction == 'sum': + elif reduction == "sum": loss = loss_list.sum() else: criterion = self.loss_fn(**kwargs) @@ -138,23 +140,25 @@ def compose(self) -> None: with Composer() as composer: self.composer = composer - task_internal_init = Tasklet("", self.internal_init) + task_internal_init = Tasklet("internal_init", self.internal_init) - task_init_oort_variables = Tasklet("", self.init_oort_variables) + task_init_oort_variables = Tasklet( + "init_oort_variables", self.init_oort_variables + ) - task_load_data = Tasklet("", self.load_data) + task_load_data = Tasklet("load_data", self.load_data) - task_init = Tasklet("", self.initialize) + task_init = Tasklet("initialize", self.initialize) - task_get = Tasklet("", self.get, TAG_FETCH) + task_get = Tasklet("fetch", self.get, TAG_FETCH) - task_train = Tasklet("", self.train) + task_train = Tasklet("train", self.train) - task_eval = Tasklet("", self.evaluate) + task_eval = Tasklet("evaluate", self.evaluate) - task_put = Tasklet("", self.put, TAG_UPLOAD) + task_put = Tasklet("upload", self.put, TAG_UPLOAD) - task_save_metrics = Tasklet("", self.save_metrics) + task_save_metrics = Tasklet("save_metrics", self.save_metrics) # create a loop object with loop exit condition function loop = Loop(loop_check_fn=lambda: self._work_done) diff --git a/lib/python/flame/mode/horizontal/syncfl/top_aggregator.py b/lib/python/flame/mode/horizontal/syncfl/top_aggregator.py index 6ade1adc6..3010272ee 100644 --- a/lib/python/flame/mode/horizontal/syncfl/top_aggregator.py +++ b/lib/python/flame/mode/horizontal/syncfl/top_aggregator.py @@ -21,17 +21,15 @@ from datetime import datetime from diskcache import Cache + from flame.channel_manager import ChannelManager from flame.common.constants import DeviceType from flame.common.custom_abcmeta import ABCMeta, abstract_attribute -from flame.common.util import ( - MLFramework, - get_ml_framework_in_use, - valid_frameworks, - weights_to_device, - weights_to_model_device, -) +from flame.common.util import (MLFramework, get_ml_framework_in_use, + valid_frameworks, weights_to_device, + weights_to_model_device) from flame.config import Config +from flame.datasamplers import datasampler_provider from flame.mode.composer import Composer from flame.mode.message import MessageType from flame.mode.role import Role @@ -40,7 +38,6 @@ from flame.optimizers import optimizer_provider from flame.plugin import PluginManager, PluginType from flame.registries import registry_provider -from flame.datasamplers import datasampler_provider logger = logging.getLogger(__name__) @@ -147,9 +144,7 @@ def _aggregate_weights(self, tag: str) -> None: if MessageType.DATASAMPLER_METADATA in msg: self.datasampler.handle_metadata_from_trainer( - msg[MessageType.DATASAMPLER_METADATA], - end, - channel, + msg[MessageType.DATASAMPLER_METADATA], end, channel, ) logger.debug(f"{end}'s parameters trained with {count} samples") @@ -252,6 +247,7 @@ def save_metrics(self): if self.metrics: self.registry_client.save_metrics(self._round - 1, self.metrics) logger.debug("saving metrics done") + self.metrics = dict() def increment_round(self): """Increment the round counter.""" diff --git a/lib/python/flame/mode/horizontal/syncfl/trainer.py b/lib/python/flame/mode/horizontal/syncfl/trainer.py index 301222988..816331064 100644 --- a/lib/python/flame/mode/horizontal/syncfl/trainer.py +++ b/lib/python/flame/mode/horizontal/syncfl/trainer.py @@ -22,23 +22,18 @@ from flame.channel_manager import ChannelManager from flame.common.constants import DeviceType from flame.common.custom_abcmeta import ABCMeta, abstract_attribute -from flame.common.util import ( - MLFramework, - delta_weights_pytorch, - delta_weights_tensorflow, - get_ml_framework_in_use, - valid_frameworks, - weights_to_device, - weights_to_model_device, -) +from flame.common.util import (MLFramework, delta_weights_pytorch, + delta_weights_tensorflow, + get_ml_framework_in_use, valid_frameworks, + weights_to_device, weights_to_model_device) from flame.config import Config from flame.datasamplers import datasampler_provider -from flame.privacies import privacy_provider from flame.mode.composer import Composer from flame.mode.message import MessageType from flame.mode.role import Role from flame.mode.tasklet import Loop, Tasklet from flame.optimizers import optimizer_provider +from flame.privacies import privacy_provider from flame.registries import registry_provider logger = logging.getLogger(__name__) @@ -202,6 +197,7 @@ def save_metrics(self): if self.metrics: self.registry_client.save_metrics(self._round - 1, self.metrics) logger.debug("saving metrics done") + self.metrics = dict() def update_metrics(self, metrics: dict[str, float]): """Update metrics.""" diff --git a/lib/python/flame/mode/hybrid/trainer.py b/lib/python/flame/mode/hybrid/trainer.py index d3bcac86c..d38a4b944 100644 --- a/lib/python/flame/mode/hybrid/trainer.py +++ b/lib/python/flame/mode/hybrid/trainer.py @@ -210,31 +210,35 @@ def compose(self) -> None: with Composer() as composer: self.composer = composer - task_init_cm = Tasklet("", self.init_cm) + task_init_cm = Tasklet("init_cm", self.init_cm) - task_internal_init = Tasklet("", self.internal_init) + task_internal_init = Tasklet("internal_init", self.internal_init) - task_load_data = Tasklet("", self.load_data) + task_load_data = Tasklet("load_data", self.load_data) - task_init = Tasklet("", self.initialize) + task_init = Tasklet("initialize", self.initialize) - task_get = Tasklet("", self.get, TAG_FETCH) + task_get = Tasklet("fetch", self.get, TAG_FETCH) - task_member_check = Tasklet("", self._member_check, TAG_RING_ALLREDUCE) + task_member_check = Tasklet( + "member_check", self._member_check, TAG_RING_ALLREDUCE + ) - task_allreduce = Tasklet("", self._ring_allreduce, TAG_RING_ALLREDUCE) + task_allreduce = Tasklet( + "ring_allreduce", self._ring_allreduce, TAG_RING_ALLREDUCE + ) - task_train = Tasklet("", self.train) + task_train = Tasklet("train", self.train) - task_eval = Tasklet("", self.evaluate) + task_eval = Tasklet("evaluate", self.evaluate) - task_put = Tasklet("", self.put, TAG_UPLOAD) + task_put = Tasklet("upload", self.put, TAG_UPLOAD) - task_save_metrics = Tasklet("", self.save_metrics) + task_save_metrics = Tasklet("save_metrics", self.save_metrics) - task_save_params = Tasklet("", self.save_params) + task_save_params = Tasklet("save_params", self.save_params) - task_save_model = Tasklet("", self.save_model) + task_save_model = Tasklet("save_model", self.save_model) # create a loop object with loop exit condition function loop = Loop(loop_check_fn=lambda: self._work_done)