Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate from FastAPI to Starlette #2171

Merged
merged 4 commits into from
Jul 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/source/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@

The (experimental) `start_client` argument `rest` was deprecated in favor of a new argument `transport`. `start_client(transport="rest")` will yield the same behaviour as `start_client(rest=True)` did before. All code should migrate to the new argument `transport`. The deprecated argument `rest` will be removed in a future release.

- **Migrate experimental REST API to Starlette** ([2171](https://github.com/adap/flower/pull/2171))

The (experimental) REST API used to be implemented in [FastAPI](https://fastapi.tiangolo.com/), but it has now been migrated to use [Starlette](https://www.starlette.io/) directly.

- **General improvements** ([#1872](https://github.com/adap/flower/pull/1872), [#1866](https://github.com/adap/flower/pull/1866), [#1884](https://github.com/adap/flower/pull/1884))

### Incompatible changes
Expand Down
4 changes: 2 additions & 2 deletions examples/mt-pytorch/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Multi-Tenant Federated Learning with Flower and PyTorch

This example contains highly experimental code. Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch.
This example contains experimental code. Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch.

## Setup

Expand All @@ -16,7 +16,7 @@ Terminal 1: start Flower server
flower-server
```

Terminal 2+3: start two clients
Terminal 2+3: start two Flower client nodes

```bash
python client.py
Expand Down
9 changes: 4 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,13 @@ iterators = "^0.0.2"
ray = { version = "==2.5.1", extras = ["default"], optional = true }
pydantic = { version = "<2.0.0", optional = true }
# Optional dependencies (REST transport layer)
requests = { version = "^2.28.2", optional = true }
fastapi = { version = "^0.95.0", optional = true }
starlette = { version = "^0.27.0", optional = true }
uvicorn = { version = "^0.21.1", extras = ["standard"], optional = true }
requests = { version = "^2.31.0", optional = true }
starlette = { version = "^0.29.0", optional = true }
uvicorn = { version = "^0.22.0", extras = ["standard"], optional = true }

[tool.poetry.extras]
simulation = ["ray", "pydantic"]
rest = ["fastapi", "requests", "starlette", "uvicorn"]
rest = ["requests", "starlette", "uvicorn"]

[tool.poetry.group.dev.dependencies]
types-dataclasses = "==0.6.6"
Expand Down
6 changes: 2 additions & 4 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,7 @@ def run_fleet_api() -> None:
# Start Fleet API
if args.fleet_api_type == TRANSPORT_TYPE_REST:
if (
importlib.util.find_spec("fastapi")
and importlib.util.find_spec("requests")
importlib.util.find_spec("requests")
and importlib.util.find_spec("starlette")
and importlib.util.find_spec("uvicorn")
) is None:
Expand Down Expand Up @@ -376,8 +375,7 @@ def run_server() -> None:
# Start Fleet API
if args.fleet_api_type == TRANSPORT_TYPE_REST:
if (
importlib.util.find_spec("fastapi")
and importlib.util.find_spec("requests")
importlib.util.find_spec("requests")
and importlib.util.find_spec("starlette")
and importlib.util.find_spec("uvicorn")
) is None:
Expand Down
24 changes: 17 additions & 7 deletions src/py/flwr/server/fleet/rest_rere/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""REST API server."""
"""Experimental REST API server."""


import sys
Expand All @@ -23,16 +23,16 @@
from flwr.server.state import State

try:
from fastapi import FastAPI, HTTPException, Request, Response
from starlette.applications import Starlette
from starlette.datastructures import Headers
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route
except ModuleNotFoundError:
sys.exit(MISSING_EXTRA_REST)


app: FastAPI = FastAPI()


@app.post("/api/v0/fleet/pull-task-ins", response_class=Response)
async def pull_task_ins(request: Request) -> Response:
"""Pull TaskIns."""
_check_headers(request.headers)
Expand Down Expand Up @@ -62,7 +62,6 @@ async def pull_task_ins(request: Request) -> Response:
)


@app.post("/api/v0/fleet/push-task-res", response_class=Response)
async def push_task_res(request: Request) -> Response: # Check if token is needed here
"""Push TaskRes."""
_check_headers(request.headers)
Expand Down Expand Up @@ -92,6 +91,17 @@ async def push_task_res(request: Request) -> Response: # Check if token is need
)


routes = [
Route("/api/v0/fleet/pull-task-ins", pull_task_ins, methods=["POST"]),
Route("/api/v0/fleet/push-task-res", push_task_res, methods=["POST"]),
]

app: Starlette = Starlette(
debug=False,
routes=routes,
)


def _check_headers(headers: Headers) -> None:
"""Check if expected headers are set."""
if "content-type" not in headers:
Expand Down