Skip to content

Commit

Permalink
feat(framework) Introduce RunTracker (#3561)
Browse files Browse the repository at this point in the history
Co-authored-by: Charles Beauville <charels@flower.ai>
  • Loading branch information
jafermarq and Charles Beauville authored Jun 8, 2024
1 parent 70844c4 commit 018dd45
Showing 1 changed file with 45 additions and 20 deletions.
65 changes: 45 additions & 20 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# ==============================================================================
"""Flower client app."""

import signal
import sys
import time
from dataclasses import dataclass
from logging import DEBUG, ERROR, INFO, WARN
from typing import Callable, ContextManager, Optional, Tuple, Type, Union

Expand All @@ -37,7 +39,7 @@
)
from flwr.common.logger import log, warn_deprecated_feature
from flwr.common.message import Error
from flwr.common.retry_invoker import RetryInvoker, exponential
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential

from .grpc_client.connection import grpc_connection
from .grpc_rere_client.connection import grpc_request_response
Expand Down Expand Up @@ -263,6 +265,29 @@ def _load_client_app() -> ClientApp:
transport, server_address
)

run_tracker = _RunTracker()

def _on_sucess(retry_state: RetryState) -> None:
if retry_state.tries > 1:
log(
INFO,
"Connection successful after %.2f seconds and %s tries.",
retry_state.elapsed_time,
retry_state.tries,
)
if run_tracker.create_node:
run_tracker.create_node()

def _on_backoff(retry_state: RetryState) -> None:
if retry_state.tries == 1:
log(WARN, "Connection attempt failed, retrying...")
else:
log(
DEBUG,
"Connection attempt failed, retrying in %.2f seconds",
retry_state.actual_wait,
)

retry_invoker = RetryInvoker(
wait_gen_factory=exponential,
recoverable_exceptions=connection_error_type,
Expand All @@ -278,25 +303,8 @@ def _load_client_app() -> ClientApp:
if retry_state.tries > 1
else None
),
on_success=lambda retry_state: (
log(
INFO,
"Connection successful after %.2f seconds and %s tries.",
retry_state.elapsed_time,
retry_state.tries,
)
if retry_state.tries > 1
else None
),
on_backoff=lambda retry_state: (
log(WARN, "Connection attempt failed, retrying...")
if retry_state.tries == 1
else log(
DEBUG,
"Connection attempt failed, retrying in %.2f seconds",
retry_state.actual_wait,
)
),
on_success=_on_sucess,
on_backoff=_on_backoff,
)

node_state = NodeState()
Expand Down Expand Up @@ -579,3 +587,20 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
)

return connection, address, error_type


@dataclass
class _RunTracker:
create_node: Optional[Callable[[], None]] = None
interrupt: bool = False

def register_signal_handler(self) -> None:
"""Register handlers for exit signals."""

def signal_handler(sig, frame): # type: ignore
# pylint: disable=unused-argument
self.interrupt = True
raise StopIteration from None

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

0 comments on commit 018dd45

Please sign in to comment.