diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 41ee21f5..f9b311a1 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -106,6 +106,8 @@ def setup_model(args, tokenizer, train_loader, grad_accum): quantization_config=bnb_config, ) else: + from .mods.models import hf as mods_models_hf + mods_models_hf.inject_padding_free_fa2() model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, attn_implementation="flash_attention_2", @@ -499,7 +501,7 @@ def main(args): avg_sample_len=dataset.get_lengths().mean(), effective_batch_size=args.effective_batch_size, max_batch_len_per_gpu=args.max_batch_len, - is_padding=not args.is_granite, + is_padding=not (args.is_granite or args.flatten_with_posid), dataset=dataset, pad_id=tokenizer.pad_token_id, seed=args.seed, @@ -518,6 +520,7 @@ def main(args): samples_per_gpu=args.samples_per_gpu, sampler=args.sampler, seed=args.seed, + flatten_with_posid=args.flatten_with_posid ) if args.local_rank == 0: @@ -744,6 +747,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs): os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py" ), ) + parser.add_argument("--flatten_with_posid", action="store_true", default=False) args = parser.parse_args() set_random_seed(args.seed) main(args) diff --git a/src/instructlab/training/mods/__init__.py b/src/instructlab/training/mods/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/instructlab/training/mods/models/__init__.py b/src/instructlab/training/mods/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/instructlab/training/mods/models/hf/__init__.py b/src/instructlab/training/mods/models/hf/__init__.py new file mode 100644 index 00000000..417acab2 --- /dev/null +++ b/src/instructlab/training/mods/models/hf/__init__.py @@ -0,0 +1,6 @@ +from . import llama +from . import mistral + +def inject_padding_free_fa2(): + llama.inject_padding_free_fa2() + mistral.inject_padding_free_fa2() \ No newline at end of file diff --git a/src/instructlab/training/mods/models/hf/llama.py b/src/instructlab/training/mods/models/hf/llama.py new file mode 100644 index 00000000..463395fd --- /dev/null +++ b/src/instructlab/training/mods/models/hf/llama.py @@ -0,0 +1,214 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from transformers.models.llama.modeling_llama import ( + LLAMA_ATTENTION_CLASSES, + apply_rotary_pos_emb, + LlamaFlashAttention2 +) +from transformers.cache_utils import Cache, StaticCache +from transformers.utils import logging, is_flash_attn_2_available +logger = logging.get_logger(__name__) + +from .utils import prepare_fa2_from_position_ids + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input # noqa + +class PaddingFreeLlamaFlashAttention2(LlamaFlashAttention2): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, position_ids=position_ids + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, position_ids=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Used to compute cu_seqlen + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if (position_ids[:,-1]==position_ids.size(1)-1).all(): + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + else: + logger.warning_once("Using PaddingFree FlashAttention2!") + batch_size = query_states.size(0) + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + + return attn_output + +def inject_padding_free_fa2(): + logger.warning_once("Injecting PaddingFree to LlamaFlashAttention2...") + LLAMA_ATTENTION_CLASSES["flash_attention_2"] = PaddingFreeLlamaFlashAttention2 \ No newline at end of file diff --git a/src/instructlab/training/mods/models/hf/mistral.py b/src/instructlab/training/mods/models/hf/mistral.py new file mode 100644 index 00000000..4715eb5d --- /dev/null +++ b/src/instructlab/training/mods/models/hf/mistral.py @@ -0,0 +1,338 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mistral model.""" + +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +from transformers.models.mistral.modeling_mistral import ( + MISTRAL_ATTENTION_CLASSES, + apply_rotary_pos_emb, + MistralFlashAttention2, + repeat_kv +) +from transformers.cache_utils import Cache, StaticCache +from transformers.utils import logging, is_flash_attn_2_available +logger = logging.get_logger(__name__) + +from .utils import prepare_fa2_from_position_ids + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input # noqa + import inspect + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + +class PaddingFreeMistralFlashAttention2(MistralFlashAttention2): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ): + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += cache_position[0] + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + position_ids=position_ids + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + position_ids=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Used to compute cu_seqlen + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if (position_ids[:,-1]==position_ids.size(1)-1).all(): + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + else: + logger.warning_once("Using PaddingFree FlashAttention2!") + batch_size = query_states.size(0) + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + + return attn_output + +def inject_padding_free_fa2(): + logger.warning_once("Injecting PaddingFree to MistralFlashAttention2...") + MISTRAL_ATTENTION_CLASSES["flash_attention_2"] = PaddingFreeMistralFlashAttention2 \ No newline at end of file diff --git a/src/instructlab/training/mods/models/hf/utils.py b/src/instructlab/training/mods/models/hf/utils.py new file mode 100644 index 00000000..60050a54 --- /dev/null +++ b/src/instructlab/training/mods/models/hf/utils.py @@ -0,0 +1,14 @@ +import torch + +def prepare_fa2_from_position_ids(query, key, value, position_ids, query_length): + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.view(-1, key.size(-2), key.size(-1)) + value = value.view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + cu_seq_lens = torch.cat(( + indices_q[position_ids==0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32) + )) + max_length = position_ids.max()+1 + return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) \ No newline at end of file diff --git a/src/instructlab/training/token_dataset.py b/src/instructlab/training/token_dataset.py index d4acf22d..9cbd21b7 100644 --- a/src/instructlab/training/token_dataset.py +++ b/src/instructlab/training/token_dataset.py @@ -87,15 +87,17 @@ def setup_dataloader( packing_max_batch_len=60000, samples_per_gpu=None, sampler="multipack", + flatten_with_posid=False, seed=47, ) -> DataLoader: collate_fn = make_collate_fn( - pad_token_id, is_granite=is_granite, max_batch_len=max_batch_len + pad_token_id, is_granite=is_granite, max_batch_len=max_batch_len, flatten_with_posid=flatten_with_posid ) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) lengths = dataset.get_lengths() + if sampler == "multipack": sampler = MultipackDistributedBatchSampler( batch_max_length=packing_max_batch_len, @@ -103,7 +105,7 @@ def setup_dataloader( num_replicas=world_size, rank=rank, seed=seed, - padding=not is_granite, + padding=not (is_granite or flatten_with_posid), ) sampler = {"batch_sampler": sampler} elif sampler == "distributed": diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index cf86237f..d2a87438 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -104,10 +104,10 @@ def __init__(self, *args, **kwargs): break -def make_collate_fn(pad_token_id, is_granite=False, max_batch_len=60000): +def make_collate_fn(pad_token_id, is_granite=False, flatten_with_posid=False, max_batch_len=60000): rank = int(os.environ["RANK"]) if is_granite: - + assert not flatten_with_posid def pad_collate_fn(batch): lens = np.array([len(item["input_ids"]) for item in batch]) @@ -135,6 +135,43 @@ def pad_collate_fn(batch): "num_loss_counted_tokens": num_loss_counted_tokens, } + elif flatten_with_posid: + + def pad_collate_fn(batch): + lens = np.array([len(item["input_ids"]) for item in batch]) + max_len = max(lens) + + input_ids = [] + labels = [] + position_ids = [] + has_label = "labels" in batch[0] + + for idx in range(len(batch)): + input_ids.append(batch[idx]["input_ids"]) + _label = batch[idx]["labels"] if has_label else batch[idx]["input_ids"].clone() + _label[0] = -100 + labels.append(_label) + position_ids += list(range(len(batch[idx]["input_ids"]))) + + input_ids = torch.cat(input_ids).unsqueeze(0) + labels = torch.cat(labels).unsqueeze(0) + position_ids = torch.tensor(position_ids, dtype=torch.long).unsqueeze(0) + + num_loss_counted_tokens = (labels != -100).sum() + + print( + f"\033[96m total tokens: {max_len * len(batch)} num samples: {len(batch)} num padding tokens: {max_len * len(batch) - lens.sum()} - rank: {rank} " + f"max len: {max_len} min len: {min(lens)} avg len: {lens.mean()} " + f"num_loss_counted_tokens: {num_loss_counted_tokens}\033[0m" + ) + + return { + "input_ids": input_ids, + "labels": labels, + "position_ids": position_ids, + "num_loss_counted_tokens": num_loss_counted_tokens, + } + else: def pad_collate_fn(batch):