Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunk file reads and tokenization for text to mds conversion #1240

Merged
merged 5 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 43 additions & 22 deletions llmfoundry/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Datasets for converting to MDS Shards."""
import os
import warnings
from abc import ABC, abstractmethod
from typing import Dict, Iterable, Union

import datasets as hf_datasets
Expand Down Expand Up @@ -35,39 +36,20 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]:
yield {'text': sample['text'].encode('utf-8')}


class ConcatTokensDataset(IterableDataset):
"""An IterableDataset that returns token samples for MDSWriter.

Returns dicts of {'tokens': bytes}

To use data created by this class and written to MDS format:

```python
import torch
from streaming.base import StreamingDataset
from transformers import AutoTokenizer
class AbstractConcatTokensDataset(ABC, IterableDataset):
"""Abstract class for defining an IterableDataset that tokenizes and.

tokenizer = AutoTokenizer.from_pretrained('your/tokenizer')
ds = StreamingDataset(local='mds-data-folder', split='val')

# note, you need to copy the numpy array because the original is non-writeable
# and torch does not support non-writeable tensors, so you get a scary warning and
# if you do try to write to the tensor you get undefined behavior
tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy())
print(tokenizer.decode(tokens))
```
concatenates text samples on the fly.
"""

def __init__(
self,
hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset],
tokenizer: PreTrainedTokenizerBase,
max_length: int,
bos_text: str,
eos_text: str,
no_wrap: bool,
):
self.hf_dataset = hf_dataset
self.tokenizer = tokenizer
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
self.max_length = max_length
Expand Down Expand Up @@ -114,8 +96,47 @@ def __init__(
'in duplicated special tokens. Please be sure this is what you intend.',
)

@abstractmethod
def __iter__(self) -> Iterable[Dict[str, bytes]]:
pass


class ConcatTokensDataset(AbstractConcatTokensDataset):
"""An IterableDataset that returns token samples for MDSWriter.

Returns dicts of {'tokens': bytes}

To use data created by this class and written to MDS format:

```python
import torch
from streaming.base import StreamingDataset
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('your/tokenizer')
ds = StreamingDataset(local='mds-data-folder', split='val')

# note, you need to copy the numpy array because the original is non-writeable
# and torch does not support non-writeable tensors, so you get a scary warning and
# if you do try to write to the tensor you get undefined behavior
tokens = torch.from_numpy(np.frombuffer(ds[0]['tokens'], dtype=np.int64).copy())
print(tokenizer.decode(tokens))
```
"""

def __init__(
self,
hf_dataset: Union[hf_datasets.IterableDataset, hf_datasets.Dataset],
tokenizer: PreTrainedTokenizerBase,
max_length: int,
bos_text: str,
eos_text: str,
no_wrap: bool,
):
self.hf_dataset = hf_dataset
super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap)

def __iter__(self) -> Iterable[Dict[str, bytes]]:
buffer = []
for sample in self.hf_dataset:
encoded = self.tokenizer(
Expand Down
7 changes: 2 additions & 5 deletions llmfoundry/utils/data_prep_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
output_folder: str,
object_store: Optional[ObjectStore],
):
"""Iterable that downloads files from an object store before yielding.
"""Iterable that downloads files before yielding the local filename.

If object_store is None, input_folder_prefix is treated as a local path.

Expand Down Expand Up @@ -138,7 +138,4 @@ def __iter__(self):
object_name=object_name,
output_filename=output_filename,
)

with open(output_filename) as _txt_file:
txt = _txt_file.read()
yield {'text': txt}
yield output_filename
74 changes: 69 additions & 5 deletions scripts/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import tempfile
from argparse import ArgumentParser, Namespace
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from glob import glob
from typing import Iterable, List, Tuple, cast
from typing import Dict, Iterable, List, Tuple, cast

import numpy as np
import psutil
from composer.utils import (
ObjectStore,
Expand All @@ -18,9 +20,9 @@
)
from streaming import MDSWriter
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from llmfoundry.data import ConcatTokensDataset
from llmfoundry.data.data import AbstractConcatTokensDataset
from llmfoundry.utils import maybe_create_mosaicml_logger
from llmfoundry.utils.data_prep_utils import (
DownloadingIterable,
Expand All @@ -37,6 +39,68 @@
DONE_FILENAME = '.text_to_mds_conversion_done'


class ConcatTokensFromFilesDataset(AbstractConcatTokensDataset):
"""An IterableDataset that returns token samples for MDSWriter from files.

Returns dicts of {'tokens': bytes}

Each file is considered a sequence.
"""

def __init__(
self,
files: Iterable[str],
tokenizer: PreTrainedTokenizerBase,
max_length: int,
bos_text: str,
eos_text: str,
no_wrap: bool,
):
self.files = files
super().__init__(tokenizer, max_length, bos_text, eos_text, no_wrap)

def __iter__(self) -> Iterable[Dict[str, bytes]]:

buffer = []
for file in self.files:
with open(file, 'r') as f:
buffer += self.bos_tokens
first_chunk = True
# Read the file in 1MB chunks to avoid memory issues
for chunk in iter(partial(f.read, 1000000), ''):
# Tokenize the chunk
encoded = self.tokenizer(
chunk,
truncation=False,
padding=False,
)
iids = encoded['input_ids']

# If this is not the first chunk, remove the BOS token
if not first_chunk:
if iids[0] == self.tokenizer.bos_token_id:
iids = iids[1:]

# Add the tokens to the buffer
buffer += iids
while len(buffer) >= self.max_length:
concat_sample = buffer[:self.max_length]
buffer = buffer[self.
max_length:] if self.should_wrap else []
irenedea marked this conversation as resolved.
Show resolved Hide resolved
yield {'tokens': np.asarray(concat_sample).tobytes()}

first_chunk = False

# Add the EOS token to the buffer to separate files.
buffer += self.eos_tokens

# Yield any remaining samples of size max_length.
while len(buffer) >= self.max_length:
concat_sample = buffer[:self.max_length]
buffer = buffer[self.max_length:] if self.should_wrap else []
yield {'tokens': np.asarray(concat_sample).tobytes()}


def parse_args() -> Namespace:
"""Parse commandline arguments."""
parser = ArgumentParser(
Expand Down Expand Up @@ -277,8 +341,8 @@ def download_and_convert(

# Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up
# to the maximum sequence length
dataset = ConcatTokensDataset(
hf_dataset=downloading_iter,
dataset = ConcatTokensFromFilesDataset(
files=downloading_iter,
max_length=concat_tokens,
tokenizer=tokenizer,
eos_text=eos_text,
Expand Down
Loading