Skip to content

Commit

Permalink
Adding more token encoding types (#1254)
Browse files Browse the repository at this point in the history
* add more token encoing types

* add more token encoing types

* add tests

* add tests

* ft support, tests

* linting is shortening my lifespan

* linting is shortening my lifespan

* long tensor

* long tensor

* long tensor

* feedbacc

* import

* import

---------

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
(cherry picked from commit 42c2d9a)
  • Loading branch information
snarayan21 authored and v-chen_data committed Jun 7, 2024
1 parent c53c105 commit 245646f
Show file tree
Hide file tree
Showing 11 changed files with 350 additions and 78 deletions.
9 changes: 8 additions & 1 deletion llmfoundry/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset
from llmfoundry.data.data import (
SUPPORTED_MDS_ENCODING_TYPES,
ConcatTokensDataset,
NoConcatDataset,
stream_remote_local_validate,
)
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.data.finetuning import (
Seq2SeqFinetuningCollator,
Expand Down Expand Up @@ -55,4 +60,6 @@
'auto_packing_ratio',
'profile_packing',
'ConcatenatedSequenceCollatorWrapper',
'stream_remote_local_validate',
'SUPPORTED_MDS_ENCODING_TYPES',
]
50 changes: 43 additions & 7 deletions llmfoundry/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,31 @@
import os
import warnings
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Union
from typing import Dict, Iterable, Optional, Union

import datasets as hf_datasets
import numpy as np
from numpy.typing import NDArray
from torch.utils.data import IterableDataset
from transformers import PreTrainedTokenizerBase

__all__ = [
'AbstractConcatTokensDataset',
'ConcatTokensDataset',
'NoConcatDataset',
'stream_remote_local_validate',
'SUPPORTED_MDS_ENCODING_TYPES',
]

SUPPORTED_MDS_ENCODING_TYPES = [
'int8',
'int16',
'int32',
'int64',
'uint8',
'uint16',
'uint32',
'uint64',
]


Expand Down Expand Up @@ -97,14 +112,14 @@ def __init__(
)

@abstractmethod
def __iter__(self) -> Iterable[Dict[str, bytes]]:
def __iter__(self) -> Iterable[Union[Dict[str, bytes], Dict[str, NDArray]]]:
pass


class ConcatTokensDataset(AbstractConcatTokensDataset):
"""An IterableDataset that returns token samples for MDSWriter.
Returns dicts of {'tokens': bytes}
Returns dicts of {'tokens': ndarray:int32}
To use data created by this class and written to MDS format:
Expand All @@ -119,7 +134,7 @@ class ConcatTokensDataset(AbstractConcatTokensDataset):
# note, you need to copy the numpy array because the original is non-writeable
# and torch does not support non-writeable tensors, so you get a scary warning and
# if you do try to write to the tensor you get undefined behavior
tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy())
tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int32).copy())
print(tokenizer.decode(tokens))
```
"""
Expand All @@ -136,7 +151,7 @@ def __init__(
self.hf_dataset = hf_dataset
super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap)

def __iter__(self) -> Iterable[Dict[str, bytes]]:
def __iter__(self) -> Iterable[Dict[str, NDArray]]:
buffer = []
for sample in self.hf_dataset:
encoded = self.tokenizer(
Expand All @@ -150,6 +165,27 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]:
concat_sample = buffer[:self.max_length]
buffer = buffer[self.max_length:] if self.should_wrap else []
yield {
# convert to bytes to store in MDS binary format
'tokens': np.asarray(concat_sample).tobytes(),
# convert to ndarray to store in MDS format
'tokens': np.asarray(concat_sample, dtype=np.int32),
}


def stream_remote_local_validate(
remote: Optional[str],
local: Optional[str],
split: Optional[str],
):
"""Check that, if needed, the local/split directory exists.
Args:
remote (Optional[str]): Remote path to the dataset.
local (Optional[str]): Local path to the dataset.
split (Optional[str]): Subdirectory specifying which dataset split to use, if any.
"""
if remote is None or (local == remote):
if local is not None and os.path.isdir(local):
contents = set(os.listdir(local))
if split is not None and split not in contents:
raise ValueError(
f'Local directory {local} does not contain split {split}',
)
36 changes: 18 additions & 18 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from streaming import Stream, StreamingDataset
from transformers import PreTrainedTokenizerBase

from llmfoundry.data import (
SUPPORTED_MDS_ENCODING_TYPES,
stream_remote_local_validate,
)
from llmfoundry.data.finetuning.collator import (
_HF_IGNORE_INDEX,
stitch_turns_decoder_only,
Expand Down Expand Up @@ -494,26 +498,15 @@ def is_valid_ift_example(
return True


def _stream_remote_local_validate(
remote: Optional[str],
local: Optional[str],
split: Optional[str],
):
if remote is None or (local == remote):
if local is not None and os.path.isdir(local):
contents = set(os.listdir(local))
if split is not None and split not in contents:
raise ValueError(
f'Local directory {local} does not contain split {split}',
)


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.
Args:
tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to
tokenize samples.
token_encoding_type (str): The encoding type of the tokenized samples. This is only used
for legacy datasets that have been written directly as 'bytes' instead of numpy
arrays. Types are auto-inferred for numpy arrays. Defaults to 'int64'.
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
Expand Down Expand Up @@ -574,6 +567,7 @@ class StreamingFinetuningDataset(StreamingDataset):
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
token_encoding_type: str = 'int64',
streams: Optional[Sequence[Stream]] = None,
local: Optional[str] = None,
remote: Optional[str] = None,
Expand Down Expand Up @@ -606,11 +600,17 @@ def __init__(
f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}',
)

if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES:
raise ValueError(
f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}',
)
self.token_encoding_type = token_encoding_type

if streams is None:
_stream_remote_local_validate(remote, local, split)
stream_remote_local_validate(remote, local, split)
else:
for stream in streams:
_stream_remote_local_validate(
stream_remote_local_validate(
stream.remote,
stream.local,
split,
Expand Down Expand Up @@ -656,11 +656,11 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
if isinstance(sample['input_ids'], bytes):
sample['input_ids'] = np.frombuffer(
sample['input_ids'],
dtype=np.int64,
dtype=getattr(np, self.token_encoding_type),
)[:self.max_seq_len].tolist().copy()
sample['labels'] = np.frombuffer(
sample['labels'],
dtype=np.int64,
dtype=getattr(np, self.token_encoding_type),
)[:self.max_seq_len].tolist().copy()
elif isinstance(sample['input_ids'], np.ndarray):
sample['input_ids'] = sample['input_ids'][:self.max_seq_len
Expand Down
47 changes: 35 additions & 12 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""Build a StreamingTextDataset dataset and dataloader for training."""

import inspect
import os
from itertools import islice
from typing import (
Any,
Expand All @@ -25,6 +24,10 @@
from transformers import PreTrainedTokenizerBase

from llmfoundry import registry
from llmfoundry.data import (
SUPPORTED_MDS_ENCODING_TYPES,
stream_remote_local_validate,
)
from llmfoundry.utils.registry_utils import construct_from_registry

__all__ = [
Expand All @@ -41,6 +44,9 @@ class StreamingTextDataset(StreamingDataset):
tokenizer (Tokenizer): HuggingFace tokenizer to
tokenize samples.
max_seq_len (int): The max sequence length of each sample.
token_encoding_type (str): The encoding type of the tokenized samples. This is only used
for legacy datasets that have been written directly as 'bytes' instead of numpy
arrays. Types are auto-inferred for numpy arrays. Defaults to 'int64'.
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
Expand Down Expand Up @@ -106,6 +112,7 @@ def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
max_seq_len: int,
token_encoding_type: str = 'int64',
streams: Optional[Sequence[Stream]] = None,
remote: Optional[str] = None,
local: Optional[str] = None,
Expand Down Expand Up @@ -137,13 +144,21 @@ def __init__(
f'StreamingTextDataset() got an unexpected keyword argument: {kwargs}',
)

if local is not None and (remote is None or (local == remote)):
if os.path.isdir(local):
contents = set(os.listdir(local))
if split not in contents:
raise ValueError(
f'local directory {local} does not contain split {split}',
)
if token_encoding_type not in SUPPORTED_MDS_ENCODING_TYPES:
raise ValueError(
f'The token_encoding_type must be one of {SUPPORTED_MDS_ENCODING_TYPES}, but got {token_encoding_type}',
)
self.token_encoding_type = token_encoding_type

if streams is None:
stream_remote_local_validate(remote, local, split)
else:
for stream in streams:
stream_remote_local_validate(
stream.remote,
stream.local,
split,
)

# TODO: discover where yamls are being converted incorrect, but temporary workaround
if isinstance(shuffle_block_size, float):
Expand Down Expand Up @@ -197,10 +212,18 @@ def _read_binary_tokenized_sample(
self,
sample: Dict[str, Any],
) -> torch.Tensor:
return torch.from_numpy(
np.frombuffer(sample['tokens'],
dtype=np.int64)[:self.max_seq_len].copy(),
)
# Modeling code still expects int64 tensors.
if isinstance(sample['tokens'], np.ndarray):
return torch.from_numpy(
sample['tokens'][:self.max_seq_len].copy(),
).to(torch.int64)
else:
return torch.from_numpy(
np.frombuffer(
sample['tokens'],
dtype=getattr(np, self.token_encoding_type),
)[:self.max_seq_len].copy(),
).to(torch.int64)

# How to process a sample
def __getitem__(self,
Expand Down
17 changes: 17 additions & 0 deletions scripts/data_prep/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ python convert_dataset_json.py \

Where `--path` can be a single json file, or a folder containing json files. `--split` denotes the intended split (hf defaults to `train`).

### Raw text files

Using the `convert_text_to_mds.py` script, we convert a [text file](https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt) containing the complete works of William Shakespeare.

<!--pytest.mark.skip-->
```bash
# Convert json dataset to StreamingDataset format
mkdir shakespeare && cd shakespeare
curl -O https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt
cd ..
python convert_text_to_mds.py \
--output_folder my-copy-shakespeare \
--input_folder shakespeare \
--concat_tokens 2048 --tokenizer EleutherAI/gpt-neox-20b \
--compression zstd
```

## Converting a finetuning dataset
Using the `convert_finetuning_dataset.py` script you can run a command such as:
<!--pytest.mark.skip-->
Expand Down
12 changes: 9 additions & 3 deletions scripts/data_prep/convert_dataset_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import datasets as hf_datasets
import psutil
import torch
from numpy.typing import NDArray
from streaming import MDSWriter
from torch.utils.data import DataLoader, Dataset, IterableDataset
from tqdm import tqdm
Expand Down Expand Up @@ -338,7 +340,7 @@ def build_dataloader(
def generate_samples(
loader: DataLoader,
truncate_num_samples: Optional[int] = None,
) -> Iterable[Dict[str, bytes]]:
) -> Iterable[Union[Dict[str, bytes], Dict[str, NDArray]]]:
"""Generator over samples of a dataloader.
Args:
Expand All @@ -356,7 +358,11 @@ def generate_samples(
if truncate_num_samples is not None and n_samples == truncate_num_samples:
return
n_samples += 1
yield {k: v[idx] for k, v in batch.items()}
yield {
k:
v[idx].numpy() if isinstance(v[idx], torch.Tensor) else v[idx]
for k, v in batch.items()
}


def main(args: Namespace) -> None:
Expand All @@ -377,7 +383,7 @@ def main(args: Namespace) -> None:
tokenizer = build_tokenizer(args.tokenizer, args.tokenizer_kwargs)
# we will enforce length, so suppress warnings about sequences too long for the model
tokenizer.model_max_length = int(1e30)
columns = {'tokens': 'bytes'}
columns = {'tokens': 'ndarray:int32'}
else:
mode = ConcatMode.NO_CONCAT
tokenizer = None
Expand Down
Loading

0 comments on commit 245646f

Please sign in to comment.