From 62ec2075ccad8025a7721a08d0f29eb5a4f87fad Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 31 Oct 2023 06:48:10 -0700 Subject: [PATCH] modify preprocess to use pydantic models (#6181) * modify preprocess to use pydantic models * changes * add changeset * fix * fix * fix typing * save * revert queuing changes * fix * fix * notebook * fix * changes * add changeset * fix functional tests --------- Co-authored-by: gradio-pr-bot --- .changeset/short-doodles-lose.md | 6 ++ demo/chatbot_multimodal/run.ipynb | 2 +- demo/chatbot_multimodal/run.py | 2 +- demo/clear_components/run.py | 1 + gradio/blocks.py | 10 ++- gradio/components/annotated_image.py | 19 +++-- gradio/components/audio.py | 54 ++++++------- gradio/components/bar_plot.py | 16 ++-- gradio/components/base.py | 42 +++++----- gradio/components/button.py | 8 +- gradio/components/chatbot.py | 38 ++++----- gradio/components/checkbox.py | 6 ++ gradio/components/checkboxgroup.py | 28 ++----- gradio/components/clear_button.py | 9 ++- gradio/components/code.py | 18 ++--- gradio/components/color_picker.py | 33 ++------ gradio/components/dataframe.py | 85 ++++++++++----------- gradio/components/dataset.py | 11 +-- gradio/components/dropdown.py | 40 +++++----- gradio/components/fallback.py | 8 +- gradio/components/file.py | 56 ++++++-------- gradio/components/file_explorer.py | 36 ++++----- gradio/components/gallery.py | 14 ++-- gradio/components/highlighted_text.py | 29 ++++--- gradio/components/html.py | 8 +- gradio/components/image.py | 32 +++----- gradio/components/json_component.py | 28 +++---- gradio/components/label.py | 58 +++++++------- gradio/components/line_plot.py | 16 ++-- gradio/components/markdown.py | 16 ++-- gradio/components/model3d.py | 28 ++----- gradio/components/number.py | 38 ++++----- gradio/components/plot.py | 28 +++---- gradio/components/radio.py | 16 ++-- gradio/components/scatter_plot.py | 16 ++-- gradio/components/slider.py | 15 +--- gradio/components/state.py | 8 +- gradio/components/textbox.py | 22 +----- gradio/components/upload_button.py | 61 ++++++++++----- gradio/components/video.py | 50 ++---------- gradio/queueing.py | 7 +- js/uploadbutton/shared/UploadButton.svelte | 6 +- test/test_components.py | 65 ++++++++-------- test/test_files/audio_sample.wav | Bin 16362 -> 16136 bytes test/test_theme_sharing.py | 25 +++--- 45 files changed, 491 insertions(+), 623 deletions(-) create mode 100644 .changeset/short-doodles-lose.md diff --git a/.changeset/short-doodles-lose.md b/.changeset/short-doodles-lose.md new file mode 100644 index 000000000000..4071d75ee8f3 --- /dev/null +++ b/.changeset/short-doodles-lose.md @@ -0,0 +1,6 @@ +--- +"@gradio/uploadbutton": minor +"gradio": minor +--- + +feat:modify preprocess to use pydantic models diff --git a/demo/chatbot_multimodal/run.ipynb b/demo/chatbot_multimodal/run.ipynb index 2f88a5d6cf8e..5581c0466efd 100644 --- a/demo/chatbot_multimodal/run.ipynb +++ b/demo/chatbot_multimodal/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/avatar.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import time\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "def add_text(history, text):\n", " history = history + [(text, None)]\n", " return history, gr.Textbox(value=\"\", interactive=False)\n", "\n", "\n", "def add_file(history, file):\n", " history = history + [((file.name,), None)]\n", " return history\n", "\n", "\n", "def bot(history):\n", " response = \"**That's cool!**\"\n", " history[-1][1] = \"\"\n", " for character in response:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " avatar_images=(None, (os.path.join(os.path.abspath(''), \"avatar.png\"))),\n", " )\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(\n", " scale=4,\n", " show_label=False,\n", " placeholder=\"Enter text and press enter, or upload an image\",\n", " container=False,\n", " )\n", " btn = gr.UploadButton(\"\ud83d\udcc1\", file_types=[\"image\", \"video\", \"audio\"])\n", "\n", " txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n", " bot, chatbot, chatbot, api_name=\"bot_response\"\n", " )\n", " txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)\n", " file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/avatar.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import time\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "def add_text(history, text):\n", " history = history + [(text, None)]\n", " return history, gr.Textbox(value=\"\", interactive=False)\n", "\n", "\n", "def add_file(history, file):\n", " history = history + [((file.name,), None)]\n", " return history\n", "\n", "\n", "def bot(history):\n", " response = \"**That's cool!**\"\n", " history[-1][1] = \"\"\n", " for character in response:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " avatar_images=(None, (os.path.join(os.path.abspath(''), \"avatar.png\"))),\n", " )\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(\n", " scale=4,\n", " show_label=False,\n", " placeholder=\"Enter text and press enter, or upload an image\",\n", " container=False,\n", " )\n", " btn = gr.UploadButton(\"\ud83d\udcc1\", file_types=[\"image\", \"video\", \"audio\"])\n", "\n", " txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n", " bot, chatbot, chatbot, api_name=\"bot_response\"\n", " )\n", " txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)\n", " file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch(allowed_paths=[\"avatar.png\"])\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/chatbot_multimodal/run.py b/demo/chatbot_multimodal/run.py index 650905aa2bcf..f9d3b5dbeba2 100644 --- a/demo/chatbot_multimodal/run.py +++ b/demo/chatbot_multimodal/run.py @@ -51,4 +51,4 @@ def bot(history): demo.queue() if __name__ == "__main__": - demo.launch() + demo.launch(allowed_paths=["avatar.png"]) diff --git a/demo/clear_components/run.py b/demo/clear_components/run.py index d60e73621d6f..0a29ba0aee8b 100644 --- a/demo/clear_components/run.py +++ b/demo/clear_components/run.py @@ -153,6 +153,7 @@ def evaluate_values(*args): are_false.append(a == "#000000") else: are_false.append(not a) + print(args) return all(are_false) diff --git a/gradio/blocks.py b/gradio/blocks.py index 474ac520baeb..0dfe437dd634 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -341,7 +341,7 @@ def __init__( self.postprocess = postprocess self.tracks_progress = tracks_progress self.concurrency_limit = concurrency_limit - self.concurrency_id = concurrency_id or id(fn) + self.concurrency_id = concurrency_id or str(id(fn)) self.batch = batch self.max_batch_size = max_batch_size self.total_runtime = 0 @@ -1260,6 +1260,14 @@ def preprocess_data( inputs_cached = processing_utils.move_files_to_cache( inputs[i], block ) + if getattr(block, "data_model", None) and inputs_cached is not None: + if issubclass(block.data_model, GradioModel): # type: ignore + print("block.data_model", block.data_model, block) + print("1inputs_cached", inputs_cached) + inputs_cached = block.data_model(**inputs_cached) # type: ignore + elif issubclass(block.data_model, GradioRootModel): # type: ignore + print("2inputs_cached", inputs_cached) + inputs_cached = block.data_model(root=inputs_cached) # type: ignore processed_input.append(block.preprocess(inputs_cached)) else: processed_input = inputs diff --git a/gradio/components/annotated_image.py b/gradio/components/annotated_image.py index db1552832cd1..b1a050c4bb76 100644 --- a/gradio/components/annotated_image.py +++ b/gradio/components/annotated_image.py @@ -103,20 +103,21 @@ def __init__( def postprocess( self, - y: tuple[ + value: tuple[ np.ndarray | _Image.Image | str, list[tuple[np.ndarray | tuple[int, int, int, int], str]], - ], + ] + | None, ) -> AnnotatedImageData | None: """ Parameters: - y: Tuple of base image and list of subsections, with each subsection a two-part tuple where the first element is a 4 element bounding box or a 0-1 confidence mask, and the second element is the label. + value: Tuple of base image and list of subsections, with each subsection a two-part tuple where the first element is a 4 element bounding box or a 0-1 confidence mask, and the second element is the label. Returns: Tuple of base image file and list of subsections, with each subsection a two-part tuple where the first element image path of the mask, and the second element is the label. """ - if y is None: + if value is None: return None - base_img = y[0] + base_img = value[0] if isinstance(base_img, str): base_img_path = base_img base_img = np.array(_Image.open(base_img)) @@ -144,7 +145,7 @@ def hex_to_rgb(value): lv = len(value) return [int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3)] - for mask, label in y[1]: + for mask, label in value[1]: mask_array = np.zeros((base_img.shape[0], base_img.shape[1])) if isinstance(mask, np.ndarray): mask_array = mask @@ -188,5 +189,7 @@ def hex_to_rgb(value): def example_inputs(self) -> Any: return {} - def preprocess(self, x: Any) -> Any: - return x + def preprocess( + self, payload: AnnotatedImageData | None + ) -> AnnotatedImageData | None: + return payload diff --git a/gradio/components/audio.py b/gradio/components/audio.py index 7b8e494c4c88..4e8478cab5df 100644 --- a/gradio/components/audio.py +++ b/gradio/components/audio.py @@ -162,20 +162,12 @@ def example_inputs(self) -> Any: return "https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav" def preprocess( - self, x: dict[str, Any] | None + self, payload: FileData | None ) -> tuple[int, np.ndarray] | str | None: - """ - Parameters: - x: dictionary with keys "path", "crop_min", "crop_max". - Returns: - audio in requested format - """ - if x is None: - return x + if payload is None: + return payload - payload: FileData = FileData(**x) assert payload.path - # Need a unique name for the file to avoid re-using the same audio file if # a user submits the same audio file twice temp_file_path = Path(payload.path) @@ -211,50 +203,50 @@ def preprocess( ) def postprocess( - self, y: tuple[int, np.ndarray] | str | Path | bytes | None - ) -> FileData | None | bytes: + self, value: tuple[int, np.ndarray] | str | Path | bytes | None + ) -> FileData | bytes | None: """ Parameters: - y: audio data in either of the following formats: a tuple of (sample_rate, data), or a string filepath or URL to an audio file, or None. + value: audio data in either of the following formats: a tuple of (sample_rate, data), or a string filepath or URL to an audio file, or None. Returns: base64 url data """ - if y is None: + if value is None: return None - if isinstance(y, bytes): + if isinstance(value, bytes): if self.streaming: - return y + return value file_path = processing_utils.save_bytes_to_cache( - y, "audio", cache_dir=self.GRADIO_CACHE + value, "audio", cache_dir=self.GRADIO_CACHE ) - elif isinstance(y, tuple): - sample_rate, data = y + elif isinstance(value, tuple): + sample_rate, data = value file_path = processing_utils.save_audio_to_cache( data, sample_rate, format=self.format, cache_dir=self.GRADIO_CACHE ) else: - if not isinstance(y, (str, Path)): - raise ValueError(f"Cannot process {y} as Audio") - file_path = str(y) + if not isinstance(value, (str, Path)): + raise ValueError(f"Cannot process {value} as Audio") + file_path = str(value) return FileData(path=file_path) def stream_output( - self, y, output_id: str, first_chunk: bool + self, value, output_id: str, first_chunk: bool ) -> tuple[bytes | None, Any]: output_file = { "path": output_id, "is_stream": True, } - if y is None: + if value is None: return None, output_file - if isinstance(y, bytes): - return y, output_file - if client_utils.is_http_url_like(y["path"]): - response = requests.get(y["path"]) + if isinstance(value, bytes): + return value, output_file + if client_utils.is_http_url_like(value["path"]): + response = requests.get(value["path"]) binary_data = response.content else: - output_file["orig_name"] = y["orig_name"] - file_path = y["path"] + output_file["orig_name"] = value["orig_name"] + file_path = value["path"] is_wav = file_path.endswith(".wav") with open(file_path, "rb") as f: binary_data = f.read() diff --git a/gradio/components/bar_plot.py b/gradio/components/bar_plot.py index d92c3b4f607e..6d2de2fd569c 100644 --- a/gradio/components/bar_plot.py +++ b/gradio/components/bar_plot.py @@ -258,15 +258,15 @@ def create_plot( return chart def postprocess( - self, y: pd.DataFrame | dict | None + self, value: pd.DataFrame | dict | None ) -> AltairPlotData | dict | None: # if None or update - if y is None or isinstance(y, dict): - return y + if value is None or isinstance(value, dict): + return value if self.x is None or self.y is None: raise ValueError("No value provided for required parameters `x` and `y`.") chart = self.create_plot( - value=y, + value=value, x=self.x, y=self.y, color=self.color, @@ -288,12 +288,10 @@ def postprocess( sort=self.sort, # type: ignore ) - return AltairPlotData( - **{"type": "altair", "plot": chart.to_json(), "chart": "bar"} - ) + return AltairPlotData(type="altair", plot=chart.to_json(), chart="bar") def example_inputs(self) -> dict[str, Any]: return {} - def preprocess(self, x: Any) -> Any: - return x + def preprocess(self, payload: AltairPlotData) -> AltairPlotData: + return payload diff --git a/gradio/components/base.py b/gradio/components/base.py index 596da58ca0ed..568c7180ddd6 100644 --- a/gradio/components/base.py +++ b/gradio/components/base.py @@ -48,21 +48,21 @@ class ComponentBase(ABC, metaclass=ComponentMeta): EVENTS: list[EventListener | str] = [] @abstractmethod - def preprocess(self, x: Any) -> Any: + def preprocess(self, payload: Any) -> Any: """ Any preprocessing needed to be performed on function input. """ - return x + return payload @abstractmethod - def postprocess(self, y): + def postprocess(self, value): """ Any postprocessing needed to be performed on function output. """ - return y + return value @abstractmethod - def as_example(self, y): + def as_example(self, value): """ Return the input data in a way that can be displayed by the examples dataset component in the front-end. @@ -88,7 +88,7 @@ def example_inputs(self) -> Any: pass @abstractmethod - def flag(self, x: Any | GradioDataModel, flag_dir: str | Path = "") -> str: + def flag(self, payload: Any | GradioDataModel, flag_dir: str | Path = "") -> str: """ Write the component's value to a format that can be stored in a csv or jsonl format for flagging. """ @@ -97,13 +97,13 @@ def flag(self, x: Any | GradioDataModel, flag_dir: str | Path = "") -> str: @abstractmethod def read_from_flag( self, - x: Any, + payload: Any, flag_dir: str | Path | None = None, ) -> GradioDataModel | Any: """ Convert the data from the csv or jsonl file into the component state. """ - return x + return payload @property @abstractmethod @@ -267,26 +267,26 @@ def api_info(self) -> dict[str, Any]: f"The api_info method has not been implemented for {self.get_block_name()}" ) - def flag(self, x: Any, flag_dir: str | Path = "") -> str: + def flag(self, payload: Any, flag_dir: str | Path = "") -> str: """ Write the component's value to a format that can be stored in a csv or jsonl format for flagging. """ if self.data_model: - x = self.data_model.from_json(x) - return x.copy_to_dir(flag_dir).model_dump_json() - return x + payload = self.data_model.from_json(payload) + return payload.copy_to_dir(flag_dir).model_dump_json() + return payload def read_from_flag( self, - x: Any, + payload: Any, flag_dir: str | Path | None = None, ): """ Convert the data from the csv or jsonl file into the component state. """ if self.data_model: - return self.data_model.from_json(json.loads(x)) - return x + return self.data_model.from_json(json.loads(payload)) + return payload class FormComponent(Component): @@ -295,11 +295,11 @@ def get_expected_parent(self) -> type[Form] | None: return None return Form - def preprocess(self, x: Any) -> Any: - return x + def preprocess(self, payload: Any) -> Any: + return payload - def postprocess(self, y): - return y + def postprocess(self, value): + return value class StreamingOutput(metaclass=abc.ABCMeta): @@ -308,7 +308,9 @@ def __init__(self, *args, **kwargs) -> None: self.streaming: bool @abc.abstractmethod - def stream_output(self, y, output_id: str, first_chunk: bool) -> tuple[bytes, Any]: + def stream_output( + self, value, output_id: str, first_chunk: bool + ) -> tuple[bytes, Any]: pass diff --git a/gradio/components/button.py b/gradio/components/button.py index 1a4f6f71a24b..3bb17269505b 100644 --- a/gradio/components/button.py +++ b/gradio/components/button.py @@ -77,11 +77,11 @@ def __init__( def skip_api(self): return True - def preprocess(self, x: Any) -> Any: - return x + def preprocess(self, payload: str) -> str: + return payload - def postprocess(self, y): - return y + def postprocess(self, value: str) -> str: + return value def example_inputs(self) -> Any: return None diff --git a/gradio/components/chatbot.py b/gradio/components/chatbot.py index e5efa61cf51f..88e4d1245ca5 100644 --- a/gradio/components/chatbot.py +++ b/gradio/components/chatbot.py @@ -14,8 +14,6 @@ from gradio.data_classes import FileData, GradioModel, GradioRootModel from gradio.events import Events -# from pydantic import Field, TypeAdapter - set_documentation_group("component") @@ -129,26 +127,28 @@ def __init__( ) def _preprocess_chat_messages( - self, chat_message: str | dict | None - ) -> str | tuple[str] | tuple[str, str] | None: + self, chat_message: str | FileMessage | None + ) -> str | tuple[str | None] | tuple[str | None, str] | None: if chat_message is None: return None - elif isinstance(chat_message, dict): - if chat_message.get("alt_text"): - return (chat_message["file"]["path"], chat_message["alt_text"]) + elif isinstance(chat_message, FileMessage): + if chat_message.alt_text is not None: + return (chat_message.file.path, chat_message.alt_text) else: - return (chat_message["file"]["path"],) - else: # string + return (chat_message.file.path,) + elif isinstance(chat_message, str): return chat_message + else: + raise ValueError(f"Invalid message for Chatbot component: {chat_message}") def preprocess( self, - y: list[list[str | dict | None] | tuple[str | dict | None, str | dict | None]], + payload: ChatbotData, ) -> list[list[str | tuple[str] | tuple[str, str] | None]]: - if y is None: - return y + if payload is None: + return payload processed_messages = [] - for message_pair in y: + for message_pair in payload.root: if not isinstance(message_pair, (tuple, list)): raise TypeError( f"Expected a list of lists or list of tuples. Received: {message_pair}" @@ -186,18 +186,12 @@ def _postprocess_chat_messages( def postprocess( self, - y: list[list[str | tuple[str] | tuple[str, str] | None] | tuple], + value: list[list[str | tuple[str] | tuple[str, str] | None] | tuple], ) -> ChatbotData: - """ - Parameters: - y: List of lists representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string or pathlib.Path filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed. - Returns: - List of lists representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information. Or None if the message is not to be displayed. - """ - if y is None: + if value is None: return ChatbotData(root=[]) processed_messages = [] - for message_pair in y: + for message_pair in value: if not isinstance(message_pair, (tuple, list)): raise TypeError( f"Expected a list of lists or list of tuples. Received: {message_pair}" diff --git a/gradio/components/checkbox.py b/gradio/components/checkbox.py index 2edf957f6b64..44ef1f960402 100644 --- a/gradio/components/checkbox.py +++ b/gradio/components/checkbox.py @@ -79,3 +79,9 @@ def api_info(self) -> dict[str, Any]: def example_inputs(self) -> bool: return True + + def preprocess(self, payload: bool | None) -> bool | None: + return payload + + def postprocess(self, value: bool | None) -> bool | None: + return value diff --git a/gradio/components/checkboxgroup.py b/gradio/components/checkboxgroup.py index 766c4fa48298..6762f20a8d85 100644 --- a/gradio/components/checkboxgroup.py +++ b/gradio/components/checkboxgroup.py @@ -101,21 +101,15 @@ def api_info(self) -> dict[str, Any]: } def preprocess( - self, x: list[str | int | float] + self, payload: list[str | int | float] ) -> list[str | int | float] | list[int | None]: - """ - Parameters: - x: list of selected choices - Returns: - list of selected choice values as strings or indices within choice list - """ if self.type == "value": - return x + return payload elif self.type == "index": choice_values = [value for _, value in self.choices] return [ choice_values.index(choice) if choice in choice_values else None - for choice in x + for choice in payload ] else: raise ValueError( @@ -123,19 +117,13 @@ def preprocess( ) def postprocess( - self, y: list[str | int | float] | str | int | float | None + self, value: list[str | int | float] | str | int | float | None ) -> list[str | int | float]: - """ - Parameters: - y: List of selected choice values. If a single choice is selected, it can be passed in as a string - Returns: - List of selected choices - """ - if y is None: + if value is None: return [] - if not isinstance(y, list): - y = [y] - return y + if not isinstance(value, list): + value = [value] + return value def as_example(self, input_data): if input_data is None: diff --git a/gradio/components/clear_button.py b/gradio/components/clear_button.py index fab34d488276..c482a6a05ab0 100644 --- a/gradio/components/clear_button.py +++ b/gradio/components/clear_button.py @@ -75,16 +75,17 @@ def add(self, components: None | Component | list[Component]) -> ClearButton: none = component.postprocess(None) if isinstance(none, (GradioModel, GradioRootModel)): none = none.model_dump() + print(none) none_values.append(none) clear_values = json.dumps(none_values) self.click(None, [], components, _js=f"() => {clear_values}") return self - def postprocess(self, y): - return y + def postprocess(self, value: str | None) -> str | None: + return value - def preprocess(self, x: Any) -> Any: - return x + def preprocess(self, payload: str | None) -> str | None: + return payload def example_inputs(self) -> Any: return None diff --git a/gradio/components/code.py b/gradio/components/code.py index e8a7d20b8e5b..a35cc08225b0 100644 --- a/gradio/components/code.py +++ b/gradio/components/code.py @@ -105,20 +105,20 @@ def __init__( value=value, ) - def preprocess(self, x: Any) -> Any: - return x + def preprocess(self, payload: Any) -> Any: + return payload - def postprocess(self, y: tuple | str | None) -> None | str: - if y is None: + def postprocess(self, value: tuple | str | None) -> None | str: + if value is None: return None - elif isinstance(y, tuple): - with open(y[0]) as file_data: + elif isinstance(value, tuple): + with open(value[0]) as file_data: return file_data.read() else: - return y.strip() + return value.strip() - def flag(self, x: Any, flag_dir: str | Path = "") -> str: - return super().flag(x, flag_dir) + def flag(self, payload: Any, flag_dir: str | Path = "") -> str: + return super().flag(payload, flag_dir) def api_info(self) -> dict[str, Any]: return {"type": "string"} diff --git a/gradio/components/color_picker.py b/gradio/components/color_picker.py index 44ebd5fc91a2..6575c9d42c0b 100644 --- a/gradio/components/color_picker.py +++ b/gradio/components/color_picker.py @@ -2,7 +2,6 @@ from __future__ import annotations -from pathlib import Path from typing import Any, Callable from gradio_client.documentation import document, set_documentation_group @@ -77,37 +76,17 @@ def __init__( def example_inputs(self) -> str: return "#000000" - def flag(self, x: Any, flag_dir: str | Path = "") -> str: - return x - - def read_from_flag(self, x: Any, flag_dir: str | Path | None = None): - return x - def api_info(self) -> dict[str, Any]: return {"type": "string"} - def preprocess(self, x: str | None) -> str | None: - """ - Any preprocessing needed to be performed on function input. - Parameters: - x: text - Returns: - text - """ - if x is None: + def preprocess(self, payload: str | None) -> str | None: + if payload is None: return None else: - return str(x) + return str(payload) - def postprocess(self, y: str | None) -> str | None: - """ - Any postprocessing needed to be performed on function output. - Parameters: - y: text - Returns: - text - """ - if y is None: + def postprocess(self, value: str | None) -> str | None: + if value is None: return None else: - return str(y) + return str(value) diff --git a/gradio/components/dataframe.py b/gradio/components/dataframe.py index b4e8afc6a770..5f7f6505a5d0 100644 --- a/gradio/components/dataframe.py +++ b/gradio/components/dataframe.py @@ -162,23 +162,16 @@ def __init__( value=value, ) - def preprocess(self, x: dict) -> pd.DataFrame | np.ndarray | list: - """ - Parameters: - x: Dictionary equivalent of DataframeData containing `headers`, `data`, and optionally `metadata` keys - Returns: - The Dataframe data in requested format - """ - value = DataframeData(**x) + def preprocess(self, payload: DataframeData) -> pd.DataFrame | np.ndarray | list: if self.type == "pandas": - if value.headers is not None: - return pd.DataFrame(value.data, columns=value.headers) + if payload.headers is not None: + return pd.DataFrame(payload.data, columns=payload.headers) else: - return pd.DataFrame(value.data) + return pd.DataFrame(payload.data) if self.type == "numpy": - return np.array(value.data) + return np.array(payload.data) elif self.type == "array": - return value.data + return payload.data else: raise ValueError( "Unknown type: " @@ -188,26 +181,27 @@ def preprocess(self, x: dict) -> pd.DataFrame | np.ndarray | list: def postprocess( self, - y: pd.DataFrame | Styler | np.ndarray | list | list[list] | dict | str | None, + value: pd.DataFrame + | Styler + | np.ndarray + | list + | list[list] + | dict + | str + | None, ) -> DataframeData | dict: - """ - Parameters: - y: dataframe in given format - Returns: - JSON object with key 'headers' for list of header names, 'data' for 2D array of string or numeric data - """ - if y is None: + if value is None: return self.postprocess(self.empty_input) - if isinstance(y, dict): - return y - if isinstance(y, (str, pd.DataFrame)): - if isinstance(y, str): - y = pd.read_csv(y) # type: ignore + if isinstance(value, dict): + return value + if isinstance(value, (str, pd.DataFrame)): + if isinstance(value, str): + value = pd.read_csv(value) # type: ignore return DataframeData( - headers=list(y.columns), # type: ignore - data=y.to_dict(orient="split")["data"], # type: ignore + headers=list(value.columns), # type: ignore + data=value.to_dict(orient="split")["data"], # type: ignore ) - elif isinstance(y, Styler): + elif isinstance(value, Styler): if semantic_version.Version(pd.__version__) < semantic_version.Version( "1.5.0" ): @@ -218,39 +212,38 @@ def postprocess( warnings.warn( "Cannot display Styler object in interactive mode. Will display as a regular pandas dataframe instead." ) - df: pd.DataFrame = y.data # type: ignore - value = DataframeData( + df: pd.DataFrame = value.data # type: ignore + return DataframeData( headers=list(df.columns), data=df.to_dict(orient="split")["data"], # type: ignore - metadata=self.__extract_metadata(y), + metadata=self.__extract_metadata(value), ) - elif isinstance(y, (str, pd.DataFrame)): - df = pd.read_csv(y) if isinstance(y, str) else y # type: ignore - value = DataframeData( + elif isinstance(value, (str, pd.DataFrame)): + df = pd.read_csv(value) if isinstance(value, str) else value # type: ignore + return DataframeData( headers=list(df.columns), data=df.to_dict(orient="split")["data"], # type: ignore ) - elif isinstance(y, (np.ndarray, list)): - if len(y) == 0: + elif isinstance(value, (np.ndarray, list)): + if len(value) == 0: return self.postprocess([[]]) - if isinstance(y, np.ndarray): - y = y.tolist() - if not isinstance(y, list): + if isinstance(value, np.ndarray): + value = value.tolist() + if not isinstance(value, list): raise ValueError("output cannot be converted to list") _headers = self.headers - if len(self.headers) < len(y[0]): + if len(self.headers) < len(value[0]): _headers: list[str] = [ *self.headers, - *[str(i) for i in range(len(self.headers) + 1, len(y[0]) + 1)], + *[str(i) for i in range(len(self.headers) + 1, len(value[0]) + 1)], ] - elif len(self.headers) > len(y[0]): - _headers = self.headers[: len(y[0])] + elif len(self.headers) > len(value[0]): + _headers = self.headers[: len(value[0])] - value = DataframeData(headers=_headers, data=y) + return DataframeData(headers=_headers, data=value) else: raise ValueError("Cannot process value as a Dataframe") - return value @staticmethod def __get_cell_style(cell_id: str, cell_styles: list[dict]) -> str: diff --git a/gradio/components/dataset.py b/gradio/components/dataset.py index e96f98ce5b36..53a6c4341332 100644 --- a/gradio/components/dataset.py +++ b/gradio/components/dataset.py @@ -121,16 +121,13 @@ def get_config(self): return config - def preprocess(self, x: Any) -> Any: - """ - Any preprocessing needed to be performed on function input. - """ + def preprocess(self, payload: int) -> int | list[list] | None: if self.type == "index": - return x + return payload elif self.type == "values": - return self.samples[x] + return self.samples[payload] - def postprocess(self, samples: list[list[Any]]) -> dict: + def postprocess(self, samples: list[list]) -> dict: return { "samples": samples, "__type__": "update", diff --git a/gradio/components/dropdown.py b/gradio/components/dropdown.py index 8cc11f954db4..dcf48c2712a1 100644 --- a/gradio/components/dropdown.py +++ b/gradio/components/dropdown.py @@ -134,48 +134,48 @@ def example_inputs(self) -> Any: return self.choices[0][1] if self.choices else None def preprocess( - self, x: str | int | float | list[str | int | float] | None + self, payload: str | int | float | list[str | int | float] | None ) -> str | int | float | list[str | int | float] | list[int | None] | None: - """ - Parameters: - x: selected choice(s) - Returns: - selected choice(s) as string or index within choice list or list of string or indices - """ if self.type == "value": - return x + return payload elif self.type == "index": choice_values = [value for _, value in self.choices] - if x is None: + if payload is None: return None elif self.multiselect: - assert isinstance(x, list) + assert isinstance(payload, list) return [ choice_values.index(choice) if choice in choice_values else None - for choice in x + for choice in payload ] else: - return choice_values.index(x) if x in choice_values else None + return ( + choice_values.index(payload) if payload in choice_values else None + ) else: raise ValueError( f"Unknown type: {self.type}. Please choose from: 'value', 'index'." ) - def _warn_if_invalid_choice(self, y): - if self.allow_custom_value or y in [value for _, value in self.choices]: + def _warn_if_invalid_choice(self, value): + if self.allow_custom_value or value in [value for _, value in self.choices]: return warnings.warn( - f"The value passed into gr.Dropdown() is not in the list of choices. Please update the list of choices to include: {y} or set allow_custom_value=True." + f"The value passed into gr.Dropdown() is not in the list of choices. Please update the list of choices to include: {value} or set allow_custom_value=True." ) - def postprocess(self, y): - if y is None: + def postprocess( + self, value: str | int | float | list[str | int | float] | None + ) -> str | int | float | list[str | int | float] | None: + if value is None: return None if self.multiselect: - [self._warn_if_invalid_choice(_y) for _y in y] + if not isinstance(value, list): + value = [value] + [self._warn_if_invalid_choice(_y) for _y in value] else: - self._warn_if_invalid_choice(y) - return y + self._warn_if_invalid_choice(value) + return value def as_example(self, input_data): if self.multiselect: diff --git a/gradio/components/fallback.py b/gradio/components/fallback.py index 2a229e13c400..3c9b98a6315f 100644 --- a/gradio/components/fallback.py +++ b/gradio/components/fallback.py @@ -2,11 +2,11 @@ class Fallback(Component): - def preprocess(self, x): - return x + def preprocess(self, payload): + return payload - def postprocess(self, x): - return x + def postprocess(self, value): + return value def example_inputs(self): return {"foo": "bar"} diff --git a/gradio/components/file.py b/gradio/components/file.py index 1f1a70583b52..a68a3927a1df 100644 --- a/gradio/components/file.py +++ b/gradio/components/file.py @@ -20,6 +20,12 @@ class ListFiles(GradioRootModel): root: List[FileData] + def __getitem__(self, index): + return self.root[index] + + def __iter__(self): + return iter(self.root) + @document() class File(Component): @@ -111,13 +117,12 @@ def __init__( self.type = type self.height = height - def _process_single_file(self, f: dict[str, Any]) -> bytes | NamedString: - file_name = f["path"] - + def _process_single_file(self, f: FileData) -> NamedString | bytes: + file_name = f.path if self.type == "filepath": file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE) file.name = file_name - return NamedString(file.name) + return NamedString(file_name) elif self.type == "binary": with open(file_name, "rb") as file_data: return file_data.read() @@ -129,38 +134,25 @@ def _process_single_file(self, f: dict[str, Any]) -> bytes | NamedString: ) def preprocess( - self, x: list[dict[str, Any]] | dict[str, Any] | None + self, payload: ListFiles | FileData | None ) -> bytes | NamedString | list[bytes | NamedString] | None: - """ - Parameters: - x: List of JSON objects with filename as 'name' property and base64 data as 'data' property - Returns: - File objects in requested format - """ - if x is None: + if payload is None: return None - if self.file_count == "single": - if isinstance(x, list): - return self._process_single_file(x[0]) + if isinstance(payload, ListFiles): + return self._process_single_file(payload[0]) else: - return self._process_single_file(x) + return self._process_single_file(payload) else: - if isinstance(x, list): - return [self._process_single_file(f) for f in x] + if isinstance(payload, ListFiles): + return [self._process_single_file(f) for f in payload] else: - return [self._process_single_file(x)] + return [self._process_single_file(payload)] - def postprocess(self, y: str | list[str] | None) -> ListFiles | FileData | None: - """ - Parameters: - y: file path - Returns: - JSON object with key 'name' for filename, 'data' for base64 url, and 'size' for filesize in bytes - """ - if y is None: + def postprocess(self, value: str | list[str] | None) -> ListFiles | FileData | None: + if value is None: return None - if isinstance(y, list): + if isinstance(value, list): return ListFiles( root=[ FileData( @@ -168,14 +160,14 @@ def postprocess(self, y: str | list[str] | None) -> ListFiles | FileData | None: orig_name=Path(file).name, size=Path(file).stat().st_size, ) - for file in y + for file in value ] ) else: return FileData( - path=y, - orig_name=Path(y).name, - size=Path(y).stat().st_size, + path=value, + orig_name=Path(value).name, + size=Path(value).stat().st_size, ) def as_example(self, input_data: str | list | None) -> str: diff --git a/gradio/components/file_explorer.py b/gradio/components/file_explorer.py index 1d1a092c1682..850e94766e2b 100644 --- a/gradio/components/file_explorer.py +++ b/gradio/components/file_explorer.py @@ -104,49 +104,39 @@ def __init__( def example_inputs(self) -> Any: return ["Users", "gradio", "app.py"] - def preprocess(self, x: list[list[str]] | None) -> list[str] | str | None: - """ - Parameters: - x: File path segments as a list of list of strings for each file relative to the root. - Returns: - File path selected, as an absolute path. - """ - if x is None: + def preprocess(self, payload: list[list[str]] | None) -> list[str] | str | None: + if payload is None: return None if self.file_count == "single": - if len(x) > 1: - raise ValueError(f"Expected only one file, but {len(x)} were selected.") - return self._safe_join(x[0]) + if len(payload) > 1: + raise ValueError( + f"Expected only one file, but {len(payload)} were selected." + ) + return self._safe_join(payload[0]) - return [self._safe_join(file) for file in (x)] + return [self._safe_join(file) for file in (payload)] def _strip_root(self, path): if path.startswith(self.root): return path[len(self.root) + 1 :] return path - def postprocess(self, y: str | list[str] | None) -> FileExplorerData | None: - """ - Parameters: - y: file path - Returns: - list representing filepath, where each string is a directory level relative to the root. - """ - if y is None: + def postprocess(self, value: str | list[str] | None) -> FileExplorerData | None: + if value is None: return None - files = [y] if isinstance(y, str) else y + files = [value] if isinstance(value, str) else value return FileExplorerData( root=[self._strip_root(file).split(os.path.sep) for file in files] ) @server - def ls(self, y=None) -> list[dict[str, str]] | None: + def ls(self, value=None) -> list[dict[str, str]] | None: """ Parameters: - y: file path as a list of strings for each directory level relative to the root. + value: file path as a list of strings for each directory level relative to the root. Returns: tuple of list of files in directory, then list of folders in directory """ diff --git a/gradio/components/gallery.py b/gradio/components/gallery.py index 6ab609f6509b..cdf4ad181994 100644 --- a/gradio/components/gallery.py +++ b/gradio/components/gallery.py @@ -124,20 +124,20 @@ def __init__( def postprocess( self, - y: list[np.ndarray | _Image.Image | str] + value: list[np.ndarray | _Image.Image | str] | list[tuple[np.ndarray | _Image.Image | str, str]] | None, ) -> GalleryData: """ Parameters: - y: list of images, or list of (image, caption) tuples + value: list of images, or list of (image, caption) tuples Returns: list of string file paths to images in temp directory """ - if y is None: + if value is None: return GalleryData(root=[]) output = [] - for img in y: + for img in value: caption = None if isinstance(img, (tuple, list)): img, caption = img @@ -160,8 +160,10 @@ def postprocess( output.append(entry) return GalleryData(root=output) - def preprocess(self, x: Any) -> Any: - return x + def preprocess(self, payload: GalleryData | None) -> GalleryData | None: + if payload is None or not payload.root: + return None + return payload def example_inputs(self) -> Any: return [ diff --git a/gradio/components/highlighted_text.py b/gradio/components/highlighted_text.py index dc3115c89490..ac0c7d96ac58 100644 --- a/gradio/components/highlighted_text.py +++ b/gradio/components/highlighted_text.py @@ -99,27 +99,27 @@ def example_inputs(self) -> Any: return {"value": [{"token": "Hello", "class_or_confidence": "1"}]} def postprocess( - self, y: list[tuple[str, str | float | None]] | dict | None + self, value: list[tuple[str, str | float | None]] | dict | None ) -> HighlightedTextData | None: """ Parameters: - y: List of (word, category) tuples, or a dictionary of two keys: "text", and "entities", which itself is a list of dictionaries, each of which have the keys: "entity" (or "entity_group"), "start", and "end" + value: List of (word, category) tuples, or a dictionary of two keys: "text", and "entities", which itself is a list of dictionaries, each of which have the keys: "entity" (or "entity_group"), "start", and "end" Returns: List of (word, category) tuples """ - if y is None: + if value is None: return None - if isinstance(y, dict): + if isinstance(value, dict): try: - text = y["text"] - entities = y["entities"] + text = value["text"] + entities = value["entities"] except KeyError as ke: raise ValueError( "Expected a dictionary with keys 'text' and 'entities' " "for the value of the HighlightedText component." ) from ke if len(entities) == 0: - y = [(text, None)] + value = [(text, None)] else: list_format = [] index = 0 @@ -132,11 +132,11 @@ def postprocess( ) index = entity["end"] list_format.append((text[index:], None)) - y = list_format + value = list_format if self.combine_adjacent: output = [] running_text, running_category = None, None - for text, category in y: + for text, category in value: if running_text is None: running_text = text running_category = category @@ -160,8 +160,13 @@ def postprocess( ) else: return HighlightedTextData( - root=[HighlightedToken(token=o[0], class_or_confidence=o[1]) for o in y] + root=[ + HighlightedToken(token=o[0], class_or_confidence=o[1]) + for o in value + ] ) - def preprocess(self, x: Any) -> Any: - return super().preprocess(x) + def preprocess(self, payload: HighlightedTextData | None) -> dict | None: + if payload is None: + return None + return payload.model_dump() diff --git a/gradio/components/html.py b/gradio/components/html.py index b439d063bbb3..8534446a2d12 100644 --- a/gradio/components/html.py +++ b/gradio/components/html.py @@ -62,11 +62,11 @@ def __init__( def example_inputs(self) -> Any: return "

Hello

" - def preprocess(self, x: Any) -> Any: - return x + def preprocess(self, payload: str | None) -> str | None: + return payload - def postprocess(self, y): - return y + def postprocess(self, value: str | None) -> str | None: + return value def api_info(self) -> dict[str, Any]: return {"type": "string"} diff --git a/gradio/components/image.py b/gradio/components/image.py index d2e7691646d8..dc525775a3b7 100644 --- a/gradio/components/image.py +++ b/gradio/components/image.py @@ -38,6 +38,7 @@ class Image(StreamingInput, Component): Events.select, Events.upload, ] + data_model = FileData def __init__( @@ -141,38 +142,25 @@ def __init__( value=value, ) - def preprocess(self, x: dict | None) -> np.ndarray | _Image.Image | str | None: - """ - Parameters: - x: FileData containing an image path pointing to the user's image - Returns: - image in requested format, or (if tool == "sketch") a dict of image and mask in requested format - """ - if x is None: - return x - - im = _Image.open(x["path"]) + def preprocess( + self, payload: FileData | None + ) -> np.ndarray | _Image.Image | str | None: + if payload is None: + return payload + im = _Image.open(payload.path) with warnings.catch_warnings(): warnings.simplefilter("ignore") im = im.convert(self.image_mode) - return image_utils.format_image( im, cast(Literal["numpy", "pil", "filepath"], self.type), self.GRADIO_CACHE ) def postprocess( - self, y: np.ndarray | _Image.Image | str | Path | None + self, value: np.ndarray | _Image.Image | str | Path | None ) -> FileData | None: - """ - Parameters: - y: image as a numpy array, PIL Image, string/Path filepath, or string URL - Returns: - base64 url data - """ - if y is None: + if value is None: return None - - return FileData(path=image_utils.save_image(y, self.GRADIO_CACHE)) + return FileData(path=image_utils.save_image(value, self.GRADIO_CACHE)) def check_streamable(self): if self.streaming and self.sources != ("webcam"): diff --git a/gradio/components/json_component.py b/gradio/components/json_component.py index 8b5fda5b7ea2..c444e828ebfa 100644 --- a/gradio/components/json_component.py +++ b/gradio/components/json_component.py @@ -69,31 +69,25 @@ def __init__( value=value, ) - def postprocess(self, y: dict | list | str | None) -> dict | list | None: - """ - Parameters: - y: either a string filepath to a JSON file, or a Python list or dict that can be converted to JSON - Returns: - JSON output in Python list or dict format - """ - if y is None: + def postprocess(self, value: dict | list | str | None) -> dict | list | None: + if value is None: return None - if isinstance(y, str): - return json.loads(y) + if isinstance(value, str): + return json.loads(value) else: - return y + return value - def preprocess(self, x: Any) -> Any: - return x + def preprocess(self, payload: dict | list | str | None) -> dict | list | str | None: + return payload def example_inputs(self) -> Any: return {"foo": "bar"} - def flag(self, x: Any, flag_dir: str | Path = "") -> str: - return json.dumps(x) + def flag(self, payload: Any, flag_dir: str | Path = "") -> str: + return json.dumps(payload) - def read_from_flag(self, x: Any, flag_dir: str | Path | None = None): - return json.loads(x) + def read_from_flag(self, payload: Any, flag_dir: str | Path | None = None): + return json.loads(payload) def api_info(self) -> dict[str, Any]: return {"type": {}, "description": "any valid json"} diff --git a/gradio/components/label.py b/gradio/components/label.py index 0ab154c8b410..3d437a504e8b 100644 --- a/gradio/components/label.py +++ b/gradio/components/label.py @@ -22,7 +22,7 @@ class LabelConfidence(GradioModel): class LabelData(GradioModel): - label: Union[str, int, float] + label: Optional[Union[str, int, float]] = None confidences: Optional[List[LabelConfidence]] = None @@ -91,44 +91,46 @@ def __init__( ) def postprocess( - self, y: dict[str, float] | str | float | None + self, value: dict[str, float] | str | float | None ) -> LabelData | dict | None: - """ - Parameters: - y: a dictionary mapping labels to confidence value, or just a string/numerical label by itself - Returns: - Object with key 'label' representing primary label, and key 'confidences' representing a list of label-confidence pairs - """ - if y is None or y == {}: + if value is None or value == {}: return {} - if isinstance(y, str) and y.endswith(".json") and Path(y).exists(): - return LabelData(**json.loads(Path(y).read_text())) - if isinstance(y, (str, float, int)): - return LabelData(label=str(y)) - if isinstance(y, dict): - if "confidences" in y and isinstance(y["confidences"], dict): - y = y["confidences"] - y = {c["label"]: c["confidence"] for c in y} - sorted_pred = sorted(y.items(), key=operator.itemgetter(1), reverse=True) + if isinstance(value, str) and value.endswith(".json") and Path(value).exists(): + return LabelData(**json.loads(Path(value).read_text())) + if isinstance(value, (str, float, int)): + return LabelData(label=str(value)) + if isinstance(value, dict): + if "confidences" in value and isinstance(value["confidences"], dict): + value = value["confidences"] + value = {c["label"]: c["confidence"] for c in value} + sorted_pred = sorted( + value.items(), key=operator.itemgetter(1), reverse=True + ) if self.num_top_classes is not None: sorted_pred = sorted_pred[: self.num_top_classes] return LabelData( - **{ - "label": sorted_pred[0][0], - "confidences": [ - {"label": pred[0], "confidence": pred[1]} - for pred in sorted_pred - ], - } + label=sorted_pred[0][0], + confidences=[ + LabelConfidence(label=pred[0], confidence=pred[1]) + for pred in sorted_pred + ], ) raise ValueError( "The `Label` output interface expects one of: a string label, or an int label, a " "float label, or a dictionary whose keys are labels and values are confidences. " - f"Instead, got a {type(y)}" + f"Instead, got a {type(value)}" ) - def preprocess(self, x: Any) -> Any: - return x + def preprocess( + self, payload: LabelData | None + ) -> dict[str, float] | str | float | None: + if payload is None: + return None + if payload.confidences is None: + return payload.label + return { + d["label"]: d["confidence"] for d in payload.model_dump()["confidences"] + } def example_inputs(self) -> Any: return { diff --git a/gradio/components/line_plot.py b/gradio/components/line_plot.py index 4d7692701a64..ab6b6172c8cf 100644 --- a/gradio/components/line_plot.py +++ b/gradio/components/line_plot.py @@ -287,15 +287,15 @@ def create_plot( return chart def postprocess( - self, y: pd.DataFrame | dict | None + self, value: pd.DataFrame | dict | None ) -> AltairPlotData | dict | None: # if None or update - if y is None or isinstance(y, dict): - return y + if value is None or isinstance(value, dict): + return value if self.x is None or self.y is None: raise ValueError("No value provided for required parameters `x` and `y`.") chart = self.create_plot( - value=y, + value=value, x=self.x, y=self.y, color=self.color, @@ -318,12 +318,10 @@ def postprocess( width=self.width, ) - return AltairPlotData( - **{"type": "altair", "plot": chart.to_json(), "chart": "line"} - ) + return AltairPlotData(type="altair", plot=chart.to_json(), chart="line") def example_inputs(self) -> Any: return None - def preprocess(self, x: Any) -> Any: - return x + def preprocess(self, value: AltairPlotData | None) -> AltairPlotData | None: + return value diff --git a/gradio/components/markdown.py b/gradio/components/markdown.py index f9cd602ea712..2e13aa1a26f0 100644 --- a/gradio/components/markdown.py +++ b/gradio/components/markdown.py @@ -75,24 +75,18 @@ def __init__( value=value, ) - def postprocess(self, y: str | None) -> str | None: - """ - Parameters: - y: markdown representation - Returns: - HTML rendering of markdown - """ - if y is None: + def postprocess(self, value: str | None) -> str | None: + if value is None: return None - unindented_y = inspect.cleandoc(y) + unindented_y = inspect.cleandoc(value) return unindented_y def as_example(self, input_data: str | None) -> str: postprocessed = self.postprocess(input_data) return postprocessed if postprocessed else "" - def preprocess(self, x: Any) -> Any: - return x + def preprocess(self, payload: str | None) -> str | None: + return payload def example_inputs(self) -> Any: return "# Hello!" diff --git a/gradio/components/model3d.py b/gradio/components/model3d.py index c62be35de73e..f3977b9ecf73 100644 --- a/gradio/components/model3d.py +++ b/gradio/components/model3d.py @@ -93,27 +93,15 @@ def __init__( value=value, ) - def preprocess(self, x: dict[str, str] | None) -> str | None: - """ - Parameters: - x: JSON object with filename as 'name' property and base64 data as 'data' property - Returns: - string file path to temporary file with the 3D image model - """ - if x is None: - return x - return x["path"] + def preprocess(self, payload: FileData | None) -> str | None: + if payload is None: + return payload + return payload.path - def postprocess(self, y: str | Path | None) -> FileData | None: - """ - Parameters: - y: path to the model - Returns: - file name mapped to base64 url data - """ - if y is None: - return y - return FileData(path=str(y)) + def postprocess(self, value: str | Path | None) -> FileData | None: + if value is None: + return value + return FileData(path=str(value)) def as_example(self, input_data: str | None) -> str: return Path(input_data).name if input_data else "" diff --git a/gradio/components/number.py b/gradio/components/number.py index 72143b515a7e..aa8088653e10 100644 --- a/gradio/components/number.py +++ b/gradio/components/number.py @@ -108,33 +108,21 @@ def _round_to_precision(num: float | int, precision: int | None) -> float | int: else: return round(num, precision) - def preprocess(self, x: float | None) -> float | None: - """ - Parameters: - x: numeric input - Returns: - number representing function input - """ - if x is None: + def preprocess(self, payload: float | None) -> float | None: + if payload is None: return None - elif self.minimum is not None and x < self.minimum: - raise Error(f"Value {x} is less than minimum value {self.minimum}.") - elif self.maximum is not None and x > self.maximum: - raise Error(f"Value {x} is greater than maximum value {self.maximum}.") - return self._round_to_precision(x, self.precision) - - def postprocess(self, y: float | None) -> float | None: - """ - Any postprocessing needed to be performed on function output. - - Parameters: - y: numeric output - Returns: - number representing function output - """ - if y is None: + elif self.minimum is not None and payload < self.minimum: + raise Error(f"Value {payload} is less than minimum value {self.minimum}.") + elif self.maximum is not None and payload > self.maximum: + raise Error( + f"Value {payload} is greater than maximum value {self.maximum}." + ) + return self._round_to_precision(payload, self.precision) + + def postprocess(self, value: float | None) -> float | None: + if value is None: return None - return self._round_to_precision(y, self.precision) + return self._round_to_precision(value, self.precision) def api_info(self) -> dict[str, str]: return {"type": "number"} diff --git a/gradio/components/plot.py b/gradio/components/plot.py index 19b50ecd67c6..673b51dd59de 100644 --- a/gradio/components/plot.py +++ b/gradio/components/plot.py @@ -97,36 +97,30 @@ def get_config(self): config["bokeh_version"] = bokeh_version return config - def preprocess(self, x: Any) -> Any: - return x + def preprocess(self, payload: PlotData | None) -> PlotData | None: + return payload def example_inputs(self) -> Any: return None - def postprocess(self, y) -> PlotData | None: - """ - Parameters: - y: plot data - Returns: - plot type mapped to plot base64 data - """ + def postprocess(self, value) -> PlotData | None: import matplotlib.figure - if y is None: + if value is None: return None - if isinstance(y, (ModuleType, matplotlib.figure.Figure)): # type: ignore + if isinstance(value, (ModuleType, matplotlib.figure.Figure)): # type: ignore dtype = "matplotlib" - out_y = processing_utils.encode_plot_to_base64(y) - elif "bokeh" in y.__module__: + out_y = processing_utils.encode_plot_to_base64(value) + elif "bokeh" in value.__module__: dtype = "bokeh" from bokeh.embed import json_item # type: ignore - out_y = json.dumps(json_item(y)) + out_y = json.dumps(json_item(value)) else: - is_altair = "altair" in y.__module__ + is_altair = "altair" in value.__module__ dtype = "altair" if is_altair else "plotly" - out_y = y.to_json() - return PlotData(**{"type": dtype, "plot": out_y}) + out_y = value.to_json() + return PlotData(type=dtype, plot=out_y) class AltairPlot: diff --git a/gradio/components/radio.py b/gradio/components/radio.py index 37c7ecec7dfa..3b57b64969fb 100644 --- a/gradio/components/radio.py +++ b/gradio/components/radio.py @@ -94,28 +94,30 @@ def __init__( def example_inputs(self) -> Any: return self.choices[0][1] if self.choices else None - def preprocess(self, x: str | int | float | None) -> str | int | float | None: + def preprocess(self, payload: str | int | float | None) -> str | int | float | None: """ Parameters: - x: selected choice + payload: selected choice Returns: value of the selected choice as string or index within choice list """ if self.type == "value": - return x + return payload elif self.type == "index": - if x is None: + if payload is None: return None else: choice_values = [value for _, value in self.choices] - return choice_values.index(x) if x in choice_values else None + return ( + choice_values.index(payload) if payload in choice_values else None + ) else: raise ValueError( f"Unknown type: {self.type}. Please choose from: 'value', 'index'." ) - def postprocess(self, y): - return y + def postprocess(self, value: str | int | float | None) -> str | int | float | None: + return value def api_info(self) -> dict[str, Any]: return { diff --git a/gradio/components/scatter_plot.py b/gradio/components/scatter_plot.py index ca23744cdd68..18f6df639be6 100644 --- a/gradio/components/scatter_plot.py +++ b/gradio/components/scatter_plot.py @@ -310,15 +310,15 @@ def create_plot( return chart def postprocess( - self, y: pd.DataFrame | dict | None + self, value: pd.DataFrame | dict | None ) -> AltairPlotData | dict | None: # if None or update - if y is None or isinstance(y, dict): - return y + if value is None or isinstance(value, dict): + return value if self.x is None or self.y is None: raise ValueError("No value provided for required parameters `x` and `y`.") chart = self.create_plot( - value=y, + value=value, x=self.x, y=self.y, color=self.color, @@ -343,12 +343,10 @@ def postprocess( y_lim=self.y_lim, ) - return AltairPlotData( - **{"type": "altair", "plot": chart.to_json(), "chart": "scatter"} - ) + return AltairPlotData(type="altair", plot=chart.to_json(), chart="scatter") def example_inputs(self) -> Any: return None - def preprocess(self, x: Any) -> Any: - return x + def preprocess(self, payload: AltairPlotData | None) -> AltairPlotData | None: + return payload diff --git a/gradio/components/slider.py b/gradio/components/slider.py index 9261fdfe65a8..85775f276280 100644 --- a/gradio/components/slider.py +++ b/gradio/components/slider.py @@ -114,15 +114,8 @@ def get_random_value(self): value = round(value, n_decimals) return value - def postprocess(self, y: float | None) -> float | None: - """ - Any postprocessing needed to be performed on function output. - Parameters: - y: numeric output - Returns: - numeric output or minimum number if None - """ - return self.minimum if y is None else y + def postprocess(self, value: float | None) -> float: + return self.minimum if value is None else value - def preprocess(self, x: Any) -> Any: - return x + def preprocess(self, payload: float) -> float: + return payload diff --git a/gradio/components/state.py b/gradio/components/state.py index dc3334c7608e..c8b9c67f2c56 100644 --- a/gradio/components/state.py +++ b/gradio/components/state.py @@ -46,11 +46,11 @@ def __init__( ) from err super().__init__(value=self.value) - def preprocess(self, x: Any) -> Any: - return x + def preprocess(self, payload: Any) -> Any: + return payload - def postprocess(self, y): - return y + def postprocess(self, value: Any) -> Any: + return value def api_info(self) -> dict[str, Any]: return {"type": {}, "description": "any valid json"} diff --git a/gradio/components/textbox.py b/gradio/components/textbox.py index fa3a1e00e7ec..c16cb651b379 100644 --- a/gradio/components/textbox.py +++ b/gradio/components/textbox.py @@ -117,25 +117,11 @@ def __init__( self.rtl = rtl self.text_align = text_align - def preprocess(self, x: str | None) -> str | None: - """ - Preprocesses input (converts it to a string) before passing it to the function. - Parameters: - x: text - Returns: - text - """ - return None if x is None else str(x) + def preprocess(self, payload: str | None) -> str | None: + return None if payload is None else str(payload) - def postprocess(self, y: str | None) -> str | None: - """ - Postproccess the function output y by converting it to a str before passing it to the frontend. - Parameters: - y: function output to postprocess. - Returns: - text - """ - return None if y is None else str(y) + def postprocess(self, value: str | None) -> str | None: + return None if value is None else str(value) def api_info(self) -> dict[str, Any]: return {"type": "string"} diff --git a/gradio/components/upload_button.py b/gradio/components/upload_button.py index ebf280b00c7f..f1dc5c090c7e 100644 --- a/gradio/components/upload_button.py +++ b/gradio/components/upload_button.py @@ -4,6 +4,7 @@ import tempfile import warnings +from pathlib import Path from typing import Any, Callable, List, Literal from gradio_client.documentation import document, set_documentation_group @@ -19,6 +20,12 @@ class ListFiles(GradioRootModel): root: List[FileData] + def __getitem__(self, index): + return self.root[index] + + def __iter__(self): + return iter(self.root) + @document() class UploadButton(Component): @@ -87,6 +94,10 @@ def __init__( raise ValueError( f"Parameter file_types must be a list. Received {file_types.__class__.__name__}" ) + if self.file_count == "multiple": + self.data_model = ListFiles + else: + self.data_model = FileData self.size = size self.file_types = file_types self.label = label @@ -118,12 +129,12 @@ def example_inputs(self) -> Any: "https://github.com/gradio-app/gradio/raw/main/test/test_files/sample_file.pdf" ] - def _process_single_file(self, f: dict[str, Any]) -> bytes | NamedString: - file_name = f["path"] + def _process_single_file(self, f: FileData) -> bytes | NamedString: + file_name = f.path if self.type == "filepath": file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE) file.name = file_name - return NamedString(file.name) + return NamedString(file_name) elif self.type == "binary": with open(file_name, "rb") as file_data: return file_data.read() @@ -135,30 +146,42 @@ def _process_single_file(self, f: dict[str, Any]) -> bytes | NamedString: ) def preprocess( - self, x: list[dict[str, Any]] | dict[str, Any] | None + self, payload: ListFiles | FileData | None ) -> bytes | NamedString | list[bytes | NamedString] | None: - """ - Parameters: - x: List of JSON objects with filename as 'name' property and base64 data as 'data' property - Returns: - File objects in requested format - """ - if x is None: + if payload is None: return None if self.file_count == "single": - if isinstance(x, list): - return self._process_single_file(x[0]) + if isinstance(payload, ListFiles): + return self._process_single_file(payload[0]) else: - return self._process_single_file(x) + return self._process_single_file(payload) else: - if isinstance(x, list): - return [self._process_single_file(f) for f in x] + if isinstance(payload, ListFiles): + return [self._process_single_file(f) for f in payload] else: - return [self._process_single_file(x)] + return [self._process_single_file(payload)] - def postprocess(self, y): - return super().postprocess(y) + def postprocess(self, value: str | list[str] | None) -> ListFiles | FileData | None: + if value is None: + return None + if isinstance(value, list): + return ListFiles( + root=[ + FileData( + path=file, + orig_name=Path(file).name, + size=Path(file).stat().st_size, + ) + for file in value + ] + ) + else: + return FileData( + path=value, + orig_name=Path(value).name, + size=Path(value).stat().st_size, + ) @property def skip_api(self): diff --git a/gradio/components/video.py b/gradio/components/video.py index 7758b6d2f0f5..f4e0b9009b75 100644 --- a/gradio/components/video.py +++ b/gradio/components/video.py @@ -43,7 +43,7 @@ class Video(Component): """ data_model = VideoData - input_data_model = FileData + EVENTS = [ Events.change, Events.clear, @@ -155,18 +155,11 @@ def __init__( value=value, ) - def preprocess(self, x: dict | VideoData) -> str | None: - """ - Parameters: - x: A tuple of (video file data, subtitle file data) or just video file data. - Returns: - A string file path or URL to the preprocessed video. Subtitle file data is ignored. - """ - if x is None: + def preprocess(self, payload: VideoData | None) -> str | None: + if payload is None: return None - data: VideoData = VideoData(**x) if isinstance(x, dict) else x - assert data.video.path - file_name = Path(data.video.path) + assert payload.video.path + file_name = Path(payload.video.path) uploaded_format = file_name.suffix.replace(".", "") needs_formatting = self.format is not None and uploaded_format != self.format flip = self.sources == ["webcam"] and self.mirror_webcam @@ -221,24 +214,6 @@ def preprocess(self, x: dict | VideoData) -> str | None: def postprocess( self, y: str | Path | tuple[str | Path, str | Path | None] | None ) -> VideoData | None: - """ - Processes a video to ensure that it is in the correct format before returning it to the front end. - Parameters: - y: video data in either of the following formats: a tuple of (video filepath, optional subtitle filepath), or just a filepath or URL to an video file, or None. - Returns: - a tuple with the two dictionary, reresent to video and (optional) subtitle, which following formats: - - The first dictionary represents the video file and contains the following keys: - - 'name': a file path to a temporary copy of the processed video. - - 'data': None - - 'is_file': True - - The second dictionary represents the subtitle file and contains the following keys: - - 'name': None - - 'data': Base64 encode the processed subtitle data. - - 'is_file': False - - If subtitle is None, returns (video, None). - - If both video and subtitle are None, returns None. - """ - if y is None or y == [None, None] or y == (None, None): return None if isinstance(y, (str, Path)): @@ -269,14 +244,6 @@ def postprocess( def _format_video(self, video: str | Path | None) -> FileData | None: """ Processes a video to ensure that it is in the correct format. - Parameters: - video: video data in either of the following formats: a string filepath or URL to an video file, or None. - Returns: - a dictionary with the following keys: - - - 'name': a file path to a temporary copy of the processed video. - - 'data': None - - 'is_file': True """ if video is None: return None @@ -328,13 +295,6 @@ def _format_video(self, video: str | Path | None) -> FileData | None: def _format_subtitle(self, subtitle: str | Path | None) -> FileData | None: """ Convert subtitle format to VTT and process the video to ensure it meets the HTML5 requirements. - Parameters: - subtitle: subtitle path in either of the VTT and SRT format. - Returns: - a dictionary with the following keys: - - 'name': None - - 'data': base64-encoded subtitle data. - - 'is_file': False """ def srt_to_vtt(srt_file_path, vtt_file_path): diff --git a/gradio/queueing.py b/gradio/queueing.py index d2de530fa588..ae347512daee 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -173,8 +173,13 @@ def get_events_in_batch(self) -> tuple[list[Event] | None, bool]: if concurrency_limit is None or existing_worker_count < concurrency_limit: batch = block_fn.batch if batch: - remaining_worker_count = concurrency_limit - existing_worker_count batch_size = block_fn.max_batch_size + if concurrency_limit is None: + remaining_worker_count = batch_size - 1 + else: + remaining_worker_count = ( + concurrency_limit - existing_worker_count + ) rest_of_batch = [ event for event in self.event_queue[index:] diff --git a/js/uploadbutton/shared/UploadButton.svelte b/js/uploadbutton/shared/UploadButton.svelte index 22fad81425f5..ed5f5a2509b8 100644 --- a/js/uploadbutton/shared/UploadButton.svelte +++ b/js/uploadbutton/shared/UploadButton.svelte @@ -52,9 +52,9 @@ all_file_data = (await upload(all_file_data, root))?.filter( (x) => x !== null ) as FileData[]; - dispatch("change", all_file_data); - dispatch("upload", all_file_data); - value = all_file_data; + value = file_count === "single" ? all_file_data?.[0] : all_file_data + dispatch("change", value); + dispatch("upload", value); } async function loadFilesFromUpload(e: Event): Promise { diff --git a/test/test_components.py b/test/test_components.py index 1b1731e14511..026b85e787ae 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -30,6 +30,8 @@ import gradio as gr from gradio import processing_utils, utils +from gradio.components.dataframe import DataframeData +from gradio.components.video import VideoData from gradio.data_classes import FileData os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -480,9 +482,9 @@ def test_component_functions(self): """ dropdown_input = gr.Dropdown(["a", "b", ("c", "c full")], multiselect=True) assert dropdown_input.preprocess("a") == "a" - assert dropdown_input.postprocess("a") == "a" + assert dropdown_input.postprocess("a") == ["a"] assert dropdown_input.preprocess("c full") == "c full" - assert dropdown_input.postprocess("c full") == "c full" + assert dropdown_input.postprocess("c full") == ["c full"] # When a Gradio app is loaded with gr.load, the tuples are converted to lists, # so we need to test that case as well @@ -558,7 +560,7 @@ def test_component_functions(self, gradio_temp_dir): type: pil, file, filepath, numpy """ - img = dict(FileData(path="test/test_files/bus.png")) + img = FileData(path="test/test_files/bus.png") image_input = gr.Image() image_input = gr.Image(type="filepath") @@ -605,12 +607,12 @@ def test_component_functions(self, gradio_temp_dir): # Output functionalities image_output = gr.Image(type="pil") processed_image = image_output.postprocess( - PIL.Image.open(img["path"]) + PIL.Image.open(img.path) ).model_dump() assert processed_image is not None if processed_image is not None: processed = PIL.Image.open(cast(dict, processed_image).get("path", "")) - source = PIL.Image.open(img["path"]) + source = PIL.Image.open(img.path) assert processed.size == source.size def test_in_interface_as_output(self): @@ -696,7 +698,7 @@ def test_component_functions(self, gradio_temp_dir): Preprocess, postprocess serialize, get_config, deserialize type: filepath, numpy, file """ - x_wav = deepcopy(media_data.BASE64_AUDIO) + x_wav = FileData(path=media_data.BASE64_AUDIO["path"]) audio_input = gr.Audio() output1 = audio_input.preprocess(x_wav) assert output1[0] == 8000 @@ -735,11 +737,6 @@ def test_component_functions(self, gradio_temp_dir): "_selectable": False, } assert audio_input.preprocess(None) is None - x_wav["is_example"] = True - x_wav["crop_min"], x_wav["crop_max"] = 1, 4 - output2 = audio_input.preprocess(x_wav) - assert output2 is not None - assert output1 != output2 audio_input = gr.Audio(type="filepath") assert isinstance(audio_input.preprocess(x_wav), str) @@ -821,21 +818,21 @@ def generate_noise(duration): assert iface(100).endswith(".wav") def test_audio_preprocess_can_be_read_by_scipy(self, gradio_temp_dir): - x_wav = { - "path": processing_utils.save_base64_to_cache( + x_wav = FileData( + path=processing_utils.save_base64_to_cache( media_data.BASE64_MICROPHONE["data"], cache_dir=gradio_temp_dir - ), - } + ) + ) audio_input = gr.Audio(type="filepath") output = audio_input.preprocess(x_wav) wavfile.read(output) def test_prepost_process_to_mp3(self, gradio_temp_dir): - x_wav = { - "path": processing_utils.save_base64_to_cache( + x_wav = FileData( + path=processing_utils.save_base64_to_cache( media_data.BASE64_MICROPHONE["data"], cache_dir=gradio_temp_dir - ), - } + ) + ) audio_input = gr.Audio(type="filepath", format="mp3") output = audio_input.preprocess(x_wav) assert output.endswith("mp3") @@ -850,13 +847,13 @@ def test_component_functions(self): """ Preprocess, serialize, get_config, value """ - x_file = deepcopy(media_data.BASE64_FILE) + x_file = FileData(path=media_data.BASE64_FILE["path"]) file_input = gr.File() - output = file_input.preprocess({"path": x_file["path"]}) + output = file_input.preprocess(x_file) assert isinstance(output, str) - input1 = file_input.preprocess({"path": x_file["path"]}) - input2 = file_input.preprocess({"path": x_file["path"]}) + input1 = file_input.preprocess(x_file) + input2 = file_input.preprocess(x_file) assert input1 == input1.name # Testing backwards compatibility assert input1 == input2 assert Path(input1).name == "sample_file.pdf" @@ -884,7 +881,7 @@ def test_component_functions(self): assert file_input.preprocess(None) is None assert file_input.preprocess(x_file) is not None - zero_size_file = {"path": "document.txt", "size": 0} + zero_size_file = FileData(path="document.txt", size=0) temp_file = file_input.preprocess(zero_size_file) assert not Path(temp_file.name).exists() @@ -933,13 +930,13 @@ def test_component_functions(self): """ preprocess """ - x_file = deepcopy(media_data.BASE64_FILE) + x_file = FileData(path=media_data.BASE64_FILE["path"]) upload_input = gr.UploadButton() - input = upload_input.preprocess({"path": x_file}) + input = upload_input.preprocess(x_file) assert isinstance(input, str) - input1 = upload_input.preprocess({"path": x_file}) - input2 = upload_input.preprocess({"path": x_file}) + input1 = upload_input.preprocess(x_file) + input2 = upload_input.preprocess(x_file) assert input1 == input1.name # Testing backwards compatibility assert input1 == input2 @@ -961,7 +958,7 @@ def test_component_functions(self): "metadata": None, } dataframe_input = gr.Dataframe(headers=["Name", "Age", "Member"]) - output = dataframe_input.preprocess(x_data) + output = dataframe_input.preprocess(DataframeData(**x_data)) assert output["Age"][1] == 24 assert not output["Member"][0] assert dataframe_input.postprocess(x_data) == x_data @@ -998,7 +995,7 @@ def test_component_functions(self): "column_widths": [], } dataframe_input = gr.Dataframe() - output = dataframe_input.preprocess(x_data) + output = dataframe_input.preprocess(DataframeData(**x_data)) assert output["Age"][1] == 24 with pytest.raises(ValueError): gr.Dataframe(type="unknown") @@ -1315,7 +1312,9 @@ def test_component_functions(self): """ Preprocess, serialize, deserialize, get_config """ - x_video = {"video": {"path": deepcopy(media_data.BASE64_VIDEO)["path"]}} + x_video = VideoData( + video=FileData(path=deepcopy(media_data.BASE64_VIDEO)["path"]) + ) video_input = gr.Video() x_video = processing_utils.move_files_to_cache([x_video], video_input)[0] @@ -1357,8 +1356,6 @@ def test_component_functions(self): "_selectable": False, } assert video_input.preprocess(None) is None - x_video["is_example"] = True - assert video_input.preprocess(x_video) is not None video_input = gr.Video(format="avi") output_video = video_input.preprocess(x_video) assert output_video[-3:] == "avi" @@ -1468,7 +1465,7 @@ def test_video_postprocess_converts_to_playable_format(self): @patch("gradio.components.video.FFmpeg") def test_video_preprocessing_flips_video_for_webcam(self, mock_ffmpeg): # Ensures that the cached temp video file is not used so that ffmpeg is called for each test - x_video = {"video": deepcopy(media_data.BASE64_VIDEO)} + x_video = VideoData(video=FileData(path=media_data.BASE64_VIDEO["path"])) video_input = gr.Video(sources=["webcam"]) _ = video_input.preprocess(x_video) diff --git a/test/test_files/audio_sample.wav b/test/test_files/audio_sample.wav index c4d8a40e21c1f0135e88b80a5dd635200d0c4ecc..495f00d86826e503f2cdc264d67f592e41263972 100644 GIT binary patch delta 14 VcmaD=-%-aJ( delta 242 zcmeCEdsWXGy+yP07qx@X1UnN-U~m@N^6c;RDKgrw%rn?w311ScEv=GmbWFsI0i=82kK2AUxoVJDpxEccG PG3*U-_V)vdqu2}p76U;y diff --git a/test/test_theme_sharing.py b/test/test_theme_sharing.py index 7fd4eaf22941..237ed0f6f02a 100644 --- a/test/test_theme_sharing.py +++ b/test/test_theme_sharing.py @@ -288,18 +288,19 @@ def test_get_next_version(self, mock): ) assert next_version == "3.20.2" - @pytest.mark.flaky - def test_theme_download(self): - assert ( - gr.themes.Base.from_hub("gradio/dracula_test@0.0.2").to_dict() - == dracula.to_dict() - ) - - with gr.Blocks(theme="gradio/dracula_test@0.0.2") as demo: - pass - - assert demo.theme.to_dict() == dracula.to_dict() - assert demo.theme.name == "gradio/dracula_test" + ## Commenting out until after 4.0 Spaces are up + # @pytest.mark.flaky + # def test_theme_download(self): + # assert ( + # gr.themes.Base.from_hub("gradio/dracula_test@0.0.2").to_dict() + # == dracula.to_dict() + # ) + + # with gr.Blocks(theme="gradio/dracula_test@0.0.2") as demo: + # pass + + # assert demo.theme.to_dict() == dracula.to_dict() + # assert demo.theme.name == "gradio/dracula_test" def test_theme_download_raises_error_if_theme_does_not_exist(self): with pytest.raises(