Skip to content

Commit

Permalink
feat: FedGFT implementation (#453)
Browse files Browse the repository at this point in the history
The FedGFT federated learning algorithm uses a custom trainer and aggregator in mode/horizontal due to the modified communication protocol.
FedGFT also focusses on minimization of bias in training as well, so we track additional metrics for performance.

An example and its results are posted in the examples/compas_fedgft folder.
  • Loading branch information
GustavBaumgart committed Oct 20, 2023
1 parent 0cae4a5 commit 47c78af
Show file tree
Hide file tree
Showing 20 changed files with 1,298 additions and 10 deletions.
52 changes: 52 additions & 0 deletions lib/python/examples/compas_fedgft/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
## FedGFT COMPAS Example

We use the COMPAS dataset to go over an example of [FedGFT](https://arxiv.org/abs/2305.09931).
FedGFT, similar to FedAvg, averages the weights of the trainers after training.
However, it also modifies the client-side loss function to ensure that the training also optimizes for fairness.
Fairness can be statistical parity (represented as `SP` in the config file), equal opportunity (`EOP`), and well-calibration (`CAL`).
The algorithm uses a hyperparameter gamma which can be in the range [0, inf) to specificy how important minimizing for bias is.
A value of 0 would mean the loss function is not modified, so there is no mathematical difference between FedAvg and FedGFT with gamma=0.
The gamma value can be specified in both `template.json` files in the `trainer` and `aggregator` folders.

We chose the most strict value to enforce fairness (gamma=50) used in the [Mitigating Group Bias in Federated Learning: Beyond Local Fairness](https://arxiv.org/abs/2305.09931) paper.
We also chose to measure the effects of minimization on statistical parity (`SP`).

This example can be run in the conda environment, so we activate the environment first.
Once you are in the `examples` directory, run the following command.

```bash
conda activate flame
```

If everything in the two `template.json` files represents the desired hyperparameters you would like, go head and run the following code to run the FedGFT example:

```bash
python run.py compas_fedgft
```

This will generate the different config files needed to run the example with 10 trainers and 1 aggregator.

All output that is generated during runtime will be stored in the `compas_fedgft/output` folder.
This includes all log files and data that were downloaded for the trainers/aggregator.
If you wish to log the metrics, you can set the registry to 'local' in the config file for the aggregator's `template.json`.
Re-running the experiment will automatically delete the current `output` folder and recreate it with the files generated in the current run.

To check the progress at this level, you can run the following command to check on the global model's accuracy:

```bash
cat compas_fedgft/output/aggregator/log.txt | grep -i accuracy
```

Once the model is done training you should have that the command below should return 50 (or the number of specified rounds, if that was changed).

```bash
cat compas_fedgft/output/aggregator/log.txt | grep -i accuracy | wc -l
```

Below is a comparison of global bias, accuracy, and loss values using gamma=0,50.

![bias](images/bias.png)

![accuracy](images/accuracy.png)

![loss](images/loss.png)
142 changes: 142 additions & 0 deletions lib/python/examples/compas_fedgft/aggregator/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""COMPAS FedGFT aggregator for PyTorch."""

import logging

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from flame.common.util import get_dataset_filename
from flame.config import Config
from flame.mode.horizontal.fedgft.top_aggregator import TopAggregator

logger = logging.getLogger(__name__)


class LogisticRegression(nn.Module):
# require 32*32 pixels
def __init__(self, in_features=120, n_classes=2):
super().__init__()
self.classifier = nn.Sequential(
nn.Linear(in_features=in_features, out_features=n_classes),
)

def forward(self, x):
logits = self.classifier(x)
# Since NLL or KL divergence uses log-probability as input, we need to use log_softmax
probs = F.log_softmax(logits, dim=1)
return probs


class CompasDataset(torch.utils.data.Dataset):
def __init__(self, filename):
npz_file = np.load(filename)
self.dataset = npz_file["dataset"]
self.targets = npz_file["target"]
self.group = npz_file["group"]

def __len__(self):
return self.dataset.shape[0]

def __getitem__(self, index):
return self.dataset[index, :], self.targets[index], self.group[index]


class PyTorchCompasAggregator(TopAggregator):
"""PyTorch COMPAS FedGFT Aggregator"""

def __init__(self, config: Config) -> None:
self.config = config
self.model = None
self.dataset = None

self.batch_size = self.config.hyperparameters.batch_size

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def initialize(self):
"""Initialize."""
self.model = LogisticRegression(in_features=10).to(self.device)
self.criterion = torch.nn.CrossEntropyLoss()

def load_data(self) -> None:
"""Load a test dataset."""
filename = get_dataset_filename(self.config.dataset)
dataset = CompasDataset(filename=filename)

self.test_loader = torch.utils.data.DataLoader(
dataset, batch_size=self.batch_size, shuffle=True
)
self.dataset_size = len(dataset)

def pre_process(self) -> None:
"""Log and report Bias."""
curr_bias = abs(self.optimizer.get_bias())
logger.info(f"Bias: {curr_bias}")
self.update_metrics({"bias": curr_bias})

def train(self) -> None:
"""Train a model."""
# Implement this if training is needed in aggregator
pass

def evaluate(self) -> None:
"""Evaluate (test) a model."""
self.model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target, group in self.test_loader:
data, target, group = (
data.to(self.device),
target.to(self.device),
group.to(self.device),
)
output = self.model(data)
test_loss += F.nll_loss(
output, target, reduction="sum"
).item() # sum up batch loss
pred = output.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()

total = len(self.test_loader.dataset)
test_loss /= total
test_accuracy = correct / total

logger.info(f"Loss: {test_loss}")
logger.info(f"Accuracy: {test_accuracy}")

self.update_metrics({"test-loss": test_loss, "test-accuracy": test_accuracy})


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="")
parser.add_argument("config", nargs="?", default="./config.json")

args = parser.parse_args()

config = Config(args.config)

a = PyTorchCompasAggregator(config)
a.compose()
a.run()
84 changes: 84 additions & 0 deletions lib/python/examples/compas_fedgft/aggregator/template.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
{
"taskid": "49d06b7526964db86cf37c70e8e0cdb6bdeaa742",
"backend": "p2p",
"brokers": [
{
"host": "localhost",
"sort": "mqtt"
},
{
"host": "localhost:10104",
"sort": "p2p"
}
],
"groupAssociation": {
"param-channel": "default"
},
"channels": [
{
"description": "Model update is sent from trainer to aggregator and vice-versa",
"groupBy": {
"type": "tag",
"value": [
"default"
]
},
"name": "param-channel",
"pair": [
"trainer",
"aggregator"
],
"funcTags": {
"aggregator": [
"distribute",
"aggregate",
"aggregateBias",
"distributeBias"
],
"trainer": [
"fetch",
"upload",
"fetchBias",
"uploadBias"
]
}
}
],
"dataset": "https://github.com/GustavBaumgart/flame-datasets/raw/main/compas/test.npz",
"dependencies": [
"numpy >= 1.2.0"
],
"hyperparameters": {
"batchSize": 256,
"learningRate": 0.002,
"rounds": 50,
"epochs": 1
},
"baseModel": {
"name": "",
"version": 2
},
"job": {
"id": "336a358619ab59012eabeefb",
"name": "medmnist"
},
"registry": {
"sort": "dummy",
"uri": ""
},
"selector": {
"sort": "default",
"kwargs": {}
},
"optimizer": {
"sort": "fedgft",
"kwargs": {
"fair": "SP",
"gamma": 50,
"reg": "l2"
}
},
"maxRunTime": 300,
"realm": "default",
"role": "aggregator"
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added lib/python/examples/compas_fedgft/images/bias.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added lib/python/examples/compas_fedgft/images/loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 47c78af

Please sign in to comment.