diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ca511996f..226988037 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -197,7 +197,7 @@ def get_train_dataloader(self) -> DataLoader: } if self.args.dataloader_prefetch_factor: dataloader_params[ - "dataloader_prefetch_factor" + "prefetch_factor" ] = self.args.dataloader_prefetch_factor sampler = self._get_train_sampler() @@ -234,7 +234,7 @@ def get_eval_dataloader( } if self.args.dataloader_prefetch_factor: dataloader_params[ - "dataloader_prefetch_factor" + "prefetch_factor" ] = self.args.dataloader_prefetch_factor if isinstance(eval_sampler, BatchSampler): @@ -268,9 +268,7 @@ def get_bench_dataloader( "pin_memory": self.args.dataloader_pin_memory, } if self.args.dataloader_prefetch_factor: - dataloader_params[ - "dataloader_prefetch_factor" - ] = self.args.dataloader_prefetch_factor + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor if not isinstance(bench_dataset, torch.utils.data.IterableDataset): dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)