Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updated feddyn implementation pytorch #392

Merged
merged 1 commit into from
Apr 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/python/flame/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ async def inner():

result, _ = run_async(inner(), self._backend.loop())
return result

def all_ends(self):
"""Return a list of all end ids (needed in FedDyn to compute alpha values)."""
return list(self._ends.keys())

def ends_digest(self) -> str:
"""Compute a digest of ends."""
Expand Down
10 changes: 5 additions & 5 deletions lib/python/flame/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ class DeviceType(Enum):
CPU = 1
GPU = 2

class TrainerState(Enum):
"""Enum class for trainer state."""
class TrainState(Enum):
"""Enum class for train state."""

PRE_TRAIN = 'pre_train'
DURING_TRAIN = 'during_train'
POST_TRAIN = 'post_train'
PRE = 'pre'
DURING = 'during'
POST = 'post'
48 changes: 48 additions & 0 deletions lib/python/flame/examples/medmnist_feddyn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
## FedDyn MedMNIST Example

We use the PathMNIST dataset from (MedMNIST)[https://medmnist.com/] to go over an example of FedDyn (alpha=0.01).
Here, the alpha value can be specified in both `template.json` files in the `trainer` and `aggregator` folders.
We chose the most commonly used value in the (Federated Learning Based on Dynamic Regularization)[https://arxiv.org/abs/2111.04263] paper, along with the same `weight_decay` value used (0.001). The learning rate was chosen to be 0.001, because a larger one did not allow the models to train well.

Since we include the `weight_decay` value as a hyperparameter to the feddyn optimizer in the config file, we recommend setting the `self.optimizer`'s `weight_decay` value in `trainer/main.py` to be 0.0, as shown below.

```python
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3, weight_decay=0.0)
```

This example is run within the conda environment, so we activate the environment first.
Once you are in the `medmnist_feddyn` directory, run the following command.

```bash
conda activate flame
```

If everything in the two `template.json` files represents the desired hyperparameters you would like, go head and run the following code to run the entire example:

```bash
python run.py
```

This will generate the different config files needed to run the example with 10 trainers and 1 aggregator.

All output will be stored in the `output` folder that is generated during runtime.
This includes all log files and data that was downloaded for the trainers/aggregator.
The aggregator folder should also include the list of accuracy/loss values derived from a global test set.
This folder can be deleted and will not affect your ability to re-run the example (in fact, if you re-run the example without deleting this folder, the `output` folder will be deleted first).

To check the progress at this level, you can run the following command to check on the global model's accuracy:

```bash
cat output/aggregator/log.txt | grep -i accuracy
```

Once the model is finished you should have that the command below should return 100 (or the number of specified rounds, if that was changed).

```bash
cat output/aggregator/log.txt | grep -i accuracy | wc -l
```

We compared global test accuracy values using alpha=0.01 to using mu=0.01/0.0 in FedProx (FedProx with mu=0.0 is equivalent to FedAvg).

![acc_feddyn](images/accuracy.png)

17 changes: 17 additions & 0 deletions lib/python/flame/examples/medmnist_feddyn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0


17 changes: 17 additions & 0 deletions lib/python/flame/examples/medmnist_feddyn/aggregator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0


185 changes: 185 additions & 0 deletions lib/python/flame/examples/medmnist_feddyn/aggregator/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""MedMNIST FedProx aggregator for PyTorch."""

import logging

from flame.common.util import get_dataset_filename
from flame.config import Config
from flame.dataset import Dataset
from flame.mode.horizontal.feddyn.top_aggregator import TopAggregator
import torch

from sklearn.metrics import accuracy_score
import numpy as np
from PIL import Image
import torchvision

logger = logging.getLogger(__name__)

# keep track of losses/accuracies of global model
fed_acc = []
fed_loss = []


class PathMNISTDataset(torch.utils.data.Dataset):

def __init__(self, filename, transform=None, as_rgb=False):
npz_file = np.load(filename)

self.transform = transform
self.as_rgb = as_rgb

self.imgs = npz_file["val_images"]
self.labels = npz_file["val_labels"]

def __len__(self):
return self.imgs.shape[0]

def __getitem__(self, index):
img, target = self.imgs[index], self.labels[index].astype(int)
img = Image.fromarray(img)

if self.as_rgb:
img = img.convert('RGB')

if self.transform is not None:
img = self.transform(img)

return img, target


class CNN(torch.nn.Module):
"""CNN Class"""

def __init__(self, num_classes):
"""Initialize."""
super(CNN, self).__init__()
self.num_classes = num_classes
self.features = torch.nn.Sequential(
torch.nn.Conv2d(3, 6, kernel_size=3, padding=1),
torch.nn.BatchNorm2d(6), torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(6, 16, kernel_size=3, padding=1),
torch.nn.BatchNorm2d(16), torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = torch.nn.Linear(16 * 7 * 7, num_classes)

def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x


class PyTorchMedMNistAggregator(TopAggregator):
"""PyTorch MedMNist Aggregator"""

def __init__(self, config: Config) -> None:
self.config = config
self.model = None
self.dataset: Dataset = None # Not sure why we need this.

self.batch_size = self.config.hyperparameters.batch_size

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def initialize(self):
"""Initialize."""
self.model = CNN(num_classes=9).to(self.device)
self.criterion = torch.nn.CrossEntropyLoss()

def load_data(self) -> None:
"""Load a test dataset."""

filename = get_dataset_filename(self.config.dataset)

data_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
])

dataset = PathMNISTDataset(filename=filename, transform=data_transform)

self.loader = torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=4 * torch.cuda.device_count(),
pin_memory=True,
drop_last=True
)
self.dataset_size = len(dataset)

def train(self) -> None:
"""Train a model."""
# Implement this if training is needed in aggregator
pass

def evaluate(self) -> None:
"""Evaluate (test) a model."""
self.model.eval()
loss_lst = list()
labels = torch.tensor([],device=self.device)
labels_pred = torch.tensor([],device=self.device)
with torch.no_grad():
for data, label in self.loader:
data, label = data.to(self.device), label.to(self.device)
output = self.model(data)
loss = self.criterion(output, label.squeeze())
loss_lst.append(loss.item())
labels_pred = torch.cat([labels_pred, output.argmax(dim=1)], dim=0)
labels = torch.cat([labels, label], dim=0)

labels_pred = labels_pred.cpu().detach().numpy()
labels = labels.cpu().detach().numpy()
val_acc = accuracy_score(labels, labels_pred)

# loss here not as meaningful
val_loss = sum(loss_lst) / len(loss_lst)
self.update_metrics({"Val Loss": val_loss, "Val Accuracy": val_acc, "Testset Size": self.dataset_size})
logger.info(f"Test Loss: {val_loss}")
logger.info(f"Test Accuracy: {val_acc}")
logger.info(f"Testset Size: {self.dataset_size}")

# record losses/accuracies
global fed_acc, fed_loss
fed_acc.append(val_acc)
fed_loss.append(val_loss)

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description='')
parser.add_argument('config', nargs='?', default="./config.json")

args = parser.parse_args()

config = Config(args.config)

a = PyTorchMedMNistAggregator(config)
a.compose()
a.run()

# write records to files
alpha = config.optimizer.kwargs['alpha']
file1 = open(f'acc_alpha{alpha}.txt','w')
file1.write('\n'.join(map(str,fed_acc)))
file1.close()
file2 = open(f'loss_alpha{alpha}.txt','w')
file2.write('\n'.join(map(str,fed_loss)))
file2.close()
81 changes: 81 additions & 0 deletions lib/python/flame/examples/medmnist_feddyn/aggregator/template.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
{
"taskid": "49d06b7526964db86cf37c70e8e0cdb6bdeaa742",
"backend": "mqtt",
"brokers": [
{
"host": "localhost",
"sort": "mqtt"
},
{
"host": "localhost:10104",
"sort": "p2p"
}
],
"groupAssociation": {
"param-channel": "default"
},
"channels": [
{
"description": "Model update is sent from trainer to aggregator and vice-versa",
"groupBy": {
"type": "tag",
"value": [
"default"
]
},
"name": "param-channel",
"pair": [
"trainer",
"aggregator"
],
"funcTags": {
"aggregator": [
"distribute",
"aggregate",
"getDatasetSize"
],
"trainer": [
"fetch",
"upload",
"uploadDatasetSize"
]
}
}
],
"dataset": "https://github.com/GustavBaumgart/flame-datasets/raw/main/medmnist/all_val.npz",
"dependencies": [
"numpy >= 1.2.0"
],
"hyperparameters": {
"batchSize": 50,
"learningRate": 0.001,
"rounds": 100,
"epochs": 4
},
"baseModel": {
"name": "",
"version": 2
},
"job": {
"id": "336a358619ab59012eabeefb",
"name": "medmnist"
},
"registry": {
"sort": "dummy",
"uri": ""
},
"selector": {
"sort": "default",
"kwargs": {}
},
"optimizer": {
"sort": "feddyn",
"kwargs": {
"alpha": 0.01,
"weight_decay": 0.001
}
},
"maxRunTime": 300,
"realm": "default",
"role": "aggregator"
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading