From 35454860152c422593d94bf50dfeb535daf5bd68 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 4 Jul 2024 18:45:46 -0400 Subject: [PATCH 1/7] add support for optimi_adamw optimizer w kahan summation --- src/axolotl/core/trainer_builder.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ec175454e..dd8d2052e 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1396,6 +1396,31 @@ def build(self, total_num_steps): trainer_kwargs = {} + if self.cfg.optimizer == "optimi_adamw": + from optimi import AdamW + + optimi_kwargs = {"lr": training_arguments_kwargs["learning_rate"]} + if "weight_decay" in training_arguments_kwargs: + optimi_kwargs["weight_decay"] = training_arguments_kwargs[ + "weight_decay" + ] + + if ( + "adam_beta1" in training_arguments_kwargs + and "adam_beta2" in training_arguments_kwargs + ): + optimi_kwargs["betas"] = ( + training_arguments_kwargs["adam_beta1"], + training_arguments_kwargs["adam_beta2"], + ) + + trainer_kwargs["optimizers"] = ( + AdamW(params=self.model.parameters(), **optimi_kwargs), + None, + ) + # Set default so transformers doesn't throw + training_arguments_kwargs["optim"] = "adamw_hf" + if self.cfg.optimizer == "lion_pytorch": from lion_pytorch import Lion From 903eff22f81a1bcd3d636cfd18b73b29b853f7ae Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 4 Jul 2024 18:47:25 -0400 Subject: [PATCH 2/7] pydantic validator for optimi_adamw --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 3cac4f839..3d0b02752 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -341,7 +341,7 @@ class HyperparametersConfig(BaseModel): learning_rate: Union[str, float] weight_decay: Optional[float] = 0.0 optimizer: Optional[ - Union[OptimizerNames, Literal["lion_pytorch"]] + Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]] ] = OptimizerNames.ADAMW_HF.value optim_args: Optional[Union[str, Dict[str, Any]]] = Field( default=None, metadata={"help": "Optional arguments to supply to optimizer."} From 251f15c29a5988fea088f8773eab0de632ef8cf7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 4 Jul 2024 19:01:30 -0400 Subject: [PATCH 3/7] workaround for setting optimizer for fsdp --- src/axolotl/core/trainer_builder.py | 60 ++++++++++++++--------------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index dd8d2052e..e41391680 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -226,6 +226,12 @@ class AxolotlTrainingMixins: default=None, metadata={"help": "whether to use sequential sampling for curriculum learning"}, ) + alternate_optimizer: Optional[str] = field( + default=None, + metadata={ + "help": "workaround to pass an alternate optimizer to the HF trainer" + }, + ) @dataclass @@ -285,7 +291,10 @@ def __init__( self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") def create_optimizer(self): - if self.args.loraplus_lr_ratio is None: + if ( + self.args.loraplus_lr_ratio is None + and self.args.alternate_optimizer is None + ): return super().create_optimizer() opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model @@ -295,15 +304,24 @@ def create_optimizer(self): opt_model, ) - loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) - loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init - opt_model, - optimizer_cls, - optimizer_kwargs, - loraplus_lr_ratio, - loraplus_lr_embedding, - ) + if self.args.loraplus_lr_ratio is not None: + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + loraplus_lr_embedding = getattr( + self.args, "loraplus_lr_embedding", None + ) + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + optimizer_kwargs, + loraplus_lr_ratio, + loraplus_lr_embedding, + ) + else: + from optimi import AdamW + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + AdamW(opt_model.parameters(), **optimizer_kwargs) + ) if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init @@ -1397,29 +1415,9 @@ def build(self, total_num_steps): trainer_kwargs = {} if self.cfg.optimizer == "optimi_adamw": - from optimi import AdamW - - optimi_kwargs = {"lr": training_arguments_kwargs["learning_rate"]} - if "weight_decay" in training_arguments_kwargs: - optimi_kwargs["weight_decay"] = training_arguments_kwargs[ - "weight_decay" - ] - - if ( - "adam_beta1" in training_arguments_kwargs - and "adam_beta2" in training_arguments_kwargs - ): - optimi_kwargs["betas"] = ( - training_arguments_kwargs["adam_beta1"], - training_arguments_kwargs["adam_beta2"], - ) - - trainer_kwargs["optimizers"] = ( - AdamW(params=self.model.parameters(), **optimi_kwargs), - None, - ) # Set default so transformers doesn't throw training_arguments_kwargs["optim"] = "adamw_hf" + training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer if self.cfg.optimizer == "lion_pytorch": from lion_pytorch import Lion From b137d007878a23863e82800e364acc715ab49680 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 4 Jul 2024 20:11:39 -0400 Subject: [PATCH 4/7] make sure to install optimizer packages --- cicd/Dockerfile.jinja | 4 ++-- docker/Dockerfile | 4 ++-- setup.py | 6 ++++++ 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 96c312ddc..7749f0904 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \ # If AXOLOTL_EXTRAS is set, append it in brackets RUN pip install causal_conv1d RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \ fi # So we can test the Docker image diff --git a/docker/Dockerfile b/docker/Dockerfile index cdb6d177a..be58d0354 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets RUN pip install causal_conv1d RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \ fi # So we can test the Docker image diff --git a/setup.py b/setup.py index 58d279475..33d829e9d 100644 --- a/setup.py +++ b/setup.py @@ -104,5 +104,11 @@ def parse_requirements(): "galore": [ "galore_torch", ], + "optimizers": [ + "galore_torch", + "lion-pytorch==0.1.2", + "lomo-optim==0.1.1", + "torch-optimi==0.2.1", + ], }, ) From a3cc7445eeee00bd397649aa2806256d19c2e7fa Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Jul 2024 14:55:02 -0400 Subject: [PATCH 5/7] make sure to have parity for model parameters passed to optimizer --- src/axolotl/core/trainer_builder.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e41391680..107c881f0 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -293,12 +293,32 @@ def __init__( def create_optimizer(self): if ( self.args.loraplus_lr_ratio is None - and self.args.alternate_optimizer is None + and self.args.alternate_optimizer != "optimi_adamw" ): return super().create_optimizer() opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: # pylint: disable=access-member-before-definition + decay_parameters = self.get_decay_parameter_names(opt_model) + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in opt_model.named_parameters() + if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p + for n, p in opt_model.named_parameters() + if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( self.args, opt_model, @@ -316,11 +336,11 @@ def create_optimizer(self): loraplus_lr_ratio, loraplus_lr_embedding, ) - else: + elif self.args.alternate_optimizer == "optimi_adamw": from optimi import AdamW self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamW(opt_model.parameters(), **optimizer_kwargs) + AdamW(optimizer_grouped_parameters, **optimizer_kwargs) ) if is_sagemaker_mp_enabled(): From 7d89b0527f43ee4e06deba3ebe299cedb5e31042 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 13 Jul 2024 14:43:44 -0400 Subject: [PATCH 6/7] add smoke test for optimi_adamw optimizer --- tests/e2e/test_lora_llama.py | 6 ++-- tests/e2e/test_optimizers.py | 67 ++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 3 deletions(-) create mode 100644 tests/e2e/test_optimizers.py diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index c79652bef..4c6fdaaa9 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -34,8 +34,8 @@ def test_lora(self, temp_dir): "sequence_len": 1024, "load_in_8bit": True, "adapter": "lora", - "lora_r": 32, - "lora_alpha": 64, + "lora_r": 8, + "lora_alpha": 16, "lora_dropout": 0.05, "lora_target_linear": True, "val_set_size": 0.1, @@ -50,7 +50,7 @@ def test_lora(self, temp_dir): "type": "alpaca", }, ], - "num_epochs": 2, + "num_epochs": 1, "micro_batch_size": 8, "gradient_accumulation_steps": 1, "output_dir": temp_dir, diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py new file mode 100644 index 000000000..119dd3d7c --- /dev/null +++ b/tests/e2e/test_optimizers.py @@ -0,0 +1,67 @@ +""" +E2E tests for custom optimizers using Llama +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestCustomOptimizers(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_optimi_adamw(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "optimi_adamw", + "lr_scheduler": "cosine", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists() From 449e25f9116aac5659b6bfd7c7a79833f456e908 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 14 Jul 2024 08:02:58 -0400 Subject: [PATCH 7/7] don't use foreach optimi by default --- src/axolotl/core/trainer_builder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 107c881f0..5391904fc 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -340,7 +340,9 @@ def create_optimizer(self): from optimi import AdamW self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamW(optimizer_grouped_parameters, **optimizer_kwargs) + AdamW( + optimizer_grouped_parameters, foreach=False, **optimizer_kwargs + ) ) if is_sagemaker_mp_enabled():