Skip to content

Commit

Permalink
Synchronous orchestrator architecture for Flame (#379)
Browse files Browse the repository at this point in the history
Additions:
- New mode added: coord_syncfl
The new mode revolves around the new component - Coordinator.
  - Top aggregator notifies the Coordinator of training state state/finish.
  - Mid aggregators now consult the Coordinator for selected trainers.
  - Trainers now register themselves with the Coordinator.
- New component added: Coordinator
Coordinators role is to coordinate the training process. It is responsible for selecting the trainers and aggregators and sending the selected trainers to the aggregators.
New architecture implements top and middle aggregators as well as trainers with additional lifespan steps which accomodate the protocol between them and the Coordinator.
- New example added: coord_hier_syncfl

contributors: @elqurio @myungjin
  • Loading branch information
lkurija1 committed Mar 31, 2023
1 parent 0469d6e commit f6c291e
Show file tree
Hide file tree
Showing 26 changed files with 1,877 additions and 8 deletions.
22 changes: 14 additions & 8 deletions lib/python/flame/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def groupable_value(self, group_association: str = ""):

return GROUPBY_DEFAULT_GROUP


class Broker(FlameSchema):
sort_to_host: dict

Expand Down Expand Up @@ -162,18 +163,18 @@ def __init__(self, config_path: str):
task_id: str
backend: BackendType
channels: dict
hyperparameters: Hyperparameters
hyperparameters: t.Optional[Hyperparameters]
brokers: Broker
job: Job
registry: Registry
selector: Selector
registry: t.Optional[Registry]
selector: t.Optional[Selector]
optimizer: t.Optional[Optimizer] = Field(default=Optimizer())
channel_configs: t.Optional[ChannelConfigs]
dataset: str
max_run_time: int
base_model: BaseModel
base_model: t.Optional[BaseModel]
groups: t.Optional[Groups]
dependencies: list[str]
dependencies: t.Optional[list[str]]
func_tag_map: t.Optional[dict]


Expand Down Expand Up @@ -203,18 +204,23 @@ def transform_config(raw_config: dict) -> dict:
"func_tag_map": func_tag_map
}

hyperparameters = transform_hyperparameters(raw_config["hyperparameters"])
config_data = config_data | {"hyperparameters": hyperparameters}
if raw_config.get("hyperparameters", None):
hyperparameters = transform_hyperparameters(
raw_config["hyperparameters"])

config_data = config_data | {"hyperparameters": hyperparameters}

sort_to_host = transform_brokers(raw_config["brokers"])
config_data = config_data | {"brokers": sort_to_host}

config_data = config_data | {
"job": raw_config["job"],
"registry": raw_config["registry"],
"selector": raw_config["selector"],
}

if raw_config.get("registry", None):
config_data = config_data | {"registry": raw_config["registry"]}

if raw_config.get("optimizer", None):
config_data = config_data | {"optimizer": raw_config.get("optimizer")}

Expand Down
15 changes: 15 additions & 0 deletions lib/python/flame/examples/coord_hier_syncfl_mnist/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
{
"taskid": "09d06b7526964db86cf37c70e8e0cdb6bd7aa743",
"backend": "mqtt",
"brokers": [
{
"host": "localhost",
"sort": "mqtt"
},
{
"host": "localhost:10104",
"sort": "p2p"
}
],
"groupAssociation": {
"top-agg-coord-channel": "default",
"middle-agg-coord-channel": "default",
"trainer-coord-channel": "default"
},
"channels": [
{
"name": "top-agg-coord-channel",
"description": "Channel between top aggregator and coordinator",
"pair": [
"top-aggregator",
"coordinator"
],
"groupBy": {
"type": "tag",
"value": [
"default"
]
},
"funcTags": {
"top-aggregator": [
"notifyCoordinator"
],
"coordinator": [
"checkEOT"
]
}
},
{
"name": "middle-agg-coord-channel",
"description": "Channel between middle aggregator and coordinator",
"pair": [
"middle-aggregator",
"coordinator"
],
"groupBy": {
"type": "tag",
"value": [
"default"
]
},
"funcTags": {
"middle-aggregator": [
"getTrainers"
],
"coordinator": [
"selectTrainers"
]
}
},
{
"name": "trainer-coord-channel",
"description": "Channel between trainer and coordinator",
"pair": [
"trainer",
"coordinator"
],
"groupBy": {
"type": "tag",
"value": [
"default"
]
},
"funcTags": {
"trainer": [
"getAggregator"
],
"coordinator": [
"selectAggregator"
]
}
}
],
"job": {
"id": "622a358619ab59012eabeefb",
"name": "mnist"
},
"selector": {
"sort": "default",
"kwargs": {}
},
"maxRunTime": 300,
"realm": "default-cluster",
"role": "coordinator"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""HIRE_MNIST horizontal hierarchical FL middle level aggregator for Keras."""

import logging

from flame.config import Config
from flame.mode.horizontal.coord_syncfl.coordinator import Coordinator

logger = logging.getLogger(__name__)

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="")
parser.add_argument("config", nargs="?", default="./config.json")

args = parser.parse_args()

config = Config(args.config)

t = Coordinator(config)
t.compose()
t.run()
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
{
"taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa743",
"backend": "mqtt",
"brokers": [
{
"host": "localhost",
"sort": "mqtt"
},
{
"host": "localhost:10104",
"sort": "p2p"
}
],
"groupAssociation": {
"middle-agg-coord-channel": "default",
"param-channel": "default",
"global-channel": "default"
},
"channels": [
{
"name": "middle-agg-coord-channel",
"description": "Channel between middle aggregator and coordinator",
"pair": [
"middle-aggregator",
"coordinator"
],
"groupBy": {
"type": "tag",
"value": [
"default"
]
},
"funcTags": {
"middle-aggregator": [
"getTrainers"
],
"coordinator": [
"selectTrainers"
]
}
},
{
"description": "Model update is sent from mid aggregator to global aggregator and vice-versa",
"groupBy": {
"type": "tag",
"value": [
"default"
]
},
"name": "global-channel",
"pair": [
"top-aggregator",
"middle-aggregator"
],
"funcTags": {
"top-aggregator": [
"distribute",
"aggregate"
],
"middle-aggregator": [
"fetch",
"upload"
]
}
},
{
"description": "Model update is sent from mid aggregator to trainer and vice-versa",
"groupBy": {
"type": "tag",
"value": [
"default"
]
},
"name": "param-channel",
"pair": [
"middle-aggregator",
"trainer"
],
"funcTags": {
"middle-aggregator": [
"distribute",
"aggregate"
],
"trainer": [
"fetch",
"upload"
]
}
}
],
"dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz",
"dependencies": [
"numpy >= 1.2.0"
],
"hyperparameters": {
"batchSize": 32,
"learningRate": 0.01,
"rounds": 5
},
"baseModel": {
"name": "",
"version": 1
},
"job": {
"id": "622a358619ab59012eabeefb",
"name": "mnist"
},
"registry": {
"sort": "dummy",
"uri": "http://flame-mlflow:5000"
},
"selector": {
"sort": "default",
"kwargs": {}
},
"maxRunTime": 300,
"realm": "default-cluster",
"role": "middle-aggregator"
}
Loading

0 comments on commit f6c291e

Please sign in to comment.