Skip to content

Commit

Permalink
feat: update to use wandb_name and allow id
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Nov 29, 2023
1 parent f194e0b commit bfbaa88
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 9 deletions.
2 changes: 1 addition & 1 deletion examples/qwen/lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion examples/qwen/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def build(self, total_num_steps):
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_run_id if self.cfg.use_wandb else None
self.cfg.wandb_name if self.cfg.use_wandb else None
)
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
Expand Down
5 changes: 2 additions & 3 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,12 +397,11 @@ def validate_config(cfg):
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
)

if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id
cfg.wandb_run_id = None

LOG.warning(
"wandb_run_id is not recommended anymore. Please use wandb_name instead."
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)

# TODO
Expand Down
16 changes: 13 additions & 3 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,13 @@ def test_warmup_step_no_conflict(self):

validate_config(cfg)

def test_wandb_rename_run_id_to_name(self):

class ValidationWandbTest(ValidationTest):
"""
Validation test for wandb
"""

def test_wandb_set_run_id_to_name(self):
cfg = DictDefault(
{
"wandb_run_id": "foo",
Expand All @@ -692,12 +698,12 @@ def test_wandb_rename_run_id_to_name(self):
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"wandb_run_id is not recommended anymore. Please use wandb_name instead."
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
in record.message
for record in self._caplog.records
)

assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"

cfg = DictDefault(
{
Expand All @@ -707,11 +713,14 @@ def test_wandb_rename_run_id_to_name(self):

validate_config(cfg)

assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None

def test_wandb_sets_env(self):
cfg = DictDefault(
{
"wandb_project": "foo",
"wandb_name": "bar",
"wandb_run_id": "bat",
"wandb_entity": "baz",
"wandb_mode": "online",
"wandb_watch": "false",
Expand All @@ -725,6 +734,7 @@ def test_wandb_sets_env(self):

assert os.environ.get("WANDB_PROJECT", "") == "foo"
assert os.environ.get("WANDB_NAME", "") == "bar"
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
assert os.environ.get("WANDB_ENTITY", "") == "baz"
assert os.environ.get("WANDB_MODE", "") == "online"
assert os.environ.get("WANDB_WATCH", "") == "false"
Expand Down

0 comments on commit bfbaa88

Please sign in to comment.