Skip to content

Commit

Permalink
update per PR feedback to handle deprecated sharegpt types
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 26, 2023
1 parent ca9abff commit b051681
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ scikit-learn==1.2.2
pynvml
art
wandb
fschat
fschat==0.2.29
19 changes: 19 additions & 0 deletions src/axolotl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,25 @@ def validate_config(cfg):
"`model_type: MixFormerSequentialForCausalLM` required for sample_packing"
)

if cfg.datasets:
for idx, ds_cfg in enumerate(cfg.datasets):
if ds_cfg.type == "sharegpt:chat":
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
cfg.datasets[idx].type = "sharegpt"
if "sharegpt_simple" in ds_cfg.type:
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
"sharegpt_simple", "sharegpt"
)

# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
Expand Down
23 changes: 23 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,26 @@ def test_packing(self):
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)

def test_sharegpt_deprecation(self):
cfg = DictDefault(
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"`type: sharegpt:chat` will soon be deprecated." in record.message
for record in self._caplog.records
)
assert cfg.datasets[0].type == "sharegpt"

cfg = DictDefault(
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"`type: sharegpt_simple` will soon be deprecated." in record.message
for record in self._caplog.records
)
assert cfg.datasets[0].type == "sharegpt:load_role"

0 comments on commit b051681

Please sign in to comment.