Skip to content

Commit

Permalink
example/implementation for fedprox (#339)
Browse files Browse the repository at this point in the history
This example is similar to the ones seen in the fedprox paper, although it currently does not simmulate stragglers and uses another dataset/architecture.

A few things were changed in order for there to be a simple process for modifying trainers.
This includes a function in util.py and another class variable in the trainer containing information on the client side regularizer.

Additionally, tests are automated (mu=1,0.1,0.01,0.001,0) so running the example generates or modifies existing files in order to provide the propper configuration for an experiment.
  • Loading branch information
GustavBaumgart committed Feb 23, 2023
1 parent 19c39eb commit a7ea2a2
Show file tree
Hide file tree
Showing 23 changed files with 1,001 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lib/python/flame/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def get_ml_framework_in_use():

return ml_framework_in_use

def get_params_detached_pytorch(model):
"""Return copy of parameters of pytorch model disconnected from graph."""
return [param.detach().clone() for param in model.parameters()]

@contextmanager
def background_thread_loop():
Expand Down
1 change: 1 addition & 0 deletions lib/python/flame/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class OptimizerType(Enum):
# FedBuff from https://arxiv.org/pdf/1903.03934.pdf and
# https://arxiv.org/pdf/2111.04877.pdf
FEDBUFF = 5
FEDPROX = 6 # FedProx


class SelectorType(Enum):
Expand Down
49 changes: 49 additions & 0 deletions lib/python/flame/examples/medmnist_fedprox/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
## FedProx MedMNIST Example

We use the PathMNIST dataset from (MedMNIST)[https://medmnist.com/] to go over an example of FedProx with different mu values (mu=1.0, 0.1, 0.01, 0.001, 0.0).
Although FedProx, unlike FedAvg, incorporates stragglers, we will not include stragglers in the following tests.
Thus, we are measuring the effects of the proximal term.

This example is run within conda environment, so we activate the environment first.
Once you are in the `medmnist_fedprox` directory, run the following command.

```bash
conda activate flame
```

Once this is complete, and say, you want to run the example (using all 10 trainers), you can select a value from mu=1.0, 0.1, 0.01, 0.001, 0.0 and run `python run.py <mu>`.
If you want to run the example with `mu=1.0`, you can run:

```bash
python run.py 1.0
```

We recommend running only one trial (or mu value) at a time.
This way you can track the progress by running the following commands:

```bash
cat aggregator/log.txt | grep -i test
```

OR

```bash
cat aggregator/log.txt | grep -i test | wc -l
```

The last command will return 300 when all 100 rounds have finished running.

Once you have finished running all the mu-values, the files that track accuracy/loss accross different rounds should be in the `aggregator` directory.
Without chaning their names, run the command below to generate figures.

```bash
python figures.py
```

The figures will be generated in the `medmnist_fedprox` directory.

We include two of them below. Overall, we found that mu=0.01 was the best value.

![acc_fedprox](images/acc_all.png)
![loss_fedprox](images/loss_all.png)

17 changes: 17 additions & 0 deletions lib/python/flame/examples/medmnist_fedprox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0


17 changes: 17 additions & 0 deletions lib/python/flame/examples/medmnist_fedprox/aggregator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0


189 changes: 189 additions & 0 deletions lib/python/flame/examples/medmnist_fedprox/aggregator/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

"""MedMNIST FedProx aggregator for PyTorch."""

import logging

from flame.config import Config
from flame.dataset import Dataset # Not sure why we need this.
from flame.mode.horizontal.top_aggregator import TopAggregator
import torch

from sklearn.metrics import accuracy_score
import numpy as np
from PIL import Image
import torchvision

logger = logging.getLogger(__name__)

# keep track of losses/accuracies of global model
fed_acc = []
fed_loss = []

class PathMNISTDataset(torch.utils.data.Dataset):
def __init__(self, transform=None, as_rgb=False):
npz_file = np.load("pathmnist.npz")

self.transform = transform
self.as_rgb = as_rgb

self.imgs = npz_file["val_images"]
self.labels = npz_file["val_labels"]

def __len__(self):
return self.imgs.shape[0]

def __getitem__(self, index):
img, target = self.imgs[index], self.labels[index].astype(int)
img = Image.fromarray(img)

if self.as_rgb:
img = img.convert('RGB')

if self.transform is not None:
img = self.transform(img)

return img, target

class CNN(torch.nn.Module):
"""CNN Class"""

def __init__(self, num_classes):
"""Initialize."""
super(CNN, self).__init__()
self.num_classes = num_classes
self.features = torch.nn.Sequential(
torch.nn.Conv2d(3, 6, kernel_size=3, padding=1),
torch.nn.BatchNorm2d(6),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(6, 16, kernel_size=3, padding=1),
torch.nn.BatchNorm2d(16),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc = torch.nn.Linear(16 * 7 * 7, num_classes)

def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

class PyTorchMedMNistAggregator(TopAggregator):
"""PyTorch MedMNist Aggregator"""

def __init__(self, config: Config) -> None:
self.config = config
self.model = None
self.dataset: Dataset = None # Not sure why we need this.

self.batch_size = self.config.hyperparameters['batchSize']

self.device = torch.device("cpu")

def initialize(self):
"""Initialize."""
self.model = CNN(num_classes=9)
self.criterion = torch.nn.CrossEntropyLoss()

def load_data(self) -> None:
"""Load a test dataset."""
logger.info('in load_data')
# FIX this. easy to break right now
self._download()

data_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

dataset = PathMNISTDataset(transform=data_transform)

self.loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=4 * torch.cuda.device_count(),
pin_memory=True,
drop_last=True
)
self.dataset_size = len(dataset)

def _download(self) -> None:
import requests
r = requests.get(self.config.dataset, allow_redirects=True)
open('pathmnist.npz', 'wb').write(r.content)

def train(self) -> None:
"""Train a model."""
# Implement this if training is needed in aggregator
pass

def evaluate(self) -> None:
"""Evaluate (test) a model."""
self.model.eval()
loss_lst = list()
labels = torch.tensor([],device=self.device)
labels_pred = torch.tensor([],device=self.device)
with torch.no_grad():
for data, label in self.loader:
data, label = data.to(self.device), label.to(self.device)
output = self.model(data)
loss = self.criterion(output, label.squeeze())
loss_lst.append(loss.item())
labels_pred = torch.cat([labels_pred, output.argmax(dim=1)], dim=0)
labels = torch.cat([labels, label], dim=0)

labels_pred = labels_pred.cpu().detach().numpy()
labels = labels.cpu().detach().numpy()
val_acc = accuracy_score(labels, labels_pred)

# loss here not as meaningful
val_loss = sum(loss_lst) / len(loss_lst)
self.update_metrics({"Val Loss": val_loss, "Val Accuracy": val_acc, "Testset Size": self.dataset_size})
logger.info(f"Test Loss: {val_loss}")
logger.info(f"Test Accuracy: {val_acc}")
logger.info(f"Testset Size: {self.dataset_size}")

# record losses/accuracies
global fed_acc, fed_loss
fed_acc.append(val_acc)
fed_loss.append(val_loss)

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description='')
parser.add_argument('config', nargs='?', default="./config.json")

args = parser.parse_args()

config = Config(args.config)

a = PyTorchMedMNistAggregator(config)
a.compose()
a.run()

# write records to files
mu = config.optimizer.kwargs['mu']
file1 = open(f'acc_mu{mu}.txt','w')
file1.write('\n'.join(map(str,fed_acc)))
file1.close()
file2 = open(f'loss_mu{mu}.txt','w')
file2.write('\n'.join(map(str,fed_loss)))
file2.close()
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
{
"taskid": "49d06b7526964db86cf37c70e8e0cdb6bd7aa742",
"backend": "mqtt",
"brokers": [
{
"host": "localhost",
"sort": "mqtt"
},
{
"host": "localhost:10104",
"sort": "p2p"
}
],
"channels": [
{
"description": "Model update is sent from trainer to aggregator and vice-versa",
"groupBy": {
"type": "tag",
"value": [
"default"
]
},
"name": "param-channel",
"pair": [
"trainer",
"aggregator"
],
"funcTags": {
"aggregator": ["distribute", "aggregate"],
"trainer": ["fetch", "upload"]
}
}
],
"dataset": "",
"dependencies": [
"numpy >= 1.2.0"
],
"hyperparameters": {
"batchSize": 32,
"learningRate": 0.001,
"rounds": 100,
"epochs": 4
},
"baseModel": {
"name": "",
"version": 1
},
"job" : {
"id": "622a358619ab59012eabeefb",
"name": "mednist"
},
"registry": {
"sort": "dummy",
"uri": ""
},
"selector": {
"sort": "default",
"kwargs": {}
},
"optimizer": {
"sort": "fedprox",
"kwargs": {
"mu": 0
}
},
"maxRunTime": 300,
"realm": "default",
"role": "aggregator"
}
Loading

0 comments on commit a7ea2a2

Please sign in to comment.