Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RetryInvoker #2553

Merged
merged 3 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions src/py/flwr/common/retry_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class RetryInvoker:

Parameters
----------
wait_strategy: Generator[float, None, None]
wait_factory: Callable[[], Generator[float, None, None]]
A generator yielding successive wait times in seconds. If the generator
is finite, the giveup event will be triggered when the generator raises
`StopIteration`.
Expand All @@ -129,11 +129,11 @@ class RetryInvoker:
data class object detailing the invocation.
on_giveup: Optional[Callable[[RetryState], None]] (default: None)
A callable to be executed in the event that `max_tries` or `max_time` is
exceeded, `should_giveup` returns True, or `wait_strategy` generator raises
exceeded, `should_giveup` returns True, or `wait_factory()` generator raises
`StopInteration`. The parameter is a data class object detailing the
invocation.
jitter: Optional[Callable[[float], float]] (default: full_jitter)
A function of the value yielded by `wait_strategy` returning the actual time
A function of the value yielded by `wait_factory()` returning the actual time
to wait. This function helps distribute wait times stochastically to avoid
timing collisions across concurrent clients. Wait times are jittered by
default using the `full_jitter` function. To disable jittering, pass
Expand All @@ -145,20 +145,20 @@ class RetryInvoker:

Examples
--------
Initialize a `RetryInvoker` with exponential backoff and call a function:
Initialize a `RetryInvoker` with exponential backoff and invoke a function:

>>> invoker = RetryInvoker(
>>> exponential(),
>>> grpc.RpcError,
>>> max_tries=3,
>>> max_time=None,
>>> )
... exponential, # Or use `lambda: exponential(3, 2)` to pass arguments
... grpc.RpcError,
... max_tries=3,
... max_time=None,
... )
>>> invoker.invoke(my_func, arg1, arg2, kw1=kwarg1)
"""

def __init__(
self,
wait_strategy: Generator[float, None, None],
wait_factory: Callable[[], Generator[float, None, None]],
recoverable_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]],
max_tries: Optional[int],
max_time: Optional[float],
Expand All @@ -169,7 +169,7 @@ def __init__(
jitter: Optional[Callable[[float], float]] = full_jitter,
should_giveup: Optional[Callable[[Exception], bool]] = None,
) -> None:
self.wait_strategy = wait_strategy
self.wait_factory = wait_factory
self.recoverable_exceptions = recoverable_exceptions
self.max_tries = max_tries
self.max_time = max_time
Expand All @@ -183,8 +183,8 @@ def __init__(
def invoke(
self,
target: Callable[..., Any],
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
*args: Any,
**kwargs: Any,
) -> Any:
"""Safely invoke the provided callable with retry mechanisms.

Expand Down Expand Up @@ -212,12 +212,12 @@ def invoke(
------
Exception
If the number of tries exceeds `max_tries`, if the total time
exceeds `max_time`, if `wait_strategy` generator raises `StopInteration`,
exceeds `max_time`, if `wait_factory()` generator raises `StopInteration`,
or if the `should_giveup` returns True for a raised exception.

Notes
-----
The time between retries is determined by the provided `wait_strategy`
The time between retries is determined by the provided `wait_factory()`
generator and can optionally be jittered using the `jitter` function.
The recoverable exceptions that trigger a retry, as well as conditions to
stop retries, are also determined by the class's initialization parameters.
Expand All @@ -230,6 +230,7 @@ def try_call_event_handler(
handler(cast(RetryState, ref_state[0]))

try_cnt = 0
wait_generator = self.wait_factory()
start = time.time()
ref_state: List[Optional[RetryState]] = [None]

Expand Down Expand Up @@ -265,7 +266,7 @@ def giveup_check(_exception: Exception) -> bool:
raise

try:
wait_time = next(self.wait_strategy)
wait_time = next(wait_generator)
if self.jitter is not None:
wait_time = self.jitter(wait_time)
if self.max_time is not None:
Expand Down
22 changes: 14 additions & 8 deletions src/py/flwr/common/retry_invoker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_successful_invocation() -> None:
backoff_handler = Mock()
giveup_handler = Mock()
invoker = RetryInvoker(
constant(0.1),
lambda: constant(0.1),
ValueError,
max_tries=None,
max_time=None,
Expand All @@ -77,7 +77,7 @@ def test_failure() -> None:
"""Check termination when unexpected exception is raised."""
# Prepare
# `constant([0.1])` generator will raise `StopIteration` after one iteration.
invoker = RetryInvoker(constant(0.1), TypeError, None, None)
invoker = RetryInvoker(lambda: constant(0.1), TypeError, None, None)

# Execute and Assert
with pytest.raises(ValueError):
Expand All @@ -88,7 +88,11 @@ def test_failure_two_exceptions(mock_sleep: MagicMock) -> None:
"""Verify one retry on a specified iterable of exceptions."""
# Prepare
invoker = RetryInvoker(
constant(0.1), (TypeError, ValueError), max_tries=2, max_time=None, jitter=None
lambda: constant(0.1),
(TypeError, ValueError),
max_tries=2,
max_time=None,
jitter=None,
)

# Execute and Assert
Expand All @@ -101,7 +105,7 @@ def test_backoff_on_failure(mock_sleep: MagicMock) -> None:
"""Verify one retry on specified exception."""
# Prepare
# `constant([0.1])` generator will raise `StopIteration` after one iteration.
invoker = RetryInvoker(constant([0.1]), ValueError, None, None, jitter=None)
invoker = RetryInvoker(lambda: constant([0.1]), ValueError, None, None, jitter=None)

# Execute and Assert
with pytest.raises(ValueError):
Expand All @@ -114,7 +118,7 @@ def test_max_tries(mock_sleep: MagicMock) -> None:
# Prepare
# Disable `jitter` to ensure 0.1s wait time.
invoker = RetryInvoker(
constant(0.1), ValueError, max_tries=2, max_time=None, jitter=None
lambda: constant(0.1), ValueError, max_tries=2, max_time=None, jitter=None
)

# Execute and Assert
Expand All @@ -132,7 +136,9 @@ def test_max_time(mock_time: MagicMock, mock_sleep: MagicMock) -> None:
0.0,
3.0,
]
invoker = RetryInvoker(constant(2), ValueError, max_tries=None, max_time=2.5)
invoker = RetryInvoker(
lambda: constant(2), ValueError, max_tries=None, max_time=2.5
)

# Execute and Assert
with pytest.raises(ValueError):
Expand All @@ -148,7 +154,7 @@ def test_event_handlers() -> None:
backoff_handler = Mock()
giveup_handler = Mock()
invoker = RetryInvoker(
constant(0.1),
lambda: constant(0.1),
ValueError,
max_tries=2,
max_time=None,
Expand All @@ -173,7 +179,7 @@ def should_give_up(exc: Exception) -> bool:
return isinstance(exc, ValueError)

invoker = RetryInvoker(
constant(0.1), ValueError, None, None, should_giveup=should_give_up
lambda: constant(0.1), ValueError, None, None, should_giveup=should_give_up
)

# Execute and Assert
Expand Down