Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename some gRPC components to improve consistency #1491

Merged
merged 3 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 25 additions & 26 deletions src/py/flwr/server/grpc_server/flower_service_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
"""Servicer for FlowerService.

Relevant knowledge for reading this modules code:
- https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
- https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
"""


from typing import Callable, Iterator

import grpc
Expand All @@ -25,36 +27,34 @@
from flwr.proto import transport_pb2_grpc
from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
from flwr.server.client_manager import ClientManager
from flwr.server.grpc_server.grpc_bridge import GRPCBridge, InsWrapper, ResWrapper
from flwr.server.grpc_server.grpc_bridge import GrpcBridge, InsWrapper, ResWrapper
from flwr.server.grpc_server.grpc_client_proxy import GrpcClientProxy


def default_bridge_factory() -> GRPCBridge:
"""Return GRPCBridge instance."""
return GRPCBridge()
def default_bridge_factory() -> GrpcBridge:
"""Return GrpcBridge instance."""
return GrpcBridge()


def default_grpc_client_factory(cid: str, bridge: GRPCBridge) -> GrpcClientProxy:
def default_grpc_client_proxy_factory(cid: str, bridge: GrpcBridge) -> GrpcClientProxy:
"""Return GrpcClientProxy instance."""
return GrpcClientProxy(cid=cid, bridge=bridge)


def register_client(
def register_client_proxy(
client_manager: ClientManager,
client: GrpcClientProxy,
client_proxy: GrpcClientProxy,
context: grpc.ServicerContext,
) -> bool:
"""Try registering GrpcClientProxy with ClientManager."""
is_success = client_manager.register(client)

is_success = client_manager.register(client_proxy)
if is_success:

def rpc_termination_callback() -> None:
client.bridge.close()
client_manager.unregister(client)
client_proxy.bridge.close()
client_manager.unregister(client_proxy)

context.add_callback(rpc_termination_callback)

return is_success


Expand All @@ -64,33 +64,32 @@ class FlowerServiceServicer(transport_pb2_grpc.FlowerServiceServicer):
def __init__(
self,
client_manager: ClientManager,
grpc_bridge_factory: Callable[[], GRPCBridge] = default_bridge_factory,
grpc_client_factory: Callable[
[str, GRPCBridge], GrpcClientProxy
] = default_grpc_client_factory,
grpc_bridge_factory: Callable[[], GrpcBridge] = default_bridge_factory,
grpc_client_proxy_factory: Callable[
[str, GrpcBridge], GrpcClientProxy
] = default_grpc_client_proxy_factory,
) -> None:
self.client_manager: ClientManager = client_manager
self.grpc_bridge_factory = grpc_bridge_factory
self.client_factory = grpc_client_factory
self.client_proxy_factory = grpc_client_proxy_factory

def Join( # pylint: disable=invalid-name
self,
request_iterator: Iterator[ClientMessage],
context: grpc.ServicerContext,
) -> Iterator[ServerMessage]:
"""Method will be invoked by each GrpcClientProxy which participates in
the network.
"""Invoked by each gRPC client which participates in the network.

Protocol:
- The first message is sent from the server to the client
- Both ServerMessage and ClientMessage are message "wrappers"
wrapping the actual message
- The Join method is (pretty much) protocol unaware
- The first message is sent from the server to the client
- Both `ServerMessage` and `ClientMessage` are message "wrappers"
wrapping the actual message
- The `Join` method is (pretty much) unaware of the protocol
"""
peer: str = context.peer()
bridge = self.grpc_bridge_factory()
client = self.client_factory(peer, bridge)
is_success = register_client(self.client_manager, client, context)
client_proxy = self.client_proxy_factory(peer, bridge)
is_success = register_client_proxy(self.client_manager, client_proxy, context)

if is_success:
# Get iterators
Expand Down
22 changes: 12 additions & 10 deletions src/py/flwr/server/grpc_server/flower_service_servicer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# limitations under the License.
# ==============================================================================
"""Tests for FlowerServiceServicer."""


import unittest
from unittest.mock import MagicMock, call

from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
from flwr.server.grpc_server.flower_service_servicer import (
FlowerServiceServicer,
register_client,
register_client_proxy,
)
from flwr.server.grpc_server.grpc_bridge import InsWrapper, ResWrapper

Expand Down Expand Up @@ -53,7 +55,7 @@ def setUp(self) -> None:
]
self.ins_wrapper_iterator = iter(self.ins_wrappers)

# Mock for GRPCBridge
# Mock for GrpcBridge
self.grpc_bridge_mock = MagicMock()
self.grpc_bridge_mock.ins_wrapper_iterator.return_value = (
self.ins_wrapper_iterator
Expand All @@ -67,20 +69,20 @@ def setUp(self) -> None:
self.grpc_client_proxy_mock = MagicMock()
self.grpc_client_proxy_mock.cid = CLIENT_CID

self.client_factory_mock = MagicMock()
self.client_factory_mock.return_value = self.grpc_client_proxy_mock
self.client_proxy_factory_mock = MagicMock()
self.client_proxy_factory_mock.return_value = self.grpc_client_proxy_mock

self.client_manager_mock = MagicMock()

def test_register_client(self) -> None:
"""Test register_client function."""
def test_register_client_proxy(self) -> None:
"""Test register_client_proxy function."""
# Prepare
self.client_manager_mock.register.return_value = True

# Execute
register_client(
register_client_proxy(
client_manager=self.client_manager_mock,
client=self.grpc_client_proxy_mock,
client_proxy=self.grpc_client_proxy_mock,
context=self.context_mock,
)

Expand All @@ -107,7 +109,7 @@ def test_join(self) -> None:
servicer = FlowerServiceServicer(
client_manager=self.client_manager_mock,
grpc_bridge_factory=self.grpc_bridge_factory_mock,
grpc_client_factory=self.client_factory_mock,
grpc_client_proxy_factory=self.client_proxy_factory_mock,
)

# Execute
Expand All @@ -124,7 +126,7 @@ def test_join(self) -> None:
assert len(self.client_messages) == num_server_messages
assert self.grpc_client_proxy_mock.cid == CLIENT_CID

self.client_factory_mock.assert_called_once_with(
self.client_proxy_factory_mock.assert_called_once_with(
CLIENT_CID, self.grpc_bridge_mock
)

Expand Down
13 changes: 7 additions & 6 deletions src/py/flwr/server/grpc_server/grpc_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides class GRPCBridge."""
"""Provides class GrpcBridge."""


from dataclasses import dataclass
from enum import Enum
Expand All @@ -37,8 +38,8 @@ class ResWrapper:
client_message: ClientMessage


class GRPCBridgeClosed(Exception):
"""Error signaling that GRPCBridge is closed."""
class GrpcBridgeClosed(Exception):
"""Error signaling that GrpcBridge is closed."""


class Status(Enum):
Expand All @@ -51,8 +52,8 @@ class Status(Enum):
CLOSED = 5


class GRPCBridge:
"""GRPCBridge holding res_wrapper and ins_wrapper.
class GrpcBridge:
"""GrpcBridge holding res_wrapper and ins_wrapper.

For understanding this class it is recommended to understand how
the threading.Condition class works. See here:
Expand All @@ -74,7 +75,7 @@ def _is_closed(self) -> bool:

def _raise_if_closed(self) -> None:
if self._status == Status.CLOSED:
raise GRPCBridgeClosed()
raise GrpcBridgeClosed()

def _transition(self, next_status: Status) -> None:
"""Validate status transition and set next status.
Expand Down
26 changes: 13 additions & 13 deletions src/py/flwr/server/grpc_server/grpc_bridge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for GRPCBridge class."""
"""Tests for GrpcBridge class."""


import time
Expand All @@ -21,15 +21,15 @@

from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
from flwr.server.grpc_server.grpc_bridge import (
GRPCBridge,
GRPCBridgeClosed,
GrpcBridge,
GrpcBridgeClosed,
InsWrapper,
ResWrapper,
)


def start_worker(
rounds: int, bridge: GRPCBridge, results: List[ClientMessage]
rounds: int, bridge: GrpcBridge, results: List[ClientMessage]
) -> Thread:
"""Simulate processing loop with five calls."""

Expand All @@ -41,7 +41,7 @@ def _worker() -> None:
res_wrapper = bridge.request(
InsWrapper(server_message=ServerMessage(), timeout=None)
)
except GRPCBridgeClosed:
except GrpcBridgeClosed:
break

results.append(res_wrapper.client_message)
Expand All @@ -58,7 +58,7 @@ def test_workflow_successful() -> None:
rounds = 5
client_messages_received: List[ClientMessage] = []

bridge = GRPCBridge()
bridge = GrpcBridge()
ins_wrapper_iterator = bridge.ins_wrapper_iterator()

worker_thread = start_worker(rounds, bridge, client_messages_received)
Expand Down Expand Up @@ -88,12 +88,12 @@ def test_workflow_close() -> None:
rounds = 5
client_messages_received: List[ClientMessage] = []

bridge = GRPCBridge()
bridge = GrpcBridge()
ins_wrapper_iterator = bridge.ins_wrapper_iterator()

worker_thread = start_worker(rounds, bridge, client_messages_received)

raised_error: Union[GRPCBridgeClosed, StopIteration, None] = None
raised_error: Union[GrpcBridgeClosed, StopIteration, None] = None

# Execute
for i in range(rounds):
Expand All @@ -109,7 +109,7 @@ def test_workflow_close() -> None:
# on next invocation.
bridge.close()

except GRPCBridgeClosed as err:
except GrpcBridgeClosed as err:
raised_error = err
break
except StopIteration as err:
Expand All @@ -133,12 +133,12 @@ def test_ins_wrapper_iterator_close_while_blocking() -> None:
rounds = 5
client_messages_received: List[ClientMessage] = []

bridge = GRPCBridge()
bridge = GrpcBridge()
ins_wrapper_iterator = bridge.ins_wrapper_iterator()

worker_thread = start_worker(rounds, bridge, client_messages_received)

raised_error: Union[GRPCBridgeClosed, StopIteration, None] = None
raised_error: Union[GrpcBridgeClosed, StopIteration, None] = None

def close_bridge_delayed(secs: int) -> None:
"""Close brige after {secs} second(s)."""
Expand All @@ -160,7 +160,7 @@ def close_bridge_delayed(secs: int) -> None:
if i < 2:
bridge.set_res_wrapper(ResWrapper(ClientMessage()))

except GRPCBridgeClosed as err:
except GrpcBridgeClosed as err:
raised_error = err
break
except StopIteration as err:
Expand All @@ -172,4 +172,4 @@ def close_bridge_delayed(secs: int) -> None:

# Assert
assert len(client_messages_received) == 2
assert isinstance(raised_error, GRPCBridgeClosed)
assert isinstance(raised_error, GrpcBridgeClosed)
7 changes: 4 additions & 3 deletions src/py/flwr/server/grpc_server/grpc_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,23 @@
# ==============================================================================
"""gRPC-based Flower ClientProxy implementation."""


from typing import Optional

from flwr import common
from flwr.common import serde
from flwr.proto.transport_pb2 import ClientMessage, ServerMessage
from flwr.server.client_proxy import ClientProxy
from flwr.server.grpc_server.grpc_bridge import GRPCBridge, InsWrapper, ResWrapper
from flwr.server.grpc_server.grpc_bridge import GrpcBridge, InsWrapper, ResWrapper


class GrpcClientProxy(ClientProxy):
"""Flower client proxy which delegates over the network using gRPC."""
"""Flower ClientProxy that uses gRPC to delegate tasks over the network."""

def __init__(
self,
cid: str,
bridge: GRPCBridge,
bridge: GrpcBridge,
):
super().__init__(cid)
self.bridge = bridge
Expand Down