Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding huggingface generators for clrs text #135

Merged
merged 8 commits into from
Jul 5, 2024
147 changes: 147 additions & 0 deletions clrs/_src/clrs_text/huggingface_generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""Functions to allow for Huggingface Intergration"""
mcleish7 marked this conversation as resolved.
Show resolved Hide resolved

from typing import List, Dict
import random

from clrs import build_sampler
from clrs._src.clrs_text.clrs_utils import format_clrs_example


def clrs_generator(
algos_and_lengths: Dict[str, List[int]],
num_samples: int,
use_hints: bool = False,
seed: int = 0,
):
"""
Huggingface datasets.Dataset generator function for creating a dataset of fixed size

Example usage:
import datasets
algos_and_lengths = {"insertion_sort": [16]}
ds = datasets.Dataset.from_generator(
clrs_gen, gen_kwargs={"algos_and_lengths": algos_and_lengths, "num_samples": 100}
)

Huggingface reference:
https://huggingface.co/docs/datasets/v2.19.0/en/package_reference/main_classes#datasets.Dataset.from_generator

Args:
mcleish7 marked this conversation as resolved.
Show resolved Hide resolved
algos_and_lengths: keys = algorithm names
[Must be same as in clrs.CLRS_30_ALGS_SETTINGS.keys()],
values = list of lengths required for that algorithm
num_samples: The size of the output dataset
use_hints: Whether hints should be included in the question and answer
seed: The random seed for all of the generators
mcleish7 marked this conversation as resolved.
Show resolved Hide resolved

Returns:
Sample question and answer in various formats with meta data as a dictionary
"""
clrs_samplers = []

# make all of the possible generators
for algo_name, lengths in algos_and_lengths.items():
for length in lengths:
sampler, _ = build_sampler(
algo_name,
seed=seed,
num_samples=-1,
length=length,
track_max_steps=False,
use_padding=False,
)
clrs_samplers.append((sampler, algo_name, length))

random.seed(seed)
# uniformly sample one element from each sampler in the list up to the maximum
for _ in range(num_samples):
sampler, algo_name, length = random.choice(clrs_samplers)
sample = sampler.next(batch_size=1) # get one sample from the sampler
question, answer = format_clrs_example(
algo_name,
sample,
use_hints=use_hints,
)

text = question + answer
mcleish7 marked this conversation as resolved.
Show resolved Hide resolved
mcleish7 marked this conversation as resolved.
Show resolved Hide resolved
yield {
"text": text,
"question": question,
"answer": answer,
"algo_name": algo_name,
"length": length,
"use_hints": use_hints,
}


def clrs_infinite_generator(
mcleish7 marked this conversation as resolved.
Show resolved Hide resolved
algos_and_lengths: Dict[str, List[int]], use_hints: bool = False, seed: int = 0
):
"""
Huggingface datasets.Dataset generator function for creating a dataset of fixed size

Example usage:
import datasets
algos_and_lengths = {"insertion_sort": [16]}
ds = IterableDataset.from_generator(
clrs_infinite_generator,
features=Features(
mcleish7 marked this conversation as resolved.
Show resolved Hide resolved
{
"text": Value(dtype="string", id=None),
"question": Value(dtype="string", id=None),
"answer": Value(dtype="string", id=None),
"algo_name": Value(dtype="string", id=None),
"length": Value(dtype="int32", id=None),
"use_hints": Value(dtype="bool_", id=None),
}
),
gen_kwargs={"algos_and_lengths": algos_and_lengths},
)

Huggingface reference:
https://huggingface.co/docs/datasets/v2.19.0/en/package_reference/main_classes#datasets.Dataset.from_generator

Args:
algos_and_lengths: keys = algorithm names
[Must be same as in clrs.CLRS_30_ALGS_SETTINGS.keys()],
values = list of lengths required for that algorithm
seed: The random seed for all of the generators

Returns:
Sample question and answer in various formats with meta data as a dictionary
"""
clrs_samplers = []

# make all of the possible generators
for algo_name, lengths in algos_and_lengths.items():
for length in lengths:
sampler, _ = build_sampler(
algo_name,
seed=seed,
num_samples=-1,
length=length,
track_max_steps=False,
use_padding=False,
)
clrs_samplers.append((sampler, algo_name, length))

random.seed(seed)
# uniformly sample one element from each sampler in the list up to the maximum
while True:
sampler, algo_name, length = random.choice(clrs_samplers)
sample = sampler.next(batch_size=1) # get one sample from the sampler
question, answer = format_clrs_example(
algo_name,
sample,
use_hints=use_hints,
)

text = question + answer
yield {
"text": text,
"question": question,
"answer": answer,
"algo_name": algo_name,
"length": length,
"use_hints": use_hints,
}
141 changes: 141 additions & 0 deletions clrs/_src/clrs_text/huggingface_generators_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Tests for clrs._src.clrs_text.huggingface_generators."""
mcleish7 marked this conversation as resolved.
Show resolved Hide resolved

mcleish7 marked this conversation as resolved.
Show resolved Hide resolved
from datasets import Dataset, IterableDataset, Value, Features

from absl.testing import parameterized
import clrs
from clrs._src.clrs_text import clrs_utils, huggingface_generators


class TestFormatCLRSExamplesHFDataset(parameterized.TestCase):
"""Based on TestFormatCLRSExamples in clrs.src_.clrs_text.clrs_utils_test.py"""

@parameterized.product(
algo_name=list(clrs.CLRS_30_ALGS_SETTINGS.keys()),
lengths=[[16], [16, 32]],
use_hints=[True, False],
)
def test_format(self, algo_name, lengths, use_hints):
"""Test that we can format samples from any algo into strings from a hf Dataset."""
algos_and_lengths = {algo_name: lengths}
ds = Dataset.from_generator(
huggingface_generators.clrs_generator,
gen_kwargs={
"algos_and_lengths": algos_and_lengths,
"num_samples": 100,
"use_hints": use_hints,
},
)

for sample in ds:
(
text,
question,
answer,
sample_algo_name,
sample_length,
use_hints,
) = (
sample["text"],
sample["question"],
sample["answer"],
sample["algo_name"],
sample["length"],
sample["use_hints"],
)

self.assertEqual(algo_name, sample_algo_name)
self.assertEqual(use_hints, use_hints)
self.assertIn(sample_length, lengths)

self.assertTrue(question.startswith(f"{algo_name}:\n"))
self.assertTrue(question.endswith(":\n"))
self.assertTrue(answer.endswith("\n\n"))

self.assertTrue(text.startswith(f"{algo_name}:\n"))
self.assertTrue(text.endswith("\n\n"))
self.assertEqual(question + answer, text)

if (
use_hints and algo_name in clrs_utils.CLRS_TASKS_WITH_HINTS
): # segments intersect has no hints option
self.assertIn("trace | ", question)
self.assertIn("initial_trace:", question)
self.assertIn("trace | ", text)
self.assertIn("initial_trace:", text)
else:
self.assertNotIn("trace | ", question)
self.assertNotIn("initial_trace:", question)
self.assertNotIn("trace | ", text)
self.assertNotIn("initial_trace:", text)


class TestFormatCLRSExamplesHFIterableDataset(parameterized.TestCase):
"""Based on TestFormatCLRSExamples in clrs.src_.clrs_text.clrs_utils_test.py"""

@parameterized.product(
algo_name=list(clrs.CLRS_30_ALGS_SETTINGS.keys()),
lengths=[[16], [16, 32]],
use_hints=[True, False],
)
def test_format(self, algo_name, lengths, use_hints):
"""Test that we can format samples from any algo into strings from a hf IterableDataset."""
algos_and_lengths = {algo_name: lengths}
ds = IterableDataset.from_generator(
huggingface_generators.clrs_infinite_generator,
features=Features(
{
"text": Value(dtype="string", id=None),
"question": Value(dtype="string", id=None),
"answer": Value(dtype="string", id=None),
"algo_name": Value(dtype="string", id=None),
"length": Value(dtype="int32", id=None),
"use_hints": Value(dtype="bool_", id=None),
}
),
gen_kwargs={"algos_and_lengths": algos_and_lengths, "use_hints": use_hints},
)

ds_iterator = iter(ds)
for _ in range(100): # only test 100 samples as we have infinite sampling on
sample = next(ds_iterator)
(
text,
question,
answer,
sample_algo_name,
sample_length,
use_hints,
) = (
sample["text"],
sample["question"],
sample["answer"],
sample["algo_name"],
sample["length"],
sample["use_hints"],
)

self.assertEqual(algo_name, sample_algo_name)
self.assertEqual(use_hints, use_hints)
self.assertIn(sample_length, lengths)

self.assertTrue(question.startswith(f"{algo_name}:\n"))
self.assertTrue(question.endswith(":\n"))
self.assertTrue(answer.endswith("\n\n"))

self.assertTrue(text.startswith(f"{algo_name}:\n"))
self.assertTrue(text.endswith("\n\n"))
self.assertEqual(question + answer, text)

if (
use_hints and algo_name in clrs_utils.CLRS_TASKS_WITH_HINTS
): # segments intersect has no hints option
self.assertIn("trace | ", question)
self.assertIn("initial_trace:", question)
self.assertIn("trace | ", text)
self.assertIn("initial_trace:", text)
else:
self.assertNotIn("trace | ", question)
self.assertNotIn("initial_trace:", question)
self.assertNotIn("trace | ", text)
self.assertNotIn("initial_trace:", text)