From 245d58eff788e8d44a59d37a2d9b26d0f08a62b4 Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Fri, 15 Dec 2023 16:01:27 -0500 Subject: [PATCH] Improve how server/js client handle unexpected errors (#6798) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Client fixes * fix * add changeset * commented out code * add changeset * Log error give generic message * Add client side catch * remove exception 😂 * Add test * Fix info and warning * lint * Use BaseException * Use event_callbacks * Use event_id not present --------- Co-authored-by: gradio-pr-bot --- .changeset/dirty-experts-cry.md | 6 ++ client/js/src/client.ts | 181 +++++++++++++++++++++----------- gradio/routes.py | 21 +++- 3 files changed, 142 insertions(+), 66 deletions(-) create mode 100644 .changeset/dirty-experts-cry.md diff --git a/.changeset/dirty-experts-cry.md b/.changeset/dirty-experts-cry.md new file mode 100644 index 000000000000..9298234a9b9a --- /dev/null +++ b/.changeset/dirty-experts-cry.md @@ -0,0 +1,6 @@ +--- +"@gradio/client": patch +"gradio": patch +--- + +feat:Improve how server/js client handle unexpected errors diff --git a/client/js/src/client.ts b/client/js/src/client.ts index 8f78a78282bf..6dbf4c82e59a 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -821,84 +821,121 @@ export function api_factory( }); } else { event_id = response.event_id as string; - if (!stream_open) { - open_stream(); - } let callback = async function (_data: object): void { - const { type, status, data } = handle_message( - _data, - last_status[fn_index] - ); - - if (type === "update" && status && !complete) { - // call 'status' listeners - fire_event({ - type: "status", - endpoint: _endpoint, - fn_index, - time: new Date(), - ...status - }); - } else if (type === "complete") { - complete = status; - } else if (type === "log") { - fire_event({ - type: "log", - log: data.log, - level: data.level, - endpoint: _endpoint, - fn_index - }); - } else if (type === "generating") { - fire_event({ - type: "status", - time: new Date(), - ...status, - stage: status?.stage!, - queue: true, - endpoint: _endpoint, - fn_index - }); - } - if (data) { - fire_event({ - type: "data", - time: new Date(), - data: transform_files - ? transform_output( - data.data, - api_info, - config.root, - config.root_url - ) - : data.data, - endpoint: _endpoint, - fn_index - }); + try { + const { type, status, data } = handle_message( + _data, + last_status[fn_index] + ); + + // TODO: Find out how to print this information + // only during testing + // console.info("data", type, status, data); + + if (type == "heartbeat") { + return; + } - if (complete) { + if (type === "update" && status && !complete) { + // call 'status' listeners + fire_event({ + type: "status", + endpoint: _endpoint, + fn_index, + time: new Date(), + ...status + }); + } else if (type === "complete") { + complete = status; + } else if (type == "unexpected_error") { + console.error("Unexpected error", status?.message); + fire_event({ + type: "status", + stage: "error", + message: "An Unexpected Error Occurred!", + queue: true, + endpoint: _endpoint, + fn_index, + time: new Date() + }); + } else if (type === "log") { + fire_event({ + type: "log", + log: data.log, + level: data.level, + endpoint: _endpoint, + fn_index + }); + return; + } else if (type === "generating") { fire_event({ type: "status", time: new Date(), - ...complete, + ...status, stage: status?.stage!, queue: true, endpoint: _endpoint, fn_index }); } - } + if (data) { + fire_event({ + type: "data", + time: new Date(), + data: transform_files + ? transform_output( + data.data, + api_info, + config.root, + config.root_url + ) + : data.data, + endpoint: _endpoint, + fn_index + }); + + if (complete) { + fire_event({ + type: "status", + time: new Date(), + ...complete, + stage: status?.stage!, + queue: true, + endpoint: _endpoint, + fn_index + }); + } + } - if (status.stage === "complete" || status.stage === "error") { - if (event_callbacks[event_id]) { - delete event_callbacks[event_id]; - if (Object.keys(event_callbacks).length === 0) { - close_stream(); + if ( + status.stage === "complete" || + status.stage === "error" + ) { + if (event_callbacks[event_id]) { + delete event_callbacks[event_id]; + if (Object.keys(event_callbacks).length === 0) { + close_stream(); + } } } + } catch (e) { + console.error("Unexpected client exception", e); + fire_event({ + type: "status", + stage: "error", + message: "An Unexpected Error Occurred!", + queue: true, + endpoint: _endpoint, + fn_index, + time: new Date() + }); + close_stream(); } }; event_callbacks[event_id] = callback; + if (!stream_open) { + open_stream(); + } } }); } @@ -1014,6 +1051,14 @@ export function api_factory( event_stream = new EventSource(url); event_stream.onmessage = async function (event) { let _data = JSON.parse(event.data); + if (!("event_id" in _data)) { + await Promise.all( + Object.keys(event_callbacks).map((event_id) => + event_callbacks[event_id](_data) + ) + ); + return; + } await event_callbacks[_data.event_id](_data); }; } @@ -1583,6 +1628,20 @@ function handle_message( success: data.success } }; + case "heartbeat": + return { + type: "heartbeat" + }; + case "unexpected_error": + return { + type: "unexpected_error", + status: { + queue, + message: data.message, + stage: "error", + success: false + } + }; case "estimation": return { type: "update", diff --git a/gradio/routes.py b/gradio/routes.py index d43616bcac15..3a59adb1b3b9 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -615,7 +615,10 @@ async def sse_stream(request: fastapi.Request): except EmptyQueue: await asyncio.sleep(check_rate) if time.perf_counter() - last_heartbeat > heartbeat_rate: - message = {"msg": ServerMessage.heartbeat} + # Fix this + message = { + "msg": ServerMessage.heartbeat, + } # Need to reset last_heartbeat with perf_counter # otherwise only a single hearbeat msg will be sent # and then the stream will retry leading to infinite queue 😬 @@ -623,7 +626,8 @@ async def sse_stream(request: fastapi.Request): if blocks._queue.stopped: message = { - "msg": ServerMessage.server_stopped, + "msg": "unexpected_error", + "message": "Server stopped unexpectedly.", "success": False, } if message: @@ -644,9 +648,16 @@ async def sse_stream(request: fastapi.Request): ) ): return - except asyncio.CancelledError as e: - del blocks._queue.pending_messages_per_session[session_hash] - await blocks._queue.clean_events(session_hash=session_hash) + except BaseException as e: + message = { + "msg": "unexpected_error", + "success": False, + "message": str(e), + } + yield f"data: {json.dumps(message)}\n\n" + if isinstance(e, asyncio.CancelledError): + del blocks._queue.pending_messages_per_session[session_hash] + await blocks._queue.clean_events(session_hash=session_hash) raise e return StreamingResponse(