From eec349a9e1f41f0ccce0f15e53cb836ae0c52f89 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jan 2024 12:16:45 -0500 Subject: [PATCH] fix hardcoded data collator fix for multipack pretraining --- src/axolotl/core/trainer_builder.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b75766d043..8088f7baf5 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -814,7 +814,7 @@ def build(self, total_num_steps): train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, args=training_args, - # data_collator=self.build_collator(**data_collator_kwargs), + data_collator=self.build_collator(training_args, **data_collator_kwargs), bench_data_collator=transformers.DataCollatorForSeq2Seq( self.tokenizer, return_tensors="pt", @@ -835,7 +835,10 @@ def build(self, total_num_steps): return trainer - def build_collator(self, **kwargs): + def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs): + if training_args.pretraining: + return None + if self.cfg.model_config_type == "mamba": return MambaDataCollator(tokenizer=self.tokenizer)