Skip to content

Commit

Permalink
fix+refactor: asyncfl loss divergence (#330)
Browse files Browse the repository at this point in the history
For asyncfl, a client (trainer) should send delta by subtracting local
weights from original global weights after training. In the current
implementation, the whole local weights were sent to a
server (aggregator). This causes loss divergence.

Supporting delta update requires refactoring of aggregators of
synchronous fl (horizontal/{top_aggregator.py, middle_aggregator.py})
as well as optimizers' do() function.

The changes here support delta update universally across all types of
modes (horizontal synchronous, asynchronous, and hybrid).
  • Loading branch information
myungjin committed Feb 10, 2023
1 parent d2115d0 commit 9c55902
Show file tree
Hide file tree
Showing 21 changed files with 219 additions and 112 deletions.
4 changes: 3 additions & 1 deletion lib/python/flame/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

"""Definitions on Types."""

from typing import Union

Scalar = Union[bool, bytes, float, int, str]

Metrics = dict[str, Scalar]

# list for tensorflow, dict for pytorach
ModelWeights = Union[list, dict]
21 changes: 20 additions & 1 deletion lib/python/flame/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from contextlib import contextmanager
from enum import Enum
from threading import Thread
from typing import List
from typing import List, Union

from pip._internal.cli.main import main as pipmain

from ..config import Config
from .typing import ModelWeights

PYTORCH = 'torch'
TENSORFLOW = 'tensorflow'
Expand Down Expand Up @@ -115,3 +116,21 @@ def mlflow_runname(config: Config) -> str:
groupby_value = groupby_value + val + "-"

return config.role + '-' + groupby_value + config.task_id[:8]


def delta_weights_pytorch(a: ModelWeights,
b: ModelWeights) -> Union[ModelWeights, None]:
"""Return delta weights for pytorch model weights."""
if a is None or b is None:
return None

return {x: a[x] - b[y] for (x, y) in zip(a, b)}


def delta_weights_tensorflow(a: ModelWeights,
b: ModelWeights) -> Union[ModelWeights, None]:
"""Return delta weights for tensorflow model weights."""
if a is None or b is None:
return None

return [x - y for (x, y) in zip(a, b)]
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
{
"taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa743",
"backend": "mqtt",
"backend": "p2p",
"brokers": [
{
"host": "flame-mosquitto",
"host": "localhost",
"sort": "mqtt"
},
{
"host": "localhost:10104",
"sort": "p2p"
}
],
"channels": [
Expand Down Expand Up @@ -64,12 +68,12 @@
"name": "mnist"
},
"registry": {
"sort": "mlflow",
"sort": "dummy",
"uri": "http://flame-mlflow:5000"
},
"selector": {
"sort": "random",
"kwargs": {"k": 1}
"sort": "default",
"kwargs": {}
},
"maxRunTime": 300,
"realm": "default/uk/london/org2/flame",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
{
"taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa744",
"backend": "mqtt",
"backend": "p2p",
"brokers": [
{
"host": "flame-mosquitto",
"host": "localhost",
"sort": "mqtt"
},
{
"host": "localhost:10104",
"sort": "p2p"
}
],
"channels": [
Expand Down Expand Up @@ -64,12 +68,12 @@
"name": "mnist"
},
"registry": {
"sort": "mlflow",
"sort": "dummy",
"uri": "http://flame-mlflow:5000"
},
"selector": {
"sort": "random",
"kwargs": {"k": 1}
"sort": "default",
"kwargs": {}
},
"maxRunTime": 300,
"realm": "default/us/west/org1/flame",
Expand Down
10 changes: 7 additions & 3 deletions lib/python/flame/examples/hier_mnist/top_aggregator/config.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
{
"taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa742",
"backend": "mqtt",
"backend": "p2p",
"brokers": [
{
"host": "flame-mosquitto",
"host": "localhost",
"sort": "mqtt"
},
{
"host": "localhost:10104",
"sort": "p2p"
}
],
"channels": [
Expand Down Expand Up @@ -45,7 +49,7 @@
"name": "mnist"
},
"registry": {
"sort": "mlflow",
"sort": "dummy",
"uri": "http://flame-mlflow:5000"
},
"selector": {
Expand Down
14 changes: 9 additions & 5 deletions lib/python/flame/examples/hier_mnist/trainer/config_uk.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
{
"taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa745",
"backend": "mqtt",
"backend": "p2p",
"brokers": [
{
"host": "flame-mosquitto",
"host": "localhost",
"sort": "mqtt"
},
{
"host": "localhost:10104",
"sort": "p2p"
}
],
"channels": [
Expand Down Expand Up @@ -46,12 +50,12 @@
"name": "mnist"
},
"registry": {
"sort": "mlflow",
"sort": "dummy",
"uri": "http://flame-mlflow:5000"
},
"selector": {
"sort": "random",
"kwargs": {"k": 1}
"sort": "default",
"kwargs": {}
},
"maxRunTime": 300,
"realm": "default/uk/london/org2/machine1",
Expand Down
14 changes: 9 additions & 5 deletions lib/python/flame/examples/hier_mnist/trainer/config_us.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
{
"taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa746",
"backend": "mqtt",
"backend": "p2p",
"brokers": [
{
"host": "flame-mosquitto",
"host": "localhost",
"sort": "mqtt"
},
{
"host": "localhost:10104",
"sort": "p2p"
}
],
"channels": [
Expand Down Expand Up @@ -46,12 +50,12 @@
"name": "mnist"
},
"registry": {
"sort": "mlflow",
"sort": "dummy",
"uri": "http://flame-mlflow:5000"
},
"selector": {
"sort": "random",
"kwargs": {"k": 1}
"sort": "default",
"kwargs": {}
},
"maxRunTime": 300,
"realm": "default/us/west/org1/machine1",
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/mnist/aggregator/config.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa742",
"backend": "mqtt",
"backend": "p2p",
"brokers": [
{
"host": "localhost",
Expand Down Expand Up @@ -38,7 +38,7 @@
"hyperparameters": {
"batchSize": 32,
"learningRate": 0.01,
"rounds": 5
"rounds": 20
},
"baseModel": {
"name": "",
Expand Down
2 changes: 1 addition & 1 deletion lib/python/flame/examples/mnist/trainer/config.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"taskid": "505f9fc483cf4df68a2409257b5fad7d3c580370",
"backend": "mqtt",
"backend": "p2p",
"brokers": [
{
"host": "localhost",
Expand Down
19 changes: 11 additions & 8 deletions lib/python/flame/mode/distributed/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
# import hashlib
import logging
from collections import OrderedDict
from copy import deepcopy

from ...channel_manager import ChannelManager
from ...common.custom_abcmeta import ABCMeta, abstract_attribute
from ...common.util import (MLFramework, get_ml_framework_in_use,
from ...common.util import (MLFramework, delta_weights_pytorch,
delta_weights_tensorflow, get_ml_framework_in_use,
mlflow_runname, valid_frameworks)
from ...registries import registry_provider
from ..composer import Composer
Expand Down Expand Up @@ -53,16 +55,14 @@ def init_cm(self) -> None:

def internal_init(self) -> None:
"""Initialize internal state for role."""

self.registry_client = registry_provider.get(self.config.registry.sort)
# initialize registry client
self.registry_client(self.config.registry.uri, self.config.job.job_id)

base_model = self.config.base_model
if base_model and base_model.name != "" and base_model.version > 0:
self.model = self.registry_client.load_model(
base_model.name, base_model.version
)
base_model.name, base_model.version)
self.ring_weights = None # latest model weights from ring all-reduce

self.registry_client.setup_run(mlflow_runname(self.config))
Expand All @@ -80,20 +80,21 @@ def internal_init(self) -> None:
if self.framework == MLFramework.UNKNOWN:
raise NotImplementedError(
"supported ml framework not found; "
f"supported frameworks are: {valid_frameworks}"
)
f"supported frameworks are: {valid_frameworks}")

if self.framework == MLFramework.PYTORCH:
self._scale_down_weights_fn = self._scale_down_weights_pytorch
self._get_send_chunk_fn = self._get_send_chunk_pytorch
self._allreduce_fn = self._allreduce_pytorch
self._allgather_fn = self._allgather_pytorch
self._delta_weights_fn = delta_weights_pytorch

elif self.framework == MLFramework.TENSORFLOW:
self._scale_down_weights_fn = self._scale_down_weights_tensorflow
self._get_send_chunk_fn = self._get_send_chunk_tensorflow
self._allreduce_fn = self._allreduce_tensorflow
self._allgather_fn = self._allgather_tensorflow
self._delta_weights_fn = delta_weights_tensorflow

def _ring_allreduce(self, tag: str) -> None:
if tag != TAG_RING_ALLREDUCE:
Expand Down Expand Up @@ -420,6 +421,9 @@ def _update_model(self):

def _update_weights(self):
"""Save weights from model."""
# save weights before updating it
self.prev_weights = deepcopy(self.weights)

if self.framework == MLFramework.PYTORCH:
self.weights = self.model.state_dict()
elif self.framework == MLFramework.TENSORFLOW:
Expand Down Expand Up @@ -484,8 +488,7 @@ def compose(self) -> None:
loop = Loop(loop_check_fn=lambda: self._work_done)
task_init_cm >> task_internal_init >> task_load_data >> task_init >> loop(
task_train >> task_allreduce >> task_eval >> task_save_metrics
>> task_increment_round
) >> task_save_params >> task_save_model
>> task_increment_round) >> task_save_params >> task_save_model

def run(self) -> None:
"""Run role."""
Expand Down
Loading

0 comments on commit 9c55902

Please sign in to comment.