Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

medmnist example update #312

Merged
merged 1 commit into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/python/flame/examples/medmnist/aggregator/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"backend": "mqtt",
"brokers": [
{
"host": "flame-mosquitto",
"host": "broker.hivemq.com",
"sort": "mqtt"
}
],
Expand All @@ -27,7 +27,7 @@
}
}
],
"dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz",
"dataset": "https://raw.github.com/GaoxiangLuo/flame-datasets/main/site1.npz",
"dependencies": [
"numpy >= 1.2.0"
],
Expand Down
2 changes: 1 addition & 1 deletion lib/python/flame/examples/medmnist/aggregator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,4 @@ def evaluate(self) -> None:

a = PyTorchMedMNistAggregator(config)
a.compose()
a.run()
a.run()
4 changes: 2 additions & 2 deletions lib/python/flame/examples/medmnist/trainer/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"backend": "mqtt",
"brokers": [
{
"host": "flame-mosquitto",
"host": "broker.hivemq.com",
"sort": "mqtt"
}
],
Expand All @@ -27,7 +27,7 @@
}
}
],
"dataset": "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz",
"dataset": "https://raw.github.com/GaoxiangLuo/flame-datasets/main/site1.npz",
"dependencies": [
"numpy >= 1.2.0"
],
Expand Down
63 changes: 53 additions & 10 deletions lib/python/flame/examples/medmnist/trainer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

import logging
from flame.common.util import install_packages
install_packages(['scikit-learn', 'medmnist'])
install_packages(['scikit-learn'])

from flame.config import Config
from flame.mode.horizontal.trainer import Trainer
import torch
import torchvision
import medmnist
from medmnist import INFO
import numpy as np
from PIL import Image
from sklearn.metrics import accuracy_score

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,6 +55,40 @@ def forward(self, x):
x = self.fc(x)
return x

class PathMNISTDataset(torch.utils.data.Dataset):
def __init__(self, split, transform=None, as_rgb=False):
npz_file = np.load("pathmnist.npz")
self.split = split
self.transform = transform
self.as_rgb = as_rgb

if self.split == 'train':
self.imgs = npz_file['train_images']
self.labels = npz_file['train_labels']
elif self.split == 'val':
self.imgs = npz_file['val_images']
self.labels = npz_file['val_labels']
elif self.split == 'test':
self.imgs = npz_file['test_images']
self.labels = npz_file['test_labels']
else:
raise ValueError

def __len__(self):
return self.imgs.shape[0]

def __getitem__(self, index):
img, target = self.imgs[index], self.labels[index].astype(int)
img = Image.fromarray(img)

if self.as_rgb:
img = img.convert('RGB')

if self.transform is not None:
img = self.transform(img)

return img, target

class PyTorchMedMNistTrainer(Trainer):
"""PyTorch MedMNist Trainer"""

Expand All @@ -70,6 +104,8 @@ def __init__(self, config: Config) -> None:

self.epochs = self.config.hyperparameters['epochs']
self.batch_size = self.config.hyperparameters['batchSize']
self._round = 1
self._rounds = self.config.hyperparameters['rounds']

def initialize(self) -> None:
"""Initialize role."""
Expand All @@ -86,15 +122,16 @@ def load_data(self) -> None:
"MedMNIST Classification Decathlon: A Lightweight AutoML Benchmark for Medical Image Analysis".
Dataset Repo: https://github.com/MedMNIST/MedMNIST
"""
info = INFO["pathmnist"]
DataClass = getattr(medmnist, info['python_class'])

self._download()

data_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
train_dataset = DataClass(split='train', transform=data_transform, download=True)
val_dataset = DataClass(split='val', transform=data_transform, download=True)

train_dataset = PathMNISTDataset(split='train', transform=data_transform)
val_dataset = PathMNISTDataset(split='val', transform=data_transform)

self.train_loader = torch.utils.data.DataLoader(
train_dataset,
Expand All @@ -115,6 +152,11 @@ def load_data(self) -> None:

self.dataset_size = len(train_dataset)

def _download(self) -> None:
import requests
r = requests.get(self.config.dataset, allow_redirects=True)
open('pathmnist.npz', 'wb').write(r.content)

def train(self) -> None:
"""Train a model."""
self.model.load_state_dict(self.weights)
Expand All @@ -134,7 +176,7 @@ def train(self) -> None:

train_loss = sum(loss_lst) / len(loss_lst)
self.update_metrics({"Train Loss": train_loss})
logger.info(f"Epoch: {epoch+1}, Train Loss: {train_loss}")
# logger.info(f"Epoch: {epoch+1}, Train Loss: {train_loss}")

def evaluate(self) -> None:
"""Evaluate a model."""
Expand All @@ -156,9 +198,11 @@ def evaluate(self) -> None:
val_acc = accuracy_score(labels, labels_pred)

val_loss = sum(loss_lst) / len(loss_lst)
self.update_metrics({"Val Loss": val_loss, "Val Accuracy": val_acc})
self.update_metrics({"Val Loss": val_loss, "Val Accuracy": val_acc, "Testset Size": self.dataset_size})
logger.info(f"Test Loss: {val_loss}")
logger.info(f"Test Accuracy: {val_acc}")
logger.info(f"Testset Size: {self.dataset_size}")


if __name__ == "__main__":
import argparse
Expand All @@ -173,4 +217,3 @@ def evaluate(self) -> None:
t = PyTorchMedMNistTrainer(config)
t.compose()
t.run()