Skip to content

Commit

Permalink
Migrate advanced-pytorch example to use FDS (#2805)
Browse files Browse the repository at this point in the history
Co-authored-by: jafermarq <javier@flower.dev>
  • Loading branch information
adam-narozniak and jafermarq authored Jan 17, 2024
1 parent 7506f35 commit b00c77b
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 84 deletions.
9 changes: 5 additions & 4 deletions examples/advanced-pytorch/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Advanced Flower Example (PyTorch)

This example demonstrates an advanced federated learning setup using Flower with PyTorch. It differs from the quickstart example in the following ways:
This example demonstrates an advanced federated learning setup using Flower with PyTorch. This example uses [Flower Datasets](https://flower.dev/docs/datasets/) and it differs from the quickstart example in the following ways:

- 10 clients (instead of just 2)
- Each client holds a local dataset of 5000 training examples and 1000 test examples (note that using the `run.sh` script will only select 10 data samples by default, as the `--toy` argument is set).
Expand Down Expand Up @@ -59,12 +59,13 @@ pip install -r requirements.txt

The included `run.sh` will start the Flower server (using `server.py`),
sleep for 2 seconds to ensure that the server is up, and then start 10 Flower clients (using `client.py`) with only a small subset of the data (in order to run on any machine),
but this can be changed by removing the `--toy True` argument in the script. You can simply start everything in a terminal as follows:
but this can be changed by removing the `--toy` argument in the script. You can simply start everything in a terminal as follows:

```shell
poetry run ./run.sh
# After activating your environment
./run.sh
```

The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill).

You can also manually run `poetry run python3 server.py` and `poetry run python3 client.py` for as many clients as you want but you have to make sure that each command is ran in a different terminal window (or a different computer on the network).
You can also manually run `python3 server.py` and `python3 client.py --client-id <ID>` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network).
41 changes: 18 additions & 23 deletions examples/advanced-pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
import argparse
from collections import OrderedDict
import warnings
import datasets

warnings.filterwarnings("ignore")


class CifarClient(fl.client.NumPyClient):
def __init__(
self,
trainset: torchvision.datasets,
testset: torchvision.datasets,
device: str,
trainset: datasets.Dataset,
testset: datasets.Dataset,
device: torch.device,
validation_split: int = 0.1,
):
self.device = device
Expand All @@ -41,17 +42,14 @@ def fit(self, parameters, config):
batch_size: int = config["batch_size"]
epochs: int = config["local_epochs"]

n_valset = int(len(self.trainset) * self.validation_split)
train_valid = self.trainset.train_test_split(self.validation_split)
trainset = train_valid["train"]
valset = train_valid["test"]

valset = torch.utils.data.Subset(self.trainset, range(0, n_valset))
trainset = torch.utils.data.Subset(
self.trainset, range(n_valset, len(self.trainset))
)
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(valset, batch_size=batch_size)

trainLoader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
valLoader = DataLoader(valset, batch_size=batch_size)

results = utils.train(model, trainLoader, valLoader, epochs, self.device)
results = utils.train(model, train_loader, val_loader, epochs, self.device)

parameters_prime = utils.get_model_params(model)
num_examples_train = len(trainset)
Expand All @@ -73,13 +71,13 @@ def evaluate(self, parameters, config):
return float(loss), len(self.testset), {"accuracy": float(accuracy)}


def client_dry_run(device: str = "cpu"):
def client_dry_run(device: torch.device = "cpu"):
"""Weak tests to check whether all client methods are working as expected."""

model = utils.load_efficientnet(classes=10)
trainset, testset = utils.load_partition(0)
trainset = torch.utils.data.Subset(trainset, range(10))
testset = torch.utils.data.Subset(testset, range(10))
trainset = trainset.select(range(10))
testset = testset.select(range(10))
client = CifarClient(trainset, testset, device)
client.fit(
utils.get_model_params(model),
Expand All @@ -102,7 +100,7 @@ def main() -> None:
help="Do a dry-run to check the client",
)
parser.add_argument(
"--partition",
"--client-id",
type=int,
default=0,
choices=range(0, 10),
Expand All @@ -112,9 +110,7 @@ def main() -> None:
)
parser.add_argument(
"--toy",
type=bool,
default=False,
required=False,
action='store_true',
help="Set to true to quicky run the client using only 10 datasamples. \
Useful for testing purposes. Default: False",
)
Expand All @@ -136,12 +132,11 @@ def main() -> None:
client_dry_run(device)
else:
# Load a subset of CIFAR-10 to simulate the local data partition
trainset, testset = utils.load_partition(args.partition)
trainset, testset = utils.load_partition(args.client_id)

if args.toy:
trainset = torch.utils.data.Subset(trainset, range(10))
testset = torch.utils.data.Subset(testset, range(10))

trainset = trainset.select(range(10))
testset = testset.select(range(10))
# Start Flower client
client = CifarClient(trainset, testset, device)

Expand Down
1 change: 1 addition & 0 deletions examples/advanced-pytorch/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ authors = [
[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
flwr-datasets = { extras = ["vision"], version = ">=0.0.2,<1.0.0" }
torch = "1.13.1"
torchvision = "0.14.1"
validators = "0.18.2"
1 change: 1 addition & 0 deletions examples/advanced-pytorch/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
flwr>=1.0, <2.0
flwr-datasets[vision]>=0.0.2, <1.0.0
torch==1.13.1
torchvision==0.14.1
validators==0.18.2
9 changes: 3 additions & 6 deletions examples/advanced-pytorch/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,17 @@
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

# Download the CIFAR-10 dataset
python -c "from torchvision.datasets import CIFAR10; CIFAR10('./dataset', download=True)"

# Download the EfficientNetB0 model
python -c "import torch; torch.hub.load( \
'NVIDIA/DeepLearningExamples:torchhub', \
'nvidia_efficientnet_b0', pretrained=True)"

python server.py &
sleep 3 # Sleep for 3s to give the server enough time to start
python server.py --toy &
sleep 10 # Sleep for 10s to give the server enough time to start and dowload the dataset

for i in `seq 0 9`; do
echo "Starting client $i"
python client.py --partition=${i} --toy True &
python client.py --client-id=${i} --toy &
done

# Enable CTRL+C to stop all background processes
Expand Down
21 changes: 8 additions & 13 deletions examples/advanced-pytorch/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import warnings

from flwr_datasets import FederatedDataset

warnings.filterwarnings("ignore")


Expand Down Expand Up @@ -39,18 +41,13 @@ def evaluate_config(server_round: int):
def get_evaluate_fn(model: torch.nn.Module, toy: bool):
"""Return an evaluation function for server-side evaluation."""

# Load data and model here to avoid the overhead of doing it in `evaluate` itself
trainset, _, _ = utils.load_data()

n_train = len(trainset)
# Load data here to avoid the overhead of doing it in `evaluate` itself
centralized_data = utils.load_centralized_data()
if toy:
# use only 10 samples as validation set
valset = torch.utils.data.Subset(trainset, range(n_train - 10, n_train))
else:
# Use the last 5k training examples as a validation set
valset = torch.utils.data.Subset(trainset, range(n_train - 5000, n_train))
centralized_data = centralized_data.select(range(10))

valLoader = DataLoader(valset, batch_size=16)
val_loader = DataLoader(centralized_data, batch_size=16)

# The `evaluate` function will be called after every round
def evaluate(
Expand All @@ -63,7 +60,7 @@ def evaluate(
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)

loss, accuracy = utils.test(model, valLoader)
loss, accuracy = utils.test(model, val_loader)
return loss, {"accuracy": accuracy}

return evaluate
Expand All @@ -79,9 +76,7 @@ def main():
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--toy",
type=bool,
default=False,
required=False,
action='store_true',
help="Set to true to use only 10 datasamples for validation. \
Useful for testing purposes. Default: False",
)
Expand Down
77 changes: 39 additions & 38 deletions examples/advanced-pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,45 @@
import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
from torch.utils.data import DataLoader

import warnings

warnings.filterwarnings("ignore")

# DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from flwr_datasets import FederatedDataset

warnings.filterwarnings("ignore")

def load_data():
"""Load CIFAR-10 (training and test set)."""
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)

trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
testset = CIFAR10("./dataset", train=False, download=True, transform=transform)
def load_partition(node_id, toy: bool = False):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10})
partition = fds.load_partition(node_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2)
partition_train_test = partition_train_test.with_transform(apply_transforms)
return partition_train_test["train"], partition_train_test["test"]

num_examples = {"trainset": len(trainset), "testset": len(testset)}
return trainset, testset, num_examples

def load_centralized_data():
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 10})
centralized_data = fds.load_full("test")
centralized_data = centralized_data.with_transform(apply_transforms)
return centralized_data

def load_partition(idx: int):
"""Load 1/10th of the training and test data to simulate a partition."""
assert idx in range(10)
trainset, testset, num_examples = load_data()
n_train = int(num_examples["trainset"] / 10)
n_test = int(num_examples["testset"] / 10)

train_parition = torch.utils.data.Subset(
trainset, range(idx * n_train, (idx + 1) * n_train)
)
test_parition = torch.utils.data.Subset(
testset, range(idx * n_test, (idx + 1) * n_test)
)
return (train_parition, test_parition)
def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
pytorch_transforms = Compose([
Resize(256),
CenterCrop(224),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch


def train(net, trainloader, valloader, epochs, device: str = "cpu"):
def train(net, trainloader, valloader, epochs,
device: torch.device = torch.device("cpu")):
"""Train the network on the training set."""
print("Starting training...")
net.to(device) # move model to GPU if available
Expand All @@ -53,7 +49,8 @@ def train(net, trainloader, valloader, epochs, device: str = "cpu"):
)
net.train()
for _ in range(epochs):
for images, labels in trainloader:
for batch in trainloader:
images, labels = batch["img"], batch["label"]
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
loss = criterion(net(images), labels)
Expand All @@ -74,15 +71,17 @@ def train(net, trainloader, valloader, epochs, device: str = "cpu"):
return results


def test(net, testloader, steps: int = None, device: str = "cpu"):
def test(net, testloader, steps: int = None,
device: torch.device = torch.device("cpu")):
"""Validate the network on the entire test set."""
print("Starting evalutation...")
net.to(device) # move model to GPU if available
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
net.eval()
with torch.no_grad():
for batch_idx, (images, labels) in enumerate(testloader):
for batch_idx, batch in enumerate(testloader):
images, labels = batch["img"], batch["label"]
images, labels = images.to(device), labels.to(device)
outputs = net(images)
loss += criterion(outputs, labels).item()
Expand All @@ -109,12 +108,14 @@ def load_efficientnet(entrypoint: str = "nvidia_efficientnet_b0", classes: int =
entrypoint: EfficientNet model to download.
For supported entrypoints, please refer
https://pytorch.org/hub/nvidia_deeplearningexamples_efficientnet/
classes: Number of classes in final classifying layer. Leave as None to get the downloaded
classes: Number of classes in final classifying layer. Leave as None to get
the downloaded
model untouched.
Returns:
EfficientNet Model
Note: One alternative implementation can be found at https://github.com/lukemelas/EfficientNet-PyTorch
Note: One alternative implementation can be found at
https://github.com/lukemelas/EfficientNet-PyTorch
"""
efficientnet = torch.hub.load(
"NVIDIA/DeepLearningExamples:torchhub", entrypoint, pretrained=True
Expand Down

0 comments on commit b00c77b

Please sign in to comment.