Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

example/implementation for FedBalancer, with a new sampler category #380

Merged
merged 1 commit into from
Apr 6, 2023

Conversation

jaemin-shin
Copy link
Collaborator

Add a new category, "sampler", which selects trainers' data at FL rounds.

Add FedBalancer (Jaemin Shin et al., FedBalancer: Data and Pace Control for Efficient Federated Learning on Heterogeneous Clients, MobiSys'22) as a new sampler, which actively selects more important training samples of trainers to speed up global FL. Implement a control scheme of "deadline", which is only used for fedbalancer's sample selection at this version. Deadline-based round termination will be supported in later updates.

Refer to lib/python/flame/examples/fedbalancer_mnist/ for example of running fedbalancer

Things that current version of fedbalancer do not support:

  • Advanced trainer selection with Oort proposed in FedBalancer
  • Other FL modes: hybrid, hierarchical

Description

Please provide a meaningful description of what this change will do, or is for. Bonus points for including links to
related issues, other PRs, or technical references.

Note that by not including a description, you are asking reviewers to do extra work to understand the context of this
change, which may lead to your PR taking much longer to review, or result in it not being reviewed at all.

Type of Change

  • Bug Fix
  • New Feature
  • Breaking Change
  • Refactor
  • Documentation
  • Other (please describe)

Checklist

  • I have read the contributing guidelines
  • Existing issues have been referenced (where applicable)
  • I have verified this change is not present in other open pull requests
  • Functionality is documented
  • All code style checks pass
  • New code contribution is covered by automated tests
  • All new and existing tests pass

@codecov-commenter
Copy link

codecov-commenter commented Mar 29, 2023

Codecov Report

Merging #380 (adc2b06) into main (c71ffc0) will not change coverage.
The diff coverage is n/a.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

@@           Coverage Diff           @@
##             main     #380   +/-   ##
=======================================
  Coverage   15.15%   15.15%           
=======================================
  Files          48       48           
  Lines        2778     2778           
=======================================
  Hits          421      421           
  Misses       2328     2328           
  Partials       29       29           

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

Copy link
Contributor

@myungjin myungjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left several comments.

if PYTORCH in sys.modules:
ml_framework_in_use = MLFramework.PYTORCH
elif TENSORFLOW in sys.modules:
if TENSORFLOW in sys.modules:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this reordering needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because loading default sampler on tensorflow keras examples (e.g., hybrid) makes ml_framework_in_use to be recognized as PyTorch, which results in error. But I agree that this is a bad workaround to the problem. Could you suggest how should we solve the problem?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is causing the problem? And how is this solving that, I don't see it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this update! Now it's working fine without it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem was, datasampler classes were importing pytorch, which makes keras examples to fail as their ml_framework_in_use are recognized as pytorch. Changed datasampler classes to not import pytorch when loaded

@@ -167,6 +179,7 @@ def __init__(self, config_path: str):
job: Job
registry: Registry
selector: Selector
sampler: Sampler
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it as an option, not a required field.
In case of missing config, we should use a default value.

This way, not all config files needs to be updated for this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@@ -60,6 +60,10 @@
"k": 1
}
},
"sampler": {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call it "datasampler" or "data-sampler"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, if this is an option, not every config needs to be modified for this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it to datasampler. Also, removed unnecessary config file updates.

@@ -36,6 +36,7 @@
from ...optimizers import optimizer_provider
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The files (top_aggregator.py and trainer.py) have moved to syncfl folder due to the recent change. The change made in these files need to be made the respective files in syncfl folder.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -118,9 +123,11 @@ def _aggregate_weights(self, tag: str) -> None:
return

total = 0

self.selected_ends = channel.ends()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be only updated for the ends which send message correctly?

perhaps should you define the following: self.selected_ends = list()?
then, should self.selected_ends.append(end) be called somewhere between lines 134-156?

Copy link
Collaborator Author

@jaemin-shin jaemin-shin Apr 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment actually let me find the part that I mistakenly implemented. Thank you very much :)

Selected_ends information is needed for deadline control, which selects appropriate deadline for the selected clients at a round. Thus, deadline control should happen at the beginning of the round, not the end of the round as in the previous code. Moved datasampler variable updates to the beginning of the round.

@@ -38,3 +38,6 @@ class MessageType(Enum):
MODEL_VERSION = 9 # model version used; an non-negative integer

STAT_UTILITY = 10 # measured utility of a trainer based on Oort

SAMPLER_TRAINER_METADATA = 11 # sampler metadata of trainers sent to aggregator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you need both types of message?

Would it be fine to have a single type called e.g., SAMPLER_METADATA?

The interpretation of metadata can be by each role which receives the metadata.
I feel this fine-grained message type may be restrictive. For example, what if we need another message that needs to be sent and received by middle-aggregator or coordinator?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Changed into a single type of message.

from typing import Any
from datetime import datetime

from ..channel import Channel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

from .sampler.fedbalancer import FedBalancerTrainerSampler, FedBalancerAggregatorSampler


class SamplerTrainerProvider(ObjectFactory):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we combine TrainerSampler and AggregatorSampler and put them into a single class?

Having them separate can cause a potential issue: e.g., aggregator chooses use default while trainer chooses fedbalancer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could combine them into a single class, and I agree to your concern. However, this is something that we discussed earlier and decided, as TrainerSampler and AggregatorSampler do not have a functionality in common. Do you think it is better to combine it despite the fact?

@jaemin-shin jaemin-shin force-pushed the fedbalancer branch 4 times, most recently from c7acbdb to 5efd46f Compare April 4, 2023 06:20
Copy link
Contributor

@lkurija1 lkurija1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a couple of notes, seems ok overall.

if PYTORCH in sys.modules:
ml_framework_in_use = MLFramework.PYTORCH
elif TENSORFLOW in sys.modules:
if TENSORFLOW in sys.modules:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is causing the problem? And how is this solving that, I don't see it

Comment on lines 62 to 67
except KeyError:
raise KeyError(
"one of the parameters among {w, lss, dss, p, noise_factor} "
+ "is not specified in config,\nrecommended set of parameters"
+ "are {20, 0.05, 0.05, 1.0, 0.0}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how to feel about this, it's further explaining the issue, but it's also generalising it, hiding the real problem. Not a big deal though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what do you mean by "generalising it, hiding the real problem"?

Comment on lines 90 to 94
if ml_framework_in_use != MLFramework.PYTORCH:
raise NotImplementedError(
"supported ml framework not found; "
f"supported frameworks are: {valid_frameworks}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't TENSORFLOW in valid_frameworks as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, thanks for this! Changed the error message here.

Comment on lines 109 to 122
if not self._is_first_selected_round and max_trainable_size > len(dataset):
sampled_indices = list(range(len(dataset)))
sampled_dataset = dataset
else:
# sample indexes with underthreshold loss (ut) and overthreshold loss (ot)
ut_indices, ot_indices = [], []

# read through the sample loss list and parse it in ut or ot indices lists,
# based on the loss threshold value
for idx, item in enumerate(self._sample_loss_list):
if item < self._loss_threshold:
ut_indices.append(idx)
else:
ot_indices.append(idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this, simple and clean

GustavBaumgart
GustavBaumgart previously approved these changes Apr 5, 2023
Copy link
Collaborator

@GustavBaumgart GustavBaumgart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall lgtm, left minor comments.



def install_package(package: str) -> bool:
if pipmain(['install', package]) == 0:
if pipmain(["install", package]) == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This apparently has the potential to change the LOG_LEVEL of the terminal. Not necessarily relevant to this PR though.

Comment on lines 35 to 45
"""Abstract method to sample data.

Parameters
----------
dataset: PyTorch Dataset of a trainer to select samples from
kwargs: other arguments specific to each datasampler algorithm

Returns
-------
dataset: PyTorch Dataset that only contains selected samples
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a guard put up so that this does not get initialized for tensorflow?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears this is left to the child classes right now, right? Is it the intention to enforce that AbstractDataSampler only be used for pytorch? You can put the guard at this level if that's the case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the current version, there is no guard for that, but we made datasamplers to not import torch themselves, which would make problems on tensorflow examples. Changed the description here to remove the word "PyTorch", as we could further support tensorflow datasets too

@@ -0,0 +1,24 @@
from torch.utils.data import Dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add copyright

Comment on lines +37 to +39
# generate json files
job_id = "622a358619ab59012eabeefb"
task_id = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why you prefer hardcoding these ids?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I randomly generate them now, but I don't know if it's a significant difference)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, just used random id for the example without a reason.

@@ -40,3 +40,5 @@ class MessageType(Enum):
STAT_UTILITY = 10 # measured utility of a trainer based on Oort

COORDINATED_ENDS = 11 # ends coordinated by a coordinator

DATASAMPLER_METADATA = 12 # datasampler metadata
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on which PR is merged first might lead to a merge conflict

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to check out run.py for the latest approach I am taking for this. I think yours is slightly more concise, so it may be better to keep it this way; I just figured I would share since this is more similar to the fedprox example's run.py.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Seems like we have slightly different style in run.py, I think we could keep it as is for now, but maybe we should consider to unify the style of run.py for examples...

myungjin
myungjin previously approved these changes Apr 5, 2023
Copy link
Contributor

@myungjin myungjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a few small comments. I am approving this PR, but you can address them now or in a separate PR. I am fine either way.

dataset: Any,
loss_fn: Any,
model: Any,
device: Any,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is device's type Any? Should it be a string?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it into string!

if end in self._round_communication_duration_history.keys():
expected_end_complete_time[end] = (
np.mean(self._round_communication_duration_history[end][-5:])
+ np.mean(self._round_epoch_train_duration_history[end][-5:])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why last five elements? it may be good to state rationale around this magic number.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the size of these arrays will keep increasing over rounds.
Would it incur memory pressure later?
If you are only interested in the last five elements, it may be good to remove 1st element once an array size reaches 5.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this to only maintain a list with length 5 for each! Also added a explanation comment about the choice of last five elements, that it is based on the authors' implementation at https://github.com/jaemin-shin/FedBalancer, # Commit 1d187c88de9b5b43e28c988b2423e9f616c80610.

To answer that question as an author, we chose to use last five history items to be adaptive at user's status changes in the real world. For example, if a user moved location and the network connection becomes unstable, we need to deal with such a change, and looking at recent data could help. Other policies such as exponential moving average could also work for this goal.

Add a new category, "datasampler", which selects trainers' data at FL rounds.

Add FedBalancer (Jaemin Shin et al., FedBalancer: Data and Pace Control for Efficient Federated Learning on Heterogeneous Clients, MobiSys'22) as a new datasampler, which actively selects more important training samples of trainers to speed up global FL.
Implement a control scheme of "deadline", which is only used for fedbalancer's sample selection at this version. Deadline-based round termination will be supported in later updates.

Refer to lib/python/flame/examples/fedbalancer_mnist/ for example of running fedbalancer

Things that current version of fedbalancer do not support:
- Advanced trainer selection with Oort proposed in FedBalancer
- Other FL modes: hybrid, hierarchical
Copy link
Contributor

@myungjin myungjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@jaemin-shin jaemin-shin merged commit e4044b2 into cisco-open:main Apr 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants