Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add remote code option to allow execution of DBRX tokenizer #1106

Merged
merged 1 commit into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions scripts/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def parse_args() -> Namespace:
help='If true, reprocess the input_folder to mds format. Otherwise, ' +
'only reprocess upon changes to the input folder or dataset creation parameters.',
)
parser.add_argument(
'--trust-remote-code',
type=bool,
required=False,
default=False,
help='If true, allows custom code to be executed to load the tokenizer',
)

parsed = parser.parse_args()

Expand All @@ -124,7 +131,8 @@ def parse_args() -> Namespace:
parser.error(
'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.'
)
tokenizer = AutoTokenizer.from_pretrained(parsed.tokenizer)
tokenizer = AutoTokenizer.from_pretrained(
parsed.tokenizer, trust_remote_code=parsed.trust_remote_code)
parsed.eos_text = tokenizer.eos_token

# now that we have validated them, change BOS/EOS to strings
Expand Down Expand Up @@ -171,6 +179,7 @@ def get_task_args(
bos_text: str,
no_wrap: bool,
compression: str,
trust_remote_code: bool,
) -> Iterable:
"""Get download_and_convert arguments split across n_groups.

Expand All @@ -187,6 +196,7 @@ def get_task_args(
bos_text (str): Text to prepend to each example to separate concatenated samples
no_wrap: (bool): Whether to let text examples wrap across multiple training examples
compression (str): The compression algorithm to use for MDS writing
trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer
"""
num_objects = len(object_names)
objs_per_group = math.ceil(num_objects / n_groups)
Expand All @@ -202,6 +212,7 @@ def get_task_args(
bos_text,
no_wrap,
compression,
trust_remote_code,
)


Expand All @@ -223,6 +234,7 @@ def download_and_convert(
bos_text: str,
no_wrap: bool,
compression: str,
trust_remote_code: bool,
):
"""Downloads and converts text fies to MDS format.

Expand All @@ -236,6 +248,7 @@ def download_and_convert(
bos_text (str): Text to prepend to each example to separate concatenated samples
no_wrap: (bool): Whether to let text examples wrap across multiple training examples
compression (str): The compression algorithm to use for MDS writing
trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer
"""
object_store = maybe_create_object_store_from_uri(input_folder)

Expand All @@ -244,7 +257,8 @@ def download_and_convert(
downloading_iter = DownloadingIterable(object_names=file_names,
output_folder=tmp_dir,
object_store=object_store)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, trust_remote_code=trust_remote_code)
tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace

# Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up
Expand Down Expand Up @@ -353,6 +367,7 @@ def convert_text_to_mds(
processes: int,
args_str: str,
reprocess: bool,
trust_remote_code: bool,
):
"""Convert a folder of text files to MDS format.

Expand All @@ -368,6 +383,7 @@ def convert_text_to_mds(
processes (int): The number of processes to use.
args_str (str): String representation of the arguments
reprocess (bool): Whether to always reprocess the given folder of text files
trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer
"""
is_remote_output = is_remote_path(output_folder)

Expand Down Expand Up @@ -396,7 +412,7 @@ def convert_text_to_mds(
# Download and convert the text files in parallel
args = get_task_args(object_names, local_output_folder, input_folder,
processes, tokenizer_name, concat_tokens, eos_text,
bos_text, no_wrap, compression)
bos_text, no_wrap, compression, trust_remote_code)
with ProcessPoolExecutor(max_workers=processes) as executor:
list(executor.map(download_and_convert_starargs, args))

Expand All @@ -405,7 +421,7 @@ def convert_text_to_mds(
else:
download_and_convert(object_names, local_output_folder, input_folder,
tokenizer_name, concat_tokens, eos_text, bos_text,
no_wrap, compression)
no_wrap, compression, trust_remote_code)

# Write a done file with the args and object names
write_done_file(local_output_folder, args_str, object_names)
Expand Down Expand Up @@ -462,6 +478,7 @@ def _args_str(original_args: Namespace) -> str:
compression=args.compression,
processes=args.processes,
reprocess=args.reprocess,
trust_remote_code=args.trust_remote_code,
args_str=_args_str(args))
except Exception as e:
if mosaicml_logger is not None:
Expand Down
3 changes: 3 additions & 0 deletions tests/a_scripts/data_prep/test_convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def call_convert_text_to_mds() -> None:
processes=processes,
args_str='Namespace()',
reprocess=False,
trust_remote_code=False,
)

call_convert_text_to_mds()
Expand Down Expand Up @@ -195,6 +196,7 @@ def call_convert_text_to_mds(reprocess: bool):
processes=1,
args_str='Namespace()',
reprocess=reprocess,
trust_remote_code=False,
)

# Create input text data
Expand Down Expand Up @@ -234,6 +236,7 @@ def test_input_folder_not_exist(tmp_path: pathlib.Path):
processes=1,
args_str='Namespace()',
reprocess=False,
trust_remote_code=False,
)


Expand Down
Loading