diff --git a/.changeset/hot-taxis-jump.md b/.changeset/hot-taxis-jump.md new file mode 100644 index 000000000000..efd4398e411c --- /dev/null +++ b/.changeset/hot-taxis-jump.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Run pre/post processing in threadpool diff --git a/gradio/blocks.py b/gradio/blocks.py index 489a2ec0d5ea..42ca4cba3b87 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -1521,6 +1521,9 @@ def handle_streaming_diffs( return data + def run_fn_batch(self, fn, batch, fn_index, state): + return [fn(fn_index, list(i), state) for i in zip(*batch)] + async def process_api( self, fn_index: int, @@ -1565,10 +1568,14 @@ async def process_api( raise ValueError( f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})" ) - - inputs = [ - self.preprocess_data(fn_index, list(i), state) for i in zip(*inputs) - ] + inputs = await anyio.to_thread.run_sync( + self.run_fn_batch, + self.preprocess_data, + inputs, + fn_index, + state, + limiter=self.limiter, + ) result = await self.call_function( fn_index, list(zip(*inputs)), @@ -1579,9 +1586,14 @@ async def process_api( in_event_listener, ) preds = result["prediction"] - data = [ - self.postprocess_data(fn_index, list(o), state) for o in zip(*preds) - ] + data = await anyio.to_thread.run_sync( + self.run_fn_batch, + self.postprocess_data, + preds, + fn_index, + state, + limiter=self.limiter, + ) data = list(zip(*data)) is_generating, iterator = None, None else: @@ -1589,7 +1601,9 @@ async def process_api( if old_iterator: inputs = [] else: - inputs = self.preprocess_data(fn_index, inputs, state) + inputs = await anyio.to_thread.run_sync( + self.preprocess_data, fn_index, inputs, state, limiter=self.limiter + ) was_generating = old_iterator is not None result = await self.call_function( fn_index, @@ -1600,7 +1614,13 @@ async def process_api( event_data, in_event_listener, ) - data = self.postprocess_data(fn_index, result["prediction"], state) + data = await anyio.to_thread.run_sync( + self.postprocess_data, + fn_index, # type: ignore + result["prediction"], + state, + limiter=self.limiter, + ) is_generating, iterator = result["is_generating"], result["iterator"] if is_generating or was_generating: run = id(old_iterator) if was_generating else id(iterator) diff --git a/gradio/components/gallery.py b/gradio/components/gallery.py index 414d7719d98e..7a5b8952505e 100644 --- a/gradio/components/gallery.py +++ b/gradio/components/gallery.py @@ -2,6 +2,7 @@ from __future__ import annotations +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, Callable, List, Literal, Optional, Tuple, Union from urllib.parse import urlparse @@ -165,7 +166,8 @@ def postprocess( if value is None: return GalleryData(root=[]) output = [] - for img in value: + + def _save(img): url = None caption = None orig_name = None @@ -194,11 +196,14 @@ def postprocess( orig_name = img.name else: raise ValueError(f"Cannot process type as image: {type(img)}") - entry = GalleryImage( + return GalleryImage( image=FileData(path=file_path, url=url, orig_name=orig_name), caption=caption, ) - output.append(entry) + + with ThreadPoolExecutor() as executor: + for o in executor.map(_save, value): + output.append(o) return GalleryData(root=output) @staticmethod diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 2d86d1d8216f..94ad52c9ff28 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -135,7 +135,7 @@ def save_pil_to_cache( temp_dir = Path(cache_dir) / hash_bytes(bytes_data) temp_dir.mkdir(exist_ok=True, parents=True) filename = str((temp_dir / f"{name}.{format}").resolve()) - img.save(filename, pnginfo=get_pil_metadata(img)) + (temp_dir / f"{name}.{format}").resolve().write_bytes(bytes_data) return filename