Skip to content

Commit

Permalink
update test for updated pretraining multipack code
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 5, 2024
1 parent a5eb52e commit ca36b3c
Showing 1 changed file with 27 additions and 32 deletions.
59 changes: 27 additions & 32 deletions tests/test_packed_pretraining.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Module for testing streaming dataset sequence packing"""
import math
import unittest
from functools import partial

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data import encode_packed_pretraining


Expand All @@ -19,21 +19,9 @@ class TestPacking(unittest.TestCase):
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"pad_token": "[PAD]",
}
)
self.max_seq_length = 8192
self.batch_size = 6
self.sample_packing_efficiency = 1
self.data_collator_kwargs = {
"padding": True,
"pad_to_multiple_of": 64 * math.ceil(self.max_seq_length / 64),
}
self.tokenizer.pad_token = "</s>"
self.max_seq_length = 2048
self.batch_size = 2

def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code
Expand All @@ -43,11 +31,19 @@ def test_packing_stream_dataset(self):
streaming=True,
)["train"]

collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=self.max_seq_length,
)

encode = partial(
encode_packed_pretraining,
self.tokenizer,
collate_fn,
max_seq_length=self.max_seq_length,
sample_packing_efficiency=self.sample_packing_efficiency,
batch_size=self.batch_size,
)

dataset = dataset.map(
Expand All @@ -57,28 +53,27 @@ def test_packing_stream_dataset(self):
remove_columns=dataset.features.keys(),
)

data_collator_fn = DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**self.data_collator_kwargs,
)

trainer_loader = DataLoader(
dataset,
batch_size=self.batch_size,
collate_fn=data_collator_fn,
batch_size=1,
collate_fn=None,
drop_last=True,
)
idx = 0
for data in trainer_loader:
if idx > 10:
break
assert data["input_ids"].shape == (self.batch_size, self.max_seq_length)
assert data["position_ids"].shape == (self.batch_size, self.max_seq_length)
assert data["labels"].shape == (self.batch_size, self.max_seq_length)
assert data["attention_mask"].shape == (
self.batch_size,
self.max_seq_length,
assert data["input_ids"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
)
assert data["position_ids"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
)
assert data["labels"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
)
assert data["attention_mask"].shape == torch.Size(
[1, self.batch_size * self.max_seq_length]
)
idx += 1

Expand Down

0 comments on commit ca36b3c

Please sign in to comment.