From 2f3ddcfad09ccdd5e5f0de6a963e55a1b602d3d9 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 22 Oct 2023 23:13:28 +0900 Subject: [PATCH 1/7] Feat: Update to handle wandb env better --- src/axolotl/utils/wandb_.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/axolotl/utils/wandb_.py b/src/axolotl/utils/wandb_.py index 6c3af3177..b0d5f6703 100644 --- a/src/axolotl/utils/wandb_.py +++ b/src/axolotl/utils/wandb_.py @@ -2,20 +2,19 @@ import os +from axolotl.utils.dict import DictDefault -def setup_wandb_env_vars(cfg): - if cfg.wandb_mode and cfg.wandb_mode == "offline": - os.environ["WANDB_MODE"] = cfg.wandb_mode - elif cfg.wandb_project and len(cfg.wandb_project) > 0: - os.environ["WANDB_PROJECT"] = cfg.wandb_project + +def setup_wandb_env_vars(cfg: DictDefault): + for key in cfg.keys(): + if key.startswith("wandb_"): + value = cfg.get(key, "") + + if value and len(value) > 0: + os.environ[key.upper()] = value + + # Enable wandb if project name is present + if cfg.wandb_project and len(cfg.wandb_project) > 0: cfg.use_wandb = True - if cfg.wandb_entity and len(cfg.wandb_entity) > 0: - os.environ["WANDB_ENTITY"] = cfg.wandb_entity - if cfg.wandb_watch and len(cfg.wandb_watch) > 0: - os.environ["WANDB_WATCH"] = cfg.wandb_watch - if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0: - os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model - if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0: - os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id else: os.environ["WANDB_DISABLED"] = "true" From 1a710479157f6d18b7590a9c040b796eff6e2077 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 22 Oct 2023 23:14:43 +0900 Subject: [PATCH 2/7] chore: rename wandb_run_id to wandb_name --- README.md | 4 ++-- examples/cerebras/btlm-ft.yml | 2 +- examples/cerebras/qlora.yml | 2 +- examples/code-llama/13b/lora.yml | 2 +- examples/code-llama/13b/qlora.yml | 2 +- examples/code-llama/34b/lora.yml | 2 +- examples/code-llama/34b/qlora.yml | 2 +- examples/code-llama/7b/lora.yml | 2 +- examples/code-llama/7b/qlora.yml | 2 +- examples/falcon/config-7b-lora.yml | 2 +- examples/falcon/config-7b-qlora.yml | 2 +- examples/falcon/config-7b.yml | 2 +- examples/gptj/qlora.yml | 2 +- examples/jeopardy-bot/config.yml | 2 +- examples/llama-2/fft_optimized.yml | 2 +- examples/llama-2/gptq-lora.yml | 2 +- examples/llama-2/lora.yml | 2 +- examples/llama-2/qlora.yml | 2 +- examples/llama-2/relora.yml | 2 +- examples/llama-2/tiny-llama.yml | 2 +- examples/mistral/config.yml | 2 +- examples/mistral/qlora.yml | 2 +- examples/mpt-7b/config.yml | 2 +- examples/openllama-3b/config.yml | 2 +- examples/openllama-3b/lora.yml | 2 +- examples/openllama-3b/qlora.yml | 2 +- examples/phi/phi-ft.yml | 2 +- examples/phi/phi-qlora.yml | 2 +- examples/pythia-12b/config.yml | 2 +- examples/pythia/lora.yml | 2 +- examples/redpajama/config-3b.yml | 2 +- examples/replit-3b/config-lora.yml | 2 +- examples/xgen-7b/xgen-7b-8k-qlora.yml | 2 +- 33 files changed, 34 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 093b8210d..8ef4b635b 100644 --- a/README.md +++ b/README.md @@ -659,7 +659,7 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server, wandb_project: # Your wandb project name wandb_entity: # A wandb Team name if using a Team wandb_watch: -wandb_run_id: # Set the name of your wandb run +wandb_name: # Set the name of your wandb run wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training # Where to save the full-finetuned model to @@ -952,7 +952,7 @@ wandb_mode: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: ``` diff --git a/examples/cerebras/btlm-ft.yml b/examples/cerebras/btlm-ft.yml index 1fea9915e..4c85a4c55 100644 --- a/examples/cerebras/btlm-ft.yml +++ b/examples/cerebras/btlm-ft.yml @@ -35,7 +35,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: btlm-out diff --git a/examples/cerebras/qlora.yml b/examples/cerebras/qlora.yml index 9f1dcc852..7b640fc27 100644 --- a/examples/cerebras/qlora.yml +++ b/examples/cerebras/qlora.yml @@ -24,7 +24,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./qlora-out batch_size: 4 diff --git a/examples/code-llama/13b/lora.yml b/examples/code-llama/13b/lora.yml index f3df1a1e2..45f66e02d 100644 --- a/examples/code-llama/13b/lora.yml +++ b/examples/code-llama/13b/lora.yml @@ -29,7 +29,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 4 diff --git a/examples/code-llama/13b/qlora.yml b/examples/code-llama/13b/qlora.yml index 8bcd0dc78..79c684f88 100644 --- a/examples/code-llama/13b/qlora.yml +++ b/examples/code-llama/13b/qlora.yml @@ -31,7 +31,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 4 diff --git a/examples/code-llama/34b/lora.yml b/examples/code-llama/34b/lora.yml index 2eb9df481..809e00710 100644 --- a/examples/code-llama/34b/lora.yml +++ b/examples/code-llama/34b/lora.yml @@ -29,7 +29,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 4 diff --git a/examples/code-llama/34b/qlora.yml b/examples/code-llama/34b/qlora.yml index 3093ec01f..ed927e51e 100644 --- a/examples/code-llama/34b/qlora.yml +++ b/examples/code-llama/34b/qlora.yml @@ -31,7 +31,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 4 diff --git a/examples/code-llama/7b/lora.yml b/examples/code-llama/7b/lora.yml index 422351d9a..37d6ae3b7 100644 --- a/examples/code-llama/7b/lora.yml +++ b/examples/code-llama/7b/lora.yml @@ -29,7 +29,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 4 diff --git a/examples/code-llama/7b/qlora.yml b/examples/code-llama/7b/qlora.yml index f5712c009..491e07c98 100644 --- a/examples/code-llama/7b/qlora.yml +++ b/examples/code-llama/7b/qlora.yml @@ -31,7 +31,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 4 diff --git a/examples/falcon/config-7b-lora.yml b/examples/falcon/config-7b-lora.yml index 25884410a..ef5eec1b7 100644 --- a/examples/falcon/config-7b-lora.yml +++ b/examples/falcon/config-7b-lora.yml @@ -26,7 +26,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./falcon-7b batch_size: 2 diff --git a/examples/falcon/config-7b-qlora.yml b/examples/falcon/config-7b-qlora.yml index 8e90e6614..03e2e3388 100644 --- a/examples/falcon/config-7b-qlora.yml +++ b/examples/falcon/config-7b-qlora.yml @@ -40,7 +40,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./qlora-out diff --git a/examples/falcon/config-7b.yml b/examples/falcon/config-7b.yml index fd5f63ccc..bf66d63e0 100644 --- a/examples/falcon/config-7b.yml +++ b/examples/falcon/config-7b.yml @@ -26,7 +26,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./falcon-7b batch_size: 2 diff --git a/examples/gptj/qlora.yml b/examples/gptj/qlora.yml index 57f132047..0e79bcd1d 100644 --- a/examples/gptj/qlora.yml +++ b/examples/gptj/qlora.yml @@ -21,7 +21,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./qlora-out gradient_accumulation_steps: 2 diff --git a/examples/jeopardy-bot/config.yml b/examples/jeopardy-bot/config.yml index 9dbdf6e6e..a0144ec51 100644 --- a/examples/jeopardy-bot/config.yml +++ b/examples/jeopardy-bot/config.yml @@ -19,7 +19,7 @@ lora_fan_in_fan_out: false wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./jeopardy-bot-7b gradient_accumulation_steps: 1 diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml index cf86a3e5c..e1c17c796 100644 --- a/examples/llama-2/fft_optimized.yml +++ b/examples/llama-2/fft_optimized.yml @@ -29,7 +29,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 1 diff --git a/examples/llama-2/gptq-lora.yml b/examples/llama-2/gptq-lora.yml index 61b00992f..c22f4f08f 100644 --- a/examples/llama-2/gptq-lora.yml +++ b/examples/llama-2/gptq-lora.yml @@ -32,7 +32,7 @@ lora_target_linear: lora_fan_in_fan_out: wandb_project: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./model-out gradient_accumulation_steps: 1 diff --git a/examples/llama-2/lora.yml b/examples/llama-2/lora.yml index 7d50877c7..4dfeb0079 100644 --- a/examples/llama-2/lora.yml +++ b/examples/llama-2/lora.yml @@ -29,7 +29,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 4 diff --git a/examples/llama-2/qlora.yml b/examples/llama-2/qlora.yml index 29b756ce5..7e453e7a1 100644 --- a/examples/llama-2/qlora.yml +++ b/examples/llama-2/qlora.yml @@ -31,7 +31,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 4 diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml index 8a7243d6f..9c9f6d6f4 100644 --- a/examples/llama-2/relora.yml +++ b/examples/llama-2/relora.yml @@ -35,7 +35,7 @@ relora_cpu_offload: false wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 4 diff --git a/examples/llama-2/tiny-llama.yml b/examples/llama-2/tiny-llama.yml index 6b3fa652f..c3af7e827 100644 --- a/examples/llama-2/tiny-llama.yml +++ b/examples/llama-2/tiny-llama.yml @@ -29,7 +29,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 4 diff --git a/examples/mistral/config.yml b/examples/mistral/config.yml index f0f7dad0a..4d116c9f8 100644 --- a/examples/mistral/config.yml +++ b/examples/mistral/config.yml @@ -21,7 +21,7 @@ pad_to_sequence_len: true wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 4 diff --git a/examples/mistral/qlora.yml b/examples/mistral/qlora.yml index 28c5ed242..8c091e977 100644 --- a/examples/mistral/qlora.yml +++ b/examples/mistral/qlora.yml @@ -38,7 +38,7 @@ lora_target_modules: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 4 diff --git a/examples/mpt-7b/config.yml b/examples/mpt-7b/config.yml index c9401890c..72f4e043e 100644 --- a/examples/mpt-7b/config.yml +++ b/examples/mpt-7b/config.yml @@ -21,7 +21,7 @@ lora_fan_in_fan_out: false wandb_project: mpt-alpaca-7b wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./mpt-alpaca-7b gradient_accumulation_steps: 1 diff --git a/examples/openllama-3b/config.yml b/examples/openllama-3b/config.yml index df6b26893..7809ec3d8 100644 --- a/examples/openllama-3b/config.yml +++ b/examples/openllama-3b/config.yml @@ -23,7 +23,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./openllama-out gradient_accumulation_steps: 1 diff --git a/examples/openllama-3b/lora.yml b/examples/openllama-3b/lora.yml index 7221abcbd..bddb777f8 100644 --- a/examples/openllama-3b/lora.yml +++ b/examples/openllama-3b/lora.yml @@ -29,7 +29,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./lora-out gradient_accumulation_steps: 1 diff --git a/examples/openllama-3b/qlora.yml b/examples/openllama-3b/qlora.yml index 89fbecde3..891dd48df 100644 --- a/examples/openllama-3b/qlora.yml +++ b/examples/openllama-3b/qlora.yml @@ -23,7 +23,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./qlora-out gradient_accumulation_steps: 1 diff --git a/examples/phi/phi-ft.yml b/examples/phi/phi-ft.yml index 8ed648ed6..cfacc49cc 100644 --- a/examples/phi/phi-ft.yml +++ b/examples/phi/phi-ft.yml @@ -31,7 +31,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 1 diff --git a/examples/phi/phi-qlora.yml b/examples/phi/phi-qlora.yml index 8fe5e98b1..780a2a116 100644 --- a/examples/phi/phi-qlora.yml +++ b/examples/phi/phi-qlora.yml @@ -31,7 +31,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 1 diff --git a/examples/pythia-12b/config.yml b/examples/pythia-12b/config.yml index 00693a164..e44bba745 100644 --- a/examples/pythia-12b/config.yml +++ b/examples/pythia-12b/config.yml @@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./pythia-12b gradient_accumulation_steps: 1 diff --git a/examples/pythia/lora.yml b/examples/pythia/lora.yml index b41e8197c..6681f627f 100644 --- a/examples/pythia/lora.yml +++ b/examples/pythia/lora.yml @@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./lora-alpaca-pythia gradient_accumulation_steps: 1 diff --git a/examples/redpajama/config-3b.yml b/examples/redpajama/config-3b.yml index edabd0e31..8895074ba 100644 --- a/examples/redpajama/config-3b.yml +++ b/examples/redpajama/config-3b.yml @@ -22,7 +22,7 @@ lora_fan_in_fan_out: false wandb_project: redpajama-alpaca-3b wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./redpajama-alpaca-3b batch_size: 4 diff --git a/examples/replit-3b/config-lora.yml b/examples/replit-3b/config-lora.yml index c3f448fab..82715eae5 100644 --- a/examples/replit-3b/config-lora.yml +++ b/examples/replit-3b/config-lora.yml @@ -21,7 +21,7 @@ lora_fan_in_fan_out: wandb_project: lora-replit wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./lora-replit batch_size: 8 diff --git a/examples/xgen-7b/xgen-7b-8k-qlora.yml b/examples/xgen-7b/xgen-7b-8k-qlora.yml index 524f4e993..26230c408 100644 --- a/examples/xgen-7b/xgen-7b-8k-qlora.yml +++ b/examples/xgen-7b/xgen-7b-8k-qlora.yml @@ -38,7 +38,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: output_dir: ./qlora-out From 3ee73687f87f7623d2d7aa98e05e2317d57c113b Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 22 Oct 2023 23:15:45 +0900 Subject: [PATCH 3/7] feat: add new recommendation and update config --- src/axolotl/utils/config.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index ef8025a3e..4b2bb6e0d 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -397,6 +397,14 @@ def validate_config(cfg): "Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch." ) + if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0: + cfg.wandb_name = cfg.wandb_run_id + cfg.wandb_run_id = None + + LOG.warning( + "wandb_run_id is not recommended anymore. Please use wandb_name instead." + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 From bcd588220d994a1eb23b0883e577386a6667813a Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 22 Oct 2023 23:35:07 +0900 Subject: [PATCH 4/7] fix: indent and pop disabled env if project passed --- src/axolotl/utils/wandb_.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/wandb_.py b/src/axolotl/utils/wandb_.py index b0d5f6703..327dd9b63 100644 --- a/src/axolotl/utils/wandb_.py +++ b/src/axolotl/utils/wandb_.py @@ -10,11 +10,12 @@ def setup_wandb_env_vars(cfg: DictDefault): if key.startswith("wandb_"): value = cfg.get(key, "") - if value and len(value) > 0: - os.environ[key.upper()] = value + if value and isinstance(value, str) and len(value) > 0: + os.environ[key.upper()] = value # Enable wandb if project name is present if cfg.wandb_project and len(cfg.wandb_project) > 0: cfg.use_wandb = True + os.environ.pop("WANDB_DISABLED", None) # Remove if present else: os.environ["WANDB_DISABLED"] = "true" From f194e0bd2e350dc4dd454ef7bbdf50911df6ff15 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 22 Oct 2023 23:35:35 +0900 Subject: [PATCH 5/7] feat: test env set for wandb and recommendation --- tests/test_validation.py | 72 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/test_validation.py b/tests/test_validation.py index 5a4ef427b..f10fffd77 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,6 +1,7 @@ """Module for testing the validation module""" import logging +import os import unittest from typing import Optional @@ -8,6 +9,7 @@ from axolotl.utils.config import validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.wandb_ import setup_wandb_env_vars class ValidationTest(unittest.TestCase): @@ -679,3 +681,73 @@ def test_warmup_step_no_conflict(self): ) validate_config(cfg) + + def test_wandb_rename_run_id_to_name(self): + cfg = DictDefault( + { + "wandb_run_id": "foo", + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "wandb_run_id is not recommended anymore. Please use wandb_name instead." + in record.message + for record in self._caplog.records + ) + + assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None + + cfg = DictDefault( + { + "wandb_name": "foo", + } + ) + + validate_config(cfg) + + def test_wandb_sets_env(self): + cfg = DictDefault( + { + "wandb_project": "foo", + "wandb_name": "bar", + "wandb_entity": "baz", + "wandb_mode": "online", + "wandb_watch": "false", + "wandb_log_model": "checkpoint", + } + ) + + validate_config(cfg) + + setup_wandb_env_vars(cfg) + + assert os.environ.get("WANDB_PROJECT", "") == "foo" + assert os.environ.get("WANDB_NAME", "") == "bar" + assert os.environ.get("WANDB_ENTITY", "") == "baz" + assert os.environ.get("WANDB_MODE", "") == "online" + assert os.environ.get("WANDB_WATCH", "") == "false" + assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint" + assert os.environ.get("WANDB_DISABLED", "") != "true" + + def test_wandb_set_disabled(self): + cfg = DictDefault({}) + + validate_config(cfg) + + setup_wandb_env_vars(cfg) + + assert os.environ.get("WANDB_DISABLED", "") == "true" + + cfg = DictDefault( + { + "wandb_project": "foo", + } + ) + + validate_config(cfg) + + setup_wandb_env_vars(cfg) + + assert os.environ.get("WANDB_DISABLED", "") != "true" From bfbaa88d0ace5b36d379cc27b38b30d2aa445e21 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 29 Nov 2023 23:37:48 +0900 Subject: [PATCH 6/7] feat: update to use wandb_name and allow id --- examples/qwen/lora.yml | 2 +- examples/qwen/qlora.yml | 2 +- src/axolotl/core/trainer_builder.py | 2 +- src/axolotl/utils/config.py | 5 ++--- tests/test_validation.py | 16 +++++++++++++--- 5 files changed, 18 insertions(+), 9 deletions(-) diff --git a/examples/qwen/lora.yml b/examples/qwen/lora.yml index ca2061a5b..87db872a5 100644 --- a/examples/qwen/lora.yml +++ b/examples/qwen/lora.yml @@ -31,7 +31,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 4 diff --git a/examples/qwen/qlora.yml b/examples/qwen/qlora.yml index 224020b7f..d3b45c940 100644 --- a/examples/qwen/qlora.yml +++ b/examples/qwen/qlora.yml @@ -31,7 +31,7 @@ lora_fan_in_fan_out: wandb_project: wandb_entity: wandb_watch: -wandb_run_id: +wandb_name: wandb_log_model: gradient_accumulation_steps: 4 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 62e527beb..5a030cf7d 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -643,7 +643,7 @@ def build(self, total_num_steps): training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None training_arguments_kwargs["run_name"] = ( - self.cfg.wandb_run_id if self.cfg.use_wandb else None + self.cfg.wandb_name if self.cfg.use_wandb else None ) training_arguments_kwargs["optim"] = ( self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 4b2bb6e0d..6ae49514a 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -397,12 +397,11 @@ def validate_config(cfg): "Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch." ) - if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0: + if cfg.wandb_run_id and not cfg.wandb_name: cfg.wandb_name = cfg.wandb_run_id - cfg.wandb_run_id = None LOG.warning( - "wandb_run_id is not recommended anymore. Please use wandb_name instead." + "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." ) # TODO diff --git a/tests/test_validation.py b/tests/test_validation.py index f10fffd77..fabc23da3 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -682,7 +682,13 @@ def test_warmup_step_no_conflict(self): validate_config(cfg) - def test_wandb_rename_run_id_to_name(self): + +class ValidationWandbTest(ValidationTest): + """ + Validation test for wandb + """ + + def test_wandb_set_run_id_to_name(self): cfg = DictDefault( { "wandb_run_id": "foo", @@ -692,12 +698,12 @@ def test_wandb_rename_run_id_to_name(self): with self._caplog.at_level(logging.WARNING): validate_config(cfg) assert any( - "wandb_run_id is not recommended anymore. Please use wandb_name instead." + "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." in record.message for record in self._caplog.records ) - assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None + assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo" cfg = DictDefault( { @@ -707,11 +713,14 @@ def test_wandb_rename_run_id_to_name(self): validate_config(cfg) + assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None + def test_wandb_sets_env(self): cfg = DictDefault( { "wandb_project": "foo", "wandb_name": "bar", + "wandb_run_id": "bat", "wandb_entity": "baz", "wandb_mode": "online", "wandb_watch": "false", @@ -725,6 +734,7 @@ def test_wandb_sets_env(self): assert os.environ.get("WANDB_PROJECT", "") == "foo" assert os.environ.get("WANDB_NAME", "") == "bar" + assert os.environ.get("WANDB_RUN_ID", "") == "bat" assert os.environ.get("WANDB_ENTITY", "") == "baz" assert os.environ.get("WANDB_MODE", "") == "online" assert os.environ.get("WANDB_WATCH", "") == "false" From 646a9f080ae61156211bc73ff3a11026bbfa8016 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 4 Dec 2023 21:57:25 +0900 Subject: [PATCH 7/7] chore: add info to readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 8ef4b635b..48e3c2758 100644 --- a/README.md +++ b/README.md @@ -660,6 +660,7 @@ wandb_project: # Your wandb project name wandb_entity: # A wandb Team name if using a Team wandb_watch: wandb_name: # Set the name of your wandb run +wandb_run_id: # Set the ID of your wandb run wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training # Where to save the full-finetuned model to