Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding more token encoding types #1254

Merged
merged 16 commits into from
Jun 6, 2024
9 changes: 8 additions & 1 deletion llmfoundry/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset
from llmfoundry.data.data import (
SUPPORTED_MDS_ENCODING_TYPES,
ConcatTokensDataset,
NoConcatDataset,
stream_remote_local_validate,
)
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.data.finetuning import (
Seq2SeqFinetuningCollator,
Expand Down Expand Up @@ -55,4 +60,6 @@
'auto_packing_ratio',
'profile_packing',
'ConcatenatedSequenceCollatorWrapper',
'stream_remote_local_validate',
'SUPPORTED_MDS_ENCODING_TYPES',
]
49 changes: 42 additions & 7 deletions llmfoundry/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,30 @@
import os
import warnings
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Union
from typing import Dict, Iterable, Optional, Union

import datasets as hf_datasets
import numpy as np
from numpy.typing import NDArray
from torch.utils.data import IterableDataset
from transformers import PreTrainedTokenizerBase

__all__ = [
'ConcatTokensDataset',
'NoConcatDataset',
'stream_remote_local_validate',
'SUPPORTED_MDS_ENCODING_TYPES',
]

SUPPORTED_MDS_ENCODING_TYPES = [
'int8',
'int16',
'int32',
'int64',
'uint8',
'uint16',
'uint32',
'uint64',
]


Expand Down Expand Up @@ -97,14 +111,14 @@ def __init__(
)

@abstractmethod
def __iter__(self) -> Iterable[Dict[str, bytes]]:
def __iter__(self) -> Iterable[Union[Dict[str, bytes], Dict[str, NDArray]]]:
pass


class ConcatTokensDataset(AbstractConcatTokensDataset):
"""An IterableDataset that returns token samples for MDSWriter.

Returns dicts of {'tokens': bytes}
Returns dicts of {'tokens': ndarray:int32}

To use data created by this class and written to MDS format:

Expand All @@ -119,7 +133,7 @@ class ConcatTokensDataset(AbstractConcatTokensDataset):
# note, you need to copy the numpy array because the original is non-writeable
# and torch does not support non-writeable tensors, so you get a scary warning and
# if you do try to write to the tensor you get undefined behavior
tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy())
tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int32).copy())
print(tokenizer.decode(tokens))
```
"""
Expand All @@ -136,7 +150,7 @@ def __init__(
self.hf_dataset = hf_dataset
super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap)

def __iter__(self) -> Iterable[Dict[str, bytes]]:
def __iter__(self) -> Iterable[Dict[str, NDArray]]:
buffer = []
for sample in self.hf_dataset:
encoded = self.tokenizer(
Expand All @@ -150,6 +164,27 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]:
concat_sample = buffer[:self.max_length]
buffer = buffer[self.max_length:] if self.should_wrap else []
yield {
# convert to bytes to store in MDS binary format
'tokens': np.asarray(concat_sample).tobytes(),
# convert to ndarray to store in MDS format
'tokens': np.asarray(concat_sample, dtype=np.int32),
}


def stream_remote_local_validate(
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
remote: Optional[str],
local: Optional[str],
split: Optional[str],
):
"""Check that, if needed, the local/split directory exists.

Args:
remote (Optional[str]): Remote path to the dataset.
local (Optional[str]): Local path to the dataset.
split (Optional[str]): Subdirectory specifying which dataset split to use, if any.
"""
if remote is None or (local == remote):
if local is not None and os.path.isdir(local):
contents = set(os.listdir(local))
if split is not None and split not in contents:
raise ValueError(
f'Local directory {local} does not contain split {split}',
)
34 changes: 16 additions & 18 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from streaming import Stream, StreamingDataset
from transformers import PreTrainedTokenizerBase

from llmfoundry.data import (
SUPPORTED_MDS_ENCODING_TYPES,
stream_remote_local_validate,
)
from llmfoundry.data.finetuning.collator import (
_HF_IGNORE_INDEX,
stitch_turns_decoder_only,
Expand Down Expand Up @@ -494,26 +498,13 @@ def is_valid_ift_example(
return True


def _stream_remote_local_validate(
remote: Optional[str],
local: Optional[str],
split: Optional[str],
):
if remote is None or (local == remote):
if local is not None and os.path.isdir(local):
contents = set(os.listdir(local))
if split is not None and split not in contents:
raise ValueError(
f'Local directory {local} does not contain split {split}',
)


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.

Args:
tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to
tokenize samples.
token_encoding_type (str): The encoding type of the tokenized samples. Defaults to 'int64'.
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
Expand Down Expand Up @@ -574,6 +565,7 @@ class StreamingFinetuningDataset(StreamingDataset):
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
token_encoding_type: str = 'int64',
streams: Optional[Sequence[Stream]] = None,
local: Optional[str] = None,
remote: Optional[str] = None,
Expand Down Expand Up @@ -606,11 +598,17 @@ def __init__(
f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}',
)

if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES:
raise ValueError(
f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}',
)
self.token_encoding_type = token_encoding_type

if streams is None:
_stream_remote_local_validate(remote, local, split)
stream_remote_local_validate(remote, local, split)
else:
for stream in streams:
_stream_remote_local_validate(
stream_remote_local_validate(
stream.remote,
stream.local,
split,
Expand Down Expand Up @@ -656,11 +654,11 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
if isinstance(sample['input_ids'], bytes):
sample['input_ids'] = np.frombuffer(
sample['input_ids'],
dtype=np.int64,
dtype=getattr(np, self.token_encoding_type),
)[:self.max_seq_len].tolist().copy()
sample['labels'] = np.frombuffer(
sample['labels'],
dtype=np.int64,
dtype=getattr(np, self.token_encoding_type),
)[:self.max_seq_len].tolist().copy()
elif isinstance(sample['input_ids'], np.ndarray):
sample['input_ids'] = sample['input_ids'][:self.max_seq_len
Expand Down
45 changes: 33 additions & 12 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""Build a StreamingTextDataset dataset and dataloader for training."""

import inspect
import os
from itertools import islice
from typing import (
Any,
Expand All @@ -25,6 +24,10 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry import registry
from llmfoundry.data import (
SUPPORTED_MDS_ENCODING_TYPES,
stream_remote_local_validate,
)
from llmfoundry.utils.registry_utils import construct_from_registry

__all__ = [
Expand All @@ -41,6 +44,7 @@ class StreamingTextDataset(StreamingDataset):
tokenizer (Tokenizer): HuggingFace tokenizer to
tokenize samples.
max_seq_len (int): The max sequence length of each sample.
token_encoding_type (str): The encoding type of the tokenized samples. Defaults to 'int64'.
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
Expand Down Expand Up @@ -106,6 +110,7 @@ def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
max_seq_len: int,
token_encoding_type: str = 'int64',
streams: Optional[Sequence[Stream]] = None,
remote: Optional[str] = None,
local: Optional[str] = None,
Expand Down Expand Up @@ -137,13 +142,21 @@ def __init__(
f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}',
)

if local is not None and (remote is None or (local == remote)):
if os.path.isdir(local):
contents = set(os.listdir(local))
if split not in contents:
raise ValueError(
f'local directory {local} does not contain split {split}',
)
if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES:
raise ValueError(
f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}',
)
self.token_encoding_type = token_encoding_type

if streams is None:
stream_remote_local_validate(remote, local, split)
else:
for stream in streams:
stream_remote_local_validate(
stream.remote,
stream.local,
split,
)

# TODO: discover where yamls are being converted incorrect, but temporary workaround
if isinstance(shuffle_block_size, float):
Expand Down Expand Up @@ -197,10 +210,18 @@ def _read_binary_tokenized_sample(
self,
sample: Dict[str, Any],
) -> torch.Tensor:
return torch.from_numpy(
np.frombuffer(sample['tokens'],
dtype=np.int64)[:self.max_seq_len].copy(),
)
# Modeling code still expects int64 tensors.
if isinstance(sample['tokens'], np.ndarray):
return torch.from_numpy(
sample['tokens'][:self.max_seq_len].copy(),
).to(torch.int64)
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
else:
return torch.from_numpy(
np.frombuffer(
sample['tokens'],
dtype=getattr(np, self.token_encoding_type),
)[:self.max_seq_len].copy(),
).to(torch.int64)

# How to process a sample
def __getitem__(self,
Expand Down
12 changes: 9 additions & 3 deletions scripts/data_prep/convert_dataset_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import datasets as hf_datasets
import psutil
import torch
from numpy.typing import NDArray
from streaming import MDSWriter
from torch.utils.data import DataLoader, Dataset, IterableDataset
from tqdm import tqdm
Expand Down Expand Up @@ -338,7 +340,7 @@ def build_dataloader(
def generate_samples(
loader: DataLoader,
truncate_num_samples: Optional[int] = None,
) -> Iterable[Dict[str, bytes]]:
) -> Iterable[Union[Dict[str, bytes], Dict[str, NDArray]]]:
"""Generator over samples of a dataloader.

Args:
Expand All @@ -356,7 +358,11 @@ def generate_samples(
if truncate_num_samples is not None and n_samples == truncate_num_samples:
return
n_samples += 1
yield {k: v[idx] for k, v in batch.items()}
yield {
k:
v[idx].numpy() if isinstance(v[idx], torch.Tensor) else v[idx]
for k, v in batch.items()
}


def main(args: Namespace) -> None:
Expand All @@ -377,7 +383,7 @@ def main(args: Namespace) -> None:
tokenizer = build_tokenizer(args.tokenizer, args.tokenizer_kwargs)
# we will enforce length, so suppress warnings about sequences too long for the model
tokenizer.model_max_length = int(1e30)
columns = {'tokens': 'bytes'}
columns = {'tokens': 'ndarray:int32'}
else:
mode = ConcatMode.NO_CONCAT
tokenizer = None
Expand Down
14 changes: 10 additions & 4 deletions scripts/data_prep/convert_dataset_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from argparse import ArgumentParser, Namespace
from enum import Enum
from glob import glob
from typing import Dict, Iterable, Optional
from typing import Dict, Iterable, Optional, Union

import datasets as hf_datasets
import torch
from numpy.typing import NDArray
from streaming import MDSWriter
from torch.utils.data import DataLoader, IterableDataset
from tqdm import tqdm
Expand Down Expand Up @@ -143,7 +145,7 @@ def build_hf_dataset(
def generate_samples(
loader: DataLoader,
truncate_num_samples: Optional[int] = None,
) -> Iterable[Dict[str, bytes]]:
) -> Iterable[Union[Dict[str, bytes], Dict[str, NDArray]]]:
"""Generator over samples of a dataloader.

Args:
Expand All @@ -161,7 +163,11 @@ def generate_samples(
if truncate_num_samples is not None and n_samples == truncate_num_samples:
return
n_samples += 1
yield {k: v[idx] for k, v in batch.items()}
yield {
k:
v[idx].numpy() if isinstance(v[idx], torch.Tensor) else v[idx]
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
for k, v in batch.items()
}


def main(args: Namespace) -> None:
Expand All @@ -175,7 +181,7 @@ def main(args: Namespace) -> None:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
# we will enforce length, so suppress warnings about sequences too long for the model
tokenizer.model_max_length = int(1e30)
columns = {'tokens': 'bytes'}
columns = {'tokens': 'ndarray:int32'}
else:
mode = ConcatMode.NO_CONCAT
tokenizer = None
Expand Down
Loading
Loading