Skip to content

Commit

Permalink
feat(framework) Make NodeState capture partition-id (#3695)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Jun 29, 2024
1 parent 14913bf commit 059c9eb
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 8 deletions.
6 changes: 5 additions & 1 deletion src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def _start_client_internal(
] = None,
max_retries: Optional[int] = None,
max_wait_time: Optional[float] = None,
partition_id: Optional[int] = None,
) -> None:
"""Start a Flower client node which connects to a Flower server.
Expand Down Expand Up @@ -234,6 +235,9 @@ class `flwr.client.Client` (default: None)
The maximum duration before the client stops trying to
connect to the server in case of connection error.
If set to None, there is no limit to the total time.
partitioni_id: Optional[int] (default: None)
The data partition index associated with this node. Better suited for
prototyping purposes.
"""
if insecure is None:
insecure = root_certificates is None
Expand Down Expand Up @@ -309,7 +313,7 @@ def _on_backoff(retry_state: RetryState) -> None:
on_backoff=_on_backoff,
)

node_state = NodeState()
node_state = NodeState(partition_id=partition_id)
# run_id -> (fab_id, fab_version)
run_info: Dict[int, Tuple[str, str]] = {}

Expand Down
9 changes: 6 additions & 3 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,25 @@
"""Node state."""


from typing import Any, Dict
from typing import Any, Dict, Optional

from flwr.common import Context, RecordSet


class NodeState:
"""State of a node where client nodes execute runs."""

def __init__(self) -> None:
def __init__(self, partition_id: Optional[int]) -> None:
self._meta: Dict[str, Any] = {} # holds metadata about the node
self.run_contexts: Dict[int, Context] = {}
self._partition_id = partition_id

def register_context(self, run_id: int) -> None:
"""Register new run context for this node."""
if run_id not in self.run_contexts:
self.run_contexts[run_id] = Context(state=RecordSet())
self.run_contexts[run_id] = Context(
state=RecordSet(), partition_id=self._partition_id
)

def retrieve_context(self, run_id: int) -> Context:
"""Get run context given a run_id."""
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/client/node_state_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_multirun_in_node_state() -> None:
expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}

# NodeState
node_state = NodeState()
node_state = NodeState(partition_id=None)

for task in tasks:
run_id = task.run_id
Expand Down
8 changes: 8 additions & 0 deletions src/py/flwr/client/supernode/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def run_supernode() -> None:
authentication_keys=authentication_keys,
max_retries=args.max_retries,
max_wait_time=args.max_wait_time,
partition_id=args.partition_id,
)

# Graceful shutdown
Expand Down Expand Up @@ -373,6 +374,13 @@ def _parse_args_common(parser: argparse.ArgumentParser) -> None:
type=str,
help="The SuperNode's public key (as a path str) to enable authentication.",
)
parser.add_argument(
"--partition-id",
type=int,
help="The data partition index associated with this SuperNode. Better suited "
"for prototyping purposes where a SuperNode might only load a fraction of an "
"artificially partitioned dataset (e.g. using `flwr-datasets`)",
)


def _try_setup_client_authentication(
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/server/superlink/fleet/vce/vce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ def start_vce(

# Construct mapping of NodeStates
node_states: Dict[int, NodeState] = {}
for node_id in nodes_mapping:
node_states[node_id] = NodeState()
for node_id, partition_id in nodes_mapping.items():
node_states[node_id] = NodeState(partition_id=partition_id)

# Load backend config
log(DEBUG, "Supported backends: %s", list(supported_backends.keys()))
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/simulation/ray_transport/ray_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _load_app() -> ClientApp:

self.app_fn = _load_app
self.actor_pool = actor_pool
self.proxy_state = NodeState()
self.proxy_state = NodeState(partition_id=int(self.cid))

def _submit_job(self, message: Message, timeout: Optional[float]) -> Message:
"""Sumbit a message to the ActorPool."""
Expand Down

0 comments on commit 059c9eb

Please sign in to comment.