Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feddyn algorithm correction #401

Merged
merged 1 commit into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lib/python/examples/medmnist_feddyn/aggregator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""MedMNIST FedProx aggregator for PyTorch."""
"""MedMNIST FedDyn aggregator for PyTorch."""

import logging

Expand Down Expand Up @@ -131,7 +131,7 @@ def train(self) -> None:
pass

def evaluate(self) -> None:
"""Evaluate (test) a model."""
"""Evaluate a model."""
self.model.eval()
loss_lst = list()
labels = torch.tensor([],device=self.device)
Expand All @@ -140,7 +140,7 @@ def evaluate(self) -> None:
for data, label in self.loader:
data, label = data.to(self.device), label.to(self.device)
output = self.model(data)
loss = self.criterion(output, label.squeeze())
loss = self.criterion(output, label.reshape(-1).long())
loss_lst.append(loss.item())
labels_pred = torch.cat([labels_pred, output.argmax(dim=1)], dim=0)
labels = torch.cat([labels, label], dim=0)
Expand Down
Binary file modified lib/python/examples/medmnist_feddyn/images/accuracy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 8 additions & 13 deletions lib/python/examples/medmnist_feddyn/trainer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""MedMNIST FedProx trainer for PyTorch Using Proximal Term."""



"""MedMNIST FedDyn trainer for PyTorch."""

import logging

Expand Down Expand Up @@ -72,8 +69,6 @@ def __init__(self, split, filename, transform=None, as_rgb=False):
elif self.split == 'test':
self.imgs = npz_file['test_images']
self.labels = npz_file['test_labels']
else:
raise ValueError

def __len__(self):
return self.imgs.shape[0]
Expand All @@ -99,22 +94,23 @@ def __init__(self, config: Config) -> None:
self.dataset_size = 0

self.model = None
self.device = torch.device("cpu")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.train_loader = None
self.val_loader = None

self.epochs = self.config.hyperparameters.epochs
self.batch_size = self.config.hyperparameters.batch_size
self.lr = self.config.hyperparameters.learning_rate
self._round = 1
self._rounds = self.config.hyperparameters.rounds

def initialize(self) -> None:
"""Initialize role."""

self.model = CNN(num_classes=9).to(self.device)
# ensure that weight_decay = 0 for FedDyn
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3, weight_decay=0.0)
# ensure that weight_decay = 0 for FedDyn if this parameter is specified in config file
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=0.0)
self.criterion = torch.nn.CrossEntropyLoss()

def load_data(self) -> None:
Expand Down Expand Up @@ -170,10 +166,9 @@ def train(self) -> None:
self.optimizer.zero_grad()
output = self.model(data)

# proximal term included in loss
loss = self.criterion(output, label.squeeze()) + self.regularizer.get_term(curr_model = self.model, prev_model = prev_model)
# regularizer term included in loss
loss = self.criterion(output, label.reshape(-1).long()) + self.regularizer.get_term(curr_model = self.model, prev_model = prev_model)

# back to normal stuff
loss_lst.append(loss.item())
loss.backward()
self.optimizer.step()
Expand All @@ -191,7 +186,7 @@ def evaluate(self) -> None:
for data, label in self.val_loader:
data, label = data.to(self.device), label.to(self.device)
output = self.model(data)
loss = self.criterion(output, label.squeeze())
loss = self.criterion(output, label.reshape(-1).long())
loss_lst.append(loss.item())
labels_pred = torch.cat([labels_pred, output.argmax(dim=1)], dim=0)
labels = torch.cat([labels, label], dim=0)
Expand Down
69 changes: 67 additions & 2 deletions lib/python/flame/mode/horizontal/feddyn/top_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
"""FedDyn horizontal FL top level aggregator."""

import logging
from copy import deepcopy
import time

from flame.common.util import (MLFramework, get_ml_framework_in_use,
weights_to_device)
weights_to_device, weights_to_model_device)
from flame.common.constants import (DeviceType, TrainState)
from flame.optimizer.train_result import TrainResult
from flame.mode.composer import Composer
from flame.mode.message import MessageType
from flame.mode.tasklet import Loop, Tasklet
Expand All @@ -31,6 +34,8 @@
TAG_AGGREGATE = 'aggregate'
TAG_GET_DATATSET_SIZE = 'getDatasetSize'

PROP_ROUND_END_TIME = "round_end_time"

class TopAggregator(BaseTopAggregator):
"""FedDyn Top level Aggregator implements an ML aggregation role."""

Expand All @@ -52,6 +57,62 @@ def get(self, tag: str) -> None:
self._aggregate_weights(tag)
elif tag == TAG_GET_DATATSET_SIZE:
self.get_dataset_size(tag)

def _aggregate_weights(self, tag: str) -> None:
channel = self.cm.get_by_tag(tag)
if not channel:
return

total = 0

# receive local model parameters from trainers
for msg, metadata in channel.recv_fifo(channel.ends()):
end, timestamp = metadata
if not msg:
logger.debug(f"No data from {end}; skipping it")
continue

logger.debug(f"received data from {end}")
channel.set_end_property(end, PROP_ROUND_END_TIME, (round, timestamp))

if MessageType.WEIGHTS in msg:
weights = weights_to_model_device(msg[MessageType.WEIGHTS], self.model)

if MessageType.DATASET_SIZE in msg:
count = msg[MessageType.DATASET_SIZE]

if MessageType.DATASAMPLER_METADATA in msg:
self.datasampler.handle_metadata_from_trainer(
msg[MessageType.DATASAMPLER_METADATA],
end,
channel,
)

logger.debug(f"{end}'s parameters trained with {count} samples")

if weights is not None and count > 0:
total += count
tres = TrainResult(weights, count)
# save training result from trainer in a disk cache
self.cache[end] = tres

# optimizer conducts optimization (in this case, aggregation)
global_weights = self.optimizer.do(
deepcopy(self.cld_weights),
self.cache,
total=total,
num_trainers=len(channel.ends()),
)
if global_weights is None:
logger.debug("failed model aggregation")
time.sleep(1)
return

# set global weights
self.weights = global_weights

# update model with global weights
self._update_model()

def get_dataset_size(self, tag: str) -> None:
logger.debug("calling get_dataset_size")
Expand Down Expand Up @@ -96,12 +157,16 @@ def _distribute_weights(self, tag: str) -> None:
weight_dict = {end:(self.dataset_sizes[end]/total_samples) * num_trainers for end in self.dataset_sizes}

logger.debug(f"weight_dict: {weight_dict}")

self.cld_weights = self.optimizer.cld_model
if self.cld_weights == None:
self.cld_weights = self.weights

# send out global model parameters to trainers
for end in channel.ends():
logger.debug(f"sending weights to {end}")
channel.send(end, {
MessageType.WEIGHTS: weights_to_device(self.weights, DeviceType.CPU),
MessageType.WEIGHTS: weights_to_device(self.cld_weights, DeviceType.CPU),
MessageType.ROUND: self._round,
MessageType.ALPHA_ADPT: self.optimizer.alpha / weight_dict.get(end, 1)
})
Expand Down
6 changes: 4 additions & 2 deletions lib/python/flame/optimizer/feddyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, alpha, weight_decay):
self.alpha = alpha
self.weight_decay = weight_decay
self.local_param_dict = dict()
self.cld_model = None

# override parent's self.regularizer
self.regularizer = FedDynRegularizer(self.alpha, self.weight_decay)
Expand Down Expand Up @@ -114,10 +115,10 @@ def do(self,
h = self.local_param_dict[end]
mean_local_param = {k:v + rate*h[k] for (k,v) in mean_local_param.items()}


# keep this model as the initial model for the next round of training
self.cld_model = {k:avg_model[k]+mean_local_param[k] for k in avg_model}

return self.cld_model
return avg_model

def add_to_hist(self, end, tres):
if end in self.local_param_dict:
Expand All @@ -130,3 +131,4 @@ def add_to_hist(self, end, tres):
# case: end was not previously recorded as active trainer
logger.debug(f"adding untracked end {end} to hist terms")
self.local_param_dict[end] = tres.weights

17 changes: 17 additions & 0 deletions lib/python/flame/optimizer/regularizer/feddyn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,20 @@ def update(self):

# adjust prev_grad
self.prev_grad += (w_vector - w_t_vector)

def to(self, device):
"""Returns a new Regularizer that has been moved to the specified device."""
import torch
new_regularizer = FedDynRegularizer(self.alpha, self.weight_decay)
# state_dict
for key in self.state_dict:
new_regularizer.state_dict[key] = [param.to(device) for param in self.state_dict[key]]

# prev_grad
if isinstance(self.prev_grad, torch.Tensor):
new_regularizer.prev_grad = self.prev_grad.to(device)
else:
new_regularizer.prev_grad = self.prev_grad

return new_regularizer