Skip to content

Commit

Permalink
events
Browse files Browse the repository at this point in the history
  • Loading branch information
abidlabs committed Jul 9, 2024
1 parent 936c713 commit db4a4e8
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 6 deletions.
4 changes: 4 additions & 0 deletions gradio/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
scale: int | None = None,
min_width: int = 160,
proxy_url: str | None = None,
sample_labels: list[str] | None = None,
):
"""
Parameters:
Expand All @@ -61,6 +62,7 @@ def __init__(
scale: relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.
min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.
proxy_url: The URL of the external Space used to load this component. Set automatically when using `gr.load()`. This should not be set manually.
sample_labels: An list of labels for each sample. If provided, the length of this list should be the same as the number of samples, and these labels will be used in the UI instead of rendering the sample values.
"""
super().__init__(
visible=visible,
Expand Down Expand Up @@ -115,6 +117,7 @@ def __init__(
else:
self.headers = [c.label or "" for c in self._components]
self.samples_per_page = samples_per_page
self.sample_labels = sample_labels

def api_info(self) -> dict[str, str]:
return {"type": "integer", "description": "index of selected example"}
Expand All @@ -124,6 +127,7 @@ def get_config(self):

config["components"] = []
config["component_props"] = self.component_props
config["sample_labels"] = self.sample_labels
config["component_ids"] = []

for component in self._components:
Expand Down
27 changes: 23 additions & 4 deletions gradio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from gradio import components, oauth, processing_utils, routes, utils, wasm_utils
from gradio.context import Context, LocalContext, get_blocks_context
from gradio.data_classes import GradioModel, GradioRootModel
from gradio.events import EventData
from gradio.events import Dependency, EventData
from gradio.exceptions import Error
from gradio.flagging import CSVLogger

Expand All @@ -50,6 +50,9 @@ def create_examples(
postprocess: bool = True,
api_name: str | Literal[False] = "load_example",
batch: bool = False,
*,
example_labels: list[str] | None = None,
visible: bool = True,
_defer_caching: bool = False,
):
"""Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
Expand All @@ -69,6 +72,8 @@ def create_examples(
api_name=api_name,
batch=batch,
_defer_caching=_defer_caching,
example_labels=example_labels,
visible=visible,
_initiated_directly=False,
)
examples_obj.create()
Expand Down Expand Up @@ -103,6 +108,9 @@ def __init__(
postprocess: bool = True,
api_name: str | Literal[False] = "load_example",
batch: bool = False,
*,
example_labels: list[str] | None = None,
visible: bool = True,
_defer_caching: bool = False,
_initiated_directly: bool = True,
):
Expand All @@ -121,6 +129,8 @@ def __init__(
postprocess: if True, postprocesses the example output after running the prediction function and before caching. Only applies if `cache_examples` is not False.
api_name: Defines how the event associated with clicking on the examples appears in the API docs. Can be a string or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that `gr.load` this app) will not be able to use the example function.
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. Used only if cache_examples is not False.
example_labels: An list of labels for each example. If provided, the length of this list should be the same as the number of examples, and these labels will be used in the UI instead of rendering the example values.
visible: If False, the examples component will be hidden in the UI.
"""
if _initiated_directly:
warnings.warn(
Expand Down Expand Up @@ -221,6 +231,10 @@ def __init__(
[ex for (ex, keep) in zip(example, input_has_examples) if keep]
for example in examples
]
if example_labels is not None and len(example_labels) != len(inputs_with_examples):
raise ValueError(
"The length of `example_labels` should be the same as the number of examples."
)

self.examples = examples
self.non_none_examples = non_none_examples
Expand All @@ -233,6 +247,7 @@ def __init__(
self.postprocess = postprocess
self.api_name: str | Literal[False] = api_name
self.batch = batch
self.example_labels = example_labels

with utils.set_directory(working_directory):
self.processed_examples = []
Expand Down Expand Up @@ -265,13 +280,17 @@ def __init__(
label=label,
samples_per_page=examples_per_page,
elem_id=elem_id,
visible=visible,
sample_labels=example_labels,
)

self.cache_logger = CSVLogger(simplify_file_data=False)
self.cached_folder = utils.get_cache_folder() / str(self.dataset._id)
self.cached_file = Path(self.cached_folder) / "log.csv"
self.cached_indices_file = Path(self.cached_folder) / "indices.csv"
self.run_on_click = run_on_click
self.cache_event: Dependency | None = None


def create(self) -> None:
"""Caches the examples if self.cache_examples is True and creates the Dataset
Expand Down Expand Up @@ -380,7 +399,7 @@ async def lazy_cache(self) -> None:
lazy_cache_fn = self.async_lazy_cache
else:
lazy_cache_fn = self.sync_lazy_cache
self.load_input_event.then(
self.cache_event = self.load_input_event.then(
lazy_cache_fn,
inputs=[self.dataset] + self.inputs,
outputs=self.outputs,
Expand Down Expand Up @@ -466,7 +485,7 @@ async def get_final_item(*args):
# create a fake dependency to process the examples and get the predictions
from gradio.events import EventListenerMethod

dependency, fn_index = blocks_config.set_event_trigger(
_, fn_index = blocks_config.set_event_trigger(
[EventListenerMethod(Context.root_block, "load")],
fn=fn,
inputs=self.inputs_with_examples, # type: ignore
Expand Down Expand Up @@ -511,7 +530,7 @@ def load_example(example_id):
] + self.load_from_cache(example_id)
return utils.resolve_singleton(processed_example)

self.load_input_event = self.dataset.click(
self.cache_event = self.load_input_event = self.dataset.click(
load_example,
inputs=[self.dataset],
outputs=self.inputs_with_examples + self.outputs, # type: ignore
Expand Down
21 changes: 21 additions & 0 deletions js/_website/src/lib/templates/gradio/04_helpers/11_examples.svx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,27 @@
import { style_formatted_text } from "$lib/text";

let obj = get_object("examples");

obj["attributes"] = [
{
name: "dataset",
annotation: "gradio.Dataset",
doc: "The dataset component corresponding to this Examples helper class.",
kwargs: null
},
{
name: "load_input_event",
annotation: "gradio.events.Dependency",
doc: "The Gradio event that populates the input values when the examples are clicked. You can attach a .then() or a .success() to this event to trigger subsequent events to fire after this event.",
kwargs: null
},
{
name: "cache_event",
annotation: "gradio.events.Dependency | None",
doc: "The Gradio event that populates the cached output values when the examples are clicked. You can attach a .then() or a .success() to this event to trigger subsequent events to fire after this event. This event is not defined if cache_exm",
kwargs: null
}
]
</script>

<!--- Title -->
Expand Down
12 changes: 10 additions & 2 deletions js/dataset/Index.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
export let label = "Examples";
export let headers: string[];
export let samples: any[][] | null = null;
export let sample_labels: string[] | null = null;
export let elem_id = "";
export let elem_classes: string[] = [];
export let visible = true;
Expand All @@ -33,7 +34,7 @@
? `/proxy=${proxy_url}file=`
: `${root}/file=`;
let page = 0;
$: gallery = components.length < 2;
$: gallery = components.length < 2 || sample_labels !== null;
let paginate = samples ? samples.length > samples_per_page : false;
let selected_samples: any[][];
Expand Down Expand Up @@ -84,7 +85,13 @@
}[][] = [];
async function get_component_meta(selected_samples: any[][]): Promise<void> {
component_meta = await Promise.all(
if (sample_labels !== null) {
component_meta = await Promise.all(sample_labels.map(async (label) => ([{
value: label,
component: (await component_map.get("textbox"))?.default as ComponentType<SvelteComponent>
}])));
} else {
component_meta = await Promise.all(
selected_samples &&
selected_samples.map(
async (sample_row) =>
Expand All @@ -99,6 +106,7 @@
)
)
);
}
}
$: component_map, get_component_meta(selected_samples);
Expand Down

0 comments on commit db4a4e8

Please sign in to comment.