diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index dc8b1501e..26cc91ed5 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -815,7 +815,7 @@ def build(self, total_num_steps): train_dataset=self.train_dataset, eval_dataset=self.eval_dataset, args=training_args, - # data_collator=self.build_collator(**data_collator_kwargs), + data_collator=self.build_collator(training_args, **data_collator_kwargs), bench_data_collator=transformers.DataCollatorForSeq2Seq( self.tokenizer, return_tensors="pt", @@ -836,7 +836,10 @@ def build(self, total_num_steps): return trainer - def build_collator(self, **kwargs): + def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs): + if training_args.pretraining: + return None + if self.cfg.model_config_type == "mamba": return MambaDataCollator(tokenizer=self.tokenizer)