Skip to content

Commit

Permalink
Feat(wandb): Refactor to be more flexible (#767)
Browse files Browse the repository at this point in the history
* Feat: Update to handle wandb env better

* chore: rename wandb_run_id to wandb_name

* feat: add new recommendation and update config

* fix: indent and pop disabled env if project passed

* feat: test env set for wandb and recommendation

* feat: update to use wandb_name and allow id

* chore: add info to readme
  • Loading branch information
NanoCode012 committed Dec 4, 2023
1 parent 58ec8b1 commit a1da39c
Show file tree
Hide file tree
Showing 39 changed files with 140 additions and 50 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,8 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
wandb_project: # Your wandb project name
wandb_entity: # A wandb Team name if using a Team
wandb_watch:
wandb_run_id: # Set the name of your wandb run
wandb_name: # Set the name of your wandb run
wandb_run_id: # Set the ID of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training

# Where to save the full-finetuned model to
Expand Down Expand Up @@ -955,7 +956,7 @@ wandb_mode:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
```

Expand Down
2 changes: 1 addition & 1 deletion examples/cerebras/btlm-ft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

output_dir: btlm-out
Expand Down
2 changes: 1 addition & 1 deletion examples/cerebras/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./qlora-out
batch_size: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/code-llama/13b/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/code-llama/13b/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/code-llama/34b/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/code-llama/34b/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/code-llama/7b/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/code-llama/7b/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/falcon/config-7b-lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./falcon-7b
batch_size: 2
Expand Down
2 changes: 1 addition & 1 deletion examples/falcon/config-7b-qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./qlora-out

Expand Down
2 changes: 1 addition & 1 deletion examples/falcon/config-7b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./falcon-7b
batch_size: 2
Expand Down
2 changes: 1 addition & 1 deletion examples/gptj/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./qlora-out
gradient_accumulation_steps: 2
Expand Down
2 changes: 1 addition & 1 deletion examples/jeopardy-bot/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./jeopardy-bot-7b
gradient_accumulation_steps: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/llama-2/fft_optimized.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/llama-2/gptq-lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./model-out
gradient_accumulation_steps: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/llama-2/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/llama-2/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/llama-2/relora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ relora_cpu_offload: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/llama-2/tiny-llama.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ lora_target_modules:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/mpt-7b/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./mpt-alpaca-7b
gradient_accumulation_steps: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/openllama-3b/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./openllama-out
gradient_accumulation_steps: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/openllama-3b/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./lora-out
gradient_accumulation_steps: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/openllama-3b/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./qlora-out
gradient_accumulation_steps: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/phi/phi-ft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/phi/phi-qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/pythia-12b/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./pythia-12b
gradient_accumulation_steps: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/pythia/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/qwen/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/qwen/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/redpajama/config-3b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./redpajama-alpaca-3b
batch_size: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/replit-3b/config-lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ lora_fan_in_fan_out:
wandb_project: lora-replit
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./lora-replit
batch_size: 8
Expand Down
2 changes: 1 addition & 1 deletion examples/xgen-7b/xgen-7b-8k-qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:
output_dir: ./qlora-out

Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def build(self, total_num_steps):
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_run_id if self.cfg.use_wandb else None
self.cfg.wandb_name if self.cfg.use_wandb else None
)
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
Expand Down
7 changes: 7 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,13 @@ def validate_config(cfg):
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
)

if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id

LOG.warning(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)

# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
Expand Down
26 changes: 13 additions & 13 deletions src/axolotl/utils/wandb_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@

import os

from axolotl.utils.dict import DictDefault

def setup_wandb_env_vars(cfg):
if cfg.wandb_mode and cfg.wandb_mode == "offline":
os.environ["WANDB_MODE"] = cfg.wandb_mode
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
os.environ["WANDB_PROJECT"] = cfg.wandb_project

def setup_wandb_env_vars(cfg: DictDefault):
for key in cfg.keys():
if key.startswith("wandb_"):
value = cfg.get(key, "")

if value and isinstance(value, str) and len(value) > 0:
os.environ[key.upper()] = value

# Enable wandb if project name is present
if cfg.wandb_project and len(cfg.wandb_project) > 0:
cfg.use_wandb = True
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
os.environ.pop("WANDB_DISABLED", None) # Remove if present
else:
os.environ["WANDB_DISABLED"] = "true"
Loading

0 comments on commit a1da39c

Please sign in to comment.