Skip to content

Commit

Permalink
Fix eval.py with lora (#965)
Browse files Browse the repository at this point in the history
* just remove it?

* or not

* fix

* fix up

* clean up

* fix example yaml

* precommit

* add test
  • Loading branch information
dakinggg committed Feb 9, 2024
1 parent c7c9d24 commit 12d1ca7
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 82 deletions.
11 changes: 11 additions & 0 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss):
def __init__(self, om_model_config: DictConfig,
tokenizer: PreTrainedTokenizerBase):
pretrained_model_name_or_path = om_model_config.pretrained_model_name_or_path
pretrained_lora_id_or_path = om_model_config.get(
'pretrained_lora_id_or_path', None)

if not om_model_config.get(
'trust_remote_code', True
Expand Down Expand Up @@ -249,6 +251,15 @@ def _autoset_attn_implementation_monkeypatch(
if peft_config_dict is not None:
peft_config = self._get_peft_config(peft_config_dict)

if pretrained_lora_id_or_path is not None:
if not peft_installed:
raise ValueError(
'PEFT is not installed, but lora_id_or_path was passed. Please install LLM Foundry with the peft extra to use lora_id_or_path.'
)
from peft import PeftModelForCausalLM
model = PeftModelForCausalLM.from_pretrained(
model, pretrained_lora_id_or_path)

super().__init__(
model=model,
shift_labels=True,
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/models/hf/hf_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import functools
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union

from composer.models.huggingface import maybe_get_underlying_model
from transformers import PreTrainedModel
from transformers.models.opt.modeling_opt import OPTDecoder

Expand Down Expand Up @@ -142,7 +143,8 @@ def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel,

# OPT has an extra layer of wrapping, so special case here
if isinstance(causal_base_model, OPTDecoder):
model.model._fsdp_wrap = False
underlying_model = maybe_get_underlying_model(model)
underlying_model.model._fsdp_wrap = False
model_block = hf_get_hidden_layers(causal_base_model)
lm_head = model.get_output_embeddings()
# some models (OPT) implement .get_input_embeddings for the causal subclass
Expand Down
59 changes: 4 additions & 55 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from rich.traceback import install
from transformers import (AutoModelForCausalLM, PreTrainedTokenizerBase,
T5ForConditionalGeneration)
from transformers import PreTrainedTokenizerBase

install()
from llmfoundry.models import MPTForCausalLM
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
from llmfoundry.utils.builders import (add_metrics_to_eval_loaders,
build_evaluators, build_logger,
Expand All @@ -34,52 +32,6 @@
log = logging.getLogger(__name__)


def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
num_retries: int) -> ComposerModel:
try:
from peft import PeftModel
except ImportError as e:
raise ImportError(
f'Error importing from peft. Run `pip install -e .[gpu,peft]`. \n {e}'
)

model_registry = {
'mpt_causal_lm': MPTForCausalLM,
'hf_causal_lm': AutoModelForCausalLM,
'hf_prefix_lm': AutoModelForCausalLM,
'hf_t5': T5ForConditionalGeneration,
}

retries = 0
composer_model_wrapper = None
while retries < num_retries and composer_model_wrapper is None:
try:
trust_remote_code = model_cfg.get('trust_remote_code', True)
use_auth_token = model_cfg.get('use_auth_token', False)
model = model_registry[model_cfg.name].from_pretrained(
model_cfg.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)

peft_model = PeftModel.from_pretrained(
model, model_cfg.pretrained_lora_id_or_path)

composer_model_wrapper = COMPOSER_MODEL_REGISTRY[model_cfg.name](
peft_model, tokenizer)
except Exception as e:
retries += 1
if retries >= num_retries:
raise e
else:
log.info(
f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining'
)

assert composer_model_wrapper is not None
return composer_model_wrapper


def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
fsdp_config: Optional[Dict], num_retries: int) -> ComposerModel:
init_context = process_init_device(model_cfg, fsdp_config)
Expand Down Expand Up @@ -174,12 +126,8 @@ def evaluate_model(
'The FSDP config block is not supported when loading ' +
'Hugging Face models in 8bit.')

if hasattr(model_cfg.model, 'pretrained_lora_id_or_path'):
composer_model = load_peft_model(model_cfg.model, tokenizer,
num_retries)
else:
composer_model = load_model(model_cfg.model, tokenizer, fsdp_config,
num_retries)
composer_model = load_model(model_cfg.model, tokenizer, fsdp_config,
num_retries)

# Now add the eval metrics
if eval_loader_config is not None:
Expand All @@ -204,6 +152,7 @@ def evaluate_model(
assert composer_model is not None

log.info(f'Building trainer for {model_cfg.model_name}...')

trainer = Trainer(
run_name=run_name,
seed=seed,
Expand Down
22 changes: 2 additions & 20 deletions scripts/eval/yamls/hf_lora_eval.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ max_seq_len: 2048
seed: 1
precision: amp_fp16

# If you are using one model, put it here:
model_name_or_path: EleutherAI/gpt-neo-125m
model_name_or_path: facebook/opt-350m
# If you are using a seperated lora weight, put it here:
# lora weights must be compatible with the specified model
lora_id_or_path: edbeeching/gpt-neo-125M-imdb-lora # Example lora weights for gpt-neo-125m

# otherwise, write a block for each model you want to test in the `models` section
lora_id_or_path: ybelkada/opt-350m-lora # Example lora weights for opt-350m

models:
-
Expand All @@ -23,21 +20,6 @@ models:
name: ${model_name_or_path}
kwargs:
model_max_length: ${max_seq_len}
# # if you are evaluating more than one model, list them all as YAML blocks without variable interpolation
# -
# model_name: mosaicml/mpt-7b
# model:
# name: hf_causal_lm
# pretrained_model_name_or_path: mosaicml/mpt-7b
# init_device: cpu
# pretrained: true
# config_overrides:
# max_seq_len: ${max_seq_len}
# tokenizer:
# name: mosaicml/mpt-7b
# kwargs:
# model_max_length: ${max_seq_len}


device_eval_batch_size: 4

Expand Down
53 changes: 47 additions & 6 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.nn as nn
from accelerate import init_empty_weights
from composer.core.precision import Precision, get_precision_context
from composer.models.huggingface import maybe_get_underlying_model
from composer.optim import DecoupledAdamW
from composer.trainer.dist_strategy import prepare_fsdp_module
from composer.utils import dist, get_device, reproducibility
Expand Down Expand Up @@ -499,8 +500,18 @@ def test_loss_fn():
atol=1e-4), f'differed at step {i}'


def test_opt_wrapping():
conf = {
@pytest.mark.parametrize('peft_config', [
None,
{
'peft_type': 'LORA',
'task_type': 'CAUSAL_LM'
},
])
def test_opt_wrapping(peft_config: Optional[dict[str, str]]):
if peft_config is not None:
_ = pytest.importorskip('peft')

conf: dict[str, dict[str, Union[str, dict]]] = {
'model': {
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': 'facebook/opt-125m',
Expand All @@ -510,6 +521,9 @@ def test_opt_wrapping():
'name': 'facebook/opt-125m'
}
}
if peft_config is not None:
conf['model']['peft_config'] = peft_config

config = DictConfig(conf)

tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(config.tokenizer)
Expand All @@ -519,10 +533,37 @@ def test_opt_wrapping():
model = ComposerHFCausalLM(config.model, tokenizer)

# check that all the modules we except are blocked from FSDP wrapping
assert not model.model.model._fsdp_wrap
assert not model.model.model.decoder._fsdp_wrap
assert not model.model.model.decoder.embed_tokens._fsdp_wrap
assert not model.model.lm_head._fsdp_wrap
underlying_model = maybe_get_underlying_model(model.model)
assert not underlying_model.model._fsdp_wrap
assert not underlying_model.model.decoder._fsdp_wrap
assert not underlying_model.model.decoder.embed_tokens._fsdp_wrap
assert not underlying_model.lm_head._fsdp_wrap


def test_lora_id():
peft = pytest.importorskip('peft')

conf: dict[str, dict[str, Union[str, dict]]] = {
'model': {
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': 'facebook/opt-350m',
'pretrained': 'false',
'pretrained_lora_id_or_path': 'ybelkada/opt-350m-lora',
},
'tokenizer': {
'name': 'facebook/opt-350m'
}
}

config = DictConfig(conf)

tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(config.tokenizer)
tokenizer = build_tokenizer(config.tokenizer.name,
tokenizer_cfg.get('kwargs', {}))

model = ComposerHFCausalLM(config.model, tokenizer)

assert isinstance(model.model, peft.PeftModelForCausalLM)


@pytest.mark.parametrize('norm_type', NORM_CLASS_REGISTRY.keys())
Expand Down

0 comments on commit 12d1ca7

Please sign in to comment.