Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Krum and MultiKrum #1481

Merged
merged 23 commits into from
Dec 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion src/py/flwr/server/strategy/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import numpy as np

from flwr.common import NDArrays
from flwr.common import NDArray, NDArrays


def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays:
Expand Down Expand Up @@ -53,6 +53,44 @@ def aggregate_median(results: List[Tuple[NDArrays, int]]) -> NDArrays:
return median_w


def aggregate_krum(
results: List[Tuple[NDArrays, int]], num_malicious: int, to_keep: int
) -> NDArrays:
"""Choose one parameter vector according to the Krum fucntion.

If to_keep is not None, then MultiKrum is applied.
"""
# Create a list of weights and ignore the number of examples
weights = [weights for weights, _ in results]

# Compute distances between vectors
distance_matrix = _compute_distances(weights)

# For each client, take the n-f-2 closest parameters vectors
num_closest = max(1, len(weights) - num_malicious - 2)
closest_indices = []
for i, _ in enumerate(distance_matrix):
closest_indices.append(
np.argsort(distance_matrix[i])[1 : num_closest + 1].tolist() # noqa: E203
)

# Compute the score for each client, that is the sum of the distances
# of the n-f-2 closest parameters vectors
scores = [
np.sum(distance_matrix[i, closest_indices[i]])
for i in range(len(distance_matrix))
]

if to_keep > 0:
# Choose to_keep clients and return their average (MultiKrum)
best_indices = np.argsort(scores)[::-1][len(scores) - to_keep :] # noqa: E203
best_results = [results[i] for i in best_indices]
return aggregate(best_results)

# Return the index of the client which minimizes the score (Krum)
return weights[np.argmin(scores)]


def weighted_loss_avg(results: List[Tuple[int, float]]) -> float:
"""Aggregate evaluation results obtained from multiple clients."""
num_total_evaluation_examples = sum([num_examples for num_examples, _ in results])
Expand All @@ -76,3 +114,21 @@ def aggregate_qffl(
updates.append(tmp)
new_parameters = [(u - v) * 1.0 for u, v in zip(parameters, updates)]
return new_parameters


def _compute_distances(weights: List[NDArrays]) -> NDArray:
"""Compute distances between vectors.

Input: weights - list of weights vectors
Output: distances - matrix distance_matrix of squared distances between the vectors
"""
flat_w = np.array(
[np.concatenate(p, axis=None).ravel() for p in weights] # type: ignore
)
distance_matrix = np.zeros((len(weights), len(weights)))
for i, _ in enumerate(flat_w):
for j, _ in enumerate(flat_w):
delta = flat_w[i] - flat_w[j]
norm = np.linalg.norm(delta) # type: ignore
distance_matrix[i, j] = norm**2
return distance_matrix
167 changes: 167 additions & 0 deletions src/py/flwr/server/strategy/krum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright 2020 Adap GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent.

[Blanchard et al., 2017].

Paper: https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf
"""


from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union

from flwr.common import (
FitRes,
MetricsAggregationFn,
NDArrays,
Parameters,
Scalar,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_proxy import ClientProxy

from .aggregate import aggregate_krum
from .fedavg import FedAvg

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""

# flake8: noqa: E501
class Krum(FedAvg):
"""Configurable Krum strategy implementation."""

# pylint: disable=too-many-arguments,too-many-instance-attributes,line-too-long
def __init__(
self,
*,
fraction_fit: float = 1.0,
fraction_evaluate: float = 1.0,
min_fit_clients: int = 2,
min_evaluate_clients: int = 2,
min_available_clients: int = 2,
num_malicious_clients: int = 0,
num_clients_to_keep: int = 0,
evaluate_fn: Optional[
Callable[
[int, NDArrays, Dict[str, Scalar]],
Optional[Tuple[float, Dict[str, Scalar]]],
]
] = None,
on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
accept_failures: bool = True,
initial_parameters: Optional[Parameters] = None,
fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
) -> None:
"""Configurable Krum strategy.

Parameters
----------
fraction_fit : float, optional
Fraction of clients used during training. Defaults to 0.1.
fraction_evaluate : float, optional
Fraction of clients used during validation. Defaults to 0.1.
min_fit_clients : int, optional
Minimum number of clients used during training. Defaults to 2.
min_evaluate_clients : int, optional
Minimum number of clients used during validation. Defaults to 2.
min_available_clients : int, optional
Minimum number of total clients in the system. Defaults to 2.
num_malicious_clients : int, optional
Number of malicious clients in the system. Defaults to 0.
num_clients_to_keep : int, optional
Number of clients to keep before averaging (MultiKrum). Defaults to 0, in that case classical Krum is applied.
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]]
Optional function used for validation. Defaults to None.
on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure training. Defaults to None.
on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure validation. Defaults to None.
accept_failures : bool, optional
Whether or not accept rounds containing failures. Defaults to True.
initial_parameters : Parameters, optional
Initial global model parameters.
"""

if (
min_fit_clients > min_available_clients
or min_evaluate_clients > min_available_clients
):
log(WARNING, WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW)

super().__init__(
fraction_fit=fraction_fit,
fraction_evaluate=fraction_evaluate,
min_fit_clients=min_fit_clients,
min_evaluate_clients=min_evaluate_clients,
min_available_clients=min_available_clients,
evaluate_fn=evaluate_fn,
on_fit_config_fn=on_fit_config_fn,
on_evaluate_config_fn=on_evaluate_config_fn,
accept_failures=accept_failures,
initial_parameters=initial_parameters,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
)
self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn
self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn
self.num_malicious_clients = num_malicious_clients
self.num_clients_to_keep = num_clients_to_keep

def __repr__(self) -> str:
rep = f"Krum(accept_failures={self.accept_failures})"
return rep

def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
"""Aggregate fit results using Krum."""
if not results:
return None, {}
# Do not aggregate if there are failures and failures are not accepted
if not self.accept_failures and failures:
return None, {}

# Convert results
weights_results = [
(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
for _, fit_res in results
]
parameters_aggregated = ndarrays_to_parameters(
aggregate_krum(
weights_results, self.num_malicious_clients, self.num_clients_to_keep
)
)

# Aggregate custom metrics if aggregation fn was provided
metrics_aggregated = {}
if self.fit_metrics_aggregation_fn:
fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No fit_metrics_aggregation_fn provided")

return parameters_aggregated, metrics_aggregated
Loading