Skip to content

Commit

Permalink
Update PyTorch E2E test (adap#2072)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored and alessiomora committed Aug 30, 2023
1 parent f32e1bc commit 03b0e33
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ env:
jobs:
pytorch:
runs-on: ubuntu-22.04
timeout-minutes: 10
steps:
- uses: actions/checkout@v3
- name: Set up Python
Expand All @@ -39,7 +40,7 @@ jobs:
run: |
cd e2e/pytorch
python -c "from torchvision.datasets import CIFAR10; CIFAR10('./data', download=True)"
- name: Run tests
- name: Run edge client test
run: |
cd e2e/pytorch
./test.sh
Expand Down
27 changes: 14 additions & 13 deletions e2e/pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,25 @@ class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)

def fit(self, parameters, config):
self.set_parameters(parameters)
set_parameters(net, parameters)
train(net, trainloader, epochs=1)
return self.get_parameters(config={}), len(trainloader.dataset), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
set_parameters(net, parameters)
loss, accuracy = test(net, testloader)
return loss, len(testloader.dataset), {"accuracy": accuracy}


# Start Flower client
fl.client.start_numpy_client(
server_address="127.0.0.1:8080",
client=FlowerClient(),
)
def set_parameters(model, parameters):
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
return

if __name__ == "__main__":
# Start Flower client
fl.client.start_numpy_client(
server_address="127.0.0.1:8080",
client=FlowerClient(),
)

0 comments on commit 03b0e33

Please sign in to comment.