Skip to content

Commit

Permalink
Deprecate triton, prefix lm, llama attention patch, and text denoisin…
Browse files Browse the repository at this point in the history
…g; Make ComposerHFT5 experimental (mosaicml#1007)

* Deprecate features and mark experimental

* fix typo

---------

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
  • Loading branch information
irenedea and dakinggg committed Mar 4, 2024
1 parent e524fdc commit e7f3c18
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 0 deletions.
5 changes: 5 additions & 0 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import random
import sys
import warnings
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union

import numpy as np
Expand All @@ -20,6 +21,7 @@
from llmfoundry.data.text_data import (StreamingTextDataset,
get_tokens_per_batch_func)
from llmfoundry.models import utils
from llmfoundry.utils.warnings import VersionedDeprecationWarning

__all__ = ['MixtureOfDenoisersCollator', 'build_text_denoising_dataloader']

Expand Down Expand Up @@ -429,6 +431,9 @@ def build_text_denoising_dataloader(
padding/waste rates for different `cfg.dataset.packing_ratio` choices,
given a starting workload YAML.
"""
warnings.warn(
VersionedDeprecationWarning('Text denoising is deprecated.',
remove_version='0.7.0'))
assert cfg.name == 'text_denoising', f'Tried to build_denoising text dataloader with cfg.name={cfg.name}'

collate_fn = MixtureOfDenoisersCollator(
Expand Down
6 changes: 6 additions & 0 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.utils import init_empty_weights
from llmfoundry.utils.config_utils import pop_config
from llmfoundry.utils.warnings import VersionedDeprecationWarning

if TYPE_CHECKING:
from peft import PeftConfig
Expand Down Expand Up @@ -285,6 +286,11 @@ def _patch_attention_type(model: PreTrainedModel,
f'attention_patch_type is only supported for llama models, but got {model.config.model_type}'
)

warnings.warn(
VersionedDeprecationWarning(
'Attention patches for Llama models are deprecated. We recommend `use_flash_attention_2: True` for Llama models.',
remove_version='0.7.0'))

log.debug(
f'Patching llama attention with {attention_patch_type} attention')
from transformers.models.llama.modeling_llama import LlamaAttention
Expand Down
4 changes: 4 additions & 0 deletions llmfoundry/models/hf/hf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import warnings
from typing import Mapping

from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy
Expand All @@ -17,6 +18,7 @@
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.utils import (adapt_tokenizer_for_denoising,
init_empty_weights)
from llmfoundry.utils.warnings import ExperimentalWarning

__all__ = ['ComposerHFT5']

Expand Down Expand Up @@ -57,6 +59,8 @@ class ComposerHFT5(HuggingFaceModelWithZLoss):

def __init__(self, om_model_config: DictConfig,
tokenizer: PreTrainedTokenizerBase):
warnings.warn(ExperimentalWarning(feature_name='ComposerHFT5'))

config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=om_model_config.get('trust_remote_code', True),
Expand Down
6 changes: 6 additions & 0 deletions llmfoundry/models/hf/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import warnings
from collections import UserDict
from typing import TYPE_CHECKING, List, Mapping, Optional

Expand All @@ -16,6 +17,7 @@
from transformers.utils.generic import ModelOutput

from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp
from llmfoundry.utils.warnings import VersionedDeprecationWarning

if TYPE_CHECKING:
from peft import PeftConfig
Expand Down Expand Up @@ -93,6 +95,10 @@ def loss(self, outputs: ModelOutput, batch: Mapping):
if self.z_loss == 0.0:
return loss

warnings.warn(
VersionedDeprecationWarning('z-loss is deprecated.',
remove_version='0.7.0'))

# Add a z_loss to the standard loss
logits_flat = logits.view(-1, logits.size(-1))
labels_flat = batch['labels'].view(-1)
Expand Down
11 changes: 11 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,17 @@ def _validate_config(self) -> None:
if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
raise ValueError(
f"Unknown attn_impl={self.attn_config['attn_impl']}")
if self.attn_config['prefix_lm']:
warnings.warn(
VersionedDeprecationWarning(
'Support for Prefix Language Models is deprecated.',
remove_version='0.7.0'))
if self.attn_config['attn_impl'] == 'triton':
warnings.warn(
VersionedDeprecationWarning(
'Support for triton attention is deprecated. Please use torch or flash attention.',
remove_version='0.7.0'))

if self.attn_config['prefix_lm'] and self.attn_config[
'attn_impl'] not in ['torch', 'triton']:
raise NotImplementedError(
Expand Down
13 changes: 13 additions & 0 deletions llmfoundry/utils/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,16 @@ class VersionedDeprecationWarning(DeprecationWarning):
def __init__(self, message: str, remove_version: str) -> None:
super().__init__(message +
f' It will be removed in version {remove_version}.')


class ExperimentalWarning(Warning):
"""A warning for experimental features.
Attributes:
feature_name (str): The name of the experimental feature.
"""

def __init__(self, feature_name: str) -> None:
super().__init__(
f'{feature_name} is experimental and may change with future versions.'
)

0 comments on commit e7f3c18

Please sign in to comment.