Skip to content

Commit

Permalink
Move orchestrator architecture to flame/mode
Browse files Browse the repository at this point in the history
  • Loading branch information
lkurija1 committed Mar 29, 2023
1 parent 41d3f19 commit bdf4566
Show file tree
Hide file tree
Showing 14 changed files with 59 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2022 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""HIRE_MNIST horizontal hierarchical FL middle level aggregator for Keras."""

import logging

from flame.config import Config
from flame.mode.orchestrated.coordinator import Coordinator
# the following needs to be imported to let the flame know
# this aggregator works on tensorflow model
from tensorflow import keras

logger = logging.getLogger(__name__)




if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description='')
parser.add_argument('config', nargs='?', default="./config.json")

args = parser.parse_args()

config = Config(args.config)

t = Coordinator(config)
t.compose()
t.run()
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"name": "selector-channel",
"description": "Channel for trainer selection",
"pair": [
"aggregator",
"middle-aggregator",
"selector"
],
"groupBy": {
Expand All @@ -31,7 +31,7 @@
]
},
"funcTags": {
"aggregator": [
"middle-aggregator": [
"getTrainers"
],
"selector": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging

from flame.config import Config
from ...middle_aggregator import MiddleAggregator
from flame.mode.orchestrated.middle_aggregator import MiddleAggregator
# the following needs to be imported to let the flame know
# this aggregator works on tensorflow model
from tensorflow import keras
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import numpy as np
from flame.config import Config
from ...trainer import Trainer
from flame.mode.orchestrated.trainer import Trainer
from tensorflow import keras
from tensorflow.keras import layers

Expand Down

This file was deleted.

Empty file.
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from flame.mode.role import Role
from abc import ABCMeta
from flame.config import Config
from flame.channel_manager import ChannelManager
from flame.mode.tasklet import Tasklet, Loop
from flame.mode.composer import Composer
from .message import MessageType
from ...channel_manager import ChannelManager
from ...config import Config
from ...mode.composer import Composer
from ...mode.message import MessageType
from ...mode.role import Role
from ...mode.tasklet import Tasklet, Loop

TAG_SELECT_TRAINERS = 'selectTrainers'
TAG_REGISTER_TRAINER = 'registerTrainer'
Expand Down Expand Up @@ -90,6 +90,6 @@ def compose(self) -> None:


loop = Loop(loop_check_fn=lambda: self._work_done)
task_internal_init >> task_init >> task_await_mid_agg_and_trainers >> loop(
task_internal_init >> task_init >> loop( task_await_mid_agg_and_trainers >>
task_register_trainers >> task_send_selected_trainers
)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from flame.mode.horizontal.middle_aggregator import MiddleAggregator as BaseMiddleAggregator
from flame.mode.composer import Composer
from flame.mode.tasklet import Tasklet, Loop
from .message import MessageType
from flame.mode.message import MessageType

TAG_DISTRIBUTE = 'distribute'
TAG_AGGREGATE = 'aggregate'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
TAG_REGISTER = 'register'



class Trainer(BaseTrainer, metaclass=ABCMeta):
def notify_readiness(self):
def notify_readiness(self): # TODO: backend should do the readiness check
channel = self.cm.get_by_tag(TAG_REGISTER)
if not channel:
return
Expand All @@ -26,7 +25,7 @@ def _send_register(self, tag):
return

end = channel.one_end()
msg = {MessageType.NEW_TRAINER}
msg = {MessageType.NEW_TRAINER: True}
channel.send(end, msg)

def put(self, tag: str) -> None:
Expand All @@ -47,9 +46,9 @@ def compose(self) -> None:

task_init = Tasklet(self.initialize)

task_notify_readiness = Tasklet(self.notify_readiness)
task_notify_readiness = Tasklet(self.notify_readiness) # TODO: change naming

task_register = Tasklet(self.put, TAG_REGISTER)
task_register = Tasklet(self.put, TAG_REGISTER) # TODO: not needed because of cm.join_all()

task_get = Tasklet(self.get, TAG_FETCH)

Expand Down

0 comments on commit bdf4566

Please sign in to comment.