Skip to content

Commit

Permalink
Mixtral fixes 20240124 (#1192) [skip ci]
Browse files Browse the repository at this point in the history
* mixtral nccl fixes

* make sure to patch for z3
  • Loading branch information
winglian committed Jan 24, 2024
1 parent af02430 commit 54d2ac1
Show file tree
Hide file tree
Showing 14 changed files with 71 additions and 13 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ tokens:
fsdp:
fsdp_config:

# Deepspeed config path. e.g., deepspeed/zero3.json
# Deepspeed config path. e.g., deepspeed_configs/zero3.json
deepspeed:

# Advanced DDP Arguments
Expand Down Expand Up @@ -982,11 +982,11 @@ for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usa
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.

```yaml
deepspeed: deepspeed/zero1.json
deepspeed: deepspeed_configs/zero1.json
```

```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed_configs/zero1.json
```

##### FSDP
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion examples/llama-2/fft_optimized.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed: #deepspeed/zero2.json # multi-gpu only
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/Mistral-7b-example/code.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@
"not only optimizer states but also gradients and parameters across GPUs. The bf16 indicate mixed precision training using bfloat16.\n",
"For more information read axolotl's readme\n",
"\"\"\"\n",
"!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed deepspeed/zero3_bf16.json"
"!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed deepspeed_configs/zero3_bf16.json"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/Mistral-7b-example/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
#default deepspeed, can use more aggresive if needed like zero2, zero3
deepspeed: deepspeed/zero1.json
deepspeed: deepspeed_configs/zero1.json
weight_decay: 0.0
fsdp:
fsdp_config:
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ accelerate launch -m axolotl.cli.train examples/mistral/config.yml

If you run into CUDA OOM, use deepspeed with config zero2.json:
```shell
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed_configs/zero2.json
```
2 changes: 1 addition & 1 deletion examples/mistral/mixtral.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed: deepspeed/zero2.json
deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0
fsdp:
fsdp_config:
Expand Down
2 changes: 1 addition & 1 deletion examples/phi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.

```shell
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed/zero1.json
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed_configs/zero1.json

# OR

Expand Down
51 changes: 50 additions & 1 deletion src/axolotl/monkeypatch/mixtral/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,61 @@
"""
Patches to support multipack for mixtral
"""
import torch
import transformers

from axolotl.monkeypatch.utils import get_unpad_data


def replace_mixtral_attn_with_multipack_flash_attn():
def patch_mixtral_moe_forward_zero3() -> None:
import torch.nn.functional as F

def mlp_forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
hidden_states
)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states

# Ref. https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(
routing_weights, self.top_k, dim=-1, sorted=False
)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)

hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states) # pylint: disable=invalid-name
flat_topk_idx = topk_idx.view(-1)
for i in range(self.num_experts):
expert = self.experts[i]
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
y = ( # pylint: disable=invalid-name
y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)
).sum(dim=1)
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

from transformers.models.mixtral.modeling_mixtral import (
MixtralBLockSparseTop2MLP,
MixtralSparseMoeBlock,
)

MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward


def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if for_zero3:
patch_mixtral_moe_forward_zero3()
2 changes: 1 addition & 1 deletion src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled

from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
Expand Down
13 changes: 11 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled

from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
Expand Down Expand Up @@ -333,7 +333,10 @@ def load_model(
)

LOG.info("patching mixtral with flash attention")
replace_mixtral_attn_with_multipack_flash_attn()
mixtral_patch_kwargs = {}
if is_deepspeed_zero3_enabled():
mixtral_patch_kwargs["for_zero3"] = True
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)

if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.falcon import (
Expand Down Expand Up @@ -646,6 +649,12 @@ def load_model(
needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False

if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
from deepspeed.utils import set_z3_leaf_modules
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

set_z3_leaf_modules(model, [MixtralSparseMoeBlock])

if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True
Expand Down

0 comments on commit 54d2ac1

Please sign in to comment.