Skip to content

Commit

Permalink
convert exponential notation lr to floats (axolotl-ai-cloud#771)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Oct 22, 2023
1 parent 77caa94 commit b32745a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def normalize_config(cfg):
or (cfg.model_type and "mistral" in cfg.model_type.lower())
)

if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)

log_gpu_memory_usage(LOG, "baseline", cfg.device)


Expand Down
39 changes: 39 additions & 0 deletions tests/test_normalize_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Test classes for checking functionality of the cfg normalization
"""
import unittest

from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault


class NormalizeConfigTestCase(unittest.TestCase):
"""
test class for normalize_config checks
"""

def _get_base_cfg(self):
return DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)

def test_lr_as_float(self):
cfg = (
self._get_base_cfg()
| DictDefault( # pylint: disable=unsupported-binary-operation
{
"learning_rate": "5e-5",
}
)
)

normalize_config(cfg)

assert cfg.learning_rate == 0.00005

0 comments on commit b32745a

Please sign in to comment.