Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixtral fixes 20240124 #1192

Merged
merged 2 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ tokens:
fsdp:
fsdp_config:

# Deepspeed config path. e.g., deepspeed/zero3.json
# Deepspeed config path. e.g., deepspeed_configs/zero3.json
deepspeed:

# Advanced DDP Arguments
Expand Down Expand Up @@ -982,11 +982,11 @@ for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usa
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.

```yaml
deepspeed: deepspeed/zero1.json
deepspeed: deepspeed_configs/zero1.json
```

```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed_configs/zero1.json
```

##### FSDP
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion examples/llama-2/fft_optimized.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed: #deepspeed/zero2.json # multi-gpu only
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/Mistral-7b-example/code.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@
"not only optimizer states but also gradients and parameters across GPUs. The bf16 indicate mixed precision training using bfloat16.\n",
"For more information read axolotl's readme\n",
"\"\"\"\n",
"!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed deepspeed/zero3_bf16.json"
"!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed deepspeed_configs/zero3_bf16.json"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/Mistral-7b-example/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
#default deepspeed, can use more aggresive if needed like zero2, zero3
deepspeed: deepspeed/zero1.json
deepspeed: deepspeed_configs/zero1.json
weight_decay: 0.0
fsdp:
fsdp_config:
Expand Down
2 changes: 1 addition & 1 deletion examples/mistral/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ accelerate launch -m axolotl.cli.train examples/mistral/config.yml

If you run into CUDA OOM, use deepspeed with config zero2.json:
```shell
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed_configs/zero2.json
```
2 changes: 1 addition & 1 deletion examples/mistral/mixtral.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed: deepspeed/zero2.json
deepspeed: deepspeed_configs/zero2.json
weight_decay: 0.0
fsdp:
fsdp_config:
Expand Down
2 changes: 1 addition & 1 deletion examples/phi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.

```shell
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed/zero1.json
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed_configs/zero1.json

# OR

Expand Down
51 changes: 50 additions & 1 deletion src/axolotl/monkeypatch/mixtral/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,61 @@
"""
Patches to support multipack for mixtral
"""
import torch
import transformers

from axolotl.monkeypatch.utils import get_unpad_data


def replace_mixtral_attn_with_multipack_flash_attn():
def patch_mixtral_moe_forward_zero3() -> None:
import torch.nn.functional as F

def mlp_forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
hidden_states
)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states

# Ref. https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py
def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(
routing_weights, self.top_k, dim=-1, sorted=False
)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)

hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states) # pylint: disable=invalid-name
flat_topk_idx = topk_idx.view(-1)
for i in range(self.num_experts):
expert = self.experts[i]
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
y = ( # pylint: disable=invalid-name
y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)
).sum(dim=1)
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

from transformers.models.mixtral.modeling_mixtral import (
MixtralBLockSparseTop2MLP,
MixtralSparseMoeBlock,
)

MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward


def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if for_zero3:
patch_mixtral_moe_forward_zero3()
2 changes: 1 addition & 1 deletion src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled

from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
Expand Down
13 changes: 11 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled

from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
Expand Down Expand Up @@ -333,7 +333,10 @@ def load_model(
)

LOG.info("patching mixtral with flash attention")
replace_mixtral_attn_with_multipack_flash_attn()
mixtral_patch_kwargs = {}
if is_deepspeed_zero3_enabled():
mixtral_patch_kwargs["for_zero3"] = True
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)

if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.falcon import (
Expand Down Expand Up @@ -646,6 +649,12 @@ def load_model(
needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False

if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
from deepspeed.utils import set_z3_leaf_modules
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

set_z3_leaf_modules(model, [MixtralSparseMoeBlock])

if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True
Expand Down
Loading