Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Add override config to Run #3730

Merged
merged 12 commits into from
Jul 8, 2024
2 changes: 1 addition & 1 deletion src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def _on_backoff(retry_state: RetryState) -> None:
run_info[run_id] = get_run(run_id)
# If get_run is None, i.e., in grpc-bidi mode
else:
run_info[run_id] = Run(run_id, "", "")
run_info[run_id] = Run(run_id, "", "", {})

# Register context for this run
node_state.register_context(run_id=run_id)
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def get_run(run_id: int) -> Run:
run_id,
get_run_response.run.fab_id,
get_run_response.run.fab_version,
dict(get_run_response.run.override_config.items()),
)

try:
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,13 @@ def get_run(run_id: int) -> Run:
# Send the request
res = _request(req, GetRunResponse, PATH_GET_RUN)
if res is None:
return Run(run_id, "", "")
return Run(run_id, "", "", {})

return Run(
run_id,
res.run.fab_id,
res.run.fab_version,
dict(res.run.override_config.items()),
)

try:
Expand Down
1 change: 1 addition & 0 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,4 @@ class Run:
run_id: int
fab_id: str
fab_version: str
override_config: Dict[str, str]
1 change: 1 addition & 0 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def _init_run(self) -> None:
run_id=res.run.run_id,
fab_id=res.run.fab_id,
fab_version=res.run.fab_version,
override_config=dict(res.run.override_config.items()),
)

@property
Expand Down
10 changes: 7 additions & 3 deletions src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ def setUp(self) -> None:
for _ in range(self.num_nodes)
]
self.state.get_run.return_value = Run(
run_id=61016, fab_id="mock/mock", fab_version="v1.0.0"
run_id=61016,
fab_id="mock/mock",
fab_version="v1.0.0",
override_config={"test_key": "test_value"},
)
state_factory = MagicMock(state=lambda: self.state)
self.driver = InMemoryDriver(run_id=61016, state_factory=state_factory)
Expand All @@ -98,6 +101,7 @@ def test_get_run(self) -> None:
self.assertEqual(self.driver.run.run_id, 61016)
self.assertEqual(self.driver.run.fab_id, "mock/mock")
self.assertEqual(self.driver.run.fab_version, "v1.0.0")
self.assertEqual(self.driver.run.override_config["test_key"], "test_value")

def test_get_nodes(self) -> None:
"""Test retrieval of nodes."""
Expand Down Expand Up @@ -223,7 +227,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None:
# Prepare
state = StateFactory("").state()
self.driver = InMemoryDriver(
state.create_run("", ""), MagicMock(state=lambda: state)
state.create_run("", "", {}), MagicMock(state=lambda: state)
)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, SqliteState)
Expand All @@ -249,7 +253,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None:
# Prepare
state_factory = StateFactory(":flwr-in-memory-state:")
state = state_factory.state()
self.driver = InMemoryDriver(state.create_run("", ""), state_factory)
self.driver = InMemoryDriver(state.create_run("", "", {}), state_factory)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
assert isinstance(state, InMemoryState)

Expand Down
6 changes: 5 additions & 1 deletion src/py/flwr/server/superlink/driver/driver_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ def CreateRun(
"""Create run ID."""
log(DEBUG, "DriverServicer.CreateRun")
state: State = self.state_factory.state()
run_id = state.create_run(request.fab_id, request.fab_version)
run_id = state.create_run(
request.fab_id,
request.fab_version,
dict(request.override_config.items()),
)
return CreateRunResponse(run_id=run_id)

def PushTaskIns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def test_successful_get_run_with_metadata(self) -> None:
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._client_public_key)
)
run_id = self.state.create_run("", "")
run_id = self.state.create_run("", "", {})
request = GetRunRequest(run_id=run_id)
shared_secret = generate_shared_key(
self._client_private_key, self._server_public_key
Expand Down Expand Up @@ -359,7 +359,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None:
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._client_public_key)
)
run_id = self.state.create_run("", "")
run_id = self.state.create_run("", "", {})
request = GetRunRequest(run_id=run_id)
client_private_key, _ = generate_key_pairs()
shared_secret = generate_shared_key(client_private_key, self._server_public_key)
Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/server/superlink/fleet/vce/vce_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def register_messages_into_state(
) -> Dict[UUID, float]:
"""Register `num_messages` into the state factory."""
state: InMemoryState = state_factory.state() # type: ignore
state.run_ids[run_id] = Run(run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0")
state.run_ids[run_id] = Run(
run_id=run_id, fab_id="Mock/mock", fab_version="v1.0.0", override_config={}
)
# Artificially add TaskIns to state so they can be processed
# by the Simulation Engine logic
nodes_cycle = cycle(nodes_mapping.keys()) # we have more messages than supernodes
Expand Down
12 changes: 10 additions & 2 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,15 +275,23 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
"""Retrieve stored `node_id` filtered by `client_public_keys`."""
return self.public_key_to_node_id.get(client_public_key)

def create_run(self, fab_id: str, fab_version: str) -> int:
def create_run(
self,
fab_id: str,
fab_version: str,
override_config: Dict[str, str],
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
# Sample a random int64 as run_id
with self.lock:
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)

if run_id not in self.run_ids:
self.run_ids[run_id] = Run(
run_id=run_id, fab_id=fab_id, fab_version=fab_version
run_id=run_id,
fab_id=fab_id,
fab_version=fab_version,
override_config=override_config,
)
return run_id
log(ERROR, "Unexpected run creation failure.")
Expand Down
29 changes: 22 additions & 7 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""SQLite based implemenation of server state."""


import json
import re
import sqlite3
import time
Expand Down Expand Up @@ -61,9 +62,10 @@

SQL_CREATE_TABLE_RUN = """
CREATE TABLE IF NOT EXISTS run(
run_id INTEGER UNIQUE,
fab_id TEXT,
fab_version TEXT
run_id INTEGER UNIQUE,
fab_id TEXT,
fab_version TEXT,
override_config TEXT
);
"""

Expand Down Expand Up @@ -613,7 +615,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
return node_id
return None

def create_run(self, fab_id: str, fab_version: str) -> int:
def create_run(
self,
fab_id: str,
fab_version: str,
override_config: Dict[str, str],
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""
# Sample a random int64 as run_id
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)
Expand All @@ -622,8 +629,13 @@ def create_run(self, fab_id: str, fab_version: str) -> int:
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
# If run_id does not exist
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
query = "INSERT INTO run (run_id, fab_id, fab_version) VALUES (?, ?, ?);"
self.query(query, (run_id, fab_id, fab_version))
query = (
"INSERT INTO run (run_id, fab_id, fab_version, override_config)"
"VALUES (?, ?, ?, ?);"
)
self.query(
query, (run_id, fab_id, fab_version, json.dumps(override_config))
)
return run_id
log(ERROR, "Unexpected run creation failure.")
return 0
Expand Down Expand Up @@ -687,7 +699,10 @@ def get_run(self, run_id: int) -> Optional[Run]:
try:
row = self.query(query, (run_id,))[0]
return Run(
run_id=run_id, fab_id=row["fab_id"], fab_version=row["fab_version"]
run_id=run_id,
fab_id=row["fab_id"],
fab_version=row["fab_version"],
override_config=json.loads(row["override_config"]),
)
except sqlite3.IntegrityError:
log(ERROR, "`run_id` does not exist.")
Expand Down
9 changes: 7 additions & 2 deletions src/py/flwr/server/superlink/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


import abc
from typing import List, Optional, Set
from typing import Dict, List, Optional, Set
from uuid import UUID

from flwr.common.typing import Run
Expand Down Expand Up @@ -157,7 +157,12 @@ def get_node_id(self, client_public_key: bytes) -> Optional[int]:
"""Retrieve stored `node_id` filtered by `client_public_keys`."""

@abc.abstractmethod
def create_run(self, fab_id: str, fab_version: str) -> int:
def create_run(
self,
fab_id: str,
fab_version: str,
override_config: Dict[str, str],
) -> int:
"""Create a new run for the specified `fab_id` and `fab_version`."""

@abc.abstractmethod
Expand Down
Loading