diff --git a/requirements.txt b/requirements.txt index a3e87dacc..3d7dc778c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,4 +32,4 @@ scikit-learn==1.2.2 pynvml art wandb -fschat +fschat==0.2.29 diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 3a574cefc..1dfdab260 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -278,6 +278,25 @@ def validate_config(cfg): "`model_type: MixFormerSequentialForCausalLM` required for sample_packing" ) + if cfg.datasets: + for idx, ds_cfg in enumerate(cfg.datasets): + if ds_cfg.type == "sharegpt:chat": + LOG.warning( + PendingDeprecationWarning( + "`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead." + ) + ) + cfg.datasets[idx].type = "sharegpt" + if "sharegpt_simple" in ds_cfg.type: + LOG.warning( + PendingDeprecationWarning( + "`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead." + ) + ) + cfg.datasets[idx].type = cfg.datasets[idx].type.replace( + "sharegpt_simple", "sharegpt" + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/test_validation.py b/tests/test_validation.py index d7935c1a5..b9a57c2e9 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -374,3 +374,26 @@ def test_merge_lora_no_bf16_fail(self): ) validate_config(cfg) + + def test_sharegpt_deprecation(self): + cfg = DictDefault( + {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]} + ) + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "`type: sharegpt:chat` will soon be deprecated." in record.message + for record in self._caplog.records + ) + assert cfg.datasets[0].type == "sharegpt" + + cfg = DictDefault( + {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]} + ) + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "`type: sharegpt_simple` will soon be deprecated." in record.message + for record in self._caplog.records + ) + assert cfg.datasets[0].type == "sharegpt:load_role"