Skip to content

Commit

Permalink
Align Driver API and Fleet API (#1546)
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Jan 13, 2023
1 parent 685ce14 commit 9572253
Show file tree
Hide file tree
Showing 28 changed files with 851 additions and 841 deletions.
70 changes: 39 additions & 31 deletions examples/mt-pytorch/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,9 @@
import random
import time

from flwr.driver import (
Driver,
GetNodesResponse,
GetNodesRequest,
Task,
Result,
CreateTasksRequest,
CreateTasksResponse,
GetResultsRequest,
GetResultsResponse,
TaskAssignment,
)
from flwr.common import ServerMessage, FitIns, ndarrays_to_parameters
from flwr.driver import Driver
from flwr.common import ServerMessage, FitIns, ndarrays_to_parameters, serde
from flwr.proto import driver_pb2, task_pb2, node_pb2, transport_pb2

from task import Net, get_parameters, set_parameters

Expand All @@ -34,58 +24,76 @@
print(f"Commencing server round {server_round + 1}")

# Get a list of node ID's from the server
get_nodes_req = GetNodesRequest()
get_nodes_req = driver_pb2.GetNodesRequest()

# ---------------------------------------------------------------------- Driver SDK
get_nodes_res: GetNodesResponse = driver.get_nodes(req=get_nodes_req)
get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes(req=get_nodes_req)
# ---------------------------------------------------------------------- Driver SDK

# Sample three nodes
all_node_ids: List[int] = get_nodes_res.node_ids
print(f"Got {len(all_node_ids)} node IDs")
num_node_ids_to_sample = 3 if len(all_node_ids) >= 3 else 1
sampled_node_ids: List[int] = random.sample(all_node_ids, num_node_ids_to_sample)
print(f"Sampled {len(sampled_node_ids)} node IDs")
print(f"Sampled {len(sampled_node_ids)} node IDs: {sampled_node_ids}")

time.sleep(sleep_time)

# Schedule a task for all three nodes
fit_ins: FitIns = FitIns(parameters=parameters, config={})
task = Task(task_id=123, legacy_server_message=ServerMessage(fit_ins=fit_ins))
task_assignment: TaskAssignment = TaskAssignment(
task=task, node_ids=sampled_node_ids
server_message = ServerMessage(fit_ins=fit_ins)
server_message_proto: transport_pb2.ServerMessage = serde.server_message_to_proto(
server_message=server_message
)
create_tasks_req = CreateTasksRequest(task_assignments=[task_assignment])
task_ins_set: List[task_pb2.TaskIns] = []
for sampled_node_id in sampled_node_ids:
new_task_ins = task_pb2.TaskIns(
task_id="", # Will be created and set by the DriverAPI
task=task_pb2.Task(
producer=node_pb2.Node(node_id=0, anonymous=True),
consumer=node_pb2.Node(node_id=sampled_node_id, anonymous=False),
legacy_server_message=server_message_proto,
),
)
task_ins_set.append(new_task_ins)

push_task_ins_req = driver_pb2.PushTaskInsRequest(task_ins_set=task_ins_set)

# ---------------------------------------------------------------------- Driver SDK
create_tasks_res: CreateTasksResponse = driver.create_tasks(req=create_tasks_req)
push_task_ins_res: driver_pb2.PushTaskInsResponse = driver.push_task_ins(
req=push_task_ins_req
)
# ---------------------------------------------------------------------- Driver SDK

print(f"Scheduled {len(create_tasks_res.task_ids)} tasks")
print(
f"Scheduled {len(push_task_ins_res.task_ids)} tasks: {push_task_ins_res.task_ids}"
)

time.sleep(sleep_time)

# Wait for results
task_ids: List[int] = create_tasks_res.task_ids
all_results: List[Result] = []
task_ids: List[str] = push_task_ins_res.task_ids
all_task_res: List[task_pb2.TaskRes] = []
while True:
get_results_req = GetResultsRequest(task_ids=task_ids)
pull_task_res_req = driver_pb2.PullTaskResRequest(task_ids=task_ids)

# ------------------------------------------------------------------ Driver SDK
get_results_res: GetResultsResponse = driver.get_results(req=get_results_req)
pull_task_res_res: driver_pb2.PullTaskResResponse = driver.pull_task_res(
req=pull_task_res_req
)
# ------------------------------------------------------------------ Driver SDK

results: List[Result] = get_results_res.results
print(f"Got {len(get_results_res.results)} results")
task_res_set: List[task_pb2.TaskRes] = pull_task_res_res.task_res_set
print(f"Got {len(task_res_set)} results")

time.sleep(sleep_time)

all_results += results
if len(all_results) == len(task_ids):
all_task_res += task_res_set
if len(all_task_res) == len(task_ids):
break

# "Aggregate" results
node_messages = [result.legacy_client_message for result in all_results]
node_messages = [task_res.task.legacy_client_message for task_res in all_task_res]
print(f"Received {len(node_messages)} results")

time.sleep(sleep_time)
Expand Down
19 changes: 11 additions & 8 deletions src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,23 @@ service Driver {
rpc GetNodes(GetNodesRequest) returns (GetNodesResponse) {}

// Create one or more tasks
rpc CreateTasks(CreateTasksRequest) returns (CreateTasksResponse) {}
rpc PushTaskIns(PushTaskInsRequest) returns (PushTaskInsResponse) {}

// Get task results
rpc GetResults(GetResultsRequest) returns (GetResultsResponse) {}
rpc PullTaskRes(PullTaskResRequest) returns (PullTaskResResponse) {}
}

// GetNodes messages
message GetNodesRequest {}
message GetNodesResponse { repeated uint64 node_ids = 1; }

// CreateTasks messages
message CreateTasksRequest { repeated TaskAssignment task_assignments = 1; }
message CreateTasksResponse { repeated uint64 task_ids = 1; }
// PushTaskIns messages
message PushTaskInsRequest { repeated TaskIns task_ins_set = 1; }
message PushTaskInsResponse { repeated string task_ids = 2; }

// GetResults messages
message GetResultsRequest { repeated uint64 task_ids = 1; }
message GetResultsResponse { repeated Result results = 1; }
// PullTaskRes messages
message PullTaskResRequest {
uint64 node_id = 1;
repeated string task_ids = 2;
}
message PullTaskResResponse { repeated TaskRes task_res_set = 1; }
35 changes: 25 additions & 10 deletions src/proto/flwr/proto/fleet.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,32 @@ package flwr.proto;
import "flwr/proto/task.proto";

service Fleet {
// Get tasks
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse) {}
// Retrieve one or more tasks, if possible
//
// HTTP API path: /api/v1/fleet/pull-task-ins
rpc PullTaskIns(PullTaskInsRequest) returns (PullTaskInsResponse) {}

// Get results
rpc CreateResults(CreateResultsRequest) returns (CreateResultsResponse) {}
// Complete one or more tasks, if possible
//
// HTTP API path: /api/v1/fleet/push-task-res
rpc PushTaskRes(PushTaskResRequest) returns (PushTaskResResponse) {}
}

// GetTasks messages
message GetTasksRequest {}
message GetTasksResponse { repeated Task task = 1; }
// PullTaskIns messages
message PullTaskInsRequest {
uint64 node_id = 1;
repeated string task_ids = 2;
}
message PullTaskInsResponse {
Reconnect reconnect = 1;
repeated TaskIns task_ins_set = 2;
}

// PushTaskRes messages
message PushTaskResRequest { repeated TaskRes task_res_set = 1; }
message PushTaskResResponse {
Reconnect reconnect = 1;
map<string, uint32> results = 2;
}

// CreateResults messages
message CreateResultsRequest { repeated Result results = 1; }
message CreateResultsResponse {}
message Reconnect { uint64 reconnect = 1; }
23 changes: 23 additions & 0 deletions src/proto/flwr/proto/node.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright 2022 Adap GmbH. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ==============================================================================

syntax = "proto3";

package flwr.proto;

message Node {
uint64 node_id = 1;
bool anonymous = 2;
}
22 changes: 14 additions & 8 deletions src/proto/flwr/proto/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,27 @@ syntax = "proto3";

package flwr.proto;

import "flwr/proto/node.proto";
import "flwr/proto/transport.proto";

message Task {
uint64 task_id = 1;
Node producer = 1;
Node consumer = 2;
string created_at = 3;
string delivered_at = 4;
string ttl = 5;
repeated string ancestry = 6;

ServerMessage legacy_server_message = 101 [ deprecated = true ];
ClientMessage legacy_client_message = 102 [ deprecated = true ];
}

message TaskAssignment {
Task task = 1;
repeated uint64 node_ids = 2;
message TaskIns {
string task_id = 1;
Task task = 2;
}

message Result {
uint64 task_id = 1;

ClientMessage legacy_client_message = 101 [ deprecated = true ];
message TaskRes {
string task_id = 1;
Task task = 2;
}
18 changes: 0 additions & 18 deletions src/py/flwr/driver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,7 @@


from .driver import Driver as Driver
from .messages import CreateTasksRequest as CreateTasksRequest
from .messages import CreateTasksResponse as CreateTasksResponse
from .messages import GetNodesRequest as GetNodesRequest
from .messages import GetNodesResponse as GetNodesResponse
from .messages import GetResultsRequest as GetResultsRequest
from .messages import GetResultsResponse as GetResultsResponse
from .messages import Result as Result
from .messages import Task as Task
from .messages import TaskAssignment as TaskAssignment

__all__ = [
"Driver",
"CreateTasksRequest",
"CreateTasksResponse",
"GetNodesRequest",
"GetNodesResponse",
"GetResultsRequest",
"GetResultsResponse",
"Task",
"TaskAssignment",
"Result",
]
51 changes: 22 additions & 29 deletions src/py/flwr/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,8 @@

from flwr.common.grpc import create_channel
from flwr.common.logger import log
from flwr.driver import serde
from flwr.proto import driver_pb2, driver_pb2_grpc

from .messages import (
CreateTasksRequest,
CreateTasksResponse,
GetNodesRequest,
GetNodesResponse,
GetResultsRequest,
GetResultsResponse,
)

DEFAULT_SERVER_ADDRESS_DRIVER = "[::]:9091"

ERROR_MESSAGE_DRIVER_NOT_CONNECTED = """
Expand Down Expand Up @@ -80,39 +70,42 @@ def disconnect(self) -> None:
channel.close()
log(INFO, "[Driver] Disconnected")

def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse:
def get_nodes(self, req: driver_pb2.GetNodesRequest) -> driver_pb2.GetNodesResponse:
"""Get client IDs."""

# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise Exception("`Driver` instance not connected")

# Serialize, call Driver API, deserialize
req_proto = serde.get_nodes_request_to_proto(req)
res_proto: driver_pb2.GetNodesResponse = self.stub.GetNodes(request=req_proto)
return serde.get_nodes_response_from_proto(res_proto)
# Call Driver API
res: driver_pb2.GetNodesResponse = self.stub.GetNodes(request=req)
return res

def create_tasks(self, req: CreateTasksRequest) -> CreateTasksResponse:
def push_task_ins(
self, req: driver_pb2.PushTaskInsRequest
) -> driver_pb2.PushTaskInsResponse:
"""Schedule tasks."""

# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise Exception("`Driver` instance not connected")

# Serialize, call Driver API, deserialize
req_proto = serde.create_tasks_request_to_proto(req)
res_proto: driver_pb2.CreateTasksResponse = self.stub.CreateTasks(
request=req_proto
)
return serde.create_tasks_response_from_proto(res_proto)
# Call Driver API
res: driver_pb2.PushTaskInsResponse = self.stub.PushTaskIns(request=req)
return res

def get_results(self, req: GetResultsRequest) -> GetResultsResponse:
def pull_task_res(
self, req: driver_pb2.PullTaskResRequest
) -> driver_pb2.PullTaskResResponse:
"""Get task results."""

# Check if channel is open
if self.stub is None:
log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED)
raise Exception("`Driver` instance not connected")

# Serialize, call Driver API, deserialize
req_proto = serde.get_results_request_to_proto(req)
res_proto: driver_pb2.GetResultsResponse = self.stub.GetResults(
request=req_proto
)
return serde.get_results_response_from_proto(res_proto)
# Call Driver API
res: driver_pb2.PullTaskResResponse = self.stub.PullTaskRes(request=req)
return res
Loading

0 comments on commit 9572253

Please sign in to comment.