Skip to content

Commit

Permalink
example/implementation for FedBalancer, with a new datasampler category
Browse files Browse the repository at this point in the history
Add a new category, "datasampler", which selects trainers' data at FL rounds.

Add FedBalancer (Jaemin Shin et al., FedBalancer: Data and Pace Control for Efficient Federated Learning on Heterogeneous Clients, MobiSys'22) as a new datasampler, which actively selects more important training samples of trainers to speed up global FL.
Implement a control scheme of "deadline", which is only used for fedbalancer's sample selection at this version. Deadline-based round termination will be supported in later updates.

Refer to lib/python/flame/examples/fedbalancer_mnist/ for example of running fedbalancer

Things that current version of fedbalancer do not support:
- Advanced trainer selection with Oort proposed in FedBalancer
- Other FL modes: hybrid, hierarchical
  • Loading branch information
jaemin-shin committed Apr 5, 2023
1 parent c71ffc0 commit adc2b06
Show file tree
Hide file tree
Showing 20 changed files with 1,554 additions and 30 deletions.
61 changes: 36 additions & 25 deletions lib/python/flame/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from .typing import ModelWeights
from .constants import DeviceType

PYTORCH = 'torch'
TENSORFLOW = 'tensorflow'
PYTORCH = "torch"
TENSORFLOW = "tensorflow"


class MLFramework(Enum):
Expand All @@ -43,7 +43,8 @@ class MLFramework(Enum):

ml_framework_in_use = MLFramework.UNKNOWN
valid_frameworks = [
framework.name.lower() for framework in MLFramework
framework.name.lower()
for framework in MLFramework
if framework != MLFramework.UNKNOWN
]

Expand Down Expand Up @@ -73,13 +74,16 @@ def get_ml_framework_in_use():

return ml_framework_in_use


def get_params_detached_pytorch(model):
"""Return copy of parameters of pytorch model disconnected from graph."""
return [param.detach().clone() for param in model.parameters()]


def get_params_as_vector_pytorch(params):
"""Return the list of parameters passed in concatenated into one vector."""
import torch

vector = None
for param in params:
if not isinstance(vector, torch.Tensor):
Expand All @@ -88,37 +92,39 @@ def get_params_as_vector_pytorch(params):
vector = torch.cat((vector, param.reshape(-1)), 0)
return vector


def get_dataset_filename(link):
"""Return path for file location"""
# currently only supports https and local file
if link.startswith('https://'):
if link.startswith("https://"):
import requests

r = requests.get(link, allow_redirects=True)

try:
filename = link.split('/')[-1]
open(filename, 'wb').write(r.content)
filename = link.split("/")[-1]
open(filename, "wb").write(r.content)
except:
filename = 'data'
open(filename, 'wb').write(r.content)
filename = "data"
open(filename, "wb").write(r.content)

return filename
elif link.startswith('file://'):

elif link.startswith("file://"):
return link[7:]

raise TypeError('link format not supported; use either https:// or file://')
raise TypeError("link format not supported; use either https:// or file://")


@contextmanager
def background_thread_loop():

def run_forever(loop):
asyncio.set_event_loop(loop)
loop.run_forever()

_loop = asyncio.new_event_loop()

_thread = Thread(target=run_forever, args=(_loop, ), daemon=True)
_thread = Thread(target=run_forever, args=(_loop,), daemon=True)
_thread.start()
yield _loop

Expand All @@ -134,11 +140,11 @@ def run_async(coro, loop, timeout=None):
def install_packages(packages: List[str]) -> None:
for package in packages:
if not install_package(package):
print(f'Failed to install package: {package}')
print(f"Failed to install package: {package}")


def install_package(package: str) -> bool:
if pipmain(['install', package]) == 0:
if pipmain(["install", package]) == 0:
return True

return False
Expand All @@ -151,20 +157,22 @@ def mlflow_runname(config: Config) -> str:
if val in config.realm:
groupby_value = groupby_value + val + "-"

return config.role + '-' + groupby_value + config.task_id[:8]
return config.role + "-" + groupby_value + config.task_id[:8]


def delta_weights_pytorch(a: ModelWeights,
b: ModelWeights) -> Union[ModelWeights, None]:
def delta_weights_pytorch(
a: ModelWeights, b: ModelWeights
) -> Union[ModelWeights, None]:
"""Return delta weights for pytorch model weights."""
if a is None or b is None:
return None

return {x: a[x] - b[y] for (x, y) in zip(a, b)}


def delta_weights_tensorflow(a: ModelWeights,
b: ModelWeights) -> Union[ModelWeights, None]:
def delta_weights_tensorflow(
a: ModelWeights, b: ModelWeights
) -> Union[ModelWeights, None]:
"""Return delta weights for tensorflow model weights."""
if a is None or b is None:
return None
Expand All @@ -174,27 +182,30 @@ def delta_weights_tensorflow(a: ModelWeights,

def get_pytorch_device(dtype: DeviceType):
import torch

if dtype == DeviceType.CPU:
device_name = "cpu"
elif dtype == DeviceType.GPU:
device_name = "cuda"
else:
raise TypeError(f"Device type {dtype} is not supported.")

return torch.device(device_name)


def weights_to_device(weights, dtype: DeviceType):
"""Send model weights to device type dtype."""

framework = get_ml_framework_in_use()
if framework == MLFramework.TENSORFLOW:
return weights
elif framework == MLFramework.PYTORCH:
torch_device = get_pytorch_device(dtype)
return {name: weights[name].to(torch_device) for name in weights}

return None


def weights_to_model_device(weights, model):
"""Send model weights to same device as model"""
framework = get_ml_framework_in_use()
Expand All @@ -204,5 +215,5 @@ def weights_to_model_device(weights, model):
# make assumption all tensors are on same device
torch_device = next(model.parameters()).device
return {name: weights[name].to(torch_device) for name in weights}

return None
17 changes: 17 additions & 0 deletions lib/python/flame/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ class SelectorType(str, Enum):
OORT = "oort"


class DataSamplerType(str, Enum):
"""Define datasampler types."""

DEFAULT = "default"
FEDBALANCER = "fedbalancer"


class Job(FlameSchema):
job_id: str = Field(alias="id")
name: str
Expand All @@ -90,6 +97,11 @@ class Selector(FlameSchema):
kwargs: dict = Field(default={})


class DataSampler(FlameSchema):
sort: DataSamplerType = Field(default=DataSamplerType.DEFAULT)
kwargs: dict = Field(default={})


class Optimizer(FlameSchema):
sort: OptimizerType = Field(default=OptimizerType.DEFAULT)
kwargs: dict = Field(default={})
Expand Down Expand Up @@ -172,6 +184,7 @@ def __init__(self, config_path: str):
job: Job
registry: t.Optional[Registry]
selector: t.Optional[Selector]
datasampler: t.Optional[DataSampler] = Field(default=DataSampler())
optimizer: t.Optional[Optimizer] = Field(default=Optimizer())
dataset: str
max_run_time: int
Expand Down Expand Up @@ -224,6 +237,10 @@ def transform_config(raw_config: dict) -> dict:
if raw_config.get("optimizer", None):
config_data = config_data | {"optimizer": raw_config.get("optimizer")}

if raw_config.get("datasampler", None):
raw_config["datasampler"]["kwargs"].update(hyperparameters)
config_data = config_data | {"datasampler": raw_config.get("datasampler")}

config_data = config_data | {
"dataset": raw_config.get("dataset", ""),
"max_run_time": raw_config.get("maxRunTime", 300),
Expand Down
78 changes: 78 additions & 0 deletions lib/python/flame/datasampler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""datasampler abstract class."""

from abc import ABC, abstractmethod
from typing import Any

from flame.channel import Channel


class AbstractDataSampler(ABC):
class AbstractTrainerDataSampler(ABC):
"""Abstract base class for trainer-side datasampler implementation."""

def __init__(self, **kwargs) -> None:
"""Initialize an instance with keyword-based arguments."""
for key, value in kwargs.items():
setattr(self, key, value)

@abstractmethod
def sample(self, dataset: Any, **kwargs) -> Any:
"""Abstract method to sample data.
Parameters
----------
dataset: Dataset of a trainer to select samples from
kwargs: other arguments specific to each datasampler algorithm
Returns
-------
dataset: Dataset that only contains selected samples
"""

@abstractmethod
def load_dataset(self, dataset: Any) -> Any:
"""Process dataset instance for datasampler."""

@abstractmethod
def get_metadata(self) -> dict[str, Any]:
"""Return metadata to send to aggregator-side datasampler."""

@abstractmethod
def handle_metadata_from_aggregator(self, metadata: dict[str, Any]) -> None:
"""Handle aggregator metadata for datasampler."""

class AbstractAggregatorDataSampler(ABC):
"""Abstract base class for aggregator-side datasampler implementation."""

def __init__(self, **kwargs) -> None:
"""Initialize an instance with keyword-based arguments."""
for key, value in kwargs.items():
setattr(self, key, value)

@abstractmethod
def get_metadata(self, end: str, round: int) -> dict[str, Any]:
"""Return metadata to send to trainer-side datasampler."""

@abstractmethod
def handle_metadata_from_trainer(
self,
metadata: dict[str, Any],
end: str,
channel: Channel,
) -> None:
"""Handle trainer metadata for datasampler."""
77 changes: 77 additions & 0 deletions lib/python/flame/datasampler/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""DefaultDataSampler class."""

import logging
from typing import Any

from flame.channel import Channel
from flame.datasampler import AbstractDataSampler

logger = logging.getLogger(__name__)


class DefaultDataSampler(AbstractDataSampler):
def __init__(self) -> None:
self.trainer_data_sampler = DefaultDataSampler.DefaultTrainerDataSampler()
self.aggregator_data_sampler = DefaultDataSampler.DefaultAggregatorDataSampler()

class DefaultTrainerDataSampler(AbstractDataSampler.AbstractTrainerDataSampler):
"""A default trainer-side datasampler class."""

def __init__(self, **kwargs):
"""Initailize instance."""
super().__init__()

def sample(self, dataset: Any, **kwargs) -> Any:
"""Return all dataset from the given dataset."""
logger.debug("calling default datasampler")

return dataset

def load_dataset(self, dataset: Any) -> None:
"""Change dataset instance to return index with each sample."""
return dataset

def get_metadata(self) -> dict[str, Any]:
"""Return metadata to send to aggregator-side datasampler."""
return {}

def handle_metadata_from_aggregator(self, metadata: dict[str, Any]) -> None:
"""Handle aggregator metadata for datasampler."""
pass

class DefaultAggregatorDataSampler(
AbstractDataSampler.AbstractAggregatorDataSampler
):
"""A default aggregator-side datasampler class."""

def __init__(self, **kwargs):
"""Initailize instance."""
super().__init__()

def get_metadata(self, end: str, round: int) -> dict[str, Any]:
"""Return metadata to send to trainer-side datasampler."""
return {}

def handle_metadata_from_trainer(
self,
metadata: dict[str, Any],
end: str,
channel: Channel,
) -> None:
"""Handle trainer metadata for datasampler."""
pass
Loading

0 comments on commit adc2b06

Please sign in to comment.