Skip to content

Commit

Permalink
feat: test env set for wandb and recommendation
Browse files Browse the repository at this point in the history
  • Loading branch information
NanoCode012 committed Oct 22, 2023
1 parent 190bf9c commit b648845
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Module for testing the validation module"""

import logging
import os
import unittest
from typing import Optional

import pytest

from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.wandb_ import setup_wandb_env_vars


class ValidationTest(unittest.TestCase):
Expand Down Expand Up @@ -606,3 +608,73 @@ def test_eval_table_size_conflict_eval_packing(self):
)

validate_config(cfg)

def test_wandb_rename_run_id_to_name(self):
cfg = DictDefault(
{
"wandb_run_id": "foo",
}
)

with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"wandb_run_id is not recommended anymore. Please use wandb_name instead."
in record.message
for record in self._caplog.records
)

assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None

cfg = DictDefault(
{
"wandb_name": "foo",
}
)

validate_config(cfg)

def test_wandb_sets_env(self):
cfg = DictDefault(
{
"wandb_project": "foo",
"wandb_name": "bar",
"wandb_entity": "baz",
"wandb_mode": "online",
"wandb_watch": "false",
"wandb_log_model": "checkpoint",
}
)

validate_config(cfg)

setup_wandb_env_vars(cfg)

assert os.environ.get("WANDB_PROJECT", "") == "foo"
assert os.environ.get("WANDB_NAME", "") == "bar"
assert os.environ.get("WANDB_ENTITY", "") == "baz"
assert os.environ.get("WANDB_MODE", "") == "online"
assert os.environ.get("WANDB_WATCH", "") == "false"
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
assert os.environ.get("WANDB_DISABLED", "") != "true"

def test_wandb_set_disabled(self):
cfg = DictDefault({})

validate_config(cfg)

setup_wandb_env_vars(cfg)

assert os.environ.get("WANDB_DISABLED", "") == "true"

cfg = DictDefault(
{
"wandb_project": "foo",
}
)

validate_config(cfg)

setup_wandb_env_vars(cfg)

assert os.environ.get("WANDB_DISABLED", "") != "true"

0 comments on commit b648845

Please sign in to comment.