Skip to content

Commit

Permalink
feat: asynchronous fl
Browse files Browse the repository at this point in the history
Asynchronous FL is implemented for two-tier topology and three-tier
hierarchical topology.

The main algorithm is based on the following two papers:
- https://arxiv.org/pdf/2111.04877.pdf
- https://arxiv.org/pdf/2106.06639.pdf

Two examples for asynchronous fl are also added. One is for a two-tier
topology and the other for a three-tier hierarchical topology.

This implementation includes the core algorithm but  doesn't include
SecAgg algorithm (presented in the papers), which is not the scope of
this change.
  • Loading branch information
myungjin committed Feb 7, 2023
1 parent a5c8fb2 commit eb4b68b
Show file tree
Hide file tree
Showing 50 changed files with 2,579 additions and 100 deletions.
4 changes: 3 additions & 1 deletion lib/python/flame/backend/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ async def _broadcast_task(self, channel):
break

end_ids = list(channel._ends.keys())
logger.debug(f"end ids for bcast = {end_ids}")
logger.debug(f"end ids for {channel.name()} bcast = {end_ids}")
for end_id in end_ids:
try:
await self.send_chunks(end_id, channel.name(), data)
Expand All @@ -374,6 +374,8 @@ async def _broadcast_task(self, channel):
await self._cleanup_end(end_id)
txq.task_done()

logger.debug(f"broadcast task for {channel.name()} terminated")

async def _unicast_task(self, channel, end_id):
txq = channel.get_txq(end_id)

Expand Down
166 changes: 126 additions & 40 deletions lib/python/flame/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@
from .common.typing import Scalar
from .common.util import run_async
from .config import GROUPBY_DEFAULT_GROUP
from .end import End
from .end import KEY_END_STATE, VAL_END_STATE_RECVD, End

logger = logging.getLogger(__name__)

KEY_CH_STATE = 'state'
VAL_CH_STATE_RECV = 'recv'
VAL_CH_STATE_SEND = 'send'


class Channel(object):
"""Channel class."""
Expand Down Expand Up @@ -117,12 +121,14 @@ async def inner() -> bool:

return result

def one_end(self) -> str:
def one_end(self, state: Union[None, str] = None) -> str:
"""Return one end out of all ends."""
return self.ends()[0]
return self.ends(state)[0]

def ends(self) -> list[str]:
def ends(self, state: Union[None, str] = None) -> list[str]:
"""Return a list of end ids."""
if state == VAL_CH_STATE_RECV or state == VAL_CH_STATE_SEND:
self.properties[KEY_CH_STATE] = state

async def inner():
selected = self._selector.select(self._ends, self.properties)
Expand Down Expand Up @@ -198,17 +204,94 @@ async def _get():

payload, status = run_async(_get(), self._backend.loop())

if self.has(end_id):
# set a property that says a message was received for the end
self._ends[end_id].set_property(KEY_END_STATE, VAL_END_STATE_RECVD)

return cloudpickle.loads(payload) if payload and status else None

def recv_fifo(self, end_ids: list[str]) -> Tuple[str, Any]:
def recv_fifo(self,
end_ids: list[str],
first_k: int = 0) -> Tuple[str, Any]:
"""Receive a message per end from a list of ends.
The message arrival order among ends is not fixed.
Messages are yielded in a FIFO manner.
This method is not thread-safe.
Parameters
----------
end_ids: a list of ends to receive a message from
first_k: an integer argument to restrict the number of ends
to receive a messagae from. The default value (= 0)
means that we'd like to receive messages from all
ends in the list. If first_k > len(end_ids),
first_k is set to len(end_ids).
Returns
-------
The function yields a pair: end id and message
"""
logger.debug(f"first_k = {first_k}, len(end_ids) = {len(end_ids)}")

first_k = min(first_k, len(end_ids))
if first_k <= 0:
# a negative value in first_k is an error
# we handle it by setting first_k as the length of the array
first_k = len(end_ids)

# DO NOT CHANGE self.tmqp as a local variable.
# With aiostream, local variable update looks incorrect.
# but with an instance variable , the variable update is
# done correctly.
#
# A temporary aysncio queue to store messages in a FIFO manner
self.tmpq = None

async def _put_message_to_tmpq_inner():
# self.tmpq must be created in the _backend loop
self.tmpq = asyncio.Queue()
_ = asyncio.create_task(
self._streamer_for_recv_fifo(end_ids, first_k))

async def _get_message_inner():
return await self.tmpq.get()

# first, create an asyncio task to fetch messages and put a temp queue
# _put_message_to_tmpq_inner works as if it is a non-blocking call
# because a task is created within it
_, _ = run_async(_put_message_to_tmpq_inner(), self._backend.loop())

# the _get_message_inner() coroutine fetches a message from the temp
# queue; we call this coroutine first_k times
for _ in range(first_k):
result, status = run_async(_get_message_inner(),
self._backend.loop())
(end_id, payload) = result
logger.debug(f"get payload for {end_id}")

if self.has(end_id):
logger.debug(f"channel got a msg for {end_id}")
# set a property to indicate that a message was received
# for the end
self._ends[end_id].set_property(KEY_END_STATE,
VAL_END_STATE_RECVD)
else:
logger.debug(f"channel has no end id {end_id} for msg")

msg = cloudpickle.loads(payload) if payload and status else None
yield end_id, msg

async def _get(end_id) -> Tuple[str, Any]:
async def _streamer_for_recv_fifo(self, end_ids: list[str], first_k: int):
"""Read messages in a FIFO fashion.
This method reads messages from queues associated with each end
and puts first_k number of the messages into a queue;
The remaining messages are saved back into a variable (peek_buf)
of their corresponding end so that they can be read later.
"""

async def _get_inner(end_id) -> Tuple[str, Any]:
if not self.has(end_id):
# can't receive message from end_id
yield end_id, None
Expand All @@ -221,40 +304,43 @@ async def _get(end_id) -> Tuple[str, Any]:

yield end_id, payload

async def _streamer(tmpq):
runs = [_get(end_id) for end_id in end_ids]

merged = stream.merge(*runs)
async with merged.stream() as streamer:
async for result in streamer:
await tmpq.put(result)

# a temporary aysncio queue to store messages in a FIFO manner.
# we define this varialbe to make sure it is visiable
# in both _inner1() and _inner2()
tmpq = None

async def _inner1():
nonlocal tmpq
# tmpq must be created in the _backend loop
tmpq = asyncio.Queue()
_ = asyncio.create_task(_streamer(tmpq))

async def _inner2():
return await tmpq.get()

# first, create an asyncio task to fetch messages and put a temp queue
# _inner1 works as if it is a non-blocking call
# because a task is created within it
_, _ = run_async(_inner1(), self._backend.loop())

# the _inner2() coroutine fetches a message from the temp queue
# we call this coroutine the number of end_ids by iterating end_ids
for _ in end_ids:
result, status = run_async(_inner2(), self._backend.loop())
(end_id, payload) = result
msg = cloudpickle.loads(payload) if payload and status else None
yield end_id, msg
runs = [_get_inner(end_id) for end_id in end_ids]

# DO NOT CHANGE self.count as a local variable
# with aiostream, local variable update looks incorrect.
# but with an instance variable , the variable update is
# done correctly.
self.count = 0
merged = stream.merge(*runs)
async with merged.stream() as streamer:
logger.debug(f"0) cnt: {self.count}, first_k: {first_k}")
async for result in streamer:
(end_id, payload) = result
logger.debug(f"1) end id: {end_id}, cnt: {self.count}")

self.count += 1
logger.debug(f"2) end id: {end_id}, cnt: {self.count}")
if self.count <= first_k:
logger.debug(f"3) end id: {end_id}, cnt: {self.count}")
await self.tmpq.put(result)

else:
logger.debug(f"4) end id: {end_id}, cnt: {self.count}")
# We already put the first_k number of messages into
# a queue.
#
# Now we need to save the remaining messages which
# were already taken out from each end's rcv queue.
# In order not to lose those messages, we use peek_buf
# in end object.

# WARNING: peek_buf must be none; if not, we called
# peek() somewhere else and then called recv_fifo()
# before recv() was called.
# To detect this potential issue, assert is given here.
assert self._ends[end_id].peek_buf is None

self._ends[end_id].peek_buf = payload

def peek(self, end_id):
"""Peek rxq of end_id and return data if queue is not empty."""
Expand Down
4 changes: 4 additions & 0 deletions lib/python/flame/channel_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ def cleanup(self):
ch.cleanup()

async def _inner(backend):
# TODO: need better mechanism to wait tx completion
# as a temporary measure, sleep 5 seconds
await asyncio.sleep(5)

# clean up backend
await backend.cleanup()

Expand Down
4 changes: 4 additions & 0 deletions lib/python/flame/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,17 @@ class OptimizerType(Enum):
FEDADAGRAD = 2 # FedAdaGrad
FEDADAM = 3 # FedAdam
FEDYOGI = 4 # FedYogi
# FedBuff from https://arxiv.org/pdf/1903.03934.pdf and
# https://arxiv.org/pdf/2111.04877.pdf
FEDBUFF = 5


class SelectorType(Enum):
"""Define selector types."""

DEFAULT = 1 # default
RANDOM = 2 # random
FEDBUFF = 3 # fedbuff


REALM_SEPARATOR = '/'
Expand Down
4 changes: 4 additions & 0 deletions lib/python/flame/end.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

from .common.typing import Scalar

KEY_END_STATE = 'state'
VAL_END_STATE_RECVD = 'recvd'
VAL_END_STATE_NONE = ''


class End(object):
"""End class."""
Expand Down
15 changes: 15 additions & 0 deletions lib/python/flame/examples/async_hier_mnist/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
Loading

0 comments on commit eb4b68b

Please sign in to comment.