Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fastai E2E test #2093

Merged
merged 12 commits into from
Jul 12, 2023
39 changes: 39 additions & 0 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ jobs:
cd e2e/scikit-learn
python simulation.py


opacus:
runs-on: ubuntu-22.04
timeout-minutes: 10
Expand Down Expand Up @@ -289,3 +290,41 @@ jobs:
run: |
cd e2e/opacus
python simulation.py


fastai:
runs-on: ubuntu-22.04
timeout-minutes: 10
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.8.17
- name: Install
run: |
python -m pip install -U pip==22.3.1
python -m pip install -U setuptools==65.6.3
python -m pip install poetry==1.3.2
poetry config virtualenvs.create false
- name: Install dependencies
run: |
cd e2e/fastai
python -m poetry install
- name: Cache Datasets
uses: actions/cache@v2
with:
path: "~/.fastai"
key: fastai-datasets
- name: Download Datasets
run: |
cd e2e/fastai
python -c "from fastai.vision.all import *; untar_data(URLs.MNIST)"
- name: Run edge client test
run: |
cd e2e/fastai
./test.sh
- name: Run virtual client test
run: |
cd e2e/fastai
python simulation.py
5 changes: 5 additions & 0 deletions e2e/fastai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Flower with FastAI testing

This directory is used for testing Flower with FastAI by using a simple MNIST recognition task.

It uses the `FedAvg` strategy.
57 changes: 57 additions & 0 deletions e2e/fastai/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import warnings
from collections import OrderedDict

import flwr as fl
import numpy as np
import torch
from fastai.vision.all import *


warnings.filterwarnings("ignore", category=UserWarning)

# Download MNIST dataset
path = untar_data(URLs.MNIST)

# Load dataset
dls = ImageDataLoaders.from_folder(
path, valid_pct=0.5, train="training", valid="testing", num_workers=0
)

subset_size = 100 # Or whatever
selected_train = np.random.choice(dls.train_ds.items, subset_size, replace=False)
selected_valid = np.random.choice(dls.valid_ds.items, subset_size, replace=False)
# Swap in the subset for the whole thing (Note: this mutates dls, so re-initialize before full training!)
dls.train = dls.test_dl(selected_train, with_labels=True)
dls.valid = dls.test_dl(selected_valid, with_labels=True)

# Define model
learn = vision_learner(dls, squeezenet1_1, metrics=error_rate)


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in learn.model.state_dict().items()]

def set_parameters(self, parameters):
params_dict = zip(learn.model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
learn.model.load_state_dict(state_dict, strict=True)

def fit(self, parameters, config):
self.set_parameters(parameters)
learn.fit(1)
return self.get_parameters(config={}), len(dls.train), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, error_rate = learn.validate()
return loss, len(dls.valid), {"accuracy": 1 - error_rate}


if __name__ == "__main__":
# Start Flower client
fl.client.start_numpy_client(
server_address="127.0.0.1:8080",
client=FlowerClient(),
)
14 changes: 14 additions & 0 deletions e2e/fastai/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[build-system]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "quickstart_fastai"
version = "0.1.0"
description = "Fastai Federated Learning E2E test with Flower"
authors = ["The Flower Authors <hello@flower.dev>"]

[tool.poetry.dependencies]
python = ">=3.7,<3.10"
flwr = { path = "../../", develop = true, extras = ["simulation"] }
fastai = "^2.7.10"
7 changes: 7 additions & 0 deletions e2e/fastai/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import flwr as fl

hist = fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=3),
)
assert (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 1
14 changes: 14 additions & 0 deletions e2e/fastai/simulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import flwr as fl

from client import FlowerClient

def client_fn(cid):
_ = cid
return FlowerClient()

hist = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
)
assert (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) > 0.98
5 changes: 5 additions & 0 deletions e2e/fastai/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash
set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/

bash ../test.sh
Loading