Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Impove tasklet composition extensibility #397

Merged
merged 1 commit into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions lib/python/flame/mode/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from queue import Queue
from types import TracebackType
from typing import Optional, Type
from flame.mode.enums import LoopIndicator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -216,6 +217,202 @@ def print(self):
print("=====")
print("done with printing chain")

def get_tasklet(self, alias: str):
"""Get a tasklet from the composer by alias."""
tasklet = next(iter(self.chain))
root = tasklet.get_root()

q = Queue()
q.put(root)
while not q.empty():
tasklet = q.get()
if tasklet.alias == alias:
return tasklet
for t in self.chain[tasklet]:
q.put(t)
else:
raise ValueError(f"Tasklet with alias {alias} not found")

def remove_tasklet(self, alias: str):
"""Remove a tasklet from the composer and stitches the chain."""
# throw on tasklet is loop starter or ender
tasklet = self.get_tasklet(alias)
if tasklet.loop_state:
raise NotImplementedError("Can't handle loop ends removal yet")

tasklet = next(iter(self.chain))
root = tasklet.get_root()

if root.alias == alias:
self.chain.pop(root)
return None

q = Queue()
q.put(root)
while not q.empty():
parent_t = q.get()

for child_t in self.chain[parent_t]:
if child_t.alias == alias: # if child for removal
# stitch the chain
self.chain[parent_t].remove(child_t)
self.chain[parent_t].update(self.chain[child_t])
# remove the child from the chain
self.chain.pop(child_t)

self.reverse_chain = self.get_reverse_chain()

return None

for child in self.chain[parent_t]:
q.put(child)
else:
raise ValueError(f"Tasklet with alias {alias} not found")

def _update_chain(self, new_order):
self.chain = {k: self.chain[k] for k in new_order}
self.reverse_chain = self.get_reverse_chain()

def insert(self, alias, new_tasklet, after=False):
"""Insert a new tasklet before a given tasklet. If after is True, insert after the given tasklet.
"""

if after:
self._insert_after(alias, new_tasklet)
else:
self._insert_before(alias, new_tasklet)

def _insert_before(self, alias, new_tasklet):
initial_chain_order = list(self.chain)
tasklet = next(iter(self.chain))
root = tasklet.get_root()

for t in self.chain:
if t.alias == new_tasklet.alias:
raise ValueError(f"Tasklet with alias '{new_tasklet.alias}' already exists")

if root.alias == alias:
self.chain[new_tasklet] = {root}
# update the chain order
updated_chain_order = initial_chain_order.copy()
updated_chain_order.insert(initial_chain_order.index(root), new_tasklet)
self._update_chain(updated_chain_order)

return None

q = Queue()
q.put(root)
while not q.empty():
parent_t = q.get()

for child_t in self.chain[parent_t]:
if child_t.alias == alias:
# handle new_tasklet loop insertion
if child_t is child_t.loop_starter: # new_tasklet is new loop starter
new_tasklet.update_loop_attrs(
check_fn=child_t.loop_check_fn,
state=child_t.loop_state,
starter=new_tasklet,
ender=child_t.loop_ender
)
child_t.update_loop_attrs(state=LoopIndicator.NONE)
elif child_t is child_t.loop_ender:
new_tasklet.update_loop_attrs(
check_fn=child_t.loop_check_fn,
state=LoopIndicator.NONE,
starter=child_t.loop_starter,
ender=child_t
)

# relink the chain
self.chain[parent_t].remove(child_t)
self.chain[parent_t].add(new_tasklet)
# insert the new tasklet
self.chain[new_tasklet] = {child_t}
# update the chain order
updated_chain_order = initial_chain_order.copy()
updated_chain_order.insert(initial_chain_order.index(child_t), new_tasklet)
self._update_chain(updated_chain_order)

# update the loop if new_tasklet is in a loop
if new_tasklet.loop_state:
start, end = new_tasklet, new_tasklet.loop_ender
tasklets_in_loop = self.get_tasklets_in_loop(start, end)

for t in tasklets_in_loop:
t.update_loop_attrs(
starter=new_tasklet,
ender=new_tasklet.loop_ender
)

return None

for child in self.chain[parent_t]:
q.put(child)

def _insert_after(self, alias, new_tasklet):
initial_chain_order = list(self.chain)
tasklet = next(iter(self.chain))
root = tasklet.get_root()

for t in self.chain:
if t.alias == new_tasklet.alias:
raise ValueError(f"Tasklet with alias '{new_tasklet.alias}' already exists")

q = Queue()
q.put(root)
while not q.empty():
parent_t = q.get()

if parent_t.alias == alias:
# handle new_tasklet loop insertion
if parent_t is parent_t.loop_ender: # new_tasklet is new loop ender
new_tasklet.update_loop_attrs(
check_fn=parent_t.loop_check_fn,
state=parent_t.loop_state,
starter=parent_t.loop_starter,
ender=new_tasklet
)
parent_t.update_loop_attrs(state=LoopIndicator.NONE)
elif parent_t is parent_t.loop_starter:
new_tasklet.update_loop_attrs(
check_fn=parent_t.loop_check_fn,
state=LoopIndicator.NONE,
starter=parent_t.loop_ender,
ender=parent_t.loop_ender
)

self.chain[new_tasklet] = {child_t for child_t in self.chain[parent_t]}
self.chain[parent_t] = {new_tasklet}
# update the chain order
updated_chain_order = initial_chain_order.copy()
updated_chain_order.insert(initial_chain_order.index(parent_t) + 1, new_tasklet)
self._update_chain(updated_chain_order)

# update the loop if new_tasklet is in a loop
if new_tasklet.loop_state:
start, end = new_tasklet.loop_starter, new_tasklet.loop_ender
tasklets_in_loop = self.get_tasklets_in_loop(start, end)

for t in tasklets_in_loop:
t.update_loop_attrs(
starter=new_tasklet.loop_starter,
ender=new_tasklet.loop_ender
)

return None

for child in self.chain[parent_t]:
q.put(child)

def get_reverse_chain(self):
niahc = {k: set() for k in self.chain}
for k, v in self.chain.items():
for t in v:
niahc[t].add(k)

return niahc


class CloneComposer(object):
"""CloneComposer clones composer object."""
Expand Down
9 changes: 9 additions & 0 deletions lib/python/flame/mode/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from enum import Flag, auto


class LoopIndicator(Flag):
"""LoopIndicator is a flag class that contains loog begin and end flags."""

NONE = 0
BEGIN = auto()
END = auto()
59 changes: 17 additions & 42 deletions lib/python/flame/mode/horizontal/coord_syncfl/middle_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging

from flame.mode.composer import Composer
from flame.mode.composer import CloneComposer
from flame.mode.horizontal.syncfl.middle_aggregator import (
TAG_AGGREGATE,
TAG_DISTRIBUTE,
Expand All @@ -27,7 +27,7 @@
MiddleAggregator as BaseMiddleAggregator,
)
from flame.mode.message import MessageType
from flame.mode.tasklet import Loop, Tasklet
from flame.mode.tasklet import Tasklet

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,53 +99,28 @@ def _handle_no_trainer(self):

def compose(self) -> None:
"""Compose role with tasklets."""
with Composer() as composer:
self.composer = composer

task_internal_init = Tasklet("", self.internal_init)

task_init = Tasklet("", self.initialize)
super().compose()

task_load_data = Tasklet("", self.load_data)
with CloneComposer(self.composer) as composer:
self.composer = composer

task_get_trainers = Tasklet("", self._get_trainers)

task_no_trainer = Tasklet("", self._handle_no_trainer)
task_no_trainer.set_continue_fn(cont_fn=lambda: self.no_trainer)

task_put_dist = Tasklet("", self.put, TAG_DISTRIBUTE)

task_put_upload = Tasklet("", self.put, TAG_UPLOAD)

task_get_aggr = Tasklet("", self.get, TAG_AGGREGATE)

task_get_fetch = Tasklet("", self.get, TAG_FETCH)

task_eval = Tasklet("", self.evaluate)

task_update_round = Tasklet("", self.update_round)

task_end_of_training = Tasklet("", self.inform_end_of_training)

loop = Loop(loop_check_fn=lambda: self._work_done)
(
task_internal_init
>> task_load_data
>> task_init
>> loop(
task_get_trainers
>> task_no_trainer
>> task_get_fetch
>> task_put_dist
>> task_get_aggr
>> task_put_upload
>> task_eval
>> task_update_round
)
>> task_end_of_training
)
self.composer.get_tasklet("fetch").insert_before(task_no_trainer)
task_no_trainer.insert_before(task_get_trainers)

@classmethod
def get_func_tags(cls) -> list[str]:
"""Return a list of function tags defined in the top level aggregator role."""
return [TAG_AGGREGATE, TAG_DISTRIBUTE, TAG_FETCH, TAG_UPLOAD, TAG_COORDINATE]
"""Return a list of function tags defined in the top level
aggregator role.
"""
return [
TAG_AGGREGATE,
TAG_DISTRIBUTE,
TAG_FETCH,
TAG_UPLOAD,
TAG_COORDINATE,
]
52 changes: 7 additions & 45 deletions lib/python/flame/mode/horizontal/coord_syncfl/top_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging

from flame.mode.composer import Composer
from flame.mode.composer import CloneComposer
from flame.mode.horizontal.syncfl.top_aggregator import TAG_AGGREGATE, TAG_DISTRIBUTE
from flame.mode.horizontal.syncfl.top_aggregator import (
TopAggregator as BaseTopAggregator,
Expand Down Expand Up @@ -62,54 +62,16 @@ def get_coordinated_ends(self):

def compose(self) -> None:
"""Compose role with tasklets."""
with Composer() as composer:
self.composer = composer

task_internal_init = Tasklet("", self.internal_init)

task_init = Tasklet("", self.initialize)

task_load_data = Tasklet("", self.load_data)

task_get_coord_ends = Tasklet("", self.get_coordinated_ends)

task_put = Tasklet("", self.put, TAG_DISTRIBUTE)
super().compose()

task_get = Tasklet("", self.get, TAG_AGGREGATE)

task_train = Tasklet("", self.train)

task_eval = Tasklet("", self.evaluate)

task_analysis = Tasklet("", self.run_analysis)

task_save_metrics = Tasklet("", self.save_metrics)

task_increment_round = Tasklet("", self.increment_round)
with CloneComposer(self.composer) as composer:
self.composer = composer

task_save_params = Tasklet("", self.save_params)
task_get_coord_ends = Tasklet("get_coord_ends", self.get_coordinated_ends)

task_save_model = Tasklet("", self.save_model)

# create a loop object with loop exit condition function
loop = Loop(loop_check_fn=lambda: self._work_done)
(
task_internal_init
>> task_load_data
>> task_init
>> loop(
task_get_coord_ends
>> task_put
>> task_get
>> task_train
>> task_eval
>> task_analysis
>> task_save_metrics
>> task_increment_round
)
>> task_save_params
>> task_save_model
)
self.composer.get_tasklet("distribute").insert_before(task_get_coord_ends)
self.composer.get_tasklet("inform_end_of_training").remove()

@classmethod
def get_func_tags(cls) -> list[str]:
Expand Down
Loading