Skip to content

Commit

Permalink
example/implementation for FedBalancer, with a new datasampler category
Browse files Browse the repository at this point in the history
Add a new category, "datasampler", which selects trainers' data at FL rounds.

Add FedBalancer (Jaemin Shin et al., FedBalancer: Data and Pace Control for Efficient Federated Learning on Heterogeneous Clients, MobiSys'22) as a new datasampler, which actively selects more important training samples of trainers to speed up global FL.
Implement a control scheme of "deadline", which is only used for fedbalancer's sample selection at this version. Deadline-based round termination will be supported in later updates.

Refer to lib/python/flame/examples/fedbalancer_mnist/ for example of running fedbalancer

Things that current version of fedbalancer do not support:
- Advanced trainer selection with Oort proposed in FedBalancer
- Other FL modes: hybrid, hierarchical
  • Loading branch information
jaemin-shin committed Apr 4, 2023
1 parent c71ffc0 commit c7acbdb
Show file tree
Hide file tree
Showing 19 changed files with 1,510 additions and 8 deletions.
6 changes: 3 additions & 3 deletions lib/python/flame/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def determine_ml_framework_in_use():
"""Determine which ml framework in use."""
global ml_framework_in_use

if PYTORCH in sys.modules:
ml_framework_in_use = MLFramework.PYTORCH
elif TENSORFLOW in sys.modules:
if TENSORFLOW in sys.modules:
ml_framework_in_use = MLFramework.TENSORFLOW
elif PYTORCH in sys.modules:
ml_framework_in_use = MLFramework.PYTORCH


def get_ml_framework_in_use():
Expand Down
16 changes: 16 additions & 0 deletions lib/python/flame/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ class SelectorType(str, Enum):
OORT = "oort"


class DataSamplerType(str, Enum):
"""Define datasampler types."""

DEFAULT = "default"
FEDBALANCER = "fedbalancer"


class Job(FlameSchema):
job_id: str = Field(alias="id")
name: str
Expand All @@ -90,6 +97,11 @@ class Selector(FlameSchema):
kwargs: dict = Field(default={})


class DataSampler(FlameSchema):
sort: DataSamplerType = Field(default=DataSamplerType.DEFAULT)
kwargs: dict = Field(default={})


class Optimizer(FlameSchema):
sort: OptimizerType = Field(default=OptimizerType.DEFAULT)
kwargs: dict = Field(default={})
Expand Down Expand Up @@ -172,6 +184,7 @@ def __init__(self, config_path: str):
job: Job
registry: t.Optional[Registry]
selector: t.Optional[Selector]
datasampler: t.Optional[DataSampler]
optimizer: t.Optional[Optimizer] = Field(default=Optimizer())
dataset: str
max_run_time: int
Expand Down Expand Up @@ -213,9 +226,12 @@ def transform_config(raw_config: dict) -> dict:
sort_to_host = transform_brokers(raw_config["brokers"])
config_data = config_data | {"brokers": sort_to_host}

raw_config["datasampler"]["kwargs"].update(hyperparameters)

config_data = config_data | {
"job": raw_config["job"],
"selector": raw_config["selector"],
"datasampler": raw_config["datasampler"],
}

if raw_config.get("registry", None):
Expand Down
80 changes: 80 additions & 0 deletions lib/python/flame/datasampler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""datasampler abstract class."""

from abc import ABC, abstractmethod
from torch.utils.data import Dataset
from typing import Any
from datetime import datetime

from flame.channel import Channel


class AbstractTrainerDataSampler(ABC):
"""Abstract base class for trainer-side datasampler implementation."""

def __init__(self, **kwargs) -> None:
"""Initialize an instance with keyword-based arguments."""
for key, value in kwargs.items():
setattr(self, key, value)

@abstractmethod
def sample(self, dataset: Dataset, **kwargs) -> Dataset:
"""Abstract method to sample data.
Parameters
----------
dataset: PyTorch Dataset of a trainer to select samples from
kwargs: other arguments specific to each datasampler algorithm
Returns
-------
dataset: PyTorch Dataset that only contains selected samples
"""

@abstractmethod
def load_dataset(self, dataset: Dataset) -> Dataset:
"""Process dataset instance for datasampler."""

@abstractmethod
def get_metadata(self) -> dict[str, Any]:
"""Return metadata to send to aggregator-side datasampler."""

@abstractmethod
def handle_metadata_from_aggregator(self, metadata: dict[str, Any]) -> None:
"""Handle aggregator metadata for datasampler."""


class AbstractAggregatorDataSampler(ABC):
"""Abstract base class for aggregator-side datasampler implementation."""

def __init__(self, **kwargs) -> None:
"""Initialize an instance with keyword-based arguments."""
for key, value in kwargs.items():
setattr(self, key, value)

@abstractmethod
def get_metadata(self, channel: Any, end: str, round: int) -> dict[str, Any]:
"""Return metadata to send to trainer-side datasampler."""

@abstractmethod
def handle_metadata_from_trainer(
self,
metadata: dict[str, Any],
end: str,
channel: Channel,
) -> None:
"""Handle trainer metadata for datasampler."""
82 changes: 82 additions & 0 deletions lib/python/flame/datasampler/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""DefaultDataSampler class."""

import logging
from torch.utils.data import Dataset
from typing import Any
from datetime import datetime

from flame.channel import Channel
from flame.common.util import MLFramework, get_ml_framework_in_use, valid_frameworks
from flame.datasampler import AbstractTrainerDataSampler, AbstractAggregatorDataSampler

logger = logging.getLogger(__name__)


class DefaultTrainerDataSampler(AbstractTrainerDataSampler):
"""A default trainer-side datasampler class."""

def __init__(self, **kwargs):
"""Initailize instance."""
super().__init__()

def sample(self, dataset: Dataset, **kwargs) -> Dataset:
"""Return all dataset from the given dataset."""
logger.debug("calling default datasampler")

ml_framework_in_use = get_ml_framework_in_use()
if ml_framework_in_use == MLFramework.PYTORCH:
logger.debug(f"sampled data: {len(dataset)}")
return dataset
else:
raise NotImplementedError(
"supported ml framework not found; "
f"supported frameworks are: {valid_frameworks}"
)

def load_dataset(self, dataset: Dataset) -> None:
"""Change dataset instance to return index with each sample."""
return dataset

def get_metadata(self) -> dict[str, Any]:
"""Return metadata to send to aggregator-side datasampler."""
return {}

def handle_metadata_from_aggregator(self, metadata: dict[str, Any]) -> None:
"""Handle aggregator metadata for datasampler."""
pass


class DefaultAggregatorDataSampler(AbstractAggregatorDataSampler):
"""A default aggregator-side datasampler class."""

def __init__(self, **kwargs):
"""Initailize instance."""
super().__init__()

def get_metadata(self, channel: Any, end: str, round: int) -> dict[str, Any]:
"""Return metadata to send to trainer-side datasampler."""
return {}

def handle_metadata_from_trainer(
self,
metadata: dict[str, Any],
end: str,
channel: Channel,
) -> None:
"""Handle trainer metadata for datasampler."""
pass
Loading

0 comments on commit c7acbdb

Please sign in to comment.