Skip to content

Commit

Permalink
feature: hybrid mode with ring allreduce (cisco-open#170)
Browse files Browse the repository at this point in the history
Hybrid mode of combining distributed learning and federated learning
is implemented. This implementation is based on PR cisco-open#131.

The current implementation has a bug thta can cause a deadlock when
trainers arrive late. Addressing the issue will be handled as a
separate PR as it requires quite changes in backend, channel, channel
manager, etc.

Co-authored-by: Gaoxiang Luo <luo00042@umn.edu>
  • Loading branch information
myungjin and GaoxiangLuo authored Jul 8, 2022
1 parent 33d0ed2 commit 53594f9
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 72 deletions.
2 changes: 1 addition & 1 deletion examples/distributed_training/dataset_1.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

"dataFormat": "npy",

"realm": "default/us/org1",
"realm": "default/us/org",

"isPublic": true
}
2 changes: 1 addition & 1 deletion examples/distributed_training/dataset_2.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

"dataFormat": "npy",

"realm": "default/us/org2",
"realm": "default/us/org",

"isPublic": true
}
2 changes: 1 addition & 1 deletion examples/distributed_training/dataset_3.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

"dataFormat": "npy",

"realm": "default/us/org3",
"realm": "default/us/org",

"isPublic": true
}
4 changes: 2 additions & 2 deletions lib/python/flame/examples/hybrid/aggregator/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
"name": "hybrid_mnist"
},
"registry": {
"sort": "mlflow",
"uri": "http://flame-mlflow:5000"
"sort": "dummy",
"uri": ""
},
"selector": {
"sort": "default",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"trainer"
],
"funcTags": {
"trainer": ["receive", "send"]
"trainer": ["ring_allreduce"]
}
},
{
Expand Down Expand Up @@ -63,8 +63,8 @@
"name": "hybrid_mnist"
},
"registry": {
"sort": "mlflow",
"uri": "http://flame-mlflow:5000"
"sort": "dummy",
"uri": ""
},
"selector": {
"sort": "default",
Expand All @@ -75,6 +75,26 @@
"kwargs": {}
},
"maxRunTime": 300,
"realm": "default/eu/german",
"role": "trainer"
"realm": "default/eu/org",
"role": "trainer",
"channelConfigs": {
"param-channel": {
"backend": "mqtt",
"brokers": [
{
"host": "flame-mosquitto2",
"sort": "mqtt"
}
]
},
"global-channel": {
"backend": "mqtt",
"brokers": [
{
"host": "flame-mosquitto",
"sort": "mqtt"
}
]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"trainer"
],
"funcTags": {
"trainer": ["receive", "send"]
"trainer": ["ring_allreduce"]
}
},
{
Expand Down Expand Up @@ -63,8 +63,8 @@
"name": "hybrid_mnist"
},
"registry": {
"sort": "mlflow",
"uri": "http://flame-mlflow:5000"
"sort": "dummy",
"uri": ""
},
"selector": {
"sort": "default",
Expand All @@ -75,6 +75,26 @@
"kwargs": {}
},
"maxRunTime": 300,
"realm": "default/eu/uk",
"role": "trainer"
"realm": "default/eu/org",
"role": "trainer",
"channelConfigs": {
"param-channel": {
"backend": "mqtt",
"brokers": [
{
"host": "flame-mosquitto2",
"sort": "mqtt"
}
]
},
"global-channel": {
"backend": "mqtt",
"brokers": [
{
"host": "flame-mosquitto",
"sort": "mqtt"
}
]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"trainer"
],
"funcTags": {
"trainer": ["receive", "send"]
"trainer": ["ring_allreduce"]
}
},
{
Expand Down Expand Up @@ -63,8 +63,8 @@
"name": "hybrid_mnist"
},
"registry": {
"sort": "mlflow",
"uri": "http://flame-mlflow:5000"
"sort": "dummy",
"uri": ""
},
"selector": {
"sort": "default",
Expand All @@ -75,6 +75,26 @@
"kwargs": {}
},
"maxRunTime": 300,
"realm": "default/us/east",
"role": "trainer"
"realm": "default/us/org",
"role": "trainer",
"channelConfigs": {
"param-channel": {
"backend": "mqtt",
"brokers": [
{
"host": "flame-mosquitto2",
"sort": "mqtt"
}
]
},
"global-channel": {
"backend": "mqtt",
"brokers": [
{
"host": "flame-mosquitto",
"sort": "mqtt"
}
]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"trainer"
],
"funcTags": {
"trainer": ["receive", "send"]
"trainer": ["ring_allreduce"]
}
},
{
Expand Down Expand Up @@ -63,8 +63,8 @@
"name": "hybrid_mnist"
},
"registry": {
"sort": "mlflow",
"uri": "http://flame-mlflow:5000"
"sort": "dummy",
"uri": ""
},
"selector": {
"sort": "default",
Expand All @@ -75,7 +75,7 @@
"kwargs": {}
},
"maxRunTime": 300,
"realm": "default/us/west",
"realm": "default/us/org",
"role": "trainer",
"channelConfigs": {
"param-channel": {
Expand Down
35 changes: 21 additions & 14 deletions lib/python/flame/mode/distributed/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def internal_init(self) -> None:

self.is_committer = False
self.ends_of_ring = None
self.total_count = 0

self.framework = get_ml_framework_in_use()
if self.framework == MLFramework.UNKNOWN:
raise NotImplementedError(
"supported ml framework not found; "
f"supported frameworks are: {valid_frameworks}"
)
f"supported frameworks are: {valid_frameworks}")

if self.framework == MLFramework.PYTORCH:
self._scale_down_weights_fn = self._scale_down_weights_pytorch
Expand All @@ -99,16 +99,15 @@ def _ring_allreduce(self, tag: str) -> None:
logger.debug(f"channel not found with tag {tag}")
return

success, total_data_count = self._member_check(channel)
logger.debug(f"member check: {success}")
logger.debug(f"total_data_count: {total_data_count}")
if not success:
self._member_check(channel)
if not self.can_ring_allreduce():
# members don't agree, we can't do ring-allreduce
logger.debug("ring was not formed")
return

self._update_weights()

self._scale_down_weights_fn(total_data_count)
self._scale_down_weights_fn(self.total_count)

self._do_ring_allreduce(channel)

Expand Down Expand Up @@ -330,7 +329,10 @@ def _handle_member_check(self, channel, end, digest) -> tuple[bool, int]:

return True, dataset_size

def _member_check(self, channel) -> tuple[bool, int]:
def _member_check(self, channel) -> None:
# reset ends in the ring
self.ends_of_ring = None

digest = channel.ends_digest()
if digest == "":
# This means that there is no ends in the channel,
Expand All @@ -339,7 +341,7 @@ def _member_check(self, channel) -> tuple[bool, int]:
self.is_committer = True
self._update_weights()
self._update_model()
return False, 0
return

msg = {
MessageType.MEMBER_DIGEST: digest,
Expand All @@ -354,15 +356,15 @@ def _member_check(self, channel) -> tuple[bool, int]:
logger.debug(f"member check msg = {msg}")
channel.broadcast(msg)

total_count = self.dataset_size
self.total_count = self.dataset_size
ends = channel.ends()
for end in ends:
success, size = self._handle_member_check(channel, end, digest)
if not success:
logger.debug(f"_handle_member_check failed for {end}")
return False, 0
return

total_count += size
self.total_count += size

my_taskid = channel.get_backend_id()
ends.append(my_taskid)
Expand All @@ -374,10 +376,15 @@ def _member_check(self, channel) -> tuple[bool, int]:
# if work is done, then no further distributed learning needed
self._work_done = (self._round > self._rounds)
if self._work_done:
return False, 0
return

# save ends that agree to form a ring
self.ends_of_ring = ends
return True, total_count
return

def can_ring_allreduce(self) -> bool:
"""Return true if a ring is formed for ring-allreduce."""
return self.ends_of_ring is not None

def get(self, tag: str) -> None:
"""Get data from remote role(s)."""
Expand Down
Loading

0 comments on commit 53594f9

Please sign in to comment.