Skip to content

Commit

Permalink
Merge pull request axolotl-ai-cloud#2 from Y-IAB/puree-schema
Browse files Browse the repository at this point in the history
[axolotl-ai-cloud#3] Add Support for Puree Schema
  • Loading branch information
seungduk-yanolja authored Mar 20, 2024
2 parents a0ed7a7 + 5f07d02 commit 9473307
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,11 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
ds_type: json # this is the default, see other options below.
# Loading dataset from Puree
# - `ds_type` needs to be parquet since puree always store dataset on parquet format
- path: puree://dataset_id
ds_type: parquet
```
- loading
Expand Down
24 changes: 24 additions & 0 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def for_d_in_datasets(dataset_configs):
pass

ds_from_cloud = False
ds_from_puree = False
storage_options = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
Expand Down Expand Up @@ -256,6 +257,17 @@ def for_d_in_datasets(dataset_configs):
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
elif config_dataset.path.startswith("puree://"):
try:
import gcsfs # type: ignore
except ImportError as exc:
raise ImportError(
"puree:// paths require gcsfs to be installed"
) from exc

storage_options = {"token": None}
ds_from_puree = True

# TODO: Figure out how to get auth creds passed
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
# try:
Expand Down Expand Up @@ -336,6 +348,18 @@ def for_d_in_datasets(dataset_configs):
split=None,
storage_options=storage_options,
)
elif ds_from_puree:
ds_type = get_ds_type(config_dataset)
dataset_id = config_dataset.path.split("://")[1]
data_files = f"gs://puree/datasets/{dataset_id}/*.{ds_type}"
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=data_files,
streaming=False,
split=None,
storage_options=storage_options,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
Expand Down

0 comments on commit 9473307

Please sign in to comment.