Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor neft patch to be more re-usable similar to trl's impl #796

Merged
merged 3 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion gitbook/README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
# Page

40 changes: 0 additions & 40 deletions src/axolotl/monkeypatch/llama_embeddings_hijack.py

This file was deleted.

40 changes: 0 additions & 40 deletions src/axolotl/monkeypatch/mistral_embeddings_hijack.py

This file was deleted.

65 changes: 65 additions & 0 deletions src/axolotl/monkeypatch/neft_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
from peft import PeftModel
from transformers import PreTrainedModel


def patch_neft(alpha, model):
embeddings = None
if isinstance(model, PreTrainedModel):
embeddings = model.get_input_embeddings()
if isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
if not embeddings:
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
embeddings.noisy_embedding_alpha = alpha
old_forward = embeddings.forward

# This hack seems to be needed to properly use a custom forward pass
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
embeddings, embeddings.__class__
)
setattr(embeddings, "forward", bound_method)

embeddings._old_forward = old_forward # pylint: disable=protected-access
return model


def unpatch_neft(model):
embeddings = None
if isinstance(model, PreTrainedModel):
embeddings = model.get_input_embeddings()
if isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
if not embeddings:
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
if hasattr(embeddings, "_old_forward"):
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
del embeddings._old_forward # pylint: disable=protected-access
del embeddings.noisy_embedding_alpha


def neft_forward(self, inputs: torch.Tensor):
embeddings = self._old_forward(inputs) # pylint: disable=protected-access

if self.training:
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
-mag_norm, mag_norm
)

return embeddings


def pretrain_hook(cfg, trainer):
if cfg.noisy_embedding_alpha:
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)


def post_train_hook(cfg, trainer):
if cfg.noisy_embedding_alpha:
unpatch_neft(trainer.model)
23 changes: 23 additions & 0 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch import neft_embeddings
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
Expand Down Expand Up @@ -107,13 +108,15 @@ def terminate_handler(_, __, model):
if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length")

pretrain_hooks(cfg, trainer)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
post_train_hooks(cfg, trainer)

LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")

Expand Down Expand Up @@ -163,3 +166,23 @@ def terminate_handler(_, __, model):
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))

return model, tokenizer


def pretrain_hooks(cfg, trainer):
"""
Run hooks right before kicking off the training
:param cfg:
:param trainer:
:return:
"""
neft_embeddings.pretrain_hook(cfg, trainer)


def post_train_hooks(cfg, trainer):
"""
Run hooks right after training completes
:param cfg:
:param trainer:
:return:
"""
neft_embeddings.post_train_hook(cfg, trainer)
20 changes: 0 additions & 20 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,26 +180,6 @@ def load_model(
LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)

if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.llama_embeddings_hijack import (
replace_llama_embeddings_with_uniform_distribution,
)

LOG.info("patching with noisy embeddings")
replace_llama_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)

if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.mistral_embeddings_hijack import (
replace_mistral_embeddings_with_uniform_distribution,
)

LOG.info("patching with noisy embeddings")
replace_mistral_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)

if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
replace_llama_rope_with_xpos_rope,
Expand Down