Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 'continue' primitive in tasklet #261

Merged
merged 1 commit into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions lib/python/flame/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,31 @@ def join(self):
"""Join the channel."""
self._backend.join(self)

def await_join(self):
"""Wait for at least one peer joins a channel."""
def await_join(self, timeout=None) -> bool:
"""Wait for at least one peer joins a channel.

async def _inner():
If timeout value is set, it will wait until timeout occurs.
Returns a boolean value to indicate whether timeout occurred or not.

Parameters
----------
timeout: a timeout value; default: None
"""

async def _inner() -> bool:
"""Return True if timeout occurs; otherwise False."""
logger.debug("waiting for join")
await self.await_join_event.wait()
try:
await asyncio.wait_for(self.await_join_event.wait(), timeout)
except asyncio.TimeoutError:
logger.debug("timeout occurred")
return True
logger.debug("at least one peer joined")
return False

_, _ = run_async(_inner(), self._backend.loop())
timeouted, _ = run_async(_inner(), self._backend.loop())
logger.debug(f"timeouted = {timeouted}")
return timeouted

def is_rxq_empty(self, end_id: str) -> bool:
"""Return true if rxq is empty; otherwise, false."""
Expand Down
12 changes: 9 additions & 3 deletions lib/python/flame/mode/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,15 @@ def run(self) -> None:
# execute tasklet
tasklet.do()

if tasklet.is_last_in_loop() and not tasklet.is_loop_done():
# we reached the last tasklet of a loop
# but the loop exit condition is not met
if tasklet.is_continue() or (tasklet.is_last_in_loop()
and not tasklet.is_loop_done()):
# we are here due to one of the following conditions:
#
# contition 1: tasklet's continue condition is met;
# so, we skip the remaing tasklets in the loop
# and go back to the start of the loop
# condition 2: we reached the last tasklet of a loop
# but the loop exit condition is not met
start, end = tasklet.loop_starter, tasklet
tasklets_in_loop = self.get_tasklets_in_loop(start, end)

Expand Down
36 changes: 31 additions & 5 deletions lib/python/flame/mode/horizontal/middle_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
TAG_FETCH = 'fetch'
TAG_UPLOAD = 'upload'

# 60 second wait time until a trainer appears in a channel
WAIT_TIME_FOR_TRAINER = 60


class MiddleAggregator(Role, metaclass=ABCMeta):
"""Middle level aggregator.
Expand Down Expand Up @@ -111,7 +114,12 @@ def _distribute_weights(self, tag: str) -> None:
return

# this call waits for at least one peer to join this channel
channel.await_join()
self.trainer_no_show = channel.await_join(WAIT_TIME_FOR_TRAINER)
if self.trainer_no_show:
logger.debug("channel await join timeouted")
# send dummy weights to unblock top aggregator
self._send_dummy_weights(TAG_UPLOAD)
return

for end in channel.ends():
logger.debug(f"sending weights to {end}")
Expand All @@ -137,13 +145,14 @@ def _aggregate_weights(self, tag: str) -> None:

if MessageType.DATASET_SIZE in msg:
count = msg[MessageType.DATASET_SIZE]
total += count

logger.debug(f"{end}'s parameters trained with {count} samples")

tres = TrainResult(weights, count)
# save training result from trainer in a disk cache
self.cache[end] = tres
if weights is not None and count > 0:
total += count
tres = TrainResult(weights, count)
# save training result from trainer in a disk cache
self.cache[end] = tres

# optimizer conducts optimization (in this case, aggregation)
global_weights = self.optimizer.do(self.cache, total)
Expand Down Expand Up @@ -176,6 +185,22 @@ def _send_weights(self, tag: str) -> None:
})
logger.debug("sending weights done")

def _send_dummy_weights(self, tag: str) -> None:
channel = self.cm.get_by_tag(tag)
if not channel:
logger.debug(f"channel not found with {tag}")
return

# this call waits for at least one peer to join this channel
channel.await_join()

# one aggregator is sufficient
end = channel.one_end()

dummy_msg = {MessageType.WEIGHTS: None, MessageType.DATASET_SIZE: 0}
channel.send(end, dummy_msg)
logger.debug("sending dummy weights done")

def update_round(self):
"""Update the round counter."""
logger.debug(f"Update current round: {self._round}")
Expand Down Expand Up @@ -209,6 +234,7 @@ def compose(self) -> None:
task_load_data = Tasklet(self.load_data)

task_put_dist = Tasklet(self.put, TAG_DISTRIBUTE)
task_put_dist.set_continue_fn(cont_fn=lambda: self.trainer_no_show)

task_put_upload = Tasklet(self.put, TAG_UPLOAD)

Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/mode/horizontal/top_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ def _aggregate_weights(self, tag: str) -> None:

if MessageType.DATASET_SIZE in msg:
count = msg[MessageType.DATASET_SIZE]
total += count

logger.debug(f"{end}'s parameters trained with {count} samples")

if weights is not None:
if weights is not None and count > 0:
total += count
tres = TrainResult(weights, count)
# save training result from trainer in a disk cache
self.cache[end] = tres
Expand Down
35 changes: 31 additions & 4 deletions lib/python/flame/mode/tasklet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

"""flame tasklet."""

from __future__ import annotations

import logging
from enum import Flag, auto
from queue import Queue
from typing import Callable

from .composer import ComposerContext

Expand All @@ -38,10 +38,22 @@ class LoopIndicator(Flag):
class Tasklet(object):
"""Tasklet is a class for defining a unit of work."""

def __init__(self, func, *args) -> None:
"""Initialize the class."""
def __init__(self, func: Callable, *args, **kwargs) -> None:
"""Initialize the class.

Parameters
----------
func: a method that will be executed as a tasklet
*args: positional arguments for method func
**kwargs: keyword arguments for method func
"""
if not callable(func):
raise TypeError(f"{func} is not callable")

self.func = func
self.args = args
self.kwargs = kwargs
self.cont_fn = None
self.loop_check_fn = None
self.composer = ComposerContext.get_composer()
self.loop_starter = None
Expand Down Expand Up @@ -85,6 +97,10 @@ def __rshift__(self, other: Tasklet) -> Tasklet:

return other

def set_continue_fn(self, cont_fn: Callable) -> None:
"""Set continue function."""
self.cont_fn = cont_fn

def get_composer(self):
"""Return composer object."""
return self.composer
Expand Down Expand Up @@ -125,7 +141,7 @@ def get_ender(self) -> Tasklet:

def do(self) -> None:
"""Execute tasklet."""
self.func(*self.args)
self.func(*self.args, **self.kwargs)

def is_loop_done(self) -> bool:
"""Return if loop is done."""
Expand All @@ -138,6 +154,13 @@ def is_last_in_loop(self) -> bool:
"""Return if the tasklet is the last one in a loop."""
return self.loop_state & LoopIndicator.END

def is_continue(self) -> bool:
"""Return True if continue condition is met and otherwise False."""
if not callable(self.cont_fn):
return False

return self.cont_fn()


class Loop(object):
"""Loop class."""
Expand All @@ -149,6 +172,9 @@ def __init__(self, loop_check_fn=None) -> None:
----------
loop_check_fn: a function object to check loop exit conditions
"""
if not callable(loop_check_fn):
raise TypeError(f"{loop_check_fn} is not callable")

self.loop_check_fn = loop_check_fn

def __call__(self, ender: Tasklet) -> Tasklet:
Expand Down Expand Up @@ -199,6 +225,7 @@ def __call__(self, ender: Tasklet) -> Tasklet:
tasklets_in_loop = composer.get_tasklets_in_loop(starter, ender)
# for each tasklet in loop, loop_check_fn and loop_ender are updated
for tasklet in tasklets_in_loop:
tasklet.loop_starter = starter
tasklet.loop_check_fn = self.loop_check_fn
tasklet.loop_ender = ender

Expand Down
2 changes: 1 addition & 1 deletion lib/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

setup(
name='flame',
version='0.0.11',
version='0.0.12',
author='Flame Maintainers',
author_email='flame-github-owners@cisco.com',
include_package_data=True,
Expand Down