Skip to content

Commit

Permalink
swap the data collator for evals if not using sample packing (#1076)
Browse files Browse the repository at this point in the history
* swap the data collator for evals if not using sample packing

* drop last from dataloader to help with issues with evals
  • Loading branch information
winglian authored Jan 10, 2024
1 parent ec02b7c commit ead34c5
Showing 1 changed file with 42 additions and 4 deletions.
46 changes: 42 additions & 4 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=too-many-lines
"""
Builder for the training args and trainer
"""
Expand Down Expand Up @@ -137,10 +138,19 @@ class AxolotlTrainer(Trainer):
args = None # type: AxolotlTrainingArguments
tag_names = ["axolotl"]

def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
def __init__(
self,
*_args,
num_epochs=1,
bench_data_collator=None,
eval_data_collator=None,
**kwargs
):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator

def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
Expand Down Expand Up @@ -239,6 +249,16 @@ def get_train_dataloader(self) -> DataLoader:
return super().get_train_dataloader()

def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader

if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
Expand Down Expand Up @@ -269,6 +289,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)

return super().get_eval_dataloader(eval_dataset)

def _get_bench_sampler(
Expand Down Expand Up @@ -651,6 +672,12 @@ def build(self, total_num_steps):
training_arguments_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor
if self.cfg.dataloader_drop_last is not None:
training_arguments_kwargs[
"dataloader_drop_last"
] = self.cfg.dataloader_drop_last
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True

if self.cfg.val_set_size == 0:
# no eval set, so don't eval
Expand Down Expand Up @@ -831,6 +858,9 @@ def build(self, total_num_steps):
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=self.build_collator(training_args, **data_collator_kwargs),
eval_data_collator=self.build_collator(
training_args, is_eval=True, **data_collator_kwargs
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
Expand All @@ -851,14 +881,22 @@ def build(self, total_num_steps):

return trainer

def build_collator(self, training_args: AxolotlTrainingArguments, **kwargs):
def build_collator(
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
return None

if self.cfg.model_config_type == "mamba":
return MambaDataCollator(tokenizer=self.tokenizer)

if training_args.sample_packing:
use_batch_sampler_collator = False
if is_eval is False and training_args.sample_packing:
use_batch_sampler_collator = True
if is_eval and training_args.eval_sample_packing:
use_batch_sampler_collator = True

if use_batch_sampler_collator:
return BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
Expand Down

0 comments on commit ead34c5

Please sign in to comment.