Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FedProx strategy #1619

Merged
merged 18 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions doc/source/apiref-flwr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ server.strategy.FedOpt
.. automethod:: __init__


.. _flwr-server-strategy-FedProx-apiref:

server.strategy.FedProx
^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: flwr.server.strategy.FedProx
:members:

.. automethod:: __init__


.. _flwr-server-strategy-FedAdagrad-apiref:

server.strategy.FedAdagrad
Expand Down
8 changes: 7 additions & 1 deletion doc/source/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@

### What's new?

- **Add support for ** `workload_id` **and** `group_id` **in Driver API** ([#1595](https://github.com/adap/flower/pull/1595))
- **Add support for** `workload_id` **and** `group_id` **in Driver API** ([#1595](https://github.com/adap/flower/pull/1595))

- **Make Android example compatible with** `flwr >= 1.0.0` **and the latest versions of Android** ([#1603](https://github.com/adap/flower/pull/1603))

- **Add new `FedProx` strategy** ([#1619](https://github.com/adap/flower/pull/1619))

This [strategy](https://github.com/adap/flower/blob/main/src/py/flwr/server/strategy/fedprox.py) is almost identical to [`FedAvg`](https://github.com/adap/flower/blob/main/src/py/flwr/server/strategy/fedavg.py),
but helps users replicate what is described in this [paper](https://arxiv.org/abs/1812.06127). It essentially adds a parameter called `proximal_mu` to
regularize the local models with respect to the global models. You might want to checkout this [blog post](https://flower.dev/blog/2023-02-02-fl-starter-pack-fedprox-mnist-cnn/) for more information.
charlesbvll marked this conversation as resolved.
Show resolved Hide resolved

### Incompatible changes

Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .fedavgm import FedAvgM as FedAvgM
from .fedmedian import FedMedian as FedMedian
from .fedopt import FedOpt as FedOpt
from .fedprox import FedProx as FedProx
from .fedyogi import FedYogi as FedYogi
from .qfedavg import QFedAvg as QFedAvg
from .strategy import Strategy as Strategy
Expand All @@ -35,6 +36,7 @@
"FedAvgAndroid",
"FedAvgM",
"FedOpt",
"FedProx",
"FedYogi",
"QFedAvg",
"FedMedian",
Expand Down
315 changes: 315 additions & 0 deletions src/py/flwr/server/strategy/fedprox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
# Copyright 2020 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Federated Optimization (FedProx) [Li et al., 2018] strategy.

Paper: https://arxiv.org/abs/1812.06127
"""


from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union

from flwr.common import (
EvaluateIns,
EvaluateRes,
FitIns,
FitRes,
MetricsAggregationFn,
NDArrays,
Parameters,
Scalar,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

from .aggregate import aggregate, weighted_loss_avg
from .strategy import Strategy

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""

# flake8: noqa: E501
class FedProx(Strategy):
charlesbvll marked this conversation as resolved.
Show resolved Hide resolved
"""Configurable FedProx strategy implementation."""

# pylint: disable=too-many-arguments,too-many-instance-attributes,line-too-long
def __init__(
self,
*,
proximal_mu: float,
charlesbvll marked this conversation as resolved.
Show resolved Hide resolved
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
evaluate_fn: Optional[
Callable[
[int, NDArrays, Dict[str, Scalar]],
Optional[Tuple[float, Dict[str, Scalar]]],
]
] = None,
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Optional[Parameters] = None,
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
) -> None:
"""Federated Optimization strategy.

Implementation based on https://arxiv.org/abs/1812.06127

The strategy in itself will not be different than FedAvg, the client needs to be adjusted.
A proximal term needs to be added to the loss function during the training:

.. math::
\\frac{\\mu}{2} || w - w^t ||^2

Where $w^t$ are the global parameters and $w$ are the local weights the function will
be optimized with.

In PyTorch for example, the loss would go from:
danieljanes marked this conversation as resolved.
Show resolved Hide resolved

.. code:: python

loss = criterion(net(inputs), labels)

To:

.. code:: python

for local_weights, global_weights in zip(net.parameters(), global_params):
proximal_term += (local_weights - global_weights).norm(2)
loss = criterion(net(inputs), labels) + (config["proximal_mu"] / 2) * proximal_term

With `global_params` being a copy of the parameters before the training takes place.

.. code:: python

global_params = copy.deepcopy(net).parameters()

Parameters
----------
proximal_mu : float
The weight of the proximal term used in the optimization. 0.0 makes
this strategy equivalent to FedAvg, and the higher the coefficient, the more
regularization will be used (that is, the client parameters will need to be
to the server parameters during training).
fraction_fit : float, optional
Fraction of clients used during training. In case `min_fit_clients`
is larger than `fraction_fit * available_clients`, `min_fit_clients`
will still be sampled. Defaults to 1.0.
fraction_evaluate : float, optional
Fraction of clients used during validation. In case `min_evaluate_clients`
is larger than `fraction_evaluate * available_clients`, `min_evaluate_clients`
will still be sampled. Defaults to 1.0.
min_fit_clients : int, optional
Minimum number of clients used during training. Defaults to 2.
min_evaluate_clients : int, optional
Minimum number of clients used during validation. Defaults to 2.
min_available_clients : int, optional
Minimum number of total clients in the system. Defaults to 2.
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]
Optional function used for validation. Defaults to None.
on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure training. Defaults to None.
on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure validation. Defaults to None.
accept_failures : bool, optional
Whether or not accept rounds containing failures. Defaults to True.
initial_parameters : Parameters, optional
Initial global model parameters.
fit_metrics_aggregation_fn : Optional[MetricsAggregationFn]
Metrics aggregation function, optional.
evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn]
Metrics aggregation function, optional.
"""
super().__init__()

if (
min_fit_clients > min_available_clients
or min_evaluate_clients > min_available_clients
):
log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)

self.proximal_mu = proximal_mu
self.fraction_fit = fraction_fit
self.fraction_evaluate = fraction_evaluate
self.min_fit_clients = min_fit_clients
self.min_evaluate_clients = min_evaluate_clients
self.min_available_clients = min_available_clients
self.evaluate_fn = evaluate_fn
self.on_fit_config_fn = on_fit_config_fn
self.on_evaluate_config_fn = on_evaluate_config_fn
self.accept_failures = accept_failures
self.initial_parameters = initial_parameters
self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn

def __repr__(self) -> str:
rep = f"FedAvg(accept_failures={self.accept_failures})"
return rep

def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]:
"""Return the sample size and the required number of available
clients."""
num_clients = int(num_available_clients * self.fraction_fit)
return max(num_clients, self.min_fit_clients), self.min_available_clients

def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]:
"""Use a fraction of available clients for evaluation."""
num_clients = int(num_available_clients * self.fraction_evaluate)
return max(num_clients, self.min_evaluate_clients), self.min_available_clients

def initialize_parameters(
self, client_manager: ClientManager
) -> Optional[Parameters]:
"""Initialize global model parameters."""
initial_parameters = self.initial_parameters
self.initial_parameters = None # Don't keep initial parameters in memory
return initial_parameters

def evaluate(
self, server_round: int, parameters: Parameters
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
"""Evaluate model parameters using an evaluation function."""
if self.evaluate_fn is None:
# No evaluation function provided
return None
parameters_ndarrays = parameters_to_ndarrays(parameters)
eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
if eval_res is None:
return None
loss, metrics = eval_res
return loss, metrics

def configure_fit(
self, server_round: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, FitIns]]:
"""Configure the next round of training."""
config = {}
if self.on_fit_config_fn is not None:
# Custom fit config function provided
config = self.on_fit_config_fn(server_round)

# Add proximal_mu to config
config["proximal_mu"] = self.proximal_mu

fit_ins = FitIns(parameters, config)

# Sample clients
sample_size, min_num_clients = self.num_fit_clients(
client_manager.num_available()
)
clients = client_manager.sample(
num_clients=sample_size, min_num_clients=min_num_clients
)

# Return client/config pairs
return [(client, fit_ins) for client in clients]
danieljanes marked this conversation as resolved.
Show resolved Hide resolved

def configure_evaluate(
self, server_round: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, EvaluateIns]]:
"""Configure the next round of evaluation."""
# Do not configure federated evaluation if fraction eval is 0.
if self.fraction_evaluate == 0.0:
return []

# Parameters and config
config = {}
if self.on_evaluate_config_fn is not None:
# Custom evaluation config function provided
config = self.on_evaluate_config_fn(server_round)
evaluate_ins = EvaluateIns(parameters, config)

# Sample clients
sample_size, min_num_clients = self.num_evaluation_clients(
client_manager.num_available()
)
clients = client_manager.sample(
num_clients=sample_size, min_num_clients=min_num_clients
)

# Return client/config pairs
return [(client, evaluate_ins) for client in clients]

def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using weighted average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}

# Convert results
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for _, fit_res in results
]
parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results))

# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.fit_metrics_aggregation_fn:
fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No fit_metrics_aggregation_fn provided")

return parameters_aggregated, metrics_aggregated

def aggregate_evaluate(
self,
server_round: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
) -> Tuple[Optional[float], Dict[str, Scalar]]:
"""Aggregate evaluation losses using weighted average."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}

# Aggregate loss
loss_aggregated = weighted_loss_avg(
[
(evaluate_res.num_examples, evaluate_res.loss)
for _, evaluate_res in results
]
)

# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.evaluate_metrics_aggregation_fn:
eval_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.evaluate_metrics_aggregation_fn(eval_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No evaluate_metrics_aggregation_fn provided")

return loss_aggregated, metrics_aggregated