Skip to content

Commit

Permalink
modify preprocess to use pydantic models (#6181)
Browse files Browse the repository at this point in the history
* modify preprocess to use pydantic models

* changes

* add changeset

* fix

* fix

* fix typing

* save

* revert queuing changes

* fix

* fix

* notebook

* fix

* changes

* add changeset

* fix functional tests

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot committed Oct 31, 2023
1 parent e16b4ab commit 62ec207
Show file tree
Hide file tree
Showing 45 changed files with 491 additions and 623 deletions.
6 changes: 6 additions & 0 deletions .changeset/short-doodles-lose.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@gradio/uploadbutton": minor
"gradio": minor
---

feat:modify preprocess to use pydantic models
2 changes: 1 addition & 1 deletion demo/chatbot_multimodal/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/avatar.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import time\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "def add_text(history, text):\n", " history = history + [(text, None)]\n", " return history, gr.Textbox(value=\"\", interactive=False)\n", "\n", "\n", "def add_file(history, file):\n", " history = history + [((file.name,), None)]\n", " return history\n", "\n", "\n", "def bot(history):\n", " response = \"**That's cool!**\"\n", " history[-1][1] = \"\"\n", " for character in response:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " avatar_images=(None, (os.path.join(os.path.abspath(''), \"avatar.png\"))),\n", " )\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(\n", " scale=4,\n", " show_label=False,\n", " placeholder=\"Enter text and press enter, or upload an image\",\n", " container=False,\n", " )\n", " btn = gr.UploadButton(\"\ud83d\udcc1\", file_types=[\"image\", \"video\", \"audio\"])\n", "\n", " txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n", " bot, chatbot, chatbot, api_name=\"bot_response\"\n", " )\n", " txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)\n", " file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatbot_multimodal"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/chatbot_multimodal/avatar.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "import time\n", "\n", "# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.\n", "\n", "\n", "def add_text(history, text):\n", " history = history + [(text, None)]\n", " return history, gr.Textbox(value=\"\", interactive=False)\n", "\n", "\n", "def add_file(history, file):\n", " history = history + [((file.name,), None)]\n", " return history\n", "\n", "\n", "def bot(history):\n", " response = \"**That's cool!**\"\n", " history[-1][1] = \"\"\n", " for character in response:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot(\n", " [],\n", " elem_id=\"chatbot\",\n", " bubble_full_width=False,\n", " avatar_images=(None, (os.path.join(os.path.abspath(''), \"avatar.png\"))),\n", " )\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(\n", " scale=4,\n", " show_label=False,\n", " placeholder=\"Enter text and press enter, or upload an image\",\n", " container=False,\n", " )\n", " btn = gr.UploadButton(\"\ud83d\udcc1\", file_types=[\"image\", \"video\", \"audio\"])\n", "\n", " txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n", " bot, chatbot, chatbot, api_name=\"bot_response\"\n", " )\n", " txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)\n", " file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(\n", " bot, chatbot, chatbot\n", " )\n", "\n", "demo.queue()\n", "if __name__ == \"__main__\":\n", " demo.launch(allowed_paths=[\"avatar.png\"])\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
2 changes: 1 addition & 1 deletion demo/chatbot_multimodal/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ def bot(history):

demo.queue()
if __name__ == "__main__":
demo.launch()
demo.launch(allowed_paths=["avatar.png"])
1 change: 1 addition & 0 deletions demo/clear_components/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def evaluate_values(*args):
are_false.append(a == "#000000")
else:
are_false.append(not a)
print(args)
return all(are_false)


Expand Down
10 changes: 9 additions & 1 deletion gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def __init__(
self.postprocess = postprocess
self.tracks_progress = tracks_progress
self.concurrency_limit = concurrency_limit
self.concurrency_id = concurrency_id or id(fn)
self.concurrency_id = concurrency_id or str(id(fn))
self.batch = batch
self.max_batch_size = max_batch_size
self.total_runtime = 0
Expand Down Expand Up @@ -1260,6 +1260,14 @@ def preprocess_data(
inputs_cached = processing_utils.move_files_to_cache(
inputs[i], block
)
if getattr(block, "data_model", None) and inputs_cached is not None:
if issubclass(block.data_model, GradioModel): # type: ignore
print("block.data_model", block.data_model, block)
print("1inputs_cached", inputs_cached)
inputs_cached = block.data_model(**inputs_cached) # type: ignore
elif issubclass(block.data_model, GradioRootModel): # type: ignore
print("2inputs_cached", inputs_cached)
inputs_cached = block.data_model(root=inputs_cached) # type: ignore
processed_input.append(block.preprocess(inputs_cached))
else:
processed_input = inputs
Expand Down
19 changes: 11 additions & 8 deletions gradio/components/annotated_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,21 @@ def __init__(

def postprocess(
self,
y: tuple[
value: tuple[
np.ndarray | _Image.Image | str,
list[tuple[np.ndarray | tuple[int, int, int, int], str]],
],
]
| None,
) -> AnnotatedImageData | None:
"""
Parameters:
y: Tuple of base image and list of subsections, with each subsection a two-part tuple where the first element is a 4 element bounding box or a 0-1 confidence mask, and the second element is the label.
value: Tuple of base image and list of subsections, with each subsection a two-part tuple where the first element is a 4 element bounding box or a 0-1 confidence mask, and the second element is the label.
Returns:
Tuple of base image file and list of subsections, with each subsection a two-part tuple where the first element image path of the mask, and the second element is the label.
"""
if y is None:
if value is None:
return None
base_img = y[0]
base_img = value[0]
if isinstance(base_img, str):
base_img_path = base_img
base_img = np.array(_Image.open(base_img))
Expand Down Expand Up @@ -144,7 +145,7 @@ def hex_to_rgb(value):
lv = len(value)
return [int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3)]

for mask, label in y[1]:
for mask, label in value[1]:
mask_array = np.zeros((base_img.shape[0], base_img.shape[1]))
if isinstance(mask, np.ndarray):
mask_array = mask
Expand Down Expand Up @@ -188,5 +189,7 @@ def hex_to_rgb(value):
def example_inputs(self) -> Any:
return {}

def preprocess(self, x: Any) -> Any:
return x
def preprocess(
self, payload: AnnotatedImageData | None
) -> AnnotatedImageData | None:
return payload
54 changes: 23 additions & 31 deletions gradio/components/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,20 +162,12 @@ def example_inputs(self) -> Any:
return "https://github.com/gradio-app/gradio/raw/main/test/test_files/audio_sample.wav"

def preprocess(
self, x: dict[str, Any] | None
self, payload: FileData | None
) -> tuple[int, np.ndarray] | str | None:
"""
Parameters:
x: dictionary with keys "path", "crop_min", "crop_max".
Returns:
audio in requested format
"""
if x is None:
return x
if payload is None:
return payload

payload: FileData = FileData(**x)
assert payload.path

# Need a unique name for the file to avoid re-using the same audio file if
# a user submits the same audio file twice
temp_file_path = Path(payload.path)
Expand Down Expand Up @@ -211,50 +203,50 @@ def preprocess(
)

def postprocess(
self, y: tuple[int, np.ndarray] | str | Path | bytes | None
) -> FileData | None | bytes:
self, value: tuple[int, np.ndarray] | str | Path | bytes | None
) -> FileData | bytes | None:
"""
Parameters:
y: audio data in either of the following formats: a tuple of (sample_rate, data), or a string filepath or URL to an audio file, or None.
value: audio data in either of the following formats: a tuple of (sample_rate, data), or a string filepath or URL to an audio file, or None.
Returns:
base64 url data
"""
if y is None:
if value is None:
return None
if isinstance(y, bytes):
if isinstance(value, bytes):
if self.streaming:
return y
return value
file_path = processing_utils.save_bytes_to_cache(
y, "audio", cache_dir=self.GRADIO_CACHE
value, "audio", cache_dir=self.GRADIO_CACHE
)
elif isinstance(y, tuple):
sample_rate, data = y
elif isinstance(value, tuple):
sample_rate, data = value
file_path = processing_utils.save_audio_to_cache(
data, sample_rate, format=self.format, cache_dir=self.GRADIO_CACHE
)
else:
if not isinstance(y, (str, Path)):
raise ValueError(f"Cannot process {y} as Audio")
file_path = str(y)
if not isinstance(value, (str, Path)):
raise ValueError(f"Cannot process {value} as Audio")
file_path = str(value)
return FileData(path=file_path)

def stream_output(
self, y, output_id: str, first_chunk: bool
self, value, output_id: str, first_chunk: bool
) -> tuple[bytes | None, Any]:
output_file = {
"path": output_id,
"is_stream": True,
}
if y is None:
if value is None:
return None, output_file
if isinstance(y, bytes):
return y, output_file
if client_utils.is_http_url_like(y["path"]):
response = requests.get(y["path"])
if isinstance(value, bytes):
return value, output_file
if client_utils.is_http_url_like(value["path"]):
response = requests.get(value["path"])
binary_data = response.content
else:
output_file["orig_name"] = y["orig_name"]
file_path = y["path"]
output_file["orig_name"] = value["orig_name"]
file_path = value["path"]
is_wav = file_path.endswith(".wav")
with open(file_path, "rb") as f:
binary_data = f.read()
Expand Down
16 changes: 7 additions & 9 deletions gradio/components/bar_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,15 @@ def create_plot(
return chart

def postprocess(
self, y: pd.DataFrame | dict | None
self, value: pd.DataFrame | dict | None
) -> AltairPlotData | dict | None:
# if None or update
if y is None or isinstance(y, dict):
return y
if value is None or isinstance(value, dict):
return value
if self.x is None or self.y is None:
raise ValueError("No value provided for required parameters `x` and `y`.")
chart = self.create_plot(
value=y,
value=value,
x=self.x,
y=self.y,
color=self.color,
Expand All @@ -288,12 +288,10 @@ def postprocess(
sort=self.sort, # type: ignore
)

return AltairPlotData(
**{"type": "altair", "plot": chart.to_json(), "chart": "bar"}
)
return AltairPlotData(type="altair", plot=chart.to_json(), chart="bar")

def example_inputs(self) -> dict[str, Any]:
return {}

def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: AltairPlotData) -> AltairPlotData:
return payload
42 changes: 22 additions & 20 deletions gradio/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,21 @@ class ComponentBase(ABC, metaclass=ComponentMeta):
EVENTS: list[EventListener | str] = []

@abstractmethod
def preprocess(self, x: Any) -> Any:
def preprocess(self, payload: Any) -> Any:
"""
Any preprocessing needed to be performed on function input.
"""
return x
return payload

@abstractmethod
def postprocess(self, y):
def postprocess(self, value):
"""
Any postprocessing needed to be performed on function output.
"""
return y
return value

@abstractmethod
def as_example(self, y):
def as_example(self, value):
"""
Return the input data in a way that can be displayed by the examples dataset component in the front-end.
Expand All @@ -88,7 +88,7 @@ def example_inputs(self) -> Any:
pass

@abstractmethod
def flag(self, x: Any | GradioDataModel, flag_dir: str | Path = "") -> str:
def flag(self, payload: Any | GradioDataModel, flag_dir: str | Path = "") -> str:
"""
Write the component's value to a format that can be stored in a csv or jsonl format for flagging.
"""
Expand All @@ -97,13 +97,13 @@ def flag(self, x: Any | GradioDataModel, flag_dir: str | Path = "") -> str:
@abstractmethod
def read_from_flag(
self,
x: Any,
payload: Any,
flag_dir: str | Path | None = None,
) -> GradioDataModel | Any:
"""
Convert the data from the csv or jsonl file into the component state.
"""
return x
return payload

@property
@abstractmethod
Expand Down Expand Up @@ -267,26 +267,26 @@ def api_info(self) -> dict[str, Any]:
f"The api_info method has not been implemented for {self.get_block_name()}"
)

def flag(self, x: Any, flag_dir: str | Path = "") -> str:
def flag(self, payload: Any, flag_dir: str | Path = "") -> str:
"""
Write the component's value to a format that can be stored in a csv or jsonl format for flagging.
"""
if self.data_model:
x = self.data_model.from_json(x)
return x.copy_to_dir(flag_dir).model_dump_json()
return x
payload = self.data_model.from_json(payload)
return payload.copy_to_dir(flag_dir).model_dump_json()
return payload

def read_from_flag(
self,
x: Any,
payload: Any,
flag_dir: str | Path | None = None,
):
"""
Convert the data from the csv or jsonl file into the component state.
"""
if self.data_model:
return self.data_model.from_json(json.loads(x))
return x
return self.data_model.from_json(json.loads(payload))
return payload


class FormComponent(Component):
Expand All @@ -295,11 +295,11 @@ def get_expected_parent(self) -> type[Form] | None:
return None
return Form

def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: Any) -> Any:
return payload

def postprocess(self, y):
return y
def postprocess(self, value):
return value


class StreamingOutput(metaclass=abc.ABCMeta):
Expand All @@ -308,7 +308,9 @@ def __init__(self, *args, **kwargs) -> None:
self.streaming: bool

@abc.abstractmethod
def stream_output(self, y, output_id: str, first_chunk: bool) -> tuple[bytes, Any]:
def stream_output(
self, value, output_id: str, first_chunk: bool
) -> tuple[bytes, Any]:
pass


Expand Down
8 changes: 4 additions & 4 deletions gradio/components/button.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ def __init__(
def skip_api(self):
return True

def preprocess(self, x: Any) -> Any:
return x
def preprocess(self, payload: str) -> str:
return payload

def postprocess(self, y):
return y
def postprocess(self, value: str) -> str:
return value

def example_inputs(self) -> Any:
return None
Loading

0 comments on commit 62ec207

Please sign in to comment.