Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CreateWorkload #2251

Merged
merged 23 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions examples/mt-pytorch/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,14 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:

# -------------------------------------------------------------------------- Driver SDK
driver.connect()
create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload(
req=driver_pb2.CreateWorkloadRequest()
)
# -------------------------------------------------------------------------- Driver SDK

workload_id = create_workload_res.workload_id
print(f"Created workload id {workload_id}")

history = History()
for server_round in range(num_rounds):
print(f"Commencing server round {server_round + 1}")
Expand Down Expand Up @@ -83,7 +89,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# loop and wait until enough client nodes are available.
while True:
# Get a list of node ID's from the server
get_nodes_req = driver_pb2.GetNodesRequest()
get_nodes_req = driver_pb2.GetNodesRequest(workload_id=workload_id)

# ---------------------------------------------------------------------- Driver SDK
get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes(
Expand Down Expand Up @@ -117,7 +123,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
new_task_ins = task_pb2.TaskIns(
task_id="", # Do not set, will be created and set by the DriverAPI
group_id="",
workload_id="",
workload_id=workload_id,
task=task_pb2.Task(
producer=node_pb2.Node(
node_id=0,
Expand Down
9 changes: 8 additions & 1 deletion src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import "flwr/proto/node.proto";
import "flwr/proto/task.proto";

service Driver {
// Request workload_id
rpc CreateWorkload(CreateWorkloadRequest) returns (CreateWorkloadResponse) {}

// Return a set of nodes
rpc GetNodes(GetNodesRequest) returns (GetNodesResponse) {}

Expand All @@ -31,8 +34,12 @@ service Driver {
rpc PullTaskRes(PullTaskResRequest) returns (PullTaskResResponse) {}
}

// CreateWorkload
message CreateWorkloadRequest {}
message CreateWorkloadResponse { string workload_id = 1; }

// GetNodes messages
message GetNodesRequest {}
message GetNodesRequest { string workload_id = 1; }
message GetNodesResponse { repeated uint64 node_ids = 1; }

// PushTaskIns messages
Expand Down
13 changes: 13 additions & 0 deletions src/py/flwr/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ def disconnect(self) -> None:
channel.close()
log(INFO, "[Driver] Disconnected")

def create_workload(
self, req: driver_pb2.CreateWorkloadRequest
) -> driver_pb2.CreateWorkloadResponse:
"""Request for workload ID."""
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise Exception("`Driver` instance not connected")

# Call Driver API
res: driver_pb2.CreateWorkloadResponse = self.stub.CreateWorkload(request=req)
return res

def get_nodes(self, req: driver_pb2.GetNodesRequest) -> driver_pb2.GetNodesResponse:
"""Get client IDs."""
# Check if channel is open
Expand Down
8 changes: 7 additions & 1 deletion src/py/flwr/driver/driver_client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class DriverClientManager(ClientManager):

def __init__(self, driver: Driver) -> None:
self.driver = driver
self.workload_id = driver.create_workload(
driver_pb2.CreateWorkloadRequest()
).workload_id
panh99 marked this conversation as resolved.
Show resolved Hide resolved
self.clients: Dict[str, ClientProxy] = {}

def __len__(self) -> int:
Expand Down Expand Up @@ -137,11 +140,14 @@ def _update_nodes(self) -> None:
node id is then converted into a `DriverClientProxy` instance and stored in the
`clients` dictionary with node id as key.
"""
get_nodes_res = self.driver.get_nodes(req=driver_pb2.GetNodesRequest())
get_nodes_res = self.driver.get_nodes(
req=driver_pb2.GetNodesRequest(workload_id=self.workload_id)
)
all_node_ids = get_nodes_res.node_ids
for node_id in all_node_ids:
self.clients[str(node_id)] = DriverClientProxy(
node_id=node_id,
driver=self.driver,
anonymous=False,
workload_id=self.workload_id,
)
5 changes: 3 additions & 2 deletions src/py/flwr/driver/driver_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
class DriverClientProxy(ClientProxy):
"""Flower client proxy which delegates work using the Driver API."""

def __init__(self, node_id: int, driver: Driver, anonymous: bool):
def __init__(self, node_id: int, driver: Driver, anonymous: bool, workload_id: str):
super().__init__(str(node_id))
self.node_id = node_id
self.driver = driver
self.workload_id = workload_id
self.anonymous = anonymous

def get_properties(
Expand Down Expand Up @@ -103,7 +104,7 @@ def _send_receive_msg(
task_ins = task_pb2.TaskIns(
task_id="",
group_id="",
workload_id="",
workload_id=self.workload_id,
task=task_pb2.Task(
producer=node_pb2.Node(
node_id=0,
Expand Down
16 changes: 12 additions & 4 deletions src/py/flwr/driver/driver_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def test_get_properties(self) -> None:
)
]
)
client = DriverClientProxy(node_id=1, driver=self.driver, anonymous=True)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=""
)
request_properties: Config = {"tensor_type": "str"}
ins: flwr.common.GetPropertiesIns = flwr.common.GetPropertiesIns(
config=request_properties
Expand Down Expand Up @@ -95,7 +97,9 @@ def test_get_parameters(self) -> None:
)
]
)
client = DriverClientProxy(node_id=1, driver=self.driver, anonymous=True)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=""
)
get_parameters_ins = GetParametersIns(config={})

# Execute
Expand Down Expand Up @@ -129,7 +133,9 @@ def test_fit(self) -> None:
)
]
)
client = DriverClientProxy(node_id=1, driver=self.driver, anonymous=True)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=""
)
parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))])
ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {})

Expand Down Expand Up @@ -163,7 +169,9 @@ def test_evaluate(self) -> None:
)
]
)
client = DriverClientProxy(node_id=1, driver=self.driver, anonymous=True)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=""
)
parameters = flwr.common.Parameters(tensors=[], tensor_type="np")
evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {})

Expand Down
50 changes: 35 additions & 15 deletions src/py/flwr/proto/driver_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions src/py/flwr/proto/driver_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,34 @@ import typing_extensions

DESCRIPTOR: google.protobuf.descriptor.FileDescriptor

class CreateWorkloadRequest(google.protobuf.message.Message):
"""CreateWorkload"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
def __init__(self,
) -> None: ...
global___CreateWorkloadRequest = CreateWorkloadRequest

class CreateWorkloadResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
WORKLOAD_ID_FIELD_NUMBER: builtins.int
workload_id: typing.Text
def __init__(self,
*,
workload_id: typing.Text = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ...
global___CreateWorkloadResponse = CreateWorkloadResponse

class GetNodesRequest(google.protobuf.message.Message):
"""GetNodes messages"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
WORKLOAD_ID_FIELD_NUMBER: builtins.int
workload_id: typing.Text
def __init__(self,
*,
workload_id: typing.Text = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ...
global___GetNodesRequest = GetNodesRequest

class GetNodesResponse(google.protobuf.message.Message):
Expand Down
34 changes: 34 additions & 0 deletions src/py/flwr/proto/driver_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ def __init__(self, channel):
Args:
channel: A grpc.Channel.
"""
self.CreateWorkload = channel.unary_unary(
'/flwr.proto.Driver/CreateWorkload',
request_serializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.FromString,
)
self.GetNodes = channel.unary_unary(
'/flwr.proto.Driver/GetNodes',
request_serializer=flwr_dot_proto_dot_driver__pb2.GetNodesRequest.SerializeToString,
Expand All @@ -34,6 +39,13 @@ def __init__(self, channel):
class DriverServicer(object):
"""Missing associated documentation comment in .proto file."""

def CreateWorkload(self, request, context):
"""Request workload_id
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetNodes(self, request, context):
"""Return a set of nodes
"""
Expand All @@ -58,6 +70,11 @@ def PullTaskRes(self, request, context):

def add_DriverServicer_to_server(servicer, server):
rpc_method_handlers = {
'CreateWorkload': grpc.unary_unary_rpc_method_handler(
servicer.CreateWorkload,
request_deserializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.FromString,
response_serializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.SerializeToString,
),
'GetNodes': grpc.unary_unary_rpc_method_handler(
servicer.GetNodes,
request_deserializer=flwr_dot_proto_dot_driver__pb2.GetNodesRequest.FromString,
Expand All @@ -83,6 +100,23 @@ def add_DriverServicer_to_server(servicer, server):
class Driver(object):
"""Missing associated documentation comment in .proto file."""

@staticmethod
def CreateWorkload(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/CreateWorkload',
flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.SerializeToString,
flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def GetNodes(request,
target,
Expand Down
13 changes: 13 additions & 0 deletions src/py/flwr/proto/driver_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ import grpc

class DriverStub:
def __init__(self, channel: grpc.Channel) -> None: ...
CreateWorkload: grpc.UnaryUnaryMultiCallable[
flwr.proto.driver_pb2.CreateWorkloadRequest,
flwr.proto.driver_pb2.CreateWorkloadResponse]
"""Request workload_id"""

GetNodes: grpc.UnaryUnaryMultiCallable[
flwr.proto.driver_pb2.GetNodesRequest,
flwr.proto.driver_pb2.GetNodesResponse]
Expand All @@ -25,6 +30,14 @@ class DriverStub:


class DriverServicer(metaclass=abc.ABCMeta):
@abc.abstractmethod
def CreateWorkload(self,
request: flwr.proto.driver_pb2.CreateWorkloadRequest,
context: grpc.ServicerContext,
) -> flwr.proto.driver_pb2.CreateWorkloadResponse:
"""Request workload_id"""
pass

@abc.abstractmethod
def GetNodes(self,
request: flwr.proto.driver_pb2.GetNodesRequest,
Expand Down
13 changes: 12 additions & 1 deletion src/py/flwr/server/driver/driver_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from flwr.common.logger import log
from flwr.proto import driver_pb2_grpc
from flwr.proto.driver_pb2 import (
CreateWorkloadRequest,
CreateWorkloadResponse,
GetNodesRequest,
GetNodesResponse,
PullTaskResRequest,
Expand All @@ -48,9 +50,18 @@ def GetNodes(
"""Get available nodes."""
log(INFO, "DriverServicer.GetNodes")
state: State = self.state_factory.state()
all_ids: Set[int] = state.get_nodes()
all_ids: Set[int] = state.get_nodes(request.workload_id)
return GetNodesResponse(node_ids=list(all_ids))

def CreateWorkload(
self, request: CreateWorkloadRequest, context: grpc.ServicerContext
) -> CreateWorkloadResponse:
"""Create workload ID."""
log(INFO, "DriverServicer.CreateWorkload")
state: State = self.state_factory.state()
workload_id = state.create_workload()
return CreateWorkloadResponse(workload_id=workload_id)

def PushTaskIns(
self, request: PushTaskInsRequest, context: grpc.ServicerContext
) -> PushTaskInsResponse:
Expand Down
Loading
Loading