Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Introduce RunTracker #3561

Merged
merged 1 commit into from
Jun 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 45 additions & 20 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# ==============================================================================
"""Flower client app."""

import signal
import sys
import time
from dataclasses import dataclass
from logging import DEBUG, ERROR, INFO, WARN
from typing import Callable, ContextManager, Optional, Tuple, Type, Union

Expand All @@ -37,7 +39,7 @@
)
from flwr.common.logger import log, warn_deprecated_feature
from flwr.common.message import Error
from flwr.common.retry_invoker import RetryInvoker, exponential
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential

from .grpc_client.connection import grpc_connection
from .grpc_rere_client.connection import grpc_request_response
Expand Down Expand Up @@ -263,6 +265,29 @@ def _load_client_app() -> ClientApp:
transport, server_address
)

run_tracker = _RunTracker()

def _on_sucess(retry_state: RetryState) -> None:
if retry_state.tries > 1:
log(
INFO,
"Connection successful after %.2f seconds and %s tries.",
retry_state.elapsed_time,
retry_state.tries,
)
if run_tracker.create_node:
run_tracker.create_node()

def _on_backoff(retry_state: RetryState) -> None:
if retry_state.tries == 1:
log(WARN, "Connection attempt failed, retrying...")
else:
log(
DEBUG,
"Connection attempt failed, retrying in %.2f seconds",
retry_state.actual_wait,
)

retry_invoker = RetryInvoker(
wait_gen_factory=exponential,
recoverable_exceptions=connection_error_type,
Expand All @@ -278,25 +303,8 @@ def _load_client_app() -> ClientApp:
if retry_state.tries > 1
else None
),
on_success=lambda retry_state: (
log(
INFO,
"Connection successful after %.2f seconds and %s tries.",
retry_state.elapsed_time,
retry_state.tries,
)
if retry_state.tries > 1
else None
),
on_backoff=lambda retry_state: (
log(WARN, "Connection attempt failed, retrying...")
if retry_state.tries == 1
else log(
DEBUG,
"Connection attempt failed, retrying in %.2f seconds",
retry_state.actual_wait,
)
),
on_success=_on_sucess,
on_backoff=_on_backoff,
)

node_state = NodeState()
Expand Down Expand Up @@ -579,3 +587,20 @@ def _init_connection(transport: Optional[str], server_address: str) -> Tuple[
)

return connection, address, error_type


@dataclass
class _RunTracker:
create_node: Optional[Callable[[], None]] = None
interrupt: bool = False

def register_signal_handler(self) -> None:
"""Register handlers for exit signals."""

def signal_handler(sig, frame): # type: ignore
# pylint: disable=unused-argument
self.interrupt = True
raise StopIteration from None

signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)