Skip to content

Commit

Permalink
fix: Huggingface Trust Remote Repo (#9535)
Browse files Browse the repository at this point in the history
* set trust_remote_code=True in HuggingFace examples
  • Loading branch information
MikhailKardash authored Jun 18, 2024
1 parent 3320107 commit 44f446c
Show file tree
Hide file tree
Showing 12 changed files with 19 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@ entrypoint: >-
--seed 1337
--save_strategy steps
--save_steps 20
--trust_remote_code false
--trust_remote_code true
max_restarts: 0
2 changes: 1 addition & 1 deletion examples/hf_trainer_api/hf_image_classification/const.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ entrypoint: >-
--seed 1337
--save_strategy steps
--save_steps 20
--trust_remote_code false
--trust_remote_code true
max_restarts: 0
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ entrypoint: >-
--save_total_limit 3
--seed 1337
--save_strategy epoch
--trust_remote_code false
--trust_remote_code true
max_restarts: 0
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ entrypoint: >-
--save_strategy steps
--save_steps 20
--deepspeed ds_configs/ds_config_stage_1.json
--trust_remote_code false
--trust_remote_code true
--fp16
max_restarts: 0
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ entrypoint: >-
--seed 1337
--save_strategy steps
--save_steps 20
--trust_remote_code false
--trust_remote_code true
max_restarts: 0
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def main(det_callback, tb_callback, model_args, data_args, training_args):
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
trust_remote_code=True,
)
else:
data_files = {}
Expand All @@ -290,6 +291,7 @@ def main(det_callback, tb_callback, model_args, data_args, training_args):
data_files=data_files,
cache_dir=model_args.cache_dir,
task="image-classification",
trust_remote_code=True,
)

# If we don't have a validation split, split off a percentage of train as validation.
Expand All @@ -310,7 +312,7 @@ def main(det_callback, tb_callback, model_args, data_args, training_args):
id2label[str(i)] = label

# Load the accuracy metric from the datasets package
metric = datasets.load_metric("accuracy")
metric = datasets.load_metric("accuracy", trust_remote_code=True,)

# Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float.
Expand Down
2 changes: 1 addition & 1 deletion examples/hf_trainer_api/hf_language_modeling/adaptive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,5 @@ entrypoint: >-
--save_steps 20
--per_device_train_batch_size 8
--per_device_eval_batch_size 8
--trust_remote_code false
--trust_remote_code true
max_restarts: 0
2 changes: 1 addition & 1 deletion examples/hf_trainer_api/hf_language_modeling/const.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ entrypoint: >-
--save_steps 20
--per_device_train_batch_size 8
--per_device_eval_batch_size 8
--trust_remote_code false
--trust_remote_code true
max_restarts: 0
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ entrypoint: >-
--save_steps 20
--per_device_train_batch_size 8
--per_device_eval_batch_size 8
--trust_remote_code false
--trust_remote_code true
max_restarts: 0
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ entrypoint: >-
--deepspeed ds_configs/ds_config_stage_1.json
--per_device_train_batch_size 8
--per_device_eval_batch_size 8
--trust_remote_code false
--trust_remote_code true
--fp16
max_restarts: 0
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ entrypoint: >-
--save_steps 20
--per_device_train_batch_size 8
--per_device_eval_batch_size 8
--trust_remote_code false
--trust_remote_code true
max_restarts: 0
6 changes: 6 additions & 0 deletions examples/hf_trainer_api/hf_language_modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def main(det_callback, tb_callback, model_args, data_args, training_args):
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
trust_remote_code=True,
)
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
Expand All @@ -376,6 +377,7 @@ def main(det_callback, tb_callback, model_args, data_args, training_args):
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
trust_remote_code=True,
)
raw_datasets["train"] = load_dataset(
data_args.dataset_name,
Expand All @@ -384,6 +386,7 @@ def main(det_callback, tb_callback, model_args, data_args, training_args):
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
trust_remote_code=True,
)
else:
data_files = {}
Expand All @@ -406,6 +409,7 @@ def main(det_callback, tb_callback, model_args, data_args, training_args):
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
**dataset_args,
trust_remote_code=True,
)
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys():
Expand All @@ -416,6 +420,7 @@ def main(det_callback, tb_callback, model_args, data_args, training_args):
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
**dataset_args,
trust_remote_code=True,
)
raw_datasets["train"] = load_dataset(
extension,
Expand All @@ -424,6 +429,7 @@ def main(det_callback, tb_callback, model_args, data_args, training_args):
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
**dataset_args,
trust_remote_code=True,
)

# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
Expand Down

0 comments on commit 44f446c

Please sign in to comment.