Skip to content

Commit

Permalink
Python client properly handles hearbeat and log messages. Also handle…
Browse files Browse the repository at this point in the history
…s responses longer than 65k (#6693)

* first commit

* newlines

* test

* Fix depends

* revert

* add changeset

* add changeset

* Lint

* queue full test

* Add code

* Update + fix

* add changeset

* Revert demo

* Typo in success

* Fix

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
3 people authored Dec 13, 2023
1 parent a3cf90e commit 34f9431
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 70 deletions.
7 changes: 7 additions & 0 deletions .changeset/yummy-roses-decide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/client": patch
"gradio": patch
"gradio_client": patch
---

fix:Python client properly handles hearbeat and log messages. Also handles responses longer than 65k
25 changes: 21 additions & 4 deletions client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,16 @@ export function api_factory(
} catch (e) {
return [{ error: BROKEN_CONNECTION_MSG }, 500];
}
const output: PostResponse = await response.json();
return [output, response.status];
let output: PostResponse;
let status: int;
try {
output = await response.json();
status = response.status;
} catch (e) {
output = { error: `Could not parse server response: ${e}` };
status = 500;
}
return [output, status];
}

async function upload_files(
Expand Down Expand Up @@ -791,7 +799,17 @@ export function api_factory(
},
hf_token
).then(([response, status]) => {
if (status !== 200) {
if (status === 503) {
fire_event({
type: "status",
stage: "error",
message: QUEUE_FULL_MSG,
queue: true,
endpoint: _endpoint,
fn_index,
time: new Date()
});
} else if (status !== 200) {
fire_event({
type: "status",
stage: "error",
Expand All @@ -806,7 +824,6 @@ export function api_factory(
if (!stream_open) {
open_stream();
}

let callback = async function (_data: object): void {
const { type, status, data } = handle_message(
_data,
Expand Down
55 changes: 28 additions & 27 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
Communicator,
JobStatus,
Message,
QueueError,
ServerMessage,
Status,
StatusUpdate,
)
Expand Down Expand Up @@ -169,41 +171,38 @@ def __init__(
async def stream_messages(self) -> None:
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
buffer = ""
async with client.stream(
"GET",
self.sse_url,
params={"session_hash": self.session_hash},
headers=self.headers,
cookies=self.cookies,
) as response:
async for line in response.aiter_text():
buffer += line
while "\n\n" in buffer:
message, buffer = buffer.split("\n\n", 1)
if message.startswith("data:"):
resp = json.loads(message[5:])
if resp["msg"] == "heartbeat":
continue
elif resp["msg"] == "server_stopped":
for (
pending_messages
) in self.pending_messages_per_event.values():
pending_messages.append(resp)
return
event_id = resp["event_id"]
if event_id not in self.pending_messages_per_event:
self.pending_messages_per_event[event_id] = []
self.pending_messages_per_event[event_id].append(resp)
if resp["msg"] == "process_completed":
self.pending_event_ids.remove(event_id)
if len(self.pending_event_ids) == 0:
self.stream_open = False
return
elif message == "":
async for line in response.aiter_lines():
line = line.rstrip("\n")
if not len(line):
continue
if line.startswith("data:"):
resp = json.loads(line[5:])
if resp["msg"] == ServerMessage.heartbeat:
continue
else:
raise ValueError(f"Unexpected SSE line: '{message}'")
elif resp["msg"] == ServerMessage.server_stopped:
for (
pending_messages
) in self.pending_messages_per_event.values():
pending_messages.append(resp)
return
event_id = resp["event_id"]
if event_id not in self.pending_messages_per_event:
self.pending_messages_per_event[event_id] = []
self.pending_messages_per_event[event_id].append(resp)
if resp["msg"] == ServerMessage.process_completed:
self.pending_event_ids.remove(event_id)
if len(self.pending_event_ids) == 0:
self.stream_open = False
return
else:
raise ValueError(f"Unexpected SSE line: '{line}'")
except BaseException as e:
import traceback

Expand All @@ -218,6 +217,8 @@ async def send_data(self, data, hash_data):
headers=self.headers,
cookies=self.cookies,
)
if req.status_code == 503:
raise QueueError("Queue is full! Please try again.")
req.raise_for_status()
resp = req.json()
event_id = resp["event_id"]
Expand Down
51 changes: 34 additions & 17 deletions client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,20 @@ class SpaceDuplicationError(Exception):
pass


class ServerMessage(str, Enum):
send_hash = "send_hash"
queue_full = "queue_full"
estimation = "estimation"
send_data = "send_data"
process_starts = "process_starts"
process_generating = "process_generating"
process_completed = "process_completed"
log = "log"
progress = "progress"
heartbeat = "heartbeat"
server_stopped = "server_stopped"


class Status(Enum):
"""Status codes presented to client users."""

Expand Down Expand Up @@ -141,16 +155,17 @@ def __lt__(self, other: Status):
def msg_to_status(msg: str) -> Status:
"""Map the raw message from the backend to the status code presented to users."""
return {
"send_hash": Status.JOINING_QUEUE,
"queue_full": Status.QUEUE_FULL,
"estimation": Status.IN_QUEUE,
"send_data": Status.SENDING_DATA,
"process_starts": Status.PROCESSING,
"process_generating": Status.ITERATING,
"process_completed": Status.FINISHED,
"progress": Status.PROGRESS,
"log": Status.LOG,
}[msg]
ServerMessage.send_hash: Status.JOINING_QUEUE,
ServerMessage.queue_full: Status.QUEUE_FULL,
ServerMessage.estimation: Status.IN_QUEUE,
ServerMessage.send_data: Status.SENDING_DATA,
ServerMessage.process_starts: Status.PROCESSING,
ServerMessage.process_generating: Status.ITERATING,
ServerMessage.process_completed: Status.FINISHED,
ServerMessage.progress: Status.PROGRESS,
ServerMessage.log: Status.LOG,
ServerMessage.server_stopped: Status.FINISHED,
}[msg] # type: ignore


@dataclass
Expand Down Expand Up @@ -436,9 +451,14 @@ async def stream_sse_v0(
headers=headers,
cookies=cookies,
) as response:
async for line in response.aiter_text():
async for line in response.aiter_lines():
line = line.rstrip("\n")
if len(line) == 0:
continue
if line.startswith("data:"):
resp = json.loads(line[5:])
if resp["msg"] in [ServerMessage.log, ServerMessage.heartbeat]:
continue
with helper.lock:
has_progress = "progress_data" in resp
status_update = StatusUpdate(
Expand Down Expand Up @@ -502,7 +522,7 @@ async def stream_sse_v1(

with helper.lock:
log_message = None
if msg["msg"] == "log":
if msg["msg"] == ServerMessage.log:
log = msg.get("log")
level = msg.get("level")
if log and level:
Expand All @@ -527,13 +547,10 @@ async def stream_sse_v1(
result = [e]
helper.job.outputs.append(result)
helper.job.latest_status = status_update

if msg["msg"] == "queue_full":
raise QueueError("Queue is full! Please try again.")
elif msg["msg"] == "process_completed":
if msg["msg"] == ServerMessage.process_completed:
del pending_messages_per_event[event_id]
return msg["output"]
elif msg["msg"] == "server_stopped":
elif msg["msg"] == ServerMessage.server_stopped:
raise ValueError("Server stopped.")

except asyncio.CancelledError:
Expand Down
15 changes: 15 additions & 0 deletions client/python/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,18 @@ def gradio_temp_dir(monkeypatch, tmp_path):
"""
monkeypatch.setenv("GRADIO_TEMP_DIR", str(tmp_path))
return tmp_path


@pytest.fixture
def long_response_with_info():
def long_response(x):
gr.Info("Beginning long response")
time.sleep(17)
gr.Info("Done!")
return "\ta\nb" * 90000

return gr.Interface(
long_response,
None,
gr.Textbox(label="Output"),
)
44 changes: 42 additions & 2 deletions client/python/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile
import time
import uuid
from concurrent.futures import CancelledError, TimeoutError
from concurrent.futures import CancelledError, TimeoutError, wait
from contextlib import contextmanager
from datetime import datetime, timedelta
from pathlib import Path
Expand All @@ -21,7 +21,13 @@

from gradio_client import Client
from gradio_client.client import DEFAULT_TEMP_DIR
from gradio_client.utils import Communicator, ProgressUnit, Status, StatusUpdate
from gradio_client.utils import (
Communicator,
ProgressUnit,
QueueError,
Status,
StatusUpdate,
)

HF_TOKEN = os.getenv("HF_TOKEN") or HfFolder.get_token()

Expand Down Expand Up @@ -488,6 +494,40 @@ def test_return_layout_and_state_components(
assert demo.predict(api_name="/close") == 4
assert demo.predict("Ali", api_name="/greeting") == ("Hello Ali", 5)

def test_long_response_time_with_gr_info_and_big_payload(
self, long_response_with_info
):
with connect(long_response_with_info) as demo:
assert demo.predict(api_name="/predict") == "\ta\nb" * 90000

def test_queue_full_raises_error(self):
demo = gr.Interface(lambda s: f"Hello {s}", "textbox", "textbox").queue(
max_size=1
)
with connect(demo) as client:
with pytest.raises(QueueError):
job1 = client.submit("Freddy", api_name="/predict")
job2 = client.submit("Abubakar", api_name="/predict")
job3 = client.submit("Pete", api_name="/predict")
wait([job1, job2, job3])
job1.result()
job2.result()
job3.result()

def test_json_parse_error(self):
data = (
"Bonjour Olivier, tu as l'air bien r\u00e9veill\u00e9 ce matin. Tu veux que je te pr\u00e9pare tes petits-d\u00e9j.\n",
None,
)

def return_bad():
return data

demo = gr.Interface(return_bad, None, ["text", "text"])
with connect(demo) as client:
pred = client.predict(api_name="/predict")
assert pred[0] == data[0]


class TestStatusUpdates:
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
Expand Down
Loading

0 comments on commit 34f9431

Please sign in to comment.