Skip to content

Commit

Permalink
Feat: Add dataset loading from S3, GCS
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Oct 22, 2023
1 parent 15d3a65 commit ec7ea57
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 20 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,11 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- path: knowrohit07/know_sql
type: context_qa.load_v2
train_on_split: validation

# dataset with splits, but no train split
dataset:
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
...
```

- loading
Expand Down
7 changes: 6 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ deepspeed
addict
fire
PyYAML>=6.0
datasets
datasets>=2.14.0
flash-attn>=2.3.0
sentencepiece
wandb
Expand All @@ -31,3 +31,8 @@ scikit-learn==1.2.2
pynvml
art
fschat==0.2.29

# remote
s3fs
gcsfs
# adlfs
111 changes: 92 additions & 19 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,30 +161,69 @@ def for_d_in_datasets(dataset_configs):
except (FileNotFoundError, ConnectionError):
pass

ds_from_cloud = False
storage_options = {}
remote_file_system = None
if d.path.startswith("s3://"):
try:
import aiobotocore.session
import s3fs
except ImportError as exc:
raise ImportError(
"s3:// paths require aiobotocore and s3fs to be installed"
) from exc

# Takes credentials from ~/.aws/credentials for default profile
s3_session = aiobotocore.session.AioSession(profile="default")
storage_options = {"session": s3_session}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif d.path.startswith("gs://") or d.path.startswith("gcs://"):
try:
import gcsfs
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc

# gcsfs will use default credentials from the environment else anon
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
# TODO: Figure out how to get auth creds passed
# elif d.path.startswith("adl://") or d.path.startswith("abfs://"):
# try:
# import adlfs
# except ImportError as exc:
# raise ImportError(
# "adl:// or abfs:// paths require adlfs to be installed"
# ) from exc

# # Gen 1
# storage_options = {
# "tenant_id": TENANT_ID,
# "client_id": CLIENT_ID,
# "client_secret": CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": ACCOUNT_NAME,
# "account_key": ACCOUNT_KEY,
# }

# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
try:
if remote_file_system and remote_file_system.exists(d.path):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass

# prefer local dataset, even if hub exists
local_path = Path(d.path)
if local_path.exists():
if local_path.is_dir():
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_dataset(
d.path,
name=d.name,
data_files=d.data_files,
streaming=False,
split=None,
)
ds = load_from_disk(d.path)
elif local_path.is_file():
ds_type = "json"
if d.ds_type:
ds_type = d.ds_type
elif ".parquet" in d.path:
ds_type = "parquet"
elif ".arrow" in d.path:
ds_type = "arrow"
elif ".csv" in d.path:
ds_type = "csv"
elif ".txt" in d.path:
ds_type = "text"
ds_type = get_ds_type(d)
ds = load_dataset(
ds_type,
name=d.name,
Expand All @@ -204,6 +243,22 @@ def for_d_in_datasets(dataset_configs):
data_files=d.data_files,
token=use_auth_token,
)
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(d.path):
ds = load_from_disk(
d.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(d.path):
ds_type = get_ds_type(d)
ds = load_dataset(
ds_type,
name=d.name,
data_files=d.path,
streaming=False,
split=None,
storage_options=storage_options,
)
else:
if isinstance(d.data_files, str):
fp = hf_hub_download(
Expand Down Expand Up @@ -371,6 +426,24 @@ def for_d_in_datasets(dataset_configs):
return dataset


def get_ds_type(ds_dict: DictDefault):
"""
Get the dataset type from the path if it's not specified
"""
ds_type = "json"
if ds_dict.ds_type:
ds_type = ds_dict.ds_type
elif ".parquet" in ds_dict.path:
ds_type = "parquet"
elif ".arrow" in ds_dict.path:
ds_type = "arrow"
elif ".csv" in ds_dict.path:
ds_type = "csv"
elif ".txt" in ds_dict.path:
ds_type = "text"
return ds_type


def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase,
cfg,
Expand Down

0 comments on commit ec7ea57

Please sign in to comment.