diff --git a/src/nanotron/models/llama_sft.py b/src/nanotron/models/llama_sft.py index 2cd12eb7..21657869 100644 --- a/src/nanotron/models/llama_sft.py +++ b/src/nanotron/models/llama_sft.py @@ -25,7 +25,6 @@ from nanotron.config import Config, LlamaConfig, ParallelismArgs from nanotron.config.models_config import RandomInit, SpectralMupInit from nanotron.generation.generate_store import AttachableStore -from nanotron.kernels.rope import liger_rotary_pos_emb from nanotron.logging import log_rank from nanotron.models import NanotronModel from nanotron.nn.activations import ACT2FN @@ -54,8 +53,7 @@ ## support the poisiton ids necessary for the cross attention feature. The cos & ## ## sin are created in the embedding layer and propagated through the pipeline so ## ## we don't have a RoPE layer in each and every decoder layer. Then in each and ## -## every decoder layer we apply the cos & sin to Q & K with `liger_rotary_pos_emb`## -## from linkedin/Liger-Kernel ## +## every decoder layer we apply the cos & sin to Q & K with `apply_rotary_pos_emb`## #################################################################################### # NOTE(tj.solergibert) Copied from https://github.com/huggingface/transformers/blob/81233c069c166af033794134bd8888783ac49ebe/src/transformers/modeling_rope_utils.py#L29 @@ -113,6 +111,39 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# NOTE(tj.solergibert) FlashAttention RoPEs are faster (triton), but currently they don't support position_ids +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (torch.Tensor): The query tensor. + k (torch.Tensor): The key tensor. + cos (torch.Tensor): The cosine part of the rotary embedding. + sin (torch.Tensor): The sine part of the rotary embedding. + unsqueeze_dim (int, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + tuple (torch.Tensor) comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + def prepare_varlen_args(position_ids): position_ids = position_ids.flatten() indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) @@ -305,8 +336,11 @@ def forward( .contiguous() ) # [3, batch_size, seq_length, n_local_q_heads, d_qk] + # TODO(tj.solergibert) Apply RoPE embeddings WITHOUT too many transpose... + query_states, key_states = query_states.transpose(1, 2), key_states.transpose(1, 2) # Apply RoPE - query_states, key_states = liger_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = query_states.transpose(1, 2), key_states.transpose(1, 2) # Prepare varlen args cu_seqlens, max_seqlen_in_batch = prepare_varlen_args(position_ids) @@ -345,6 +379,10 @@ def __init__( layer_idx: int, ): super().__init__() + + # NOTE(tj.solergibert) SFT + self.position_embedding = LlamaRotaryEmbedding(config=config) + self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attn = CausalSelfAttention( config=config, @@ -360,10 +398,12 @@ def forward( self, hidden_states: Union[torch.Tensor, TensorPointer], position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] - cos: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] - sin: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + cos, sin = self.position_embedding( + hidden_states, position_ids + ) # TODO(tj.solergibert) We just need from inputs_ids the device type + residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -378,8 +418,6 @@ def forward( return { "hidden_states": hidden_states, "position_ids": position_ids, - "cos": cos, - "sin": sin, } @@ -395,22 +433,12 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_confi ) self.pg = tp_pg - # NOTE(tj.solergibert) SFT - self.position_embedding = LlamaRotaryEmbedding(config=config) - def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): # [batch_size, seq_length] # Format input in `[seq_length, batch_size]` to support high TP with low batch_size input_ids = input_ids.transpose(0, 1) input_embeds = self.token_embedding(input_ids) - # NOTE(tj.solergibert) We create the cos & sin and propagate them through the pipeline so we - # don't have to create the LlamaRotaryEmbedding layer in each and every decoder layer - # We will still send the position ids for the varlen - cos, sin = self.position_embedding( - input_embeds, position_ids - ) # TODO(tj.solergibert) We just need from inputs_ids the device type - - return {"input_embeds": input_embeds, "position_ids": position_ids, "cos": cos, "sin": sin} + return {"input_embeds": input_embeds, "position_ids": position_ids} class LlamaModel(nn.Module): @@ -443,7 +471,7 @@ def __init__( "parallel_config": parallel_config, }, module_input_keys={"input_ids", "position_ids"}, - module_output_keys={"input_embeds", "position_ids", "cos", "sin"}, + module_output_keys={"input_embeds", "position_ids"}, ) self.decoder = nn.ModuleList( @@ -457,8 +485,8 @@ def __init__( "tp_pg": parallel_context.tp_pg, "layer_idx": layer_idx, }, - module_input_keys={"hidden_states", "position_ids", "cos", "sin"}, - module_output_keys={"hidden_states", "position_ids", "cos", "sin"}, + module_input_keys={"hidden_states", "position_ids"}, + module_output_keys={"hidden_states", "position_ids"}, ) for layer_idx in range(config.num_hidden_layers) ] diff --git a/tools/check_sft.py b/tools/check_sft.py index 2f4d68f4..9670f6f5 100644 --- a/tools/check_sft.py +++ b/tools/check_sft.py @@ -19,7 +19,7 @@ dtype = torch.bfloat16 device = torch.device("cuda") -PATH_TO_LLAMA = "meta-llama/Meta-Llama-3.1-8B-Instruct" +PATH_TO_LLAMA = "/store/swissai/a06/models/Meta-Llama-3.1-8B-Instruct" # NOTE(tj.solergibert) This script is for testing porpuses. ONLY use 1 GPU DP = 1 @@ -199,7 +199,7 @@ def main(): # Create ChatDataloaders train_dataset = ChatDataset( - dataset_path="Magpie-Align/Magpie-Pro-300K-Filtered", # "Open-Orca/SlimOrca", + dataset_path="/store/swissai/a06/datasets_raw/Magpie-Pro-300K-Filtered", # "Open-Orca/SlimOrca", tokenizer_name_or_path=PATH_TO_LLAMA, sequence_length=2048, train_on_completions_only=False,