Skip to content

Commit

Permalink
Feat: Add dataset loading from S3, GCS
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Nov 15, 2023
1 parent 614cff4 commit dc5fddf
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 20 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,11 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- path: knowrohit07/know_sql
type: context_qa.load_v2
train_on_split: validation

# dataset with splits, but no train split
dataset:
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
...
```

- loading
Expand Down
7 changes: 6 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ deepspeed
addict
fire
PyYAML>=6.0
datasets
datasets>=2.14.0
flash-attn>=2.3.0
sentencepiece
wandb
Expand All @@ -32,3 +32,8 @@ pynvml
art
fschat==0.2.29
gradio

# remote filesystems
s3fs
gcsfs
# adlfs
116 changes: 97 additions & 19 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,30 +165,74 @@ def for_d_in_datasets(dataset_configs):
except (FileNotFoundError, ConnectionError):
pass

ds_from_cloud = False
storage_options = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
try:
import aiobotocore.session # type: ignore
import s3fs # type: ignore
except ImportError as exc:
raise ImportError(
"s3:// paths require aiobotocore and s3fs to be installed"
) from exc

# Takes credentials from ~/.aws/credentials for default profile
s3_session = aiobotocore.session.AioSession(profile="default")
storage_options = {"session": s3_session}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith(
"gs://"
) or config_dataset.path.startswith("gcs://"):
try:
import gcsfs # type: ignore
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc

# gcsfs will use default credentials from the environment else anon
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
# TODO: Figure out how to get auth creds passed
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
# try:
# import adlfs
# except ImportError as exc:
# raise ImportError(
# "adl:// or abfs:// paths require adlfs to be installed"
# ) from exc

# # Gen 1
# storage_options = {
# "tenant_id": TENANT_ID,
# "client_id": CLIENT_ID,
# "client_secret": CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": ACCOUNT_NAME,
# "account_key": ACCOUNT_KEY,
# }

# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
try:
if remote_file_system and remote_file_system.exists(
config_dataset.path
):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass

# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
if local_path.is_dir():
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
split=None,
)
ds = load_from_disk(config_dataset.path)
elif local_path.is_file():
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
ds_type = get_ds_type(config_dataset)

ds = load_dataset(
ds_type,
name=config_dataset.name,
Expand All @@ -208,6 +252,22 @@ def for_d_in_datasets(dataset_configs):
data_files=config_dataset.data_files,
token=use_auth_token,
)
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
)
else:
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
Expand Down Expand Up @@ -299,6 +359,24 @@ def for_d_in_datasets(dataset_configs):
return dataset, prompters


def get_ds_type(config_dataset: DictDefault):
"""
Get the dataset type from the path if it's not specified
"""
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
return ds_type


def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase,
cfg,
Expand Down

0 comments on commit dc5fddf

Please sign in to comment.