diff --git a/src/axolotl/monkeypatch/neft_embeddings.py b/src/axolotl/monkeypatch/neft_embeddings.py new file mode 100644 index 0000000000..524d48f8ff --- /dev/null +++ b/src/axolotl/monkeypatch/neft_embeddings.py @@ -0,0 +1,65 @@ +""" +patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914 +""" +import torch +from peft import PeftModel +from transformers import PreTrainedModel + + +def patch_neft(alpha, model): + embeddings = None + if isinstance(model, PreTrainedModel): + embeddings = model.get_input_embeddings() + if isinstance(model, PeftModel): + embeddings = model.base_model.get_input_embeddings() + if not embeddings: + raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}") + embeddings.noisy_embedding_alpha = alpha + old_forward = embeddings.forward + + # This hack seems to be needed to properly use a custom forward pass + # all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11 + bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter + embeddings, embeddings.__class__ + ) + setattr(embeddings, "forward", bound_method) + + embeddings._old_forward = old_forward # pylint: disable=protected-access + return model + + +def unpatch_neft(model): + embeddings = None + if isinstance(model, PreTrainedModel): + embeddings = model.get_input_embeddings() + if isinstance(model, PeftModel): + embeddings = model.base_model.get_input_embeddings() + if not embeddings: + raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}") + if hasattr(embeddings, "_old_forward"): + embeddings.forward = embeddings._old_forward # pylint: disable=protected-access + del embeddings._old_forward # pylint: disable=protected-access + del embeddings.noisy_embedding_alpha + + +def neft_forward(self, inputs: torch.Tensor): + embeddings = self._old_forward(inputs) # pylint: disable=protected-access + + if self.training: + dims = torch.tensor(embeddings.size(1) * embeddings.size(2)) + mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims) + embeddings = embeddings + torch.zeros_like(embeddings).uniform_( + -mag_norm, mag_norm + ) + + return embeddings + + +def pretrain_hook(cfg, trainer): + if cfg.noisy_embedding_alpha: + trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model) + + +def post_train_hook(cfg, trainer): + if cfg.noisy_embedding_alpha: + unpatch_neft(trainer.model) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index da98600a45..c6c9c5aff4 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -15,6 +15,7 @@ from axolotl.common.cli import TrainerCliArgs from axolotl.logging_config import configure_logging +from axolotl.monkeypatch import neft_embeddings from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer @@ -109,6 +110,7 @@ def terminate_handler(_, __, model): if cfg.group_by_length: LOG.info("hang tight... sorting dataset for group_by_length") + pretrain_hooks(cfg, trainer) if cfg.flash_optimum: with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=True, enable_mem_efficient=True @@ -116,6 +118,7 @@ def terminate_handler(_, __, model): trainer.train(resume_from_checkpoint=resume_from_checkpoint) else: trainer.train(resume_from_checkpoint=resume_from_checkpoint) + post_train_hooks(cfg, trainer) LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") @@ -144,3 +147,23 @@ def terminate_handler(_, __, model): trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) return model, tokenizer + + +def pretrain_hooks(cfg, trainer): + """ + Run hooks right before kicking off the training + :param cfg: + :param trainer: + :return: + """ + neft_embeddings.pretrain_hook(cfg, trainer) + + +def post_train_hooks(cfg, trainer): + """ + Run hooks right after training completes + :param cfg: + :param trainer: + :return: + """ + neft_embeddings.post_train_hook(cfg, trainer) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c133e9eb61..8722e8bae3 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -180,26 +180,26 @@ def load_model( LOG.info("patching with flash attention") replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) - if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha: - from axolotl.monkeypatch.llama_embeddings_hijack import ( - replace_llama_embeddings_with_uniform_distribution, - ) - - LOG.info("patching with noisy embeddings") - replace_llama_embeddings_with_uniform_distribution( - noise_alpha=cfg.noisy_embedding_alpha - ) - - if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha: - from axolotl.monkeypatch.mistral_embeddings_hijack import ( - replace_mistral_embeddings_with_uniform_distribution, - ) - - LOG.info("patching with noisy embeddings") - replace_mistral_embeddings_with_uniform_distribution( - noise_alpha=cfg.noisy_embedding_alpha - ) - + # if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha: + # from axolotl.monkeypatch.llama_embeddings_hijack import ( + # replace_llama_embeddings_with_uniform_distribution, + # ) + # + # LOG.info("patching with noisy embeddings") + # replace_llama_embeddings_with_uniform_distribution( + # noise_alpha=cfg.noisy_embedding_alpha + # ) + # + # if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha: + # from axolotl.monkeypatch.mistral_embeddings_hijack import ( + # replace_mistral_embeddings_with_uniform_distribution, + # ) + # + # LOG.info("patching with noisy embeddings") + # replace_mistral_embeddings_with_uniform_distribution( + # noise_alpha=cfg.noisy_embedding_alpha + # ) + # if cfg.is_llama_derived_model and cfg.xpos_rope: from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( replace_llama_rope_with_xpos_rope,