Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FedProx strategy #1619

Merged
merged 18 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions doc/source/apiref-flwr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ server.strategy.FedOpt
.. automethod:: __init__


.. _flwr-server-strategy-FedProx-apiref:

server.strategy.FedProx
^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: flwr.server.strategy.FedProx
:members:

.. automethod:: __init__


.. _flwr-server-strategy-FedAdagrad-apiref:

server.strategy.FedAdagrad
Expand Down
8 changes: 7 additions & 1 deletion doc/source/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@

### What's new?

- **Add support for ** `workload_id` **and** `group_id` **in Driver API** ([#1595](https://github.com/adap/flower/pull/1595))
- **Add support for** `workload_id` **and** `group_id` **in Driver API** ([#1595](https://github.com/adap/flower/pull/1595))

- **Make Android example compatible with** `flwr >= 1.0.0` **and the latest versions of Android** ([#1603](https://github.com/adap/flower/pull/1603))

- **Add new `FedProx` strategy** ([#1619](https://github.com/adap/flower/pull/1619))

This [strategy](https://github.com/adap/flower/blob/main/src/py/flwr/server/strategy/fedprox.py) is almost identical to [`FedAvg`](https://github.com/adap/flower/blob/main/src/py/flwr/server/strategy/fedavg.py),
but helps users replicate what is described in this [paper](https://arxiv.org/abs/1812.06127). It essentially adds a parameter called `proximal_mu` to
regularize the local models with respect to the global models. You might want to checkout this [blog post](https://flower.dev/blog/2023-02-02-fl-starter-pack-fedprox-mnist-cnn/) for more information.
charlesbvll marked this conversation as resolved.
Show resolved Hide resolved

### Incompatible changes

Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/strategy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .fedavgm import FedAvgM as FedAvgM
from .fedmedian import FedMedian as FedMedian
from .fedopt import FedOpt as FedOpt
from .fedprox import FedProx as FedProx
from .fedyogi import FedYogi as FedYogi
from .qfedavg import QFedAvg as QFedAvg
from .strategy import Strategy as Strategy
Expand All @@ -35,6 +36,7 @@
"FedAvgAndroid",
"FedAvgM",
"FedOpt",
"FedProx",
"FedYogi",
"QFedAvg",
"FedMedian",
Expand Down
188 changes: 188 additions & 0 deletions src/py/flwr/server/strategy/fedprox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright 2020 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Federated Optimization (FedProx) [Li et al., 2018] strategy.

Paper: https://arxiv.org/abs/1812.06127
"""


from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple

from flwr.common import FitIns, MetricsAggregationFn, NDArrays, Parameters, Scalar
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

from .fedavg import FedAvg

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""

# flake8: noqa: E501
class FedProx(FedAvg):
"""Configurable FedProx strategy implementation."""

# pylint: disable=too-many-arguments,too-many-instance-attributes,line-too-long
def __init__(
self,
*,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
evaluate_fn: Optional[
Callable[
[int, NDArrays, Dict[str, Scalar]],
Optional[Tuple[float, Dict[str, Scalar]]],
]
] = None,
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Optional[Parameters] = None,
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
proximal_mu: float,
) -> None:
"""Federated Optimization strategy.

Implementation based on https://arxiv.org/abs/1812.06127

The strategy in itself will not be different than FedAvg, the client needs to be adjusted.
A proximal term needs to be added to the loss function during the training:

.. math::
\\frac{\\mu}{2} || w - w^t ||^2

Where $w^t$ are the global parameters and $w$ are the local weights the function will
be optimized with.

In PyTorch for example, the loss would go from:
danieljanes marked this conversation as resolved.
Show resolved Hide resolved

.. code:: python

loss = criterion(net(inputs), labels)

To:

.. code:: python

for local_weights, global_weights in zip(net.parameters(), global_params):
proximal_term += (local_weights - global_weights).norm(2)
loss = criterion(net(inputs), labels) + (config["proximal_mu"] / 2) * proximal_term

With `global_params` being a copy of the parameters before the training takes place.

.. code:: python

global_params = copy.deepcopy(net).parameters()

Parameters
----------
fraction_fit : float, optional
Fraction of clients used during training. In case `min_fit_clients`
is larger than `fraction_fit * available_clients`, `min_fit_clients`
will still be sampled. Defaults to 1.0.
fraction_evaluate : float, optional
Fraction of clients used during validation. In case `min_evaluate_clients`
is larger than `fraction_evaluate * available_clients`, `min_evaluate_clients`
will still be sampled. Defaults to 1.0.
min_fit_clients : int, optional
Minimum number of clients used during training. Defaults to 2.
min_evaluate_clients : int, optional
Minimum number of clients used during validation. Defaults to 2.
min_available_clients : int, optional
Minimum number of total clients in the system. Defaults to 2.
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]
Optional function used for validation. Defaults to None.
on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure training. Defaults to None.
on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure validation. Defaults to None.
accept_failures : bool, optional
Whether or not accept rounds containing failures. Defaults to True.
initial_parameters : Parameters, optional
Initial global model parameters.
fit_metrics_aggregation_fn : Optional[MetricsAggregationFn]
Metrics aggregation function, optional.
evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn]
Metrics aggregation function, optional.
proximal_mu : float
The weight of the proximal term used in the optimization. 0.0 makes
this strategy equivalent to FedAvg, and the higher the coefficient, the more
regularization will be used (that is, the client parameters will need to be
to the server parameters during training).
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
"""
super().__init__()

if (
min_fit_clients > min_available_clients
or min_evaluate_clients > min_available_clients
):
log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)

super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=min_fit_clients,
min_evaluate_clients=min_evaluate_clients,
min_available_clients=min_available_clients,
evaluate_fn=evaluate_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
initial_parameters=initial_parameters,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
)
self.proximal_mu = proximal_mu

def __repr__(self) -> str:
rep = f"FedProx(accept_failures={self.accept_failures})"
return rep

def configure_fit(
self, server_round: int, parameters: Parameters, client_manager: ClientManager
) -> List[Tuple[ClientProxy, FitIns]]:
"""Configure the next round of training.

Sends the proximal factor mu to the clients
"""
config = {}
if self.on_fit_config_fn is not None:
# Custom fit config function provided
config = self.on_fit_config_fn(server_round)

# Add proximal_mu to config
config["proximal_mu"] = self.proximal_mu

fit_ins = FitIns(parameters, config)

# Sample clients
sample_size, min_num_clients = self.num_fit_clients(
client_manager.num_available()
)
clients = client_manager.sample(
num_clients=sample_size, min_num_clients=min_num_clients
)

# Return client/config pairs
return [(client, fit_ins) for client in clients]
danieljanes marked this conversation as resolved.
Show resolved Hide resolved