Skip to content

Commit

Permalink
Feat: Add dataset loading from S3, GCS (#765)
Browse files Browse the repository at this point in the history
* Feat: Add dataset loading from S3, GCS

* chore: update docs

* chore: add more info on cloud loading
  • Loading branch information
NanoCode012 committed Nov 16, 2023
1 parent 1bc1186 commit 3cc67d2
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 21 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- path: knowrohit07/know_sql
type: context_qa.load_v2
train_on_split: validation

# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
dataset:
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
...
```

- loading
Expand Down Expand Up @@ -520,7 +526,7 @@ float16: true

# A list of one or more datasets to finetune the model with
datasets:
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
Expand Down
7 changes: 6 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ deepspeed
addict
fire
PyYAML>=6.0
datasets
datasets>=2.14.0
flash-attn>=2.3.0
sentencepiece
wandb
Expand All @@ -33,3 +33,8 @@ art
fschat==0.2.29
gradio
tensorboard

# remote filesystems
s3fs
gcsfs
# adlfs
116 changes: 97 additions & 19 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,30 +170,74 @@ def for_d_in_datasets(dataset_configs):
except (FileNotFoundError, ConnectionError):
pass

ds_from_cloud = False
storage_options = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
try:
import aiobotocore.session # type: ignore
import s3fs # type: ignore
except ImportError as exc:
raise ImportError(
"s3:// paths require aiobotocore and s3fs to be installed"
) from exc

# Takes credentials from ~/.aws/credentials for default profile
s3_session = aiobotocore.session.AioSession(profile="default")
storage_options = {"session": s3_session}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith(
"gs://"
) or config_dataset.path.startswith("gcs://"):
try:
import gcsfs # type: ignore
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc

# gcsfs will use default credentials from the environment else anon
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
# TODO: Figure out how to get auth creds passed
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
# try:
# import adlfs
# except ImportError as exc:
# raise ImportError(
# "adl:// or abfs:// paths require adlfs to be installed"
# ) from exc

# # Gen 1
# storage_options = {
# "tenant_id": TENANT_ID,
# "client_id": CLIENT_ID,
# "client_secret": CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": ACCOUNT_NAME,
# "account_key": ACCOUNT_KEY,
# }

# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
try:
if remote_file_system and remote_file_system.exists(
config_dataset.path
):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass

# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
if local_path.is_dir():
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
split=None,
)
ds = load_from_disk(config_dataset.path)
elif local_path.is_file():
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
ds_type = get_ds_type(config_dataset)

ds = load_dataset(
ds_type,
name=config_dataset.name,
Expand All @@ -213,6 +257,22 @@ def for_d_in_datasets(dataset_configs):
data_files=config_dataset.data_files,
token=use_auth_token,
)
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
)
else:
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
Expand Down Expand Up @@ -304,6 +364,24 @@ def for_d_in_datasets(dataset_configs):
return dataset, prompters


def get_ds_type(config_dataset: DictDefault):
"""
Get the dataset type from the path if it's not specified
"""
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
return ds_type


def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase,
cfg,
Expand Down

0 comments on commit 3cc67d2

Please sign in to comment.