Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hf dict cfg overrides #90

Merged
merged 5 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`."""

from typing import Union
from typing import Mapping, Union

from composer.metrics.nlp import (InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
Expand Down Expand Up @@ -47,7 +47,26 @@ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer):
om_model_config.pretrained_model_name_or_path,
trust_remote_code=om_model_config.get('trust_remote_code', True),
use_auth_token=om_model_config.get('use_auth_token', False),
**om_model_config.get('config_overrides', {}))
)

# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
if not hasattr(config, k):
raise ValueError(
f'config does not have attribute "{k}" to override ({k}: {v}).'
)

attr = getattr(config, k)
if isinstance(attr, Mapping):
extra_keys = [_k for _k in v.keys() if _k not in attr.keys()]
if extra_keys:
raise ValueError(
f'Config dict override got unknown keys. '
f'Extra keys: {extra_keys}. '
f'Expected (a subset of) keys: {list(attr.keys())}.')
getattr(config, k).update(v)
vchiley marked this conversation as resolved.
Show resolved Hide resolved
else:
setattr(config, k, v)

train_metrics = [
LanguageCrossEntropy(len(tokenizer)),
Expand Down
25 changes: 23 additions & 2 deletions llmfoundry/models/hf/hf_prefix_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import Union
from typing import Mapping, Union

from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy
from omegaconf import DictConfig
Expand Down Expand Up @@ -69,7 +69,28 @@ class ComposerHFPrefixLM(HuggingFaceModelWithZLoss):
def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer):
config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
**om_model_config.get('config_overrides', {}))
trust_remote_code=om_model_config.get('trust_remote_code', True),
use_auth_token=om_model_config.get('use_auth_token', False),
)

# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
if not hasattr(config, k):
raise ValueError(
f'config does not have attribute "{k}" to override ({k}: {v}).'
)

attr = getattr(config, k)
if isinstance(attr, Mapping):
extra_keys = [_k for _k in v.keys() if _k not in attr.keys()]
if extra_keys:
raise ValueError(
f'Config dict override got unknown keys. '
f'Extra keys: {extra_keys}. '
f'Expected (a subset of) keys: {list(attr.keys())}.')
getattr(config, k).update(v)
else:
setattr(config, k, v)

# Set up the tokenizer (add tokens for denoising sentinels if needed)
if om_model_config.get('adapt_vocab_for_denoising', False):
Expand Down
25 changes: 23 additions & 2 deletions llmfoundry/models/hf/hf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import Union
from typing import Mapping, Union

from composer.metrics.nlp import LanguageCrossEntropy, MaskedAccuracy
from omegaconf import DictConfig
Expand Down Expand Up @@ -58,7 +58,28 @@ class ComposerHFT5(HuggingFaceModelWithZLoss):
def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer):
config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
**om_model_config.get('config_overrides', {}))
trust_remote_code=om_model_config.get('trust_remote_code', True),
use_auth_token=om_model_config.get('use_auth_token', False),
)

# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
if not hasattr(config, k):
raise ValueError(
f'config does not have attribute "{k}" to override ({k}: {v}).'
)

attr = getattr(config, k)
if isinstance(attr, Mapping):
extra_keys = [_k for _k in v.keys() if _k not in attr.keys()]
if extra_keys:
raise ValueError(
f'Config dict override got unknown keys. '
f'Extra keys: {extra_keys}. '
f'Expected (a subset of) keys: {list(attr.keys())}.')
getattr(config, k).update(v)
else:
setattr(config, k, v)

if not config.is_encoder_decoder:
raise ValueError(f'Model type "hf_t5" currently only supports T5 models ' +\
Expand Down
5 changes: 5 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
'name': 'kaiming_normal_',
'fan_mode': 'fan_in',
'init_nonlinearity': 'relu',
'init_div_is_residual': True,
'emb_init_std': None,
'emb_init_uniform_lim': None,
'init_std': None,
'init_gain': 0.0,
}


Expand Down
104 changes: 104 additions & 0 deletions tests/test_hf_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Mapping

import pytest
import torch
from composer.utils import reproducibility
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from transformers import AutoConfig, AutoModelForCausalLM

from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils import build_tokenizer


@pytest.mark.parametrize('model_cfg_overrides', [
{
'max_seq_len': 1024
},
{
'attn_config': {
'attn_impl': 'triton'
}
},
{
'init_config': {
'emb_init_std': 5
}
},
{
'max_seq_len': 1024,
'attn_config': {
'attn_impl': 'triton'
},
'init_config': {
'emb_init_std': 5
},
},
pytest.param({'msl': 1024},
marks=pytest.mark.xfail(reason='"msl" is a ValueError',
strict=True)),
pytest.param({'attn_config': {
'attn_iml': 'triton'
}},
marks=pytest.mark.xfail(reason='"attn_impl" mispelled',
strict=True)),
])
def test_hf_config_override(
model_cfg_overrides,
conf_path='scripts/train/yamls/mpt/testing.yaml',
):
AutoConfig.register('mpt', MPTConfig)
AutoModelForCausalLM.register(MPTConfig, MPTForCausalLM)

with open(conf_path) as f:
test_cfg = om.load(f)

reproducibility.seed_all(test_cfg.seed)

# Build Model
# For fast initialization, use `meta` device
print('Initializing model...')
device = 'cpu'
test_cfg.model.init_device = device
test_cfg.device = device
test_cfg.precision = 'fp16'
test_cfg.model.attn_config = {'attn_impl': 'torch', 'alibi': True}

tokenizer = build_tokenizer(test_cfg.tokenizer)
model = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model,
tokenizer)

# save model
tmp_dir = tempfile.TemporaryDirectory()
save_path = tmp_dir.name

tokenizer.save_pretrained(save_path)
model.config.save_pretrained(save_path)
torch.save(model.state_dict(), Path(save_path) / 'pytorch_model.bin')

# load hf causal lm model with config_overrides
hf_model_config = deepcopy(test_cfg)
model_cfg = {
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': save_path,
'pretrained': False,
'config_overrides': model_cfg_overrides,
}
hf_model_config.model = model_cfg

hf_model = COMPOSER_MODEL_REGISTRY[hf_model_config.model.name](
hf_model_config.model, tokenizer=tokenizer)

for k, v in hf_model_config.model.config_overrides.items():
if isinstance(v, Mapping):
for _k, _v in v.items():
assert getattr(hf_model.config, k)[_k] == _v
else:
assert getattr(hf_model.config, k) == v
18 changes: 13 additions & 5 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ def get_objs(conf_path='scripts/train/yamls/mpt/testing.yaml'):
test_cfg.device_eval_batch_size = 2
test_cfg.device_train_microbatch_size = 2

tokenizer = build_tokenizer(test_cfg.tokenizer)

model = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model,
test_cfg.tokenizer)
tokenizer)
# Optimizer
assert test_cfg.optimizer.name == 'decoupled_adamw'
optimizer = DecoupledAdamW(model.parameters(),
Expand Down Expand Up @@ -302,8 +304,10 @@ def test_determinism(attn_impl: str, precision):
test_cfg.model.init_device = 'cuda:0'
test_cfg.device = 'cuda:0'

tokenizer = build_tokenizer(test_cfg.tokenizer)

model_1 = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model,
test_cfg.tokenizer)
tokenizer)
model_2 = copy.deepcopy(model_1)

optimizer_1 = DecoupledAdamW(model_1.parameters(),
Expand Down Expand Up @@ -359,8 +363,10 @@ def test_loss_fn():
'init_std': 0.02,
}

tokenizer = build_tokenizer(test_cfg.tokenizer)

model_1 = COMPOSER_MODEL_REGISTRY[test_cfg.model.name](test_cfg.model,
test_cfg.tokenizer)
tokenizer)
model_2 = copy.deepcopy(model_1)
assert isinstance(model_1.loss_fn, torch.nn.CrossEntropyLoss)
model_2.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)
Expand Down Expand Up @@ -412,10 +418,12 @@ def test_opt_wrapping(prefixlm):
}
config = DictConfig(conf)

tokenizer = build_tokenizer(config.tokenizer)

if prefixlm:
model = ComposerHFPrefixLM(config.model, config.tokenizer)
model = ComposerHFPrefixLM(config.model, tokenizer)
else:
model = ComposerHFCausalLM(config.model, config.tokenizer)
model = ComposerHFCausalLM(config.model, tokenizer)

# check that all the modules we except are blocked from FSDP wrapping
assert not model.model.model._fsdp_wrap
Expand Down