Skip to content

Commit

Permalink
Rename connection messages to improve consistency (#1259)
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Jun 21, 2022
1 parent 6b3bd3e commit 342b5ee
Show file tree
Hide file tree
Showing 16 changed files with 193 additions and 191 deletions.
22 changes: 12 additions & 10 deletions doc/source/creating-new-messages.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ Within the :code:`ServerMessage` block:
repeated int64 l=2;
}
oneof msg {
Reconnect reconnect = 1;
GetParameters get_parameters = 2;
FitIns fit_ins = 3;
EvaluateIns evaluate_ins = 4;
ExampleIns example_ins = 5;
ReconnectIns reconnect_ins = 1;
GetPropertiesIns get_properties_ins = 2;
GetParametersIns get_parameters_ins = 3;
FitIns fit_ins = 4;
EvaluateIns evaluate_ins = 5;
ExampleIns example_ins = 6;
}
Within the ClientMessage block:
Expand All @@ -58,11 +59,12 @@ Within the ClientMessage block:
}
oneof msg {
Disconnect disconnect = 1;
ParametersRes parameters_res = 2;
FitRes fit_res = 3;
EvaluateRes evaluate_res = 4;
ExampleRes examples_res = 5;
DisconnectRes disconnect_res = 1;
GetPropertiesRes get_properties_res = 2;
GetParametersRes get_parameters_res = 3;
FitRes fit_res = 4;
EvaluateRes evaluate_res = 5;
ExampleRes examples_res = 6;
}
Make sure to also add a field of the newly created message type in :code:`oneof msg`.
Expand Down
10 changes: 5 additions & 5 deletions examples/android/client/app/src/main/proto/transport.proto
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ enum Reason {
}

message ServerMessage {
message Reconnect { int64 seconds = 1; }
message ReconnectIns { int64 seconds = 1; }
message GetParameters {}
message FitIns {
Parameters parameters = 1;
Expand All @@ -51,7 +51,7 @@ message ServerMessage {
}
message PropertiesIns { map<string, Scalar> config = 1; }
oneof msg {
Reconnect reconnect = 1;
ReconnectIns reconnect = 1;
GetParameters get_parameters = 2;
FitIns fit_ins = 3;
EvaluateIns evaluate_ins = 4;
Expand All @@ -60,7 +60,7 @@ message ServerMessage {
}

message ClientMessage {
message Disconnect { Reason reason = 1; }
message DisconnectRes { Reason reason = 1; }
message ParametersRes { Parameters parameters = 1; }
message FitRes {
Parameters parameters = 1;
Expand All @@ -77,7 +77,7 @@ message ClientMessage {
}
message PropertiesRes { map<string, Scalar> properties = 1; }
oneof msg {
Disconnect disconnect = 1;
DisconnectRes disconnect = 1;
ParametersRes parameters_res = 2;
FitRes fit_res = 3;
EvaluateRes evaluate_res = 4;
Expand Down Expand Up @@ -107,4 +107,4 @@ message Scalar {
string string = 14;
bytes bytes = 15;
}
}
}
2 changes: 1 addition & 1 deletion examples/embedded_devices/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
# Return the number of evaluation examples and the evaluation result (loss)
metrics = {"accuracy": float(accuracy)}
return EvaluateRes(
num_examples=len(self.testset), loss=float(loss), metrics=metrics
loss=float(loss), num_examples=len(self.testset), metrics=metrics
)


Expand Down
16 changes: 8 additions & 8 deletions src/proto/flwr/proto/transport.proto
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ enum Reason {
}

message ServerMessage {
message Reconnect { int64 seconds = 1; }
message ReconnectIns { int64 seconds = 1; }
message GetPropertiesIns { map<string, Scalar> config = 1; }
message GetParametersIns { map<string, Scalar> config = 1; }
message FitIns {
Expand All @@ -54,7 +54,7 @@ message ServerMessage {
map<string, Scalar> config = 2;
}
oneof msg {
Reconnect reconnect = 1;
ReconnectIns reconnect_ins = 1;
GetPropertiesIns get_properties_ins = 2;
GetParametersIns get_parameters_ins = 3;
FitIns fit_ins = 4;
Expand All @@ -63,24 +63,24 @@ message ServerMessage {
}

message ClientMessage {
message Disconnect { Reason reason = 1; }
message DisconnectRes { Reason reason = 1; }
message GetPropertiesRes {
Status status = 1;
map<string, Scalar> properties = 2;
}
message GetParametersRes { Parameters parameters = 1; }
message FitRes {
Parameters parameters = 1;
int64 num_examples = 2;
map<string, Scalar> metrics = 5;
Parameters parameters = 2;
int64 num_examples = 3;
map<string, Scalar> metrics = 4;
}
message EvaluateRes {
int64 num_examples = 1;
float loss = 2;
int64 num_examples = 3;
map<string, Scalar> metrics = 4;
}
oneof msg {
Disconnect disconnect = 1;
DisconnectRes disconnect_res = 1;
GetPropertiesRes get_properties_res = 2;
GetParametersRes get_parameters_res = 3;
FitRes fit_res = 4;
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
EXPECTED_NUM_SERVER_MESSAGE = 10

SERVER_MESSAGE = ServerMessage()
SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect=ServerMessage.Reconnect())
SERVER_MESSAGE_RECONNECT = ServerMessage(reconnect_ins=ServerMessage.ReconnectIns())

CLIENT_MESSAGE = ClientMessage()
CLIENT_MESSAGE_DISCONNECT = ClientMessage(disconnect=ClientMessage.Disconnect())
CLIENT_MESSAGE_DISCONNECT = ClientMessage(disconnect_res=ClientMessage.DisconnectRes())


def unused_tcp_port() -> int:
Expand Down Expand Up @@ -63,7 +63,7 @@ def mock_join( # type: ignore # pylint: disable=invalid-name

try:
client_message = next(request_iterator)
if client_message.HasField("disconnect"):
if client_message.HasField("disconnect_res"):
break
except StopIteration:
break
Expand Down Expand Up @@ -100,7 +100,7 @@ def run_client() -> int:
server_message = receive()

messages_received += 1
if server_message.HasField("reconnect"):
if server_message.HasField("reconnect_ins"):
send(CLIENT_MESSAGE_DISCONNECT)
break

Expand Down
14 changes: 7 additions & 7 deletions src/py/flwr/client/grpc_client/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def handle(
reconnect later (False).
"""
field = server_msg.WhichOneof("msg")
if field == "reconnect":
disconnect_msg, sleep_duration = _reconnect(server_msg.reconnect)
if field == "reconnect_ins":
disconnect_msg, sleep_duration = _reconnect(server_msg.reconnect_ins)
return disconnect_msg, sleep_duration, False
if field == "get_properties_ins":
return _get_properties(client, server_msg.get_properties_ins), 0, True
Expand All @@ -65,17 +65,17 @@ def handle(


def _reconnect(
reconnect_msg: ServerMessage.Reconnect,
reconnect_msg: ServerMessage.ReconnectIns,
) -> Tuple[ClientMessage, int]:
# Determine the reason for sending Disconnect message
# Determine the reason for sending DisconnectRes message
reason = Reason.ACK
sleep_duration = None
if reconnect_msg.seconds is not None:
reason = Reason.RECONNECT
sleep_duration = reconnect_msg.seconds
# Build Disconnect message
disconnect = ClientMessage.Disconnect(reason=reason)
return ClientMessage(disconnect=disconnect), sleep_duration
# Build DisconnectRes message
disconnect_res = ClientMessage.DisconnectRes(reason=reason)
return ClientMessage(disconnect_res=disconnect_res), sleep_duration


def _get_properties(
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .parameter import weights_to_parameters as weights_to_parameters
from .typing import Code as Code
from .typing import Config as Config
from .typing import Disconnect as Disconnect
from .typing import DisconnectRes as DisconnectRes
from .typing import EvaluateIns as EvaluateIns
from .typing import EvaluateRes as EvaluateRes
from .typing import FitIns as FitIns
Expand All @@ -34,7 +34,7 @@
from .typing import MetricsAggregationFn as MetricsAggregationFn
from .typing import Parameters as Parameters
from .typing import Properties as Properties
from .typing import Reconnect as Reconnect
from .typing import ReconnectIns as ReconnectIns
from .typing import Scalar as Scalar
from .typing import Status as Status
from .typing import Weights as Weights
Expand All @@ -45,7 +45,7 @@
"bytes_to_ndarray",
"Code",
"Config",
"Disconnect",
"DisconnectRes",
"EvaluateIns",
"EvaluateRes",
"FitIns",
Expand All @@ -61,7 +61,7 @@
"Parameters",
"parameters_to_weights",
"Properties",
"Reconnect",
"ReconnectIns",
"Scalar",
"Status",
"Weights",
Expand Down
44 changes: 22 additions & 22 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,45 +43,45 @@ def parameters_from_proto(msg: Parameters) -> typing.Parameters:
return typing.Parameters(tensors=tensors, tensor_type=msg.tensor_type)


# === Reconnect message ===
# === ReconnectIns message ===


def reconnect_to_proto(reconnect: typing.Reconnect) -> ServerMessage.Reconnect:
"""Serialize Reconnect to ProtoBuf message."""
if reconnect.seconds is not None:
return ServerMessage.Reconnect(seconds=reconnect.seconds)
return ServerMessage.Reconnect()
def reconnect_ins_to_proto(ins: typing.ReconnectIns) -> ServerMessage.ReconnectIns:
"""Serialize ReconnectIns to ProtoBuf message."""
if ins.seconds is not None:
return ServerMessage.ReconnectIns(seconds=ins.seconds)
return ServerMessage.ReconnectIns()


def reconnect_from_proto(msg: ServerMessage.Reconnect) -> typing.Reconnect:
"""Deserialize Reconnect from ProtoBuf message."""
return typing.Reconnect(seconds=msg.seconds)
def reconnect_ins_from_proto(msg: ServerMessage.ReconnectIns) -> typing.ReconnectIns:
"""Deserialize ReconnectIns from ProtoBuf message."""
return typing.ReconnectIns(seconds=msg.seconds)


# === Disconnect message ===
# === DisconnectRes message ===


def disconnect_to_proto(disconnect: typing.Disconnect) -> ClientMessage.Disconnect:
"""Serialize Disconnect to ProtoBuf message."""
def disconnect_res_to_proto(res: typing.DisconnectRes) -> ClientMessage.DisconnectRes:
"""Serialize DisconnectRes to ProtoBuf message."""
reason_proto = Reason.UNKNOWN
if disconnect.reason == "RECONNECT":
if res.reason == "RECONNECT":
reason_proto = Reason.RECONNECT
elif disconnect.reason == "POWER_DISCONNECTED":
elif res.reason == "POWER_DISCONNECTED":
reason_proto = Reason.POWER_DISCONNECTED
elif disconnect.reason == "WIFI_UNAVAILABLE":
elif res.reason == "WIFI_UNAVAILABLE":
reason_proto = Reason.WIFI_UNAVAILABLE
return ClientMessage.Disconnect(reason=reason_proto)
return ClientMessage.DisconnectRes(reason=reason_proto)


def disconnect_from_proto(msg: ClientMessage.Disconnect) -> typing.Disconnect:
"""Deserialize Disconnect from ProtoBuf message."""
def disconnect_res_from_proto(msg: ClientMessage.DisconnectRes) -> typing.DisconnectRes:
"""Deserialize DisconnectRes from ProtoBuf message."""
if msg.reason == Reason.RECONNECT:
return typing.Disconnect(reason="RECONNECT")
return typing.DisconnectRes(reason="RECONNECT")
if msg.reason == Reason.POWER_DISCONNECTED:
return typing.Disconnect(reason="POWER_DISCONNECTED")
return typing.DisconnectRes(reason="POWER_DISCONNECTED")
if msg.reason == Reason.WIFI_UNAVAILABLE:
return typing.Disconnect(reason="WIFI_UNAVAILABLE")
return typing.Disconnect(reason="UNKNOWN")
return typing.DisconnectRes(reason="WIFI_UNAVAILABLE")
return typing.DisconnectRes(reason="UNKNOWN")


# === GetParameters messages ===
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ class GetPropertiesRes:


@dataclass
class Reconnect:
"""Reconnect message from server to client."""
class ReconnectIns:
"""ReconnectIns message from server to client."""

seconds: Optional[int]


@dataclass
class Disconnect:
"""Disconnect message from client to server."""
class DisconnectRes:
"""DisconnectRes message from client to server."""

reason: str
Loading

0 comments on commit 342b5ee

Please sign in to comment.