Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix token counting to allow there to be no attention mask #818

Merged
merged 3 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,10 @@ def get_tokens_per_batch_func(
"""

def get_num_samples_in_batch(batch: Batch) -> int:
if not isinstance(batch, Mapping) or 'attention_mask' not in batch:
if not isinstance(batch, Mapping) or ('attention_mask' not in batch and
'input_ids' not in batch):
raise ValueError(
'get_tokens_per_batch_func() requires a batch with an attention_mask key'
'get_tokens_per_batch_func() requires a batch with an attention_mask key or an input_ids key'
)

if not decoder_only and 'decoder_attention_mask' not in batch:
Expand All @@ -336,7 +337,10 @@ def get_num_samples_in_batch(batch: Batch) -> int:
)

# Count number of non padding tokens in batch
input_ids_tokens = int(torch.sum(batch['attention_mask']).item())
if 'attention_mask' in batch:
input_ids_tokens = int(torch.sum(batch['attention_mask']).item())
else:
input_ids_tokens = batch['input_ids'].numel()

# For encoder decoder models only
decoder_input_ids_tokens = 0
Expand Down
22 changes: 15 additions & 7 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,16 +642,17 @@ def test_token_counting_func(pad_token_id: int, batch_size: int,
assert actual_token_count == expected_token_count


@pytest.mark.parametrize(
'dataloader_type',
['finetuning-hf', 'finetuning-streaming', 'denoising', 'text'])
@pytest.mark.parametrize('dataloader_type,tensor_input',
[('finetuning-hf', False),
('finetuning-streaming', False), ('denoising', False),
('text', True), ('text', False)])
@pytest.mark.parametrize('pad_token_id', [100, None])
@pytest.mark.parametrize('batch_size', [1, 8])
@pytest.mark.parametrize('model_max_length', [1024])
@pytest.mark.parametrize('padding_side', ['left'])
def test_token_counting_func_dataloader_setting(
dataloader_type: str, pad_token_id: Optional[int], batch_size: int,
model_max_length: int, padding_side: str,
dataloader_type: str, tensor_input: bool, pad_token_id: Optional[int],
batch_size: int, model_max_length: int, padding_side: str,
monkeypatch: pytest.MonkeyPatch):
gptt = transformers.AutoTokenizer.from_pretrained('gpt2')
gptt.pad_token_id = pad_token_id if pad_token_id is not None else gptt.eos_token_id
Expand All @@ -661,9 +662,11 @@ def test_token_counting_func_dataloader_setting(
batch_strings = []
expected_token_count = 0
for _ in range(batch_size):
# Get randomly different lengths if we are going to add padding
sample_length = random.randint(
1, model_max_length //
4) if pad_token_id is not None else model_max_length // 4
4) if (pad_token_id is not None and
not tensor_input) else model_max_length // 4
batch_strings.append(' '.join(['hello'] * sample_length))
expected_token_count += sample_length

Expand All @@ -672,13 +675,18 @@ def test_token_counting_func_dataloader_setting(
for b in batch_strings
]

if tensor_input:
batch_tokenized = [
torch.tensor(b['input_ids']) for b in batch_tokenized
]

if dataloader_type == 'denoising':
expected_token_count += 2 * batch_size # for the two eos tokens
expected_token_count += 5 * batch_size # for the corruption prefix tokens

if dataloader_type in {'finetuning-hf', 'finetuning-streaming'}:
for b in batch_tokenized:
b['labels'] = b['input_ids'].copy()
b['labels'] = b['input_ids'].copy() # type: ignore
expected_token_count *= 2
expected_token_count += 1 * batch_size # for the eos token

Expand Down
Loading