Skip to content

Commit

Permalink
fixes for alpaca w chatml, and don't include attention_mask w mistral…
Browse files Browse the repository at this point in the history
… for flash attention (#728)
  • Loading branch information
winglian committed Oct 14, 2023
1 parent 7f2027d commit 3553172
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
12 changes: 8 additions & 4 deletions src/axolotl/prompt_strategies/alpaca_chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
"""Module for Alpaca prompt strategy classes"""

from typing import Tuple
from typing import Any, Dict, Optional, Tuple

from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
Expand All @@ -9,9 +9,13 @@
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter


def load(tokenizer, cfg):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
prompt_style = PromptStyle.CHAT.value
if ds_cfg and "conversation" in ds_cfg:
prompt_style = ds_cfg["conversation"]

return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value),
AlpacaPrompter(prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
Expand Down
4 changes: 3 additions & 1 deletion src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
)

# Phi doesn't want the attention_mask feature when training
if "CodeGenTokenizer" in tokenizer.__class__.__name__:
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
cfg.is_mistral_derived_model and cfg.flash_attention
):
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
Expand Down

0 comments on commit 3553172

Please sign in to comment.