Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: asynchronous fl #323

Merged
merged 1 commit into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lib/python/flame/backend/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ async def _broadcast_task(self, channel):
break

end_ids = list(channel._ends.keys())
logger.debug(f"end ids for bcast = {end_ids}")
logger.debug(f"end ids for {channel.name()} bcast = {end_ids}")
for end_id in end_ids:
try:
await self.send_chunks(end_id, channel.name(), data)
Expand All @@ -374,6 +374,8 @@ async def _broadcast_task(self, channel):
await self._cleanup_end(end_id)
txq.task_done()

logger.debug(f"broadcast task for {channel.name()} terminated")

async def _unicast_task(self, channel, end_id):
txq = channel.get_txq(end_id)

Expand Down
166 changes: 126 additions & 40 deletions lib/python/flame/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@
from .common.typing import Scalar
from .common.util import run_async
from .config import GROUPBY_DEFAULT_GROUP
from .end import End
from .end import KEY_END_STATE, VAL_END_STATE_RECVD, End

logger = logging.getLogger(__name__)

KEY_CH_STATE = 'state'
VAL_CH_STATE_RECV = 'recv'
VAL_CH_STATE_SEND = 'send'


class Channel(object):
"""Channel class."""
Expand Down Expand Up @@ -117,12 +121,14 @@ async def inner() -> bool:

return result

def one_end(self) -> str:
def one_end(self, state: Union[None, str] = None) -> str:
"""Return one end out of all ends."""
return self.ends()[0]
return self.ends(state)[0]

def ends(self) -> list[str]:
def ends(self, state: Union[None, str] = None) -> list[str]:
"""Return a list of end ids."""
if state == VAL_CH_STATE_RECV or state == VAL_CH_STATE_SEND:
self.properties[KEY_CH_STATE] = state

async def inner():
selected = self._selector.select(self._ends, self.properties)
Expand Down Expand Up @@ -198,17 +204,94 @@ async def _get():

payload, status = run_async(_get(), self._backend.loop())

if self.has(end_id):
# set a property that says a message was received for the end
self._ends[end_id].set_property(KEY_END_STATE, VAL_END_STATE_RECVD)

return cloudpickle.loads(payload) if payload and status else None

def recv_fifo(self, end_ids: list[str]) -> Tuple[str, Any]:
def recv_fifo(self,
end_ids: list[str],
first_k: int = 0) -> Tuple[str, Any]:
"""Receive a message per end from a list of ends.
The message arrival order among ends is not fixed.
Messages are yielded in a FIFO manner.
This method is not thread-safe.
Parameters
----------
end_ids: a list of ends to receive a message from
first_k: an integer argument to restrict the number of ends
to receive a messagae from. The default value (= 0)
means that we'd like to receive messages from all
ends in the list. If first_k > len(end_ids),
first_k is set to len(end_ids).
Returns
-------
The function yields a pair: end id and message
"""
logger.debug(f"first_k = {first_k}, len(end_ids) = {len(end_ids)}")

first_k = min(first_k, len(end_ids))
if first_k <= 0:
# a negative value in first_k is an error
# we handle it by setting first_k as the length of the array
first_k = len(end_ids)

# DO NOT CHANGE self.tmqp as a local variable.
# With aiostream, local variable update looks incorrect.
# but with an instance variable , the variable update is
# done correctly.
#
# A temporary aysncio queue to store messages in a FIFO manner
self.tmpq = None

async def _put_message_to_tmpq_inner():
# self.tmpq must be created in the _backend loop
self.tmpq = asyncio.Queue()
_ = asyncio.create_task(
self._streamer_for_recv_fifo(end_ids, first_k))

async def _get_message_inner():
return await self.tmpq.get()

# first, create an asyncio task to fetch messages and put a temp queue
# _put_message_to_tmpq_inner works as if it is a non-blocking call
# because a task is created within it
_, _ = run_async(_put_message_to_tmpq_inner(), self._backend.loop())

# the _get_message_inner() coroutine fetches a message from the temp
# queue; we call this coroutine first_k times
for _ in range(first_k):
result, status = run_async(_get_message_inner(),
self._backend.loop())
(end_id, payload) = result
logger.debug(f"get payload for {end_id}")

if self.has(end_id):
logger.debug(f"channel got a msg for {end_id}")
# set a property to indicate that a message was received
# for the end
self._ends[end_id].set_property(KEY_END_STATE,
VAL_END_STATE_RECVD)
else:
logger.debug(f"channel has no end id {end_id} for msg")

msg = cloudpickle.loads(payload) if payload and status else None
yield end_id, msg

async def _get(end_id) -> Tuple[str, Any]:
async def _streamer_for_recv_fifo(self, end_ids: list[str], first_k: int):
"""Read messages in a FIFO fashion.
This method reads messages from queues associated with each end
and puts first_k number of the messages into a queue;
The remaining messages are saved back into a variable (peek_buf)
of their corresponding end so that they can be read later.
"""

async def _get_inner(end_id) -> Tuple[str, Any]:
if not self.has(end_id):
# can't receive message from end_id
yield end_id, None
Expand All @@ -221,40 +304,43 @@ async def _get(end_id) -> Tuple[str, Any]:

yield end_id, payload

async def _streamer(tmpq):
runs = [_get(end_id) for end_id in end_ids]

merged = stream.merge(*runs)
async with merged.stream() as streamer:
async for result in streamer:
await tmpq.put(result)

# a temporary aysncio queue to store messages in a FIFO manner.
# we define this varialbe to make sure it is visiable
# in both _inner1() and _inner2()
tmpq = None

async def _inner1():
nonlocal tmpq
# tmpq must be created in the _backend loop
tmpq = asyncio.Queue()
_ = asyncio.create_task(_streamer(tmpq))

async def _inner2():
return await tmpq.get()

# first, create an asyncio task to fetch messages and put a temp queue
# _inner1 works as if it is a non-blocking call
# because a task is created within it
_, _ = run_async(_inner1(), self._backend.loop())

# the _inner2() coroutine fetches a message from the temp queue
# we call this coroutine the number of end_ids by iterating end_ids
for _ in end_ids:
result, status = run_async(_inner2(), self._backend.loop())
(end_id, payload) = result
msg = cloudpickle.loads(payload) if payload and status else None
yield end_id, msg
runs = [_get_inner(end_id) for end_id in end_ids]

# DO NOT CHANGE self.count as a local variable
# with aiostream, local variable update looks incorrect.
# but with an instance variable , the variable update is
# done correctly.
self.count = 0
merged = stream.merge(*runs)
async with merged.stream() as streamer:
logger.debug(f"0) cnt: {self.count}, first_k: {first_k}")
async for result in streamer:
(end_id, payload) = result
logger.debug(f"1) end id: {end_id}, cnt: {self.count}")

self.count += 1
logger.debug(f"2) end id: {end_id}, cnt: {self.count}")
if self.count <= first_k:
logger.debug(f"3) end id: {end_id}, cnt: {self.count}")
await self.tmpq.put(result)

else:
logger.debug(f"4) end id: {end_id}, cnt: {self.count}")
# We already put the first_k number of messages into
# a queue.
#
# Now we need to save the remaining messages which
# were already taken out from each end's rcv queue.
# In order not to lose those messages, we use peek_buf
# in end object.

# WARNING: peek_buf must be none; if not, we called
# peek() somewhere else and then called recv_fifo()
# before recv() was called.
# To detect this potential issue, assert is given here.
assert self._ends[end_id].peek_buf is None

self._ends[end_id].peek_buf = payload

def peek(self, end_id):
"""Peek rxq of end_id and return data if queue is not empty."""
Expand Down
4 changes: 4 additions & 0 deletions lib/python/flame/channel_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ def cleanup(self):
ch.cleanup()

async def _inner(backend):
# TODO: need better mechanism to wait tx completion
# as a temporary measure, sleep 5 seconds
await asyncio.sleep(5)

# clean up backend
await backend.cleanup()

Expand Down
4 changes: 4 additions & 0 deletions lib/python/flame/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,17 @@ class OptimizerType(Enum):
FEDADAGRAD = 2 # FedAdaGrad
FEDADAM = 3 # FedAdam
FEDYOGI = 4 # FedYogi
# FedBuff from https://arxiv.org/pdf/1903.03934.pdf and
# https://arxiv.org/pdf/2111.04877.pdf
FEDBUFF = 5


class SelectorType(Enum):
"""Define selector types."""

DEFAULT = 1 # default
RANDOM = 2 # random
FEDBUFF = 3 # fedbuff


REALM_SEPARATOR = '/'
Expand Down
4 changes: 4 additions & 0 deletions lib/python/flame/end.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

from .common.typing import Scalar

KEY_END_STATE = 'state'
VAL_END_STATE_RECVD = 'recvd'
VAL_END_STATE_NONE = ''


class End(object):
"""End class."""
Expand Down
15 changes: 15 additions & 0 deletions lib/python/flame/examples/async_hier_mnist/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
Loading