Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes issue 5781: Enables specifying a caching directory for Examples #6803

Merged
merged 13 commits into from
Dec 19, 2023
5 changes: 5 additions & 0 deletions .changeset/spicy-wings-thank.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Fixes issue 5781: Enables specifying a caching directory for Examples
3 changes: 1 addition & 2 deletions gradio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.components import Component

CACHED_FOLDER = "gradio_cached_examples"
LOG_FILE = "log.csv"

set_documentation_group("helpers")
Expand Down Expand Up @@ -248,7 +247,7 @@ def __init__(
elem_id=elem_id,
)

self.cached_folder = Path(CACHED_FOLDER) / str(self.dataset._id)
self.cached_folder = utils.get_cache_folder() / str(self.dataset._id)
self.cached_file = Path(self.cached_folder) / "log.csv"
self.cache_examples = cache_examples
self.run_on_click = run_on_click
Expand Down
3 changes: 1 addition & 2 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
from gradio.context import Context
from gradio.data_classes import ComponentServerBody, PredictBody, ResetBody
from gradio.exceptions import Error
from gradio.helpers import CACHED_FOLDER
from gradio.oauth import attach_oauth
from gradio.queueing import Estimation
from gradio.route_utils import ( # noqa: F401
Expand Down Expand Up @@ -455,7 +454,7 @@ async def file(path_or_url: str, request: fastapi.Request):
)
was_uploaded = utils.is_in_or_equal(abs_path, app.uploaded_file_dir)
is_cached_example = utils.is_in_or_equal(
abs_path, utils.abspath(CACHED_FOLDER)
abs_path, utils.abspath(utils.get_cache_folder())
)

if not (
Expand Down
4 changes: 4 additions & 0 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,3 +1016,7 @@ def __setitem__(self, key: K, value: V) -> None:
elif len(self) >= self.max_size:
self.popitem(last=False)
super().__setitem__(key, value)


def get_cache_folder() -> Path:
return Path(os.environ.get("GRADIO_EXAMPLES_CACHE", "gradio_cached_examples"))
2 changes: 1 addition & 1 deletion guides/02_building-interfaces/03_more-on-examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Sometimes your app has many input components, but you would only like to provide
## Caching examples

You may wish to provide some cached examples of your model for users to quickly try out, in case your model takes a while to run normally.
If `cache_examples=True`, the `Interface` will run all of your examples through your app and save the outputs when you call the `launch()` method. This data will be saved in a directory called `gradio_cached_examples`.
If `cache_examples=True`, the `Interface` will run all of your examples through your app and save the outputs when you call the `launch()` method. This data will be saved in a directory called `gradio_cached_examples` in your working directory by default. You can also set this directory with the `GRADIO_EXAMPLES_CACHE` environment variable, which can be either an absolute path or a relative path to your working directory.

Whenever a user clicks on an example, the output will automatically be populated in the app now, using data from this cached directory instead of actually running the function. This is useful so users can quickly try out your model without adding any load!

Expand Down
131 changes: 72 additions & 59 deletions test/test_chat_interface.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import tempfile
from concurrent.futures import wait
from pathlib import Path
from unittest.mock import patch

import pytest

import gradio as gr
from gradio import helpers


def invalid_fn(message):
Expand Down Expand Up @@ -79,44 +80,52 @@ def test_events_attached(self):
)

def test_example_caching(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
double, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello hello")
assert prediction_hi[0].root[0] == ("hi", "hi hi")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
double, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello hello")
assert prediction_hi[0].root[0] == ("hi", "hi hi")

def test_example_caching_async(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
async_greet, examples=["abubakar", "tom"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("abubakar", "hi, abubakar")
assert prediction_hi[0].root[0] == ("tom", "hi, tom")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
async_greet, examples=["abubakar", "tom"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("abubakar", "hi, abubakar")
assert prediction_hi[0].root[0] == ("tom", "hi, tom")

def test_example_caching_with_streaming(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello")
assert prediction_hi[0].root[0] == ("hi", "hi")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello")
assert prediction_hi[0].root[0] == ("hi", "hi")

def test_example_caching_with_streaming_async(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
async_stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello")
assert prediction_hi[0].root[0] == ("hi", "hi")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
async_stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello")
assert prediction_hi[0].root[0] == ("hi", "hi")

def test_default_accordion_params(self):
chatbot = gr.ChatInterface(
Expand Down Expand Up @@ -146,34 +155,38 @@ def test_setting_accordion_params(self, monkeypatch):
assert accordion.get_config().get("label") == "MOAR"

def test_example_caching_with_additional_inputs(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=["textbox", "slider"],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=["textbox", "slider"],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")

def test_example_caching_with_additional_inputs_already_rendered(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
with gr.Blocks():
with gr.Accordion("Inputs"):
text = gr.Textbox()
slider = gr.Slider()
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=[text, slider],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
with gr.Blocks():
with gr.Accordion("Inputs"):
text = gr.Textbox()
slider = gr.Slider()
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=[text, slider],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")


class TestAPI:
Expand Down
7 changes: 5 additions & 2 deletions test/test_external.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import tempfile
import textwrap
import warnings
from pathlib import Path
Expand Down Expand Up @@ -356,7 +357,7 @@ def test_private_space_v4_sse_v1(self):
class TestLoadInterfaceWithExamples:
def test_interface_load_examples(self, tmp_path):
test_file_dir = Path(Path(__file__).parent, "test_files")
with patch("gradio.helpers.CACHED_FOLDER", tmp_path):
with patch("gradio.utils.get_cache_folder", return_value=tmp_path):
gr.load(
name="models/google/vit-base-patch16-224",
examples=[Path(test_file_dir, "cheetah1.jpg")],
Expand All @@ -365,7 +366,9 @@ def test_interface_load_examples(self, tmp_path):

def test_interface_load_cache_examples(self, tmp_path):
test_file_dir = Path(Path(__file__).parent, "test_files")
with patch("gradio.helpers.CACHED_FOLDER", tmp_path):
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
gr.load(
name="models/google/vit-base-patch16-224",
examples=[Path(test_file_dir, "cheetah1.jpg")],
Expand Down
Loading
Loading