Skip to content

Commit

Permalink
Add CreateWorkload (#2251)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <daniel@flower.dev>
Co-authored-by: Charles Beauville <charles@adap.com>
  • Loading branch information
3 people authored Sep 7, 2023
1 parent 0c4289d commit 9e645a1
Show file tree
Hide file tree
Showing 17 changed files with 389 additions and 69 deletions.
10 changes: 8 additions & 2 deletions examples/mt-pytorch/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,14 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:

# -------------------------------------------------------------------------- Driver SDK
driver.connect()
create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload(
req=driver_pb2.CreateWorkloadRequest()
)
# -------------------------------------------------------------------------- Driver SDK

workload_id = create_workload_res.workload_id
print(f"Created workload id {workload_id}")

history = History()
for server_round in range(num_rounds):
print(f"Commencing server round {server_round + 1}")
Expand Down Expand Up @@ -83,7 +89,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
# loop and wait until enough client nodes are available.
while True:
# Get a list of node ID's from the server
get_nodes_req = driver_pb2.GetNodesRequest()
get_nodes_req = driver_pb2.GetNodesRequest(workload_id=workload_id)

# ---------------------------------------------------------------------- Driver SDK
get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes(
Expand Down Expand Up @@ -117,7 +123,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
new_task_ins = task_pb2.TaskIns(
task_id="", # Do not set, will be created and set by the DriverAPI
group_id="",
workload_id="",
workload_id=workload_id,
task=task_pb2.Task(
producer=node_pb2.Node(
node_id=0,
Expand Down
9 changes: 8 additions & 1 deletion src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import "flwr/proto/node.proto";
import "flwr/proto/task.proto";

service Driver {
// Request workload_id
rpc CreateWorkload(CreateWorkloadRequest) returns (CreateWorkloadResponse) {}

// Return a set of nodes
rpc GetNodes(GetNodesRequest) returns (GetNodesResponse) {}

Expand All @@ -31,8 +34,12 @@ service Driver {
rpc PullTaskRes(PullTaskResRequest) returns (PullTaskResResponse) {}
}

// CreateWorkload
message CreateWorkloadRequest {}
message CreateWorkloadResponse { string workload_id = 1; }

// GetNodes messages
message GetNodesRequest {}
message GetNodesRequest { string workload_id = 1; }
message GetNodesResponse { repeated uint64 node_ids = 1; }

// PushTaskIns messages
Expand Down
13 changes: 13 additions & 0 deletions src/py/flwr/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ def disconnect(self) -> None:
channel.close()
log(INFO, "[Driver] Disconnected")

def create_workload(
self, req: driver_pb2.CreateWorkloadRequest
) -> driver_pb2.CreateWorkloadResponse:
"""Request for workload ID."""
# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise Exception("`Driver` instance not connected")

# Call Driver API
res: driver_pb2.CreateWorkloadResponse = self.stub.CreateWorkload(request=req)
return res

def get_nodes(self, req: driver_pb2.GetNodesRequest) -> driver_pb2.GetNodesResponse:
"""Get client IDs."""
# Check if channel is open
Expand Down
10 changes: 9 additions & 1 deletion src/py/flwr/driver/driver_client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class DriverClientManager(ClientManager):

def __init__(self, driver: Driver) -> None:
self.driver = driver
self.workload_id = ""
self.clients: Dict[str, ClientProxy] = {}

def __len__(self) -> int:
Expand Down Expand Up @@ -137,11 +138,18 @@ def _update_nodes(self) -> None:
node id is then converted into a `DriverClientProxy` instance and stored in the
`clients` dictionary with node id as key.
"""
get_nodes_res = self.driver.get_nodes(req=driver_pb2.GetNodesRequest())
if self.workload_id == "":
self.workload_id = self.driver.create_workload(
driver_pb2.CreateWorkloadRequest()
).workload_id
get_nodes_res = self.driver.get_nodes(
req=driver_pb2.GetNodesRequest(workload_id=self.workload_id)
)
all_node_ids = get_nodes_res.node_ids
for node_id in all_node_ids:
self.clients[str(node_id)] = DriverClientProxy(
node_id=node_id,
driver=self.driver,
anonymous=False,
workload_id=self.workload_id,
)
5 changes: 3 additions & 2 deletions src/py/flwr/driver/driver_client_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
class DriverClientProxy(ClientProxy):
"""Flower client proxy which delegates work using the Driver API."""

def __init__(self, node_id: int, driver: Driver, anonymous: bool):
def __init__(self, node_id: int, driver: Driver, anonymous: bool, workload_id: str):
super().__init__(str(node_id))
self.node_id = node_id
self.driver = driver
self.workload_id = workload_id
self.anonymous = anonymous

def get_properties(
Expand Down Expand Up @@ -103,7 +104,7 @@ def _send_receive_msg(
task_ins = task_pb2.TaskIns(
task_id="",
group_id="",
workload_id="",
workload_id=self.workload_id,
task=task_pb2.Task(
producer=node_pb2.Node(
node_id=0,
Expand Down
16 changes: 12 additions & 4 deletions src/py/flwr/driver/driver_client_proxy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def test_get_properties(self) -> None:
)
]
)
client = DriverClientProxy(node_id=1, driver=self.driver, anonymous=True)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=""
)
request_properties: Config = {"tensor_type": "str"}
ins: flwr.common.GetPropertiesIns = flwr.common.GetPropertiesIns(
config=request_properties
Expand Down Expand Up @@ -95,7 +97,9 @@ def test_get_parameters(self) -> None:
)
]
)
client = DriverClientProxy(node_id=1, driver=self.driver, anonymous=True)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=""
)
get_parameters_ins = GetParametersIns(config={})

# Execute
Expand Down Expand Up @@ -129,7 +133,9 @@ def test_fit(self) -> None:
)
]
)
client = DriverClientProxy(node_id=1, driver=self.driver, anonymous=True)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=""
)
parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))])
ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {})

Expand Down Expand Up @@ -163,7 +169,9 @@ def test_evaluate(self) -> None:
)
]
)
client = DriverClientProxy(node_id=1, driver=self.driver, anonymous=True)
client = DriverClientProxy(
node_id=1, driver=self.driver, anonymous=True, workload_id=""
)
parameters = flwr.common.Parameters(tensors=[], tensor_type="np")
evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {})

Expand Down
50 changes: 35 additions & 15 deletions src/py/flwr/proto/driver_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions src/py/flwr/proto/driver_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,34 @@ import typing_extensions

DESCRIPTOR: google.protobuf.descriptor.FileDescriptor

class CreateWorkloadRequest(google.protobuf.message.Message):
"""CreateWorkload"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
def __init__(self,
) -> None: ...
global___CreateWorkloadRequest = CreateWorkloadRequest

class CreateWorkloadResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
WORKLOAD_ID_FIELD_NUMBER: builtins.int
workload_id: typing.Text
def __init__(self,
*,
workload_id: typing.Text = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ...
global___CreateWorkloadResponse = CreateWorkloadResponse

class GetNodesRequest(google.protobuf.message.Message):
"""GetNodes messages"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
WORKLOAD_ID_FIELD_NUMBER: builtins.int
workload_id: typing.Text
def __init__(self,
*,
workload_id: typing.Text = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ...
global___GetNodesRequest = GetNodesRequest

class GetNodesResponse(google.protobuf.message.Message):
Expand Down
34 changes: 34 additions & 0 deletions src/py/flwr/proto/driver_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ def __init__(self, channel):
Args:
channel: A grpc.Channel.
"""
self.CreateWorkload = channel.unary_unary(
'/flwr.proto.Driver/CreateWorkload',
request_serializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.FromString,
)
self.GetNodes = channel.unary_unary(
'/flwr.proto.Driver/GetNodes',
request_serializer=flwr_dot_proto_dot_driver__pb2.GetNodesRequest.SerializeToString,
Expand All @@ -34,6 +39,13 @@ def __init__(self, channel):
class DriverServicer(object):
"""Missing associated documentation comment in .proto file."""

def CreateWorkload(self, request, context):
"""Request workload_id
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetNodes(self, request, context):
"""Return a set of nodes
"""
Expand All @@ -58,6 +70,11 @@ def PullTaskRes(self, request, context):

def add_DriverServicer_to_server(servicer, server):
rpc_method_handlers = {
'CreateWorkload': grpc.unary_unary_rpc_method_handler(
servicer.CreateWorkload,
request_deserializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.FromString,
response_serializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.SerializeToString,
),
'GetNodes': grpc.unary_unary_rpc_method_handler(
servicer.GetNodes,
request_deserializer=flwr_dot_proto_dot_driver__pb2.GetNodesRequest.FromString,
Expand All @@ -83,6 +100,23 @@ def add_DriverServicer_to_server(servicer, server):
class Driver(object):
"""Missing associated documentation comment in .proto file."""

@staticmethod
def CreateWorkload(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/CreateWorkload',
flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.SerializeToString,
flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def GetNodes(request,
target,
Expand Down
13 changes: 13 additions & 0 deletions src/py/flwr/proto/driver_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ import grpc

class DriverStub:
def __init__(self, channel: grpc.Channel) -> None: ...
CreateWorkload: grpc.UnaryUnaryMultiCallable[
flwr.proto.driver_pb2.CreateWorkloadRequest,
flwr.proto.driver_pb2.CreateWorkloadResponse]
"""Request workload_id"""

GetNodes: grpc.UnaryUnaryMultiCallable[
flwr.proto.driver_pb2.GetNodesRequest,
flwr.proto.driver_pb2.GetNodesResponse]
Expand All @@ -25,6 +30,14 @@ class DriverStub:


class DriverServicer(metaclass=abc.ABCMeta):
@abc.abstractmethod
def CreateWorkload(self,
request: flwr.proto.driver_pb2.CreateWorkloadRequest,
context: grpc.ServicerContext,
) -> flwr.proto.driver_pb2.CreateWorkloadResponse:
"""Request workload_id"""
pass

@abc.abstractmethod
def GetNodes(self,
request: flwr.proto.driver_pb2.GetNodesRequest,
Expand Down
13 changes: 12 additions & 1 deletion src/py/flwr/server/driver/driver_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from flwr.common.logger import log
from flwr.proto import driver_pb2_grpc
from flwr.proto.driver_pb2 import (
CreateWorkloadRequest,
CreateWorkloadResponse,
GetNodesRequest,
GetNodesResponse,
PullTaskResRequest,
Expand All @@ -48,9 +50,18 @@ def GetNodes(
"""Get available nodes."""
log(INFO, "DriverServicer.GetNodes")
state: State = self.state_factory.state()
all_ids: Set[int] = state.get_nodes()
all_ids: Set[int] = state.get_nodes(request.workload_id)
return GetNodesResponse(node_ids=list(all_ids))

def CreateWorkload(
self, request: CreateWorkloadRequest, context: grpc.ServicerContext
) -> CreateWorkloadResponse:
"""Create workload ID."""
log(INFO, "DriverServicer.CreateWorkload")
state: State = self.state_factory.state()
workload_id = state.create_workload()
return CreateWorkloadResponse(workload_id=workload_id)

def PushTaskIns(
self, request: PushTaskInsRequest, context: grpc.ServicerContext
) -> PushTaskInsResponse:
Expand Down
Loading

0 comments on commit 9e645a1

Please sign in to comment.