Skip to content

Commit

Permalink
fixing dataset issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Satrat committed Oct 10, 2023
1 parent f1441cb commit 29271a3
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from sparseml.experimental.sparsegpt.dispatch import evaluate_perplexity, load_model
from sparseml.experimental.sparsegpt.llama2 import load_data
from sparseml.experimental.sparsegpt.main import sequential
from sparseml.modifiers.obcq.utils.helpers import ppl_eval_general
from sparseml.transformers.sparsification.obcq.obcq import one_shot
from sparseml.transformers.sparsification.obcq.utils.helpers import llama_forward


dataset = "open_platypus"
Expand Down Expand Up @@ -101,8 +103,8 @@ def run_experimental_obcq(experimental_args):
torch.cuda.empty_cache()

_, testloader, _ = load_data(experimental_args, data_sequence_length)
prod_perplexity = evaluate_perplexity(
experimental_args, prod_model, testloader, device, max_samples_per_iteration=8
prod_perplexity = ppl_eval_general(
llama_forward, prod_model, testloader, device, max_samples_per_iteration=8
)
print(
f"Experimental Perplexity: {exp_perplexity}, "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from sparseml.experimental.sparsegpt.dispatch import evaluate_perplexity, load_model
from sparseml.experimental.sparsegpt.main import sequential
from sparseml.experimental.sparsegpt.opt import load_data
from sparseml.modifiers.obcq.utils.helpers import ppl_eval_general
from sparseml.transformers.sparsification.obcq.obcq import one_shot
from sparseml.transformers.sparsification.obcq.utils.helpers import opt_forward


dataset = "c4"
Expand Down Expand Up @@ -99,10 +101,10 @@ def run_experimental_obcq(experimental_args):
device=prod_args.device,
recipe_file=prod_args.recipe,
)

experimental_args.dataset = "wikitext2"
_, testloader, _ = load_data(experimental_args, data_sequence_length)
prod_perplexity = evaluate_perplexity(
experimental_args, prod_model, testloader, device, max_samples_per_iteration=8
prod_perplexity = ppl_eval_general(
opt_forward, prod_model, testloader, device, max_samples_per_iteration=8
)
print(
f"Experimental Perplexity: {exp_perplexity}, "
Expand Down
43 changes: 21 additions & 22 deletions src/sparseml/transformers/data/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import torch
from datasets import load_dataset
from torch.nn import Module
from torch.utils.data import Dataset
from transformers import AutoTokenizer

Expand All @@ -27,7 +26,7 @@
class TransformersDataset(RegistryMixin, Dataset):
def __init__(
self,
model: Module,
model: str,
seqlen: int,
nsamples: int,
path: str,
Expand All @@ -36,6 +35,7 @@ def __init__(
split: str = "train",
use_max_tokens: bool = True,
split_percent_to_use: float = 1.0,
shuffle: bool = True,
**kwargs,
):
self.tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
Expand All @@ -51,38 +51,37 @@ def __init__(

random.seed(seed)
data = list(dataset)
random.shuffle(data)
data_to_use = int(split_percent_to_use * len(data))
data = data[-data_to_use:] if self._split_from_end else data[:data_to_use]
self._data = data[-data_to_use:] if self._split_from_end else data[:data_to_use]
if not self._nsamples:
self._nsamples = len(dataset)

if use_max_tokens:
data_length = int(min(len(data), self._seqlen * self._nsamples / 50))
else:
data_length = self._nsamples

if self._split_from_end:
self._data = data[-data_length:]
else:
self._data = data[:data_length]
if shuffle:
random.shuffle(self._data)
self._data = self._data[: self._nsamples]

def create_dataloader(self, data, join_on=None):
self.loader = []
if self._use_max_tokens:
full_encoded = self.tokenizer(join_on.join(data), return_tensors="pt")[
data_idx = 0
encoder = self.tokenizer(join_on.join(data), return_tensors="pt")[
"input_ids"
][0]
for sample_idx in range(self._nsamples):
start_idx = sample_idx * self._seqlen
while self._nsamples is None or len(self.loader) < self._nsamples:
start_idx = data_idx * self._seqlen
end_idx = start_idx + self._seqlen
if end_idx > len(full_encoded):
self._nsamples = sample_idx
if start_idx >= len(encoder):
break
tokenized_sample = self._add_end_token(full_encoded[start_idx:end_idx])
tokenized_sample = self._add_end_token(tokenized_sample)
elif end_idx >= len(encoder):
sequence = encoder[start_idx:]
else:
sequence = encoder[start_idx:end_idx]
data_idx += 1

tokenized_sample = self._add_end_token(sequence)
tokenized_sample = torch.unsqueeze(tokenized_sample, dim=0)
self.loader.append(tokenized_sample)
if data_idx >= len(data):
break
else:
for sample in data:
tokenized_sample = self.tokenizer(
Expand Down Expand Up @@ -111,7 +110,7 @@ def _add_end_token(self, tokenized_sample):
return tokenized_sample

def __len__(self):
return self._nsamples
return len(self.loader)

def __item__(self, idx):
return self.loader[idx]
2 changes: 2 additions & 0 deletions src/sparseml/transformers/data/c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def __init__(
split_percent_to_use: float = 1.0,
):
kwargs = {"data_files": {split: "en/c4-train.00000-of-01024.json.gz"}}
if split_percent_to_use > 0.2:
split_percent_to_use = 0.2
super().__init__(
model=model,
seqlen=seqlen,
Expand Down
3 changes: 2 additions & 1 deletion src/sparseml/transformers/data/wikitext.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def __init__(
seed=seed,
split=split,
split_percent_to_use=split_percent_to_use,
shuffle=False,
)

join_on = "\n\n" if split == "test" else " "
processed_data = [sample["text"] for sample in self._data]
processed_data = [str(sample["text"]) for sample in self._data]
self.create_dataloader(processed_data, join_on=join_on)

0 comments on commit 29271a3

Please sign in to comment.