Skip to content

Commit

Permalink
create cos and sin in each decoder layer and check_sft working on todi
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Sep 17, 2024
1 parent f3bf21d commit 06553df
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 24 deletions.
72 changes: 50 additions & 22 deletions src/nanotron/models/llama_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from nanotron.config import Config, LlamaConfig, ParallelismArgs
from nanotron.config.models_config import RandomInit, SpectralMupInit
from nanotron.generation.generate_store import AttachableStore
from nanotron.kernels.rope import liger_rotary_pos_emb
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
Expand Down Expand Up @@ -54,8 +53,7 @@
## support the poisiton ids necessary for the cross attention feature. The cos & ##
## sin are created in the embedding layer and propagated through the pipeline so ##
## we don't have a RoPE layer in each and every decoder layer. Then in each and ##
## every decoder layer we apply the cos & sin to Q & K with `liger_rotary_pos_emb`##
## from linkedin/Liger-Kernel ##
## every decoder layer we apply the cos & sin to Q & K with `apply_rotary_pos_emb`##
####################################################################################

# NOTE(tj.solergibert) Copied from https://github.com/huggingface/transformers/blob/81233c069c166af033794134bd8888783ac49ebe/src/transformers/modeling_rope_utils.py#L29
Expand Down Expand Up @@ -113,6 +111,39 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


# NOTE(tj.solergibert) FlashAttention RoPEs are faster (triton), but currently they don't support position_ids
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (torch.Tensor): The query tensor.
k (torch.Tensor): The key tensor.
cos (torch.Tensor): The cosine part of the rotary embedding.
sin (torch.Tensor): The sine part of the rotary embedding.
unsqueeze_dim (int, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
tuple (torch.Tensor) comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def prepare_varlen_args(position_ids):
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
Expand Down Expand Up @@ -305,8 +336,11 @@ def forward(
.contiguous()
) # [3, batch_size, seq_length, n_local_q_heads, d_qk]

# TODO(tj.solergibert) Apply RoPE embeddings WITHOUT too many transpose...
query_states, key_states = query_states.transpose(1, 2), key_states.transpose(1, 2)
# Apply RoPE
query_states, key_states = liger_rotary_pos_emb(query_states, key_states, cos, sin)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
query_states, key_states = query_states.transpose(1, 2), key_states.transpose(1, 2)

# Prepare varlen args
cu_seqlens, max_seqlen_in_batch = prepare_varlen_args(position_ids)
Expand Down Expand Up @@ -345,6 +379,10 @@ def __init__(
layer_idx: int,
):
super().__init__()

# NOTE(tj.solergibert) SFT
self.position_embedding = LlamaRotaryEmbedding(config=config)

self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = CausalSelfAttention(
config=config,
Expand All @@ -360,10 +398,12 @@ def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
cos: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
sin: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:

cos, sin = self.position_embedding(
hidden_states, position_ids
) # TODO(tj.solergibert) We just need from inputs_ids the device type

residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

Expand All @@ -378,8 +418,6 @@ def forward(
return {
"hidden_states": hidden_states,
"position_ids": position_ids,
"cos": cos,
"sin": sin,
}


Expand All @@ -395,22 +433,12 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_confi
)
self.pg = tp_pg

# NOTE(tj.solergibert) SFT
self.position_embedding = LlamaRotaryEmbedding(config=config)

def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): # [batch_size, seq_length]
# Format input in `[seq_length, batch_size]` to support high TP with low batch_size
input_ids = input_ids.transpose(0, 1)
input_embeds = self.token_embedding(input_ids)

# NOTE(tj.solergibert) We create the cos & sin and propagate them through the pipeline so we
# don't have to create the LlamaRotaryEmbedding layer in each and every decoder layer
# We will still send the position ids for the varlen
cos, sin = self.position_embedding(
input_embeds, position_ids
) # TODO(tj.solergibert) We just need from inputs_ids the device type

return {"input_embeds": input_embeds, "position_ids": position_ids, "cos": cos, "sin": sin}
return {"input_embeds": input_embeds, "position_ids": position_ids}


class LlamaModel(nn.Module):
Expand Down Expand Up @@ -443,7 +471,7 @@ def __init__(
"parallel_config": parallel_config,
},
module_input_keys={"input_ids", "position_ids"},
module_output_keys={"input_embeds", "position_ids", "cos", "sin"},
module_output_keys={"input_embeds", "position_ids"},
)

self.decoder = nn.ModuleList(
Expand All @@ -457,8 +485,8 @@ def __init__(
"tp_pg": parallel_context.tp_pg,
"layer_idx": layer_idx,
},
module_input_keys={"hidden_states", "position_ids", "cos", "sin"},
module_output_keys={"hidden_states", "position_ids", "cos", "sin"},
module_input_keys={"hidden_states", "position_ids"},
module_output_keys={"hidden_states", "position_ids"},
)
for layer_idx in range(config.num_hidden_layers)
]
Expand Down
4 changes: 2 additions & 2 deletions tools/check_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

dtype = torch.bfloat16
device = torch.device("cuda")
PATH_TO_LLAMA = "meta-llama/Meta-Llama-3.1-8B-Instruct"
PATH_TO_LLAMA = "/store/swissai/a06/models/Meta-Llama-3.1-8B-Instruct"

# NOTE(tj.solergibert) This script is for testing porpuses. ONLY use 1 GPU
DP = 1
Expand Down Expand Up @@ -199,7 +199,7 @@ def main():

# Create ChatDataloaders
train_dataset = ChatDataset(
dataset_path="Magpie-Align/Magpie-Pro-300K-Filtered", # "Open-Orca/SlimOrca",
dataset_path="/store/swissai/a06/datasets_raw/Magpie-Pro-300K-Filtered", # "Open-Orca/SlimOrca",
tokenizer_name_or_path=PATH_TO_LLAMA,
sequence_length=2048,
train_on_completions_only=False,
Expand Down

0 comments on commit 06553df

Please sign in to comment.