Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for Rotary Position Embeddings #675

Merged
merged 119 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
aa9509e
..
Oct 5, 2023
7354fcc
..
ShashankMosaicML Oct 5, 2023
6801142
..
ShashankMosaicML Oct 5, 2023
eff6270
..
ShashankMosaicML Oct 5, 2023
cdc6798
..
ShashankMosaicML Oct 5, 2023
a74afb4
..
ShashankMosaicML Oct 5, 2023
3c02585
..
ShashankMosaicML Oct 5, 2023
47f5af6
..
ShashankMosaicML Oct 5, 2023
3389f78
..
ShashankMosaicML Oct 5, 2023
c9f2154
..
ShashankMosaicML Oct 5, 2023
9db76a8
..
ShashankMosaicML Oct 5, 2023
722eb0c
..
ShashankMosaicML Oct 5, 2023
4eb9f17
..
ShashankMosaicML Oct 6, 2023
c675c22
..
ShashankMosaicML Oct 6, 2023
de765c4
..
ShashankMosaicML Oct 6, 2023
529ada8
..
ShashankMosaicML Oct 6, 2023
7d39ffc
..
ShashankMosaicML Oct 6, 2023
bb92769
..
ShashankMosaicML Oct 6, 2023
841becb
..
ShashankMosaicML Oct 6, 2023
e5d0e65
..
ShashankMosaicML Oct 6, 2023
dabd231
..
ShashankMosaicML Oct 7, 2023
7f1109a
..
ShashankMosaicML Oct 7, 2023
e98841d
removed the roformer impementation of rope
ShashankMosaicML Oct 8, 2023
dea3b03
..
ShashankMosaicML Oct 8, 2023
2927a8c
fixed all the lint errors
ShashankMosaicML Oct 8, 2023
d605fbf
..
ShashankMosaicML Oct 8, 2023
7b250f7
..
ShashankMosaicML Oct 9, 2023
196b8e1
../llmfoundry/models/mpt/modeling_mpt.py
ShashankMosaicML Oct 9, 2023
22212a1
Merge pull request #2 from mosaicml/main
ShashankMosaicML Oct 9, 2023
0c3942e
..
ShashankMosaicML Oct 10, 2023
829b2a4
..
ShashankMosaicML Oct 12, 2023
1629d1a
..
ShashankMosaicML Oct 12, 2023
eb658a3
added unit test to test rotary embeddings
ShashankMosaicML Oct 12, 2023
5cb95f6
..
ShashankMosaicML Oct 12, 2023
30aa448
..
ShashankMosaicML Oct 12, 2023
52119f5
..
ShashankMosaicML Oct 13, 2023
048b886
..
ShashankMosaicML Oct 13, 2023
c2ee0de
..
ShashankMosaicML Oct 13, 2023
9e8a1d6
..
ShashankMosaicML Oct 13, 2023
9c2a2a6
..
ShashankMosaicML Oct 13, 2023
6fa3037
..
ShashankMosaicML Oct 13, 2023
5323520
..
ShashankMosaicML Oct 13, 2023
b0960f7
Update llmfoundry/models/mpt/modeling_mpt.py
ShashankMosaicML Oct 13, 2023
f6632e1
incorporated some suggestions from the pr
ShashankMosaicML Oct 13, 2023
df749ae
..
ShashankMosaicML Oct 14, 2023
8b886ba
..
ShashankMosaicML Oct 17, 2023
76a2095
Merge pull request #3 from mosaicml/main
ShashankMosaicML Oct 17, 2023
dc58fc7
..
ShashankMosaicML Oct 17, 2023
ed5a477
..
ShashankMosaicML Oct 17, 2023
2a53de3
..
ShashankMosaicML Oct 19, 2023
34e147c
..
ShashankMosaicML Oct 19, 2023
0a9d3af
..
ShashankMosaicML Oct 19, 2023
5981ade
added mark for gpu in the rotary embedding test
ShashankMosaicML Oct 20, 2023
b179b07
Merge pull request #4 from mosaicml/main
ShashankMosaicML Oct 20, 2023
9afa082
..
ShashankMosaicML Oct 20, 2023
9835acd
..
ShashankMosaicML Oct 20, 2023
1801677
..
ShashankMosaicML Oct 21, 2023
0a38037
removed thecode for hf implementation of rope
ShashankMosaicML Oct 23, 2023
d86a1a5
..
ShashankMosaicML Oct 23, 2023
7e336d2
..
ShashankMosaicML Oct 23, 2023
1897353
added tests
ShashankMosaicML Oct 24, 2023
213cd14
..
ShashankMosaicML Oct 24, 2023
68d03d3
..
ShashankMosaicML Oct 24, 2023
c0da75c
...
ShashankMosaicML Oct 24, 2023
4600415
..
ShashankMosaicML Oct 25, 2023
5ecda44
..
ShashankMosaicML Oct 25, 2023
6441180
..
ShashankMosaicML Oct 26, 2023
d71a2a0
..
ShashankMosaicML Oct 26, 2023
f952e5b
Merge pull request #5 from ShashankMosaicML/rotary_dail_imp
ShashankMosaicML Oct 26, 2023
f33ed5f
..
ShashankMosaicML Oct 26, 2023
07eafb7
Merge branch 'main' into rotary_hf_imp
ShashankMosaicML Oct 26, 2023
7efb6b1
fixed the tests after the merge
ShashankMosaicML Oct 26, 2023
3a056a8
minor change
ShashankMosaicML Oct 26, 2023
327dded
Fixed some tests failing due to a transformers library bug
ShashankMosaicML Oct 27, 2023
9c00106
added check for flash_attention before importing their rotary embedding
ShashankMosaicML Oct 27, 2023
999209c
added check for flash_attention in tests before using dail rope
ShashankMosaicML Oct 27, 2023
a681b64
fixed tests
ShashankMosaicML Oct 27, 2023
766fa75
..
ShashankMosaicML Oct 27, 2023
869be97
Merge pull request #7 from mosaicml/main
ShashankMosaicML Oct 27, 2023
21a4f31
..
ShashankMosaicML Oct 27, 2023
dbac1e0
temporary fix
ShashankMosaicML Oct 27, 2023
5d62dfe
..
ShashankMosaicML Oct 27, 2023
ca57151
..
ShashankMosaicML Oct 27, 2023
99a81a1
fixed a test
ShashankMosaicML Oct 27, 2023
b674e83
..
ShashankMosaicML Oct 27, 2023
8be09ab
minor change
ShashankMosaicML Oct 27, 2023
067439e
minor changes
ShashankMosaicML Oct 27, 2023
b325097
added documentation
ShashankMosaicML Oct 30, 2023
1b35c0b
added documentation
ShashankMosaicML Oct 30, 2023
3988b57
temp commit
ShashankMosaicML Oct 31, 2023
82ce2d9
made _set_config_defaults recursive
ShashankMosaicML Oct 31, 2023
d2930f9
minor changes
ShashankMosaicML Oct 31, 2023
b022b12
reformatted tutorial table
ShashankMosaicML Oct 31, 2023
dbc0f84
reformatted tutorial table
ShashankMosaicML Oct 31, 2023
54d8304
reformatted tutorial table
ShashankMosaicML Oct 31, 2023
c05d34d
added documentation on how to install flash attention 2
ShashankMosaicML Oct 31, 2023
c8f099d
minor changes
ShashankMosaicML Oct 31, 2023
a48afa2
minor changes
ShashankMosaicML Oct 31, 2023
fa46318
minor changes
ShashankMosaicML Oct 31, 2023
5b10164
minor changes
ShashankMosaicML Oct 31, 2023
64d4a57
minor changes
ShashankMosaicML Oct 31, 2023
04452e6
minor changes
ShashankMosaicML Oct 31, 2023
1eff648
..
ShashankMosaicML Oct 31, 2023
046fa08
Merge branch 'main' into rotary_hf_imp
ShashankMosaicML Oct 31, 2023
cceca07
Merge branch 'main' into rotary_hf_imp
ShashankMosaicML Nov 1, 2023
1e59de5
resolved some comments from the PR
ShashankMosaicML Nov 2, 2023
ac0fd40
fixed tests
ShashankMosaicML Nov 2, 2023
e59f784
modified is_flash_v2_installed
ShashankMosaicML Nov 3, 2023
c602a06
minor changes
ShashankMosaicML Nov 3, 2023
e59ec7e
Merge branch 'main' into rotary_hf_imp
ShashankMosaicML Nov 3, 2023
5744a3e
Update TUTORIAL.md
ShashankMosaicML Nov 4, 2023
0036ce7
Update TUTORIAL.md
ShashankMosaicML Nov 4, 2023
9ebd2e7
Update TUTORIAL.md
ShashankMosaicML Nov 4, 2023
4874713
Update TUTORIAL.md
ShashankMosaicML Nov 4, 2023
e0d8b75
resolved PR comments
ShashankMosaicML Nov 4, 2023
5ba6968
Merge branch 'main' into rotary_hf_imp
ShashankMosaicML Nov 6, 2023
6b680d7
Merge branch 'main' into rotary_hf_imp
ShashankMosaicML Nov 6, 2023
9204e31
Merge branch 'main' into rotary_hf_imp
ShashankMosaicML Nov 6, 2023
828d2bb
Merge branch 'main' into rotary_hf_imp
ShashankMosaicML Nov 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions TUTORIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,42 @@ Forging LLMs can be quite complicated — you have to get your data prepared, se

This tutorial will provide a brief intro to the repo’s structure and underlying tools (all courtesy of MosaicML, of course), will go over a few example workflows and point you to the related resources within the repo, and will finally cover a number of FAQs that we have encountered since release.

- [LLM Foundry Tutorial](#llm-foundry-tutorial)
- [Intro](#intro)
- [How this repo is structured](#how-this-repo-is-structured)
- [Key components](#key-components)
- [Composer](#composer)
- [StreamingDataset](#streamingdataset)
- [MCLI](#mcli)
- [How the YAMLs work](#how-the-yamls-work)
- [Example Workflows](#example-workflows)
- [Workflow 1: I want to play with a HF model like MPT-7B locally](#workflow-1-i-want-to-play-with-a-hf-model-like-mpt-7b-locally)
- [Workflow 2: I want to deploy an inference endpoint with a HF model like MPT-7B](#workflow-2-i-want-to-deploy-an-inference-endpoint-with-a-hf-model-like-mpt-7b)
- [Workflow 3: I want to finetune a HF model like MPT-7B](#workflow-3-i-want-to-finetune-a-hf-model-like-mpt-7b)
- [Supervised FineTuning and Instruction FineTuning](#supervised-finetuning-and-instruction-finetuning)
- [Domain Adaptation and Sequence Length Adaptation](#domain-adaptation-and-sequence-length-adaptation)
- [Data](#data)
- [Modeling](#modeling)
- [Workflow 4: I want to train a new HF model from scratch](#workflow-4-i-want-to-train-a-new-hf-model-from-scratch)
- [FAQs](#faqs)
- [Why is the script only using 1 out of N GPUs?](#why-is-the-script-only-using-1-out-of-n-gpus)
- [I’m running into an Out-Of-Memory (OOM) error. What do I do?](#im-running-into-an-out-of-memory-oom-error-what-do-i-do)
- [What hardware can I train on?](#what-hardware-can-i-train-on)
- [What hardware can I run eval on?](#what-hardware-can-i-run-eval-on)
- [What is FSDP?](#what-is-fsdp)
- [What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--triton-for-mpt-and-which-one-should-i-use)
- [Can I finetune using PEFT / LORA?](#can-i-finetune-using-peft--lora)
- [Can I quantize these models and/or run on CPU?](#can-i-quantize-these-models-andor-run-on-cpu)
- [How do I deploy with ONNX/FasterTransformer?](#how-do-i-deploy-with-onnxfastertransformer)
- [How expensive is it to build LLMs?](#how-expensive-is-it-to-build-llms)
- [Common installation issues](#common-installation-issues)
- [Why is the script only using 1 out of N GPUs?](#why-is-the-script-only-using-1-out-of-n-gpus)
- [I’m running into an Out-Of-Memory (OOM) error. What do I do?](#im-running-into-an-out-of-memory-oom-error-what-do-i-do)
- [What hardware can I train on?](#what-hardware-can-i-train-on)
- [What hardware can I run eval on?](#what-hardware-can-i-run-eval-on)
- [What hardware can I run inference on?](#what-hardware-can-i-run-inference-on)
- [What is FSDP?](#what-is-fsdp)
- [What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--triton--for-mpt-and-which-one-should-i-use)
- [Limitations](#limitations)
- [What is `triton-pre-mlir`?](#what-is-triton-pre-mlir)
- [Known issue with sm86+ GPUs](#known-issue-with-sm86-gpus)
- [Support for FlashAttention-2](#support-for-flashattention-2)
- [What kinds of positional embeddings does LLM Foundry support?](#what-kinds-of-positional-embeddings-does-llm-foundry-support)
- [Can I finetune using PEFT / LoRA?](#can-i-finetune-using-peft--lora)
- [Can I quantize these models and/or run on CPU?](#can-i-quantize-these-models-andor-run-on-cpu)
- [How do I deploy with ONNX/FasterTransformer?](#how-do-i-deploy-with-onnxfastertransformer)
- [TransformerEngine and amp\_fp8 support](#transformerengine-and-amp_fp8-support)
- [How expensive is it to build LLMs?](#how-expensive-is-it-to-build-llms)
- [Common installation issues](#common-installation-issues)

Let’s get started!

Expand Down Expand Up @@ -328,6 +343,18 @@ The majority of our training setups use `triton`. -->
Updating to LLVM14 (or LLVM15) cannot be done because there are breaking changes.
What is the result of this? Although sm89+ is not **formally** supported until LLVM15, our testing on H100 GPUs shows that `attn_impl=triton` still works well and still runs fast. The only issue is that when the network is starting to run, LLVM might throw a warning like: `'sm_90' is not a recognized processor for this target (ignoring processor)`. This warning does not seem to affect performance.

#### Support for FlashAttention-2
- [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster attention computation. LLM Foundry supports FlashAttention-2. Please follow the instructions [here](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#flashattention).

### What kinds of positional embeddings does LLM Foundry support?
Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.03762.pdf), [Attention with Linear Biases (ALiBi)](https://arxiv.org/pdf/2108.12409.pdf), and [Rotary Positional Embeddings (RoPE)](https://arxiv.org/pdf/2104.09864.pdf). There is also an option to switch off all of these embeddings to get [No Positional Embedding](https://arxiv.org/pdf/2203.16634.pdf).

| Name | YAML Config | Training MFU on MPT-7B trained on 8 A100 80GB GPUs | Notes |
|:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Learned Positional Embeddings | <pre>model:<br> learned_pos_emb:&nbsp;True</pre>| 65.7 | |
| ALiBi | <pre>model:<br> attn_config:<br> alibi:&nbsp;True</pre>| 64.5 | Requires Triton or Torch attention. |
| RoPE (Dao-AILab Implementation) | <pre>model:<br> attn_config:<br> rope:&nbsp;True<br> rope_impl:&nbsp;dail</pre>| 64.5 | Requires a CUDA GPU and the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.0.1 or higher to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn v2. Note that the attention implementation can still be `torch`, `triton`, or `flash`. |
| RoPE (Hugging<code>&nbsp;</code>Face Implementation) | <pre>model:<br> attn_config:<br> rope:&nbsp;True<br> rope_impl:&nbsp;hf</pre>| 62.3 | |

### Can I finetune using PEFT / LoRA?
- The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so:
Expand Down
71 changes: 58 additions & 13 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import math
import warnings
from typing import Any, List, Optional, Tuple
from typing import Any, Optional

import torch
import torch.nn as nn
Expand All @@ -17,12 +17,13 @@
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY


def is_flash_v2_installed():
def is_flash_v2_installed(v2_version: str = '2.0.0'):
assert version.parse(v2_version) >= version.parse('2.0.0')
try:
import flash_attn as flash_attn
except:
return False
return version.parse(flash_attn.__version__) >= version.parse('2.0.0')
return version.parse(flash_attn.__version__) >= version.parse(v2_version)


def is_flash_v1_installed():
Expand All @@ -33,6 +34,16 @@ def is_flash_v1_installed():
return version.parse(flash_attn.__version__) < version.parse('2.0.0')


# Before importing any transformers models, we need to disable transformers flash attention if
# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
# gated import otherwise.
if is_flash_v1_installed():
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
import transformers
transformers.utils.is_flash_attn_available = lambda: False

from transformers.models.llama.modeling_llama import apply_rotary_pos_emb


def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
original_is_causal: bool) -> bool:
# disable causal when it is not needed
Expand Down Expand Up @@ -70,7 +81,7 @@ def scaled_multihead_dot_product_attention(
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
Expand All @@ -79,7 +90,7 @@ def scaled_multihead_dot_product_attention(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:

if multiquery:
Expand Down Expand Up @@ -183,7 +194,7 @@ def scaled_multihead_dot_product_attention(


def check_valid_inputs(*tensors: torch.Tensor,
valid_dtypes: Optional[List[torch.dtype]] = None):
valid_dtypes: Optional[list[torch.dtype]] = None):
if valid_dtypes is None:
valid_dtypes = [torch.float16, torch.bfloat16]
for tensor in tensors:
Expand All @@ -199,7 +210,7 @@ def flash_attn_fn(
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
Expand All @@ -208,7 +219,7 @@ def flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
Expand Down Expand Up @@ -337,7 +348,7 @@ def triton_flash_attn_fn(
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
Expand All @@ -346,7 +357,7 @@ def triton_flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
Expand Down Expand Up @@ -552,12 +563,13 @@ def __init__(
def forward(
self,
x: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb_w_meta_info: Optional[dict] = None,
is_causal: bool = True,
needs_weights: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)

Expand All @@ -581,6 +593,39 @@ def forward(
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)

if rotary_emb_w_meta_info is not None:
rotary_emb = rotary_emb_w_meta_info['rotary_emb']
seq_len = rotary_emb_w_meta_info['seq_len']
offset_info = rotary_emb_w_meta_info['offset_info']
bsz, seqlen = query.shape[:2]
query = query.view(bsz, seqlen, -1, self.head_dim)
key = key.view(bsz, seqlen, -1, self.head_dim)

if rotary_emb_w_meta_info['impl'] == 'dail':
value = value.view(bsz, seqlen, -1, self.head_dim)

kv = torch.stack([key, value], dim=2)
query, kv = rotary_emb(query,
kv,
seqlen_offset=offset_info,
max_seqlen=seq_len)
[key, value] = torch.unbind(kv, dim=2)

value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
elif rotary_emb_w_meta_info['impl'] == 'hf':
(cos, sin) = rotary_emb(value, seq_len)
# The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query, key = apply_rotary_pos_emb(query, key, cos, sin,
offset_info)
# The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb
query = query.transpose(1, 2)
key = key.transpose(1, 2)

query = query.view(bsz, seqlen, self.d_model)
key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)

context, attn_weights, past_key_value = self.attn_fn(
query,
key,
Expand Down Expand Up @@ -677,7 +722,7 @@ def __init__(
def attn_bias_shape(
attn_impl: str, n_heads: int, seq_len: int, alibi: bool,
prefix_lm: bool, causal: bool,
use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]:
use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]:
if attn_impl == 'flash':
return None
elif attn_impl in ['torch', 'triton']:
Expand Down
43 changes: 30 additions & 13 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,31 @@
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY

attn_config_defaults: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'triton',
'qk_ln': False,
'clip_qkv': None,
'softmax_scale': None,
'prefix_lm': False,
'attn_uses_sequence_id': False,
'alibi': False,
'alibi_bias_max': 8,
'rope': False,
'rope_theta': 10000,
'rope_impl': 'dail',
'rope_dail_config': {
'type': 'original',
'pos_idx_in_fp32': True,
'xpos_scale_base': 512,
},
'rope_hf_config': {
'type': 'no_scaling',
'factor': 1.0,
},
}


class MPTBlock(nn.Module):

Expand All @@ -30,18 +55,7 @@ def __init__(
**kwargs: Any,
):
if attn_config is None:
attn_config = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'triton',
'qk_ln': False,
'clip_qkv': None,
'softmax_scale': None,
'prefix_lm': False,
'attn_uses_sequence_id': False,
'alibi': False,
'alibi_bias_max': 8,
}
attn_config = attn_config_defaults

if ffn_config is None:
ffn_config = {
Expand All @@ -58,7 +72,8 @@ def __init__(
# necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id',
'alibi_bias_max'
'alibi_bias_max', 'rope', 'rope_theta', 'rope_impl',
'rope_dail_config', 'rope_hf_config'
}
attn_config_subset_for_attn_class = {
k: v
Expand Down Expand Up @@ -94,6 +109,7 @@ def forward(
x: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
rotary_emb_w_meta_info: Optional[Dict] = None,
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
Expand All @@ -104,6 +120,7 @@ def forward(
a,
past_key_value=past_key_value,
attn_bias=attn_bias,
rotary_emb_w_meta_info=rotary_emb_w_meta_info,
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
Expand Down
Loading
Loading