Skip to content

Commit

Permalink
Merge pull request #1 from elqurio/development
Browse files Browse the repository at this point in the history
Make sdk config backwards compatible. (cisco-open#355)
  • Loading branch information
openwithcode authored Mar 3, 2023
2 parents 66b6db5 + a627cbd commit c150b73
Show file tree
Hide file tree
Showing 16 changed files with 41 additions and 33 deletions.
14 changes: 11 additions & 3 deletions lib/python/flame/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
"""Config parser."""

from enum import Enum
from pydantic import Field
import typing as t
from pydantic import Field
from pydantic import BaseModel as pydBaseModel
import json

Expand Down Expand Up @@ -82,6 +82,8 @@ class Registry(FlameSchema):
class Selector(FlameSchema):
sort: SelectorType = Field(default=SelectorType.DEFAULT)
kwargs: dict = Field(default={})


class Optimizer(FlameSchema):
sort: OptimizerType = Field(default=OptimizerType.DEFAULT)
kwargs: dict = Field(default={})
Expand All @@ -97,6 +99,7 @@ class Hyperparameters(FlameSchema):
learning_rate: t.Optional[float] = Field(alias="learningRate")
rounds: int
epochs: int
aggregation_goal: t.Optional[int] = Field(alias="aggGoal", default=None)


class Groups(FlameSchema):
Expand Down Expand Up @@ -148,6 +151,12 @@ class ChannelConfigs(FlameSchema):


class Config(FlameSchema):
def __init__(self, config_path: str):
raw_config = read_config(config_path)
transformed_config = transform_config(raw_config)

super().__init__(**transformed_config)

role: str
realm: str
task: t.Optional[str] = Field(default="local")
Expand All @@ -174,8 +183,7 @@ def read_config(filename: str) -> dict:
return json.loads(f.read())


def load_config(filename: str) -> Config:
raw_config = read_config(filename)
def transform_config(raw_config: dict) -> dict:
config_data = {
"role": raw_config["role"],
"realm": raw_config["realm"],
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/adult/aggregator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import torch
import torch.nn as nn
from flame.config import Config, load_config
from flame.config import Config
from flame.mode.horizontal.top_aggregator import TopAggregator

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,7 +88,7 @@ def evaluate(self) -> None:

args = parser.parse_args()

config = load_config(args.config)
config = Config(args.config)

t = PyTorchAdultAggregator(config)
t.compose()
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/adult/trainer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch.optim as optim
from flame.common.constants import DATA_FOLDER_PATH
from flame.common.util import install_packages
from flame.config import Config, load_config
from flame.config import Config
from flame.mode.horizontal.trainer import Trainer

install_packages(['scikit-learn'])
Expand Down Expand Up @@ -149,7 +149,7 @@ def evaluate(self) -> None:

args = parser.parse_args()

config = load_config(args.config)
config = Config(args.config)

t = PyTorchAdultTrainer(config)
t.compose()
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/dist_mnist/trainer/keras/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from statistics import mean

import numpy as np
from flame.config import Config, load_config
from flame.config import Config
from flame.mode.distributed.trainer import Trainer
from tensorflow import keras
from tensorflow.keras import layers
Expand Down Expand Up @@ -132,7 +132,7 @@ def evaluate(self) -> None:

args = parser.parse_args()

config = load_config(args.config)
config = Config(args.config)

t = KerasMnistTrainer(config)
t.compose()
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/dist_mnist/trainer/pytorch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data_utils
from flame.config import Config, load_config
from flame.config import Config
from flame.mode.distributed.trainer import Trainer
from torchvision import datasets, transforms

Expand Down Expand Up @@ -145,7 +145,7 @@ def evaluate(self) -> None:

args = parser.parse_args()

config = load_config(args.config)
config = Config(args.config)

t = PyTorchMnistTrainer(config)
t.compose()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import logging

from flame.config import Config, load_config
from flame.config import Config
from flame.mode.horizontal.middle_aggregator import MiddleAggregator
# the following needs to be imported to let the flame know
# this aggregator works on tensorflow model
Expand Down Expand Up @@ -58,7 +58,7 @@ def evaluate(self) -> None:

args = parser.parse_args()

config = load_config(args.config)
config = Config(args.config)

a = KerasMnistMiddleAggregator(config)
a.compose()
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/hier_mnist/top_aggregator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import logging

from flame.config import Config, load_config
from flame.config import Config
from flame.dataset import Dataset
from flame.mode.horizontal.top_aggregator import TopAggregator
from tensorflow import keras
Expand Down Expand Up @@ -82,7 +82,7 @@ def evaluate(self) -> None:

args = parser.parse_args()

config = load_config(args.config)
config = Config(args.config)

a = KerasMnistTopAggregator(config)
a.compose()
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/hier_mnist/trainer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from statistics import mean

import numpy as np
from flame.config import Config, load_config
from flame.config import Config
from flame.mode.horizontal.trainer import Trainer
from tensorflow import keras
from tensorflow.keras import layers
Expand Down Expand Up @@ -131,7 +131,7 @@ def evaluate(self) -> None:

args = parser.parse_args()

config = load_config(args.config)
config = Config(args.config)

t = KerasMnistTrainer(config)
t.compose()
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/hybrid/aggregator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import logging

from flame.config import Config, load_config
from flame.config import Config
from flame.dataset import Dataset
from flame.mode.horizontal.top_aggregator import TopAggregator
from tensorflow import keras
Expand Down Expand Up @@ -81,7 +81,7 @@ def evaluate(self) -> None:

args = parser.parse_args()

config = load_config(args.config)
config = Config(args.config)

a = KerasMnistAggregator(config)
a.compose()
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/hybrid/trainer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from statistics import mean

import numpy as np
from flame.config import Config, load_config
from flame.config import Config
from flame.mode.hybrid.trainer import Trainer
from tensorflow import keras
from tensorflow.keras import layers
Expand Down Expand Up @@ -131,7 +131,7 @@ def evaluate(self) -> None:

args = parser.parse_args()

config = load_config(args.config)
config = Config(args.config)

t = KerasMnistTrainer(config)
t.compose()
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/medmnist/aggregator/pytorch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import logging

from flame.config import Config, load_config
from flame.config import Config
from flame.dataset import Dataset # Not sure why we need this.
from flame.mode.horizontal.top_aggregator import TopAggregator
import torch
Expand Down Expand Up @@ -85,7 +85,7 @@ def evaluate(self) -> None:

args = parser.parse_args()

config = load_config(args.config)
config = Config(args.config)

a = PyTorchMedMNistAggregator(config)
a.compose()
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/medmnist/trainer/pytorch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from flame.common.util import install_packages
install_packages(['scikit-learn'])

from flame.config import Config, load_config
from flame.config import Config
from flame.mode.horizontal.trainer import Trainer
import torch
import torchvision
Expand Down Expand Up @@ -212,7 +212,7 @@ def evaluate(self) -> None:

args = parser.parse_args()

config = load_config(args.config)
config = Config(args.config)

t = PyTorchMedMNistTrainer(config)
t.compose()
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/mnist/aggregator/keras/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import logging

from flame.config import Config, load_config
from flame.config import Config
from flame.dataset import Dataset
from flame.mode.horizontal.top_aggregator import TopAggregator
from tensorflow import keras
Expand Down Expand Up @@ -82,7 +82,7 @@ def evaluate(self) -> None:

args = parser.parse_args()

config = load_config(args.config)
config = Config(args.config)

a = KerasMnistAggregator(config)
a.compose()
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/mnist/aggregator/pytorch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from flame.config import Config, load_config
from flame.config import Config
from flame.dataset import Dataset
from flame.mode.horizontal.top_aggregator import TopAggregator
from torchvision import datasets, transforms
Expand Down Expand Up @@ -143,7 +143,7 @@ def evaluate(self) -> None:

args = parser.parse_args()

config = load_config(args.config)
config = Config(args.config)

a = PyTorchMnistAggregator(config)
a.compose()
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/mnist/trainer/keras/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from statistics import mean

import numpy as np
from flame.config import Config, load_config
from flame.config import Config
from flame.mode.horizontal.trainer import Trainer
from tensorflow import keras
from tensorflow.keras import layers
Expand Down Expand Up @@ -131,7 +131,7 @@ def evaluate(self) -> None:

args = parser.parse_args()

config = load_config(args.config)
config = Config(args.config)

t = KerasMnistTrainer(config)
t.compose()
Expand Down
4 changes: 2 additions & 2 deletions lib/python/flame/examples/mnist/trainer/pytorch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data_utils
from flame.config import Config, load_config
from flame.config import Config
from flame.mode.horizontal.trainer import Trainer


Expand Down Expand Up @@ -146,7 +146,7 @@ def evaluate(self) -> None:
parser.add_argument('config', nargs='?', default="./config.json")

args = parser.parse_args()
config = load_config(args.config)
config = Config(args.config)

t = PyTorchMnistTrainer(config)
t.compose()
Expand Down

0 comments on commit c150b73

Please sign in to comment.