Skip to content

Commit

Permalink
Allow overrides for nested PretrainedConfig (#1089)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored and KuuCi committed Apr 18, 2024
1 parent 6092a8e commit 4326d19
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
16 changes: 14 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from composer.models.huggingface import peft_installed
from composer.utils import dist
from omegaconf import DictConfig
from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel,
PreTrainedTokenizerBase)
from transformers import (AutoConfig, AutoModelForCausalLM, PretrainedConfig,
PreTrainedModel, PreTrainedTokenizerBase)

from llmfoundry.metrics import (DEFAULT_CAUSAL_LM_EVAL_METRICS,
DEFAULT_CAUSAL_LM_TRAIN_METRICS)
Expand Down Expand Up @@ -161,6 +161,18 @@ def _autoset_attn_implementation_monkeypatch(
elif attr is None and isinstance(v, Mapping):
setattr(config, k, {})
getattr(config, k).update(v)
elif isinstance(attr, PretrainedConfig):
if not isinstance(v, Mapping):
raise ValueError(
f'Expected a dictionary for config override {k}, but got {v}.'
)

for _k, _v in v.items():
if not hasattr(attr, _k):
raise ValueError(
f'config does not have attribute "{_k}" to override ({k}: {_k}: {_v}).'
)
setattr(attr, _k, _v)
else:
setattr(config, k, v)

Expand Down
33 changes: 32 additions & 1 deletion tests/models/hf/test_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, PretrainedConfig

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils import build_tokenizer
Expand Down Expand Up @@ -205,3 +205,34 @@ def test_rope_scaling_override():
# This would error if the config isn't parsed into a proper dictionary
model.get_metadata()
assert model.config.rope_scaling == {'type': 'dynamic', 'factor': 0.5}


@pytest.mark.skipif('HUGGING_FACE_HUB_TOKEN' not in os.environ,
reason='CI does not have access to Dbrx')
def test_nested_override():
model_cfg = {
'name': 'hf_causal_lm',
'pretrained_model_name_or_path': 'databricks/dbrx-instruct',
'config_overrides': {
'ffn_config': {
'ffn_hidden_size': 500,
}
},
'use_auth_token': True,
'pretrained': False,
'init_device': 'meta',
}
model_cfg = om.create(model_cfg)

model = build_composer_model(
name=model_cfg.name,
cfg=model_cfg,
tokenizer=None, # type: ignore
)

# The value we changed
assert model.config.ffn_config.ffn_hidden_size == 500
# Ensure we still have a config, and haven't replaced it with a dictionary
assert isinstance(model.config.ffn_config, PretrainedConfig)
# Ensure the other values still exist and are not set back to their defaults
assert model.config.ffn_config.moe_num_experts == 16

0 comments on commit 4326d19

Please sign in to comment.