Skip to content

Commit

Permalink
Support Sample packing for phi arch (axolotl-ai-cloud#586)
Browse files Browse the repository at this point in the history
* phi sequence packing

* sample packing fixes

* fix linting

* fix inference and phi e2e tests

* update phi example now that sample packing works

* wandb import keeps getting moved around
  • Loading branch information
winglian committed Sep 15, 2023
1 parent b3d6918 commit 67370a5
Show file tree
Hide file tree
Showing 10 changed files with 1,138 additions and 23 deletions.
6 changes: 6 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ ignore_missing_imports = True
[mypy-axolotl.monkeypatch.*]
ignore_errors = True

[mypy-axolotl.models.phi.*]
ignore_errors = True

[mypy-flash_attn.*]
ignore_missing_imports = True

Expand All @@ -20,6 +23,9 @@ ignore_missing_imports = True
[mypy-peft]
ignore_missing_imports = True

[mypy-wandb]
ignore_missing_imports = True

[mypy-bitsandbytes]
ignore_missing_imports = True

Expand Down
8 changes: 4 additions & 4 deletions examples/phi/phi-ft.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
base_model: microsoft/phi-1_5
base_model_config: microsoft/phi-1_5
model_type: AutoModelForCausalLM
model_type: MixFormerSequentialForCausalLM
tokenizer_type: AutoTokenizer
is_llama_derived_model: false
trust_remote_code: true
Expand All @@ -18,7 +18,7 @@ val_set_size: 0.05
output_dir: ./phi-sft-out

sequence_len: 2048
sample_packing: false # does not work with phi
sample_packing: true
pad_to_sequence_len:

adapter:
Expand All @@ -35,10 +35,10 @@ wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 2
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_bnb_8bit
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
Expand Down
Empty file added src/axolotl/models/__init__.py
Empty file.
6 changes: 6 additions & 0 deletions src/axolotl/models/phi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
MixFormers model architecture used for phi models
"""

from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
63 changes: 63 additions & 0 deletions src/axolotl/models/phi/configuration_mixformer_sequential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# pylint: skip-file

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import math
from typing import Any, Dict, List, Optional, Union

from transformers import PretrainedConfig


class MixFormerSequentialConfig(PretrainedConfig):
"""MixFormer (sequential for DeepSpeed) configuration."""

model_type = "mixformer-sequential"

attribute_map = {
"max_position_embeddings": "n_positions",
"hidden_size": "n_embd",
"num_attention_heads": "n_head",
"num_hidden_layers": "n_layer",
"input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
"blocks": "architecture", # `blocks` key is for backward compatibility
}

def __init__(
self,
vocab_size: Optional[int] = 50304,
n_positions: Optional[int] = 2048,
n_embd: Optional[int] = 1024,
n_layer: Optional[int] = 20,
n_inner: Optional[int] = None,
n_head: Optional[int] = 16,
rotary_dim: Optional[int] = 32,
activation_function: Optional[str] = "gelu_new",
embd_layer: Optional[str] = "default",
architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
embd_pdrop: Optional[float] = 0.0,
resid_pdrop: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-5,
initializer_range: Optional[float] = 0.02,
tie_word_embeddings: Optional[bool] = False,
pad_vocab_size_multiple: Optional[int] = 64,
**kwargs
) -> None:
self.vocab_size = int(
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.rotary_dim = min(rotary_dim, n_embd // n_head)
self.activation_function = activation_function
self.embd_layer = embd_layer
self.architecture = architecture
self.embd_pdrop = embd_pdrop
self.resid_pdrop = resid_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
Loading

0 comments on commit 67370a5

Please sign in to comment.