Skip to content

Commit

Permalink
fix: clear metrics and tasklet alias (#466)
Browse files Browse the repository at this point in the history
After metrics are reported (usually every round in the save_metrics method) they are cleared.
This prevents a metric from being reported more than it should be.

Additionally, in order to properly report runtime-related metrics, the alias field was specified on all tasklets.
  • Loading branch information
GustavBaumgart committed Nov 30, 2023
1 parent e4fa6c3 commit 01b4584
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 83 deletions.
42 changes: 21 additions & 21 deletions lib/python/flame/mode/distributed/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,10 @@
from flame.channel_manager import ChannelManager
from flame.common.constants import DeviceType
from flame.common.custom_abcmeta import ABCMeta, abstract_attribute
from flame.common.util import (
MLFramework,
delta_weights_pytorch,
delta_weights_tensorflow,
get_ml_framework_in_use,
valid_frameworks,
weights_to_device,
weights_to_model_device,
)
from flame.common.util import (MLFramework, delta_weights_pytorch,
delta_weights_tensorflow,
get_ml_framework_in_use, valid_frameworks,
weights_to_device, weights_to_model_device)
from flame.config import Config
from flame.mode.composer import Composer
from flame.mode.message import MessageType
Expand Down Expand Up @@ -444,6 +439,7 @@ def save_metrics(self):
if self.metrics:
self.registry_client.save_metrics(self._round, self.metrics)
logger.debug("saving metrics done")
self.metrics = dict()

def update_metrics(self, metrics: dict[str, float]):
"""Update metrics."""
Expand Down Expand Up @@ -501,29 +497,33 @@ def compose(self) -> None:
with Composer() as composer:
self.composer = composer

task_init_cm = Tasklet("", self.init_cm)
task_init_cm = Tasklet("init_cm", self.init_cm)

task_internal_init = Tasklet("", self.internal_init)
task_internal_init = Tasklet("internal_init", self.internal_init)

task_load_data = Tasklet("", self.load_data)
task_load_data = Tasklet("load_data", self.load_data)

task_init = Tasklet("", self.initialize)
task_init = Tasklet("initialize", self.initialize)

task_member_check = Tasklet("", self._member_check, TAG_RING_ALLREDUCE)
task_member_check = Tasklet(
"member_check", self._member_check, TAG_RING_ALLREDUCE
)

task_allreduce = Tasklet("", self._ring_allreduce, TAG_RING_ALLREDUCE)
task_allreduce = Tasklet(
"ring_allreduce", self._ring_allreduce, TAG_RING_ALLREDUCE
)

task_train = Tasklet("", self.train)
task_train = Tasklet("train", self.train)

task_eval = Tasklet("", self.evaluate)
task_eval = Tasklet("evaluate", self.evaluate)

task_increment_round = Tasklet("", self.increment_round)
task_increment_round = Tasklet("inc_round", self.increment_round)

task_save_metrics = Tasklet("", self.save_metrics)
task_save_metrics = Tasklet("save_metrics", self.save_metrics)

task_save_params = Tasklet("", self.save_params)
task_save_params = Tasklet("save_params", self.save_params)

task_save_model = Tasklet("", self.save_model)
task_save_model = Tasklet("save_model", self.save_model)

# create a loop object with loop exit condition function
loop = Loop(loop_check_fn=lambda: self._work_done)
Expand Down
25 changes: 18 additions & 7 deletions lib/python/flame/mode/horizontal/coord_syncfl/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,21 +225,32 @@ def compose(self) -> None:
with Composer() as composer:
self.composer = composer

task_await = Tasklet("", self.await_mid_aggs_and_trainers)
task_await = Tasklet(
"await_mid_aggs_and_trainers", self.await_mid_aggs_and_trainers
)

task_pairing = Tasklet("", self.pair_mid_aggs_and_trainers)
task_pairing = Tasklet(
"pair_mid_aggs_and_trainers", self.pair_mid_aggs_and_trainers
)

task_send_mid_aggs_to_top_agg = Tasklet(
"", self.send_selected_middle_aggregators
"send_selected_middle_aggregators",
self.send_selected_middle_aggregators,
)

task_send_trainers_to_agg = Tasklet("", self.send_selected_trainers)
task_send_trainers_to_agg = Tasklet(
"send_selected_trainers", self.send_selected_trainers
)

task_send_agg_to_trainer = Tasklet("", self.send_selected_middle_aggregator)
task_send_agg_to_trainer = Tasklet(
"send_selected_middle_aggregator", self.send_selected_middle_aggregator
)

task_increment_round = Tasklet("", self.increment_round)
task_increment_round = Tasklet("inc_round", self.increment_round)

task_inform_eot = Tasklet("", self.inform_end_of_training)
task_inform_eot = Tasklet(
"inform_end_of_training", self.inform_end_of_training
)

loop = Loop(loop_check_fn=lambda: self._work_done)
(
Expand Down
46 changes: 25 additions & 21 deletions lib/python/flame/mode/horizontal/oort/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
# SPDX-License-Identifier: Apache-2.0
"""Oort horizontal FL top level aggregator."""

import inspect
import logging
import math
import inspect

import torch
from flame.channel import VAL_CH_STATE_SEND
from flame.common.constants import DeviceType
from flame.common.util import weights_to_device
from flame.common.custom_abcmeta import abstract_attribute
from flame.common.util import weights_to_device
from flame.mode.composer import Composer
from flame.mode.horizontal.syncfl.trainer import TAG_FETCH, TAG_UPLOAD
from flame.mode.horizontal.syncfl.trainer import Trainer as BaseTrainer
Expand Down Expand Up @@ -79,7 +79,7 @@ def init_oort_variables(self) -> None:
"""Initialize Oort variables."""
self._stat_utility = 0

if 'reduction' not in inspect.signature(self.loss_fn).parameters:
if "reduction" not in inspect.signature(self.loss_fn).parameters:
msg = "Parameter 'reduction' not found in loss function "
msg += f"'{self.loss_fn.__name__}', which is required for Oort"
raise TypeError(msg)
Expand All @@ -90,26 +90,28 @@ def oort_loss(
target: torch.Tensor,
epoch: int,
batch_idx: int,
**kwargs
**kwargs,
) -> torch.Tensor:
"""
Measure the loss of a trainer during training.
The trainer's statistical utility is measured at epoch 1.
"""
if epoch == 1 and batch_idx == 0:
if 'reduction' in kwargs.keys():
reduction = kwargs['reduction']
if "reduction" in kwargs.keys():
reduction = kwargs["reduction"]
else:
reduction = 'mean' # default reduction policy is mean
kwargs_wo_reduction = {key: value for key, value in kwargs.items() if key != 'reduction'}
reduction = "mean" # default reduction policy is mean
kwargs_wo_reduction = {
key: value for key, value in kwargs.items() if key != "reduction"
}

criterion = self.loss_fn(reduction='none', **kwargs_wo_reduction)
criterion = self.loss_fn(reduction="none", **kwargs_wo_reduction)
loss_list = criterion(output, target)
self._stat_utility += torch.square(loss_list).sum()
if reduction == 'mean':

if reduction == "mean":
loss = loss_list.mean()
elif reduction == 'sum':
elif reduction == "sum":
loss = loss_list.sum()
else:
criterion = self.loss_fn(**kwargs)
Expand Down Expand Up @@ -138,23 +140,25 @@ def compose(self) -> None:
with Composer() as composer:
self.composer = composer

task_internal_init = Tasklet("", self.internal_init)
task_internal_init = Tasklet("internal_init", self.internal_init)

task_init_oort_variables = Tasklet("", self.init_oort_variables)
task_init_oort_variables = Tasklet(
"init_oort_variables", self.init_oort_variables
)

task_load_data = Tasklet("", self.load_data)
task_load_data = Tasklet("load_data", self.load_data)

task_init = Tasklet("", self.initialize)
task_init = Tasklet("initialize", self.initialize)

task_get = Tasklet("", self.get, TAG_FETCH)
task_get = Tasklet("fetch", self.get, TAG_FETCH)

task_train = Tasklet("", self.train)
task_train = Tasklet("train", self.train)

task_eval = Tasklet("", self.evaluate)
task_eval = Tasklet("evaluate", self.evaluate)

task_put = Tasklet("", self.put, TAG_UPLOAD)
task_put = Tasklet("upload", self.put, TAG_UPLOAD)

task_save_metrics = Tasklet("", self.save_metrics)
task_save_metrics = Tasklet("save_metrics", self.save_metrics)

# create a loop object with loop exit condition function
loop = Loop(loop_check_fn=lambda: self._work_done)
Expand Down
18 changes: 7 additions & 11 deletions lib/python/flame/mode/horizontal/syncfl/top_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@
from datetime import datetime

from diskcache import Cache

from flame.channel_manager import ChannelManager
from flame.common.constants import DeviceType
from flame.common.custom_abcmeta import ABCMeta, abstract_attribute
from flame.common.util import (
MLFramework,
get_ml_framework_in_use,
valid_frameworks,
weights_to_device,
weights_to_model_device,
)
from flame.common.util import (MLFramework, get_ml_framework_in_use,
valid_frameworks, weights_to_device,
weights_to_model_device)
from flame.config import Config
from flame.datasamplers import datasampler_provider
from flame.mode.composer import Composer
from flame.mode.message import MessageType
from flame.mode.role import Role
Expand All @@ -40,7 +38,6 @@
from flame.optimizers import optimizer_provider
from flame.plugin import PluginManager, PluginType
from flame.registries import registry_provider
from flame.datasamplers import datasampler_provider

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -147,9 +144,7 @@ def _aggregate_weights(self, tag: str) -> None:

if MessageType.DATASAMPLER_METADATA in msg:
self.datasampler.handle_metadata_from_trainer(
msg[MessageType.DATASAMPLER_METADATA],
end,
channel,
msg[MessageType.DATASAMPLER_METADATA], end, channel,
)

logger.debug(f"{end}'s parameters trained with {count} samples")
Expand Down Expand Up @@ -252,6 +247,7 @@ def save_metrics(self):
if self.metrics:
self.registry_client.save_metrics(self._round - 1, self.metrics)
logger.debug("saving metrics done")
self.metrics = dict()

def increment_round(self):
"""Increment the round counter."""
Expand Down
16 changes: 6 additions & 10 deletions lib/python/flame/mode/horizontal/syncfl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,18 @@
from flame.channel_manager import ChannelManager
from flame.common.constants import DeviceType
from flame.common.custom_abcmeta import ABCMeta, abstract_attribute
from flame.common.util import (
MLFramework,
delta_weights_pytorch,
delta_weights_tensorflow,
get_ml_framework_in_use,
valid_frameworks,
weights_to_device,
weights_to_model_device,
)
from flame.common.util import (MLFramework, delta_weights_pytorch,
delta_weights_tensorflow,
get_ml_framework_in_use, valid_frameworks,
weights_to_device, weights_to_model_device)
from flame.config import Config
from flame.datasamplers import datasampler_provider
from flame.privacies import privacy_provider
from flame.mode.composer import Composer
from flame.mode.message import MessageType
from flame.mode.role import Role
from flame.mode.tasklet import Loop, Tasklet
from flame.optimizers import optimizer_provider
from flame.privacies import privacy_provider
from flame.registries import registry_provider

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -202,6 +197,7 @@ def save_metrics(self):
if self.metrics:
self.registry_client.save_metrics(self._round - 1, self.metrics)
logger.debug("saving metrics done")
self.metrics = dict()

def update_metrics(self, metrics: dict[str, float]):
"""Update metrics."""
Expand Down
30 changes: 17 additions & 13 deletions lib/python/flame/mode/hybrid/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,31 +210,35 @@ def compose(self) -> None:
with Composer() as composer:
self.composer = composer

task_init_cm = Tasklet("", self.init_cm)
task_init_cm = Tasklet("init_cm", self.init_cm)

task_internal_init = Tasklet("", self.internal_init)
task_internal_init = Tasklet("internal_init", self.internal_init)

task_load_data = Tasklet("", self.load_data)
task_load_data = Tasklet("load_data", self.load_data)

task_init = Tasklet("", self.initialize)
task_init = Tasklet("initialize", self.initialize)

task_get = Tasklet("", self.get, TAG_FETCH)
task_get = Tasklet("fetch", self.get, TAG_FETCH)

task_member_check = Tasklet("", self._member_check, TAG_RING_ALLREDUCE)
task_member_check = Tasklet(
"member_check", self._member_check, TAG_RING_ALLREDUCE
)

task_allreduce = Tasklet("", self._ring_allreduce, TAG_RING_ALLREDUCE)
task_allreduce = Tasklet(
"ring_allreduce", self._ring_allreduce, TAG_RING_ALLREDUCE
)

task_train = Tasklet("", self.train)
task_train = Tasklet("train", self.train)

task_eval = Tasklet("", self.evaluate)
task_eval = Tasklet("evaluate", self.evaluate)

task_put = Tasklet("", self.put, TAG_UPLOAD)
task_put = Tasklet("upload", self.put, TAG_UPLOAD)

task_save_metrics = Tasklet("", self.save_metrics)
task_save_metrics = Tasklet("save_metrics", self.save_metrics)

task_save_params = Tasklet("", self.save_params)
task_save_params = Tasklet("save_params", self.save_params)

task_save_model = Tasklet("", self.save_model)
task_save_model = Tasklet("save_model", self.save_model)

# create a loop object with loop exit condition function
loop = Loop(loop_check_fn=lambda: self._work_done)
Expand Down

0 comments on commit 01b4584

Please sign in to comment.