diff --git a/gitbook/README.md b/gitbook/README.md index 5c4b4d58a..642bde22a 100644 --- a/gitbook/README.md +++ b/gitbook/README.md @@ -1,2 +1 @@ # Page - diff --git a/src/axolotl/monkeypatch/llama_embeddings_hijack.py b/src/axolotl/monkeypatch/llama_embeddings_hijack.py deleted file mode 100644 index 654ca3ba8..000000000 --- a/src/axolotl/monkeypatch/llama_embeddings_hijack.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -patch to add noisy embeddings per https://arxiv.org/abs/2310.05914 -""" - -import torch -import transformers.models.llama.modeling_llama -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5): - # pylint: disable=duplicate-code - def noised_embed(orig_embed, noise_alpha, model): - def new_func(input_ids): - # during training, we add noise to the embedding - # during generation, we don't add noise to the embedding - if model.training: - embed_init = orig_embed(input_ids) - dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) - mag_norm = noise_alpha / torch.sqrt(dims) - return embed_init + torch.zeros_like(embed_init).uniform_( - -mag_norm, mag_norm - ) - return orig_embed(input_ids) - - return new_func - - def post_init(orig_post_init): - def new_func(self): - orig_post_init(self) - self.embed_tokens.forward = noised_embed( - self.embed_tokens.forward, noise_alpha, self - ) - - return new_func - - transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init( - transformers.models.llama.modeling_llama.LlamaModel.post_init - ) diff --git a/src/axolotl/monkeypatch/mistral_embeddings_hijack.py b/src/axolotl/monkeypatch/mistral_embeddings_hijack.py deleted file mode 100644 index ed5f25965..000000000 --- a/src/axolotl/monkeypatch/mistral_embeddings_hijack.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -patch to add noisy embeddings per https://arxiv.org/abs/2310.05914 -""" - -import torch -import transformers.models.mistral.modeling_mistral -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5): - # pylint: disable=duplicate-code - def noised_embed(orig_embed, noise_alpha, model): - def new_func(input_ids): - # during training, we add noise to the embedding - # during generation, we don't add noise to the embedding - if model.training: - embed_init = orig_embed(input_ids) - dims = torch.tensor(embed_init.size(1) * embed_init.size(2)) - mag_norm = noise_alpha / torch.sqrt(dims) - return embed_init + torch.zeros_like(embed_init).uniform_( - -mag_norm, mag_norm - ) - return orig_embed(input_ids) - - return new_func - - def post_init(orig_post_init): - def new_func(self): - orig_post_init(self) - self.embed_tokens.forward = noised_embed( - self.embed_tokens.forward, noise_alpha, self - ) - - return new_func - - transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init( - transformers.models.mistral.modeling_mistral.MistralModel.post_init - ) diff --git a/src/axolotl/monkeypatch/neft_embeddings.py b/src/axolotl/monkeypatch/neft_embeddings.py new file mode 100644 index 000000000..524d48f8f --- /dev/null +++ b/src/axolotl/monkeypatch/neft_embeddings.py @@ -0,0 +1,65 @@ +""" +patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914 +""" +import torch +from peft import PeftModel +from transformers import PreTrainedModel + + +def patch_neft(alpha, model): + embeddings = None + if isinstance(model, PreTrainedModel): + embeddings = model.get_input_embeddings() + if isinstance(model, PeftModel): + embeddings = model.base_model.get_input_embeddings() + if not embeddings: + raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}") + embeddings.noisy_embedding_alpha = alpha + old_forward = embeddings.forward + + # This hack seems to be needed to properly use a custom forward pass + # all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11 + bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter + embeddings, embeddings.__class__ + ) + setattr(embeddings, "forward", bound_method) + + embeddings._old_forward = old_forward # pylint: disable=protected-access + return model + + +def unpatch_neft(model): + embeddings = None + if isinstance(model, PreTrainedModel): + embeddings = model.get_input_embeddings() + if isinstance(model, PeftModel): + embeddings = model.base_model.get_input_embeddings() + if not embeddings: + raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}") + if hasattr(embeddings, "_old_forward"): + embeddings.forward = embeddings._old_forward # pylint: disable=protected-access + del embeddings._old_forward # pylint: disable=protected-access + del embeddings.noisy_embedding_alpha + + +def neft_forward(self, inputs: torch.Tensor): + embeddings = self._old_forward(inputs) # pylint: disable=protected-access + + if self.training: + dims = torch.tensor(embeddings.size(1) * embeddings.size(2)) + mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims) + embeddings = embeddings + torch.zeros_like(embeddings).uniform_( + -mag_norm, mag_norm + ) + + return embeddings + + +def pretrain_hook(cfg, trainer): + if cfg.noisy_embedding_alpha: + trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model) + + +def post_train_hook(cfg, trainer): + if cfg.noisy_embedding_alpha: + unpatch_neft(trainer.model) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 468d25e14..b9b0e595d 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -16,6 +16,7 @@ from axolotl.common.cli import TrainerCliArgs from axolotl.logging_config import configure_logging +from axolotl.monkeypatch import neft_embeddings from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer @@ -107,6 +108,7 @@ def terminate_handler(_, __, model): if cfg.group_by_length: LOG.info("hang tight... sorting dataset for group_by_length") + pretrain_hooks(cfg, trainer) if cfg.flash_optimum: with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=True, enable_mem_efficient=True @@ -114,6 +116,7 @@ def terminate_handler(_, __, model): trainer.train(resume_from_checkpoint=resume_from_checkpoint) else: trainer.train(resume_from_checkpoint=resume_from_checkpoint) + post_train_hooks(cfg, trainer) LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") @@ -163,3 +166,23 @@ def terminate_handler(_, __, model): trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) return model, tokenizer + + +def pretrain_hooks(cfg, trainer): + """ + Run hooks right before kicking off the training + :param cfg: + :param trainer: + :return: + """ + neft_embeddings.pretrain_hook(cfg, trainer) + + +def post_train_hooks(cfg, trainer): + """ + Run hooks right after training completes + :param cfg: + :param trainer: + :return: + """ + neft_embeddings.post_train_hook(cfg, trainer) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d907a194b..d0042abc4 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -180,26 +180,6 @@ def load_model( LOG.info("patching with flash attention") replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) - if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha: - from axolotl.monkeypatch.llama_embeddings_hijack import ( - replace_llama_embeddings_with_uniform_distribution, - ) - - LOG.info("patching with noisy embeddings") - replace_llama_embeddings_with_uniform_distribution( - noise_alpha=cfg.noisy_embedding_alpha - ) - - if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha: - from axolotl.monkeypatch.mistral_embeddings_hijack import ( - replace_mistral_embeddings_with_uniform_distribution, - ) - - LOG.info("patching with noisy embeddings") - replace_mistral_embeddings_with_uniform_distribution( - noise_alpha=cfg.noisy_embedding_alpha - ) - if cfg.is_llama_derived_model and cfg.xpos_rope: from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( replace_llama_rope_with_xpos_rope,