Skip to content

Commit

Permalink
Support eval loader when finetuning from JSONL files in object stores (
Browse files Browse the repository at this point in the history
…#469)

* try hardcoded ft file path

* fix eval set with remote ft data

* fix linter errors

* change dl path

* only dl rank 0

* only dl rank 0 cleanup

* Apply suggestions from code review

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>

* CR fixes and encapsualte

* fix docstring

* fix docstring

* Apply suggestions from code review

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>

---------

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
  • Loading branch information
samhavens and dakinggg authored Jul 21, 2023
1 parent ef8d414 commit 66b84bb
Showing 1 changed file with 79 additions and 35 deletions.
114 changes: 79 additions & 35 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import logging
import os
import tempfile
from typing import Union

import torch
Expand Down Expand Up @@ -156,40 +155,7 @@ def build_finetuning_dataloader(cfg: DictConfig, tokenizer: Tokenizer,
'When using a HuggingFace dataset from a URL, you must set the ' + \
'`split` key in the dataset config.'
)
supported_extensions = ['jsonl', 'csv', 'parquet']
with tempfile.TemporaryDirectory() as tmp_dir:
for extension in supported_extensions:
name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}'
destination = str(
os.path.abspath(
f'{tmp_dir}/{cfg.dataset.split}.{extension}'))
try:
with dist.run_local_rank_zero_first():
get_file(name, destination, overwrite=True)
except FileNotFoundError as e:
if extension == supported_extensions[-1]:
raise FileNotFoundError(
f'Could not find a {cfg.dataset.split} file with any of ' + \
f'the supported extensions: {supported_extensions}\n' + \
f'at {cfg.dataset.hf_name}/{cfg.dataset.split}'
) from e
else:
print(
f'Could not find {name}, looking for another extension'
)
continue
# 'json' causes special behavior in the dataset constructor
cfg.dataset.hf_name = extension if extension != 'jsonl' else 'json'
kwargs = cfg.dataset.get('hf_kwargs', {})
kwargs['data_files'] = destination
cfg.dataset['hf_kwargs'] = kwargs
print(cfg.dataset)
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
tokenizer=tokenizer,
)
break
dataset = _build_hf_dataset_from_remote(cfg, tokenizer)
else:
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
Expand Down Expand Up @@ -269,6 +235,84 @@ def _validate_config(dataset_cfg: DictConfig):
)


def _build_hf_dataset_from_remote(cfg: DictConfig, tokenizer: Tokenizer):
"""Builds a dataset from a remote object store.
This function supports 'jsonl', 'csv', and 'parquet' file formats for the dataset. It will attempt to download
the dataset, then once it is downloaded, convert it into HuggingFace ``datasets`` format, and then return this
dataset.
The function also ensures synchronicity across multiple processes during the file download. It creates a signal
file that is used to synchronize the start of the download across different processes. Once the download is
completed, the function removes the signal file.
Args:
cfg (DictConfig): The configuration dictionary containing the necessary parameters to load the dataset.
This includes:
- dataset.hf_name: The path of the HuggingFace dataset to download.
- dataset.split: The dataset split to download (e.g., 'train', 'validation', 'test').
- dataset.max_seq_len: The maximum sequence length for tokenizing the dataset.
tokenizer (Tokenizer): The tokenizer to be used to tokenize the dataset.
Returns:
Dataset: A HuggingFace dataset built from the remote file, prepared and tokenized for fine-tuning the model.
Raises:
FileNotFoundError: Raised if the dataset file cannot be found with any of the supported extensions.
"""
supported_extensions = ['jsonl', 'csv', 'parquet']
finetune_dir = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
f'downloaded_finetuning_data/{cfg.dataset.split}')
os.makedirs(finetune_dir, exist_ok=True)
for extension in supported_extensions:
name = f'{cfg.dataset.hf_name.strip("/")}/{cfg.dataset.split}.{extension}'
destination = str(
os.path.abspath(f'{finetune_dir}/{cfg.dataset.split}.{extension}'))
# Since we don't know exactly what the extension will be, since it is one of a list
# use a signal file to wait for instead of the desired file
signal_file_path = os.path.join(finetune_dir, '.the_eagle_has_landed')
if dist.get_local_rank() == 0:
try:
get_file(name, destination, overwrite=True)
except FileNotFoundError as e:
if extension == supported_extensions[-1]:
raise FileNotFoundError(
f'Could not find a {cfg.dataset.split} file with any of ' + \
f'the supported extensions: {supported_extensions}\n' + \
f'at {cfg.dataset.hf_name}/{cfg.dataset.split}'
) from e
else:
print(
f'Could not find {name}, looking for another extension')
continue

os.makedirs(os.path.dirname(signal_file_path), exist_ok=True)
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_download')

# Avoid the collective call until the local rank zero has finished trying to download the checkpoint
# so that we don't timeout for large downloads. This syncs all processes on the node
with dist.local_rank_zero_download_and_wait(signal_file_path):
# Then, wait to ensure every node has finished downloading the checkpoint
dist.barrier()

# clean up signal file
if dist.get_local_rank() == 0:
os.remove(signal_file_path)
dist.barrier()

cfg.dataset.hf_name = finetune_dir
print(cfg.dataset)
dataset = dataset_constructor.build_from_hf(
cfg.dataset,
max_seq_len=cfg.dataset.max_seq_len,
tokenizer=tokenizer,
)
return dataset


def _build_collate_fn(dataset_cfg: DictConfig, tokenizer: Tokenizer,
device_batch_size: int):
collate_fn = Seq2SeqFinetuningCollator(
Expand Down

0 comments on commit 66b84bb

Please sign in to comment.