Skip to content

Commit

Permalink
commit comments 1
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Jul 13, 2024
1 parent 541b62b commit c61cf73
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 61 deletions.
8 changes: 3 additions & 5 deletions llmfoundry/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from argparse import Namespace
from typing import Optional

import typer
Expand All @@ -28,7 +27,7 @@ def train(


@app.command(name='convert_dataset_json')
def convert_dataset_json_cli(
def convert_dataset_json(
path: str = typer.Option(
...,
'--path',
Expand Down Expand Up @@ -71,19 +70,18 @@ def convert_dataset_json_cli(
help='Number of workers for data loading',
), # type: ignore
):
args = Namespace(
convert_dataset_json_from_args(
path=path,
split=split,
out_root=out_root,
compression=compression,
concat_tokens=concat_tokens,
split=split,
tokenizer=tokenizer,
bos_text=bos_text,
eos_text=eos_text,
no_wrap=no_wrap,
num_workers=num_workers,
)
convert_dataset_json_from_args(args)


if __name__ == '__main__':
Expand Down
88 changes: 60 additions & 28 deletions llmfoundry/data_prep/convert_dataset_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

"""Streaming dataset conversion scripts for json files."""
import os
from argparse import Namespace
from enum import Enum
from glob import glob
from typing import Optional
Expand Down Expand Up @@ -91,33 +90,44 @@ def build_hf_dataset(
return dataset


def convert_dataset_json(args: Namespace) -> None:
def convert_dataset_json(
path: str,
out_root: str,
compression: Optional[str],
concat_tokens: Optional[int],
split: str,
tokenizer: Optional[str] = None,
bos_text: str = '',
eos_text: str = '',
no_wrap: bool = False,
num_workers: Optional[int] = None,
) -> None:
"""Main: create C4/pile streaming dataset.
Args:
args (Namespace): Commandline arguments.
"""
if args.concat_tokens is not None:
if concat_tokens is not None:
mode = ConcatMode.CONCAT_TOKENS
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
built_tokenizer = AutoTokenizer.from_pretrained(tokenizer)
# we will enforce length, so suppress warnings about sequences too long for the model
tokenizer.model_max_length = int(1e30)
built_tokenizer.model_max_length = int(1e30)
columns = {'tokens': 'ndarray:int32'}
else:
mode = ConcatMode.NO_CONCAT
tokenizer = None
built_tokenizer = None
columns = {'text': 'str'}

# Get samples
dataset = build_hf_dataset(
path=args.path,
split=args.split,
path=path,
split=split,
mode=mode,
max_length=args.concat_tokens,
bos_text=args.bos_text,
eos_text=args.eos_text,
no_wrap=args.no_wrap,
tokenizer=tokenizer,
max_length=concat_tokens,
bos_text=bos_text,
eos_text=eos_text,
no_wrap=no_wrap,
tokenizer=built_tokenizer,
)

print('here')
Expand All @@ -130,34 +140,56 @@ def convert_dataset_json(args: Namespace) -> None:
print(f'It will finish at a value below 100% if tokenizing')
with MDSWriter(
columns=columns,
out=os.path.join(args.out_root),
compression=args.compression,
out=os.path.join(out_root),
compression=compression,
) as out:
for sample in tqdm(dataset):
out.write(sample)


def convert_dataset_json_from_args(args: Namespace) -> None:
if os.path.isdir(args.out_root) and len(
set(os.listdir(args.out_root)).intersection(set(args.split)),
def convert_dataset_json_from_args(
path: str,
out_root: str,
compression: Optional[str],
concat_tokens: Optional[int],
split: str,
tokenizer: Optional[str] = None,
bos_text: Optional[str] = None,
eos_text: Optional[str] = None,
no_wrap: bool = False,
num_workers: Optional[int] = None,
) -> None:
if os.path.isdir(out_root) and len(
set(os.listdir(out_root)).intersection(set(split)),
) > 0:
raise ValueError(
f'--out_root={args.out_root} contains {os.listdir(args.out_root)} which cannot overlap with the requested splits {args.splits}.',
f'--out_root={out_root} contains {os.listdir(out_root)} which cannot overlap with the requested splits {split}.',
)

# Make sure we have needed concat options
if (
args.concat_tokens is not None and
isinstance(args.concat_tokens, int) and args.tokenizer is None
concat_tokens is not None and isinstance(concat_tokens, int) and
tokenizer is None
):
args.error(
ValueError(
'When setting --concat_tokens, you must specify a --tokenizer',
)

# now that we have validated them, change BOS/EOS to strings
if args.bos_text is None:
args.bos_text = ''
if args.eos_text is None:
args.eos_text = ''

convert_dataset_json(args)
if bos_text is None:
bos_text = ''
if eos_text is None:
eos_text = ''

convert_dataset_json(
path=path,
out_root=out_root,
compression=compression,
concat_tokens=concat_tokens,
split=split,
tokenizer=tokenizer,
bos_text=bos_text,
eos_text=eos_text,
no_wrap=no_wrap,
num_workers=num_workers,
)
13 changes: 12 additions & 1 deletion scripts/data_prep/convert_dataset_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,15 @@ def parse_args() -> Namespace:


if __name__ == '__main__':
convert_dataset_json_from_args(parse_args())
args = parse_args()
convert_dataset_json_from_args(
path=args.path,
out_root=args.out_root,
compression=args.compression,
concat_tokens=args.concat_tokens,
split=args.split,
tokenizer=args.tokenizer,
bos_text=args.bos_text,
eos_text=args.eos_text,
no_wrap=args.no_wrap,
)
23 changes: 9 additions & 14 deletions tests/a_scripts/data_prep/test_convert_dataset_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0

import os
from argparse import Namespace
from pathlib import Path

from llmfoundry.data_prep import convert_dataset_json
Expand All @@ -12,18 +11,14 @@ def test_json_script_from_api(tmp_path: Path):
# test calling it directly
path = os.path.join(tmp_path, 'my-copy-arxiv-1')
convert_dataset_json(
Namespace(
**{
'path': 'scripts/data_prep/example_data/arxiv.jsonl',
'out_root': path,
'compression': None,
'split': 'train',
'concat_tokens': None,
'bos_text': None,
'eos_text': None,
'no_wrap': False,
'num_workers': None,
},
),
path='scripts/data_prep/example_data/arxiv.jsonl',
out_root=path,
compression=None,
split='train',
concat_tokens=None,
bos_text='',
eos_text='',
no_wrap=False,
num_workers=None,
)
assert os.path.exists(path)
22 changes: 9 additions & 13 deletions tests/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,19 +269,15 @@ def create_arxiv_dataset(path: Path) -> str:
arxiv_path = os.path.join('scripts', arxiv_path)

convert_dataset_json(
Namespace(
**{
'path': arxiv_path,
'out_root': arxiv_dir,
'compression': None,
'split': downloaded_split,
'concat_tokens': None,
'bos_text': None,
'eos_text': None,
'no_wrap': False,
'num_workers': None,
},
),
path=arxiv_path,
out_root=arxiv_dir,
compression=None,
split=downloaded_split,
concat_tokens=None,
bos_text='',
eos_text='',
no_wrap=False,
num_workers=None,
)

return arxiv_dir
Expand Down

0 comments on commit c61cf73

Please sign in to comment.