Skip to content

Commit

Permalink
Improvements to gr.Examples: adds events as attributes and document…
Browse files Browse the repository at this point in the history
…s, them, adds `sample_labels`, and `visible` properties (#8733)

* events

* examples

* add changeset

* format

* add changeset

* add changeset

* format

* changes

* Update gradio/components/dataset.py

Co-authored-by: Ali Abdalla <ali.si3luwa@gmail.com>

* Update gradio/helpers.py

Co-authored-by: Ali Abdalla <ali.si3luwa@gmail.com>

* add test

* Update test/test_helpers.py

Co-authored-by: Ali Abdalla <ali.si3luwa@gmail.com>

* changes

* format

* add to interface as well

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Ali Abdalla <ali.si3luwa@gmail.com>
  • Loading branch information
3 people committed Jul 10, 2024
1 parent d15ada9 commit fb0daf3
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 8 deletions.
7 changes: 7 additions & 0 deletions .changeset/slow-candles-fail.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/dataset": minor
"gradio": minor
"website": minor
---

feat:Improvements to `gr.Examples`: adds events as attributes and documents, them, adds `sample_labels`, and `visible` properties
4 changes: 4 additions & 0 deletions gradio/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
scale: int | None = None,
min_width: int = 160,
proxy_url: str | None = None,
sample_labels: list[str] | None = None,
):
"""
Parameters:
Expand All @@ -61,6 +62,7 @@ def __init__(
scale: relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.
min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.
proxy_url: The URL of the external Space used to load this component. Set automatically when using `gr.load()`. This should not be set manually.
sample_labels: A list of labels for each sample. If provided, the length of this list should be the same as the number of samples, and these labels will be used in the UI instead of rendering the sample values.
"""
super().__init__(
visible=visible,
Expand Down Expand Up @@ -115,6 +117,7 @@ def __init__(
else:
self.headers = [c.label or "" for c in self._components]
self.samples_per_page = samples_per_page
self.sample_labels = sample_labels

def api_info(self) -> dict[str, str]:
return {"type": "integer", "description": "index of selected example"}
Expand All @@ -124,6 +127,7 @@ def get_config(self):

config["components"] = []
config["component_props"] = self.component_props
config["sample_labels"] = self.sample_labels
config["component_ids"] = []

for component in self._components:
Expand Down
26 changes: 22 additions & 4 deletions gradio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from gradio import components, oauth, processing_utils, routes, utils, wasm_utils
from gradio.context import Context, LocalContext, get_blocks_context
from gradio.data_classes import GradioModel, GradioRootModel
from gradio.events import EventData
from gradio.events import Dependency, EventData
from gradio.exceptions import Error
from gradio.flagging import CSVLogger

Expand All @@ -50,6 +50,9 @@ def create_examples(
postprocess: bool = True,
api_name: str | Literal[False] = "load_example",
batch: bool = False,
*,
example_labels: list[str] | None = None,
visible: bool = True,
_defer_caching: bool = False,
):
"""Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
Expand All @@ -69,6 +72,8 @@ def create_examples(
api_name=api_name,
batch=batch,
_defer_caching=_defer_caching,
example_labels=example_labels,
visible=visible,
_initiated_directly=False,
)
examples_obj.create()
Expand Down Expand Up @@ -103,6 +108,9 @@ def __init__(
postprocess: bool = True,
api_name: str | Literal[False] = "load_example",
batch: bool = False,
*,
example_labels: list[str] | None = None,
visible: bool = True,
_defer_caching: bool = False,
_initiated_directly: bool = True,
):
Expand All @@ -121,6 +129,8 @@ def __init__(
postprocess: if True, postprocesses the example output after running the prediction function and before caching. Only applies if `cache_examples` is not False.
api_name: Defines how the event associated with clicking on the examples appears in the API docs. Can be a string or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that `gr.load` this app) will not be able to use the example function.
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. Used only if cache_examples is not False.
example_labels: A list of labels for each example. If provided, the length of this list should be the same as the number of examples, and these labels will be used in the UI instead of rendering the example values.
visible: If False, the examples component will be hidden in the UI.
"""
if _initiated_directly:
warnings.warn(
Expand Down Expand Up @@ -221,6 +231,10 @@ def __init__(
[ex for (ex, keep) in zip(example, input_has_examples) if keep]
for example in examples
]
if example_labels is not None and len(example_labels) != len(examples):
raise ValueError(
"If `example_labels` are provided, the length of `example_labels` must be the same as the number of examples."
)

self.examples = examples
self.non_none_examples = non_none_examples
Expand All @@ -233,6 +247,7 @@ def __init__(
self.postprocess = postprocess
self.api_name: str | Literal[False] = api_name
self.batch = batch
self.example_labels = example_labels

with utils.set_directory(working_directory):
self.processed_examples = []
Expand Down Expand Up @@ -265,13 +280,16 @@ def __init__(
label=label,
samples_per_page=examples_per_page,
elem_id=elem_id,
visible=visible,
sample_labels=example_labels,
)

self.cache_logger = CSVLogger(simplify_file_data=False)
self.cached_folder = utils.get_cache_folder() / str(self.dataset._id)
self.cached_file = Path(self.cached_folder) / "log.csv"
self.cached_indices_file = Path(self.cached_folder) / "indices.csv"
self.run_on_click = run_on_click
self.cache_event: Dependency | None = None

def create(self) -> None:
"""Caches the examples if self.cache_examples is True and creates the Dataset
Expand Down Expand Up @@ -380,7 +398,7 @@ async def lazy_cache(self) -> None:
lazy_cache_fn = self.async_lazy_cache
else:
lazy_cache_fn = self.sync_lazy_cache
self.load_input_event.then(
self.cache_event = self.load_input_event.then(
lazy_cache_fn,
inputs=[self.dataset] + self.inputs,
outputs=self.outputs,
Expand Down Expand Up @@ -466,7 +484,7 @@ async def get_final_item(*args):
# create a fake dependency to process the examples and get the predictions
from gradio.events import EventListenerMethod

dependency, fn_index = blocks_config.set_event_trigger(
_, fn_index = blocks_config.set_event_trigger(
[EventListenerMethod(Context.root_block, "load")],
fn=fn,
inputs=self.inputs_with_examples, # type: ignore
Expand Down Expand Up @@ -511,7 +529,7 @@ def load_example(example_id):
] + self.load_from_cache(example_id)
return utils.resolve_singleton(processed_example)

self.load_input_event = self.dataset.click(
self.cache_event = self.load_input_event = self.dataset.click(
load_example,
inputs=[self.dataset],
outputs=self.inputs_with_examples + self.outputs, # type: ignore
Expand Down
4 changes: 4 additions & 0 deletions gradio/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
clear_btn: str | Button | None = "Clear",
delete_cache: tuple[int, int] | None = None,
show_progress: Literal["full", "minimal", "hidden"] = "full",
example_labels: list[str] | None = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -168,6 +169,7 @@ def __init__(
clear_btn: The button to use for clearing the inputs. Defaults to a `gr.Button("Clear", variant="secondary")`. Can be set to a string (which becomes the button label) or a `gr.Button` object (which allows for more customization). Can be set to None, which hides the button.
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
show_progress: whether to show progress animation while running. Has no effect if the interface is `live`.
example_labels: A list of labels for each example. If provided, the length of this list should be the same as the number of examples, and these labels will be used in the UI instead of rendering the example values.
"""
super().__init__(
analytics_enabled=analytics_enabled,
Expand Down Expand Up @@ -320,6 +322,7 @@ def __init__(

self.examples = examples
self.examples_per_page = examples_per_page
self.example_labels = example_labels

if isinstance(submit_btn, Button):
self.submit_btn_parms = submit_btn.recover_kwargs(submit_btn.get_config())
Expand Down Expand Up @@ -879,6 +882,7 @@ def render_examples(self):
examples_per_page=self.examples_per_page,
_api_mode=self.api_mode,
batch=self.batch,
example_labels=self.example_labels,
)

def __str__(self):
Expand Down
25 changes: 25 additions & 0 deletions js/_website/src/lib/templates/gradio/04_helpers/11_examples.svx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,27 @@
import { style_formatted_text } from "$lib/text";

let obj = get_object("examples");

obj["attributes"] = [
{
name: "dataset",
annotation: "gradio.Dataset",
doc: "The `gr.Dataset` component corresponding to this Examples object.",
kwargs: null
},
{
name: "load_input_event",
annotation: "gradio.events.Dependency",
doc: "The Gradio event that populates the input values when the examples are clicked. You can attach a `.then()` or a `.success()` to this event to trigger subsequent events to fire after this event.",
kwargs: null
},
{
name: "cache_event",
annotation: "gradio.events.Dependency | None",
doc: "The Gradio event that populates the cached output values when the examples are clicked. You can attach a `.then()` or a `.success()` to this event to trigger subsequent events to fire after this event. This event is `None` if `cache_examples` if False, and is the same as `load_input_event` if `cache_examples` is `'lazy'`.",
kwargs: null
}
]
</script>

<!--- Title -->
Expand Down Expand Up @@ -37,6 +58,10 @@ None
### Initialization
<ParamTable parameters={obj.parameters} />

<!--- Attributes -->
### Attributes
<ParamTable parameters={obj.attributes} />


{#if obj.demos && obj.demos.length > 0}
<!--- Demos -->
Expand Down
14 changes: 11 additions & 3 deletions js/dataset/Index.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import { Block } from "@gradio/atoms";
import type { SvelteComponent, ComponentType } from "svelte";
import type { Gradio, SelectData } from "@gradio/utils";
import { BaseExample } from "@gradio/textbox";
export let components: string[];
export let component_props: Record<string, any>[];
export let component_map: Map<
Expand All @@ -13,6 +14,7 @@
export let label = "Examples";
export let headers: string[];
export let samples: any[][] | null = null;
export let sample_labels: string[] | null = null;
export let elem_id = "";
export let elem_classes: string[] = [];
export let visible = true;
Expand All @@ -33,7 +35,7 @@
? `/proxy=${proxy_url}file=`
: `${root}/file=`;
let page = 0;
$: gallery = components.length < 2;
$: gallery = components.length < 2 || sample_labels !== null;
let paginate = samples ? samples.length > samples_per_page : false;
let selected_samples: any[][];
Expand All @@ -51,7 +53,7 @@
}
$: {
samples = samples || [];
samples = sample_labels ? sample_labels.map((e) => [e]) : samples || [];
paginate = samples.length > samples_per_page;
if (paginate) {
visible_pages = [];
Expand Down Expand Up @@ -146,7 +148,13 @@
on:mouseenter={() => handle_mouseenter(i)}
on:mouseleave={() => handle_mouseleave()}
>
{#if component_meta.length && component_map.get(components[0])}
{#if sample_labels}
<BaseExample
value={sample_row[0]}
selected={current_hover === i}
type="gallery"
/>
{:else if component_meta.length && component_map.get(components[0])}
<svelte:component
this={component_meta[0][0].component}
{...component_props[0]}
Expand Down
3 changes: 2 additions & 1 deletion js/dataset/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"@gradio/atoms": "workspace:^",
"@gradio/client": "workspace:^",
"@gradio/utils": "workspace:^",
"@gradio/upload": "workspace:^"
"@gradio/upload": "workspace:^",
"@gradio/textbox": "workspace:^"
},
"devDependencies": {
"@gradio/preview": "workspace:^"
Expand Down
3 changes: 3 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions test/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,28 @@ def test_some_headers(self, patched_cache_folder):
)
assert examples.dataset.headers == ["im", ""]

def test_example_labels(self, patched_cache_folder):
examples = gr.Examples(
examples=[
[5, "add", 3],
[4, "divide", 2],
[-4, "multiply", 2.5],
[0, "subtract", 1.2],
],
inputs=[
gr.Number(),
gr.Radio(["add", "divide", "multiply", "subtract"]),
gr.Number(),
],
example_labels=["add", "divide", "multiply", "subtract"],
)
assert examples.dataset.sample_labels == [
"add",
"divide",
"multiply",
"subtract",
]


def test_example_caching_relaunch(connect):
def combine(a, b):
Expand Down

0 comments on commit fb0daf3

Please sign in to comment.