Skip to content

Commit

Permalink
Multipack simplify for Mixtral (#1142)
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored Jan 18, 2024
1 parent 1d70f24 commit 6910e6a
Show file tree
Hide file tree
Showing 11 changed files with 201 additions and 430 deletions.
23 changes: 16 additions & 7 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import Optional
from typing import Optional, Type, Union

import torch
import transformers
Expand All @@ -37,6 +37,7 @@
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
Expand Down Expand Up @@ -896,14 +897,22 @@ def build_collator(
if is_eval and training_args.eval_sample_packing:
use_batch_sampler_collator = True

collator: Type[
Union[
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
]
]
if use_batch_sampler_collator:
return BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**kwargs,
)
if self.cfg.model_config_type == "mixtral":
collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
collator = DataCollatorForSeq2Seq

return DataCollatorForSeq2Seq(
return collator(
self.tokenizer,
return_tensors="pt",
**kwargs,
Expand Down
18 changes: 4 additions & 14 deletions src/axolotl/monkeypatch/mixtral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,10 @@
"""
import transformers

from axolotl.monkeypatch.utils import get_unpad_data

def replace_mixtral_attn_with_multipack_flash_attn():
from .modeling_mixtral import (
MixtralMultipackFlashAttention2,
mixtral_decoder_layer_forward,
mixtral_model_forward,
)

transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
mixtral_decoder_layer_forward
)
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
mixtral_model_forward
def replace_mixtral_attn_with_multipack_flash_attn():
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[
"flash_attention_2"
] = MixtralMultipackFlashAttention2
Loading

0 comments on commit 6910e6a

Please sign in to comment.