Skip to content

Commit

Permalink
Impove tasklet composition extensibility (#397)
Browse files Browse the repository at this point in the history
This PR improves the extensibility of tasklet composition by allowing the developer to extend the tasklet set on the original `Composer` by declaring new tasklets under `CloneComposer` context and then use the composer to get originally implmented tasklets by alias and perform composition on them.
Implemented tasklet/composer operations:
- `insert_before` which inserts a tasklet before the target tasklet
- `insert_after` which inserts a tasklet after the target tasklet
- `replace_with` which replaces the target tasklet with the new tasklet
- `remove` which removes the target tasklet

The implemented operations work with loop edges, except for `remove` which will raise NotImplementedError if the target tasklet has loop edges.

This PR demonstrates the extensibility of tasklet composition by refactoring the existing coord_syncfl mode to use the new extensibility (demonstrating the reduction in LOC needed to achieve same results).
  • Loading branch information
lkurija1 committed Apr 21, 2023
1 parent 3a73322 commit f8f1da9
Show file tree
Hide file tree
Showing 8 changed files with 294 additions and 139 deletions.
197 changes: 197 additions & 0 deletions lib/python/flame/mode/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from queue import Queue
from types import TracebackType
from typing import Optional, Type
from flame.mode.enums import LoopIndicator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -216,6 +217,202 @@ def print(self):
print("=====")
print("done with printing chain")

def get_tasklet(self, alias: str):
"""Get a tasklet from the composer by alias."""
tasklet = next(iter(self.chain))
root = tasklet.get_root()

q = Queue()
q.put(root)
while not q.empty():
tasklet = q.get()
if tasklet.alias == alias:
return tasklet
for t in self.chain[tasklet]:
q.put(t)
else:
raise ValueError(f"Tasklet with alias {alias} not found")

def remove_tasklet(self, alias: str):
"""Remove a tasklet from the composer and stitches the chain."""
# throw on tasklet is loop starter or ender
tasklet = self.get_tasklet(alias)
if tasklet.loop_state:
raise NotImplementedError("Can't handle loop ends removal yet")

tasklet = next(iter(self.chain))
root = tasklet.get_root()

if root.alias == alias:
self.chain.pop(root)
return None

q = Queue()
q.put(root)
while not q.empty():
parent_t = q.get()

for child_t in self.chain[parent_t]:
if child_t.alias == alias: # if child for removal
# stitch the chain
self.chain[parent_t].remove(child_t)
self.chain[parent_t].update(self.chain[child_t])
# remove the child from the chain
self.chain.pop(child_t)

self.reverse_chain = self.get_reverse_chain()

return None

for child in self.chain[parent_t]:
q.put(child)
else:
raise ValueError(f"Tasklet with alias {alias} not found")

def _update_chain(self, new_order):
self.chain = {k: self.chain[k] for k in new_order}
self.reverse_chain = self.get_reverse_chain()

def insert(self, alias, new_tasklet, after=False):
"""Insert a new tasklet before a given tasklet. If after is True, insert after the given tasklet.
"""

if after:
self._insert_after(alias, new_tasklet)
else:
self._insert_before(alias, new_tasklet)

def _insert_before(self, alias, new_tasklet):
initial_chain_order = list(self.chain)
tasklet = next(iter(self.chain))
root = tasklet.get_root()

for t in self.chain:
if t.alias == new_tasklet.alias:
raise ValueError(f"Tasklet with alias '{new_tasklet.alias}' already exists")

if root.alias == alias:
self.chain[new_tasklet] = {root}
# update the chain order
updated_chain_order = initial_chain_order.copy()
updated_chain_order.insert(initial_chain_order.index(root), new_tasklet)
self._update_chain(updated_chain_order)

return None

q = Queue()
q.put(root)
while not q.empty():
parent_t = q.get()

for child_t in self.chain[parent_t]:
if child_t.alias == alias:
# handle new_tasklet loop insertion
if child_t is child_t.loop_starter: # new_tasklet is new loop starter
new_tasklet.update_loop_attrs(
check_fn=child_t.loop_check_fn,
state=child_t.loop_state,
starter=new_tasklet,
ender=child_t.loop_ender
)
child_t.update_loop_attrs(state=LoopIndicator.NONE)
elif child_t is child_t.loop_ender:
new_tasklet.update_loop_attrs(
check_fn=child_t.loop_check_fn,
state=LoopIndicator.NONE,
starter=child_t.loop_starter,
ender=child_t
)

# relink the chain
self.chain[parent_t].remove(child_t)
self.chain[parent_t].add(new_tasklet)
# insert the new tasklet
self.chain[new_tasklet] = {child_t}
# update the chain order
updated_chain_order = initial_chain_order.copy()
updated_chain_order.insert(initial_chain_order.index(child_t), new_tasklet)
self._update_chain(updated_chain_order)

# update the loop if new_tasklet is in a loop
if new_tasklet.loop_state:
start, end = new_tasklet, new_tasklet.loop_ender
tasklets_in_loop = self.get_tasklets_in_loop(start, end)

for t in tasklets_in_loop:
t.update_loop_attrs(
starter=new_tasklet,
ender=new_tasklet.loop_ender
)

return None

for child in self.chain[parent_t]:
q.put(child)

def _insert_after(self, alias, new_tasklet):
initial_chain_order = list(self.chain)
tasklet = next(iter(self.chain))
root = tasklet.get_root()

for t in self.chain:
if t.alias == new_tasklet.alias:
raise ValueError(f"Tasklet with alias '{new_tasklet.alias}' already exists")

q = Queue()
q.put(root)
while not q.empty():
parent_t = q.get()

if parent_t.alias == alias:
# handle new_tasklet loop insertion
if parent_t is parent_t.loop_ender: # new_tasklet is new loop ender
new_tasklet.update_loop_attrs(
check_fn=parent_t.loop_check_fn,
state=parent_t.loop_state,
starter=parent_t.loop_starter,
ender=new_tasklet
)
parent_t.update_loop_attrs(state=LoopIndicator.NONE)
elif parent_t is parent_t.loop_starter:
new_tasklet.update_loop_attrs(
check_fn=parent_t.loop_check_fn,
state=LoopIndicator.NONE,
starter=parent_t.loop_ender,
ender=parent_t.loop_ender
)

self.chain[new_tasklet] = {child_t for child_t in self.chain[parent_t]}
self.chain[parent_t] = {new_tasklet}
# update the chain order
updated_chain_order = initial_chain_order.copy()
updated_chain_order.insert(initial_chain_order.index(parent_t) + 1, new_tasklet)
self._update_chain(updated_chain_order)

# update the loop if new_tasklet is in a loop
if new_tasklet.loop_state:
start, end = new_tasklet.loop_starter, new_tasklet.loop_ender
tasklets_in_loop = self.get_tasklets_in_loop(start, end)

for t in tasklets_in_loop:
t.update_loop_attrs(
starter=new_tasklet.loop_starter,
ender=new_tasklet.loop_ender
)

return None

for child in self.chain[parent_t]:
q.put(child)

def get_reverse_chain(self):
niahc = {k: set() for k in self.chain}
for k, v in self.chain.items():
for t in v:
niahc[t].add(k)

return niahc


class CloneComposer(object):
"""CloneComposer clones composer object."""
Expand Down
9 changes: 9 additions & 0 deletions lib/python/flame/mode/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from enum import Flag, auto


class LoopIndicator(Flag):
"""LoopIndicator is a flag class that contains loog begin and end flags."""

NONE = 0
BEGIN = auto()
END = auto()
59 changes: 17 additions & 42 deletions lib/python/flame/mode/horizontal/coord_syncfl/middle_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging

from flame.mode.composer import Composer
from flame.mode.composer import CloneComposer
from flame.mode.horizontal.syncfl.middle_aggregator import (
TAG_AGGREGATE,
TAG_DISTRIBUTE,
Expand All @@ -27,7 +27,7 @@
MiddleAggregator as BaseMiddleAggregator,
)
from flame.mode.message import MessageType
from flame.mode.tasklet import Loop, Tasklet
from flame.mode.tasklet import Tasklet

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,53 +99,28 @@ def _handle_no_trainer(self):

def compose(self) -> None:
"""Compose role with tasklets."""
with Composer() as composer:
self.composer = composer

task_internal_init = Tasklet("", self.internal_init)

task_init = Tasklet("", self.initialize)
super().compose()

task_load_data = Tasklet("", self.load_data)
with CloneComposer(self.composer) as composer:
self.composer = composer

task_get_trainers = Tasklet("", self._get_trainers)

task_no_trainer = Tasklet("", self._handle_no_trainer)
task_no_trainer.set_continue_fn(cont_fn=lambda: self.no_trainer)

task_put_dist = Tasklet("", self.put, TAG_DISTRIBUTE)

task_put_upload = Tasklet("", self.put, TAG_UPLOAD)

task_get_aggr = Tasklet("", self.get, TAG_AGGREGATE)

task_get_fetch = Tasklet("", self.get, TAG_FETCH)

task_eval = Tasklet("", self.evaluate)

task_update_round = Tasklet("", self.update_round)

task_end_of_training = Tasklet("", self.inform_end_of_training)

loop = Loop(loop_check_fn=lambda: self._work_done)
(
task_internal_init
>> task_load_data
>> task_init
>> loop(
task_get_trainers
>> task_no_trainer
>> task_get_fetch
>> task_put_dist
>> task_get_aggr
>> task_put_upload
>> task_eval
>> task_update_round
)
>> task_end_of_training
)
self.composer.get_tasklet("fetch").insert_before(task_no_trainer)
task_no_trainer.insert_before(task_get_trainers)

@classmethod
def get_func_tags(cls) -> list[str]:
"""Return a list of function tags defined in the top level aggregator role."""
return [TAG_AGGREGATE, TAG_DISTRIBUTE, TAG_FETCH, TAG_UPLOAD, TAG_COORDINATE]
"""Return a list of function tags defined in the top level
aggregator role.
"""
return [
TAG_AGGREGATE,
TAG_DISTRIBUTE,
TAG_FETCH,
TAG_UPLOAD,
TAG_COORDINATE,
]
52 changes: 7 additions & 45 deletions lib/python/flame/mode/horizontal/coord_syncfl/top_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging

from flame.mode.composer import Composer
from flame.mode.composer import CloneComposer
from flame.mode.horizontal.syncfl.top_aggregator import TAG_AGGREGATE, TAG_DISTRIBUTE
from flame.mode.horizontal.syncfl.top_aggregator import (
TopAggregator as BaseTopAggregator,
Expand Down Expand Up @@ -62,54 +62,16 @@ def get_coordinated_ends(self):

def compose(self) -> None:
"""Compose role with tasklets."""
with Composer() as composer:
self.composer = composer

task_internal_init = Tasklet("", self.internal_init)

task_init = Tasklet("", self.initialize)

task_load_data = Tasklet("", self.load_data)

task_get_coord_ends = Tasklet("", self.get_coordinated_ends)

task_put = Tasklet("", self.put, TAG_DISTRIBUTE)
super().compose()

task_get = Tasklet("", self.get, TAG_AGGREGATE)

task_train = Tasklet("", self.train)

task_eval = Tasklet("", self.evaluate)

task_analysis = Tasklet("", self.run_analysis)

task_save_metrics = Tasklet("", self.save_metrics)

task_increment_round = Tasklet("", self.increment_round)
with CloneComposer(self.composer) as composer:
self.composer = composer

task_save_params = Tasklet("", self.save_params)
task_get_coord_ends = Tasklet("get_coord_ends", self.get_coordinated_ends)

task_save_model = Tasklet("", self.save_model)

# create a loop object with loop exit condition function
loop = Loop(loop_check_fn=lambda: self._work_done)
(
task_internal_init
>> task_load_data
>> task_init
>> loop(
task_get_coord_ends
>> task_put
>> task_get
>> task_train
>> task_eval
>> task_analysis
>> task_save_metrics
>> task_increment_round
)
>> task_save_params
>> task_save_model
)
self.composer.get_tasklet("distribute").insert_before(task_get_coord_ends)
self.composer.get_tasklet("inform_end_of_training").remove()

@classmethod
def get_func_tags(cls) -> list[str]:
Expand Down
Loading

0 comments on commit f8f1da9

Please sign in to comment.